Pytorch Lightning hooks 톺아보기

yslee·2022년 2월 13일
0

Pytorch Lightning

목록 보기
2/2
post-thumbnail

이번 글에서는 Pytorch Lightning(as PL)의 두 가지 핵심이라 생각되는 컴포넌트 중 하나인 LightningModule (as PL.LM)에서 training loop를 구성하는 내용에 대해 알아보려 한다.

PL를 처음 사용하면서 가장 난감 했던 부분은 on_* 으로 시작되는 각 함수가 어디서 정확하게 실행되는지를 파악하는 부분이었다.

기존 pytorch의 경우 하나하나 사용자가 명시적으로 작성했기 때문에 불편한 코드를 읽을 수만 있으면 파악하는 것은 문서를 읽지 않고 확인할 수 있었다. 하지만 PL의 경우 추상화되어 있는 코드가 무엇인지 확인하고 각 training loop 내부에서 어떻게 실행되는지 먼저 파악할 필요가 있어 확실히 직관적이지는 않다는 것을 확인 할 수 있었다.

때문에 각 hook의 실행순서 및 PL 학습 시 내부 구조를 추상적으로 파악하는 데 도움이 되었으면 싶어 글을 남긴다.

# 에포크 루프 
 while self.epoch < hyperparametsers['epoch']:
	p_bar = tqdm(train_set, total=len(train_set))
    # 학습 루프
	for batch in p_bar:
    	# 학습 로직
		loss = self._train_step(batch)
	
      	# val loop, 로직
      	if self.itr % hyperparametsers['sampling_interval'] == 0:
        	self._test_step(sample_batch)

      	# logging & msg print
      	msg = 'E:%d, Itr:%d, Loss:%0.4f' % (
        	  self.epoch + 1, self.itr, loss)
	    p_bar.set_description(msg)
      	self.itr += 1
		
    # 여러 필요한 작업들 
	... 
    ...
    
    # 최적화 함수
	self.__opti_dis_scheduler.step()
	self.__opti_gen_scheduler.step()
	self.epoch += 1

도메인, 작성자의 취향, 테스트마다 조금씩 순서나 방식이 다르겠지만 우리는 위와 같은 학습에 필요한 작업을 하나씩 작성해 진행할 것이다. 필요에 따라 기존에 작업해 놓은 코드를 가져와 수정해 새로운 학습을 작성하는 경우도 많으리라 생각한다.

하지만 PL에서는 위 모든 것으로 fit() 함수 하나로 동작해 PL.LM의 training_step() 과 같은 메서드를 호출하는 방식으로 작동된다.

하지만

  • backword 전후로 특정 작업을 추가하고 싶다면?
  • gradient clipping 작업을 진행하고 싶다면?
  • 여러 모델을 사용해 배치 데이터가 학습 모델로 들어오기 전 데이터에 처리가 필요하다면?
  • ETC..

다양한 학습 상황을 만족하게 하기에는 training_step() 메서드 하나로는 힘들어질 수 있다. PL은 이런 다양한 상황을 고려해 내부에서 동작하는 학습을 custom 할 수 있는 hooks를 제공한다.

여기서부터 제공되는 코드는 PL문서에 포함된 Hooks의 실행 순서를 확인할 수 있는 의사코드(pseudocode)이며 실제로는 저렇게 단순하게 작동하지는 않지만, PL 학습 구조를 파악하기에는 큰 문제가 없다고 생각해 사용한다.

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)

먼저 fit이 실행되면 datamodule의 prepare_data()를 호출해 데이터를 준비하고 학습을 진행한다.

def train_on_device(model):
    # called PER DEVICE
    on_fit_start() 			
    setup("fit")
    configure_optimizers()

    on_pretrain_routine_start()
    on_pretrain_routine_end()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:  	# 학습이 시작되는 지점
        train_loop()		# 실제 학습 루프가 여기서 호출
    on_train_end()

    on_fit_end()
    teardown("fit")

