Webshop 코드리뷰

이두현·2024년 3월 17일
0

webshop/baseline_models/test.py

line 106 ~ line 113 :

if args.mem:
        env.env.num_prev_obs = 1
        env.env.num_prev_actions = 5
        print('memory')
    else:
        env.env.num_prev_obs = 0
        env.env.num_prev_actions = 0
        print('no memory')
    

    if args.bart:
        bart_model = BartForConditionalGeneration.from_pretrained(args.bart_path)
        print('bart model loaded', args.bart_path)
    else:
        bart_model = None
  • args.mem 변수는 훈련에 이전 관측과 선택했던 action을 포함시킬 것인가를 나타냄

→ 논문에서는 이를 포함하면 오히려 성능이 떨어진다고 보고

  • bart_model 은 search창에 위치했을 때 query를 generation 하는 모델인데 존재하지 않는 경우 predict 함수에서 가능한 action 들 중 마지막 action을 선택하도록 짜여있다

  • main 함수 안의 args 변수는 위의 parse_args 함수에서 아래와 같이 반환된다

Namespace(bart=True, bart_path='./ckpts/web_search/checkpoint-800', image=True, mem=0, model_path='./ckpts/web_click/epoch_9/model.pth', softmax=True)

반면 전역 args 변수는 train_rl.py 함수에서

Untitled

다음과 같이 반환되는데 이는

Untitled

이미 존재하는 argument의 Namespace와 정의되지 않은 부분으로 나눠진 tuple을 반환하기 때문에 test.py 의 args에는 아래와 같은 정보가 전달된다.

Namespace(ban_buy=0, bert_path='', bptt=8, ckpt_freq=10000, click_item_name=1, clip=10, debug=0, embedding_dim=128, eval_freq=500, exploration_method='softmax', extra_search_path='./data/goal_query_predict.json', f='1', gamma=0.9, get_image=1, go_to_item=0, go_to_search=0, grad_encoder=1, harsh_reward=0, hidden_dim=128, human_goals=1, learning_rate=1e-05, log_freq=100, max_steps=300000, network='bert', num=None, num_envs=4, num_prev_actions=0, num_prev_obs=0, output_dir='logs', score_handicap=0, seed=0, state_format='text_rich', step_limit=100, test_freq=5000, w_en=1, w_il=0, w_pg=1, w_td=1, wandb=1)

이후 search query를 만드는 BART 모델은 modify 되지 않고 pretrained 모델을 가져와 사용한다

if args.bart:
        bart_model = BartForConditionalGeneration.from_pretrained(args.bart_path)
        print('bart model loaded', args.bart_path)
    else:
        bart_model = None

BERT 모델은 pre-trained 모델에 추가적인 layer들을 추가해 만들어지며 layer의 output을 RL state로 사용하고도 있다

config = BertConfigForWebshop(image=args.image)
    model = BertModelForWebshop(config)
    model.cuda()
    model.load_state_dict(torch.load(args.model_path), strict=False)

사용한 모델 구성은 아래와 같다

class BertModelForWebshop(PreTrainedModel):

    config_class = BertConfigForWebshop

    def __init__(self, config):
        super().__init__(config)
        bert_config = BertConfig.from_pretrained('bert-base-uncased')
        if config.pretrained_bert:
            self.bert = BertModel.from_pretrained('bert-base-uncased')
        else:
            self.bert = BertModel(config)
        self.bert.resize_token_embeddings(30526)
        self.attn = BiAttention(768, 0.0)
        self.linear_1 = nn.Linear(768 * 4, 768)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(768, 1)
        if config.image:
            self.image_linear = nn.Linear(512, 768)
        else:
            self.image_linear = None

        # for state value prediction, used in RL
        self.linear_3 = nn.Sequential(
                nn.Linear(768, 128),
                nn.LeakyReLU(),
                nn.Linear(128, 1),
            )

Untitled

  • 선언된 BiAttention 모듈은 그림에서 attention fusion layer로 여겨진다

main 함수에서 총 500개의 episode가 평가되며 각 episode는 최대 timestep 100 을 진행하게 된다(넘을 경우 reward : 0)

def episode(model, idx=None, verbose=False, softmax=False, rule=False, bart_model=None):
    obs, info = env.reset(idx)
    if verbose:
        print(info['goal'])
    for i in range(100):
        action = predict(obs, info, model, softmax=softmax, rule=rule, bart_model=bart_model)
        if verbose:
            print(action)
        obs, reward, done, info = env.step(action)
        if done:
            return reward
    return 0

main 함수에서는 episode에 rule = True 가 아닐 경우 score_softmax, True 일 경우 score_rule을 주고있다

score_softmax, score_rule = episode(model, idx=i, softmax=args.softmax, bart_model=bart_model), episode(model, idx=i, rule=True)

test.py 의 predict 함수

