Implicit Language Q Learning 코드리뷰

이두현·2024년 3월 17일
0

scripts/train/bc_train_loop.py

raw_dataset_train = load_item(cfg['train_dataset'], system_cfg['device'])

src/load_objects.py

def load_item(config, *args, verbose=True):
    config = config.copy()
    name = config.pop('name')
    if name not in registry:
        raise NotImplementedError
    if 'cache_id' in config:
        if (name, config['cache_id']) in cache:
            if verbose:
                print(f'loading from cache ({name}, {config["cache_id"]})')
            return cache[(name, config['cache_id'])]
    if verbose:
        print(f'loading {name}: {config}')
    **item = registry[name](config, *args, verbose=verbose)**
    if 'cache_id' in config:
        print(f'saving to cache ({name}, {config["cache_id"]})')
        cache[(name, config['cache_id'])] = item
    return item
  • name이 vis_dial_list_dataset 이고 src/visdial/load_objects.py 에 register된 함수에서 불러옴

src/visdial/load_objects.py

@register('vis_dial_list_dataset')
def load_vis_list_dataset(config, device, verbose=True):
    **vd = load_item(config['data'], verbose=verbose)**
    token_reward = load_item(config['token_reward'], device, verbose=verbose)
    return VisDialListDataset(vd, max_len=config['max_len'], 
                              token_reward=token_reward, 
                              top_p=config['top_p'], 
                              bottom_p=config['bottom_p'])
  • load_item 후 vd에 할당되는 것은 src/visdial/load_objects.py 의 load_vis_dial 결과

src/visdial/load_objects.py

def load_vis_dial(config, verbose=True):
    if config['additional_scenes'] is not None:
        with open(convert_path(config['additional_scenes']), 'rb') as f:
            config['additional_scenes'] = pkl.load(f)
    if config['cutoff_rule'] is not None:
        **config['cutoff_rule'] = load_item(config['cutoff_rule'], verbose=verbose)**
    return VisDialogueData(convert_path(config['data_path']), 
                           convert_path(config['img_feat_path']), 
                           config['split'], 
                           reward_cache=convert_path(config['reward_cache']), 
                           norm_img_feats=config['norm_img_feats'], 
                           reward_shift=config['reward_shift'], 
                           reward_scale=config['reward_scale'], 
                           addition_scenes=config['additional_scenes'], 
                           mode=config['mode'], 
                           cutoff_rule=config['cutoff_rule'], 
                           yn_reward=config['yn_reward'], 
                           yn_reward_kind=config['yn_reward_kind'])
  • src/visdial/load_objects.py 에서 percentile_cutoff_rule 을 불러옴

src/visdial/visdial_base.py

class PercentileCutoffRule:
    def __init__(self, goal_value: float, percentile: float):
        self.goal_value = goal_value
        self.percentile = percentile

    def apply_rule(self, scene: Scene, event: Event):
        progress = sum([ev.progress for ev in event.get_events()]) / (self.goal_value-scene.initial_val)
        return progress >= self.percentile
  • goal_value 1.0, percentile 0.5
  • env 운영 방식을 정하는 클래스 (VisDialogueData 클래스에 arg로 들어감)

이어서

def load_vis_dial(config, verbose=True):
    if config['additional_scenes'] is not None:
        with open(convert_path(config['additional_scenes']), 'rb') as f:
            config['additional_scenes'] = pkl.load(f)
    if config['cutoff_rule'] is not None:
        config['cutoff_rule'] = load_item(config['cutoff_rule'], verbose=verbose)
    return **VisDialogueData(convert_path(config['data_path']), 
                           convert_path(config['img_feat_path']), 
                           config['split'], 
                           reward_cache=convert_path(config['reward_cache']), 
                           norm_img_feats=config['norm_img_feats'], 
                           reward_shift=config['reward_shift'], 
                           reward_scale=config['reward_scale'], 
                           addition_scenes=config['additional_scenes'], 
                           mode=config['mode'], 
                           cutoff_rule=config['cutoff_rule'], 
                           yn_reward=config['yn_reward'], 
                           yn_reward_kind=config['yn_reward_kind'])**
  • VisDialogueData 클래스는 task dialogue + reward를 저장
  • reward_cachce : dialogue 당 reward 저장된 dir를 넘겨줌 (data/vis_dialogue/processed/visdial_0.5/train_rank_reward_cache1.json)

src/visdial/visdial_base.py (VisDialogueData class init 과정)

class VisDialogueData:
    def __init__(self, data_path: str, img_feat_path: str, 
                 split: str, reward_cache: Optional[str]=None, 
                 norm_img_feats: bool=True, reward_shift: float=0.0, 
                 reward_scale: float=1.0, 
                 addition_scenes: Optional[List[Scene]]=None, 
                 mode: str='env_stops', 
                 cutoff_rule: Optional[CutoffRule]=None, 
                 yn_reward: float=-2.0, yn_reward_kind: str='none'):
        assert mode in ['agent_stops', 'env_stops', '10_stop']
        assert yn_reward_kind in yn_reward_fs
        if mode == 'env_stops':
            if cutoff_rule is None:
                cutoff_rule = PercentileCutoffRule(1.0, 0.5)
            assert reward_cache is not None
        yn_reward_f = yn_reward_fs[yn_reward_kind]
        self.norm_img_feats = norm_img_feats
        with open(data_path, 'r') as f:
            data = json.load(f)
        if reward_cache is not None:
            with open(reward_cache, 'r') as f:
                reward = json.load(f)
            progress = reward
            reward = [[item * reward_scale + reward_shift for item in rs[1:]] for rs in reward]
        else:
            progress = None
            reward = None
        img_feats = h5py.File(img_feat_path, 'r')['images_%s' % (split)]
        if self.norm_img_feats:
            img_feats = normalize(img_feats, axis=1, norm='l2')
        assert len(img_feats) == len(data)
        if mode == 'agent_stops':
            self.scenes = sum([Scene.from_json_stops(data[i], img_feats[i], 
                                                     reward if reward is None else reward[i], 
                                                     progress[i] if progress is not None else None) for i in range(len(data))], [])
        **elif mode == 'env_stops':
            # maybe make reward 0 or -1 here
            self.scenes = [Scene.from_json_cuttoff(data[i], img_feats[i], 
                                                   progress[i] if progress is not None else None, 
                                                   cutoff_rule, yn_reward, yn_reward_f) for i in range(len(data))]**
        elif mode == '10_stop':
            self.scenes = [Scene.from_json(data[i], img_feats[i], 
                                           reward if reward is None else reward[i], 
                                           progress[i] if progress is not None else None) for i in range(len(data))]
        else:
            raise NotImplementedError
        if addition_scenes is not None:
            self.scenes += addition_scenes
  • datapath : ‘/home/doolee13/Implicit-Language-Q-Learning/src/utils/../../data/vis_dialogue/raw/visdial_0.5/visdial_0.5_train.json’
  • data 같은 경우 data[0][’dialog’] 의 길이가 10 임
    • 한 length 에는 (question, answer) pair 가 들어가 있음
  • reward[0] 의 길이는 11 (처음에 event 말고 scene 자체에 대한 reward + events 길이)
  • progress에 원래있던 reward 넣고 reward는 각 idx마다 [1:] 에 대해 scale + shift 결과로 업데이트
    • 변환된 reward 안쓰이고 그냥 원래 reward, 즉 progress 쓰임
  • datapath 에 sequential data preprocess 한 파일 있어야 함
  • ‘question’ 자리에 metadata를 (k, v) 정리한 문장 내용, ‘answer’ 자리에 reviewText 문장이 dictionary 형태로오게
  • reward cache 파일은 일단 나중에

src/visdial/visdial_base.py

@dataclass
class Scene:
    caption: str
    img_feat: np.ndarray
    events: List[Event]
    initial_val: Optional[float]
    cutoff_rule: Optional[CutoffRule]

@classmethod
    def from_json_cuttoff(cls, scene_json, img_feat, progress, cutoff_rule, yn_reward, yn_reward_f):
        caption = scene_json['caption']
        events = []
        for i in range(len(scene_json['dialog'])):
            events.append(QuestionEvent(scene_json['dialog'][i]['question']+'?', 
                                        0.0, None, None, None))
            events.append(AnswerEvent(scene_json['dialog'][i]['answer'], 
                                      -1.0 + (yn_reward if yn_reward_f is not None and yn_reward_f(scene_json['dialog'][i]['answer']) else 0.0), 
                                      0.0 if progress is None else progress[i+1], 
                                      None, None, None))
        scene = cls(caption, img_feat, events, 0.0 if progress is None else progress[0], cutoff_rule)
        for p, n in zip(events[:-1], events[1:]):
            p.next = n
            n.prev = p
        for ev in events:
            ev.scene = scene
        for i, ev in enumerate(scene.events):
            if isinstance(ev, AnswerEvent) and ev.is_final():
                scene.events = scene.events[:(i+1)]
                scene.events[-1].next = None
                ev.reward = 0.0 + (yn_reward if yn_reward_f is not None and yn_reward_f(scene.events[-1].answer) else 0.0)
        return scene
  • caption : 그냥 문장으로 되어있음
  • progress 길이가 scene_json 길이보다 하나 더 긺
  • QuestionEvent (question, progress, scene, prev, next)
  • AnswerEvent (answer, reward, progress, scene, prev, next)
    • answer 가 환경이고 여기에만 reward를 부여중
  • 만약 event가 마지막이고 answerevent에 대응된다면 1) scene 객체의 events list를 ~i로 만들고 2) 마지막 원소의 next를 none 으로 만들고 3) 답변 내용에 대해 추가적 reward를 반영 (끝났으므로 -1 은 없앰)
  • scene 안에는 event list 가 있고 이들은 서로 앞, 뒤로 연결되어있으며 자기가 속한 scene 을 클래스 변수로 갖고 있다.

이어서

@register('vis_dial_list_dataset')
def load_vis_list_dataset(config, device, verbose=True):
    vd = load_item(config['data'], verbose=verbose)
    **token_reward = load_item(config['token_reward'], device, verbose=verbose)**
    **return VisDialListDataset(vd, max_len=config['max_len'], 
                              token_reward=token_reward, 
                              top_p=config['top_p'], 
                              bottom_p=config['bottom_p'])**
  • vd 에 VisDialogueData 클래스가 넘겨짐
  • token_reward는 토큰 단위로 주어지는 reward (모든 실험에서 0으로 주어진다고 함 → 그럼 왜,,? )
  • VisDialListDataset 은 VisDialogueData 를 wrap 하고 DataPoint 로 변환해 offline RL 에 사용될 수 있게 변환
  • max_len 은 최대 sequence 길이

src/visdial/visdial_dataset.py

class VisDialListDataset(List_RL_Dataset):
    def __init__(self, data: VisDialogueData, 
                 max_len: Optional[int], 
                 token_reward: TokenReward, 
                 top_p: Optional[float]=None, 
                 bottom_p: Optional[float]=None, 
                ) -> None:
        **tokenizer = VisDialTokenizer()**
        super().__init__(tokenizer, token_reward, max_len)
        self.data = data
        self.datapoints = []

src/visdial/visdial_tokenizer.py

class VisDialTokenizer(Tokenizer):
    def __init__(self):
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.add_special_tokens({'additional_special_tokens': ['</a>', '<a>', '<stop>', '</eod>'], 
                                           'bos_token': '<s>', 
                                           'sep_token': '</s>', 
                                           'pad_token': '<|pad|>'})
        super().__init__(self.tokenizer.convert_tokens_to_ids('<|pad|>'), 
                         self.tokenizer.convert_tokens_to_ids('</s>'), 
                         self.tokenizer.convert_tokens_to_ids('</a>'), 
                         self.tokenizer.convert_tokens_to_ids('<s>'), 
                         self.tokenizer.convert_tokens_to_ids('<a>'), 
                         self.tokenizer.convert_tokens_to_ids('</eod>'))
        self.stop_token = self.tokenizer.convert_tokens_to_ids('<stop>')

이어서

