PEBBLE (pbrl) 코드리뷰

이두현·2024년 3월 17일
0

TODO (2022.09.22)

1) actor sample distribution 함수 상세히 보기

run_PEBBLE.sh

for seed in 12345 23451 34512 45123 51234 67890 78906 89067 90678 6789; do
    python train_PEBBLE.py  env=quadruped_walk seed=$seed agent.params.actor_lr=0.0001 agent.params.critic_lr=0.0001 gradient_update=1 activation=tanh num_unsup_steps=9000 num_train_steps=1000000 num_interact=30000 max_feedback=1000 reward_batch=100 reward_update=50 feed_type=$1 teacher_beta=-1 teacher_gamma=1 teacher_eps_mistake=0 teacher_eps_skip=0 teacher_eps_equal=0.1
done
  • train_PEBBLE.py 함수에 config 를 넘겨주면서 코드 시작

train_PEBBLE.py

  • main 함수는 class Workspace 의 run 함수를 호출하며 시작
def __init__(self, cfg):
				# reproduce same result with same seed 
				utils.set_seed_everywhere(cfg.seed)
				cfg.agent.params.obs_dim = self.env.observation_space.shape[0]
        cfg.agent.params.action_dim = self.env.action_space.shape[0]
        cfg.agent.params.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]
				# Replaybuffer and Rewardmodel initialize
  • 같은 seed를 줄 경우 같은 실험결과를 주기 위해 set_seed_everywhere 함수를 정의

https://hoya012.github.io/blog/reproducible_pytorch/

https://pytorch.org/docs/stable/notes/randomness.html

함수 내부에서는 torch, cuda, python random 에 대한 seed 고정을 하고 있음

# main 함수 계속
# get from B_Pref/agent/sac.py
self.agent = hydra.utils.instantiate(cfg.agent)
# instantiate replay_buffer and reward_model
self.replay_buffer = ReplayBuffer()
self.reward_model = RewardModel()

  • step 별 run 함수 동작

전체제약 : self.step < self.cfg.num_train_steps

1) self.cfg.num_seed_steps 를 기준으로 action 선택 방식을 결정

self.step < self.cfg.num_seed_steps → action_space에서 random sampling

self.step ≥ self.cfg.num_seed_steps → evaluation mode에서 agent가 action 선택

num_seed_steps : 1000 // num_unsup_steps : 5000

2) [0, num_seed_steps] // [num_seed_steps, num_unsup_steps] // [num_unsup_steps, …] 기준별 행동양식

2-1) self.step == (self.cfg.num_seed_steps + self.cfg.num_unsup_steps)

reward_model의 change_batch 함수를 실행

reward_model의 teacher threshold를 수정

learn_reward를 처음 받음

replay_buffer를 relabel 함

unsupervised exploration 을 했기 때문에 Q function(critic)을 reset함 (self.agent.reset() 과의 차이?)

agent update after reset을 진행 (그냥 agent update와의 차이?)

2-2) self.step > self.cfg.num_seed_steps + self.cfg.num_unsup_steps

total_feedback 과 max_feedback 사이의 관계에 따라

reward_model의 chage_batch

reward_model의 teacher threshold

reward_model의 set_batch 수행

learn_reward 를 받음(처음 받는 것 아님!)

replay_buffer를 relabel 함

agent update를 진행

2-3) self.step > self.cfg.num_seed_steps

agent 의 update_state_ent 함수가 수행됨

종합)

  • self.step 이 self.cfg.num_seed_steps 보다 작을 경우 : action random sampling

  • self.step 이 self.cfg.num_seed_steps 과 self.cfg.num_seed_steps + self.cfg.num_unsup_steps 사이인 경우 : action 은 agent에 의해 선택되며, agent의 update_state_ent 함수가 수행됨 2-3) 과정

  • self.step 이 self.cfg.num_seed_steps + self.cfg.num_unsup_steps 보다 크거나 같을 경우 같은 경우나 큰 경우에 따라 2-1) 이나 2-2) 절차에 따라 수행

  • 위와 step 수와 관계없이 env step, reward_hat 생성, reward_model과 replay_buffer에 정보 저장은 반드시 이뤄짐