def predict(obs, info, model, softmax=False, rule=False, bart_model=None):
    # info : dictionary containing 'valid' : valid actions, 'goal' : given instruction to solve, 'score', 'estimate_score', 'prev_ob', image_feat'
    valid_acts = info['valid']
    if valid_acts[0].startswith('search['):
        if bart_model is None:
            return valid_acts[-1]
        else:
            goal = process_goal(obs)
            query = bart_predict(goal, bart_model, num_return_sequences=5, num_beams=5)
            # query = random.choice(query)  # in the paper, we sample from the top-5 generated results.
            query = query[0]  #... but use the top-1 generated search will lead to better results than the paper results.
            return f'search[{query}]'
            
    if rule:
        item_acts = [act for act in valid_acts if act.startswith('click[item - ')]
        if item_acts:
            return item_acts[0]
        else:
            assert 'click[buy now]' in valid_acts
            return 'click[buy now]'

    ipdb.set_trace()
                
    # https://heekangpark.github.io/nlp/huggingface-bert
    # 'input_ids' : id list of tokens , 'token_type_ids' : 0 if sentence A, 1 if sentence B , 'attention_mask' : 0 for padding token 1 for tokens that should be trained
    state_encodings = tokenizer(process(obs), max_length=512, truncation=True, padding='max_length') # process : ' -> nothing and sep -> SEP
    action_encodings = tokenizer(list(map(process, valid_acts)), max_length=512, truncation=True,  padding='max_length') # valid_acts : list of available actions 
    batch = {
        'state_input_ids': state_encodings['input_ids'],
        'state_attention_mask': state_encodings['attention_mask'],
        'action_input_ids': action_encodings['input_ids'],
        'action_attention_mask': action_encodings['attention_mask'],
        'sizes': len(valid_acts),
        'images': info['image_feat'].tolist(),
        'labels': 0
    }
    batch = data_collator([batch])
    # make batch cuda
    batch = {k: v.cuda() for k, v in batch.items()}
    outputs = model(**batch)
    if softmax:
        idx = torch.multinomial(F.softmax(outputs.logits[0], dim=0), 1)[0].item()
    else:
        idx = outputs.logits[0].argmax(0).item()
    return valid_acts[idx]
  • obs는 list 안에 character가 하나씩 들어가있는 구조

Untitled

위는 obs 일부 예시

  • data_collector 는 원래 여러 batch 에 대해 묶어준 후 dictionary를 반환하는 것처럼 보이지만 prediction 함수에서는 하나의 state에 대해서만 batch 로 넘겨주게 된다
  • model 은 BertModelForWebShop 이며 foward 함수는 loss와 logit을 반환한다
  • torch.multinomial 은 확률값을 가진 row들이 주어지면 second parameter 개수만큼의 index를 반환한다

score_softmax의 predict 함수 parameter는

obs, info, model, softmax = True, rule = False, bart_model = model

[위의 코드에서는 line 3 이후 else 부분]

  • search 창에 들어와있을 경우
    • bart_model 이 만들어내는 query 중 top -1 결과를 사용

[코드에서는 state, action encoding 줄부터 시작]

  • click 옵션에 들어온 경우
    • Bert 모델이 return 한 loss, logit 중
    • softmax 가 true 이므로 logit을 확률값으로 사용해 random 하게 반환한다

score_rule 의 predict 함수 parameter는

obs, info, model, softmax = False, rule = True, bart_model = none 이다

[위의 코드에서는 line 3]

  • search 창에 들어와있을 경우
    • valid_act 중 가장 마지막 action을 선택

[위의 코드에서는 if rule 부분]

  • click 옵션에 들어온 경우
    • click[item-] 형식의 옵션이 있을 경우 가장 먼저 있는 것을 선택,
    • 위의 옵션이 없을 경우 바로 click[buy now] 옵션 선택

Untitled


train_search_il.py 분석

먼저 get_dataset으로 dataset 준비

def get_dataset(name, flip=False, variant=None, size=None):
    fname = name + "-flip" if flip else name
    fpath = os.path.join(os.path.dirname(__file__), fname)
    d = {}
    splits = ["train", "validation", "test"]
    if name == "web_search":
        splits = ["train", "validation", "test", "all"]
    for split in splits:
        input, output = get_data(split) if name != "nl2bash" else get_data(
            split, variant=variant)
        l = len(input) if size is None else int(len(input) * size)
        print("{} size: {}".format(split, l))
        if flip:
            input, output = output, input
        input, output = input[:l], output[:l]
        d[split] = process_dataset(input, output)
    d = DatasetDict(d)
    return d
  • train_search_il 에서는 splits 가 train, val, test, all 로 나뉜다
  • flip 옵션이 없으므로 그대로 반환
def get_data(split):
    data = json.load(open(PATH))
    goals, searches = [], []
    for goal, search_list in data.items():
        goal = process_goal(goal)
        for search in search_list:
            search = process_str(search)
            goals.append(goal)
            searches.append(search)
    n = len(goals)
  • PATH 에는 goal_query_map이 저장되어있으며 하나의 goal에 여러 search 경험이 있는 경우가 있으므로 이를 모두 반영하도록 한다
human_goals = json.load(open(HUMAN_GOAL_PATH, 'r'))
    goal_range = range(len(human_goals))
    if split == 'train':
        goal_range = range(500, len(human_goals))
    elif split == 'validation':
        goal_range = range(500, 1500)
    elif split == 'test':
        goal_range = range(0, 500)

goals_, searches_ = [], []
    for goal, search in zip(goals, searches):
        if goal in human_goals and human_goals.index(goal) in goal_range:
            goals_.append(goal)
            searches_.append(search)
    return goals_, searches_
  • human_goals 안에는 goal_query 중에서 goal에 해당하는 것들이 모여있다
  • split 이 train, val, test 중 하나에 속할 경우, goal 이 human_goal 안에 포함되어 있고 정해놓은 goal_range안에 그 index 가 포함된다면 배열에 저장해 반환
elif split == "all":  # all human instructions, but without groundtruth search queries
        all_data = json.load(open(GOAL_PATH))
        all_goals = []
        all_goals_processed = []
        for ins_list in all_data.values():
            for ins in ins_list:
                ins = ins['instruction']
                all_goals.append(ins)
                all_goals_processed.append(process_str(ins))
        return all_goals_processed, all_goals
  • split 항목이 all 에 해당할 경우
  • 상품번호 : list of 정보들 형태로 이뤄져 있고 list of 정보들 중 instrution 부분을 따오고 이는 goal 에 해당하기 때문에 all_goals 배열에 저장 후 return