class VisDialListDataset(List_RL_Dataset):
    def __init__(self, data: VisDialogueData, 
                 max_len: Optional[int], 
                 token_reward: TokenReward, 
                 top_p: Optional[float]=None, 
                 bottom_p: Optional[float]=None, 
                ) -> None:
        tokenizer = VisDialTokenizer()
        super().__init__(tokenizer, token_reward, max_len)
        self.data = data
        self.datapoints = []
        **for item in self.data:
            obs = VDObservation(item, item.events[-1])
            self.datapoints.append(DataPoint.from_obs(obs, self.tokenizer, self.token_reward))**
  • item is self.scenes in VisDialogueData

src/data/rl_data.py

@classmethod
    def from_obs(cls, obs: Language_Observation, tokenizer: Tokenizer, token_reward: TokenReward, meta: Optional[Dict[str, Any]]=None):
        **sequence, terminal = obs.to_sequence()**

src/visdial/visdial_env.py

def to_sequence(self) -> Tuple[List[Tuple[str, Optional[float]]], bool]:
        if self.event is None:
            return [(self.scene.caption, None)], False
        **evs = self.event.get_events()**
        sequence = [(self.scene.caption, None)]
        sequence += [(str(evs[i]), evs[i+1].reward if isinstance(evs[i+1], AnswerEvent) else None) for i in range(len(evs)-1)]
        sequence += [(str(evs[-1]), 0.0 if isinstance(evs[-1], StopEvent) else None)]
        terminal = self.event.is_final()
        return sequence, terminal
  • 첫 장면에서는 self.event가 None 이고 self.scene만 존재한다
  • [('Bikes parked on the side of a street beside a fence.', None), ('how many bikes there?', -1.0), ('3', None), ('what color are bikes?', -1.0), ('i see green red and white', None), ('are they parked on stock parking?', -3.0), ('no', None), ('are there any people?', 0.0), , ('2', None)]
  • 완성된 sequence 예시

src/visdial/visdial_base.py

def get_events(self, direction="prev"):
        if direction == "prev":
            func = lambda ev: ev.prev
        elif direction == "next":
            func = lambda ev: ev.next
        else:
            raise NotImplementedError
        events = []
        ev = self
        while ev is not None:
            events.append(ev)
            ev = func(ev)
        if direction == 'prev':
            events.reverse()
        return events
  • default ‘prev’ 이므로 events list에 load 한다
  • evs 변수에 events list 가 반환됨

이어서

def to_sequence(self) -> Tuple[List[Tuple[str, Optional[float]]], bool]:
        if self.event is None:
            return [(self.scene.caption, None)], False
        ****evs = self.event.get_events()
        **sequence = [(self.scene.caption, None)]
        sequence += [(str(evs[i]), evs[i+1].reward if isinstance(evs[i+1], AnswerEvent) else None) for i in range(len(evs)-1)]
        sequence += [(str(evs[-1]), 0.0 if isinstance(evs[-1], StopEvent) else None)]
        terminal = self.event.is_final()**
        return sequence, terminal
  • sequence list에는 [0]에 scene caption 정보 저장
  • 그 다음부터는 (question 문장, answer에 대한 reward) 저장

이어서

def from_obs(cls, obs: Language_Observation, tokenizer: Tokenizer, token_reward: TokenReward, meta: Optional[Dict[str, Any]]=None):
        sequence, terminal = obs.to_sequence()
        obs_meta = obs.metadata()
        if meta is not None and obs_meta is not None:
            meta = {**obs_meta, **meta}
        **elif obs_meta is not None:
            meta = obs_meta**
        if len(sequence) == 0 or sequence[0][1] is not None:
            raw_str = tokenizer.id_to_token(tokenizer.boa_token_id)
        **else:
            raw_str = tokenizer.id_to_token(tokenizer.bos_token_id)**
  • 굵은 부분 실행, bos token 은 시작을 나타내는 ‘’ 토큰
    		action_rewards = []
        for s, r in sequence:
            raw_str += s
            if r is None:
                raw_str += tokenizer.id_to_token(tokenizer.eos_token_id)
            else:
                raw_str += tokenizer.id_to_token(tokenizer.eoa_token_id)
                action_rewards.append(r)
        if terminal:
            raw_str += tokenizer.id_to_token(tokenizer.eod_token_id)
        tokens = tokenizer.encode(raw_str)[0]
        token_rewards = token_reward.get_token_reward(tokens)
  • r 이 None 인 경우는 질문에 대한 대답인 경우
    • 이 경우에는 ‘’ 토큰을 붙여줌
  • r 이 존재하는 경우에는 질문인 경우
    • 이 경우에는 ‘’ 토큰을 붙여줌
  • 만들어진 raw string을 id로 변환해서 tokens에 저장
  • token_rewards는 모두 0임을 참고
  • Bikes parked on the side of a street beside a fence.how many bikes there?3what color are bikes?i see green red and whiteare they parked on stock parking?noare there any people?2
  • 생성된 raw_str 예시
       for i, t in enumerate(tokens):
            if t == tokenizer.eos_token_id:
                curr_idx = i
            elif t == tokenizer.eoa_token_id:
                action_idxs.extend(list(range(curr_idx, i)))
                state_idxs.extend(list(range(curr_idx, i)))
                reward.extend([token_rewards[x] for x in range(curr_idx, i)])
                reward[-1] += action_rewards[curr_action_idx]
                utterance_action_idxs.append(i)
                utterance_state_idxs.append(curr_idx)
                utterance_rewards.append(action_rewards[curr_action_idx]+sum([token_rewards[x] for x in range(curr_idx, i)]))
                curr_idx = i
                curr_action_idx += 1
        **state_idxs.append(len(tokens)-1)**
        utterance_state_idxs.append(len(tokens)-1)
        terminals = ([0] * (len(state_idxs)-1))+[int(terminal)]
        utterance_terminals = ([0] * (len(utterance_state_idxs)-1))+[int(terminal)]
        return cls(raw_str, tokens, state_idxs, action_idxs, reward, terminals, 
                   utterance_state_idxs, utterance_action_idxs, 
                   utterance_rewards, utterance_terminals, meta=meta)
  • action_idxs는 포함 시작해서 미만 까지 index 포함
  • ex) how many bikes there?
  • state, reward 도 똑같이, 근데 reward는 action_rewards를 마지막 항에 더해줌
  • utterance_action_idx 는 토큰마다 리스트에 기록
  • utterance_state_idx는 토큰마다 리스트에 기록
  • 알고리즘 상 마지막 는 반영이 안되므로 볼드 부분 추가

이어서

def load_item(config, *args, verbose=True):
    config = config.copy()
    name = config.pop('name')
    if name not in registry:
        raise NotImplementedError
    if 'cache_id' in config:
        if (name, config['cache_id']) in cache:
            if verbose:
                print(f'loading from cache ({name}, {config["cache_id"]})')
            return cache[(name, config['cache_id'])]
    if verbose:
        print(f'loading {name}: {config}')
    item = registry[name](config, *args, verbose=verbose)
    if 'cache_id' in config:
        print(f'saving to cache ({name}, {config["cache_id"]})')
        cache[(name, config['cache_id'])] = item
    **return item**
  • scripts/train/bc_train_loop.py 의 raw_dataset_train 에 VisDialListDataset 이 반환됨

scripts/train/bc_train_loop.py

		raw_dataset_train = load_item(cfg['train_dataset'], system_cfg['device'])
    raw_dataset_eval = load_item(cfg['eval_dataset'], system_cfg['device'])
    if isinstance(raw_dataset_train, Iterable_RL_Dataset):
        dataset_train = GeneralIterDataset(raw_dataset_train, 'cpu')
    else:
        **dataset_train = GeneralDataset(raw_dataset_train, 'cpu')**
    if isinstance(raw_dataset_eval, Iterable_RL_Dataset):
        dataset_eval = GeneralIterDataset(raw_dataset_eval, 'cpu')
    else:
        **dataset_eval = GeneralDataset(raw_dataset_eval, 'cpu')**
  • raw_dataset_train : VisDialListDataset 클래스

src/data/torch_datsets.py

mclass GeneralDataset(Dataset):
    def __init__(self, rl_dataset: List_RL_Dataset, 
                 device: Union[torch.device, str]):
        self.rl_dataset = rl_dataset
        self.device = device
    
    def __len__(self):
        return self.rl_dataset.size()
    
    def __getitem__(self, i):
        return self.rl_dataset.get_item(i)

    def collate(self, items):
        return self.rl_dataset.collate(items, self.device)
    
    def collate_simple(self, items):
        return items

src/data/rl_data.py

class RL_Dataset(ABC):
    def collate(self, items: List[DataPoint], device):
        tokens, state_idxs, action_idxs, rewards, terminals, u_state_idxs, u_action_idxs, u_rewards, u_terminals = zip(*map(lambda x: x.to_tensors(device, self.max_len), items))
        tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attn_mask = (tokens != self.tokenizer.pad_token_id).float()
        state_idxs = torch.nn.utils.rnn.pad_sequence(state_idxs, batch_first=True, padding_value=0)
        action_idxs = torch.nn.utils.rnn.pad_sequence(action_idxs, batch_first=True, padding_value=0)
        terminals = torch.nn.utils.rnn.pad_sequence(terminals, batch_first=True, padding_value=1)
        rewards = torch.nn.utils.rnn.pad_sequence(rewards, batch_first=True, padding_value=0.0)
        u_state_idxs = torch.nn.utils.rnn.pad_sequence(u_state_idxs, batch_first=True, padding_value=0)
        u_action_idxs = torch.nn.utils.rnn.pad_sequence(u_action_idxs, batch_first=True, padding_value=0)
        u_terminals = torch.nn.utils.rnn.pad_sequence(u_terminals, batch_first=True, padding_value=1)
        u_rewards = torch.nn.utils.rnn.pad_sequence(u_rewards, batch_first=True, padding_value=0.0)
        return {'tokens': tokens, 'attn_mask': attn_mask, 
                'state_idxs': state_idxs, 'action_idxs': action_idxs, 
                'rewards': rewards, 'terminals': terminals, 
                'u_state_idxs': u_state_idxs, 'u_action_idxs': u_action_idxs, 
                'u_rewards': u_rewards, 'u_terminals': u_terminals}
  • 배치 방향으로 여러 tensor가 있을 때 가장 긴 sequence에 padding 붙여서 사이즈 맞춰줌

모델 load에 대한 내용

scripts/train/bc_train_loop.py

model = load_item(cfg['model'], system_cfg['device'])

src/load_objects.py

@register('bc_lm')
def load_bc_lm(config, device, verbose=True):
    **gpt2 = load_item(config['gpt2'], verbose=verbose)**
    dataset = load_item(config['dataset'], device, verbose=verbose)
    model = BC_LM(gpt2, dataset, device, config['transition_weight'])
    return load_model(config['load'], model, device, verbose=verbose)
@register('gpt2')
def load_gpt2(config, verbose=True):
    obj = GPT2LMHeadModel if config['lm_head'] else GPT2Model
    if config['from_pretrained']:
        return obj.from_pretrained(config['gpt2_type'])
    config = GPT2Config.from_pretrained(config['gpt2_type'])
    return obj(config)
  • config[’lm_head’] = True → obj : GPT2LMHeadModel
  • config[’gpt2_type’] = ‘gpt2’
  • 위의 gpt2 변수에는 GPT2LMHeadModel 이 들어감
  • dataset은 VIsDialListDataset

src/models/bc_lm.py

class BC_LM(BaseTransformer):
    def __init__(self, 
                 model: PreTrainedModel, 
                 dataset: RL_Dataset, 
                 device: Union[torch.device, str] = "cuda", 
                 transition_weight: float=0.0, 
                ):
        assert isinstance(model, GPT2LMHeadModel)
        super().__init__(model, dataset, device)
        self.h_dim  = self.model.config.n_embd
        self.transition_weight = transition_weight
  • self.h_dim = 768
  • self.transition_weight = 0.0

inherit from this src/models/base.py

