환경에 대해서는 일단 좀 미뤄두고 전체적인 흐름부터 보는걸로…
main.py
def main():
actor_critic = Policy(
envs.obs_space.shape,
envs.action_space,
)
actor_critic.to(device)
model.py
class Policy
def init():
if action_space.__class__.__name__ == "Discrete":
num_outputs = action_space.n
self.dist = Categorical(self.base.output_size, num_outputs)
→ Discrete은 Discrete(n) 과 같이 주어졌을 경우 action이 [0, n-1] 중 하나임을 의미
→ Box는 low 와 high 사이 shape 만큼의 dimension을 갖는 실수 vector 형태를 의미
distributions.py
class FixedCategorical(torch.distributions.Categorical):
def sample(self):
return super().sample().unsqueeze(-1)
def log_probs(self, actions):
return (
super()
.log_prob(actions.squeeze(-1))
.view(actions.size(0), -1)
.sum(-1)
.unsqueeze(-1)
)
def mode(self):
return self.probs.argmax(dim=-1, keepdim=True)
class Categorical(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(Categorical, self).__init__()
init_ = lambda m: init(
m,
nn.init.orthogonal_,
lambda x: nn.init.constant_(x, 0),
gain=0.01)
self.linear = init_(nn.Linear(num_inputs, num_outputs))
def forward(self, x):
x = self.linear(x)
return FixedCategorical(logits=x)
main.py
def main():
agent = algo.PPO(
actor_critic,
args.clip_param,
args.ppo_epoch,
args.num_mini_batch,
args.value_loss_coef,
args.entropy_coef,
lr=args.lr,
eps=args.eps,
max_grad_norm=args.max_grad_norm)
main.py
def main():
rollouts = RolloutStorage(args.num_steps, args.num_processes,
envs.observation_space.shape, envs.action_space,
actor_critic.recurrent_hidden_state_size)
obs = envs.reset()
rollouts.obs[0].copy_(obs)
rollouts.to(device)
main.py
def main():
num_updates = int(
args.num_env_steps) // args.num_steps // args.num_processes
main.py
for j in range(num_updates):
if args.use_linear_lr_decay:
# decrease learning rate linearly
utils.update_linear_schedule(
agent.optimizer, j, num_updates,
agent.optimizer.lr if args.algo == "acktr" else args.lr)
main.py
for step in range(args.num_steps):
# Sample actions
with torch.no_grad():
value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
rollouts.obs[step], rollouts.recurrent_hidden_states[step],
rollouts.masks[step])
model.py
def act(self, inputs, rnn_hxs, masks, deterministic=False):
value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)
model.py
MLPBase 의 구조
2-layer로 된 actor와 critic이 있고 1 layer의 critic_linear 가 존재한다
critic_linear는 critic에서 나온 feature를 다시한번 통과시킨다
def forward(self, inputs, rnn_hxs, masks):
x = inputs
if self.is_recurrent:
x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks)
hidden_critic = self.critic(x)
hidden_actor = self.actor(x)
return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs
forward 함수의 return은 위와 같다
model.py
dist = self.dist(actor_features)
if deterministic:
action = dist.mode()
else:
action = dist.sample()
action_log_probs = dist.log_probs(action)
dist_entropy = dist.entropy().mean()
return value, action, action_log_probs, rnn_hxs
algo/distributions.py
class FixedCategorical(torch.distributions.Categorical):
def sample(self):
return super().sample().unsqueeze(-1)
def log_probs(self, actions):
return (
super()
.log_prob(actions.squeeze(-1))
.view(actions.size(0), -1)
.sum(-1)
.unsqueeze(-1)
)
def mode(self):
return self.probs.argmax(dim=-1, keepdim=True)
main.py
episode_rewards = deque(maxlen=10)
obs, reward, done, infos = envs.step(action)
for info in infos:
if 'episode' in info.keys():
episode_rewards.append(info['episode']['r'])
→ 일단 reward는 rollout에 기록되어 훈련에 사용되는 것 같지만 episode_reward는 log 용 데이터로 보임
main.py
masks = torch.FloatTensor(
[[0.0] if done_ else [1.0] for done_ in done])
bad_masks = torch.FloatTensor(
[[0.0] if 'bad_transition' in info.keys() else [1.0]
for info in infos])
rollouts.insert(obs, recurrent_hidden_states, action,
action_log_prob, value, reward, masks, bad_masks)
main.py
with torch.no_grad():
next_value = actor_critic.get_value(
rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
rollouts.masks[-1]).detach()
algo/storage.py
def compute_returns(self,
next_value,
use_gae,
gamma,
gae_lambda,
use_proper_time_limits=True):
if use_proper_time_limits:
if use_gae:
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.size(0))):
delta = self.rewards[step] + gamma * self.value_preds[
step + 1] * self.masks[step +
1] - self.value_preds[step]
gae = delta + gamma * gae_lambda * self.masks[step +
1] * gae
gae = gae * self.bad_masks[step + 1]
self.returns[step] = gae + self.value_preds[step]
main.py
rollouts.compute_returns(next_value, args.use_gae, args.gamma,
args.gae_lambda, args.use_proper_time_limits)
algo/ppo.py
def update(self, rollouts):
# advantage normalize 한 후 data generator에 넘겨줌
advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
advantages = (advantages - advantages.mean()) / (
advantages.std() + 1e-5)
for e in range(self.ppo_epoch):
if self.actor_critic.is_recurrent:
data_generator = rollouts.recurrent_generator(
advantages, self.num_mini_batch)
else:
data_generator = rollouts.feed_forward_generator(
advantages, self.num_mini_batch)
storage.py
def feed_forward_generator():
sampler = BatchSampler(
SubsetRandomSampler(range(batch_size)),
mini_batch_size,
drop_last=True)
sample.py
class SubsetRandomSampler # samples elements randomly from a given list iof indices
class BatchSampler(Sampler[List[int]]): # wraps another sampler to yield a mini-batch of indices
storage.py
for indices in sampler:
obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(
-1, self.recurrent_hidden_states.size(-1))[indices]
actions_batch = self.actions.view(-1,
self.actions.size(-1))[indices]
value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1,
1)[indices]
if advantages is None:
adv_targ = None
else:
adv_targ = advantages.view(-1, 1)[indices]
yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ
ppo.py
for sample in data_generator:
obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \
adv_targ = sample
# Reshape to do in a single forward pass for all steps
values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
obs_batch, recurrent_hidden_states_batch, masks_batch,
actions_batch)
ppo.py
ratio = torch.exp(action_log_probs -
old_action_log_probs_batch)
surr1 = ratio * adv_targ
surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param) * adv_targ
action_loss = -torch.min(surr1, surr2).mean()
위와 같이 명시한 ratio를 정확히 계산하고 있다
ppo.py
if self.use_clipped_value_loss:
value_pred_clipped = value_preds_batch + \
(values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
value_losses = (values - return_batch).pow(2)
value_losses_clipped = (
value_pred_clipped - return_batch).pow(2)
value_loss = 0.5 * torch.max(value_losses,
value_losses_clipped).mean()
else:
value_loss = 0.5 * (return_batch - values).pow(2).mean()
ppo.py
self.optimizer.zero_grad()
(value_loss * self.value_loss_coef + action_loss -
dist_entropy * self.entropy_coef).backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
self.max_grad_norm)
self.optimizer.step()
ppo.py
value_loss_epoch += value_loss.item()
action_loss_epoch += action_loss.item()
dist_entropy_epoch += dist_entropy.item()
# 여기서부터는 가장 바깥의 ppo_epoch loop 탈출
num_updates = self.ppo_epoch * self.num_mini_batch
value_loss_epoch /= num_updates
action_loss_epoch /= num_updates
dist_entropy_epoch /= num_updates
이후 각 epoch 마다 평균 loss 가 어떻게 되는지 각 세부 loss에 대해 구한다
main.py
rollouts.after_update()
storage.py
def after_update(self):
self.obs[0].copy_(self.obs[-1])
self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1])
self.masks[0].copy_(self.masks[-1])
self.bad_masks[0].copy_(self.bad_masks[-1])