[train, val, test] : return (goal, searches)

[all] : return (goal_processed, goal)

def process_dataset(input, output, max_len=256):
    # https://huggingface.co/docs/transformers/v4.26.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.from_pretrained
    # loook for __call__ func tokenzier
    # padding = 'max_length' : padd to a maximum level defined as max_length param // truncation = True : remove token from a longest sequence in the pair // return_tensors : torch.Tesnor if 'pt' 
    # __call__ func returns BatchEncoding
    input_encodings = tokenizer(input, padding='max_length',
                                max_length=max_len, truncation=True, return_tensors='pt')
    output_encodings = tokenizer(
        output, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt')
    ipdb.set_trace()
    labels = output_encodings['input_ids'] # 'input_ids' field means batch of encoded result  size : [len of train_data, max_seq]
    # https://huggingface.co/transformers/v4.11.3/_modules/transformers/models/bart/modeling_bart.html
    # info for shift_tokens_right
    # parameter is (input_ids, pad_token_id, decoder_start_token_id) : why give eos_token_id for decoder start?
    #https://stackoverflow.com/questions/64904840/why-we-need-a-decoder-start-token-id-during-generation-in-huggingface-bart 
    # reason why they used EOS_token for decoder start ! 
    # internally changes token_ids == -100 into pad_token_id (-100 indicates mask!, automatically ignored in pytorch loss funcs)
    decoder_input_ids = shift_tokens_right(labels, PAD_TOKEN_ID, EOS_TOKEN_ID)
    # so change back ? 
    labels[labels[:, :] == PAD_TOKEN_ID] = -100
    # generate Dataset class from dictionary
    dataset = Dataset.from_dict({
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'decoder_input_ids': decoder_input_ids,
        'labels': labels,
    })
    # select format for iter() output
    dataset.set_format(type='torch', columns=[
                       'input_ids', 'labels', 'decoder_input_ids', 'attention_mask'])
    return dataset
  • get_data 함수를 나온 후 위의 process_data 함수를 거치면 huggingface 의 Dataset class를 반환하게 된다
if __name__ == "__main__":
    ipdb.set_trace()
    dataset = get_dataset("web_search", flip=False)
    train_dataset = dataset['train']
    print(train_dataset[0])
    model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
    model.resize_token_embeddings(len(tokenizer))
    # model = BartForConditionalGeneration.from_pretrained('./models/qdmr-high-level-base/checkpoint-10000')
    training_args = TrainingArguments(
        output_dir='./ckpts/web_search',
        num_train_epochs=10,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=50,
        weight_decay=0.01,
        evaluation_strategy="steps",
        logging_dir='./logs',
        logging_steps=50,
        eval_steps=20,
        save_steps=200
        # eval_accumulation_steps=1
    )
    # data collator : define ways to prepare batch from list form
    # set to default_data_collator in this setting
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dataset["validation"],
        compute_metrics=None,
    )
    trainer.train()
parser.add_argument("--bart_path", type=str, default='./ckpts/web_search/checkpoint-800', help="BART model path if using it")
  • test.py 함수에서 BART 모델을 불러오는 경로는 다음과 같다

TODO

  1. 기존 baseline에 exploration 을 증가시킨 방법으로 실험
  2. Inverse RL을 통해 baseline의 대부분을 차지하는 imitation learning을 대체

train_choice_il.py 분석

config = BertConfigForWebshop(
        image=args.image, pretrain_bert=args.pretrain)
model = BertModelForWebshop(config)
class BertConfigForWebshop(PretrainedConfig):
    model_type = "bert"

    def __init__(
        self,
        pretrained_bert=True,
        image=False,

        **kwargs
    ):
        self.pretrained_bert = pretrained_bert
        self.image = image
        super().__init__(**kwargs)
  • args.image는 디폴트 값 1 (parse argument 하는 과정에서)
  • args.pretrained_bert 도 1
train_dataset = get_dataset("train", mem=args.mem)
eval_dataset = get_dataset("eval", mem=args.mem)
def get_data(split, mem=False, filter_search=True):
    # mem : defalut false in all settings
    # MEM_PATH doesn't even exist
    path = MEM_PATH if mem else PATH
    print('Loading data from {}'.format(path))
    with open(path, 'r') as json_file:
        json_list = list(json_file)

    human_goals = json.load(open(HUMAN_GOAL_PATH, 'r'))

    random.seed(233)
    random.shuffle(json_list)

		goal_range = range(len(human_goals))
    if split == 'train':
        goal_range = range(1500, len(human_goals))
    elif split == 'eval':
        goal_range = range(500, 1500)
    elif split == 'test':
        goal_range = range(0, 500)
  • PATH 에 있는 데이터는 ‘states’, ‘available_actions’, ‘action_idxs’, ‘images’ 정보를 포함한 dictionary
  • 이 때 action_idxs는 available_actions 중에서 취한 action 을 의미한다
  • HUMAN_GOAL_PATH 에 있는 데이터는 string 형태의 human goal 모음 list