class BaseModel(ABC, nn.Module):
    def __init__(self, 
                 dataset: RL_Dataset, 
                 device: Union[torch.device, str]) -> None:
        super().__init__()
        self.dataset = dataset
        self.device = device
        self.max_len = self.dataset.max_len

    def prepare_inputs(self, items: InputType):
        if isinstance(items, dict):
            return items
        return to(self.dataset.collate(items, self.device), self.device)

    @abstractmethod
    def get_loss(self, items: InputType, **kwargs):
        pass

class BaseTransformer(BaseModel):
    def __init__(self, 
                 pretrained_model: PreTrainedModel, 
                 dataset: RL_Dataset, 
                 device: Union[torch.device, str]) -> None:
        super().__init__(dataset, device)
        self.model = pretrained_model
        self.model.resize_token_embeddings(self.dataset.tokenizer.num_tokens())

모델 로드 개빡세네..

알고보니 huggingface 구현체 그대로 복사해놓은 것이었음


훈련부분 다시

scripts/train/bc_train_loop.py

for epoch in tqdm(range(train_cfg['epochs']), disable=not accelerator.is_local_main_process):
        for items in tqdm(data_loader, disable=not accelerator.is_local_main_process):
            items = to(items, system_cfg['device'])
            **loss, logs, postproc_fs = accelerator.unwrap_model(model).get_loss(items, **train_cfg['loss'])**

src/models/bc_lm.py

def get_loss(self, 
                 items: InputType):
        **prepared_inputs = self.prepare_inputs(items)**
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']

src/models/base.py

def prepare_inputs(self, items: InputType):
        **if isinstance(items, dict):
            return items**
        return to(self.dataset.collate(items, self.device), self.device)
  • 이미 dictionary 형태이므로 그냥 반환

이어서

def get_loss(self, 
                 items: InputType):
        prepared_inputs = self.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        a_idx = prepared_inputs['action_idxs']
        model_outputs = self(tokens, attn_mask, 
                             output_attentions=True)
        logs = {}
        transformer_logs = get_transformer_logs(model_outputs.attentions, 
                                                self.model, 
                                                attn_mask)
        n = attn_mask.sum().item()
        **weights = self.get_weights(tokens, a_idx)**
  • tokens shape : [2, 110]
  • 그냥 참고) model_outputs.logits shape : [2, 110, 50264]
  • n : 181 ? (in first batch)
def get_weights(self, 
                    tokens: torch.Tensor, 
                    action_idxs: torch.Tensor):
        weights = torch.full(tokens.shape, self.transition_weight).to(self.device)
        if action_idxs.shape[1] == 0:
            n = torch.zeros((tokens.shape[0],)).long().to(self.device)
        else:
            n = torch.argmax(action_idxs, dim=1)+1
        for i in range(tokens.shape[0]):
            weights[i] = torch.scatter(weights[i], dim=0, index=action_idxs[i, :n[i]], src=torch.full((n[i].item(),), 1.0).to(self.device))
        return weights
  • self.transition_weight : 0
  • action_idx.shape : [2, 62]
  • torch argmax로 각 batch에 대해 padding 이 아닌 action의 index n을 구함 (+1 해서 :n 미만 범위로)
  • 마지막줄은 item index가 나타내는 idx들을 1로 채워서 tokens와 크기가 같고 action 자리는 1로 채워진 weights를 반환한다.
def get_loss(self, 
                 items: InputType):
        prepared_inputs = self.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        a_idx = prepared_inputs['action_idxs']
        model_outputs = self(tokens, attn_mask, 
                             output_attentions=True)
        logs = {}
        transformer_logs = get_transformer_logs(model_outputs.attentions, 
                                                self.model, 
                                                attn_mask)
        n = attn_mask.sum().item()
        weights = self.get_weights(tokens, a_idx)
        **token_loss = self.awac_loss(tokens, attn_mask, model_outputs.logits, weights)**
def awac_loss(self, tokens, attn_mask, logits, w):
        w = w.detach()
        losses = F.cross_entropy(logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1), reduction='none')
        losses = losses.reshape(tokens.shape[0], tokens.shape[1]-1)
        return (losses * w[:, :-1] * attn_mask[:, 1:]).sum() / attn_mask[:, 1:].sum()
  • reshape 후 logit 모양 : [2 * 109, 50264], token 모양 : [2 * 109]
  • cross_ent 이후 shape은 [2*109]
  • 다시 [2, 109] 로 reshape
  • w(weight) shape은 tokens shape, attn_mask shape 과 같고 [2, 110]
  • w[:, :-1], attn_mask[:, 1:] 인 이유 : w는 action 예측용인데 마지막 토큰에 대해서는 예측할 일이 없으므로
  • attn_mask는 첫 토큰에 대해 예측 loss를 부여하지 않으므로

evalutaor 부분 리뷰해야함

scripts/train/bc_train_loop.py

if cfg['evaluator'] is not None:
        evaluator = load_item(cfg['evaluator'], system_cfg['device'])

src/load_objects.py

@register('bc_evaluator')
def load_bc_evaluator(config, device, verbose=True):
    env = load_item(config['env'], device, verbose=verbose)
    return BC_Evaluator(env, config['env'], config['kind'], **config['generation_kwargs'])
  • BC_Evaluator class는 src/models/bc_lm.py에
  • evaluation은 4096마다

scripts/train/bc_train_loop.py

				 	     with torch.no_grad():
                    for i, eval_items in enumerate(eval_data_loader):
                        eval_items = to(eval_items, system_cfg['device'])
                        if i >= train_cfg['eval_batches']:
                            break
                        _, logs, postproc_fs = accelerator.unwrap_model(model).get_loss(eval_items, **train_cfg['loss'])
                        **if evaluator is not None:
                            evaluator_logs = evaluator.evaluate(accelerator.unwrap_model(model), eval_items)**
                            if evaluator_logs is not None:
                                logs['evaluation'] = evaluator_logs
                        eval_logs.accum_logs(logs)
                eval_label = 'eval'
                eval_total_logs = eval_logs.log(*postproc_fs, 
                                                partial(label_logs, label=eval_label), 
                                                iteration=step, epoch=epoch)
  • evaluator 는 BC_Evaluator class
  • unwrap_model(model)의 반환은 BC_LM class

src/models/bc_lm.py

def evaluate(self, model: BC_LM, items: InputType) -> Optional[Dict[str, Any]]:
        policy = BC_Policy(model, self.kind, **self.generation_kwargs)
        tokens = model.prepare_inputs(items)['tokens']
        n = tokens.shape[0]
        total_token_reward = 0
        total_env_reward = 0
        for i in range(n):
            **result, sequence = interact_environment(self.env, policy, None)**
  • model.prepare_inputs는 items가 dictionary 형태면 별일 없이 그냥 반환
  • tokens.shape 예시 : [32, 179]
  • 근데 tokens는 batch size만 계산하는데 쓰이고 이후에는 안쓰이는 것 같음..

src/data/language_environment.py

def interact_environment(env: Language_Environment, policy: Policy, obs: Optional[Language_Observation]=None):
    obs_sequence = []
    if obs is None:
        obs = env.reset()
    while not env.is_terminal():
        **action = policy.act(obs)**
  • env.reset() 예시 : a elongated pizza with some tongs to handle it

src/models/bc_lm.py

# from BC_Policy class 
def act(self, obs: Language_Observation) -> str:
        item = DataPoint.from_obs(obs, self.bc_lm.dataset.tokenizer, self.bc_lm.dataset.token_reward)
        **generations, probs = self.generate([item], always_terminate, **self.generation_kwargs)**
        sorted_outputs = list(zip(*sorted(zip(generations[0][1], probs[0]), key=lambda x: -x[1])))[0]
        return sorted_outputs[0]
  • obs 는 a elongated pizza with some tongs to handle it 이고 위에서 설명한 from_obs 따라서 진행
def generate(self, items: InputType, 
                 termination_condition: Callable[[np.ndarray], bool], **kwargs):
        prepared_inputs = self.bc_lm.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        if self.kind == 'beam':
            method = self.beam_raw
        **elif self.kind == 'sample':
            method = self.sample_raw**
        else:
            raise NotImplementedError
        generations, probs = method(tokens, attn_mask, 
                                    termination_condition, 
                                    **kwargs)
        return generations, probs
  • default 설정이 self.kind == ‘sample’ 이므로 sample_raw 방식을 채택
def sample_raw(self, 
                   tokens: torch.Tensor, attn_mask: torch.Tensor, 
                   termination_condition: Callable[[np.ndarray], bool], 
                   num_generations=1, max_generation_len=None, 
                   temp=1.0, top_k=None, top_p=None, 
                   prefix_embs: Optional[torch.Tensor]=None, 
                   prefix_attn_mask: Optional[torch.Tensor]=None, 
                   remove_prefix_position_embs: bool=False):
        tokenizer = self.bc_lm.dataset.tokenizer
        max_length = self.bc_lm.dataset.max_len
        if max_length is None:
            max_length = self.bc_lm.model.config.n_positions
        max_length = min(max_length, self.bc_lm.model.config.n_positions)
        device = self.bc_lm.device
        bsize = tokens.shape[0]
        n = bsize * num_generations
        if max_generation_len is None:
            max_generation_len = max_length+1
        input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
        model_outputs = self.bc_lm(tokens, attn_mask, prefix_embs=prefix_embs, 
                                   prefix_attn_mask=prefix_attn_mask, 
                                   remove_prefix_position_embs=remove_prefix_position_embs, 
                                   use_cache=True)
        dialogue_kvs = model_outputs.past_key_values
        dialogue_lens = attn_mask.sum(dim=1)
        tokens = pad_sequence(torch.repeat_interleave(tokens, num_generations, dim=0), max_length, tokenizer.pad_token_id, device, 1)
        dialogue_lens = torch.repeat_interleave(dialogue_lens, num_generations, dim=0)
        **dialogue_kvs = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), dialogue_kvs)**
        log_probs = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
        termination_mask = torch.full((dialogue_lens.shape[0],), 1).to(device)
        t = torch.min(dialogue_lens).int()
  • num_generations 는 몇 개의 응답 결과를 생성할지 결정하는 듯함
  • max_length = self.bc_lm.model.config.n_positions = 1024
  • tokens.shape = [1, 13]
  • input_strs 변환 예시 : [' a elongated pizza with some tongs to handle it ']
  • prefix_embs는 default None, prefix_t = 0
  • dialogue_lens = [13.]
  • pad_sequence 이후 tokens 는 max_length(1024)에 맞춰 pad 된 후 재설정 : [1, 1024]
  • dialogue_lens 는 num_generations = 1 이므로 모양 동일
  • log_probs는 [0.] 이렇게 생김
  • termination_mask는 [1] 이렇게 생김
  • t = 13 (in this case)

src/utils/sampling_utils.py

map_all_kvs = lambda f, kvs: tuple([tuple(map(f, items)) for items in kvs])
  • kvs는 이중 리스트 [[], [], [], … , []] 이런식으로 생겨야하며
  • 먼저 안의 리스트에 대해 f 를 적용하고
  • 위의 결과들을 모아 tuple을 만든다
