RL에서 log_softmax, distribution.Categorical()

JTDK·2021년 6월 28일
0

log_softmax?

torch.nn.functional.log_softmax(input, dim=None, _stacklevel=3, dtype=None)

  • input (Tensor) – input
  • dim (int) – A dimension along which log_softmax will be computed.
  • dtype (torch.dtype, optional) – the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

너무 명시적이라 달리 할말이 없네. 그냥 소프트맥스값을 로그 취하는 함수다.
다만, 두가지를 따로할때보다 정밀도가 높은데, 그 이유는 컴퓨터가 실수를 표현할때 유한개의 소숫점을 사용하기 때문이다. 즉 컴터를 거치면 소프트맥스 값 자체가 근사값으로 나오기때문에, 그걸 구한걸 다시 로그를 하면 원래 값이랑 오차가 생긴다.

Why?

SoftMax야 뭐 output값을 확률로 받아야 여러모로 편리하니 그렇다치는데, log는 왜 하는걸까?

찾아보니 계산상 편리 및 퍼포먼스를 위해서 하는거란다. 유한소수로 실수를 표현하는 컴퓨터의 특성 상, 너무 작은 확률이나 너무 큰 확률이 들어오면 이를 제대로 반영하지 못한다. 따라서 log를 취해 0~1사이의 값을 -∞~0의 값으로 바꿔서 충분히 표현되게 한다. 또한 곱셈을 덧셈으로 표현할 수 있다는 점에서 계산상 편의도 존재한다.

  • 계산할때 편리하다
  • 0~1사이의 값을 -∞~0사이로 바꿔서 제대로 값이 표현되게 한다.

Implementation

import torch.nn as nn
import torch as T
import torch.nn.functional as F

class model(nn.Module){

	...
    
    def forward(self, x):
        x = F.normalize(x, dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.log_softmax(self.actor_lin1(y), dim=0) ## 여기
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic
    	
}	

	...
    
    def run_episode(worker_env, worker_model):
    
   	...
            
    	action_dist = torch.distributions.Categorical(logits=logits)
       	action = action_dist.sample()
            
    	...
        
    

위 예제처럼 action에대한 최종 output을 받을때 확률의 로그 값인 log_softmax를 이용한다.

torch.distributions.Categorical()

torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)

  • logit 값이나, 확률값을 인풋으로 받아서(둘중하나만) Categorical 객체를 생성한다

Categorical 객체는 카테고리 확률분포를 가지는 객체인데, 카테고리 확률분포는
출력 (원핫)벡터 x=(x1,x2,,xK)x=(x1,x2,…,xK), 모수 벡터 μ=(μ1,,μK)μ=(μ1,…,μK) 에 대하여 Cat(x;μ)Cat(x;μ)로 표현할 수 있다. 쉽게말해 Categorical()xx의 확률(μμ)분포를 가지는 객체로, 다음과 같이 구현한다.

m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
m.sample()  # equal probability of 0, 1, 2, 3

보통 실제 구현에서는 저 위의 예제처럼 log_softmax와 세트로 사용한다

profile
RL, 퀀트 투자 공부 정리

0개의 댓글