# train_PEBBLE.py 계속
def run(self):
						if done:
								obs = self.env.reset()
                self.agent.reset()
                done = False
                episode_reward = 0
                avg_train_true_return.append(true_episode_reward)
                true_episode_reward = 0
                if self.log_success:
                    episode_success = 0
                episode_step = 0
                episode += 1

						if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, sample=True)
# agent/sac.py
def act(self, obs, sample=False):
        obs = torch.FloatTensor(obs).to(self.device)
        obs = obs.unsqueeze(0)
        dist = self.actor(obs)
				# dist : torch.distribution class 반환
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range) # action_range 사이에 값을 맞춤
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])
  • utils.eval_mode(self.agent) 는 agent 의 이전 model.training 상태를 저장하고 enter 시 model.train(False) 설정해준 후 exit 시에는 다시 model.train(원래 상태) 로 되돌려준다
  • action 선택에 사용되는 actor distribution 에 대해서는 추후 더 자세히 봐야할 듯!
# # train_PEBBLE.py 계속
if self.step == (self.cfg.num_seed_steps + self.cfg.num_unsup_steps):
                # update schedule
                if self.cfg.reward_schedule == 1:
                    frac = (self.cfg.num_train_steps-self.step) / self.cfg.num_train_steps
                    if frac == 0:
                        frac = 0.01
                elif self.cfg.reward_schedule == 2:
                    frac = self.cfg.num_train_steps / (self.cfg.num_train_steps-self.step +1)
                else:
                    frac = 1
                self.reward_model.change_batch(frac)

Untitled

  • B-pref 논문에 설명된 것과 같이 reward_schedule ==1 인 경우 step에 따라 fraction 감소, reward_schedule==2 인 경우 step에따라 fraction 증가
  • else는 uniform scheduling
  • self.reward_model.change_batch(frac) 은 reward_model의 self.mb_size 를 frac 만큼 변화시키는 함수
# # train_PEBBLE.py 계속
# update margin --> not necessary / will be updated soon
                new_margin = np.mean(avg_train_true_return) * (self.cfg.segment / self.env._max_episode_steps)
                self.reward_model.set_teacher_thres_skip(new_margin)
                self.reward_model.set_teacher_thres_equal(new_margin)
# first learn reward
                self.learn_reward(first_flag=1)
  • margin은 중요하지 않다? → 나중에 더 보기
# train_PEBBLE.py
def learn_reward(self, first_flag=0):
				if first_flag == 1:
						# return the number of labeled queries
            labeled_queries = self.reward_model.uniform_sampling()
  • first_flag 가 켜지면 무조건 uniform_sampling 을 하게되며 이는 Bpref 논문에 의하면 아래와 같이 정의된다

Untitled

# reward_model.py
def uniform_sampling(self):
        # get queries
        sa_t_1, sa_t_2, r_t_1, r_t_2 =  self.get_queries(
            mb_size=self.mb_size)
            
        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)
        
        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)
        
        return len(labels)