dialogue_kvs = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), dialogue_kvs)
  • dialogue_kvs 는 이중 tuple에 둘러쌓인 [1, 12, 16, 64] 크기의 tensor (이중 리스트지만 길이가 1이어서 [[]] 이렇게 생김
  • repeat_interleave 결과는 해도 원래와 동일
  • 위의 tensor 들에 대해 pad_sequence 함수를 적용시키고 (dimension 2에 대해)
  • 결과는 이중 tuple 에 둘러쌓인 [1, 12, 1024, 64]

src/models/bc_lm.py 의 sample_raw function 이어서

while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
            curr_token = tokens[:, t-1].unsqueeze(1)
            curr_dialogue_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], dialogue_kvs)
            transformer_outputs = self.bc_lm(curr_token, None, past_key_values=curr_dialogue_kvs, use_cache=True)
            logits = transformer_outputs.logits
            logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
            logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
            logits = process_logits(transformer_outputs.logits, temp=temp, top_k=top_k, top_p=top_p)
            cat_dist = torch.distributions.categorical.Categorical(logits=logits[:, 0])
            new_tokens = cat_dist.sample()
            log_probs += cat_dist.log_prob(new_tokens)
            tokens[:, t] = new_tokens
            dialogue_kvs = update_kvs(dialogue_kvs, transformer_outputs.past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
            for idx in range(n):
                if tokens[idx, t] == tokenizer.eoa_token_id and t >= dialogue_lens[idx]:
                    termination_mask[idx] *= (1 - int(termination_condition(tokenizer.decode(tokens[idx, :].tolist(), 
                                                                                             clean_up_tokenization_spaces=False))))
            t += 1
            termination_mask *= ((t-dialogue_lens) < max_generation_len).int()
  • curr_token : 제공된 item 을 token 화 시킨 것 중 마지막
  • curr_dialogue_kvs 는 dialogue_kvs가 이중 tuple 안에 tesnor 있는 형태이기 때문에 현재 timestep (t-1) 의 하나 이전인 t-2 까지 kv 값 참조 (첫 iter 기준)
  • logits 모양은 [1, 1, 50264]
  • torch.where 부분은 termination 한다면 padding token에 1e7 부여해 버려서 generation을 끝내고 padding 이 오게 만듦
  • 그 밑줄은 generation 이 여러 개이고 t = dialogue_lens 중 최소로 설정되므로 아직 t가 이미 정해진 dialogue 범위 안에 있는 경우 1e7를 할당해 무조건 그 단어가 나오도록 만듦
  • process_logits는 temp, top_k, top_p 반영 → default 세팅에서는 변경 사항 없음
  • logits[:, 0] : [1, 50264] 에 대해 cat_dist 생성
  • new_tokens 모양 : [1]
  • log_probs 이므로 문장 총 생성 확률을 알기 위해 더한다
  • (t-2) 까지 k, v를 참고해 (t-1) token을 바탕으로 t 번째 단어를 생성했으므로 (t-1) k, v 를 업데이트 해준다. - update_kvs
  • termination_condition은 always true를 반환하므로 만든 token 이 eoa_token 이면 종료
    • termination_mask를 0으로 만듦
  • 원래 있던 token 제외 생성된 토큰들이 max_generation을 넘어가면 termination_mask를 0으로 만듦
        output_strs = [tokenizer.decode(tokens[i, :].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        processed_outputs = []
        for i in range(len(input_strs)):
            temp_outputs = []
            for x in range(num_generations):
                processed_str = output_strs[i*num_generations+x][len(input_strs[i]):].strip()
                if tokenizer.id_to_token(tokenizer.pad_token_id) in processed_str:
                    processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.pad_token_id))].strip()
                if tokenizer.id_to_token(tokenizer.eoa_token_id) in processed_str:
                    processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.eoa_token_id))].strip()
                temp_outputs.append(processed_str)
            processed_outputs.append(temp_outputs)
        return list(zip(input_strs, processed_outputs)), log_probs.reshape(-1, num_generations)
  • 생성된 output을 pad token 이 나오기 전까지 잘라주고
  • input_strs이 여러개인 경우 (input, processed_output) 을 묶은 것, log_probs 를 반환

src/models/bc_lm.py

# from BC_Policy class 
def act(self, obs: Language_Observation) -> str:
        item = DataPoint.from_obs(obs, self.bc_lm.dataset.tokenizer, self.bc_lm.dataset.token_reward)
        generations, probs = self.generate([item], always_terminate, **self.generation_kwargs)
        sorted_outputs = list(zip(*sorted(zip(generations[0][1], probs[0]), key=lambda x: -x[1])))[0]
        return sorted_outputs[0]
  • 여러개 생성하는 경우 log_prob 가 최대인 행동에 대해 action으로 선정

checkpoint save directory

/home/doolee13/Implicit-Language-Q-Learning/src/utils/../../outputs/visual_dialogue/visdial_bc_test1/


모델 변환 커맨드

python convert_bc.py --load ../../outputs/visual_dialogue/visdial_bc_test1/model.pkl --save ../../outputs/visual_dialogue/visdial_bc_test1/model_converted.pkl

scripts/data/convert_bc.py

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--load', type=str)
    parser.add_argument('--save', type=str)
    args = parser.parse_args()

    state_dict = torch.load(args.load, map_location=torch.device('cpu'))
    for k in list(state_dict.keys()):
        if 'model.' in k:
            if k.startswith('model'):
                state_dict['lm_policy'+k[len('model'):]] = state_dict.pop(k)
    torch.save(state_dict, args.save)
  • model.() 이렇게 되어있는 state dict의 이름을 lm_policy.() 이런식으로 바꾼다

odict_keys(['lm_policy.transformer.wte.weight', 'lm_policy.transformer.wpe.weight', 'lm_policy.transformer.h.0.ln_1.weight', 'lm_policy.transformer.h.0.ln_1.bias', 'lm_policy.transformer.h.0.attn.bias', 'lm_policy.transformer.h.0.attn.masked_bias', 'lm_policy.transformer.h.0.attn.c_attn.weight', 'lm_policy.transformer.h.0.attn.c_attn.bias', 'lm_policy.transformer.h.0.attn.c_proj.weight', 'lm_policy.transformer.h.0.attn.c_proj.bias', 'lm_policy.transformer.h.0.ln_2.weight', 'lm_policy.transformer.h.0.ln_2.bias', 'lm_policy.transformer.h.0.mlp.c_fc.weight', 'lm_policy.transformer.h.0.mlp.c_fc.bias', 'lm_policy.transformer.h.0.mlp.c_proj.weight', 'lm_policy.transformer.h.0.mlp.c_proj.bias', 'lm_policy.transformer.h.1.ln_1.weight', 'lm_policy.transformer.h.1.ln_1.bias', 'lm_policy.transformer.h.1.attn.bias', 'lm_policy.transformer.h.1.attn.masked_bias', 'lm_policy.transformer.h.1.attn.c_attn.weight', 'lm_policy.transformer.h.1.attn.c_attn.bias', 'lm_policy.transformer.h.1.attn.c_proj.weight', 'lm_policy.transformer.h.1.attn.c_proj.bias', 'lm_policy.transformer.h.1.ln_2.weight', 'lm_policy.transformer.h.1.ln_2.bias', 'lm_policy.transformer.h.1.mlp.c_fc.weight', 'lm_policy.transformer.h.1.mlp.c_fc.bias', 'lm_policy.transformer.h.1.mlp.c_proj.weight', 'lm_policy.transformer.h.1.mlp.c_proj.bias', 'lm_policy.transformer.h.2.ln_1.weight', 'lm_policy.transformer.h.2.ln_1.bias', 'lm_policy.transformer.h.2.attn.bias', 'lm_policy.transformer.h.2.attn.masked_bias', 'lm_policy.transformer.h.2.attn.c_attn.weight', 'lm_policy.transformer.h.2.attn.c_attn.bias', 'lm_policy.transformer.h.2.attn.c_proj.weight', 'lm_policy.transformer.h.2.attn.c_proj.bias', 'lm_policy.transformer.h.2.ln_2.weight', 'lm_policy.transformer.h.2.ln_2.bias', 'lm_policy.transformer.h.2.mlp.c_fc.weight', 'lm_policy.transformer.h.2.mlp.c_fc.bias', 'lm_policy.transformer.h.2.mlp.c_proj.weight', 'lm_policy.transformer.h.2.mlp.c_proj.bias', 'lm_policy.transformer.h.3.ln_1.weight', 'lm_policy.transformer.h.3.ln_1.bias', 'lm_policy.transformer.h.3.attn.bias', 'lm_policy.transformer.h.3.attn.masked_bias', 'lm_policy.transformer.h.3.attn.c_attn.weight', 'lm_policy.transformer.h.3.attn.c_attn.bias', 'lm_policy.transformer.h.3.attn.c_proj.weight', 'lm_policy.transformer.h.3.attn.c_proj.bias', 'lm_policy.transformer.h.3.ln_2.weight', 'lm_policy.transformer.h.3.ln_2.bias', 'lm_policy.transformer.h.3.mlp.c_fc.weight', 'lm_policy.transformer.h.3.mlp.c_fc.bias', 'lm_policy.transformer.h.3.mlp.c_proj.weight', 'lm_policy.transformer.h.3.mlp.c_proj.bias', 'lm_policy.transformer.h.4.ln_1.weight', 'lm_policy.transformer.h.4.ln_1.bias', 'lm_policy.transformer.h.4.attn.bias', 'lm_policy.transformer.h.4.attn.masked_bias', 'lm_policy.transformer.h.4.attn.c_attn.weight', 'lm_policy.transformer.h.4.attn.c_attn.bias', 'lm_policy.transformer.h.4.attn.c_proj.weight', 'lm_policy.transformer.h.4.attn.c_proj.bias', 'lm_policy.transformer.h.4.ln_2.weight', 'lm_policy.transformer.h.4.ln_2.bias', 'lm_policy.transformer.h.4.mlp.c_fc.weight', 'lm_policy.transformer.h.4.mlp.c_fc.bias', 'lm_policy.transformer.h.4.mlp.c_proj.weight', 'lm_policy.transformer.h.4.mlp.c_proj.bias', 'lm_policy.transformer.h.5.ln_1.weight', 'lm_policy.transformer.h.5.ln_1.bias', 'lm_policy.transformer.h.5.attn.bias', 'lm_policy.transformer.h.5.attn.masked_bias', 'lm_policy.transformer.h.5.attn.c_attn.weight', 'lm_policy.transformer.h.5.attn.c_attn.bias', 'lm_policy.transformer.h.5.attn.c_proj.weight', 'lm_policy.transformer.h.5.attn.c_proj.bias', 'lm_policy.transformer.h.5.ln_2.weight', 'lm_policy.transformer.h.5.ln_2.bias', 'lm_policy.transformer.h.5.mlp.c_fc.weight', 'lm_policy.transformer.h.5.mlp.c_fc.bias', 'lm_policy.transformer.h.5.mlp.c_proj.weight', 'lm_policy.transformer.h.5.mlp.c_proj.bias', 'lm_policy.transformer.h.6.ln_1.weight', 'lm_policy.transformer.h.6.ln_1.bias', 'lm_policy.transformer.h.6.attn.bias', 'lm_policy.transformer.h.6.attn.masked_bias', 'lm_policy.transformer.h.6.attn.c_attn.weight', 'lm_policy.transformer.h.6.attn.c_attn.bias', 'lm_policy.transformer.h.6.attn.c_proj.weight', 'lm_policy.transformer.h.6.attn.c_proj.bias', 'lm_policy.transformer.h.6.ln_2.weight', 'lm_policy.transformer.h.6.ln_2.bias', 'lm_policy.transformer.h.6.mlp.c_fc.weight', 'lm_policy.transformer.h.6.mlp.c_fc.bias', 'lm_policy.transformer.h.6.mlp.c_proj.weight', 'lm_policy.transformer.h.6.mlp.c_proj.bias', 'lm_policy.transformer.h.7.ln_1.weight', 'lm_policy.transformer.h.7.ln_1.bias', 'lm_policy.transformer.h.7.attn.bias', 'lm_policy.transformer.h.7.attn.masked_bias', 'lm_policy.transformer.h.7.attn.c_attn.weight', 'lm_policy.transformer.h.7.attn.c_attn.bias', 'lm_policy.transformer.h.7.attn.c_proj.weight', 'lm_policy.transformer.h.7.attn.c_proj.bias', 'lm_policy.transformer.h.7.ln_2.weight', 'lm_policy.transformer.h.7.ln_2.bias', 'lm_policy.transformer.h.7.mlp.c_fc.weight', 'lm_policy.transformer.h.7.mlp.c_fc.bias', 'lm_policy.transformer.h.7.mlp.c_proj.weight', 'lm_policy.transformer.h.7.mlp.c_proj.bias', 'lm_policy.transformer.h.8.ln_1.weight', 'lm_policy.transformer.h.8.ln_1.bias', 'lm_policy.transformer.h.8.attn.bias', 'lm_policy.transformer.h.8.attn.masked_bias', 'lm_policy.transformer.h.8.attn.c_attn.weight', 'lm_policy.transformer.h.8.attn.c_attn.bias', 'lm_policy.transformer.h.8.attn.c_proj.weight', 'lm_policy.transformer.h.8.attn.c_proj.bias', 'lm_policy.transformer.h.8.ln_2.weight', 'lm_policy.transformer.h.8.ln_2.bias', 'lm_policy.transformer.h.8.mlp.c_fc.weight', 'lm_policy.transformer.h.8.mlp.c_fc.bias', 'lm_policy.transformer.h.8.mlp.c_proj.weight', 'lm_policy.transformer.h.8.mlp.c_proj.bias', 'lm_policy.transformer.h.9.ln_1.weight', 'lm_policy.transformer.h.9.ln_1.bias', 'lm_policy.transformer.h.9.attn.bias', 'lm_policy.transformer.h.9.attn.masked_bias', 'lm_policy.transformer.h.9.attn.c_attn.weight', 'lm_policy.transformer.h.9.attn.c_attn.bias', 'lm_policy.transformer.h.9.attn.c_proj.weight', 'lm_policy.transformer.h.9.attn.c_proj.bias', 'lm_policy.transformer.h.9.ln_2.weight', 'lm_policy.transformer.h.9.ln_2.bias', 'lm_policy.transformer.h.9.mlp.c_fc.weight', 'lm_policy.transformer.h.9.mlp.c_fc.bias', 'lm_policy.transformer.h.9.mlp.c_proj.weight', 'lm_policy.transformer.h.9.mlp.c_proj.bias', 'lm_policy.transformer.h.10.ln_1.weight', 'lm_policy.transformer.h.10.ln_1.bias', 'lm_policy.transformer.h.10.attn.bias', 'lm_policy.transformer.h.10.attn.masked_bias', 'lm_policy.transformer.h.10.attn.c_attn.weight', 'lm_policy.transformer.h.10.attn.c_attn.bias', 'lm_policy.transformer.h.10.attn.c_proj.weight', 'lm_policy.transformer.h.10.attn.c_proj.bias', 'lm_policy.transformer.h.10.ln_2.weight', 'lm_policy.transformer.h.10.ln_2.bias', 'lm_policy.transformer.h.10.mlp.c_fc.weight', 'lm_policy.transformer.h.10.mlp.c_fc.bias', 'lm_policy.transformer.h.10.mlp.c_proj.weight', 'lm_policy.transformer.h.10.mlp.c_proj.bias', 'lm_policy.transformer.h.11.ln_1.weight', 'lm_policy.transformer.h.11.ln_1.bias', 'lm_policy.transformer.h.11.attn.bias', 'lm_policy.transformer.h.11.attn.masked_bias', 'lm_policy.transformer.h.11.attn.c_attn.weight', 'lm_policy.transformer.h.11.attn.c_attn.bias', 'lm_policy.transformer.h.11.attn.c_proj.weight', 'lm_policy.transformer.h.11.attn.c_proj.bias', 'lm_policy.transformer.h.11.ln_2.weight', 'lm_policy.transformer.h.11.ln_2.bias', 'lm_policy.transformer.h.11.mlp.c_fc.weight', 'lm_policy.transformer.h.11.mlp.c_fc.bias', 'lm_policy.transformer.h.11.mlp.c_proj.weight', 'lm_policy.transformer.h.11.mlp.c_proj.bias', 'lm_policy.transformer.ln_f.weight', 'lm_policy.transformer.ln_f.bias', 'lm_policy.lm_head.weight'])

