scripts/train/bc_train_loop.py
raw_dataset_train = load_item(cfg['train_dataset'], system_cfg['device'])
src/load_objects.py
def load_item(config, *args, verbose=True):
config = config.copy()
name = config.pop('name')
if name not in registry:
raise NotImplementedError
if 'cache_id' in config:
if (name, config['cache_id']) in cache:
if verbose:
print(f'loading from cache ({name}, {config["cache_id"]})')
return cache[(name, config['cache_id'])]
if verbose:
print(f'loading {name}: {config}')
**item = registry[name](config, *args, verbose=verbose)**
if 'cache_id' in config:
print(f'saving to cache ({name}, {config["cache_id"]})')
cache[(name, config['cache_id'])] = item
return item
src/visdial/load_objects.py
@register('vis_dial_list_dataset')
def load_vis_list_dataset(config, device, verbose=True):
**vd = load_item(config['data'], verbose=verbose)**
token_reward = load_item(config['token_reward'], device, verbose=verbose)
return VisDialListDataset(vd, max_len=config['max_len'],
token_reward=token_reward,
top_p=config['top_p'],
bottom_p=config['bottom_p'])
src/visdial/load_objects.py
def load_vis_dial(config, verbose=True):
if config['additional_scenes'] is not None:
with open(convert_path(config['additional_scenes']), 'rb') as f:
config['additional_scenes'] = pkl.load(f)
if config['cutoff_rule'] is not None:
**config['cutoff_rule'] = load_item(config['cutoff_rule'], verbose=verbose)**
return VisDialogueData(convert_path(config['data_path']),
convert_path(config['img_feat_path']),
config['split'],
reward_cache=convert_path(config['reward_cache']),
norm_img_feats=config['norm_img_feats'],
reward_shift=config['reward_shift'],
reward_scale=config['reward_scale'],
addition_scenes=config['additional_scenes'],
mode=config['mode'],
cutoff_rule=config['cutoff_rule'],
yn_reward=config['yn_reward'],
yn_reward_kind=config['yn_reward_kind'])
src/visdial/visdial_base.py
class PercentileCutoffRule:
def __init__(self, goal_value: float, percentile: float):
self.goal_value = goal_value
self.percentile = percentile
def apply_rule(self, scene: Scene, event: Event):
progress = sum([ev.progress for ev in event.get_events()]) / (self.goal_value-scene.initial_val)
return progress >= self.percentile
def load_vis_dial(config, verbose=True):
if config['additional_scenes'] is not None:
with open(convert_path(config['additional_scenes']), 'rb') as f:
config['additional_scenes'] = pkl.load(f)
if config['cutoff_rule'] is not None:
config['cutoff_rule'] = load_item(config['cutoff_rule'], verbose=verbose)
return **VisDialogueData(convert_path(config['data_path']),
convert_path(config['img_feat_path']),
config['split'],
reward_cache=convert_path(config['reward_cache']),
norm_img_feats=config['norm_img_feats'],
reward_shift=config['reward_shift'],
reward_scale=config['reward_scale'],
addition_scenes=config['additional_scenes'],
mode=config['mode'],
cutoff_rule=config['cutoff_rule'],
yn_reward=config['yn_reward'],
yn_reward_kind=config['yn_reward_kind'])**
src/visdial/visdial_base.py (VisDialogueData class init 과정)
class VisDialogueData:
def __init__(self, data_path: str, img_feat_path: str,
split: str, reward_cache: Optional[str]=None,
norm_img_feats: bool=True, reward_shift: float=0.0,
reward_scale: float=1.0,
addition_scenes: Optional[List[Scene]]=None,
mode: str='env_stops',
cutoff_rule: Optional[CutoffRule]=None,
yn_reward: float=-2.0, yn_reward_kind: str='none'):
assert mode in ['agent_stops', 'env_stops', '10_stop']
assert yn_reward_kind in yn_reward_fs
if mode == 'env_stops':
if cutoff_rule is None:
cutoff_rule = PercentileCutoffRule(1.0, 0.5)
assert reward_cache is not None
yn_reward_f = yn_reward_fs[yn_reward_kind]
self.norm_img_feats = norm_img_feats
with open(data_path, 'r') as f:
data = json.load(f)
if reward_cache is not None:
with open(reward_cache, 'r') as f:
reward = json.load(f)
progress = reward
reward = [[item * reward_scale + reward_shift for item in rs[1:]] for rs in reward]
else:
progress = None
reward = None
img_feats = h5py.File(img_feat_path, 'r')['images_%s' % (split)]
if self.norm_img_feats:
img_feats = normalize(img_feats, axis=1, norm='l2')
assert len(img_feats) == len(data)
if mode == 'agent_stops':
self.scenes = sum([Scene.from_json_stops(data[i], img_feats[i],
reward if reward is None else reward[i],
progress[i] if progress is not None else None) for i in range(len(data))], [])
**elif mode == 'env_stops':
# maybe make reward 0 or -1 here
self.scenes = [Scene.from_json_cuttoff(data[i], img_feats[i],
progress[i] if progress is not None else None,
cutoff_rule, yn_reward, yn_reward_f) for i in range(len(data))]**
elif mode == '10_stop':
self.scenes = [Scene.from_json(data[i], img_feats[i],
reward if reward is None else reward[i],
progress[i] if progress is not None else None) for i in range(len(data))]
else:
raise NotImplementedError
if addition_scenes is not None:
self.scenes += addition_scenes
src/visdial/visdial_base.py
@dataclass
class Scene:
caption: str
img_feat: np.ndarray
events: List[Event]
initial_val: Optional[float]
cutoff_rule: Optional[CutoffRule]
@classmethod
def from_json_cuttoff(cls, scene_json, img_feat, progress, cutoff_rule, yn_reward, yn_reward_f):
caption = scene_json['caption']
events = []
for i in range(len(scene_json['dialog'])):
events.append(QuestionEvent(scene_json['dialog'][i]['question']+'?',
0.0, None, None, None))
events.append(AnswerEvent(scene_json['dialog'][i]['answer'],
-1.0 + (yn_reward if yn_reward_f is not None and yn_reward_f(scene_json['dialog'][i]['answer']) else 0.0),
0.0 if progress is None else progress[i+1],
None, None, None))
scene = cls(caption, img_feat, events, 0.0 if progress is None else progress[0], cutoff_rule)
for p, n in zip(events[:-1], events[1:]):
p.next = n
n.prev = p
for ev in events:
ev.scene = scene
for i, ev in enumerate(scene.events):
if isinstance(ev, AnswerEvent) and ev.is_final():
scene.events = scene.events[:(i+1)]
scene.events[-1].next = None
ev.reward = 0.0 + (yn_reward if yn_reward_f is not None and yn_reward_f(scene.events[-1].answer) else 0.0)
return scene
@register('vis_dial_list_dataset')
def load_vis_list_dataset(config, device, verbose=True):
vd = load_item(config['data'], verbose=verbose)
**token_reward = load_item(config['token_reward'], device, verbose=verbose)**
**return VisDialListDataset(vd, max_len=config['max_len'],
token_reward=token_reward,
top_p=config['top_p'],
bottom_p=config['bottom_p'])**
src/visdial/visdial_dataset.py
class VisDialListDataset(List_RL_Dataset):
def __init__(self, data: VisDialogueData,
max_len: Optional[int],
token_reward: TokenReward,
top_p: Optional[float]=None,
bottom_p: Optional[float]=None,
) -> None:
**tokenizer = VisDialTokenizer()**
super().__init__(tokenizer, token_reward, max_len)
self.data = data
self.datapoints = []
src/visdial/visdial_tokenizer.py
class VisDialTokenizer(Tokenizer):
def __init__(self):
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer.add_special_tokens({'additional_special_tokens': ['</a>', '<a>', '<stop>', '</eod>'],
'bos_token': '<s>',
'sep_token': '</s>',
'pad_token': '<|pad|>'})
super().__init__(self.tokenizer.convert_tokens_to_ids('<|pad|>'),
self.tokenizer.convert_tokens_to_ids('</s>'),
self.tokenizer.convert_tokens_to_ids('</a>'),
self.tokenizer.convert_tokens_to_ids('<s>'),
self.tokenizer.convert_tokens_to_ids('<a>'),
self.tokenizer.convert_tokens_to_ids('</eod>'))
self.stop_token = self.tokenizer.convert_tokens_to_ids('<stop>')
class VisDialListDataset(List_RL_Dataset):
def __init__(self, data: VisDialogueData,
max_len: Optional[int],
token_reward: TokenReward,
top_p: Optional[float]=None,
bottom_p: Optional[float]=None,
) -> None:
tokenizer = VisDialTokenizer()
super().__init__(tokenizer, token_reward, max_len)
self.data = data
self.datapoints = []
**for item in self.data:
obs = VDObservation(item, item.events[-1])
self.datapoints.append(DataPoint.from_obs(obs, self.tokenizer, self.token_reward))**
src/data/rl_data.py
@classmethod
def from_obs(cls, obs: Language_Observation, tokenizer: Tokenizer, token_reward: TokenReward, meta: Optional[Dict[str, Any]]=None):
**sequence, terminal = obs.to_sequence()**
src/visdial/visdial_env.py
def to_sequence(self) -> Tuple[List[Tuple[str, Optional[float]]], bool]:
if self.event is None:
return [(self.scene.caption, None)], False
**evs = self.event.get_events()**
sequence = [(self.scene.caption, None)]
sequence += [(str(evs[i]), evs[i+1].reward if isinstance(evs[i+1], AnswerEvent) else None) for i in range(len(evs)-1)]
sequence += [(str(evs[-1]), 0.0 if isinstance(evs[-1], StopEvent) else None)]
terminal = self.event.is_final()
return sequence, terminal
src/visdial/visdial_base.py
def get_events(self, direction="prev"):
if direction == "prev":
func = lambda ev: ev.prev
elif direction == "next":
func = lambda ev: ev.next
else:
raise NotImplementedError
events = []
ev = self
while ev is not None:
events.append(ev)
ev = func(ev)
if direction == 'prev':
events.reverse()
return events
def to_sequence(self) -> Tuple[List[Tuple[str, Optional[float]]], bool]:
if self.event is None:
return [(self.scene.caption, None)], False
****evs = self.event.get_events()
**sequence = [(self.scene.caption, None)]
sequence += [(str(evs[i]), evs[i+1].reward if isinstance(evs[i+1], AnswerEvent) else None) for i in range(len(evs)-1)]
sequence += [(str(evs[-1]), 0.0 if isinstance(evs[-1], StopEvent) else None)]
terminal = self.event.is_final()**
return sequence, terminal
def from_obs(cls, obs: Language_Observation, tokenizer: Tokenizer, token_reward: TokenReward, meta: Optional[Dict[str, Any]]=None):
sequence, terminal = obs.to_sequence()
obs_meta = obs.metadata()
if meta is not None and obs_meta is not None:
meta = {**obs_meta, **meta}
**elif obs_meta is not None:
meta = obs_meta**
if len(sequence) == 0 or sequence[0][1] is not None:
raw_str = tokenizer.id_to_token(tokenizer.boa_token_id)
**else:
raw_str = tokenizer.id_to_token(tokenizer.bos_token_id)**
action_rewards = []
for s, r in sequence:
raw_str += s
if r is None:
raw_str += tokenizer.id_to_token(tokenizer.eos_token_id)
else:
raw_str += tokenizer.id_to_token(tokenizer.eoa_token_id)
action_rewards.append(r)
if terminal:
raw_str += tokenizer.id_to_token(tokenizer.eod_token_id)
tokens = tokenizer.encode(raw_str)[0]
token_rewards = token_reward.get_token_reward(tokens)
for i, t in enumerate(tokens):
if t == tokenizer.eos_token_id:
curr_idx = i
elif t == tokenizer.eoa_token_id:
action_idxs.extend(list(range(curr_idx, i)))
state_idxs.extend(list(range(curr_idx, i)))
reward.extend([token_rewards[x] for x in range(curr_idx, i)])
reward[-1] += action_rewards[curr_action_idx]
utterance_action_idxs.append(i)
utterance_state_idxs.append(curr_idx)
utterance_rewards.append(action_rewards[curr_action_idx]+sum([token_rewards[x] for x in range(curr_idx, i)]))
curr_idx = i
curr_action_idx += 1
**state_idxs.append(len(tokens)-1)**
utterance_state_idxs.append(len(tokens)-1)
terminals = ([0] * (len(state_idxs)-1))+[int(terminal)]
utterance_terminals = ([0] * (len(utterance_state_idxs)-1))+[int(terminal)]
return cls(raw_str, tokens, state_idxs, action_idxs, reward, terminals,
utterance_state_idxs, utterance_action_idxs,
utterance_rewards, utterance_terminals, meta=meta)
def load_item(config, *args, verbose=True):
config = config.copy()
name = config.pop('name')
if name not in registry:
raise NotImplementedError
if 'cache_id' in config:
if (name, config['cache_id']) in cache:
if verbose:
print(f'loading from cache ({name}, {config["cache_id"]})')
return cache[(name, config['cache_id'])]
if verbose:
print(f'loading {name}: {config}')
item = registry[name](config, *args, verbose=verbose)
if 'cache_id' in config:
print(f'saving to cache ({name}, {config["cache_id"]})')
cache[(name, config['cache_id'])] = item
**return item**
scripts/train/bc_train_loop.py
raw_dataset_train = load_item(cfg['train_dataset'], system_cfg['device'])
raw_dataset_eval = load_item(cfg['eval_dataset'], system_cfg['device'])
if isinstance(raw_dataset_train, Iterable_RL_Dataset):
dataset_train = GeneralIterDataset(raw_dataset_train, 'cpu')
else:
**dataset_train = GeneralDataset(raw_dataset_train, 'cpu')**
if isinstance(raw_dataset_eval, Iterable_RL_Dataset):
dataset_eval = GeneralIterDataset(raw_dataset_eval, 'cpu')
else:
**dataset_eval = GeneralDataset(raw_dataset_eval, 'cpu')**
src/data/torch_datsets.py
mclass GeneralDataset(Dataset):
def __init__(self, rl_dataset: List_RL_Dataset,
device: Union[torch.device, str]):
self.rl_dataset = rl_dataset
self.device = device
def __len__(self):
return self.rl_dataset.size()
def __getitem__(self, i):
return self.rl_dataset.get_item(i)
def collate(self, items):
return self.rl_dataset.collate(items, self.device)
def collate_simple(self, items):
return items
src/data/rl_data.py
class RL_Dataset(ABC):
def collate(self, items: List[DataPoint], device):
tokens, state_idxs, action_idxs, rewards, terminals, u_state_idxs, u_action_idxs, u_rewards, u_terminals = zip(*map(lambda x: x.to_tensors(device, self.max_len), items))
tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=self.tokenizer.pad_token_id)
attn_mask = (tokens != self.tokenizer.pad_token_id).float()
state_idxs = torch.nn.utils.rnn.pad_sequence(state_idxs, batch_first=True, padding_value=0)
action_idxs = torch.nn.utils.rnn.pad_sequence(action_idxs, batch_first=True, padding_value=0)
terminals = torch.nn.utils.rnn.pad_sequence(terminals, batch_first=True, padding_value=1)
rewards = torch.nn.utils.rnn.pad_sequence(rewards, batch_first=True, padding_value=0.0)
u_state_idxs = torch.nn.utils.rnn.pad_sequence(u_state_idxs, batch_first=True, padding_value=0)
u_action_idxs = torch.nn.utils.rnn.pad_sequence(u_action_idxs, batch_first=True, padding_value=0)
u_terminals = torch.nn.utils.rnn.pad_sequence(u_terminals, batch_first=True, padding_value=1)
u_rewards = torch.nn.utils.rnn.pad_sequence(u_rewards, batch_first=True, padding_value=0.0)
return {'tokens': tokens, 'attn_mask': attn_mask,
'state_idxs': state_idxs, 'action_idxs': action_idxs,
'rewards': rewards, 'terminals': terminals,
'u_state_idxs': u_state_idxs, 'u_action_idxs': u_action_idxs,
'u_rewards': u_rewards, 'u_terminals': u_terminals}
모델 load에 대한 내용
scripts/train/bc_train_loop.py
model = load_item(cfg['model'], system_cfg['device'])
src/load_objects.py
@register('bc_lm')
def load_bc_lm(config, device, verbose=True):
**gpt2 = load_item(config['gpt2'], verbose=verbose)**
dataset = load_item(config['dataset'], device, verbose=verbose)
model = BC_LM(gpt2, dataset, device, config['transition_weight'])
return load_model(config['load'], model, device, verbose=verbose)
@register('gpt2')
def load_gpt2(config, verbose=True):
obj = GPT2LMHeadModel if config['lm_head'] else GPT2Model
if config['from_pretrained']:
return obj.from_pretrained(config['gpt2_type'])
config = GPT2Config.from_pretrained(config['gpt2_type'])
return obj(config)
src/models/bc_lm.py
class BC_LM(BaseTransformer):
def __init__(self,
model: PreTrainedModel,
dataset: RL_Dataset,
device: Union[torch.device, str] = "cuda",
transition_weight: float=0.0,
):
assert isinstance(model, GPT2LMHeadModel)
super().__init__(model, dataset, device)
self.h_dim = self.model.config.n_embd
self.transition_weight = transition_weight
inherit from this src/models/base.py
class BaseModel(ABC, nn.Module):
def __init__(self,
dataset: RL_Dataset,
device: Union[torch.device, str]) -> None:
super().__init__()
self.dataset = dataset
self.device = device
self.max_len = self.dataset.max_len
def prepare_inputs(self, items: InputType):
if isinstance(items, dict):
return items
return to(self.dataset.collate(items, self.device), self.device)
@abstractmethod
def get_loss(self, items: InputType, **kwargs):
pass
class BaseTransformer(BaseModel):
def __init__(self,
pretrained_model: PreTrainedModel,
dataset: RL_Dataset,
device: Union[torch.device, str]) -> None:
super().__init__(dataset, device)
self.model = pretrained_model
self.model.resize_token_embeddings(self.dataset.tokenizer.num_tokens())
모델 로드 개빡세네..
알고보니 huggingface 구현체 그대로 복사해놓은 것이었음
훈련부분 다시
scripts/train/bc_train_loop.py
for epoch in tqdm(range(train_cfg['epochs']), disable=not accelerator.is_local_main_process):
for items in tqdm(data_loader, disable=not accelerator.is_local_main_process):
items = to(items, system_cfg['device'])
**loss, logs, postproc_fs = accelerator.unwrap_model(model).get_loss(items, **train_cfg['loss'])**
src/models/bc_lm.py
def get_loss(self,
items: InputType):
**prepared_inputs = self.prepare_inputs(items)**
tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
src/models/base.py
def prepare_inputs(self, items: InputType):
**if isinstance(items, dict):
return items**
return to(self.dataset.collate(items, self.device), self.device)
def get_loss(self,
items: InputType):
prepared_inputs = self.prepare_inputs(items)
tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
a_idx = prepared_inputs['action_idxs']
model_outputs = self(tokens, attn_mask,
output_attentions=True)
logs = {}
transformer_logs = get_transformer_logs(model_outputs.attentions,
self.model,
attn_mask)
n = attn_mask.sum().item()
**weights = self.get_weights(tokens, a_idx)**
def get_weights(self,
tokens: torch.Tensor,
action_idxs: torch.Tensor):
weights = torch.full(tokens.shape, self.transition_weight).to(self.device)
if action_idxs.shape[1] == 0:
n = torch.zeros((tokens.shape[0],)).long().to(self.device)
else:
n = torch.argmax(action_idxs, dim=1)+1
for i in range(tokens.shape[0]):
weights[i] = torch.scatter(weights[i], dim=0, index=action_idxs[i, :n[i]], src=torch.full((n[i].item(),), 1.0).to(self.device))
return weights
def get_loss(self,
items: InputType):
prepared_inputs = self.prepare_inputs(items)
tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
a_idx = prepared_inputs['action_idxs']
model_outputs = self(tokens, attn_mask,
output_attentions=True)
logs = {}
transformer_logs = get_transformer_logs(model_outputs.attentions,
self.model,
attn_mask)
n = attn_mask.sum().item()
weights = self.get_weights(tokens, a_idx)
**token_loss = self.awac_loss(tokens, attn_mask, model_outputs.logits, weights)**
def awac_loss(self, tokens, attn_mask, logits, w):
w = w.detach()
losses = F.cross_entropy(logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1), reduction='none')
losses = losses.reshape(tokens.shape[0], tokens.shape[1]-1)
return (losses * w[:, :-1] * attn_mask[:, 1:]).sum() / attn_mask[:, 1:].sum()
evalutaor 부분 리뷰해야함
scripts/train/bc_train_loop.py
if cfg['evaluator'] is not None:
evaluator = load_item(cfg['evaluator'], system_cfg['device'])
src/load_objects.py
@register('bc_evaluator')
def load_bc_evaluator(config, device, verbose=True):
env = load_item(config['env'], device, verbose=verbose)
return BC_Evaluator(env, config['env'], config['kind'], **config['generation_kwargs'])
scripts/train/bc_train_loop.py
with torch.no_grad():
for i, eval_items in enumerate(eval_data_loader):
eval_items = to(eval_items, system_cfg['device'])
if i >= train_cfg['eval_batches']:
break
_, logs, postproc_fs = accelerator.unwrap_model(model).get_loss(eval_items, **train_cfg['loss'])
**if evaluator is not None:
evaluator_logs = evaluator.evaluate(accelerator.unwrap_model(model), eval_items)**
if evaluator_logs is not None:
logs['evaluation'] = evaluator_logs
eval_logs.accum_logs(logs)
eval_label = 'eval'
eval_total_logs = eval_logs.log(*postproc_fs,
partial(label_logs, label=eval_label),
iteration=step, epoch=epoch)
src/models/bc_lm.py
def evaluate(self, model: BC_LM, items: InputType) -> Optional[Dict[str, Any]]:
policy = BC_Policy(model, self.kind, **self.generation_kwargs)
tokens = model.prepare_inputs(items)['tokens']
n = tokens.shape[0]
total_token_reward = 0
total_env_reward = 0
for i in range(n):
**result, sequence = interact_environment(self.env, policy, None)**
src/data/language_environment.py
def interact_environment(env: Language_Environment, policy: Policy, obs: Optional[Language_Observation]=None):
obs_sequence = []
if obs is None:
obs = env.reset()
while not env.is_terminal():
**action = policy.act(obs)**
src/models/bc_lm.py
# from BC_Policy class
def act(self, obs: Language_Observation) -> str:
item = DataPoint.from_obs(obs, self.bc_lm.dataset.tokenizer, self.bc_lm.dataset.token_reward)
**generations, probs = self.generate([item], always_terminate, **self.generation_kwargs)**
sorted_outputs = list(zip(*sorted(zip(generations[0][1], probs[0]), key=lambda x: -x[1])))[0]
return sorted_outputs[0]
def generate(self, items: InputType,
termination_condition: Callable[[np.ndarray], bool], **kwargs):
prepared_inputs = self.bc_lm.prepare_inputs(items)
tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
if self.kind == 'beam':
method = self.beam_raw
**elif self.kind == 'sample':
method = self.sample_raw**
else:
raise NotImplementedError
generations, probs = method(tokens, attn_mask,
termination_condition,
**kwargs)
return generations, probs
def sample_raw(self,
tokens: torch.Tensor, attn_mask: torch.Tensor,
termination_condition: Callable[[np.ndarray], bool],
num_generations=1, max_generation_len=None,
temp=1.0, top_k=None, top_p=None,
prefix_embs: Optional[torch.Tensor]=None,
prefix_attn_mask: Optional[torch.Tensor]=None,
remove_prefix_position_embs: bool=False):
tokenizer = self.bc_lm.dataset.tokenizer
max_length = self.bc_lm.dataset.max_len
if max_length is None:
max_length = self.bc_lm.model.config.n_positions
max_length = min(max_length, self.bc_lm.model.config.n_positions)
device = self.bc_lm.device
bsize = tokens.shape[0]
n = bsize * num_generations
if max_generation_len is None:
max_generation_len = max_length+1
input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
model_outputs = self.bc_lm(tokens, attn_mask, prefix_embs=prefix_embs,
prefix_attn_mask=prefix_attn_mask,
remove_prefix_position_embs=remove_prefix_position_embs,
use_cache=True)
dialogue_kvs = model_outputs.past_key_values
dialogue_lens = attn_mask.sum(dim=1)
tokens = pad_sequence(torch.repeat_interleave(tokens, num_generations, dim=0), max_length, tokenizer.pad_token_id, device, 1)
dialogue_lens = torch.repeat_interleave(dialogue_lens, num_generations, dim=0)
**dialogue_kvs = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), dialogue_kvs)**
log_probs = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
termination_mask = torch.full((dialogue_lens.shape[0],), 1).to(device)
t = torch.min(dialogue_lens).int()
src/utils/sampling_utils.py
map_all_kvs = lambda f, kvs: tuple([tuple(map(f, items)) for items in kvs])
dialogue_kvs = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), dialogue_kvs)
src/models/bc_lm.py 의 sample_raw function 이어서
while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
curr_token = tokens[:, t-1].unsqueeze(1)
curr_dialogue_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], dialogue_kvs)
transformer_outputs = self.bc_lm(curr_token, None, past_key_values=curr_dialogue_kvs, use_cache=True)
logits = transformer_outputs.logits
logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
logits = process_logits(transformer_outputs.logits, temp=temp, top_k=top_k, top_p=top_p)
cat_dist = torch.distributions.categorical.Categorical(logits=logits[:, 0])
new_tokens = cat_dist.sample()
log_probs += cat_dist.log_prob(new_tokens)
tokens[:, t] = new_tokens
dialogue_kvs = update_kvs(dialogue_kvs, transformer_outputs.past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
for idx in range(n):
if tokens[idx, t] == tokenizer.eoa_token_id and t >= dialogue_lens[idx]:
termination_mask[idx] *= (1 - int(termination_condition(tokenizer.decode(tokens[idx, :].tolist(),
clean_up_tokenization_spaces=False))))
t += 1
termination_mask *= ((t-dialogue_lens) < max_generation_len).int()
output_strs = [tokenizer.decode(tokens[i, :].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
processed_outputs = []
for i in range(len(input_strs)):
temp_outputs = []
for x in range(num_generations):
processed_str = output_strs[i*num_generations+x][len(input_strs[i]):].strip()
if tokenizer.id_to_token(tokenizer.pad_token_id) in processed_str:
processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.pad_token_id))].strip()
if tokenizer.id_to_token(tokenizer.eoa_token_id) in processed_str:
processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.eoa_token_id))].strip()
temp_outputs.append(processed_str)
processed_outputs.append(temp_outputs)
return list(zip(input_strs, processed_outputs)), log_probs.reshape(-1, num_generations)
src/models/bc_lm.py
# from BC_Policy class
def act(self, obs: Language_Observation) -> str:
item = DataPoint.from_obs(obs, self.bc_lm.dataset.tokenizer, self.bc_lm.dataset.token_reward)
generations, probs = self.generate([item], always_terminate, **self.generation_kwargs)
sorted_outputs = list(zip(*sorted(zip(generations[0][1], probs[0]), key=lambda x: -x[1])))[0]
return sorted_outputs[0]
checkpoint save directory
/home/doolee13/Implicit-Language-Q-Learning/src/utils/../../outputs/visual_dialogue/visdial_bc_test1/
모델 변환 커맨드
python convert_bc.py --load ../../outputs/visual_dialogue/visdial_bc_test1/model.pkl --save ../../outputs/visual_dialogue/visdial_bc_test1/model_converted.pkl
scripts/data/convert_bc.py
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str)
parser.add_argument('--save', type=str)
args = parser.parse_args()
state_dict = torch.load(args.load, map_location=torch.device('cpu'))
for k in list(state_dict.keys()):
if 'model.' in k:
if k.startswith('model'):
state_dict['lm_policy'+k[len('model'):]] = state_dict.pop(k)
torch.save(state_dict, args.save)
odict_keys(['lm_policy.transformer.wte.weight', 'lm_policy.transformer.wpe.weight', 'lm_policy.transformer.h.0.ln_1.weight', 'lm_policy.transformer.h.0.ln_1.bias', 'lm_policy.transformer.h.0.attn.bias', 'lm_policy.transformer.h.0.attn.masked_bias', 'lm_policy.transformer.h.0.attn.c_attn.weight', 'lm_policy.transformer.h.0.attn.c_attn.bias', 'lm_policy.transformer.h.0.attn.c_proj.weight', 'lm_policy.transformer.h.0.attn.c_proj.bias', 'lm_policy.transformer.h.0.ln_2.weight', 'lm_policy.transformer.h.0.ln_2.bias', 'lm_policy.transformer.h.0.mlp.c_fc.weight', 'lm_policy.transformer.h.0.mlp.c_fc.bias', 'lm_policy.transformer.h.0.mlp.c_proj.weight', 'lm_policy.transformer.h.0.mlp.c_proj.bias', 'lm_policy.transformer.h.1.ln_1.weight', 'lm_policy.transformer.h.1.ln_1.bias', 'lm_policy.transformer.h.1.attn.bias', 'lm_policy.transformer.h.1.attn.masked_bias', 'lm_policy.transformer.h.1.attn.c_attn.weight', 'lm_policy.transformer.h.1.attn.c_attn.bias', 'lm_policy.transformer.h.1.attn.c_proj.weight', 'lm_policy.transformer.h.1.attn.c_proj.bias', 'lm_policy.transformer.h.1.ln_2.weight', 'lm_policy.transformer.h.1.ln_2.bias', 'lm_policy.transformer.h.1.mlp.c_fc.weight', 'lm_policy.transformer.h.1.mlp.c_fc.bias', 'lm_policy.transformer.h.1.mlp.c_proj.weight', 'lm_policy.transformer.h.1.mlp.c_proj.bias', 'lm_policy.transformer.h.2.ln_1.weight', 'lm_policy.transformer.h.2.ln_1.bias', 'lm_policy.transformer.h.2.attn.bias', 'lm_policy.transformer.h.2.attn.masked_bias', 'lm_policy.transformer.h.2.attn.c_attn.weight', 'lm_policy.transformer.h.2.attn.c_attn.bias', 'lm_policy.transformer.h.2.attn.c_proj.weight', 'lm_policy.transformer.h.2.attn.c_proj.bias', 'lm_policy.transformer.h.2.ln_2.weight', 'lm_policy.transformer.h.2.ln_2.bias', 'lm_policy.transformer.h.2.mlp.c_fc.weight', 'lm_policy.transformer.h.2.mlp.c_fc.bias', 'lm_policy.transformer.h.2.mlp.c_proj.weight', 'lm_policy.transformer.h.2.mlp.c_proj.bias', 'lm_policy.transformer.h.3.ln_1.weight', 'lm_policy.transformer.h.3.ln_1.bias', 'lm_policy.transformer.h.3.attn.bias', 'lm_policy.transformer.h.3.attn.masked_bias', 'lm_policy.transformer.h.3.attn.c_attn.weight', 'lm_policy.transformer.h.3.attn.c_attn.bias', 'lm_policy.transformer.h.3.attn.c_proj.weight', 'lm_policy.transformer.h.3.attn.c_proj.bias', 'lm_policy.transformer.h.3.ln_2.weight', 'lm_policy.transformer.h.3.ln_2.bias', 'lm_policy.transformer.h.3.mlp.c_fc.weight', 'lm_policy.transformer.h.3.mlp.c_fc.bias', 'lm_policy.transformer.h.3.mlp.c_proj.weight', 'lm_policy.transformer.h.3.mlp.c_proj.bias', 'lm_policy.transformer.h.4.ln_1.weight', 'lm_policy.transformer.h.4.ln_1.bias', 'lm_policy.transformer.h.4.attn.bias', 'lm_policy.transformer.h.4.attn.masked_bias', 'lm_policy.transformer.h.4.attn.c_attn.weight', 'lm_policy.transformer.h.4.attn.c_attn.bias', 'lm_policy.transformer.h.4.attn.c_proj.weight', 'lm_policy.transformer.h.4.attn.c_proj.bias', 'lm_policy.transformer.h.4.ln_2.weight', 'lm_policy.transformer.h.4.ln_2.bias', 'lm_policy.transformer.h.4.mlp.c_fc.weight', 'lm_policy.transformer.h.4.mlp.c_fc.bias', 'lm_policy.transformer.h.4.mlp.c_proj.weight', 'lm_policy.transformer.h.4.mlp.c_proj.bias', 'lm_policy.transformer.h.5.ln_1.weight', 'lm_policy.transformer.h.5.ln_1.bias', 'lm_policy.transformer.h.5.attn.bias', 'lm_policy.transformer.h.5.attn.masked_bias', 'lm_policy.transformer.h.5.attn.c_attn.weight', 'lm_policy.transformer.h.5.attn.c_attn.bias', 'lm_policy.transformer.h.5.attn.c_proj.weight', 'lm_policy.transformer.h.5.attn.c_proj.bias', 'lm_policy.transformer.h.5.ln_2.weight', 'lm_policy.transformer.h.5.ln_2.bias', 'lm_policy.transformer.h.5.mlp.c_fc.weight', 'lm_policy.transformer.h.5.mlp.c_fc.bias', 'lm_policy.transformer.h.5.mlp.c_proj.weight', 'lm_policy.transformer.h.5.mlp.c_proj.bias', 'lm_policy.transformer.h.6.ln_1.weight', 'lm_policy.transformer.h.6.ln_1.bias', 'lm_policy.transformer.h.6.attn.bias', 'lm_policy.transformer.h.6.attn.masked_bias', 'lm_policy.transformer.h.6.attn.c_attn.weight', 'lm_policy.transformer.h.6.attn.c_attn.bias', 'lm_policy.transformer.h.6.attn.c_proj.weight', 'lm_policy.transformer.h.6.attn.c_proj.bias', 'lm_policy.transformer.h.6.ln_2.weight', 'lm_policy.transformer.h.6.ln_2.bias', 'lm_policy.transformer.h.6.mlp.c_fc.weight', 'lm_policy.transformer.h.6.mlp.c_fc.bias', 'lm_policy.transformer.h.6.mlp.c_proj.weight', 'lm_policy.transformer.h.6.mlp.c_proj.bias', 'lm_policy.transformer.h.7.ln_1.weight', 'lm_policy.transformer.h.7.ln_1.bias', 'lm_policy.transformer.h.7.attn.bias', 'lm_policy.transformer.h.7.attn.masked_bias', 'lm_policy.transformer.h.7.attn.c_attn.weight', 'lm_policy.transformer.h.7.attn.c_attn.bias', 'lm_policy.transformer.h.7.attn.c_proj.weight', 'lm_policy.transformer.h.7.attn.c_proj.bias', 'lm_policy.transformer.h.7.ln_2.weight', 'lm_policy.transformer.h.7.ln_2.bias', 'lm_policy.transformer.h.7.mlp.c_fc.weight', 'lm_policy.transformer.h.7.mlp.c_fc.bias', 'lm_policy.transformer.h.7.mlp.c_proj.weight', 'lm_policy.transformer.h.7.mlp.c_proj.bias', 'lm_policy.transformer.h.8.ln_1.weight', 'lm_policy.transformer.h.8.ln_1.bias', 'lm_policy.transformer.h.8.attn.bias', 'lm_policy.transformer.h.8.attn.masked_bias', 'lm_policy.transformer.h.8.attn.c_attn.weight', 'lm_policy.transformer.h.8.attn.c_attn.bias', 'lm_policy.transformer.h.8.attn.c_proj.weight', 'lm_policy.transformer.h.8.attn.c_proj.bias', 'lm_policy.transformer.h.8.ln_2.weight', 'lm_policy.transformer.h.8.ln_2.bias', 'lm_policy.transformer.h.8.mlp.c_fc.weight', 'lm_policy.transformer.h.8.mlp.c_fc.bias', 'lm_policy.transformer.h.8.mlp.c_proj.weight', 'lm_policy.transformer.h.8.mlp.c_proj.bias', 'lm_policy.transformer.h.9.ln_1.weight', 'lm_policy.transformer.h.9.ln_1.bias', 'lm_policy.transformer.h.9.attn.bias', 'lm_policy.transformer.h.9.attn.masked_bias', 'lm_policy.transformer.h.9.attn.c_attn.weight', 'lm_policy.transformer.h.9.attn.c_attn.bias', 'lm_policy.transformer.h.9.attn.c_proj.weight', 'lm_policy.transformer.h.9.attn.c_proj.bias', 'lm_policy.transformer.h.9.ln_2.weight', 'lm_policy.transformer.h.9.ln_2.bias', 'lm_policy.transformer.h.9.mlp.c_fc.weight', 'lm_policy.transformer.h.9.mlp.c_fc.bias', 'lm_policy.transformer.h.9.mlp.c_proj.weight', 'lm_policy.transformer.h.9.mlp.c_proj.bias', 'lm_policy.transformer.h.10.ln_1.weight', 'lm_policy.transformer.h.10.ln_1.bias', 'lm_policy.transformer.h.10.attn.bias', 'lm_policy.transformer.h.10.attn.masked_bias', 'lm_policy.transformer.h.10.attn.c_attn.weight', 'lm_policy.transformer.h.10.attn.c_attn.bias', 'lm_policy.transformer.h.10.attn.c_proj.weight', 'lm_policy.transformer.h.10.attn.c_proj.bias', 'lm_policy.transformer.h.10.ln_2.weight', 'lm_policy.transformer.h.10.ln_2.bias', 'lm_policy.transformer.h.10.mlp.c_fc.weight', 'lm_policy.transformer.h.10.mlp.c_fc.bias', 'lm_policy.transformer.h.10.mlp.c_proj.weight', 'lm_policy.transformer.h.10.mlp.c_proj.bias', 'lm_policy.transformer.h.11.ln_1.weight', 'lm_policy.transformer.h.11.ln_1.bias', 'lm_policy.transformer.h.11.attn.bias', 'lm_policy.transformer.h.11.attn.masked_bias', 'lm_policy.transformer.h.11.attn.c_attn.weight', 'lm_policy.transformer.h.11.attn.c_attn.bias', 'lm_policy.transformer.h.11.attn.c_proj.weight', 'lm_policy.transformer.h.11.attn.c_proj.bias', 'lm_policy.transformer.h.11.ln_2.weight', 'lm_policy.transformer.h.11.ln_2.bias', 'lm_policy.transformer.h.11.mlp.c_fc.weight', 'lm_policy.transformer.h.11.mlp.c_fc.bias', 'lm_policy.transformer.h.11.mlp.c_proj.weight', 'lm_policy.transformer.h.11.mlp.c_proj.bias', 'lm_policy.transformer.ln_f.weight', 'lm_policy.transformer.ln_f.bias', 'lm_policy.lm_head.weight'])
이런식으로 바꿈 : 원래 다 model.() 이런 형식 이었는듯?
IQL 시작
들어가기 앞서
1) python train_iql.py model.load.checkpoint_path=outputs/visdial/model_converted.pkl model.load.strict_load=false train.loss.awac_weight=0.0
이런식으로 trained BC 이용해서 훈련시킬 수 있고
2) 그냥 python train_iql.py 하면 (awac_weight =1.0) BC와 offline RL 동시에 훈련 가능하다
여기서부터는 그냥 training 동시에 하는 경우 가정
실제로 사용하는 cmd
python train_iql.py model.load.checkpoint_path=outputs/visual_dialogue/visdial_bc_test1/model_converted.pkl model.load.strict_load=false train.loss.awac_weight=0.0
일단 bc training 끝나고 불러오는 경우 가정
model load 부분에서 다름
src/load_objects.py
@register('per_token_iql')
def load_per_token_iql(config, device, verbose=True):
gpt2 = load_item(config['gpt2'], verbose=verbose)
dataset = load_item(config['dataset'], device, verbose=verbose)
model = PerTokenIQL(gpt2, dataset, device, config['alpha'], config['gamma'],
config['beta'], config['transition_weight'], config['clip_weight'],
config['value_max'], config['value_min'], config['detach_v'],
config['detach_pi'], config['detach_q'], config['double_q'],
config['tau'], config['seperate_policy'], config['seperate_target'],
config['exp_weights'], config['dm_margin'], config['advanced_mlp'],
config['cql_temp'])
return load_model(config['load'], model, device, verbose=verbose)
이렇게 PerTokenIQL 클래스를 만들고
src/load_object.py
def load_model(config, model, device, verbose=True):
model = model.to(device)
if config['checkpoint_path'] is not None:
if verbose:
print('loading %s state dict from: %s' % (config['name'], convert_path(config["checkpoint_path"])))
model.load_state_dict(torch.load(convert_path(config['checkpoint_path']), map_location='cpu'), strict=config['strict_load'])
if verbose:
print('loaded.')
return model
@register('per_token_iql')
def load_per_token_iql(config, device, verbose=True):
gpt2 = load_item(config['gpt2'], verbose=verbose)
dataset = load_item(config['dataset'], device, verbose=verbose)
model = PerTokenIQL(gpt2, dataset, device, config['alpha'], config['gamma'],
config['beta'], config['transition_weight'], config['clip_weight'],
config['value_max'], config['value_min'], config['detach_v'],
config['detach_pi'], config['detach_q'], config['double_q'],
config['tau'], config['seperate_policy'], config['seperate_target'],
config['exp_weights'], config['dm_margin'], config['advanced_mlp'],
config['cql_temp'])
**return load_model(config['load'], model, device, verbose=verbose)**
scripts/train/iql_train_loop.py
**loss, logs, postproc_fs = accelerator.unwrap_model(model).get_loss(items, **train_cfg['loss'])**
src/models/iql_model.py
def get_loss(self,
items: InputType,
awac_weight=0.0,
v_loss_weight=0.0, # 1.0
q_loss_weight=0.0, # 1.0
cql_loss_weight=0.0, # 1.0
dm_loss_weight=0.0,
mc_returns=False):
prepared_inputs = self.prepare_inputs(items)
a_idx = prepared_inputs['action_idxs']
**get_qvs_outputs = self.get_qvs(items,
qv_kwargs={'output_attentions': True},
policy_kwargs={'output_attentions': True},
target_kwargs={'output_attentions': True},
skip_policy_on_train=(awac_weight == 0.0),
)**
tokens, attn_mask, model_outputs = get_qvs_outputs['tokens'], get_qvs_outputs['attn_mask'], get_qvs_outputs['model_outputs']
vs, qs = get_qvs_outputs['vs'], get_qvs_outputs['qs']
vns, target_qs, rs = get_qvs_outputs['vns'], get_qvs_outputs['target_qs'], get_qvs_outputs['rs']
terminals, logits, weights = get_qvs_outputs['terminals'], get_qvs_outputs['logits'], get_qvs_outputs['weights']
def get_qvs(self, items: InputType,
prefix_embs: Optional[torch.Tensor]=None,
prefix_attn_mask: Optional[torch.Tensor]=None,
remove_prefix_position_embs: bool=False,
qv_kwargs=None, policy_kwargs=None, target_kwargs=None,
**kwargs):
prepared_inputs = self.prepare_inputs(items)
tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
s_idx, a_idx = prepared_inputs['state_idxs'], prepared_inputs['action_idxs']
rs, terminals = prepared_inputs['rewards'], prepared_inputs['terminals']
**self_outputs = self(tokens, attn_mask, s_idx, a_idx,
prefix_embs, prefix_attn_mask,
remove_prefix_position_embs,
qv_kwargs, policy_kwargs, target_kwargs,
**kwargs)**
def forward(self,
tokens: torch.Tensor,
attn_mask: Optional[torch.Tensor],
state_idxs: torch.Tensor,
action_idxs: torch.Tensor,
prefix_embs: Optional[torch.Tensor]=None,
prefix_attn_mask: Optional[torch.Tensor]=None,
remove_prefix_position_embs: bool=False,
qv_kwargs=None, policy_kwargs=None, target_kwargs=None,
skip_policy_on_train=False,
detach_full_policy=False):
if qv_kwargs is None:
qv_kwargs = {}
if target_kwargs is None:
target_kwargs = {}
if policy_kwargs is None:
policy_kwargs = {}
if self.lm_target is None: # 안함
qv_kwargs.update(target_kwargs)
if self.lm_policy is None: # 안함
qv_kwargs.update(policy_kwargs)
if attn_mask is None: # 안함
attn_mask = torch.ones(tokens.shape, dtype=torch.long).to(self.device)
**if prefix_embs is None:
prefix_embs = torch.empty((tokens.shape[0], 0, self.h_dim)).to(self.device)**
prefix_t = prefix_embs.shape[1]
set_pos_ids = prefix_attn_mask is not None
if prefix_attn_mask is None:
prefix_attn_mask = torch.ones(prefix_embs.shape[:2]).to(self.device)
input_attn_mask = torch.cat((prefix_attn_mask, attn_mask), dim=1)
position_ids = torch.cumsum(input_attn_mask, dim=1)-1 if set_pos_ids else None
# def forward
if isinstance(self.model, GPT2Model):
transformer = self.model
if self.lm_target is not None:
target_transformer = self.lm_target
if self.lm_policy is not None:
policy_transformer = self.lm_policy
**elif isinstance(self.model, GPT2LMHeadModel):
transformer = self.model.transformer
if self.lm_target is not None:
target_transformer = self.lm_target.transformer
if self.lm_policy is not None:
policy_transformer = self.lm_policy.transformer**
else:
raise NotImplementedError
**if self.lm_target is not None:
target_prefix_embs = prefix_embs.clone() # [1, 0, 768]
if self.lm_policy is not None:
policy_prefix_embs = prefix_embs.clone() # [1, 0, 768]**
if remove_prefix_position_embs:
prefix_embs -= transformer.wpe(position_ids[:, :prefix_embs.shape[1]])
input_embeddings = torch.cat((prefix_embs, transformer.wte(tokens)), dim=1)
model_outputs = self.model(inputs_embeds=input_embeddings,
attention_mask=input_attn_mask,
position_ids=position_ids, #None
output_hidden_states=True,
**qv_kwargs)
all_model_outputs = {
'qv_model_outputs': model_outputs,
'policy_model_outputs': model_outputs,
'target_model_outputs': model_outputs
}
# def forward
if self.advanced_mlp:
hidden_states = model_outputs.hidden_states[-2][:, prefix_t:, :]
else:
**hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]**
if self.lm_target is None:
target_hidden_states = hidden_states
else:
if remove_prefix_position_embs:
target_prefix_embs -= target_transformer.wpe(position_ids[:, :prefix_embs.shape[1]])
**target_input_embeddings = torch.cat((target_prefix_embs, target_transformer.wte(tokens)), dim=1)**
**with torch.no_grad():
target_outputs = self.lm_target(inputs_embeds=target_input_embeddings,
attention_mask=input_attn_mask,
position_ids=position_ids,
output_hidden_states=True,
**target_kwargs)**
all_model_outputs['target_model_outputs'] = target_outputs
if self.advanced_mlp:
target_hidden_states = target_outputs.hidden_states[-2][:, prefix_t:, :]
else:
**target_hidden_states = target_outputs.hidden_states[-1][:, prefix_t:, :]**
f self.lm_policy is None:
if isinstance(self.model, GPT2Model):
policy_hidden_states = hidden_states
else:
policy_hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]
**else:**
if skip_policy_on_train and self.training:
**policy_hidden_states = hidden_states**
else:
if remove_prefix_position_embs:
policy_prefix_embs -= policy_transformer.wpe(position_ids[:, :prefix_embs.shape[1]])
policy_input_embeddings = torch.cat((policy_prefix_embs, policy_transformer.wte(tokens)), dim=1)
if detach_full_policy:
with torch.no_grad():
policy_outputs = self.lm_policy(inputs_embeds=policy_input_embeddings,
attention_mask=input_attn_mask,
position_ids=position_ids,
output_hidden_states=True,
**policy_kwargs)
else:
policy_outputs = self.lm_policy(inputs_embeds=policy_input_embeddings,
attention_mask=input_attn_mask,
position_ids=position_ids,
output_hidden_states=True,
**policy_kwargs)
all_model_outputs['policy_model_outputs'] = policy_outputs
if isinstance(self.model, GPT2Model):
if self.advanced_mlp:
policy_hidden_states = policy_outputs.hidden_states[-2][:, prefix_t:, :]
else:
policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
else:
policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
state_hidden_states = torch.gather(input=hidden_states, dim=1, index=state_idxs.unsqueeze(2).repeat(1, 1, self.h_dim))
action_hidden_states = torch.gather(input=hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim))
action_target_hidden_states = torch.gather(input=target_hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim))
vs = self.v(state_hidden_states.detach() if self.detach_v else state_hidden_states).squeeze(2)
qs = self.q(action_hidden_states.detach() if self.detach_q else action_hidden_states)
ipdb> print(state_idxs)
tensor([[ 11, 12, 13, 14, 15, 22, 23, 24, 25, 26, 27, 30, 31, 32,
33, 34, 35, 38, 39, 40, 41, 42, 43, 44, 52, 53, 54, 55,
56, 57, 58, 59, 60, 63, 64, 65, 66, 67, 68, 71, 72, 73,
74, 75, 76, 79, 80, 81, 82, 83, 84, 85, 86, 87, 90, 91,
92, 93, 94, 95, 98, 99, 100, 101, 102, 103, 107]],
device='cuda:0')
ipdb> print(action_idxs)
tensor([[ 11, 12, 13, 14, 15, 22, 23, 24, 25, 26, 27, 30, 31, 32,
33, 34, 35, 38, 39, 40, 41, 42, 43, 44, 52, 53, 54, 55,
56, 57, 58, 59, 60, 63, 64, 65, 66, 67, 68, 71, 72, 73,
74, 75, 76, 79, 80, 81, 82, 83, 84, 85, 86, 87, 90, 91,
92, 93, 94, 95, 98, 99, 100, 101, 102, 103]], device='cuda:0')
if self.double_q:
qs2 = self.q2(action_hidden_states.detach() if self.detach_q else action_hidden_states)
with torch.no_grad():
target_qs = self.target_q(action_target_hidden_states)
if self.double_q:
target_qs2 = self.target_q2(action_target_hidden_states)
if skip_policy_on_train and self.training and self.lm_policy is not None:
**logits = torch.zeros((policy_hidden_states.shape[0],policy_hidden_states.shape[1],self.dataset.tokenizer.num_tokens(),)).to(self.device)**
else:
if detach_full_policy:
with torch.no_grad():
logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
else:
logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
return {
'model_outputs': all_model_outputs,
'vs': vs,
'target_vs': vs,
'qs': (qs, qs2,) if self.double_q else qs,
'target_qs': self.clip_values(torch.minimum(target_qs, target_qs2) if self.double_q else target_qs),
'logits': logits,
}
다시 get_qvs
# def get_qvs
self_outputs = self(tokens, attn_mask, s_idx, a_idx,
prefix_embs, prefix_attn_mask,
remove_prefix_position_embs,
qv_kwargs, policy_kwargs, target_kwargs,
**kwargs)
model_outputs, vs, qs = self_outputs['model_outputs'], self_outputs['vs'], self_outputs['qs']
target_qs, logits = self_outputs['target_qs'], self_outputs['logits']
vt = vs[:, :-1] # [1, 66]
vtp1 = vs[:, 1:]
select_tokens = torch.gather(tokens[:, 1:], dim=1, index=a_idx)
cql_term = self.get_cql_loss(qs, select_tokens, terminals)
def get_cql_loss(self, qs, action_tokens, terminals):
n = (1 - terminals[:, :-1]).sum()
if self.double_q:
q1, q2 = qs
b, t, d = q1.shape
return ((F.cross_entropy(q1.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction='none').reshape(b, t) * (1 - terminals[:, :-1])) + (F.cross_entropy(q2.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction='none').reshape(b, t) * (1 - terminals[:, :-1]))).sum() / max(n.item(), 1.0)
b, t, d = qs.shape
return (F.cross_entropy(qs.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction='none').reshape(b, t) * (1 - terminals[:, :-1])).sum() / max(n.item(), 1.0)
full_qs = qs
if self.double_q:
q1, q2 = qs
q1 = torch.gather(q1, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
q2 = torch.gather(q2, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
# tok_seq = [self.dataset.tokenizer.id_to_token(token) for token in select_tokens[0].detach().cpu().tolist()][:(1-terminals[0, :-1]).sum()]
# max_q_seq = torch.max(q1, q2)[0, :(1-terminals[0, :-1]).sum()].detach().cpu().tolist()
# print(self.dataset.tokenizer.decode(tokens[0, :][:attn_mask[0, :].sum().long()].tolist(), clean_up_tokenization_spaces=False))
# print(list(zip(tok_seq, max_q_seq)))
# print(rs)
qs = (q1, q2,)
else:
qs = torch.gather(qs, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
dm_term = self.get_dm_loss(full_qs, qs, terminals, self.dm_margin)
target_qs = torch.gather(target_qs, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
with torch.no_grad():
weights = self.get_weights(tokens, vt, target_qs, s_idx, a_idx, terminals)
return {
'tokens': tokens,
'attn_mask': attn_mask,
'model_outputs': model_outputs,
'vs': vt,
'qs': qs,
'vns': vtp1,
'target_vs': vt,
'target_qs': target_qs,
'target_vns': vtp1,
'rs': rs,
'terminals': terminals,
'logits': logits,
'weights': weights,
'cql_term': cql_term,
'dm_term': dm_term,
}
def get_loss()
def get_loss(self,
items: InputType,
awac_weight=0.0,
v_loss_weight=0.0,
q_loss_weight=0.0,
cql_loss_weight=0.0,
dm_loss_weight=0.0,
mc_returns=False):
prepared_inputs = self.prepare_inputs(items)
a_idx = prepared_inputs['action_idxs']
get_qvs_outputs = self.get_qvs(items,
qv_kwargs={'output_attentions': True},
policy_kwargs={'output_attentions': True},
target_kwargs={'output_attentions': True},
skip_policy_on_train=(awac_weight == 0.0),
)
tokens, attn_mask, model_outputs = get_qvs_outputs['tokens'], get_qvs_outputs['attn_mask'], get_qvs_outputs['model_outputs']
vs, qs = get_qvs_outputs['vs'], get_qvs_outputs['qs']
vns, target_qs, rs = get_qvs_outputs['vns'], get_qvs_outputs['target_qs'], get_qvs_outputs['rs']
terminals, logits, weights = get_qvs_outputs['terminals'], get_qvs_outputs['logits'], get_qvs_outputs['weights']
n = (1 - terminals[:, :-1]).sum().item()
rs_downstream = self.get_downstream_rs(rs, self.gamma)
if mc_returns:
v_loss = self.get_v_loss(vs, rs_downstream, terminals)
**else:
v_loss = self.get_v_loss(vs, target_qs, terminals)**
q_loss = self.get_q_loss(vns, qs, rs, self.gamma, terminals)
cql_loss = get_qvs_outputs['cql_term']
dm_loss = get_qvs_outputs['dm_term']
token_loss = self.awac_loss(tokens, attn_mask, logits, weights)
loss = awac_weight * token_loss + v_loss_weight * v_loss + q_loss_weight * q_loss + cql_loss_weight * cql_loss + dm_loss_weight * dm_loss
return loss, logs, [postproc_f, hist_f]
evaluation 부분
iql_train_loop.py
evaluator_logs = evaluator.evaluate(accelerator.unwrap_model(model), eval_items)
src/models/iql_model.py
def evaluate(self, model: PerTokenIQL, items: InputType) -> Optional[Dict[str, Any]]:
policy = IQL_Policy(model, self.kind, **self.generation_kwargs)
tokens = model.prepare_inputs(items)['tokens']
total_token_reward = 0
total_env_reward = 0
for i in range(tokens.shape[0]):
**result, sequence = interact_environment(self.env, policy, None)**
src/data/language_environment.py
def interact_environment(env: Language_Environment, policy: Policy, obs: Optional[Language_Observation]=None):
obs_sequence = []
if obs is None:
obs = env.reset()
while not env.is_terminal():
**action = policy.act(obs)**
src/models/iql_model.py
def act(self, obs: Language_Observation) -> str:
item = DataPoint.from_obs(obs, self.iql_model.dataset.tokenizer, self.iql_model.dataset.token_reward)
**generations, logprobs, kls = self.generate([item], always_terminate, **self.generation_kwargs)**
self.kls_all.append(kls[0, 0].item())
self.logprobs_all.append(logprobs[0, 0].item())
return generations[0][1][0]
def generate(self, items: InputType,
termination_condition: Callable[[np.ndarray], bool], **kwargs):
prepared_inputs = self.iql_model.prepare_inputs(items)
tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
state_idxs, action_idxs = prepared_inputs['state_idxs'], prepared_inputs['action_idxs']
if self.kind == 'beam':
method = self.beam_raw
**elif self.kind == 'sample':
method = self.sample_raw**
else:
raise NotImplementedError
generations, info, kls = method(tokens, attn_mask,
state_idxs, action_idxs,
termination_condition,
**kwargs)
return generations, info, kls
def sample_raw(self,
tokens: torch.Tensor, attn_mask: torch.Tensor,
state_idxs: torch.Tensor, action_idxs: torch.Tensor,
termination_condition: Callable[[np.ndarray], bool],
num_generations=1, max_generation_len=None,
temp=1.0, top_k=None, top_p=None,
exp_adv=False, adv_weight=0.0, adv_clip=None,
include_logits=True, include_adv=True,
rerank_log_prob_weight: float=0.0,
rerank_advantage_weight: float=0.0,
prefix_embs: Optional[torch.Tensor]=None,
prefix_attn_mask: Optional[torch.Tensor]=None,
remove_prefix_position_embs: bool=False):
assert include_logits or include_adv
tokenizer = self.iql_model.dataset.tokenizer
max_length = self.iql_model.dataset.max_len
if max_length is None:
max_length = self.iql_model.model.config.n_positions
max_length = min(max_length, self.iql_model.model.config.n_positions)
device = self.iql_model.device
bsize = tokens.shape[0]
n = bsize * num_generations
if max_generation_len is None:
max_generation_len = max_length+1
input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
model_outputs = self.iql_model(tokens, attn_mask,
state_idxs, action_idxs,
prefix_embs=prefix_embs,
prefix_attn_mask=prefix_attn_mask,
remove_prefix_position_embs=remove_prefix_position_embs,
qv_kwargs={'use_cache': True},
policy_kwargs={'use_cache': True},
target_kwargs={'use_cache': True})['model_outputs']
kvs = {'qv': model_outputs['qv_model_outputs'].past_key_values}
**if self.iql_model.lm_target is not None:
kvs['target'] = model_outputs['target_model_outputs'].past_key_values
if self.iql_model.lm_policy is not None:
kvs['policy'] = model_outputs['policy_model_outputs'].past_key_values**
dialogue_lens = attn_mask.sum(dim=1)
tokens = pad_sequence(torch.repeat_interleave(tokens, num_generations, dim=0), max_length, tokenizer.pad_token_id, device, 1)
dialogue_lens = torch.repeat_interleave(dialogue_lens, num_generations, dim=0)
kvs['qv'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), kvs['qv'])
if 'target' in kvs:
kvs['target'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), kvs['target'])
if 'policy' in kvs:
kvs['policy'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), kvs['policy'])
log_probs = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
kls = torch.full((dialogue_lens.shape[0],), math.log(num_generations)-((num_generations-1)/num_generations)).to(device)
advantages = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
termination_mask = torch.full((dialogue_lens.shape[0],), 1).to(device)
state_idxs_temp, action_idxs_temp = torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device), torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device)
t = torch.min(dialogue_lens).int()
base_logits = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
curr_token = tokens[:, t-1].unsqueeze(1)
curr_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['qv'])
curr_target_kvs, curr_policy_kvs = curr_kvs, curr_kvs
if 'target' in kvs:
curr_target_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['target'])
if 'policy' in kvs:
curr_policy_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['policy'])
iql_outputs = self.iql_model(curr_token, None, state_idxs_temp, action_idxs_temp,
qv_kwargs={'use_cache': True, 'past_key_values': curr_kvs},
policy_kwargs={'use_cache': True, 'past_key_values': curr_policy_kvs},
target_kwargs={'use_cache': True, 'past_key_values': curr_target_kvs})
model_outputs, logits = iql_outputs['model_outputs'], iql_outputs['logits']
logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
edited_logits = process_logits(logits.clone(), temp=temp, top_k=top_k, top_p=top_p)
vs, qs = iql_outputs['target_vs'], iql_outputs['target_qs']
**if exp_adv:
adv_logits = adv_weight * (qs - vs.unsqueeze(2))**
else:
adv_sign = ((qs - vs.unsqueeze(2)) > 0.0).float()
adv_logits = adv_weight * adv_sign + (1 - adv_weight) * (1 - adv_sign)
adv_logits = torch.log(adv_logits)
if adv_clip is not None: # no running
adv_logits = torch.clip(adv_logits, max=adv_clip)
adv_logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
full_logits = (edited_logits if include_logits else 0.0) + (adv_logits if include_adv else 0.0) + base_logits.unsqueeze(1).unsqueeze(2)
cat_dist = torch.distributions.categorical.Categorical(logits=full_logits[:, 0])
original_cat_dist = torch.distributions.categorical.Categorical(logits=logits[:, 0])
new_tokens = cat_dist.sample()
log_probs += cat_dist.log_prob(new_tokens)
kls += (cat_dist.log_prob(new_tokens) - original_cat_dist.log_prob(new_tokens))
qs_chosen = torch.gather(qs.squeeze(1), dim=1, index=new_tokens.unsqueeze(1)).squeeze(1)
advantages += (qs_chosen - vs.squeeze(1))
tokens[:, t] = new_tokens
kvs['qv'] = update_kvs(kvs['qv'], model_outputs['qv_model_outputs'].past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
if 'target' in kvs:
kvs['target'] = update_kvs(kvs['target'], model_outputs['target_model_outputs'].past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
if 'policy' in kvs:
kvs['policy'] = update_kvs(kvs['policy'], model_outputs['policy_model_outputs'].past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
for idx in range(n):
if tokens[idx, t] == tokenizer.eoa_token_id and t >= dialogue_lens[idx]:
termination_mask[idx] *= (1 - int(termination_condition(tokenizer.decode(tokens[idx, :].tolist(),
clean_up_tokenization_spaces=False))))
t += 1
termination_mask *= ((t-dialogue_lens) < max_generation_len).int()
scores = ((advantages * rerank_advantage_weight) + (log_probs * rerank_log_prob_weight)).reshape(-1, num_generations)
order = torch.argsort(-scores, dim=1)
output_strs = [tokenizer.decode(tokens[i, :].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
processed_outputs = []
for i in range(len(input_strs)):
temp_outputs = []
for x in range(num_generations):
processed_str = output_strs[i*num_generations+order[i, x]][len(input_strs[i]):].strip()
if tokenizer.id_to_token(tokenizer.pad_token_id) in processed_str:
processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.pad_token_id))].strip()
if tokenizer.id_to_token(tokenizer.eoa_token_id) in processed_str:
processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.eoa_token_id))].strip()
temp_outputs.append(processed_str)
processed_outputs.append(temp_outputs)
scores = torch.gather(scores, dim=1, index=order)
log_probs = torch.gather(log_probs.reshape(-1, num_generations), dim=1, index=order)
kls = torch.gather(kls.reshape(-1, num_generations), dim=1, index=order)
return list(zip(input_strs, processed_outputs)), log_probs.reshape(-1, num_generations), kls
GPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50264, 768)
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(2): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(3): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(4): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(5): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(6): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(7): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(8): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(9): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(10): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=768, out_features=50264, bias=False)
요약본 작성중
forward에서 하는 일
주어진 아이템 token의 emb와 attn mask를 self.model(GPT2LMHead) 로 넘겨 hidden state 생성
policy 훈련 안시킬거면 policy_hidden_state도 hidden_state와 동일
위에서 생성한 hidden_states에 대해 action와 state token이 idx 어디에 있는지 파악해서 torch.gather로 state_hidden_states와 action_hidden_state를 생성
target_hidden_states 으로부터 action_target_hidden_states 생성
Value network에는 state_hidden_state, Q network에는 action_hidden_state 넘겨서 vs, qs 생성
target Q network에 action_target_hidden_states 넘겨서 target_qs 생성
logits 는 그냥 [batch, seq_len, total_token] 인 zero tensor 넘김
output : vs, target_vs(그냥 vs랑 같음..), qs, target_qs (target_qs랑 target_qs2 중 작은걸로 clip), logits(zero tensor)
<lm_policy가 여기서 딱히 뭘 하지는 않음 >
get_qvs에서 하는 일
item foward에 통과시키기
forward에서 얻은 qs값과 action t+1 step 에 대응하는 token 사이 cql loss 구하기
q network 에 대해 다시 생각해볼 필요가 있음
하나에 안올라가서 deepspeed 시도중
accelerate configuration saved at /home/doolee13/.cache/huggingface/accelerate/default_config.yaml
명령어
accelerate launch train_iql_deepspeed.py
"fp16": {
"enabled": "true",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 15,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.0001
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 1.000000e+08,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1.000000e+08,
"contiguous_gradients": false
},
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false,
"gradient_accumulation_steps": 1,
"steps_per_print": inf,
"bf16": {
"enabled": false
}
}
ILQL beam_raw 부분
tokenizer = self.iql_model.dataset.tokenizer
max_length = self.iql_model.dataset.max_len
if max_length is None:
max_length = self.iql_model.model.config.n_positions
max_length = min(max_length, self.iql_model.model.config.n_positions)
device = self.iql_model.device
# (2, 50264)
bsize, vocab_size = tokens.shape[0], tokenizer.num_tokens()
# 20
n = bsize * beam_width
if max_generation_len is None:
max_generation_len = max_length+1
input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
model_outputs = self.iql_model(tokens, attn_mask,
state_idxs, action_idxs,
prefix_embs=prefix_embs,
prefix_attn_mask=prefix_attn_mask,
remove_prefix_position_embs=remove_prefix_position_embs,
qv_kwargs={'use_cache': True},
policy_kwargs={'use_cache': True},
target_kwargs={'use_cache': True})['model_outputs']
kvs = {'qv': model_outputs['qv_model_outputs'].past_key_values}
if self.iql_model.lm_target is not None:
kvs['target'] = model_outputs['target_model_outputs'].past_key_values
if self.iql_model.lm_policy is not None:
kvs['policy'] = model_outputs['policy_model_outputs'].past_key_values
# [bsize, ] : [158, 284]
original_dialogue_lens = attn_mask.sum(dim=1)
# [bsize, beam_width]
# [0,0,0,0 .. ,0]
# [1,1,1,1 .. ,1]
# [bsize-1, ..., bsize-1]
batch_indicator = torch.stack(beam_width*[torch.arange(0, bsize).to(device)], dim=1)
# [bsize, max among] -> [bsize * beam_width, max_len]
# [2, 284] -> [20, 1024]
tokens = pad_sequence(torch.repeat_interleave(tokens, beam_width, dim=0), max_length, tokenizer.pad_token_id, device, 1)
# [bsize * beam_width]
# [158, 158, 158, ..., 158, 284, 284, ..., 284]
dialogue_lens = torch.repeat_interleave(original_dialogue_lens, beam_width, dim=0)
kvs['qv'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs['qv'])
if 'target' in kvs:
kvs['target'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs['target'])
if 'policy' in kvs:
kvs['policy'] = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs['policy'])
# [bsize, beam_width] = [2, 20]
curr_scores = torch.zeros(bsize, beam_width).to(device) # (batch, k)
# [bsize, beam_width] = [2, 20]
logit_scores = torch.zeros(bsize, beam_width).to(device) # (batch, k)
# [bsize*beam_width]
termination_mask = torch.full((n,), 1).to(device)
state_idxs_temp, action_idxs_temp = torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device), torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device)
t = torch.min(dialogue_lens).int()
# [bsize * beam_width]
base_logits = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
# [bsize*beam_width, 1] = [20, 1]
# looks like [a, a, a, a,a ..., b, b, b, b]
curr_token = tokens[:, t-1].unsqueeze(1)
curr_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['qv'])
curr_target_kvs, curr_policy_kvs = curr_kvs, curr_kvs
if 'target' in kvs:
curr_target_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['target'])
if 'policy' in kvs:
curr_policy_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], kvs['policy'])
iql_outputs = self.iql_model(curr_token, None, state_idxs_temp, action_idxs_temp,
qv_kwargs={'use_cache': True, 'past_key_values': curr_kvs},
policy_kwargs={'use_cache': True, 'past_key_values': curr_policy_kvs},
target_kwargs={'use_cache': True, 'past_key_values': curr_target_kvs})
model_outputs, logits = iql_outputs['model_outputs'], iql_outputs['logits']
# logits shape : [20, 1, 50264]
logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
edited_logits = process_logits(logits.clone(), temp=temp, top_k=top_k, top_p=top_p)
vs, qs = iql_outputs['target_vs'], iql_outputs['target_qs']
if exp_adv:
adv_logits = adv_weight * (qs - vs.unsqueeze(2))
else:
adv_sign = ((qs - vs.unsqueeze(2)) > 0.0).float()
adv_logits = adv_weight * adv_sign + (1 - adv_weight) * (1 - adv_sign)
adv_logits = torch.log(adv_logits)
if adv_clip is not None:
adv_logits = torch.clip(adv_logits, max=adv_clip)
adv_logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
full_logits = (edited_logits if include_logits else 0.0) + (adv_logits if include_adv else 0.0) + base_logits.unsqueeze(1).unsqueeze(2)
scores = (torch.log(F.softmax(full_logits, dim=-1)).reshape(1, bsize, beam_width, -1).permute(3, 0, 1, 2) + curr_scores).permute(1, 2, 3, 0).reshape(1, bsize, -1) # (time, batch, k*vocab)
curr_scores, top_k_ = torch.topk(scores[0, :, :], k=beam_width, dim=1) # (batch, k), (batch, k)
tokens = tokens[(batch_indicator * beam_width + (top_k_ // vocab_size)).reshape(-1), :]
tokens[:, t] = top_k_.reshape(-1) % vocab_size