모델 클래스 내부에 freeze_bn
, freeze_extractor
등등, 함수를 구현해두고 필요할 때 호출하는 식으로 사용하면 편하다!
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
m.eval()
sequential하게 묶여있으면, 묶여있는 연산들이 하나의 모듈로 처리가 된다. 그래서 Sequential로 묶여있는 애들 하나하나를 보고싶으면 또 그 하위 단위로 내려가야 함.
아래 코드를 보면 for 문이 두번이다.
for m in self.parameters():
m.requires_grad = False
for name, module in self.named_modules():
if name.split(".")[0] in ['motion_encoder', 'gru', 'flow_head', 'mask']:
for n, param in module.named_parameters():
param.requires_grad = True
이 과정에서 .eval()
과 requires_grad = False
의 차이점이 궁금했다. -> 더 알아볼 것!
validation phase 시, 보통
model.eval()
with torch.no_grad():
# Evaluate model
# ...
model.train()
위와 같은 구조를 가지게 된다.
이때, model.train()시 내부 모든 파라미터가 trainable해지기 때문에, 다시 freeze 해줘야 한다.
model.eval()
with torch.no_grad():
# Evaluate model
# ...
model.train()
model.module.freeze_bn()