# reward_model.py
# self.input 에는 obs + action concat 정보가 (flatten 시켜 append)
# self.target 에는 reward 를 의미하는 r_t가 들어가있다 
def get_queries(self, mb_size=20):
        len_traj, max_len = len(self.inputs[0]), len(self.inputs)
        img_t_1, img_t_2 = None, None
        
				# 만약 마지막 obs+action 정보가 다 차있지 않으면 마지막 정보는 drop
        if len(self.inputs[-1]) < len_traj:
            max_len = max_len - 1
        
        # get train traj
				# 마지막 정보 drop 여부를 결정한 input, target 정보
        train_inputs = np.array(self.inputs[:max_len])
        train_targets = np.array(self.targets[:max_len])
   
				# mb_size 만큼의 random number를 추출하고 뽑은 수를 다시넣고 뽑는 복원추출을 진행한다 
        batch_index_2 = np.random.choice(max_len, size=mb_size, replace=True)
        sa_t_2 = train_inputs[batch_index_2] # Batch x T x dim of s&a
        r_t_2 = train_targets[batch_index_2] # Batch x T x 1
        
        batch_index_1 = np.random.choice(max_len, size=mb_size, replace=True)
        sa_t_1 = train_inputs[batch_index_1] # Batch x T x dim of s&a
        r_t_1 = train_targets[batch_index_1] # Batch x T x 1
                
        sa_t_1 = sa_t_1.reshape(-1, sa_t_1.shape[-1]) # (Batch x T) x dim of s&a
        r_t_1 = r_t_1.reshape(-1, r_t_1.shape[-1]) # (Batch x T) x 1
        sa_t_2 = sa_t_2.reshape(-1, sa_t_2.shape[-1]) # (Batch x T) x dim of s&a
        r_t_2 = r_t_2.reshape(-1, r_t_2.shape[-1]) # (Batch x T) x 1

        # Generate time index 
        time_index = np.array([list(range(i*len_traj,
                                            i*len_traj+self.size_segment)) for i in range(mb_size)])
        time_index_2 = time_index + np.random.choice(len_traj-self.size_segment, size=mb_size, replace=True).reshape(-1,1)
        time_index_1 = time_index + np.random.choice(len_traj-self.size_segment, size=mb_size, replace=True).reshape(-1,1)
        
        sa_t_1 = np.take(sa_t_1, time_index_1, axis=0) # Batch x size_seg x dim of s&a
        r_t_1 = np.take(r_t_1, time_index_1, axis=0) # Batch x size_seg x 1
        sa_t_2 = np.take(sa_t_2, time_index_2, axis=0) # Batch x size_seg x dim of s&a
        r_t_2 = np.take(r_t_2, time_index_2, axis=0) # Batch x size_seg x 1
                
        return sa_t_1, sa_t_2, r_t_1, r_t_2
  • len_traj 변수에 (T * dim of s&a) 만큼이 할당된 것으로 생각
  • sa_t_x, r_t_x 는 dimension 설명에 되어있는것과 같이 (mb_size * len_traj) 만큼의 s&a concatenated 저장되어있음
  • time_index 변수에는 [[len_traj, len_traj + size_seg], [len_traj 2, len_traj 2 + size_seg],…] 가 저장되어있고 각 list는 size_seg 만큼의 길이이므로 (0, len_traj - size_seg) 중에서 random 하게 고른 변수를 더하면 각 list에서 [len_traj i + random, len_traj i + size_seg + random] 이 범위 안에 들어오도록 time_index_x 변수를 맞출 수 있다
# reward_model.py
# receive two trajectories and return teacher's preference
def get_label(self, sa_t_1, sa_t_2, r_t_1, r_t_2):
				# Batch x 1 return 할 것 
        sum_r_t_1 = np.sum(r_t_1, axis=1)
        sum_r_t_2 = np.sum(r_t_2, axis=1)
        
        # skip the query
				# skipping threshold 가 설정되어있는 경우
        if self.teacher_thres_skip > 0: 
						# np.maximum 은 array 각 위치별 둘 중 최대값을 반환
            max_r_t = np.maximum(sum_r_t_1, sum_r_t_2)
            max_index = (max_r_t > self.teacher_thres_skip).reshape(-1)
						# threhold를 넘는 reward가 하나도 없는 경우 
            if sum(max_index) == 0:
                return None, None, None, None, []

						# 모든 input 변수에 대해 threshold를 넘긴 값들만 취급하여 새로운 값 설정
            sa_t_1 = sa_t_1[max_index]
            sa_t_2 = sa_t_2[max_index]
            r_t_1 = r_t_1[max_index]
            r_t_2 = r_t_2[max_index]
						# margin_index를 구분하기 위한 과정 
            sum_r_t_1 = np.sum(r_t_1, axis=1)
            sum_r_t_2 = np.sum(r_t_2, axis=1)
        
        # equally preferable
				# 둘 중 선호도 구분이 불가능한 경우
				# 후에 이 index 에 해당하는 label은 -1로 설정됨
        margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) < self.teacher_thres_equal).reshape(-1)
        
        # perfectly rational
				# r_t_1.shape => Batch x seg_size x 1 
        seg_size = r_t_1.shape[1]
        temp_r_t_1 = r_t_1.copy()
        temp_r_t_2 = r_t_2.copy()
				# discount factor 고려해서 값 재할당
        for index in range(seg_size-1):
            temp_r_t_1[:,:index+1] *= self.teacher_gamma
            temp_r_t_2[:,:index+1] *= self.teacher_gamma
				# seg_size 방향으로 더해 Batch x 1 값 도출
        sum_r_t_1 = np.sum(temp_r_t_1, axis=1)
        sum_r_t_2 = np.sum(temp_r_t_2, axis=1)
            
				
        rational_labels = 1*(sum_r_t_1 < sum_r_t_2)
        if self.teacher_beta > 0: # Bradley-Terry rational model
						# batch 방향으로 두개 concat 
            r_hat = torch.cat([torch.Tensor(sum_r_t_1), 
                               torch.Tensor(sum_r_t_2)], axis=-1)
            r_hat = r_hat*self.teacher_beta
						# batch 당 두 개 r 값에 대한 softmax 진행
            ent = F.softmax(r_hat, dim=-1)[:, 1]
						# Bernoulli distribution 에 따라 0,1 generate
            labels = torch.bernoulli(ent).int().numpy().reshape(-1, 1)
        else:
						# 아니라면 단순 값비교한 결과가 label 로 설정
            labels = rational_labels
        
        # making a mistake
        len_labels = labels.shape[0]
				# len_labels 만큼의 random number 생성
        rand_num = np.random.rand(len_labels)
        noise_index = rand_num <= self.teacher_eps_mistake
				# 주어진 조건만큼 noise 주기
        labels[noise_index] = 1 - labels[noise_index]
 
        # equally preferable
				# preference 구분불가 항목에는 -1 부여
        labels[margin_index] = -1 
        
        return sa_t_1, sa_t_2, r_t_1, r_t_2, labels
  • 받은 trajectory 들을 통해 label을 return 하는 과정
