Pytorch-Lightning Commom Use Cases 04 - Hyperparameters

한건우·2021년 10월 21일
0
  • Pytorch LightningArgumentParser랑 상호작용 가능한 기능을 포함하고 있어, Hyperparameter 최적화 framework와 호환 가능함

ArgumentParser

  • Pytorch-Lightning은 내장 Python ArgumentParser와 호환되도록 디자인됨
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--layer_1_dim", type=int, default=128)
args = parser.parse_args()
  • 이 경우 다음과 같이 프로그램을 실행할 수 있음
python trainer.py --layer_1_dim 64

Argparser Best Practices

  • best practice
    1. Trainer args (gpus, num_nodes, etc...)
    2. Model specific arguments (layer_dim, num_layers, learning_rate , etc ...)
    3. Program arguments (data_path, cluster_email, etc ...)
class LitModel(LightningModule):
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("LitModel")
        parser.add_argument("--encoder_layers", type=int, default=12)
        parser.add_argument("--data_path", type=str, default="/some/path")
        return parent_parser
# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser

parser = ArgumentParser()

# add PROGRAM level args
parser.add_argument("--conda_env", type=str, default="some_name")
parser.add_argument("--notification_email", type=str, default="will@email.com")

# add model specific args
parser = LitModel.add_model_specific_args(parser)

# add all the available trainer options to argparse
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
parser = Trainer.add_argparse_args(parser)

args = parser.parse_args()
  • 커맨드라인 예시
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
  • 학습 코드 예시
# init the trainer like this
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)

# NOT like this
trainer = Trainer(gpus=hparams.gpus, ...)

# init the model with Namespace directly
model = LitModel(args)

# or init the model with all the key-value pairs
dict_args = vars(args)
model = LitModel(**dict_args)

LightningModule hyperparameters

  • 종종 많은 버젼의 모델을 학습시켜야할 때가 있음
  • 몇달 뒤에 다시 이 모델들을 봤을 때 어떻게 학습시켰는지 알 수 있도록 hyperparameter 정보를 저장할 수 있음

  1. save_hyperparameters() method를 사용하려는 LightningModule__init__ 추가하면, self.hparams 속성에 추가됨. 이 hyperparameters는 checkpoint파일에 저장되며 쉽게 세팅가능함
class LitMNIST(LightningModule):
    def __init__(self, layer_1_dim=128, learning_rate=1e-2, **kwargs):
        super().__init__()
        # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
        self.save_hyperparameters()

        # equivalent
        self.save_hyperparameters("layer_1_dim", "learning_rate")

        # Now possible to access layer_1_dim from hparams
        self.hparams.layer_1_dim
  1. 전체다 저장하고 싶지 않고 일부만 저장하고 싶을 때는 다음과 같이 save_hyperparameters()를 사용할 것
class LitMNIST(LightningModule):
    def __init__(self, loss_fx, generator_network, layer_1_dim=128 ** kwargs):
        super().__init__()
        self.layer_1_dim = layer_1_dim
        self.loss_fx = loss_fx

        # call this to save (layer_1_dim=128) to the checkpoint
        self.save_hyperparameters("layer_1_dim")


# to load specify the other args
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())
  1. dictnamespace 형태도 한번에 hparams에 넘겨줄 수 있음
class LitMNIST(LightningModule):
    def __init__(self, conf: Optional[Union[Dict, Namespace, DictConfig]] = None, **kwargs):
        super().__init__()
        # save the config and any extra arguments
        self.save_hyperparameters(conf)
        self.save_hyperparameters(kwargs)

        self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_dim)
        self.layer_2 = nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
        self.layer_3 = nn.Linear(self.hparams.layer_2_dim, 10)


conf = {...}
# OR
# conf = parser.parse_args()
# OR
# conf = OmegaConf.create(...)
model = LitMNIST(conf=conf, anything=10)

# Now possible to access any stored variables from hparams
model.hparams.anything

# for this to work, you need to access with `self.hparams.layer_1_dim`, not `conf.layer_1_dim`
model = LitMNIST.load_from_checkpoint(PATH)
profile
아마추어 GAN잽이

0개의 댓글