Pytorch - with torch.no_grad()와 model.eval()

권규보·2022년 10월 21일
0

boostcamp

목록 보기
8/9

모델의 성능을 측정할 때 코드에 꼭 들어가는 내용이 있다.
with torch.no_grad()model.eval()
둘 다 성능을 평가할 때 써줘야 하는 것은 알았지만 왜 써야하는지는 잘 알지 못했으며, 굳이 둘 다 해줘야 하는지도 몰랐기 때문에 정리하기로 했다.

1. with torch.no_grad()

이 명령어는 다음과 같이 사용한다.

with torch.no_grad():
	for i batch_in,batch_out in data_iter:
    ...

pytorch의 autograd engine을 비활성화 시킨다. 즉, 이 명령어 후 들여쓰기 한 부분에서는 더이상 gradient를 트래킹하지 않는다.
따라서 메모리가 줄어들고 계산 속도가 증가하는 이점이 있다.
사실 loss.backward()를 안 쓰면 되는 것 아니냐는 의문이 들 수 있는데, loss.backward()를 쓰지도 않을거면서 gradient를 트래킹하는 것이 비효율적이기 때문에 이 명령어를 사용해 준다.

2. model.eval()

항상 train 전에는 model.train(), eval 전에는 model.eval() 을 해야한다고 배웠다.
이 명령어는 학습할 때와 test 시에 다르게 동작하는 Layer들 때문에 필요하다. batchnorm이나 dropout 같은 경우 학습시에는 동작하지만 test 시에는 작동하지 않아야 한다. 이런 layer의 on, off를 해주는 역할을 한다.

여담

with torch.no_grad():
	model.eval()
    ...

혹은

model.eval()
with torch.no_grad():
    ...

둘의 순서 차이는 상관 없다.

profile
기록장

0개의 댓글