# reward_model.py
# receive two trajectories and return teacher's preference
def put_queries(self, sa_t_1, sa_t_2, labels):
		    # batch size 알아내기  
				total_sample = sa_t_1.shape[0]
				# query를 모두 넣었을 때 buffer index 
        next_index = self.buffer_index + total_sample
        if next_index >= self.capacity:
            self.buffer_full = True
            maximum_index = self.capacity - self.buffer_index
            np.copyto(self.buffer_seg1[self.buffer_index:self.capacity], sa_t_1[:maximum_index])
            np.copyto(self.buffer_seg2[self.buffer_index:self.capacity], sa_t_2[:maximum_index])
            np.copyto(self.buffer_label[self.buffer_index:self.capacity], labels[:maximum_index])

						# 넘어가서 buffer 앞에부터 다시 채움
            remain = total_sample - (maximum_index)
            if remain > 0:
                np.copyto(self.buffer_seg1[0:remain], sa_t_1[maximum_index:])
                np.copyto(self.buffer_seg2[0:remain], sa_t_2[maximum_index:])
                np.copyto(self.buffer_label[0:remain], labels[maximum_index:])

            self.buffer_index = remain
        else:
						# 단순 복사 후 index update
            np.copyto(self.buffer_seg1[self.buffer_index:next_index], sa_t_1)
            np.copyto(self.buffer_seg2[self.buffer_index:next_index], sa_t_2)
            np.copyto(self.buffer_label[self.buffer_index:next_index], labels)
            self.buffer_index = next_index
  • sa_t_x 를 buffer_segx , labels를 buffer_label에 넣어주는 과정
# train_PEBBLE.py
# learn_reward 이어서
				self.total_feedback += self.reward_model.mb_size
				# uniform_sampling 함수는 len(labels)를 반환
				self.labeled_feedback += labeled_queries
				train_acc = 0
        if self.labeled_feedback > 0:
            # update reward
            for epoch in range(self.cfg.reward_update):
                if self.cfg.label_margin > 0 or self.cfg.teacher_eps_equal > 0:
                    train_acc = self.reward_model.train_soft_reward()
                else:
                    train_acc = self.reward_model.train_reward()
                total_acc = np.mean(train_acc)
                
                if total_acc > 0.97:
                    break;
  • train_soft_reward 함수의 호출 조건에 대한 탐색 필요!