for json_str in json_list:
        result = json.loads(json_str)
        # only strips out the perfect goal
        s = process_goal(result['states'][0])
        assert s in human_goals, s
        goal_idx = human_goals.index(s)
        if goal_idx not in goal_range:
            continue
  • json_str에는 위에서 말한 dictionary 가 들어있고 이중 ‘states’ 의 첫번째 원소 [0] 에는 instruction, 즉 goal 이 들어있음
  • human_goals 리스트의 index가 설정한 split dataset 범위 안에 있는지 검사
				num_trajs += 1
        if 'images' not in result:
            # if result['states'][i] is in product info state
            # result['images'][i] is nonzero (len of 512)
            result['images'] = [0] * len(result['states'])
        for state, valid_acts, idx, image in zip(result['states'], result['available_actions'], result['action_idxs'], result['images']):
            # result['action_idx'][i] saves action index that has been made 
            # result['states'][i+1] will be affected by this
            cnt += 1
            if filter_search and idx == -1:
                continue
            state_list.append(state)
            image_list.append([0.] * 512 if image == 0 else image)
            if len(valid_acts) > 20:  # do some action space reduction...
                bad += 1
                # [0~5] is in + randomly chosen 10 among [6~len(valid_acts)]
                new_idxs = list(range(6)) + \
                    random.sample(range(6, len(valid_acts)), 10)
                if idx not in new_idxs:
                    new_idxs += [idx]
                new_idxs = sorted(new_idxs)
                valid_acts = [valid_acts[i] for i in new_idxs]
                idx = new_idxs.index(idx)
                # print(valid_acts)
            action_list.extend(valid_acts)
            idx_list.append(idx)
            # in order to keep track of action_list ?
            size_list.append(len(valid_acts))
  • 이미지가 dictionary 에 없으면 [0] * 512 값으로 채워넣는다 (if, for 문 종합시)
  • 가능한 action 이 20 가지가 넘는 경우 action space를 줄인다
  • valid_acts는 다른 변수들과 다르게 배열 형태이므로 extend로 추가한다
  • 대신, size_list 에 valid_acts 크기를 저장해 몇번째 transition 에서 valid action 들이었는지 기록한다
  • idx는 action 들 중 선택한 것의 index를 의미한다
def get_dataset(split, mem=False):
    # tokenizers are from bert 
    states, actions, idxs, sizes, images = get_data(split, mem)
    state_encodings = tokenizer(
        states, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
    action_encodings = tokenizer(
        actions, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
    # informations about attention mask
    # https://bo-10000.tistory.com/132
    dataset = {
        'state_input_ids': state_encodings['input_ids'],
        'state_attention_mask': state_encodings['attention_mask'],
        'action_input_ids': action_encodings['input_ids'].split(sizes),
        'action_attention_mask': action_encodings['attention_mask'].split(sizes),
        'sizes': sizes,
        'images': torch.tensor(images),
        'labels': idxs,
    }
    return Dataset.from_dict(dataset)
  • tokenizer 는 공통적으로 bert-base-uncased 사용
  • tensor.split(list) 는 list 안의 원소들 개수만큼 원래 tensor 를 나눠줌
  • action_list는 개수를 size_list 에 기록한 채 한 배열에 담아두었으므로 위와 같은 split 함수로 나눠 저장
def data_collator(batch):
    # collate_fn 
    # __get_item__ func of dataset will return above dictionary form
    # so batch would be list of above dictionary
    state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [
    ], [], [], [], [], [], []
    for sample in batch:
        state_input_ids.append(sample['state_input_ids'])
        state_attention_mask.append(sample['state_attention_mask'])
        action_input_ids.extend(sample['action_input_ids'])
        action_attention_mask.extend(sample['action_attention_mask'])
        sizes.append(sample['sizes'])
        labels.append(sample['labels'])
        images.append(sample['images'])
    max_state_len = max(sum(x) for x in state_attention_mask)
    max_action_len = max(sum(x) for x in action_attention_mask)
    return {
        'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
        'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
        'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
        'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
        'sizes': torch.tensor(sizes),
        'images': torch.tensor(images),
        'labels': torch.tensor(labels),
    }
  • 모든 데이터에 대해 가장 긴 state sequence 길이, 가장 긴 action sequence 길이를 기준으로 square 하게 만들어 list를 모두 tensor 형태로 반환
# DataLoaders creation:
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
    )
    eval_dataloader = DataLoader(
        eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    # any : if one is True it's true // not + any : true only all of them is false
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
  • 데이터와 optimizer 준비
# Scheduler and math around the number of training steps.
    # gradient_accumulation_steps: # of updates steps to accumulate before performing a backward pass
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(
            args.max_train_steps / num_update_steps_per_epoch)

    # num_warmup_steps : default 0 // num_training_steps: default None
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
  • max_train_steps가 안정해진 경우 num_train_epoch 에 기반해 결정
  • max_train_steps가 정해진 경우 else문으로 가서 num_train_epoch 를 덮어씌움
  • scheduler를 받은 후 항목들을 accelerator 에 포함
		if hasattr(args.checkpointing_steps, "isdigit"):
        checkpointing_steps = args.checkpointing_steps
        if args.checkpointing_steps.isdigit():
            checkpointing_steps = int(args.checkpointing_steps)
    else:
        checkpointing_steps = None

    # We need to initialize the trackers we use, and also store our configuration
    if args.with_tracking:
        experiment_config = vars(args)
        # TensorBoard cannot log Enums, need the raw value
        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
        accelerator.init_trackers("glue_no_trainer", experiment_config)

    # Get the metric function
    metric = load_metric("accuracy")

    # Train!
    total_batch_size = args.per_device_train_batch_size * \
        accelerator.num_processes * args.gradient_accumulation_steps
  • checkpointing_steps 변수는 몇 번의 step 마다 accelerator 가 현재 state 등 정보를 저장해야하는지 정하는 변수
  • default 값은 ‘epoch’
  • checkpointg_steps가 int형이면 변환, 아니면 유지
  • dataset library에서 metric 객체를 load
  • total_batch_size 구할 때 마지막에 곱해주는 항 이유를 모르겠지만 실제적으로 쓰이는 변수는 아니므로 패스

BertModelForWebshop

def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, images=None, labels=None):
        sizes = sizes.tolist()
        # print(state_input_ids.shape, action_input_ids.shape)
        # bert returns
        # last_hidden_state, pooler_output, hidden_states, attentions
        # if you want to know more about attention mask, 
        # https://huggingface.co/docs/transformers/v4.26.1/en/glossary#attention-mask
        state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0] # (batch, seq_len, hidden_size)
        if images is not None and self.image_linear is not None:
            images = self.image_linear(images) # (batch, 768)
            state_rep = torch.cat([images.unsqueeze(1), state_rep], dim=1) # (batch, seq_len+1, hidden_size)
            state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
        action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
        # several actions were possible for one state and sizes remembers # info
        state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
        state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
        act_lens = action_attention_mask.sum(1).tolist()
        state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
        state_action_rep = self.relu(self.linear_1(state_action_rep))
        act_values = get_aggregated(state_action_rep, act_lens, 'mean') # mean pool step
        act_values = self.linear_2(act_values).squeeze(1)

        logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]

        loss = None
        if labels is not None:
            loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
        )
  • sizes 별로 repeat 하기 전 state_rep dimenision은 [batch, seq_len+1, hidden_size] 인데 repeat(j,1,1) 하게 되면 batch 방향으로만 j 번 반복하게 된다 (j는 size 배열의 원소)
  • attention_mask는 [batch, seq_len+1] 만큼의 dimension 이 있기 때문에 repeat을 (j,1) 로 하는 것이다
  • action_attention_mask는 [batch, seq_len] 차원으로 seq 방향에서 1인 mask의 개수를 세어 act_lens 배열에 저장한다
  • batch 개수만큼 scalar 값을 다 구한 후 action 개수만큼 잘라서 softmax를 취하는 형태를 갖고 있다
  • loss는 -log(p) 의 average로 계산된다