이런식으로 바꿈 : 원래 다 model.() 이런 형식 이었는듯?


IQL 시작

들어가기 앞서

1) python train_iql.py model.load.checkpoint_path=outputs/visdial/model_converted.pkl model.load.strict_load=false train.loss.awac_weight=0.0

이런식으로 trained BC 이용해서 훈련시킬 수 있고

2) 그냥 python train_iql.py 하면 (awac_weight =1.0) BC와 offline RL 동시에 훈련 가능하다

여기서부터는 그냥 training 동시에 하는 경우 가정

실제로 사용하는 cmd

python train_iql.py model.load.checkpoint_path=outputs/visual_dialogue/visdial_bc_test1/model_converted.pkl model.load.strict_load=false train.loss.awac_weight=0.0

일단 bc training 끝나고 불러오는 경우 가정

model load 부분에서 다름

src/load_objects.py

@register('per_token_iql')
def load_per_token_iql(config, device, verbose=True):
    gpt2 = load_item(config['gpt2'], verbose=verbose)
    dataset = load_item(config['dataset'], device, verbose=verbose)
    model = PerTokenIQL(gpt2, dataset, device, config['alpha'], config['gamma'], 
                        config['beta'], config['transition_weight'], config['clip_weight'], 
                        config['value_max'], config['value_min'], config['detach_v'], 
                        config['detach_pi'], config['detach_q'], config['double_q'], 
                        config['tau'], config['seperate_policy'], config['seperate_target'], 
                        config['exp_weights'], config['dm_margin'], config['advanced_mlp'], 
                        config['cql_temp'])
    return load_model(config['load'], model, device, verbose=verbose)
  • gpt2 는 gpt2lmheadmodel (bc 와 동일)
  • PerTokenIQL은 (src/models/iql_model.py에) advanced_mlp가 true이면 3 layer mlp를 q, v로, False 이며 2 layer mlp 를 q, v network로 설정 (default = false)
  • self.double_q 가 True인 self.q2 에 network를 할당 (이 경우 target_q2 network도 만들어줌)
  • 별개로 target_q newtork를 항상 만들어줌 (double = True 면 target도 두 개 만듦)
  • q network와 target q network의 weight를 동일하게 만들어줌
  • self.seperate_target 이 True 일 경우 self.lm_target(self.model과 weight가 동일한) 을 만드는데 아직 의도는 모르겠음
  • self.seperate_policy 이 True 일 경우 마찬가지로 self.lm_policy를 만드는데 의도 모름 (얘도 self.model과 weight 동일하게 만들어짐) → 얘는 나중에 self.model(GPT2LMHead) BC에 대해 훈련 시킨 weight로 load됨

이렇게 PerTokenIQL 클래스를 만들고

src/load_object.py

def load_model(config, model, device, verbose=True):
    model = model.to(device)
    if config['checkpoint_path'] is not None:
        if verbose:
            print('loading %s state dict from: %s' % (config['name'], convert_path(config["checkpoint_path"])))
        model.load_state_dict(torch.load(convert_path(config['checkpoint_path']), map_location='cpu'), strict=config['strict_load'])
        if verbose:
            print('loaded.')
    return model
  • checkpoint_path에서 PerTokenIQL 의 lm_policy 에 load 한다
    • 앞서 dictionary key의 model→lm_policy로 다 바꿨었음
@register('per_token_iql')
def load_per_token_iql(config, device, verbose=True):
    gpt2 = load_item(config['gpt2'], verbose=verbose)
    dataset = load_item(config['dataset'], device, verbose=verbose)
    model = PerTokenIQL(gpt2, dataset, device, config['alpha'], config['gamma'], 
                        config['beta'], config['transition_weight'], config['clip_weight'], 
                        config['value_max'], config['value_min'], config['detach_v'], 
                        config['detach_pi'], config['detach_q'], config['double_q'], 
                        config['tau'], config['seperate_policy'], config['seperate_target'], 
                        config['exp_weights'], config['dm_margin'], config['advanced_mlp'], 
                        config['cql_temp'])
    **return load_model(config['load'], model, device, verbose=verbose)**
  • 위의 내용이 여기서 실행되는 것임

scripts/train/iql_train_loop.py

**loss, logs, postproc_fs = accelerator.unwrap_model(model).get_loss(items, **train_cfg['loss'])**

src/models/iql_model.py

def get_loss(self, 
                 items: InputType, 
                 awac_weight=0.0, 
                 v_loss_weight=0.0, # 1.0
                 q_loss_weight=0.0, # 1.0
                 cql_loss_weight=0.0, # 1.0
                 dm_loss_weight=0.0, 
                 mc_returns=False):
        prepared_inputs = self.prepare_inputs(items)
        a_idx = prepared_inputs['action_idxs']
        **get_qvs_outputs = self.get_qvs(items, 
                                       qv_kwargs={'output_attentions': True}, 
                                       policy_kwargs={'output_attentions': True}, 
                                       target_kwargs={'output_attentions': True}, 
                                       skip_policy_on_train=(awac_weight == 0.0), 
                                      )**
        tokens, attn_mask, model_outputs = get_qvs_outputs['tokens'], get_qvs_outputs['attn_mask'], get_qvs_outputs['model_outputs']
        vs, qs = get_qvs_outputs['vs'], get_qvs_outputs['qs']
        vns, target_qs, rs = get_qvs_outputs['vns'], get_qvs_outputs['target_qs'], get_qvs_outputs['rs']
        terminals, logits, weights = get_qvs_outputs['terminals'], get_qvs_outputs['logits'], get_qvs_outputs['weights']
  • prepared_inputs 는 위에서 설명한 prepare_inputs 사용
  • self.prepare_inputs 는 사실 아무것도 안하고 item 그대로 반환
def get_qvs(self, items: InputType, 
                prefix_embs: Optional[torch.Tensor]=None, 
                prefix_attn_mask: Optional[torch.Tensor]=None, 
                remove_prefix_position_embs: bool=False, 
                qv_kwargs=None, policy_kwargs=None, target_kwargs=None, 
                **kwargs):
        prepared_inputs = self.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        s_idx, a_idx = prepared_inputs['state_idxs'], prepared_inputs['action_idxs']
        rs, terminals = prepared_inputs['rewards'], prepared_inputs['terminals']
        **self_outputs = self(tokens, attn_mask, s_idx, a_idx, 
                            prefix_embs, prefix_attn_mask, 
                            remove_prefix_position_embs, 
                            qv_kwargs, policy_kwargs, target_kwargs, 
                            **kwargs)**