여기서부터 나오는 on_* 메서드들이 PL.LM에서 오버라이딩해 사용할 수 있는 Hooks 들이다. 학습 중 필요한 영역에서 적절한 Hooks를 정의해 사용할 수 있다.

def train_loop():
    on_epoch_start()
    on_train_epoch_start()
	
    # datamodule로 받은 train_dataloader가 여기서 
    # 베치 데이터로 변환되어 학습 루프를 시작
    for batch in train_dataloader():	
        on_train_batch_start()

        on_before_batch_transfer()
        
        # 데이터를 지정한 GPU로 이동시키는 영역으로 생각
        transfer_batch_to_device()
        on_after_batch_transfer()
		
        # PL.LM에서 작성한 training_step()을 호출
        training_step()
		
        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        
        # training_step에서 반환한 LOSS의 
        # gradient를 계산하는 것으로 보임
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        
        # 정의한 optimizer를 사용한 weight 업데이트
        optimizer_step()
		
        on_train_batch_end()
			
        # should_check_val을 조정하면 
        # looping 내부에서 validation이 가능할 것으로 보임
        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()
    on_epoch_end()
  • transfer_batch_to_device()
  • backward()
  • optimizer_step()
    을 확인해보자
# PL.core.lightning.py (hooks)
def transfer_batch_to_device(
    self,
    batch: Any,
    device: torch.device,
    dataloader_idx: int,
) -> Any:
    return move_data_to_device(batch, device)

분산 처리 등이 있으므로 내부 구현은 복잡하지만, move_data_to_device를 통해 데이터를 GPU로 보내는 것을 확인할 수 있다.

# PL.core.lightning.py (hooks)
def backward(
    self,
    loss: Tensor,
    optimizer: Optional[Optimizer],
    optimizer_idx: Optional[int],
    *args,
    **kwargs
) -> None:
    loss.backward(*args, **kwargs)

loss.backward를 한번 감싼 형태를 가지고 있다.

# PL.core.lightning.py (hooks)
def optimizer_step(
    self,
    epoch: int,
    batch_idx: int,
    optimizer: Union[Optimizer, LightningOptimizer],
    optimizer_idx: int = 0,
    optimizer_closure: Optional[Callable[[], Any]] = None,
    on_tpu: bool = False,
    using_native_amp: bool = False,
    using_lbfgs: bool = False,
) -> None:
    optimizer.step(closure=optimizer_closure)

# PL.plugins.precision.precision_plugin.py 
def optimizer_step(
        self,
        model: Union["pl.LightningModule", Module],
        optimizer: Optimizer,
        optimizer_idx: int,
        closure: Callable[[], Any],
        **kwargs: Any,
    ) -> None:
        """Hook to run the optimizer step."""
        if isinstance(model, pl.LightningModule):
            closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
        optimizer.step(closure=closure, **kwargs)

중간에 AMP와 같은 작업이 있어 레이어가 나뉘어 있지만 결국 입력한 optimizer을 구동시키는 것을 확인할 수 있다.

def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False) # off

    on_validation_start()
    on_epoch_start()
    on_validation_epoch_start()

    for batch in val_dataloader():
        on_validation_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        validation_step()

        on_validation_batch_end()
    validation_epoch_end()

    on_validation_epoch_end()
    on_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True) # on

validation loop는 역시 torch.set_grad_enabled()를 사용해 gradient 추적을 on/off 시킨다. 이외 다른 hooks는 training loop와 같은 구조를 가지는 것으로 보인다.

TF에서 pytorch로 처음 넘어왔을 때 저거 까먹어 실수하는 경우가 많이 있었는데...

이번 글에서는 PL.LM에서 사용 가능한 hooks를 확인했다.
Trainer의 경우 생성자 파라미터가 다양하게 있는 것을 빼면 사용자가 직접 수정해서 사용할 일은 없어 보이지만 다음 글에서는 Trainer에서 설정 할 수 있는 기능을 확인해 보고자 한다.

profile
지식보다 지혜를

0개의 댓글