<reward model 부분>
embedding 을 to_pred 라는 linear layer에 넣어서 나온 output을 (binned or scalar 둘 다 가능) 주어진 label 과 mle loss 혹은 cross entropy 구해서 학습시킨다
reward model 을 Bradely Tery model 안쓰고 단순 supervised leanring 방식 쓰나
palm 부분은 input으로 줄 수 있는 feature 의 각 element가 같은 space에 있는 것도 아니고 sequential 한 의미도 없기 때문에 원래 사용하던 dae 의 feature extraction 그대로 사용하는 것이 적절해 보임
PrefPPO 코드 리뷰
for seed in 12345 23451 34512 45123 51234 67890 78906 89067 90678 6789; do
python train_PrefPPO.py --env quadruped_walk --seed $seed --lr 0.00005 --batch-size 128 --n-envs 16 --ent-coef 0.0 --n-steps 500 --total-timesteps 4000000 --num-layer 3 --hidden-dim 256 --clip-init 0.4 --gae-lambda 0.9 --re-feed-type 1 --re-num-interaction $1 --teacher-beta -1 --teacher-gamma 1 --teacher-eps-mistake 0 --teacher-eps-skip 0 --teacher-eps-equal 0 --re-segment 50 --unsuper-step 32000 --unsuper-n-epochs 50 --re-max-feed 1000 --re-batch 100
done
train_PrefPPO.py
if metaworld_flag:
env = make_vec_metaworld_env(
args.env,
n_envs=args.n_envs,
monitor_dir=args.tensorboard_log,
seed=args.seed)
else:
env = make_vec_dmcontrol_env(
args.env,
n_envs=args.n_envs,
monitor_dir=args.tensorboard_log,
seed=args.seed)
# instantiating the reward model
reward_model = RewardModel(
env.envs[0].observation_space.shape[0],
env.envs[0].action_space.shape[0],
size_segment=args.re_segment,
activation=args.re_act,
lr=args.re_lr,
mb_size=args.re_batch,
teacher_beta=args.teacher_beta,
teacher_gamma=args.teacher_gamma,
teacher_eps_mistake=args.teacher_eps_mistake,
teacher_eps_skip=args.teacher_eps_skip,
teacher_eps_equal=args.teacher_eps_equal,
large_batch=args.re_large_batch)
# network arch
net_arch = [dict(pi=[args.hidden_dim]*args.num_layer,
vf=[args.hidden_dim]*args.num_layer)]
policy_kwargs = dict(net_arch=net_arch)
model = PPO_REWARD(
reward_model,
MlpPolicy, env,
tensorboard_log=args.tensorboard_log,
seed=args.seed,
learning_rate=args.lr,
batch_size=args.batch_size,
n_steps=args.n_steps,
ent_coef=args.ent_coef,
policy_kwargs=policy_kwargs,
use_sde=use_sde,
sde_sample_freq=args.sde_freq,
gae_lambda=args.gae_lambda,
clip_range=clip_range,
n_epochs=args.n_epochs,
num_interaction=args.re_num_interaction,
num_feed=args.re_num_feed,
feed_type=args.re_feed_type,
re_update=args.re_update,
metaworld_flag=metaworld_flag,
max_feed=args.re_max_feed,
unsuper_step=args.unsuper_step,
unsuper_n_epochs=args.unsuper_n_epochs,
size_segment=args.re_segment,
max_ep_len=max_ep_len,
verbose=1)
model.learn(total_timesteps=args.total_timesteps, unsuper_flag=1)
ppo_with_reward.py
if unsuper_flag == 1:
return super(PPO_REWARD, self).learn_unsuper(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
)
→ 우리 환경에서는 unsupervised 필요없을 것 같음
on_policy_with_reward_algorithm.py
def learn_unsuper():
total_timesteps, callback = self._setup_learn(
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
)
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
if self.num_timesteps < self.unsuper_step:
continue_training = self.collect_rollouts_unsuper(
self.env, callback, self.rollout_buffer,
n_rollout_steps=self.n_steps, replay_buffer=self.unsuper_buffer)
else:
if self.first_reward_train == 0:
self.learn_reward()
self.num_interactions = 0
self.first_reward_train = 2
self.policy.reset_value()
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
def collect_rollouts_unsuper(
self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer,
n_rollout_steps: int, replay_buffer: EntReplayBuffer
) -> bool:
n_steps = 0
rollout_buffer.reset()
# RolloutBuffer class 설명
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.dones, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()