zero_grad()의 이해

JTDK·2021년 6월 28일
0

Neural Network에서 parameter들을 업데이트 할때 우리는 zero_grad()를 습관적으로 사용한다.

import torch as T
import torch,optimizer as optim

optimizer = optim.Adam(lr=lr, params=self.parameters())

...

optimizer.zero_grad()
loss.backward()
optimizer.step()

zero_grad()의 의미가 뭔지, 왜 사용하는지 알아보장

TORCH.OPTIM.OPTIMIZER.ZERO_GRAD

Optimizer.zero_grad(set_to_none=False)[source]
Sets the gradients of all optimized torch.Tensor s to zero.

엄청 심플하게 적혀있는데 좀 풀어써보자면,
보통 딥러닝에서는 미니배치+루프 조합을 사용해서 parameter들을 업데이트하는데,
한 루프에서 업데이트를 위해 loss.backward()를 호출하면 각 파라미터들의 .grad 값에 변화도가 저장이 된다.

이후 다음 루프에서 zero_grad()를 하지않고 역전파를 시키면 이전 루프에서 .grad에 저장된 값이 다음 루프의 업데이트에도 간섭을 해서 원하는 방향으로 학습이 안된다고 한다.

따라서 루프가 한번 돌고나서 역전파를 하기전에 반드시 zero_grad().grad 값들을 0으로 초기화시킨 후 학습을 진행해야 한다.

쓸 일은 없지만 zero_grad() 에 인자로 넘겨주는 set_to_none=False을 잠깐 보자.

보통 학습에서는 디폴트 값을 그냥 사용하지만, set_to_none=True로 인자를 넘기면 메모리 사용량이 적어져서 퍼포먼스가 소폭 향상하는 장점이 있다. 하지만 유저가 manual하게 .grad 값에 접근해서 수정을 할때 다르게 작동한다거나 하는 점 때문에 쓰이진 않는다.

profile
RL, 퀀트 투자 공부 정리

0개의 댓글