Pytorch lightning๋ ์ ๋ฒ์ ๋ฆฌ๋ทฐํ ignite์ ๋น์ทํ, ๊ทธ๋์ ๋น๊ต๋๋ ์คํ์์ค ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ด๋ค. ignite๋ pytorch์ ๊ณต์์ ์ธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ผ๊ณ ๋ ํ์ง๋ง, lightning์ด ํ๊ตญ์ด๋ก ๋ ์๋ฃ๊ฐ ๋์ฑ ๋ง์ ๊ฒ ๊ฐ๋ค. ignite์ ํต์ฌ์ด Engine
์ด์๋ ๊ฒ์ฒ๋ผ lightning์ ํต์ฌ์ Trainer
์ lightningmodule
์ด๋ค.
lightning module์ ๋๋ถ๋ถ์ ํ์ต ์๊ณ ๋ฆฌ์ฆ์ด ์ ์๋๋ ํด๋์ค์ด๋ค. lightning module์ ๊ตฌํํ๊ธฐ ์ํด์๋ LightningModule ํด๋์ค๋ฅผ ์์๋ฐ์์ผํ๋ค.
class MyClass(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = ๋ชจ๋ธ
def forward(self, x):
pass
def training_step(self, batch, batch_idx):
pass
def validation_step(self, batch, batch_idx):
pass
def test_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
pass
LightningModule์์ฒด๊ฐ pytorch์ nn.module
์ ์์๋ฐ์ ํด๋์ค์ด๊ธฐ ๋๋ฌธ์ nn.module
์์ ์ฌ์ฉํ ์ ์๋ ๊ฒ๋ค์ ๋ค ์ธ ์ ์๋ค.
Trainer๋ ๊ธฐ๋ณธ์ ์ผ๋ก optimizer step, backward, logging, ๋ถ์ฐํ์ต๋ฑ์ ๋ค๋ฃจ๋ ๋ถ๋ถ์ด๋ค. ๊ทธ๋์ ์ ์ ๊ฐ ์ง์ ์์ ํด์ ์ฌ์ฉํ์ง๋ ์๊ณ ๊ตฌํ๋ ๋ถ๋ถ์ Trainer์์ ๊ฐ์ ธ์ ์ด๋ค๊ณ ํ๋ค.
์ ์ ๋ค์ ์ฝ๋์คํ์ผ์ด ๋น์ทํด์ง๋ค๋ ์ฅ์ ์ด ๊ต์ฅํ ํฐ ๊ฒ ๊ฐ๋ค. ์๋ฌด๋๋ ๊นํ์์ ๋ค๋ฅธ ๊ฐ๋ฐ์์ ์ฝ๋๋ฅผ ๋ดค์๋ "์ด ๋ถ๋ถ์ ์ด๋์๋๊ฑฐ์ง?" "์ด๊ฒ ๋ญ์ง?"์ด๋ฐ ์ง๋ฌธ์ ๋์ก๋ ๋์๊ฒ๋ ignite๋ lightning๊ฐ์ ์คํ์์ค๊ฐ ํ์ฑํ ๋ฌ์ผ๋ฉด ์ข๊ฒ ๋ค๋ ๋ฐ๋จ์ด ์๋ค. ๋ด๊ฐ ์งํํ๋ ํ๋ก์ ํธ์๋ lightning์ ์ ์ฉํด ๋ณด์์ผ๊ฒ ๋ค.