hook이란 (feat. PyTorch)

Cammie·2022년 10월 17일
0

PyTorch

목록 보기
3/4
post-thumbnail

hook

  • 패키지화된 코드에서 다른 프로그래머가 custom 코드를 중간에 실행시킬 수 있도록 만들어놓은 인터페이스
  • self.hooks 에 등록된 함수가 있으면 실행하게 된다.
  • 아래 코드와 같이 활용할 수 있다.
def hook_custom(x):
    print(f'current value is {x}')
클래스객체.hooks = []
클래스객체.hooks.append(hook_custom)

Tensor에 hook을 적용할 때

  • register_hook으로 등록
tensor = torch.rand(1, requires_grad=True)

def tensor_hook(grad):
    pass

tensor.register_hook(tensor_hook)

# tensor는 backward hook만 있다.
tensor._backward_hooks

Module에 hook을 적용할 때

  • register_forward_hook, register_forward_pre_hook, register_full_backward_hook으로 등록할 수 있다.

  • register_forward_hook : forward pass를 하는 동안 (output이 계산할 때 마다) 만들어놓은 hook function을 호출.
    이렇게 등록한 함수에 인자로 모듈이 실행되기 전 입력값과 실행 후 출력값을 받음 (input, output)

  • register_forward_pre_hook : forward pass를 하기 직전에 hook function을 호출.
    이렇게 등록한 함수에 인자로 모듈이 실행되기 전 입력값만을 받음 (input)

  • register_full_backward_hook : backward pass를 하는 동안 (gradient가 계산될 때마다) hook function을 호출.
    이렇게 등록한 함수에 인자로 backpropagation에서의 gradient 값들을 받음 (grad_input, grad_output)

class Model(nn.Module):
    def __init__(self):
        super().__init__()

def module_hook(grad):
    pass

model = Model()
model.register_forward_pre_hook(module_hook)
model.register_forward_hook(module_hook)
model.register_full_backward_hook(module_hook)

모델의 특정 layer에 대해서만 해당 hook 함수를 적용하고 싶을 때

# model.get_model_shortcuts ; 모델 각각의 모듈
for name, module in model.get_model_shortcuts():
	if(name == 'target_layer_name'): # 만약 해당 모듈이름이 우리가 원하는 모듈 네임이면
		module.register_forward_hook(module_hook) # 해당 모듈에 hook 등록

hook 지우기

hook을 등록할 때, 이를 특정 변수에 지정하여 등록한 후, 해당 변수.remove()를 하게 되면 등록된 hook을 지우고 사용할 수 있다.


* 관련 유튜브 자료 : https://www.youtube.com/watch?v=syLFCVYua6Q



0개의 댓글