๐Ÿ™ˆPytorch ignite๊ฐ€ ๋ญ์•ผ๐Ÿ™‰!!

k-sanaยท2022๋…„ 4์›” 12์ผ
2

pytorch

๋ชฉ๋ก ๋ณด๊ธฐ
1/2
post-thumbnail

๐Ÿค” Boilerplate๋ž€?

ํ…œํ”Œ๋ฆฟ(Template)๊ฐ™์€ ๋Š๋‚Œ์œผ๋กœ, ๋ฐ˜๋ณต์ ์ธ ์ฝ”๋“œ๋ฅผ ํƒ€์ดํ•‘ํ•  ํ•„์š”์—†์ด ๋ฐ˜๋ณต์ ์ธ ์ผ๋“ค์„ ํ•˜์ง€ ์•Š๋„๋ก ๋„์™€์ฃผ๋Š” ๊ฒƒ์ด ๋ฐ”๋กœ ๋ณด์ผ๋Ÿฌํ”Œ๋ ˆ์ดํŠธ(Boilerplate)์ด๋‹ค. pytorch ignite๋Š” ๋”ฅ๋Ÿฌ๋‹ ๋ถ„์•ผ์—์„œ ๋ชจ๋ธ์„ ์ฝ”๋”ฉํ•˜๋Š” ์‹œ๊ฐ„๋ณด๋‹ค ๋ถ€์ˆ˜์ ์ธ ์š”์†Œ์˜ ์ฝ”๋”ฉ(trainer, dataset ๋“ฑ)์— ๋” ๋งŽ์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋˜๊ธฐ ๋•Œ๋ฌธ์— ์žฌ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ฝ”๋“œ๋ฅผ ๋งŒ๋“ค์ž๋Š” ์˜๋ฏธ์ด๋‹ค.

(์ขŒ) ignite๋ฅผ ์‚ฌ์šฉํ•œ ์ฝ”๋“œ (์šฐ) ์ผ๋ฐ˜์ ์ธ ์ฝ”๋“œ

๐Ÿคฉ IGNITE YOUR NETWORKS!

Pytorch-ignite๋ž€ model์„ ํ›ˆ๋ จ์‹œํ‚ค๊ณ  ์—…๋ฐ์ดํŠธํ•˜๋Š” ๋ชจ๋“  ๊ณผ์ •์„ ์ด์™€ ๊ด€๋ จ๋œ ๋ฉ”์†Œ๋“œ๋ฅผ ์ œ๊ณตํ•˜์—ฌ ๊น”๋”ํ•˜๊ณ  ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋„์™€์ฃผ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ด๋‹ค. ignite์˜ ๋ณธ์งˆ์€ ignite.engine.Engine์œผ๋กœ ์ด๋ฃจ์–ด์ ธ์žˆ๋Š”๋ฐ, Engine์€ ์ž…๋ ฅ ๋ฐ›์€ ์—ฐ์‚ฐ์„ ๊ณ„์†ํ•ด์„œ ๋ฐ˜๋ณต ์ˆ˜ํ–‰ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

while epoch < max_epochs:
    # run an epoch on data
    data_iter = iter(data)
    while True:
        try:
            batch = next(data_iter)
            output = process_function(batch)
            iter_counter += 1
        except StopIteration:
            data_iter = iter(data)

        if iter_counter == epoch_length:
            break

์œ„์˜ ์ฝ”๋“œ๋Š” Engine ์˜ ์˜๋ฏธ๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์ฝ”๋“œ์ด๋‹ค.

๐Ÿš— Engine

Engine์€ ์‰ฝ๊ฒŒ ์ƒ๊ฐํ•ด ํ•™์Šตํ•˜๋Š” ๋ถ€๋ถ„์„ ๊ณ„์† ๋Œ๋ ค์ฃผ๋Š” ๊ฒƒ์ด๋‹ค. ๋งˆ์น˜ ์ž๋™์ฐจ ์—”์ง„์ฒ˜๋Ÿผ ๋ง์ด๋‹ค. Engine์˜ ์‚ฌ์šฉ๋ฒ•์€ ignite.engine.engine.Engine(process_function)์ด๋‹ค. process_function์€ ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์ฝ”๋”ฉํ•˜๋Š” ๋ถ€๋ถ„์€ feed-forward, loss ๊ณ„์‚ฐ, ์—ญ์ „ํŒŒ ๊ณ„์‚ฐ, Gradient Descent ์ˆ˜ํ–‰ ๋“ฑ ์ด๋‹ค. ์•„๋ž˜ ์ฝ”๋“œ๋Š” ๊ธฐ๋ณธ ํŠธ๋ ˆ์ด๋„ˆ๋ฅผ ๋งŒ๋“  ์˜ˆ์‹œ์ด๋‹ค.