def get_aggregated(output, lens, method):
    """
    Get the aggregated hidden state of the encoder.
    B x D
    """
    if method == 'mean':
        return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
    elif method == 'last':
        return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
    elif method == 'first':
        return output[:, 0, :]
  • output 은 (batch, action_seq_len, 768) 의 dimension을 갖는 state_action_rep 변수이다
  • lens에는 batch 각 원소에 대한 action seq의 길이를 측정한 act_lens 변수가 들어간다
  • output[i,:j,:] 차원은 [1,j,hidden] 이 아니라 [j, hidden]이 되므로 j 에 해당하는 action 만큼의 action 이 mean 처리된다
  • [batch, hidden] 차원을 반환할 것이다

다시 main.py

for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()
        if args.with_tracking:
            total_loss = total_step = 0

        for step, batch in enumerate(train_dataloader):
            # We need to skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == starting_epoch:
                if resume_step is not None and step < resume_step:
                    completed_steps += 1
                    continue
            outputs = model(**batch)
            loss = outputs.loss
            # We keep track of the loss at each epoch
            if args.with_tracking:
                total_loss += loss.detach().float()
                total_step += 1
            loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)

            metric.add_batch(
                predictions=torch.stack([logit.argmax(dim=0)
                                        for logit in outputs.logits]),
                references=batch["labels"]
            )

            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1
  • total_loss 는 log 기록용
  • gradient_accumulation_step 으로 나눈 loss 를 보내고 optimizer step은 step 이 개수만큼 진행된 후
  • predictions 에 argmax 값들을 저장

checkpointing_steps 변수에 step 기준으로 저장되어있을 경우

					if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps }"
                    if args.output_dir is not None:
                        output_dir = os.path.join(args.output_dir, output_dir)
                    accelerator.save_state(output_dir)

checkpointing_steps 변수에 epoch 기준으로 저장되어있을 경우

			if args.checkpointing_steps == "epoch":
            output_dir = f"epoch_{epoch}"
            if args.output_dir is not None:
                output_dir = os.path.join(args.output_dir, output_dir)
            os.makedirs(output_dir, exist_ok=True)
            unwrapped_model = accelerator.unwrap_model(model)
            torch.save(unwrapped_model.state_dict(),
                       os.path.join(output_dir, "model.pth"))

train_ril.py 분석

def train(agent, eval_env, test_env, envs, args):
    # Agent, WebEnv, WebEnv, list of WebEnvs, args
    start = time.time()
    states, valids, transitions = [], [], []
    state0 = None
    for env in envs:
        ob, info = env.reset()
        if state0 is None:
            state0 = (ob, info)
        states.append(agent.build_state(ob, info))
        valids.append(info['valid'])
  • eval_env, test_env 는 WebEnv 라는 WebAgentTextEnv 의 wrapper class 를 인자로 받는다
  • envs 는 env 개수만큼 WebEnv list 로 생성되어 전달된다

class WebEnv

def reset(self, idx=None):
        if idx is None:
            idx = random.sample(self.goal_idxs, k=1)[0]
        ob, info = self.env.reset(idx) # WebAgentText env
  • random index를 선정 후 WebAgentTextEnv reset 함수를 실행한다

class WebAgentTextEnv