def forward(self, 
                tokens: torch.Tensor, 
                attn_mask: Optional[torch.Tensor], 
                state_idxs: torch.Tensor, 
                action_idxs: torch.Tensor, 
                prefix_embs: Optional[torch.Tensor]=None, 
                prefix_attn_mask: Optional[torch.Tensor]=None, 
                remove_prefix_position_embs: bool=False, 
                qv_kwargs=None, policy_kwargs=None, target_kwargs=None, 
                skip_policy_on_train=False, 
                detach_full_policy=False):
        if qv_kwargs is None:
            qv_kwargs = {}
        if target_kwargs is None:
            target_kwargs = {}
        if policy_kwargs is None:
            policy_kwargs = {}
        if self.lm_target is None: # 안함
            qv_kwargs.update(target_kwargs)
        if self.lm_policy is None: # 안함
            qv_kwargs.update(policy_kwargs)
        if attn_mask is None: # 안함
            attn_mask = torch.ones(tokens.shape, dtype=torch.long).to(self.device)
        **if prefix_embs is None:
            prefix_embs = torch.empty((tokens.shape[0], 0, self.h_dim)).to(self.device)**
        prefix_t = prefix_embs.shape[1]
        set_pos_ids = prefix_attn_mask is not None
        if prefix_attn_mask is None:
            prefix_attn_mask = torch.ones(prefix_embs.shape[:2]).to(self.device)
        input_attn_mask = torch.cat((prefix_attn_mask, attn_mask), dim=1)
        position_ids = torch.cumsum(input_attn_mask, dim=1)-1 if set_pos_ids else None
  • qv_kwargs : {'output_attentions': True}
  • target_kwargs: {'output_attentions': True}
  • policy_kwargs: {'output_attentions': True}
  • self.lm_target, self.lm_policy 는 self.model(gpt2lmhead) weight를 갖고 있음
  • attn_mask는 이미 정해져 있음
  • prefix_embs 는 none 이라 if 문 통과 후 (1, 0, 768)
  • prefix_t : 0 , set_pos_ids = False
  • prefix_attn_mask : [1,0] → 그냥 [] 이라고 보면 됨
  • input_attn_mask 는 attn_mask와 동일
  • position_ids = None
        # def forward
        if isinstance(self.model, GPT2Model):
            transformer = self.model
            if self.lm_target is not None:
                target_transformer = self.lm_target
            if self.lm_policy is not None:
                policy_transformer = self.lm_policy
        **elif isinstance(self.model, GPT2LMHeadModel):
            transformer = self.model.transformer
            if self.lm_target is not None:
                target_transformer = self.lm_target.transformer
            if self.lm_policy is not None:
                policy_transformer = self.lm_policy.transformer**
        else:
            raise NotImplementedError
        **if self.lm_target is not None:
            target_prefix_embs = prefix_embs.clone() # [1, 0, 768]
        if self.lm_policy is not None:
            policy_prefix_embs = prefix_embs.clone() # [1, 0, 768]**
        if remove_prefix_position_embs:
            prefix_embs -= transformer.wpe(position_ids[:, :prefix_embs.shape[1]])
        input_embeddings = torch.cat((prefix_embs, transformer.wte(tokens)), dim=1)
        model_outputs = self.model(inputs_embeds=input_embeddings, 
                                   attention_mask=input_attn_mask, 
                                   position_ids=position_ids, #None
                                   output_hidden_states=True, 
                                   **qv_kwargs)
        all_model_outputs = {
                                'qv_model_outputs': model_outputs, 
                                'policy_model_outputs': model_outputs, 
                                'target_model_outputs': model_outputs
                            }
  • GPT2LMHeadModeld은 GPT2Model + lmhead 인데 여기서 transformer 변수에 self.model.transformer에 할당
  • lm_target과 lm_policy가 둘다 none이 아니므로 각각의 transformer에 target_transformer, policy_transformer 을 할당
  • target_prefix_embs, policy_prefix_embs 에 prefix_embs를 복사
  • remove_prefix_position_emb는 안함
  • input_embeddings = transformer.wte(tokens) 와 같음
  • 위에서 했던 것처럼 token, attn_mask 넘기는 대신 wte(token), attn_mask 넘기는 방식
    • token shape : [1, 112] 일 때 input_embeddings [1, 112, 768]
  • model_outputs 의 key는 ['logits', 'past_key_values', 'hidden_states', 'attentions']
        # def forward
        if self.advanced_mlp:
            hidden_states = model_outputs.hidden_states[-2][:, prefix_t:, :]
        else:
            **hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]**
        if self.lm_target is None:
            target_hidden_states = hidden_states
        else:
            if remove_prefix_position_embs:
                target_prefix_embs -= target_transformer.wpe(position_ids[:, :prefix_embs.shape[1]])
            **target_input_embeddings = torch.cat((target_prefix_embs, target_transformer.wte(tokens)), dim=1)**
            **with torch.no_grad():
                target_outputs = self.lm_target(inputs_embeds=target_input_embeddings, 
                                                attention_mask=input_attn_mask, 
                                                position_ids=position_ids, 
                                                output_hidden_states=True, 
                                                **target_kwargs)**
            all_model_outputs['target_model_outputs'] = target_outputs
            if self.advanced_mlp:
                target_hidden_states = target_outputs.hidden_states[-2][:, prefix_t:, :]
            else:
                **target_hidden_states = target_outputs.hidden_states[-1][:, prefix_t:, :]**
  • model_outputs 는 embedding output + 각 layer 의 output tuple
    • 각 원소 크기는 (batch_size, sequence_len, hidden_dim)
    • output_hidden_state = True 옵션 줘야 반환됨
  • bold 부분 실행
  • self.model 과 마찬가지로 self.lm_target forward(without gradient)
  • shape of hidden_state : [1, seq_len, 768]
        f self.lm_policy is None:
            if isinstance(self.model, GPT2Model):
                policy_hidden_states = hidden_states
            else:
                policy_hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]
        **else:**
            if skip_policy_on_train and self.training:
                **policy_hidden_states = hidden_states**
            else:
                if remove_prefix_position_embs:
                    policy_prefix_embs -= policy_transformer.wpe(position_ids[:, :prefix_embs.shape[1]])
                policy_input_embeddings = torch.cat((policy_prefix_embs, policy_transformer.wte(tokens)), dim=1)
                if detach_full_policy:
                    with torch.no_grad():
                        policy_outputs = self.lm_policy(inputs_embeds=policy_input_embeddings, 
                                                        attention_mask=input_attn_mask, 
                                                        position_ids=position_ids, 
                                                        output_hidden_states=True, 
                                                        **policy_kwargs)
                else:
                    policy_outputs = self.lm_policy(inputs_embeds=policy_input_embeddings, 
                                                        attention_mask=input_attn_mask, 
                                                        position_ids=position_ids, 
                                                        output_hidden_states=True, 
                                                        **policy_kwargs)
                all_model_outputs['policy_model_outputs'] = policy_outputs
                if isinstance(self.model, GPT2Model):
                    if self.advanced_mlp:
                        policy_hidden_states = policy_outputs.hidden_states[-2][:, prefix_t:, :]
                    else:
                        policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
                else:
                    policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
  • skip_policy_on_train 과 self.train 이 모두 True
        state_hidden_states = torch.gather(input=hidden_states, dim=1, index=state_idxs.unsqueeze(2).repeat(1, 1, self.h_dim))
        action_hidden_states = torch.gather(input=hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim))
        action_target_hidden_states = torch.gather(input=target_hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim))
        vs = self.v(state_hidden_states.detach() if self.detach_v else state_hidden_states).squeeze(2)
        qs = self.q(action_hidden_states.detach() if self.detach_q else action_hidden_states)

ipdb> print(state_idxs)
tensor([[ 11, 12, 13, 14, 15, 22, 23, 24, 25, 26, 27, 30, 31, 32,
33, 34, 35, 38, 39, 40, 41, 42, 43, 44, 52, 53, 54, 55,
56, 57, 58, 59, 60, 63, 64, 65, 66, 67, 68, 71, 72, 73,
74, 75, 76, 79, 80, 81, 82, 83, 84, 85, 86, 87, 90, 91,
92, 93, 94, 95, 98, 99, 100, 101, 102, 103, 107]],
device='cuda:0')
ipdb> print(action_idxs)
tensor([[ 11, 12, 13, 14, 15, 22, 23, 24, 25, 26, 27, 30, 31, 32,
33, 34, 35, 38, 39, 40, 41, 42, 43, 44, 52, 53, 54, 55,
56, 57, 58, 59, 60, 63, 64, 65, 66, 67, 68, 71, 72, 73,
74, 75, 76, 79, 80, 81, 82, 83, 84, 85, 86, 87, 90, 91,
92, 93, 94, 95, 98, 99, 100, 101, 102, 103]], device='cuda:0')

  • state and action idx 가 마지막만 다르고 같다
  • state_hidden_state.shape : [1, 67, 768] // action_hidden_state.shape : [1, 66, 768]
  • detach_v, detach_q 가 다 False
        if self.double_q:
            qs2 = self.q2(action_hidden_states.detach() if self.detach_q else action_hidden_states)
        with torch.no_grad():
            target_qs = self.target_q(action_target_hidden_states)
            if self.double_q:
                target_qs2 = self.target_q2(action_target_hidden_states)
        if skip_policy_on_train and self.training and self.lm_policy is not None:
            **logits = torch.zeros((policy_hidden_states.shape[0],policy_hidden_states.shape[1],self.dataset.tokenizer.num_tokens(),)).to(self.device)**
        else:
            if detach_full_policy:
                with torch.no_grad():
                    logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
            else:
                logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
        return  {
                    'model_outputs': all_model_outputs, 
                    'vs': vs, 
                    'target_vs': vs, 
                    'qs': (qs, qs2,) if self.double_q else qs, 
                    'target_qs': self.clip_values(torch.minimum(target_qs, target_qs2) if self.double_q else target_qs), 
                    'logits': logits, 
                }
  • double_q = True
  • logits shape = [1, seq_len, 50264] but all zeros in this case
  • 앞서 policy_hidden_states에 그냥 hidden_state 집어넣었었음

다시 get_qvs

        # def get_qvs
				self_outputs = self(tokens, attn_mask, s_idx, a_idx, 
                            prefix_embs, prefix_attn_mask, 
                            remove_prefix_position_embs, 
                            qv_kwargs, policy_kwargs, target_kwargs, 
                            **kwargs)
        model_outputs, vs, qs = self_outputs['model_outputs'], self_outputs['vs'], self_outputs['qs']
        target_qs, logits = self_outputs['target_qs'], self_outputs['logits']
        vt = vs[:, :-1] # [1, 66]
        vtp1 = vs[:, 1:] 
        select_tokens = torch.gather(tokens[:, 1:], dim=1, index=a_idx)
        cql_term = self.get_cql_loss(qs, select_tokens, terminals)
    
  • select_tokens : [1, 107] 인 tokens[ : , 1:] 에서 [1, 66]인 a_idx 로 indexing 해서 [1, 66]의 텐서를 추출
  • cql loss는 논문 p.5 에 의하면 ‘adding a small amount of NLL loss to the Q values’ 인데 get_cql_loss는 qs와 select_tokens 사이의 cross entropy를 계산해 넘긴다
  • select_tokens는 q network의 [batch_size, seq_len, action_dim]을 내뱉는 구조상 이를 [batch_size, seq_len] 으로 전환하기 위한 장치로 볼 수 있음 (까먹었으면 요약본 참고)