# reward_model.py
def train_reward(self):
				# self.de == ensemble_size
        ensemble_losses = [[] for _ in range(self.de)]
        ensemble_acc = np.array([0 for _ in range(self.de)])
        
        max_len = self.capacity if self.buffer_full else self.buffer_index
        total_batch_index = []
        for _ in range(self.de):
						# [0, max_len] 랜덤배열 generate
            total_batch_index.append(np.random.permutation(max_len))
        
				# np.ceil 은 올림함수
        num_epochs = int(np.ceil(max_len/self.train_batch_size))
        list_debug_loss1, list_debug_loss2 = [], []
        total = 0
        
        for epoch in range(num_epochs):
            self.opt.zero_grad()
            loss = 0.0
            
						# 해당 epoch에서 사용할 가장 마지막 index
            last_index = (epoch+1)*self.train_batch_size
            if last_index > max_len:
                last_index = max_len
            
						# range in ensemble_size    
            for member in range(self.de):
                
                # get random batch
								# total_batch_index[member]는 [0, max_len]의 임의 배열 가리킴
								# train batch 개수 만큼의 데이터를 idxs로 뽑아옴
                idxs = total_batch_index[member][epoch*self.train_batch_size:last_index]
                sa_t_1 = self.buffer_seg1[idxs]
                sa_t_2 = self.buffer_seg2[idxs]
                labels = self.buffer_label[idxs]
                labels = torch.from_numpy(labels.flatten()).long().to(device)
                
                if member == 0:
                    total += labels.size(0)
                
                # get logits
                r_hat1 = self.r_hat_member(sa_t_1, member=member)
                r_hat2 = self.r_hat_member(sa_t_2, member=member)
                r_hat1 = r_hat1.sum(axis=1)
                r_hat2 = r_hat2.sum(axis=1)
								# label 생성때와 마찬가지로 batch 방향 두개 concat
                r_hat = torch.cat([r_hat1, r_hat2], axis=-1)

                # compute loss
								# nn.CrossEntropyLoss : LogSoftmax + NLLLoss
                curr_loss = self.CEloss(r_hat, labels)
                loss += curr_loss
								# curr_loss.item() -> detach into python float
                ensemble_losses[member].append(curr_loss.item())
                
                # compute acc
								# torch.max returns (max val, max indicies) // 1 indicates dimension
                _, predicted = torch.max(r_hat.data, 1)
                correct = (predicted == labels).sum().item()
                ensemble_acc[member] += correct
                
            loss.backward()
            self.opt.step()
        
        ensemble_acc = ensemble_acc / total
        
        return ensemble_acc

Untitled

  • 논문에서 제시된 이 식을 바탕으로 reward network parameter 를 학습하는 과정
# reward_model.py
# parameterized reward 의 network 구조
def gen_net(in_size=1, out_size=1, H=128, n_layers=3, activation='tanh'):
    net = []
    for i in range(n_layers):
        net.append(nn.Linear(in_size, H))
        net.append(nn.LeakyReLU())
        in_size = H
    net.append(nn.Linear(in_size, out_size))
    if activation == 'tanh':
        net.append(nn.Tanh())
    elif activation == 'sig':
        net.append(nn.Sigmoid())
    else:
        net.append(nn.ReLU())

    return

# self.de == ensemble_size 만큼의 reward model을 만들어서 
# self.ensemble 에 추가
def construct_ensemble(self):
        for i in range(self.de):
            model = nn.Sequential(*gen_net(in_size=self.ds+self.da, 
                                           out_size=1, H=256, n_layers=3, 
                                           activation=self.activation)).float().to(device)
            self.ensemble.append(model)
            self.paramlst.extend(model.parameters())
            
        self.opt = torch.optim.Adam(self.paramlst, lr = self.lr)

# 지정한 member에서 sa_t_x를 input으로 받아 p 예측값을 return 함
def r_hat_member(self, x, member=-1):
        # the network parameterizes r hat in eqn 1 from the paper
        return self.ensemble[member](torch.from_numpy(x).float().to(device))

def r_hat_batch(self, x):
        # they say they average the rewards from each member of the ensemble, but I think this only makes sense if the rewards are already normalized
        # but I don't understand how the normalization should be happening right now :(
        r_hats = []
        for member in range(self.de):
            r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
        r_hats = np.array(r_hats)

        return np.mean(r_hats, axis=0)
  • train_reward 함수를 수행하는 과정에서 reward_model.py 에서 사용된 함수들 정리
# train_PEBBLE.py
# def run 함수 계속

# reward network 업데이트 이후 off-policy learning을 위한 relabeling 과정
self.replay_buffer.relabel_with_predictor(self.reward_model)
# replay_buffer.py

