Pytorch Lightning
은 ArgumentParser
랑 상호작용 가능한 기능을 포함하고 있어, Hyperparameter 최적화 framework와 호환 가능함
ArgumentParser
Pytorch-Lightning
은 내장 Python ArgumentParser
와 호환되도록 디자인됨
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--layer_1_dim", type=int, default=128)
args = parser.parse_args()
- 이 경우 다음과 같이 프로그램을 실행할 수 있음
python trainer.py --layer_1_dim 64
Argparser Best Practices
- best practice
- Trainer args (
gpus
, num_nodes
, etc...)
- Model specific arguments (
layer_dim
, num_layers
, learning_rate
, etc ...)
- Program arguments (
data_path
, cluster_email
, etc ...)
class LitModel(LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitModel")
parser.add_argument("--encoder_layers", type=int, default=12)
parser.add_argument("--data_path", type=str, default="/some/path")
return parent_parser
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--conda_env", type=str, default="some_name")
parser.add_argument("--notification_email", type=str, default="will@email.com")
parser = LitModel.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)
trainer = Trainer(gpus=hparams.gpus, ...)
model = LitModel(args)
dict_args = vars(args)
model = LitModel(**dict_args)
LightningModule hyperparameters
- 종종 많은 버젼의 모델을 학습시켜야할 때가 있음
- 몇달 뒤에 다시 이 모델들을 봤을 때 어떻게 학습시켰는지 알 수 있도록 hyperparameter 정보를 저장할 수 있음
save_hyperparameters()
method를 사용하려는 LightningModule
의 __init__
추가하면, self.hparams
속성에 추가됨. 이 hyperparameters는 checkpoint파일에 저장되며 쉽게 세팅가능함
class LitMNIST(LightningModule):
def __init__(self, layer_1_dim=128, learning_rate=1e-2, **kwargs):
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters("layer_1_dim", "learning_rate")
self.hparams.layer_1_dim
- 전체다 저장하고 싶지 않고 일부만 저장하고 싶을 때는 다음과 같이
save_hyperparameters()
를 사용할 것
class LitMNIST(LightningModule):
def __init__(self, loss_fx, generator_network, layer_1_dim=128 ** kwargs):
super().__init__()
self.layer_1_dim = layer_1_dim
self.loss_fx = loss_fx
self.save_hyperparameters("layer_1_dim")
model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())
dict
나 namespace
형태도 한번에 hparams에 넘겨줄 수 있음
class LitMNIST(LightningModule):
def __init__(self, conf: Optional[Union[Dict, Namespace, DictConfig]] = None, **kwargs):
super().__init__()
self.save_hyperparameters(conf)
self.save_hyperparameters(kwargs)
self.layer_1 = nn.Linear(28 * 28, self.hparams.layer_1_dim)
self.layer_2 = nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
self.layer_3 = nn.Linear(self.hparams.layer_2_dim, 10)
conf = {...}
model = LitMNIST(conf=conf, anything=10)
model.hparams.anything
model = LitMNIST.load_from_checkpoint(PATH)