[PyTorch] Finetuning - model layer별로 freeze하기

·2022년 6월 27일
0

Freezing

모델 클래스 내부에 freeze_bn, freeze_extractor 등등, 함수를 구현해두고 필요할 때 호출하는 식으로 사용하면 편하다!

특정 연산을 모두 freeze하고 싶은 경우

for m in self.modules():
    if isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
        m.eval()

nn.Sequential로 묶여있는 한 단위를 freeze하고 싶은 경우

sequential하게 묶여있으면, 묶여있는 연산들이 하나의 모듈로 처리가 된다. 그래서 Sequential로 묶여있는 애들 하나하나를 보고싶으면 또 그 하위 단위로 내려가야 함.
아래 코드를 보면 for 문이 두번이다.

for m in self.parameters():
    m.requires_grad = False
for name, module in self.named_modules():
    if name.split(".")[0] in ['motion_encoder', 'gru', 'flow_head', 'mask']:
        for n, param in module.named_parameters():
            param.requires_grad = True

이 과정에서 .eval()requires_grad = False의 차이점이 궁금했다. -> 더 알아볼 것!

주의점

validation phase 시, 보통

model.eval()
with torch.no_grad():
  # Evaluate model
  # ...
model.train()

위와 같은 구조를 가지게 된다.
이때, model.train()시 내부 모든 파라미터가 trainable해지기 때문에, 다시 freeze 해줘야 한다.

model.eval()
with torch.no_grad():
  # Evaluate model
  # ...
model.train()
model.module.freeze_bn()
profile
튼튼

0개의 댓글