def relabel_with_predictor(self, predictor):
        batch_size = 200
				# self.idx 는 replay buffer에 들어있는 마지막 trajectory index
        total_iter = int(self.idx/batch_size)
        
				# ceil(올림) 함수 대신 total_iter 에 대한 정수값 처리
        if self.idx > batch_size*total_iter:
            total_iter += 1
            
        for index in range(total_iter):
            last_index = (index+1)*batch_size
            if (index+1)*batch_size > self.idx:
                last_index = self.idx
                
            obses = self.obses[index*batch_size:last_index]
            actions = self.actions[index*batch_size:last_index]
            inputs = np.concatenate([obses, actions], axis=-1)
            
						# predictor == reward_model
						# r_hat_batch 함수는 ensemble로 얻은 예측값 mean을 return
            pred_reward = predictor.r_hat_batch(inputs)
						# 원래 replay_buffer 값을 업데이트
            self.rewards[index*batch_size:last_index] = pred_reward
# train_PEBBLE.py
# def run 함수 계속

# 이제 unsupervised exploration 이 끝났으니 초기화 후 다시 학습 시작
# critic, critic_target instantiate
self.agent.reset_critic()
                
self.agent.update_after_reset(
                    self.replay_buffer, self.logger, self.step, 
                    gradient_update=self.cfg.reset_update, 
                    policy_update=True)
# agent/sac.py

def update_after_reset(self, replay_buffer, logger, step, gradient_update=1, policy_update=True):
        for index in range(gradient_update):
						# replay_buffer에서 batch size 만큼의 정보들을 random 하게 tensor 형태로 sample
            obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample(
                self.batch_size)

            print_flag = False
            if index == gradient_update -1:
                logger.log('train/batch_reward', reward.mean(), step)
                print_flag = True
                
            self.update_critic(obs, action, reward, next_obs, not_done_no_max,
                               logger, step, print_flag)

						# 
            if index % self.actor_update_frequency == 0 and policy_update:
                self.update_actor_and_alpha(obs, logger, step, print_flag)

						# 내분점으로 soft update 하는 방식
            if index % self.critic_target_update_frequency == 0:
                utils.soft_update_params(self.critic, self.critic_target,
                                         self.critic_tau)

def update_critic(self, obs, action, reward, next_obs, 
                      not_done, logger, step, print_flag=True):
        
        dist = self.actor(next_obs)
				# stochastic policy gradient 를 implement 하는 reparameterize trick 
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
				# double q learning 을 우한 두 개의 q값 반환
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
				# log_prob 유도 부분까지 이해한됨... 
        target_V = torch.min(target_Q1,
                             target_Q2) - self.alpha.detach() * log_prob
        target_Q = reward + (not_done * self.discount * target_V)
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        
        if print_flag:
            logger.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        self.critic.log(logger, step)