def reset(self, session=None, instruction_text=None):
        """Create a new session and reset environment variables"""
        session_int = None
        if session is not None:
            self.session = str(session)
            if isinstance(session, int):
                session_int = session
        else:
            self.session = ''.join(random.choices(string.ascii_lowercase, k=10))
        if self.session_prefix is not None:
            self.session = self.session_prefix + self.session

        init_url = f'{self.base_url}/{self.session}'
        self.browser.get(init_url, session_id=self.session, session_int=session_int)

        self.text_to_clickable = None
        self.instruction_text = self.get_instruction_text() if instruction_text is None else instruction_text
        obs = self.observation # this is a funtion
        self.prev_obs = [obs]
        self.prev_actions = []
        return obs, None
  • session 으로 위에서 설정한 random idx를 전달받는다
  • self.session_prefix는 main 에서 envs parameter 받을 때 list 중 몇 번째 env 인지 나타내는 index ex) train0 이런식으로 이름붙임
  • self.session 은 ex) train0_idx 이런식으로 이름붙음
  • self.browser 는 SimBrowser 클래스로 html 환경을 text로 볼 수 있게 중화? 시켜 주는 class
  • self.browser.get 을 하면 self.page_source에 html 소스를 저장하고 self.current_url 변수에 현재 페이지 정보를 저장 (class 변수에 이 값들을 저장해놓음)
  • self.instruction_text 에는 ‘Instruction: ….’ 이런식으로 사용자가 원하는 상품에 대한 내용 반환
  • self.observation 함수는 html 을 parsing 후 알아보기 편한 text-rich 상태로 반환해주며 예로는 아래와 같음

Untitled

다시 class WebEnv

				self.session = self.env.server.user_sessions[self.env.session]
        if info is None:
            info = {}
        self.cur_ob, self.prev_ob = ob, None
        info.update({'valid': self.get_valid_actions(), 'goal': self.env.instruction_text,
                     'score': 0, 'estimate_score': self.score(),
                     'prev_ob': self.prev_ob, 'desc': '', 'feat': ''
                     })
        self.steps = 0
				self.item_rank = -1
        return ob, info
  • self.env 는 WebAgentTextEnv 를 의미하고 self.env.session 은 예시로 계속드는 train0_idx 를 반환한다
  • self.enmv.server는 SimServer를 의미하고 user_sessions dictionary 는 WebAgentTextEnv 에서 reset 호출시 SimBrowser의 get 함수를 부르게 되는데 이는 SimServer 에 대한 wrapper 클래스로서 receive 함수를 호출하게 되는데 이 함수에 user_sessions dict 을 채우는 내용이 들어있다
  • self.session 은 이와 같은 내용을 포함하고 있다

Untitled

  • estimate_score 도 avail_actions 중에 choice 옵션이 없으므로 0을 기록해놓는다
  • item_rank -1 은 나중에 이유를 알아내야 할 듯?

info 업데이트에 들어가는 함수

WebEnv class에 있는 get_valid_actions 함수

def get_valid_actions(self):
        valid_info = self.env.get_available_actions()
        if valid_info['has_search_bar']:  # only search action available
            atts = self.session['goal']['attributes']
            query = self.session['goal']['query']
            inst = self.session['goal']['instruction_text']
            texts = self.get_search_texts(atts, query, inst)
            valids = [f'search[{text}]' for text in texts]
        else:
            valids = []  # and text.startswith('b')]
            for text in valid_info['clickables']:
                # ban buy when options not completed
                if text == 'buy now' and self.ban_buy:
                    cur_options = len(self.session['options'])
                    all_options = len(
                        self.env.server.product_item_dict[self.session["asin"]]['customization_options'])
                    if cur_options != all_options:
                        continue
                if text != 'search':
                    if self.click_item_name and text in self.asin2name:
                        text = 'item - ' + self.asin2name[text]
                    valids.append(f'click[{text}]')
                # do some action space reduction...
                if self.reduce_click and len(valids) > 20:
                    valids = valids[:6] + random.sample(valids[6:], 10)
        if len(valids) == 0:
            valids = ['finish']
        return valids
  • valid_info 는 WebAgentTextEnv에서 search option 만 가능하지 여부와 선택할 수 있다면 action버튼 종류들을 반환한다
  • search창에 있는 경우 self.get_search_texts 함수를 통해 search_il 모델이 만든 instruction 별 query를 이용해 검색한다
  • get_search_texts에서 사용되는 self.extra_search는 search_il 이 만든 search query를 의미한다
  • 이 함수에서 쓰인 query가 어떤 의미인지 이해하지 못했지만 search_il 이 있는이상 사용되지 않는다!

(일단 reset 단계에서는 모두 시작이 search 창에서 부터이다 )

다시 train_rl.py

for env in envs:
        ob, info = env.reset()
        if state0 is None:
            state0 = (ob, info)
        ipdb.set_trace()
        states.append(agent.build_state(ob, info))
        valids.append(info['valid'])
  • info[’valid’]에는 해당 session 에서 수행 가능한 action 들의 모음이 있다

agent.py

def build_state(self, ob, info):
        """ Returns a state representation built from various info sources. """
        obs_ids = self.encode(ob)
        goal_ids = self.encode(info['goal'])
        click = info['valid'][0].startswith('click[')
        estimate = info['estimate_score']
        obs_str = ob.replace('\n', '[SEP]')
        goal_str = info['goal']
        image_feat = info.get('image_feat')
        return State(obs_ids, goal_ids, click, estimate, obs_str, goal_str, image_feat)
  • Agent 의 tokenizer 는 모두 bert uncased 이다
  • obs 에 대한 tokenized 결과, instruction(goal)에 대한 tokenized 결과를 반환한다
  • click 엔 valid action 중 click 버튼이 존재하는지를 저장한다
  • estimate : self.score() 함수 결과를 반환 (reset 직후에는 0)
  • obs 에서 개행문자 \n를 [SEP] 토큰으로 치환
  • get으로 없을 경우에도 오류가 나지 않게 dict 에서 받을 수 있다

train_rl.py

for step in range(1, args.max_steps + 1):
        # get actions from policy
        action_strs, action_ids, values = agent.act(states, valids, method=args.exploration_method)

agent.py