def get_cql_loss(self, qs, action_tokens, terminals):
        n = (1 - terminals[:, :-1]).sum()
        if self.double_q:
            q1, q2 = qs
            b, t, d = q1.shape
            return ((F.cross_entropy(q1.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction='none').reshape(b, t) * (1 - terminals[:, :-1])) + (F.cross_entropy(q2.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction='none').reshape(b, t) * (1 - terminals[:, :-1]))).sum() / max(n.item(), 1.0)
        b, t, d = qs.shape
        return (F.cross_entropy(qs.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction='none').reshape(b, t) * (1 - terminals[:, :-1])).sum() / max(n.item(), 1.0)
  • qs[0].shape : [1, 66, 50264] // action_tokens.shape : [1, 66]
  • terminals : [1, 67] // 끝에만 1 signal 떠있음
  • 결과는 scalar
        full_qs = qs
        if self.double_q:
            q1, q2 = qs
            q1 = torch.gather(q1, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
            q2 = torch.gather(q2, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
            # tok_seq = [self.dataset.tokenizer.id_to_token(token) for token in select_tokens[0].detach().cpu().tolist()][:(1-terminals[0, :-1]).sum()]
            # max_q_seq = torch.max(q1, q2)[0, :(1-terminals[0, :-1]).sum()].detach().cpu().tolist()
            # print(self.dataset.tokenizer.decode(tokens[0, :][:attn_mask[0, :].sum().long()].tolist(), clean_up_tokenization_spaces=False))
            # print(list(zip(tok_seq, max_q_seq)))
            # print(rs)
            qs = (q1, q2,)
        else:
            qs = torch.gather(qs, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
        dm_term = self.get_dm_loss(full_qs, qs, terminals, self.dm_margin)
        target_qs = torch.gather(target_qs, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
        with torch.no_grad():
            weights = self.get_weights(tokens, vt, target_qs, s_idx, a_idx, terminals)
        return {
                    'tokens': tokens, 
                    'attn_mask': attn_mask, 
                    'model_outputs': model_outputs, 
                    'vs': vt, 
                    'qs': qs, 
                    'vns': vtp1, 
                    'target_vs': vt, 
                    'target_qs': target_qs, 
                    'target_vns': vtp1, 
                    'rs': rs, 
                    'terminals': terminals, 
                    'logits': logits, 
                    'weights': weights, 
                    'cql_term': cql_term, 
                    'dm_term': dm_term, 
                }
  • q1 크기 : [1, 66, 50264] // select_tokens : [1, 66]
  • q1 의 최종 크기는 [1, 66] 이고 선택한 token들의(action_idx) 예측된 q 값을 담고있다
  • dm_loss는 전체 단어에 대한 q 예측값에서 실제로 선택된 action의 q 예측값을 뺀 것을 제곱해 더한 regularizing term?
  • target_qs에 대해서도 실제로 선택한 action에 대해 q 값을 선택 [1, 66]
  • weights 일단 생략 : bc 따로 먼저 진행하는 경우에는 필요 없음

def get_loss()

def get_loss(self, 
                 items: InputType, 
                 awac_weight=0.0, 
                 v_loss_weight=0.0, 
                 q_loss_weight=0.0, 
                 cql_loss_weight=0.0, 
                 dm_loss_weight=0.0, 
                 mc_returns=False):
        prepared_inputs = self.prepare_inputs(items)
        a_idx = prepared_inputs['action_idxs']
        get_qvs_outputs = self.get_qvs(items, 
                                       qv_kwargs={'output_attentions': True}, 
                                       policy_kwargs={'output_attentions': True}, 
                                       target_kwargs={'output_attentions': True}, 
                                       skip_policy_on_train=(awac_weight == 0.0), 
                                      )
        tokens, attn_mask, model_outputs = get_qvs_outputs['tokens'], get_qvs_outputs['attn_mask'], get_qvs_outputs['model_outputs']
        vs, qs = get_qvs_outputs['vs'], get_qvs_outputs['qs']
        vns, target_qs, rs = get_qvs_outputs['vns'], get_qvs_outputs['target_qs'], get_qvs_outputs['rs']
        terminals, logits, weights = get_qvs_outputs['terminals'], get_qvs_outputs['logits'], get_qvs_outputs['weights']
				n = (1 - terminals[:, :-1]).sum().item()
        rs_downstream = self.get_downstream_rs(rs, self.gamma)
				if mc_returns:
            v_loss = self.get_v_loss(vs, rs_downstream, terminals)
        **else:
            v_loss = self.get_v_loss(vs, target_qs, terminals)**
				q_loss = self.get_q_loss(vns, qs, rs, self.gamma, terminals)
				cql_loss = get_qvs_outputs['cql_term']
        dm_loss = get_qvs_outputs['dm_term']
        token_loss = self.awac_loss(tokens, attn_mask, logits, weights)
        loss = awac_weight * token_loss + v_loss_weight * v_loss + q_loss_weight * q_loss + cql_loss_weight * cql_loss + dm_loss_weight * dm_loss
				return loss, logs, [postproc_f, hist_f]
  • 일단 log 부분 생략중
  • rs_downstream은 gamma factor 고려한 최종 reward (rs 와 모양 같음)
  • 아래 loss에 해당하는게 v_loss

Untitled

  • 여기 해당하는게 q_loss

Untitled

  • awac weight : 0 , dm_loss_weight : 0 나머지는 1

evaluation 부분

iql_train_loop.py

evaluator_logs = evaluator.evaluate(accelerator.unwrap_model(model), eval_items)

src/models/iql_model.py

def evaluate(self, model: PerTokenIQL, items: InputType) -> Optional[Dict[str, Any]]:
        policy = IQL_Policy(model, self.kind, **self.generation_kwargs)
        tokens = model.prepare_inputs(items)['tokens']
        total_token_reward = 0
        total_env_reward = 0
        for i in range(tokens.shape[0]):
            **result, sequence = interact_environment(self.env, policy, None)**

src/data/language_environment.py

def interact_environment(env: Language_Environment, policy: Policy, obs: Optional[Language_Observation]=None):
    obs_sequence = []
    if obs is None:
        obs = env.reset()
    while not env.is_terminal():
        **action = policy.act(obs)**

src/models/iql_model.py

def act(self, obs: Language_Observation) -> str:
        item = DataPoint.from_obs(obs, self.iql_model.dataset.tokenizer, self.iql_model.dataset.token_reward)
        **generations, logprobs, kls = self.generate([item], always_terminate, **self.generation_kwargs)**
        self.kls_all.append(kls[0, 0].item())
        self.logprobs_all.append(logprobs[0, 0].item())
        return generations[0][1][0]
def generate(self, items: InputType, 
                 termination_condition: Callable[[np.ndarray], bool], **kwargs):
        prepared_inputs = self.iql_model.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        state_idxs, action_idxs = prepared_inputs['state_idxs'], prepared_inputs['action_idxs']
        if self.kind == 'beam':
            method = self.beam_raw
        **elif self.kind == 'sample':
            method = self.sample_raw**
        else:
            raise NotImplementedError
        generations, info, kls = method(tokens, attn_mask, 
                                             state_idxs, action_idxs, 
                                             termination_condition, 
                                             **kwargs)
        return generations, info, kls
def sample_raw(self, 
                   tokens: torch.Tensor, attn_mask: torch.Tensor, 
                   state_idxs: torch.Tensor, action_idxs: torch.Tensor, 
                   termination_condition: Callable[[np.ndarray], bool], 
                   num_generations=1, max_generation_len=None, 
                   temp=1.0, top_k=None, top_p=None, 
                   exp_adv=False, adv_weight=0.0, adv_clip=None, 
                   include_logits=True, include_adv=True, 
                   rerank_log_prob_weight: float=0.0, 
                   rerank_advantage_weight: float=0.0, 
                   prefix_embs: Optional[torch.Tensor]=None, 
                   prefix_attn_mask: Optional[torch.Tensor]=None, 
                   remove_prefix_position_embs: bool=False):
        assert include_logits or include_adv
        
        tokenizer = self.iql_model.dataset.tokenizer
        max_length = self.iql_model.dataset.max_len
        if max_length is None:
            max_length = self.iql_model.model.config.n_positions
        max_length = min(max_length, self.iql_model.model.config.n_positions)
        device = self.iql_model.device
        bsize = tokens.shape[0]
        n = bsize * num_generations
        if max_generation_len is None:
            max_generation_len = max_length+1
        input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
        model_outputs = self.iql_model(tokens, attn_mask, 
                                       state_idxs, action_idxs, 
                                       prefix_embs=prefix_embs, 
                                       prefix_attn_mask=prefix_attn_mask, 
                                       remove_prefix_position_embs=remove_prefix_position_embs, 
                                       qv_kwargs={'use_cache': True}, 
                                       policy_kwargs={'use_cache': True}, 
                                       target_kwargs={'use_cache': True})['model_outputs']
  • max_len : 1024, bsize : 1, num_generations : 1
  • ex input_strs : [' this is a school bus driving in the rain ']
  • prefix_t = 0
        kvs = {'qv': model_outputs['qv_model_outputs'].past_key_values}
        **if self.iql_model.lm_target is not None:
            kvs['target'] = model_outputs['target_model_outputs'].past_key_values
        if self.iql_model.lm_policy is not None:
            kvs['policy'] = model_outputs['policy_model_outputs'].past_key_values**
        dialogue_lens = attn_mask.sum(dim=1)
        tokens = pad_sequence(torch.repeat_interleave(tokens, num_generations, dim=0), max_length, tokenizer.pad_token_id, device, 1)
        dialogue_lens = torch.repeat_interleave(dialogue_lens, num_generations, dim=0)
        kvs['qv'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), kvs['qv'])
        if 'target' in kvs:
            kvs['target'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), kvs['target'])
        if 'policy' in kvs:
            kvs['policy'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), kvs['policy'])
        log_probs = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
        kls = torch.full((dialogue_lens.shape[0],), math.log(num_generations)-((num_generations-1)/num_generations)).to(device)
        advantages = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
        termination_mask = torch.full((dialogue_lens.shape[0],), 1).to(device)
        state_idxs_temp, action_idxs_temp = torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device), torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device)
        t = torch.min(dialogue_lens).int()
        base_logits = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
  • kvs 의 ‘qv’는 self.model의 past_key_val , ‘target’, ‘policy’에 모두 해당
  • dialogue_lens : [11.]
  • tokens : [1, 11] → [1, 1024]
  • dialogue_lens shape : [1] 로 변화없음
  • map_all_kvs 는 이전 sample_raw 참조 (이중 리스트에 pad_sequence 적용)
  • num_generation 만큼 log_probs, kls, advantages, termination_mask 준비
     while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
            curr_token = tokens[:, t-1].unsqueeze(1)
            curr_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['qv'])
            curr_target_kvs, curr_policy_kvs = curr_kvs, curr_kvs
            if 'target' in kvs:
                curr_target_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['target'])
            if 'policy' in kvs:
                curr_policy_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['policy'])
            iql_outputs = self.iql_model(curr_token, None, state_idxs_temp, action_idxs_temp, 
                                         qv_kwargs={'use_cache': True, 'past_key_values': curr_kvs}, 
                                         policy_kwargs={'use_cache': True, 'past_key_values': curr_policy_kvs}, 
                                         target_kwargs={'use_cache': True, 'past_key_values': curr_target_kvs})
            model_outputs, logits = iql_outputs['model_outputs'], iql_outputs['logits']
            
            logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
            logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
            edited_logits = process_logits(logits.clone(), temp=temp, top_k=top_k, top_p=top_p)

            vs, qs = iql_outputs['target_vs'], iql_outputs['target_qs']
            **if exp_adv:
                adv_logits = adv_weight * (qs - vs.unsqueeze(2))**
            else:
                adv_sign = ((qs - vs.unsqueeze(2)) > 0.0).float()
                adv_logits = adv_weight * adv_sign + (1 - adv_weight) * (1 - adv_sign)
                adv_logits = torch.log(adv_logits)
            if adv_clip is not None: # no running
                adv_logits = torch.clip(adv_logits, max=adv_clip)
            adv_logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
            adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)

            full_logits = (edited_logits if include_logits else 0.0) + (adv_logits if include_adv else 0.0) + base_logits.unsqueeze(1).unsqueeze(2)

            cat_dist = torch.distributions.categorical.Categorical(logits=full_logits[:, 0])
            original_cat_dist = torch.distributions.categorical.Categorical(logits=logits[:, 0])

            new_tokens = cat_dist.sample()
            log_probs += cat_dist.log_prob(new_tokens)
            kls += (cat_dist.log_prob(new_tokens) - original_cat_dist.log_prob(new_tokens))
            qs_chosen = torch.gather(qs.squeeze(1), dim=1, index=new_tokens.unsqueeze(1)).squeeze(1)
            advantages += (qs_chosen - vs.squeeze(1))
            tokens[:, t] = new_tokens
            kvs['qv'] = update_kvs(kvs['qv'], model_outputs['qv_model_outputs'].past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
            if 'target' in kvs:
                kvs['target'] = update_kvs(kvs['target'], model_outputs['target_model_outputs'].past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
            if 'policy' in kvs:
                kvs['policy'] = update_kvs(kvs['policy'], model_outputs['policy_model_outputs'].past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
            for idx in range(n):
                if tokens[idx, t] == tokenizer.eoa_token_id and t >= dialogue_lens[idx]:
                    termination_mask[idx] *= (1 - int(termination_condition(tokenizer.decode(tokens[idx, :].tolist(), 
                                                                                             clean_up_tokenization_spaces=False))))
            t += 1
            termination_mask *= ((t-dialogue_lens) < max_generation_len).int()
  • curr_kv, curr_policy_kv, curr_target_kv 사용해 forward
  • termination mask가 1이면 pad sequence에 -무한 확률을 부여해 끝나지 않도록 하고 아닐 경우 큰 확률을 부여해 끝나도록 한다
  • 그 밑줄은 generation 이 여러 개이고 t = dialogue_lens 중 최소로 설정되므로 아직 t가 이미 정해진 dialogue 범위 안에 있는 경우 1e7를 할당해 무조건 그 단어가 나오도록 만듦
  • process_logits는 temp, top_k, top_p 반영 → default 세팅에서는 변경 사항 없음
  • vs : [1, 1] , qs: [1, 1, 50264] , adv_weight: 16
  • advantage logit 에 대해서도 logit과 같은 과정 진행
  • 위의 두개를 합쳐 full_logits를 만듦 , full_logits[:, 0].shape → [1, 50264]
  • cat_dist : full_logit 으로 만듦, original_cat_dist : 원래 logit으로 만듦
  • cummulate log logit , add kl divergence
  • 선택된 token에 대해 q값 계산해 v와 advantage 계산
  • token 할당 후 qv, target, policy의 kv 업데이트
        scores = ((advantages * rerank_advantage_weight) + (log_probs * rerank_log_prob_weight)).reshape(-1, num_generations)
        order = torch.argsort(-scores, dim=1)
        output_strs = [tokenizer.decode(tokens[i, :].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        processed_outputs = []
        for i in range(len(input_strs)):
            temp_outputs = []
            for x in range(num_generations):
                processed_str = output_strs[i*num_generations+order[i, x]][len(input_strs[i]):].strip()
                if tokenizer.id_to_token(tokenizer.pad_token_id) in processed_str:
                    processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.pad_token_id))].strip()
                if tokenizer.id_to_token(tokenizer.eoa_token_id) in processed_str:
                    processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.eoa_token_id))].strip()
                temp_outputs.append(processed_str)
            processed_outputs.append(temp_outputs)
        scores = torch.gather(scores, dim=1, index=order)
        log_probs = torch.gather(log_probs.reshape(-1, num_generations), dim=1, index=order)
        kls = torch.gather(kls.reshape(-1, num_generations), dim=1, index=order)
        return list(zip(input_strs, processed_outputs)), log_probs.reshape(-1, num_generations), kls
  • rerank_advantage_weight : 1.0, rerank_log_prob_weight : 0

GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50264, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(2): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(3): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(4): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(5): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(6): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(7): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(8): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(9): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(10): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50264, bias=False)

요약본 작성중

forward에서 하는 일
주어진 아이템 token의 emb와 attn mask를 self.model(GPT2LMHead) 로 넘겨 hidden state 생성

  • 이 과정을 self.model과 self.lm_target(no_grad)에 대해 동시에 실행
  • hidden_states와 target_hidden_states 생성

policy 훈련 안시킬거면 policy_hidden_state도 hidden_state와 동일
위에서 생성한 hidden_states에 대해 action와 state token이 idx 어디에 있는지 파악해서 torch.gather로 state_hidden_states와 action_hidden_state를 생성
target_hidden_states 으로부터 action_target_hidden_states 생성
Value network에는 state_hidden_state, Q network에는 action_hidden_state 넘겨서 vs, qs 생성
target Q network에 action_target_hidden_states 넘겨서 target_qs 생성
logits 는 그냥 [batch, seq_len, total_token] 인 zero tensor 넘김
output : vs, target_vs(그냥 vs랑 같음..), qs, target_qs (target_qs랑 target_qs2 중 작은걸로 clip), logits(zero tensor)

<lm_policy가 여기서 딱히 뭘 하지는 않음 >

get_qvs에서 하는 일
item foward에 통과시키기
forward에서 얻은 qs값과 action t+1 step 에 대응하는 token 사이 cql loss 구하기
q network 에 대해 다시 생각해볼 필요가 있음

  • input을 action으로 받고 모든 action에 대해 q값 estimate을 구함
  • 우리는 다음 step에 올 실제 action을 알고있음
  • 이를 바탕으로 gather를 하면 [batch, seq_len, action_dim] 크기의 q network output (forward에서 넘어온)을 => [batch, seq_len, 1] 으로 줄일 수 있음 (V와 동일한 dimension 완성)

하나에 안올라가서 deepspeed 시도중

accelerate configuration saved at /home/doolee13/.cache/huggingface/accelerate/default_config.yaml

명령어

accelerate launch train_iql_deepspeed.py

"fp16": {
"enabled": "true",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 15,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.0001
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 1.000000e+08,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1.000000e+08,
"contiguous_gradients": false
},
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false,
"gradient_accumulation_steps": 1,
"steps_per_print": inf,
"bf16": {
"enabled": false
}
}

ILQL beam_raw 부분

			  tokenizer = self.iql_model.dataset.tokenizer
        max_length = self.iql_model.dataset.max_len
        if max_length is None:
            max_length = self.iql_model.model.config.n_positions
        max_length = min(max_length, self.iql_model.model.config.n_positions)
        device = self.iql_model.device
				# (2, 50264)
        bsize, vocab_size = tokens.shape[0], tokenizer.num_tokens()
				# 20
        n = bsize * beam_width
        if max_generation_len is None:
            max_generation_len = max_length+1
        input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
        model_outputs = self.iql_model(tokens, attn_mask, 
                                       state_idxs, action_idxs, 
                                       prefix_embs=prefix_embs, 
                                       prefix_attn_mask=prefix_attn_mask, 
                                       remove_prefix_position_embs=remove_prefix_position_embs, 
                                       qv_kwargs={'use_cache': True}, 
                                       policy_kwargs={'use_cache': True}, 
                                       target_kwargs={'use_cache': True})['model_outputs']
        kvs = {'qv': model_outputs['qv_model_outputs'].past_key_values}
        if self.iql_model.lm_target is not None:
            kvs['target'] = model_outputs['target_model_outputs'].past_key_values
        if self.iql_model.lm_policy is not None:
            kvs['policy'] = model_outputs['policy_model_outputs'].past_key_values
				# [bsize, ] : [158, 284]
        original_dialogue_lens = attn_mask.sum(dim=1)
				# [bsize, beam_width]
				# [0,0,0,0 .. ,0]
				# [1,1,1,1 .. ,1]
				# [bsize-1, ..., bsize-1] 
        batch_indicator = torch.stack(beam_width*[torch.arange(0, bsize).to(device)], dim=1)
  • max_len 는 설정이 gpt2 보다 클 경우 gpt2 의 max_len 값으로 수정
  • input_str 은 batch 방향에서 attn_mask 길이(1들의 합) 만큼 token 불러와 forward
  • past_key_value 값을 이용해 다음 token 예측을 위한 key, value 값을 세팅해놓음
  • original_dialogue_lens : [batch_size]
  • batch_indicator : [batch_size, beam_width]
    • arange(0, bsize)를 beam_width만큼 반복
			  # [bsize, max among] -> [bsize * beam_width, max_len]
				# [2, 284] -> [20, 1024]
				tokens = pad_sequence(torch.repeat_interleave(tokens, beam_width, dim=0), max_length, tokenizer.pad_token_id, device, 1)
	      # [bsize * beam_width] 
				# [158, 158, 158, ..., 158, 284, 284, ..., 284]
				dialogue_lens = torch.repeat_interleave(original_dialogue_lens, beam_width, dim=0)
        kvs['qv'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs['qv'])
        if 'target' in kvs:
            kvs['target'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs['target'])
        if 'policy' in kvs:
            kvs['policy'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs['policy'])
        # [bsize, beam_width] = [2, 20]
				curr_scores = torch.zeros(bsize, beam_width).to(device)  # (batch, k)
        # [bsize, beam_width] = [2, 20]    
        logit_scores = torch.zeros(bsize, beam_width).to(device)  # (batch, k)
        # [bsize*beam_width] 
				termination_mask = torch.full((n,), 1).to(device)
        state_idxs_temp, action_idxs_temp = torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device), torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device)
        t = torch.min(dialogue_lens).int()
        # [bsize * beam_width]
				base_logits = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
  • torch.repeat_interleave는 [bsize, len] 길이의 token을 beam_width 번씩 반복
    • tokens : [bsize * beam_width, len] 크기의 결과
  • original_dialogue_lens는 [bsize,] 인데 이것도 beam_width 씩 반복
    • dialogue lens는 [bsize * beam_width,] 크기
while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
						# [bsize*beam_width, 1] = [20, 1] 
					  # looks like [a, a, a, a,a ..., b, b, b, b] 
            curr_token = tokens[:, t-1].unsqueeze(1)
            curr_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['qv'])
            curr_target_kvs, curr_policy_kvs = curr_kvs, curr_kvs
            if 'target' in kvs:
                curr_target_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['target'])
            if 'policy' in kvs:
                curr_policy_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['policy'])
            iql_outputs = self.iql_model(curr_token, None, state_idxs_temp, action_idxs_temp, 
                                         qv_kwargs={'use_cache': True, 'past_key_values': curr_kvs}, 
                                         policy_kwargs={'use_cache': True, 'past_key_values': curr_policy_kvs}, 
                                         target_kwargs={'use_cache': True, 'past_key_values': curr_target_kvs})
            model_outputs, logits = iql_outputs['model_outputs'], iql_outputs['logits']
						# logits shape : [20, 1, 50264]            

            logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
            logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
            edited_logits = process_logits(logits.clone(), temp=temp, top_k=top_k, top_p=top_p)
  • t 가 dialogue_lens 중 가장 작은 것
    • 일단 첫 스텝에서는 가장 짧은 dialogue 의 curr_token 만 마지막걸 제대로 가리키고 있음
  • logits 의 크기는 [bsize * beam_width, 1, vocab] 이고 timestep 0에 대해 작업중
  • termination mask가 1이면 float(’-inf’)를 부여해 아직 끝나지 못하도록,
    • termination mask 0으로 끝나야 한다면 pad token logit 이 1e7이 되어 무조건 pad token이 나오도록
  • masked_fill 은 아직 dialogue가 끝나지 않은 문장들은 그 문장이 그대로 올 수 있게 작업
vs, qs = iql_outputs['target_vs'], iql_outputs['target_qs']
            if exp_adv:
                adv_logits = adv_weight * (qs - vs.unsqueeze(2))
            else:
                adv_sign = ((qs - vs.unsqueeze(2)) > 0.0).float()
                adv_logits = adv_weight * adv_sign + (1 - adv_weight) * (1 - adv_sign)
                adv_logits = torch.log(adv_logits)
            if adv_clip is not None:
                adv_logits = torch.clip(adv_logits, max=adv_clip)
            adv_logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
            adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)

            full_logits = (edited_logits if include_logits else 0.0) + (adv_logits if include_adv else 0.0) + base_logits.unsqueeze(1).unsqueeze(2)
  • 훈련한 vs, qs 에 대해 과정반복
  • 최종 결과는 edited_logits + adv_logits 로 full_logits 생성 : [bsize*beam_width, 1, vocab]
scores = (torch.log(F.softmax(full_logits, dim=-1)).reshape(1, bsize, beam_width, -1).permute(3, 0, 1, 2) + curr_scores).permute(1, 2, 3, 0).reshape(1, bsize, -1)  # (time, batch, k*vocab)
curr_scores, top_k_ = torch.topk(scores[0, :, :], k=beam_width, dim=1)  # (batch, k), (batch, k)
tokens = tokens[(batch_indicator * beam_width + (top_k_ // vocab_size)).reshape(-1), :]
tokens[:, t] = top_k_.reshape(-1) % vocab_size
  • full_logits : [20, 1, 50264] 인데 → (1, 2, 10, 50264) → (50264, 1, 2, 10) + curr_score → (1, 2, 10, 50264) → (1, 2, 502640)
  • curr_scores : [bsize, beam_width] = (2, 10) 이며 이는 각 배치마다 k 갈래 각각 점수 기록용
  • (50264, 1, 2, 10) 은 각 단어를 취할 때 갈래가 얻는 점수기록판
  • (1, 2, 502640) 은 배치 별로 (단어, 갈래) 의 모든 순서쌍에 대한 점수
    • (batch = 0, k = 0) 일 때 단어 50264, (batch = 0, k = 1) 일 때 단어 50264, …
      • 같은 순서쌍 그룹에서 최대 k 에 두개이상 포함되어도 과거 token 업데이트를 같은 값으로 하고 t 시점 다른 단어를 %vocab_size로 주기 때문에 걱정 x
  • topk // vocab_size는 (단어, 갈래) 순서쌍 중 최대 값 index를 단어로 나눈 것이므로 갈래를 의미할 것
    • batch 0에 대한 갈래정보부터 나올 것이므로 batch_indicator로 맞춰줌
    • batchindicator가 [0, 0, 0, .., 0, 1, 1, 1..,1 ] 이렇게 붙어있으므로 [batch,k] 뽑은 top_k 갈래 계산 후 reshape(-1)로 일렬로 만듦
    • token은 [bsize * k, max_len] 인데 각 batch 마다 k 개의 최고 갈래를 제안했으므로 이걸로 token을 업데이트
profile
0100101

0개의 댓글