Pytorch-Lightning Commom Use Cases 01 - Child Modules

한건우·2021년 10월 20일
0
  • Research 프로젝트는 같은 데이터셋에 대해서 다른 접근법을 취하는 경향이 있음
  • Pytorch Lightning에서는 이 부분을 상속을 통해서 아주 쉽게 처리함

  • 예를 들면, MNIST 이미지에서 feature를 뽑으려고 AutoEncoder를 학습시킨다고 할 때, dataloader 설정이 되어있는 LitMNIST라는 모듈을 확장해서 사용할 수 있음
  • Autoencoder 모델에서 init/forward/training/validation/test step만 변경해주면 됨

class Encoder(torch.nn.Module):
    pass


class Decoder(torch.nn.Module):
    pass


class AutoEncoder(LitMNIST):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.metric = MSE()

    def forward(self, x):
        return self.encoder(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch

        representation = self.encoder(x)
        x_hat = self.decoder(representation)

        loss = self.metric(x, x_hat)
        return loss

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "test")

    def _shared_eval(self, batch, batch_idx, prefix):
        x, _ = batch
        representation = self.encoder(x)
        x_hat = self.decoder(representation)

        loss = self.metric(x, x_hat)
        self.log(f"{prefix}_loss", loss)
  • 이 경우, 기존 사용하던 같은 trainer 인스턴스로 학습이 가능함

autoencoder = AutoEncoder()
trainer = Trainer()
trainer.fit(autoencoder)
  • Lightning Module을 Pytorch model 처럼 사용하고 싶은 경우, 꼭 forward method를 구현해 놓아야함
some_images = torch.Tensor(32, 1, 28, 28)
representations = autoencoder(some_images)
profile
아마추어 GAN잽이

0개의 댓글