def act(self, states, valid_acts, method, state_strs=None, eps=0.1):
        """ Returns a string action from poss_acts. """
        act_ids = self.encode_valids(valid_acts) # [#envs, # actions in info['valid']]
				
				# sample actions
        act_values, act_sizes, values = self.network.rl_forward(states, act_ids, value=True, act=True)
  • parmater로 넘겨지는 valid_act 의 크기는 [num envs, 각 가능한 action 개수]
  • self.network = BertModelForWebShop
  • act = True 는 gradient 끌 것임을 의미

bert.py

def rl_forward(self, state_batch, act_batch, value=False, q=False, act=False):
        act_values = []
        act_sizes = []
        values = []
        # state_batch, act_batch : [num envs, State named tuple]
        # act_batch : [num_envs, # valid_acts, each valid_act tokenized result]
        for state, valid_acts in zip(state_batch, act_batch):
            with torch.set_grad_enabled(not act):
                state_ids = torch.tensor([state.obs]).cuda() # state.obs : encoded obs by tokenizer
                state_mask = (state_ids > 0).int()  # to mask-out 0 values
                act_lens = [len(_) for _ in valid_acts]
                act_ids = [torch.tensor(_) for _ in valid_acts]
                # pad according to longest sequence
                # returns [# valid_acts, longest tokenized result]
                act_ids = nn.utils.rnn.pad_sequence(act_ids, batch_first=True).cuda()
                act_mask = (act_ids > 0).int()
                act_size = torch.tensor([len(valid_acts)]).cuda()
                if self.image_linear is not None:
                    images = [state.image_feat]
                    images = [torch.zeros(512) if _ is None else _ for _ in images] 
                    images = torch.stack(images).cuda()  # BS x 512
                else:
                    images = None
                # returns [num actions]
                logits = self.forward(state_ids, state_mask, act_ids, act_mask, act_size, images=images).logits[0]
                act_values.append(logits)
                act_sizes.append(len(valid_acts))
            if value:
                v = self.bert(state_ids, state_mask)[0]
                # when state_ids : [1,37] // v : [1, 37, 768]
                values.append(self.linear_3(v[0][0])) # append one scalar value
        act_values = torch.cat(act_values, dim=0) # append all actions in one vector
        act_values = torch.cat([F.log_softmax(_, dim=0) for _ in act_values.split(act_sizes)], dim=0)
        # Optionally, output state value prediction
        if value:
            values = torch.cat(values, dim=0)
            return act_values, act_sizes, values
        else:
            return act_values, act_sizes
  • 앞서 말한대로 gradient 는 끈 채로 진행 (action 만 선택하는 것이므로)
  • nn.utils.rnn.pad_sequence는 batch_first = True 일 경우 (batch, max_seq) 기준으로 pad결과를 반환해준다
  • action 예측 시 choice 모델과 weight를 공유한다
  • value는 Q(s(o,a)) 를 나타내는 것 같다 (scalar value)

agent.py

				act_values = act_values.split(act_sizes) # splited into [num envs, each act sizes]
        if method == 'softmax':
            act_probs = [F.softmax(vals, dim=0) for vals in act_values] # vals are logits
            # stochastic action sampling
            act_idxs = [torch.multinomial(probs, num_samples=1).item() for probs in act_probs]    
        elif method == 'greedy':
            act_idxs = [vals.argmax(dim=0).item() for vals in act_values]
        elif method == 'eps': # eps exploration
            act_idxs = [vals.argmax(dim=0).item() if random.random() > eps else random.randint(0, len(vals)-1) for vals in act_values]
        acts = [acts[idx] for acts, idx in zip(act_ids, act_idxs)] # [num_envs, each action dimension]

        # decode actions
        act_strs, act_ids = [], [] # string_version and tokenized number version
        for act, idx, valids in zip(acts, act_idxs, valid_acts):
            if torch.is_tensor(act):
                act = act.tolist()
            if 102 in act:
                act = act[:act.index(102) + 1]
            act_ids.append(act)  # [101, ..., 102]
            if idx is None:  # generative
                act_str = self.decode(act)
            else:  # int
                act_str = valids[idx]
            act_strs.append(act_str)
        return act_strs, act_ids, values
  • reset 이후 softmax로 설정되어있음
  • bert network 를 공유한 action probability 에 의거해 action을 stochastic 하게 고른다
  • acts 에는 각 env에 대해 선택된 action 의 숫자 tokenzied version 이 들어가 있음
  • act_strs, act_ids에 action 의 문자버전, 숫자 tokenzied 버전을 저장한다
  • 이들의 첫 dimension은 env 의 개수이다
  • 101과 102번 이 의미하는 token이 무엇인지 봐야할듯

train.py

