webshop/baseline_models/test.py
line 106 ~ line 113 :
if args.mem:
env.env.num_prev_obs = 1
env.env.num_prev_actions = 5
print('memory')
else:
env.env.num_prev_obs = 0
env.env.num_prev_actions = 0
print('no memory')
if args.bart:
bart_model = BartForConditionalGeneration.from_pretrained(args.bart_path)
print('bart model loaded', args.bart_path)
else:
bart_model = None
→ 논문에서는 이를 포함하면 오히려 성능이 떨어진다고 보고
bart_model 은 search창에 위치했을 때 query를 generation 하는 모델인데 존재하지 않는 경우 predict 함수에서 가능한 action 들 중 마지막 action을 선택하도록 짜여있다
main 함수 안의 args 변수는 위의 parse_args 함수에서 아래와 같이 반환된다
Namespace(bart=True, bart_path='./ckpts/web_search/checkpoint-800', image=True, mem=0, model_path='./ckpts/web_click/epoch_9/model.pth', softmax=True)
반면 전역 args 변수는 train_rl.py 함수에서
다음과 같이 반환되는데 이는
이미 존재하는 argument의 Namespace와 정의되지 않은 부분으로 나눠진 tuple을 반환하기 때문에 test.py 의 args에는 아래와 같은 정보가 전달된다.
Namespace(ban_buy=0, bert_path='', bptt=8, ckpt_freq=10000, click_item_name=1, clip=10, debug=0, embedding_dim=128, eval_freq=500, exploration_method='softmax', extra_search_path='./data/goal_query_predict.json', f='1', gamma=0.9, get_image=1, go_to_item=0, go_to_search=0, grad_encoder=1, harsh_reward=0, hidden_dim=128, human_goals=1, learning_rate=1e-05, log_freq=100, max_steps=300000, network='bert', num=None, num_envs=4, num_prev_actions=0, num_prev_obs=0, output_dir='logs', score_handicap=0, seed=0, state_format='text_rich', step_limit=100, test_freq=5000, w_en=1, w_il=0, w_pg=1, w_td=1, wandb=1)
이후 search query를 만드는 BART 모델은 modify 되지 않고 pretrained 모델을 가져와 사용한다
if args.bart:
bart_model = BartForConditionalGeneration.from_pretrained(args.bart_path)
print('bart model loaded', args.bart_path)
else:
bart_model = None
BERT 모델은 pre-trained 모델에 추가적인 layer들을 추가해 만들어지며 layer의 output을 RL state로 사용하고도 있다
config = BertConfigForWebshop(image=args.image)
model = BertModelForWebshop(config)
model.cuda()
model.load_state_dict(torch.load(args.model_path), strict=False)
사용한 모델 구성은 아래와 같다
class BertModelForWebshop(PreTrainedModel):
config_class = BertConfigForWebshop
def __init__(self, config):
super().__init__(config)
bert_config = BertConfig.from_pretrained('bert-base-uncased')
if config.pretrained_bert:
self.bert = BertModel.from_pretrained('bert-base-uncased')
else:
self.bert = BertModel(config)
self.bert.resize_token_embeddings(30526)
self.attn = BiAttention(768, 0.0)
self.linear_1 = nn.Linear(768 * 4, 768)
self.relu = nn.ReLU()
self.linear_2 = nn.Linear(768, 1)
if config.image:
self.image_linear = nn.Linear(512, 768)
else:
self.image_linear = None
# for state value prediction, used in RL
self.linear_3 = nn.Sequential(
nn.Linear(768, 128),
nn.LeakyReLU(),
nn.Linear(128, 1),
)
main 함수에서 총 500개의 episode가 평가되며 각 episode는 최대 timestep 100 을 진행하게 된다(넘을 경우 reward : 0)
def episode(model, idx=None, verbose=False, softmax=False, rule=False, bart_model=None):
obs, info = env.reset(idx)
if verbose:
print(info['goal'])
for i in range(100):
action = predict(obs, info, model, softmax=softmax, rule=rule, bart_model=bart_model)
if verbose:
print(action)
obs, reward, done, info = env.step(action)
if done:
return reward
return 0
main 함수에서는 episode에 rule = True 가 아닐 경우 score_softmax, True 일 경우 score_rule을 주고있다
score_softmax, score_rule = episode(model, idx=i, softmax=args.softmax, bart_model=bart_model), episode(model, idx=i, rule=True)
test.py 의 predict 함수
def predict(obs, info, model, softmax=False, rule=False, bart_model=None):
# info : dictionary containing 'valid' : valid actions, 'goal' : given instruction to solve, 'score', 'estimate_score', 'prev_ob', image_feat'
valid_acts = info['valid']
if valid_acts[0].startswith('search['):
if bart_model is None:
return valid_acts[-1]
else:
goal = process_goal(obs)
query = bart_predict(goal, bart_model, num_return_sequences=5, num_beams=5)
# query = random.choice(query) # in the paper, we sample from the top-5 generated results.
query = query[0] #... but use the top-1 generated search will lead to better results than the paper results.
return f'search[{query}]'
if rule:
item_acts = [act for act in valid_acts if act.startswith('click[item - ')]
if item_acts:
return item_acts[0]
else:
assert 'click[buy now]' in valid_acts
return 'click[buy now]'
ipdb.set_trace()
# https://heekangpark.github.io/nlp/huggingface-bert
# 'input_ids' : id list of tokens , 'token_type_ids' : 0 if sentence A, 1 if sentence B , 'attention_mask' : 0 for padding token 1 for tokens that should be trained
state_encodings = tokenizer(process(obs), max_length=512, truncation=True, padding='max_length') # process : ' -> nothing and sep -> SEP
action_encodings = tokenizer(list(map(process, valid_acts)), max_length=512, truncation=True, padding='max_length') # valid_acts : list of available actions
batch = {
'state_input_ids': state_encodings['input_ids'],
'state_attention_mask': state_encodings['attention_mask'],
'action_input_ids': action_encodings['input_ids'],
'action_attention_mask': action_encodings['attention_mask'],
'sizes': len(valid_acts),
'images': info['image_feat'].tolist(),
'labels': 0
}
batch = data_collator([batch])
# make batch cuda
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
if softmax:
idx = torch.multinomial(F.softmax(outputs.logits[0], dim=0), 1)[0].item()
else:
idx = outputs.logits[0].argmax(0).item()
return valid_acts[idx]
위는 obs 일부 예시
score_softmax의 predict 함수 parameter는
obs, info, model, softmax = True, rule = False, bart_model = model
[위의 코드에서는 line 3 이후 else 부분]
[코드에서는 state, action encoding 줄부터 시작]
score_rule 의 predict 함수 parameter는
obs, info, model, softmax = False, rule = True, bart_model = none 이다
[위의 코드에서는 line 3]
[위의 코드에서는 if rule 부분]
train_search_il.py 분석
먼저 get_dataset으로 dataset 준비
def get_dataset(name, flip=False, variant=None, size=None):
fname = name + "-flip" if flip else name
fpath = os.path.join(os.path.dirname(__file__), fname)
d = {}
splits = ["train", "validation", "test"]
if name == "web_search":
splits = ["train", "validation", "test", "all"]
for split in splits:
input, output = get_data(split) if name != "nl2bash" else get_data(
split, variant=variant)
l = len(input) if size is None else int(len(input) * size)
print("{} size: {}".format(split, l))
if flip:
input, output = output, input
input, output = input[:l], output[:l]
d[split] = process_dataset(input, output)
d = DatasetDict(d)
return d
def get_data(split):
data = json.load(open(PATH))
goals, searches = [], []
for goal, search_list in data.items():
goal = process_goal(goal)
for search in search_list:
search = process_str(search)
goals.append(goal)
searches.append(search)
n = len(goals)
human_goals = json.load(open(HUMAN_GOAL_PATH, 'r'))
goal_range = range(len(human_goals))
if split == 'train':
goal_range = range(500, len(human_goals))
elif split == 'validation':
goal_range = range(500, 1500)
elif split == 'test':
goal_range = range(0, 500)
goals_, searches_ = [], []
for goal, search in zip(goals, searches):
if goal in human_goals and human_goals.index(goal) in goal_range:
goals_.append(goal)
searches_.append(search)
return goals_, searches_
elif split == "all": # all human instructions, but without groundtruth search queries
all_data = json.load(open(GOAL_PATH))
all_goals = []
all_goals_processed = []
for ins_list in all_data.values():
for ins in ins_list:
ins = ins['instruction']
all_goals.append(ins)
all_goals_processed.append(process_str(ins))
return all_goals_processed, all_goals
[train, val, test] : return (goal, searches)
[all] : return (goal_processed, goal)
def process_dataset(input, output, max_len=256):
# https://huggingface.co/docs/transformers/v4.26.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.from_pretrained
# loook for __call__ func tokenzier
# padding = 'max_length' : padd to a maximum level defined as max_length param // truncation = True : remove token from a longest sequence in the pair // return_tensors : torch.Tesnor if 'pt'
# __call__ func returns BatchEncoding
input_encodings = tokenizer(input, padding='max_length',
max_length=max_len, truncation=True, return_tensors='pt')
output_encodings = tokenizer(
output, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt')
ipdb.set_trace()
labels = output_encodings['input_ids'] # 'input_ids' field means batch of encoded result size : [len of train_data, max_seq]
# https://huggingface.co/transformers/v4.11.3/_modules/transformers/models/bart/modeling_bart.html
# info for shift_tokens_right
# parameter is (input_ids, pad_token_id, decoder_start_token_id) : why give eos_token_id for decoder start?
#https://stackoverflow.com/questions/64904840/why-we-need-a-decoder-start-token-id-during-generation-in-huggingface-bart
# reason why they used EOS_token for decoder start !
# internally changes token_ids == -100 into pad_token_id (-100 indicates mask!, automatically ignored in pytorch loss funcs)
decoder_input_ids = shift_tokens_right(labels, PAD_TOKEN_ID, EOS_TOKEN_ID)
# so change back ?
labels[labels[:, :] == PAD_TOKEN_ID] = -100
# generate Dataset class from dictionary
dataset = Dataset.from_dict({
'input_ids': input_encodings['input_ids'],
'attention_mask': input_encodings['attention_mask'],
'decoder_input_ids': decoder_input_ids,
'labels': labels,
})
# select format for iter() output
dataset.set_format(type='torch', columns=[
'input_ids', 'labels', 'decoder_input_ids', 'attention_mask'])
return dataset
if __name__ == "__main__":
ipdb.set_trace()
dataset = get_dataset("web_search", flip=False)
train_dataset = dataset['train']
print(train_dataset[0])
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
model.resize_token_embeddings(len(tokenizer))
# model = BartForConditionalGeneration.from_pretrained('./models/qdmr-high-level-base/checkpoint-10000')
training_args = TrainingArguments(
output_dir='./ckpts/web_search',
num_train_epochs=10,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=50,
weight_decay=0.01,
evaluation_strategy="steps",
logging_dir='./logs',
logging_steps=50,
eval_steps=20,
save_steps=200
# eval_accumulation_steps=1
)
# data collator : define ways to prepare batch from list form
# set to default_data_collator in this setting
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=dataset["validation"],
compute_metrics=None,
)
trainer.train()
parser.add_argument("--bart_path", type=str, default='./ckpts/web_search/checkpoint-800', help="BART model path if using it")
TODO
train_choice_il.py 분석
config = BertConfigForWebshop(
image=args.image, pretrain_bert=args.pretrain)
model = BertModelForWebshop(config)
class BertConfigForWebshop(PretrainedConfig):
model_type = "bert"
def __init__(
self,
pretrained_bert=True,
image=False,
**kwargs
):
self.pretrained_bert = pretrained_bert
self.image = image
super().__init__(**kwargs)
train_dataset = get_dataset("train", mem=args.mem)
eval_dataset = get_dataset("eval", mem=args.mem)
def get_data(split, mem=False, filter_search=True):
# mem : defalut false in all settings
# MEM_PATH doesn't even exist
path = MEM_PATH if mem else PATH
print('Loading data from {}'.format(path))
with open(path, 'r') as json_file:
json_list = list(json_file)
human_goals = json.load(open(HUMAN_GOAL_PATH, 'r'))
random.seed(233)
random.shuffle(json_list)
goal_range = range(len(human_goals))
if split == 'train':
goal_range = range(1500, len(human_goals))
elif split == 'eval':
goal_range = range(500, 1500)
elif split == 'test':
goal_range = range(0, 500)
for json_str in json_list:
result = json.loads(json_str)
# only strips out the perfect goal
s = process_goal(result['states'][0])
assert s in human_goals, s
goal_idx = human_goals.index(s)
if goal_idx not in goal_range:
continue
num_trajs += 1
if 'images' not in result:
# if result['states'][i] is in product info state
# result['images'][i] is nonzero (len of 512)
result['images'] = [0] * len(result['states'])
for state, valid_acts, idx, image in zip(result['states'], result['available_actions'], result['action_idxs'], result['images']):
# result['action_idx'][i] saves action index that has been made
# result['states'][i+1] will be affected by this
cnt += 1
if filter_search and idx == -1:
continue
state_list.append(state)
image_list.append([0.] * 512 if image == 0 else image)
if len(valid_acts) > 20: # do some action space reduction...
bad += 1
# [0~5] is in + randomly chosen 10 among [6~len(valid_acts)]
new_idxs = list(range(6)) + \
random.sample(range(6, len(valid_acts)), 10)
if idx not in new_idxs:
new_idxs += [idx]
new_idxs = sorted(new_idxs)
valid_acts = [valid_acts[i] for i in new_idxs]
idx = new_idxs.index(idx)
# print(valid_acts)
action_list.extend(valid_acts)
idx_list.append(idx)
# in order to keep track of action_list ?
size_list.append(len(valid_acts))
def get_dataset(split, mem=False):
# tokenizers are from bert
states, actions, idxs, sizes, images = get_data(split, mem)
state_encodings = tokenizer(
states, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
action_encodings = tokenizer(
actions, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
# informations about attention mask
# https://bo-10000.tistory.com/132
dataset = {
'state_input_ids': state_encodings['input_ids'],
'state_attention_mask': state_encodings['attention_mask'],
'action_input_ids': action_encodings['input_ids'].split(sizes),
'action_attention_mask': action_encodings['attention_mask'].split(sizes),
'sizes': sizes,
'images': torch.tensor(images),
'labels': idxs,
}
return Dataset.from_dict(dataset)
def data_collator(batch):
# collate_fn
# __get_item__ func of dataset will return above dictionary form
# so batch would be list of above dictionary
state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [
], [], [], [], [], [], []
for sample in batch:
state_input_ids.append(sample['state_input_ids'])
state_attention_mask.append(sample['state_attention_mask'])
action_input_ids.extend(sample['action_input_ids'])
action_attention_mask.extend(sample['action_attention_mask'])
sizes.append(sample['sizes'])
labels.append(sample['labels'])
images.append(sample['images'])
max_state_len = max(sum(x) for x in state_attention_mask)
max_action_len = max(sum(x) for x in action_attention_mask)
return {
'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
'sizes': torch.tensor(sizes),
'images': torch.tensor(images),
'labels': torch.tensor(labels),
}
# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataloader = DataLoader(
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
# any : if one is True it's true // not + any : true only all of them is false
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
# gradient_accumulation_steps: # of updates steps to accumulate before performing a backward pass
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(
args.max_train_steps / num_update_steps_per_epoch)
# num_warmup_steps : default 0 // num_training_steps: default None
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("glue_no_trainer", experiment_config)
# Get the metric function
metric = load_metric("accuracy")
# Train!
total_batch_size = args.per_device_train_batch_size * \
accelerator.num_processes * args.gradient_accumulation_steps
BertModelForWebshop
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
sizes = sizes.tolist()
# print(state_input_ids.shape, action_input_ids.shape)
# bert returns
# last_hidden_state, pooler_output, hidden_states, attentions
# if you want to know more about attention mask,
# https://huggingface.co/docs/transformers/v4.26.1/en/glossary#attention-mask
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0] # (batch, seq_len, hidden_size)
if images is not None and self.image_linear is not None:
images = self.image_linear(images) # (batch, 768)
state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1) # (batch, seq_len+1, hidden_size)
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
# several actions were possible for one state and sizes remembers # info
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
act_lens = action_attention_mask.sum(1).tolist()
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
state_action_rep = self.relu(self.linear_1(state_action_rep))
act_values = get_aggregated(state_action_rep, act_lens, 'mean') # mean pool step
act_values = self.linear_2(act_values).squeeze(1)
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
loss = None
if labels is not None:
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
def get_aggregated(output, lens, method):
"""
Get the aggregated hidden state of the encoder.
B x D
"""
if method == 'mean':
return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
elif method == 'last':
return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
elif method == 'first':
return output[:, 0, :]
다시 main.py
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = total_step = 0
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
total_step += 1
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
metric.add_batch(
predictions=torch.stack([logit.argmax(dim=0)
for logit in outputs.logits]),
references=batch["labels"]
)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
checkpointing_steps 변수에 step 기준으로 저장되어있을 경우
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps }"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
checkpointing_steps 변수에 epoch 기준으로 저장되어있을 경우
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
os.makedirs(output_dir, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
torch.save(unwrapped_model.state_dict(),
os.path.join(output_dir, "model.pth"))
train_ril.py 분석
def train(agent, eval_env, test_env, envs, args):
# Agent, WebEnv, WebEnv, list of WebEnvs, args
start = time.time()
states, valids, transitions = [], [], []
state0 = None
for env in envs:
ob, info = env.reset()
if state0 is None:
state0 = (ob, info)
states.append(agent.build_state(ob, info))
valids.append(info['valid'])
class WebEnv
def reset(self, idx=None):
if idx is None:
idx = random.sample(self.goal_idxs, k=1)[0]
ob, info = self.env.reset(idx) # WebAgentText env
class WebAgentTextEnv
def reset(self, session=None, instruction_text=None):
"""Create a new session and reset environment variables"""
session_int = None
if session is not None:
self.session = str(session)
if isinstance(session, int):
session_int = session
else:
self.session = ''.join(random.choices(string.ascii_lowercase, k=10))
if self.session_prefix is not None:
self.session = self.session_prefix + self.session
init_url = f'{self.base_url}/{self.session}'
self.browser.get(init_url, session_id=self.session, session_int=session_int)
self.text_to_clickable = None
self.instruction_text = self.get_instruction_text() if instruction_text is None else instruction_text
obs = self.observation # this is a funtion
self.prev_obs = [obs]
self.prev_actions = []
return obs, None
다시 class WebEnv
self.session = self.env.server.user_sessions[self.env.session]
if info is None:
info = {}
self.cur_ob, self.prev_ob = ob, None
info.update({'valid': self.get_valid_actions(), 'goal': self.env.instruction_text,
'score': 0, 'estimate_score': self.score(),
'prev_ob': self.prev_ob, 'desc': '', 'feat': ''
})
self.steps = 0
self.item_rank = -1
return ob, info
info 업데이트에 들어가는 함수
WebEnv class에 있는 get_valid_actions 함수
def get_valid_actions(self):
valid_info = self.env.get_available_actions()
if valid_info['has_search_bar']: # only search action available
atts = self.session['goal']['attributes']
query = self.session['goal']['query']
inst = self.session['goal']['instruction_text']
texts = self.get_search_texts(atts, query, inst)
valids = [f'search[{text}]' for text in texts]
else:
valids = [] # and text.startswith('b')]
for text in valid_info['clickables']:
# ban buy when options not completed
if text == 'buy now' and self.ban_buy:
cur_options = len(self.session['options'])
all_options = len(
self.env.server.product_item_dict[self.session["asin"]]['customization_options'])
if cur_options != all_options:
continue
if text != 'search':
if self.click_item_name and text in self.asin2name:
text = 'item - ' + self.asin2name[text]
valids.append(f'click[{text}]')
# do some action space reduction...
if self.reduce_click and len(valids) > 20:
valids = valids[:6] + random.sample(valids[6:], 10)
if len(valids) == 0:
valids = ['finish']
return valids
(일단 reset 단계에서는 모두 시작이 search 창에서 부터이다 )
다시 train_rl.py
for env in envs:
ob, info = env.reset()
if state0 is None:
state0 = (ob, info)
ipdb.set_trace()
states.append(agent.build_state(ob, info))
valids.append(info['valid'])
agent.py
def build_state(self, ob, info):
""" Returns a state representation built from various info sources. """
obs_ids = self.encode(ob)
goal_ids = self.encode(info['goal'])
click = info['valid'][0].startswith('click[')
estimate = info['estimate_score']
obs_str = ob.replace('\n', '[SEP]')
goal_str = info['goal']
image_feat = info.get('image_feat')
return State(obs_ids, goal_ids, click, estimate, obs_str, goal_str, image_feat)
train_rl.py
for step in range(1, args.max_steps + 1):
# get actions from policy
action_strs, action_ids, values = agent.act(states, valids, method=args.exploration_method)
agent.py
def act(self, states, valid_acts, method, state_strs=None, eps=0.1):
""" Returns a string action from poss_acts. """
act_ids = self.encode_valids(valid_acts) # [#envs, # actions in info['valid']]
# sample actions
act_values, act_sizes, values = self.network.rl_forward(states, act_ids, value=True, act=True)
bert.py
def rl_forward(self, state_batch, act_batch, value=False, q=False, act=False):
act_values = []
act_sizes = []
values = []
# state_batch, act_batch : [num envs, State named tuple]
# act_batch : [num_envs, # valid_acts, each valid_act tokenized result]
for state, valid_acts in zip(state_batch, act_batch):
with torch.set_grad_enabled(not act):
state_ids = torch.tensor([state.obs]).cuda() # state.obs : encoded obs by tokenizer
state_mask = (state_ids > 0).int() # to mask-out 0 values
act_lens = [len(_) for _ in valid_acts]
act_ids = [torch.tensor(_) for _ in valid_acts]
# pad according to longest sequence
# returns [# valid_acts, longest tokenized result]
act_ids = nn.utils.rnn.pad_sequence(act_ids, batch_first=True).cuda()
act_mask = (act_ids > 0).int()
act_size = torch.tensor([len(valid_acts)]).cuda()
if self.image_linear is not None:
images = [state.image_feat]
images = [torch.zeros(512) if _ is None else _ for _ in images]
images = torch.stack(images).cuda() # BS x 512
else:
images = None
# returns [num actions]
logits = self.forward(state_ids, state_mask, act_ids, act_mask, act_size, images=images).logits[0]
act_values.append(logits)
act_sizes.append(len(valid_acts))
if value:
v = self.bert(state_ids, state_mask)[0]
# when state_ids : [1,37] // v : [1, 37, 768]
values.append(self.linear_3(v[0][0])) # append one scalar value
act_values = torch.cat(act_values, dim=0) # append all actions in one vector
act_values = torch.cat([F.log_softmax(_, dim=0) for _ in act_values.split(act_sizes)], dim=0)
# Optionally, output state value prediction
if value:
values = torch.cat(values, dim=0)
return act_values, act_sizes, values
else:
return act_values, act_sizes
agent.py
act_values = act_values.split(act_sizes) # splited into [num envs, each act sizes]
if method == 'softmax':
act_probs = [F.softmax(vals, dim=0) for vals in act_values] # vals are logits
# stochastic action sampling
act_idxs = [torch.multinomial(probs, num_samples=1).item() for probs in act_probs]
elif method == 'greedy':
act_idxs = [vals.argmax(dim=0).item() for vals in act_values]
elif method == 'eps': # eps exploration
act_idxs = [vals.argmax(dim=0).item() if random.random() > eps else random.randint(0, len(vals)-1) for vals in act_values]
acts = [acts[idx] for acts, idx in zip(act_ids, act_idxs)] # [num_envs, each action dimension]
# decode actions
act_strs, act_ids = [], [] # string_version and tokenized number version
for act, idx, valids in zip(acts, act_idxs, valid_acts):
if torch.is_tensor(act):
act = act.tolist()
if 102 in act:
act = act[:act.index(102) + 1]
act_ids.append(act) # [101, ..., 102]
if idx is None: # generative
act_str = self.decode(act)
else: # int
act_str = valids[idx]
act_strs.append(act_str)
return act_strs, act_ids, values
train.py
for env, action_str, action_id, state in zip(envs, action_strs, action_ids, states):
ob, reward, done, info = env.step(action_str)
if state0 is None: # first state
state0 = (ob, info)
r_att = r_opt = 0
if 'verbose' in info: # record specific logs
r_att = info['verbose'].get('r_att', 0)
r_option = info['verbose'].get('r_option ', 0)
r_price = info['verbose'].get('r_price', 0)
r_type = info['verbose'].get('r_type', 0)
w_att = info['verbose'].get('w_att', 0)
w_option = info['verbose'].get('w_option', 0)
w_price = info['verbose'].get('w_price', 0)
reward_str = f'{reward/10:.2f} = ({r_att:.2f} * {w_att:.2f} + {r_option:.2f} * {w_option:.2f} + {r_price:.2f} * {w_price:.2f}) * {r_type:.2f}'
else:
reward_str = str(reward)
log('Reward{}: {}, Done {}\n'.format(step, reward_str, done))
next_state = agent.build_state(ob, info)
next_valid = info['valid']
next_states, next_valids, rewards, dones = \
next_states + [next_state], next_valids + [next_valid], rewards + [reward], dones + [done]
# RL update
transitions.append(TransitionPG(states, action_ids, rewards, values, agent.encode_valids(valids), dones))
if len(transitions) >= args.bptt:
_, _, last_values = agent.act(next_states, next_valids, method='softmax')
stats = agent.update(transitions, last_values, step=step)
for k, v in stats.items():
tb.logkv_mean(k, v)
del transitions[:]
torch.cuda.empty_cache()
agent.py
def update(self, transitions, last_values, step=None, rewards_invdy=None):
returns, advs = discount_reward(transitions, last_values, self.gamma) # [#timesteps, num_envs]
stats_global = defaultdict(float)
def discount_reward(transitions, last_values, gamma):
returns, advantages = [], []
R = last_values.detach() # always detached
for t in reversed(range(len(transitions))):
_, _, rewards, values, _, dones = transitions[t]
R = torch.FloatTensor(rewards).to(device) + gamma * R * (1 - torch.FloatTensor(dones).to(device))
baseline = values
adv = R - baseline
returns.append(R)
advantages.append(adv)
return returns[::-1], advantages[::-1]
agent.py이어서
for transition, adv in zip(transitions, advs):
stats = {}
# len of transition.state = num_envs
log_valid, valid_sizes = self.network.rl_forward(transition.state, transition.valid_acts)
act_values = log_valid.split(valid_sizes)
# stores log_a value for each envs
log_a = torch.stack([values[acts.index(act)]
for values, acts, act in zip(act_values, transition.valid_acts, transition.act)])
stats['loss_pg'] = - (log_a * adv.detach()).mean()
stats['loss_td'] = adv.pow(2).mean() # L_value in paper
stats['loss_il'] = - log_valid.mean() # not in paper & its weight set to zero!
stats['loss_en'] = (log_valid * log_valid.exp()).mean()
for k in stats:
stats[k] = self.w[k] * stats[k] / len(transitions)
stats['loss'] = sum(stats[k] for k in stats)
stats['returns'] = torch.stack(returns).mean() / len(transitions)
stats['advs'] = torch.stack(advs).mean() / len(transitions)
stats['loss'].backward()
# Compute the gradient norm
stats['gradnorm_unclipped'] = sum(p.grad.norm(2).item() for p in self.network.parameters() if p.grad is not None)
nn.utils.clip_grad_norm_(self.network.parameters(), self.clip) # self.clip = max_norm // set to 10
stats['gradnorm_clipped'] = sum(p.grad.norm(2).item() for p in self.network.parameters() if p.grad is not None)
for k, v in stats.items():
stats_global[k] += v.item() if torch.is_tensor(v) else v
del stats
self.optimizer.step()
self.optimizer.zero_grad()
return stats_global
# handle done
for i, env in enumerate(envs):
if dones[i]:
ob, info = env.reset()
if i == 0:
state0 = (ob, info)
next_states[i] = agent.build_state(ob, info)
next_valids[i] = info['valid']
states, valids = next_states, next_valids