def update_model(engine, batch): # process function
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model) # ignite.engine.engine.Engine()์˜ ํ˜•์‹

@trainer.on(Events.ITERATION_COMPLETED(every=100)) # event
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss}, lr: {lr}")

trainer.run(data_loader, max_epochs=5) # Engine์€ ๊ฐ„๋‹จํ•˜๊ฒŒ .run()์œผ๋กœ ๋Œ๋ฆด ์ˆ˜ ์žˆ๋‹ค.

> Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
> ...
> Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01

์•„๋ž˜ ์ฝ”๋“œ๋Š” evaluator ์˜ˆ์‹œ ์ฝ”๋“œ์ด๋‹ค.

from ignite.metrics import Accuracy

def predict_on_batch(engine, batch)
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)

    return y_pred, y

evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)

๐Ÿฅณ EVENT

๊ธฐ๋ณธ์ ์ธ train process

Pytorch ignite์—๋Š” Engine์˜ ํšจ์œจ, ์œ ์—ฐ์„ฑ์„ ํ–ฅ์ƒ์‹œํ‚ค๊ธฐ ์œ„ํ•ด EVENT ์‹œ์Šคํ…œ์ด ๋„์ž…๋ฌ๋‹ค. ์˜ˆ๋ฅผ๋“ค๋ฉด

  • STARTED : ์—”์ง„ ์‹คํ–‰์ด ์‹œ์ž‘๋  ๋•Œ ๋ฐœ์ƒํ•˜๋Š” ์ด๋ฒคํŠธ
  • EPOCH_STARTED : Epoch๊ฐ€ ์‹œ์ž‘๋  ๋•Œ ๋ฐœ์ƒํ•˜๋Š” ์ด๋ฒคํŠธ
  • GET_BATCH_STARTED : ๋‹ค์Œ ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ ์ „์— ๋ฐœ์ƒํ•˜๋Š” ์ด๋ฒคํŠธ
    ๋“ฑ, ๋งŽ์€ ์ด๋ฒคํŠธ๊ฐ€ ์กด์žฌํ•œ๋‹ค. ๊ทธ๋ž˜์„œ ์‚ฌ์šฉ์ž๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์ •์˜ํ•œ ์ฝ”๋“œ๋ฅผ Event handler๋กœ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค. handler๋Š” lambda, function, class method ๋“ฑ๊ณผ ๊ฐ™์€ ๋ชจ๋“  ํ•จ์ˆ˜๊ฐ€ ๋  ์ˆ˜ ์žˆ๋‹ค. Pytorch ignite๋Š” ๋งŽ์€ ์ด๋ฒคํŠธ๊ฐ€ ์กด์žฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•จ์ˆ˜๋ฅผ ๋“ฑ๋ก๋งŒ ํ•ด์ฃผ๋ฉด ์‰ฝ๊ฒŒ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.

Event Handler๋Š” add_event_handler ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์„œ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค. ์•„๋ž˜๋Š” ์˜ˆ์‹œ ์ฝ”๋“œ์ด๋‹ค.

def run_validation(engine, validation_engine, valid_loader):
validation_engine.run(valid_loader, max_epoch=1)

train_engine.add_event_handler(
	Events.EPOCH_COMPLETED,
    run_validation,
    validation_engine,
    valid_loader,

์•„๋ž˜๋Š” decorator๋ฅผ ํ™œ์šฉํ•˜์—ฌ Event call-back ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•œ ์˜ˆ์‹œ์ด๋‹ค.

@train_engine.on(Events.EPOCH_COMPLETED)
def print_train_logs(engine):
	avg_p_norm = engine.state.metrics['|param|']
    avg_g_norm = engine.state.metrics['|g_param|']
    avg_loss = engine.state.metrics['loss']
    avg_accuracy = engine.state.metrics['accuracy']
    
    print('Epoch {} - |param|={:.2e} |g_param|={:.2e} loss={}, accuracy={}'
    engine.state.epoch,
    avg_p_norm,
    avg_g_norm,
    avg_loss,
    avg_accuracy,
    ))

๋‹ค์Œ์—๋Š” ๋น„์Šทํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ธ pytorch lighting์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๋„๋ก ํ•˜๊ฒ ๋‹ค.

profile
I'm bamboo.

0๊ฐœ์˜ ๋Œ“๊ธ€