def update_actor_and_alpha(self, obs, logger, step, print_flag=False):
        dist = self.actor(obs)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        actor_Q1, actor_Q2 = self.critic(obs, action)

        actor_Q = torch.min(actor_Q1, actor_Q2)
        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()
        if print_flag:
            logger.log('train_actor/loss', actor_loss, step)
            logger.log('train_actor/target_entropy', self.target_entropy, step)
            logger.log('train_actor/entropy', -log_prob.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_prob - self.target_entropy).detach()).mean()
            if print_flag:
                logger.log('train_alpha/loss', alpha_loss, step)
                logger.log('train_alpha/value', self.alpha, step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()
  • reparameterization trick은 아래의 링크 참조

Probability distributions - torch.distributions - PyTorch 1.12 documentation

  • sac 관련 부분 : update_critic, update_actor_and_alpha 부분은 다시 봐야 할듯
# train_PEBBLE.py
# def run 함수 계속

# unsupervised exploration 단계를 벗어남
elif self.step > self.cfg.num_seed_steps + self.cfg.num_unsup_steps:
				# 아래 두 조건 만족 시 frac 변수 설정 후 reward model의 mb_size 변화시킴 
				# total_feedback 변수는 learn_reward 함수 호출 시 매번 mb_size 만큼 증가
				# interact_count 변수는 run 함수의 while step 구문 안에서 1씩 증가
				if self.total_feedback < self.cfg.max_feedback:
                    if interact_count == self.cfg.num_interact:
										self.reward_model.change_batch(frac)

										# line 274 ~ 280 minor details
												# first_flag 끄고 함수 실행, 내부 호출 내용은 이미 설명
												self.learn_reward()
                        self.replay_buffer.relabel_with_predictor(self.reward_model)
												# 위에서 self.step == (self.cfg.num_seed_steps + self.cfg.num_unsup_steps) 조건에서 초기화 등장
												# while 문에서 step 마다 1씩 증가하므로 step이 위의 조건식보다 커지면 cfg.num_interact 주기마다 해석중인 elif 문을 구동하길 원하는듯
                        interact_count = 0
				
				# 위에서 본 update_after_reset과의 차이는 for문에서 index 가 아닌 self.step 에 따라 actor 와 critic 업데이트를 한다는 것이다
				self.agent.update(self.replay_buffer, self.logger, self.step, 1)
# train_PEBBLE.py
# def run 함수 계속

# step 이 num_seed_steps 이상, num_seed_steps + num_unsup_steps 사이에서 행동을 정의
# unsupervised exploration 이 이뤄짐
elif self.step > self.cfg.num_seed_steps:
								# self.update과의 차이는 update_critic 대신 update_critic_state_ent를 수행한다는 것이다
                self.agent.update_state_ent(self.replay_buffer, self.logger, self.step, 
                                            gradient_update=1, K=self.cfg.topK)
# agent/sac.py

def update_critic_state_ent(
        self, obs, full_obs, action, next_obs, not_done, logger,
        step, K=5, print_flag=True):
        
        dist = self.actor(next_obs)
        next_action = dist.rsample()
        log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
        target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
        target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob
        
        # compute state entropy
        state_entropy = compute_state_entropy(obs, full_obs, k=K)
       
        self.s_ent_stats.update(state_entropy)
        norm_state_entropy = state_entropy / self.s_ent_stats.std
                
        if self.normalize_state_entropy:
            state_entropy = norm_state_entropy
        
        target_Q = state_entropy + (not_done * self.discount * target_V)
        target_Q = target_Q.detach()

        # get current Q estimates
				# 이 부분부터는 update_critic과 동일
        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)
        
        if print_flag:
            logger.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)
  • KNN 에 의해 엔트로피를 계산하고 이를 이용해 critic update 하는 것은 다시 봐야 할 듯!
  • cmpute_state_entropy() ??
# train_PEBBLE.py
# def run 함수 계속
 						
						next_obs, reward, done, extra = self.env.step(action)
						# 입력 정보를 받아 reward ensemble 의 평균을 반환
            reward_hat = self.reward_model.r_hat(np.concatenate([obs, action], axis=-1))

            # allow infinite bootstrap
            done = float(done)
            done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward_hat
            true_episode_reward += reward
            
            if self.log_success:
                episode_success = max(episode_success, extra['success'])
                
            # adding data to the reward training data
						# add_data 에서 self.target 과 self.input을 get_queries 함수를 위해 변경
						# actor, critic update를 위해 replay_buffer 수정
            self.reward_model.add_data(obs, action, reward, done)
            self.replay_buffer.add(
                obs, action, reward_hat, 
                next_obs, done, done_no_max)
# reward_model.py
def add_data(self, obs, act, rew, done):
        sa_t = np.concatenate([obs, act], axis=-1)
        r_t = rew
        
        flat_input = sa_t.reshape(1, self.da+self.ds)
        r_t = np.array(r_t)
        flat_target = r_t.reshape(1, 1)

        init_data = len(self.inputs) == 0
        if init_data:
            self.inputs.append(flat_input)
            self.targets.append(flat_target)
        elif done:
            self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input])
            self.targets[-1] = np.concatenate([self.targets[-1], flat_target])
            # FIFO
            if len(self.inputs) > self.max_size:
                self.inputs = self.inputs[1:]
                self.targets = self.targets[1:]
            self.inputs.append([])
            self.targets.append([])
        else:
            if len(self.inputs[-1]) == 0:
                self.inputs[-1] = flat_input
                self.targets[-1] = flat_target
            else:
                self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input])
                self.targets[-1] = np.concatenate([self.targets[-1], flat_target])
# train_PEBBLE.py						
						obs = next_obs
            episode_step += 1
            self.step += 1
            interact_count += 1
  • 마무리로 매 step마다 다음과 같은 업데이트 진행
profile
0100101

0개의 댓글