for env, action_str, action_id, state in zip(envs, action_strs, action_ids, states):
            ob, reward, done, info = env.step(action_str)
            if state0 is None:  # first state
                state0 = (ob, info)
                r_att = r_opt = 0
                if 'verbose' in info: # record specific logs
                    r_att = info['verbose'].get('r_att', 0)
                    r_option = info['verbose'].get('r_option ', 0)
                    r_price = info['verbose'].get('r_price', 0)
                    r_type = info['verbose'].get('r_type', 0)
                    w_att = info['verbose'].get('w_att', 0)
                    w_option = info['verbose'].get('w_option', 0)
                    w_price = info['verbose'].get('w_price', 0)
                    reward_str = f'{reward/10:.2f} = ({r_att:.2f} * {w_att:.2f} + {r_option:.2f} * {w_option:.2f} + {r_price:.2f} * {w_price:.2f}) * {r_type:.2f}'
                else:
                    reward_str = str(reward)
                log('Reward{}: {}, Done {}\n'.format(step, reward_str, done))
            next_state = agent.build_state(ob, info)
            next_valid = info['valid']
            next_states, next_valids, rewards, dones = \
                next_states + [next_state], next_valids + [next_valid], rewards + [reward], dones + [done]
  • ‘verbose’ 옵션이 커져있으면 자세한 output log를 기록
  • agent.build_state 는 obs 와 info 정보를 갖고 State named tuple을 반환
  • env 한 step 결과를 next_state, next_valids … 배열에 저장한다
				# RL update
        transitions.append(TransitionPG(states, action_ids, rewards, values, agent.encode_valids(valids), dones))
        if len(transitions) >= args.bptt:
            _, _, last_values = agent.act(next_states, next_valids, method='softmax')
            stats = agent.update(transitions, last_values, step=step)
            for k, v in stats.items():
                tb.logkv_mean(k, v)
            del transitions[:]
            torch.cuda.empty_cache()
  • transitions 배열에 TransitionPG named tuple 정보를 저장한다 (agent.encode_valids는 list of list 를 encode 해주는 함수다)
  • rl update 를 시작할 최소 transition 개수인 args.bptt가 쌓이면 next_states 와 valid action 포함 집합을 보내 S(o,a) 에 대한 Q 값을 얻는다 (agent.act 함수로부터, act 함수는 위에서 분석 완료함)
  • transition 이 원하는 개수만큼 쌓이면 update 진행 후 다시 비우기

agent.py

def update(self, transitions, last_values, step=None, rewards_invdy=None):
        returns, advs = discount_reward(transitions, last_values, self.gamma) # [#timesteps, num_envs]
        stats_global = defaultdict(float)
  • TD 방식으로 계산한 return 과 advantage 값을 반환한다
def discount_reward(transitions, last_values, gamma):
    returns, advantages = [], []
    R = last_values.detach()  # always detached
    for t in reversed(range(len(transitions))):
        _, _, rewards, values, _, dones = transitions[t]
        R = torch.FloatTensor(rewards).to(device) + gamma * R * (1 - torch.FloatTensor(dones).to(device))
        baseline = values
        adv = R - baseline
        returns.append(R)
        advantages.append(adv)
    return returns[::-1], advantages[::-1]
  • .예측값은 rt + gamma * r{t+1} 이므로 받은 reward를 거꾸로 순회하면서 각 timestep 에 대한 값을 구하고 마지막엔 다시 reverse 해서 반환
  • advantage func도 예측값에서 baseline을 빼서 반환

agent.py이어서

for transition, adv in zip(transitions, advs):
            stats = {}
            # len of transition.state = num_envs
            log_valid, valid_sizes = self.network.rl_forward(transition.state, transition.valid_acts)
            act_values = log_valid.split(valid_sizes)
            # stores log_a value for each envs
            log_a = torch.stack([values[acts.index(act)]
                                        for values, acts, act in zip(act_values, transition.valid_acts, transition.act)])

            stats['loss_pg'] = - (log_a * adv.detach()).mean()
            stats['loss_td'] = adv.pow(2).mean() # L_value in paper
            stats['loss_il'] = - log_valid.mean() # not in paper & its weight set to zero!
            stats['loss_en'] = (log_valid * log_valid.exp()).mean()
            for k in stats:
                stats[k] = self.w[k] * stats[k] / len(transitions)
            stats['loss'] = sum(stats[k] for k in stats)
            stats['returns'] = torch.stack(returns).mean() / len(transitions)
            stats['advs'] = torch.stack(advs).mean() / len(transitions)
            stats['loss'].backward()
  • BertModelForWeshop (self.network)에서 나온 log_valid 값은 num_envs 값들을 하나로 이어붙인 tensor 이다 (내가 본 step에선 44) [log_softmax 취해져있음]
  • log_softmax 는 결국 논문에서 log_pi 를 의미하므로 log_a에 각 env별 log_pi 값들이 저장되어 있음
  • 이후 stats에 저장된 loss 들은 아래 부분을 구현한 것

Untitled

  • loss_il 은 실험적으로 빼고 논문에도 제외한 것 같음
  • loss_il 에 대한 wieght가 0이기 때문에 stats[’loss’]에 전체 loss 저장 후 backward 진행
  • 각 envs 들은 batch 로 생각하는 것?
  • loss 들을 transition 의 길이로 나눠주고 있음
						# Compute the gradient norm
            stats['gradnorm_unclipped'] = sum(p.grad.norm(2).item() for p in self.network.parameters() if p.grad is not None)
            nn.utils.clip_grad_norm_(self.network.parameters(), self.clip) # self.clip = max_norm // set to 10 
            stats['gradnorm_clipped'] = sum(p.grad.norm(2).item() for p in self.network.parameters() if p.grad is not None)
            for k, v in stats.items():
                stats_global[k] += v.item() if torch.is_tensor(v) else v
            del stats
        self.optimizer.step()
        self.optimizer.zero_grad()
        return stats_global
  • gradient clipped 결과와 아닌 결과를 비교
  • 결과적으론 clip 한다는 뜻?
  • transition 길이만큼 진행 후 optimizer 가동
  • 다른 부분에서 optimizer zero_grad 해주나?
				# handle done
        for i, env in enumerate(envs):
            if dones[i]:               
                ob, info = env.reset()
                if i == 0:
                    state0 = (ob, info)
                next_states[i] = agent.build_state(ob, info)
                next_valids[i] = info['valid']
        states, valids = next_states, next_valids
  • step 을돌면서 먼저 끝난 환경이 있으면 다시 reset 시켜 참여할 수 있도록
profile
0100101

0개의 댓글