RewardLLM 코드리뷰

이두현·2024년 3월 17일
0

세가지 task 가 있는데 negotiation 버전에 대해

먼저 SL 베이스라인을 돌려야 함

sl.py

model_ty = get_model_type(args.model_type)
  • models/init.py 에 있으며 model_type이 cda_rnn_model 이기 때문에 CdaRnnModel을 반환한다
  • model_ty는 models/cda_rnn_models.py의 CdaRnnModel 클래스

sl.py

corpus = model_ty.corpus_ty(
        args.data,
        freq_cutoff=args.unk_threshold,
        train=args.train_file,
        valid=args.val_file,
        test=args.test_file,
        verbose=True,
    )
  • SL 에서 corpus_ty 는 coarse_dialogue_ats/corpus.py의 Actcorpus 클래스로 선언되어있고 이는 WordCorpus 클래스를 상속받고 있다. (WordCorpus 는 utils/data 에 선언)
  • freq_cutoff : dictionary 에 존재하기 위한 최소한의 단어개수 (default 20)
  • WordCorpus 의 init 이 자동 실행되는데 여기서 word_dict, item_dict, context_dict 이 만들어짐

utils.data.py

class WordCorpus(object):
    """An utility that stores the entire dataset.

    It has the train, valid and test datasets and corresponding dictionaries.
    """

    def __init__(self, path, freq_cutoff=2, train='train.txt',
                 valid='val.txt', test='test.txt', verbose=False):
        self.verbose = verbose
        # only add words from the train dataset
        self.word_dict, self.item_dict, self.context_dict = Dictionary.from_file(
            os.path.join(path, train),
            freq_cutoff=freq_cutoff)
@classmethod
    def from_file(cls, file_name: str, freq_cutoff: int) -> Tuple['Dictionary', 'Dictionary', 'Dictionary']:
        """Constructs a dictionary from the given file."""
        assert os.path.exists(file_name)
        word_dict = cls.read_tag(file_name, 'dialogue', freq_cutoff=freq_cutoff)
        item_dict = cls.read_tag(file_name, 'output', init_dict=False)
        context_dict = cls.read_tag(file_name, 'input', init_dict=False)
        return word_dict, item_dict, context_dict
  • file_name은 data/train.txt 로 설정되어있음
@classmethod
    def read_tag(cls, file_name: str, tag: str, freq_cutoff: int = -1, init_dict: bool = True) -> 'Dictionary':
        """
        Convert the tokens between a tag in a dataset into a dictionary

        Args:
            file_name: Location of the txt file with dialogues

            tag: The XML tag which contains the tokens that you want to parse

            freq_cutoff: A minimum number of times the token must appear for it to be added to the dictionary.
                By default, all tokens that are seen at least once are added

            init_dict: If True, will initialize the dictionary with the
                default special tokens <eos>, <unk>, <selection>, and <pad>

        Returns: A Dictionary that contains all of the tokens in between the specified tag
            for all examples in the training file

        """
        token_freqs = OrderedDict()
        with open(file_name, 'r') as f:
            for line in f:
                tokens = line.strip().split()
                tokens = get_tag(tokens, tag)
                for token in tokens:
                    token_freqs[token] = token_freqs.get(token, 0) + 1
        dictionary = cls(init=init_dict)
        token_freqs = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True)
        for token, freq in token_freqs:
            if freq > freq_cutoff:
                dictionary.add_word(token)
        return dictionary
<input> 1 4 4 1 1 2 </input> <dialogue> THEM: i would like 4 hats and you can have the rest . <eos> YOU: deal <eos> THEM: <selection> </dialogue> <output> item0=1 item1=0 item2=1 item0=0 item1=4 item2=0 </output> <partner_input> 1 0 4 2 1 2 </partner_input>
def get_tag(tokens: List[str], tag: str) -> List[str]:
    """Extracts the value inside the given tag."""
    return tokens[tokens.index('<' + tag + '>') + 1:tokens.index('</' + tag + '>')]
  • train.txt 의 한줄 예시는 아래와 같이 되어있음
  • get_tag 함수는 예를 들어 태그가 있으면 그 사이 string을 가져오기 위한 함수이다.
  • 그러므로 read_tag 함수에서는 tag 사이에 주어진 tokens 를 대상으로 개수를 세고 등장빈도가 많은 순서대로 정렬한 token_freqs를 사용해 Dictionary 클래스를 인스턴스화 시킨 dictionary 변수에 add_word 함수를 사용해 dictionary를 만들어 반환한다.
  • 이를 dialogue, output, input 태그에 대해 각각 word, item, context 변수에 할당한다.

utils.data.py

wordcorpus 클래스 init 남은거 계속

# construct all 3 datasets
        self.train = self.tokenize(os.path.join(path, train)) if train else []
        self.valid = self.tokenize(os.path.join(path, valid)) if valid else []
        self.test = self.tokenize(os.path.join(path, test)) if test else []

        # find out the output length from the train dataset
        self.output_length = max([len(x[2]) for x in self.train])
  • Dictionary.from_file 이 위에 내용이고
  • dataset 이 tokenize 함수를 통해 만들어짐
  • output_length는 self.train안에 (input_idx, word_idx, item_idx) 중 item_idx 중 최대 길이로 설정
def tokenize(self, file_name: str) -> List[Example]:
        """
        Tokenize and numericalize the dataset found at filename.

        Args:
            file_name: The location of the dataset

        Returns: A list of examples. Each example contains:
            input_idxs: A numerical representation of the context, which includes the number of items
                in the game as well as the individual utilities for each item.

            word_idxs: A list of token indexes for each of the words spoken in the dialogue. This includes divider tokens
                like "YOU: ", "THEM: ", "<selection>", etc.

            item_idxs: An index representing the allocation given to the user at the end of the game
                Example index: "item0=0 item1=1 item2=2" -> 55
        """
        lines = read_lines(file_name)
        random.shuffle(lines)

        unk = self.word_dict.get_idx('<unk>')
        dataset, total, unks = [], 0, 0
        for line in lines:
            tokens = line.split()
            input_tokens = get_tag(tokens, 'input')
            dialogue_tokens = get_tag(tokens, 'dialogue')

            input_idxs = self.context_dict.w2i(input_tokens)
            word_idxs = self.get_word_indices(dialogue_tokens, input_tokens)
            item_idxs = self.item_dict.w2i(get_tag(tokens, 'output'))
            dataset.append((input_idxs, word_idxs, item_idxs))
            # compute statistics
            total += len(input_idxs) + len(word_idxs) + len(item_idxs)
            unks += np.count_nonzero([idx == unk for idx in word_idxs])

        if self.verbose:
            print('dataset %s, total %d, unks %s, ratio %0.2f%%, datapoints %d' % (
                file_name, total, unks, 100. * unks / total, len(lines)))
        return dataset
  • file_name은 data/train.txt 이런식
  • read_lines 는 list of string 을 반환
  • word_dict 은 dialogue 내용들로 만든 dictionary 이고 token에 모르는 문자를 다 넣으려 unk 변수를 선언
  • for loop 안 input_tokens 와 dialogue_tokens 는 각각 input 과 dialouge 토큰 사이의 string 을 반환(list 형태임)
  • context_dict 는 input 토큰 사이의 정보를 사전으로 만든것이기 때문에 이를 이용해 input_token을 iput_idxs 로 변환
  • get_word_indices는 마찬가지로 dialogue 토큰 사이 정보로 만든 사전으로 idx 변환하는 것인데 input_token 이 두번째 변수로 왜 들어가는지 모르겠음(함수에서도 사용되지 않음)

→ get_word_indices 는 coarse_dialogue_acts.corpus.py 에 저장된 함수를 상세히 봐야할 것이다

  • item_idx 로 마찬가지이고 dataset 에 input, dialogue, output 을 idx로 전환한 값을 저장

다시 sl.py

model = model_ty(
        corpus.word_dict,
        corpus.item_dict,
        corpus.context_dict,
        corpus.output_length,
        args,
        device_id,
    )
  • corpus 는 Actorcorpus 클래스를 의미하고 word_dict, item_dict, context_dict 은 아래에서 나온다
  • recap) model_ty는 models/cda_rnn_models.py의 CdaRnnModel 클래스

sl.py

engine = Engine(model, args, device_id, verbose=True, corpus_type=args.corpus_type)
train_loss, valid_loss, select_loss = engine.train(corpus)

utils/engine.py

def train(self, corpus, N=None, callbacks: Iterable[Callable] = []):
        """Entry point."""
        N = len(corpus.word_dict) if N is None else N
        best_model, best_valid_select_loss = None, 1e100
        lr = self.args.lr
        last_decay_epoch = 0
        self.t = 0

        validdata = corpus.valid_dataset(self.args.bsz, device_id=self.device_id)
        for epoch in range(1, self.args.max_epoch + 1):
            traindata = corpus.train_dataset(self.args.bsz, device_id=self.device_id)
  • corpus 는 utils.data.py 의 wordcorpus 를 상속받는 Actorcorpus 이며 train_dataset 함수 호출 시 부모 클래스의 wordcorpus의 해당함수가 호출됨

utils.data.py

def train_dataset(self, bsz: int, shuffle=True, device_id=None):
        return self._split_into_batches(copy.copy(self.train), bsz,
                                        shuffle=shuffle, device_id=device_id)

def _split_into_batches(self, dataset, bsz, shuffle=True, device_id=None):
        """Splits given dataset into batches."""
        if shuffle:
            random.shuffle(dataset)

        # sort and pad
        dataset.sort(key=lambda x: len(x[1]))
        pad = self.word_dict.get_idx('<pad>')

        batches = []
        stats = {
            'n': 0,
            'nonpadn': 0
        }

        for i in range(0, len(dataset), bsz):
            # groups contexes, words and items
            inputs, words, items = [], [], []
            for j in range(i, min(i + bsz, len(dataset))):
                inputs.append(dataset[j][0])
                words.append(dataset[j][1])
                items.append(dataset[j][2])

            # the longest dialogue in the batch
            max_len = len(words[-1])

            # pad all the dialogues to match the longest dialogue
            for j in range(len(words)):
                stats['n'] += max_len
                stats['nonpadn'] += len(words[j])
                words[j] += [pad] * (max_len - len(words[j]))

            # construct tensor for context
            ctx = torch.LongTensor(inputs).transpose(0, 1).contiguous()
            data = torch.LongTensor(words).transpose(0, 1).contiguous()
            # construct tensor for selection target
            sel_tgt = torch.LongTensor(items).transpose(0, 1).contiguous().view(-1)
            #if 16 in sel_tgt:
            #    print("FOUND THE FRIGGING THING")
            if device_id is not None:
                ctx = ctx.cuda(device_id)
                data = data.cuda(device_id)
                sel_tgt = sel_tgt.cuda(device_id)

            # construct tensor for input and target
            inpt = data.narrow(0, 0, data.size(0) - 1)
            tgt = data.narrow(0, 1, data.size(0) - 1).view(-1)

            batches.append((ctx, inpt, tgt, sel_tgt))

        if shuffle:
            random.shuffle(batches)

        return batches, stats
  • _ split_into_batches 에 들어가는 dataset은 (input_idx, word_idx, item_idx)의 배열로 구성되어있다
  • 미니 배치별로 inputs, words, items 배열을 만든 후 words 의 경우 가장 긴 배열 길이에 맞춰 다른 words 원소들을 padding 시킨다.
  • (inputs, words, items) 를 (ctx, data, sel_tgt) 으로


reinforce.py

device, seed setting 후

reward_ty = get_reward_type(args.reward_type)
rewarder = reward_ty()
  • utility 가 default type이고 utils/reward.py 의 UtilityRewarder class를 반환한다
  • rewarder는 UtilityRewarder 클래스이다
alice_model = utils.load_model(args.alice_model, cuda=args.cuda)
alice_ty = get_agent_type(alice_model)
alice = alice_ty(alice_model, args, name="Alice", train=True, rewarder=rewarder)
  • args.alicemodel : trained_models/sl/sl1.th (sl.py 에서 default model rnn 으로 놓고 돌린 결과)
  • utils.utils의 get_agent_type은 모델 형에 따라 CdaAgent 클래스나 RlAgent 클래스를 반환한다 (여기서 Alice 는 RlAgent이다)
  • utils.agent 의 RlAgent가 반환되며 로그기록, 선택, 모델 업데이트 등의 클래스 함수들이 있다
bob_model = utils.load_model(args.bob_model, cuda=args.cuda)
bob_ty = get_agent_type(bob_model)
bob = bob_ty(bob_model, args, name="Bob", train=False, rewarder=rewarder)
  • bob의 weight는 freeze 시켜 훈련하지 않는다
args.novelty_model = utils.load_model(args.novelty_model, cuda=args.cuda)
  • dialogue 의 novelty 점수를 매기기 위한 model이라는데 이후 언급되지 않음..
dialog = Dialog([alice, bob], args)
logger = DialogLogger(log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)
  • dialog, logger는 utils.dialog 파일에 존재, dialog는 대화를 하는 두 주체를 받아 dialog를 중계하는 class, DialogLogger는 이를 log_file에 dump 하는 클래스
  • ContextGenerator는 utils.utils 에 존재, context를 파일로부터 읽어와서 제공하는 클래스
assert alice_model.corpus_ty == bob_model.corpus_ty
corpus = alice_model.corpus_ty(
            args.data,
            freq_cutoff=args.unk_threshold,
            train=args.train_file,
            valid=args.val_file,
            test=args.test_file,
            verbose=True,
        )
  • alice_model은 DialogModel 클래스이고 corpus_ty 는 utils.data의 WordCorpus로 저장되어있음
  • train, val, test는 아래 깃헙 홈페이지에서 가져옴

GitHub - facebookresearch/end-to-end-negotiator: Deal or No Deal? End-to-End Learning for Negotiation Dialogues

engines = [
            Engine(alice_model, args, device_id, verbose=False),
            Engine(bob_model, args, device_id, verbose=False),
        ]
  • utils.engine 의 Engine class가 반환되며 training, evaluation을 담당하는 클래스이다
reinforce = Reinforce(dialog, ctx_gen, args, engines, corpus, logger, name)
reinforce.run()
  • reinforce.py 에 정의된 클래스로 에이전트간 dialogue를 진행시키며 alice를 update한다

여전히 reinforce.py

def run(self):
        """Entry point of the training."""
        # Assumes that both models are running on the same device_id
        assert self.engines[0].device_id == self.engines[1].device_id
        n = 0
        for e in range(self.args.nepoch):
            for ctxs in self.ctx_gen.iter(nepoch=1):
  • ctx_gen : ContextGenerator class (utils/utils.py)
  • ctxs는 context pair 를 포함 [[여섯 자리 context], [여섯자리 context]]

utils/utils.py

class ContextGenerator(object):
    """Dialogue context generator. Generates contexes from the file."""

    def __init__(self, context_file):
        self.ctxs = []
        with open(context_file, "r") as f:
            ctx_pair = []
            for line in f:
                ctx = line.strip().split()
                ctx_pair.append(ctx)
                if len(ctx_pair) == 2:
                    self.ctxs.append(ctx_pair)
                    ctx_pair = []

    def sample(self):
        return random.choice(self.ctxs)

    def iter(
        self, nepoch: int = None, neps: int = None, is_random=False
    ) -> Iterable[List[List[str]]]:
        """
        Iterate through all of the contexts specified in the context_file

        Args:
            nepoch: The number of times to iterate through every context in the file
            n_eps: The number of contexts to generate.

        Note: Specify either nepoch or n_eps, but not both

        Returns: A generator where each element contains a list of length 2,
            each specifying the utilities and counts for each agent in the game
        """
        if nepoch is not None and neps is not None:
            raise ValueError("Specify either number of epochs or episodes")

        if nepoch is not None:
            for e in range(nepoch):
                if is_random:
                    random.shuffle(self.ctxs)
                for ctx in self.ctxs:
                    yield ctx
        elif neps is not None:
            n = 0
            while n < neps:
                if is_random:
                    random.shuffle(self.ctxs)
                for ctx in self.ctxs:
                    yield ctx
                    n += 1
                    if n == neps:
                        break
        else:
            raise NotImplementedError
  • context_file : data/self_play_lite.txt , 각 라인마다 6개의 숫자 (물건개수, utility)
  • self.ctxs 는 context pair 로 이뤄진 배열 list
  • neps 나 nepoch 둘 중 하나만 정하면 되고 default 값으로 nepoch = 1 로 정해져 있으므로 self.ctxs가 갖고있는 순서쌍을 하나씩 반환하는 generator 를 iter 함수에서 반환한다.

reinforce.py

for ctxs in self.ctx_gen.iter(nepoch=1):
                n += 1
                self.logger.dump("=" * 80)
                self.dialog.test_prompt = "=" * 80 + "\n"
  • logger와 dialog 는 여기에서 정의됨
  • iteration을 시작하면서 ======= 를 flush
  • 클래스 변수 test_prompt 설정

reinforce.py

skip = self.dialog.run(
                    ctxs, self.logger, update=(True, False), forced=False, training=True
                )

utils.dialog.py

def run(self, ctxs, logger, update, forced=False, training=True):
        """
        Run one episode of conversation
        """
        assert len(self.agents) == len(ctxs)
        self.ctxs = ctxs
        assert len(ctxs[0]) == len(ctxs[1]) == 6
        self.feed_context(ctxs, logger, forced=forced)
        writer, reader, first_agent_index = self.choose_starting_order()
  • ctxs 는 여섯개 숫자로 이뤄진 context pair (agent 와 ctxs가 2인지 확인)
  • ran 함수 돌려서 먼저 할 사람 정하고 반환
def feed_context(self, ctxs, logger, forced=False):
        """
        Initialize agents by feeding in the contexts
        and initializing other dialogue-specific variables
        """
        # feed context
        for agent, ctx in zip(self.agents, ctxs):
            agent.feed_context(ctx)
            logger.dump_ctx(agent.name, ctx, forced=forced)
            s = " ".join(
                [
                    "%s=(count:%s value:%s)"
                    % (logger.CODE2ITEM[i][1], ctx[2 * i], ctx[2 * i + 1])
                    for i in range(3)
                ]
            )
            if agent.name == "Bob":
                self.test_prompt += f"{agent.name}   : {s}\n"
            else:
                self.test_prompt += f"{agent.name} : {s}\n"
        logger.dump("-" * 80, forced=forced)
        self.test_prompt += "-" * 80 + "\n"
  • agent 는 utils/agent 의 RlAgent 클래스를 의미

→ 이 부분이 다시 돌렸을 때 coarse_diaglogue_acts.agents.CdaAgent 임을 확인

utils/agent.py

class RlAgent(LstmAgent):
		def feed_context(self, ctx):
        super(RlAgent, self).feed_context(ctx)
class LstmAgent(Agent):
    """An agent that uses DialogModel as an AI."""

		def feed_context(self, context):
        # the hidden state of all the pronounced words
        self.lang_hs = []
        # all the pronounced words
        self.words = []
        self.context = context
        # encoded context
        self.ctx = self._encode(context, self.model.context_dict)
        # hidded state of context
        self.ctx_h = self.model.forward_context(Variable(self.ctx).to(device))
        # current hidden state of the language rnn
        self.lang_h = self.model.zero_hid(1)
def _encode(self, inpt, dictionary):
        """A helper function that encodes the passed in words using the dictionary.

        inpt: is a list of strings.
        dictionary: prebuild mapping, see Dictionary in data.py
        """
        encoded = torch.LongTensor(dictionary.w2i(inpt)).unsqueeze(1)
        if self.model.device_id is not None:
            encoded = encoded.cuda(self.model.device_id)
        return encoded
  • self.model 은 Dialogmodel 이고 SL 단계에서 훈련된 것을 initial state로 사용하기 때문에 이미 train.txt 에 대해 dictionary 가 형성되어 있다. (자세한 방식은 SL 부분 참고)
  • _encode 함수를 통해 context 를 index 로 변환 후 dimension 을 추가해 self.ctx로 반환한다

models/dialog_model


class DialogModel(modules.CudaModule):
    corpus_ty = WordCorpus

    def __init__(self, word_dict, item_dict, context_dict, output_length, args, device_id):
        super(DialogModel, self).__init__(device_id)

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.args = args

        # embedding for words
        self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word).to(device)

        # context encoder
        ctx_encoder_ty = modules.RnnContextEncoder if args.rnn_ctx_encoder \
            else modules.MlpContextEncoder
        self.ctx_encoder = ctx_encoder_ty(len(self.context_dict), domain.input_length(),
                                          args.nembed_ctx, args.nhid_ctx, args.init_range, device_id).to(device)

		def forward_context(self, ctx):
        """Run context encoder."""
        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #ctx = ctx.to(device)
        return self.ctx_encoder(ctx.to(device))
  • args.rnn_ctx_encoder 가 디폴트 false 이므로 ctx_encoder_ty 는 Mlpcontextencoder (models/modules.py에 정의)
  • forward_context 는 아래 정의된 모델에 context를 넘겨 hidden state를 주는 함수
class MlpContextEncoder(CudaModule):
    """Simple encoder for the dialogue context. Encoder counts and values via MLP."""

    def __init__(
        self, n: int, k: int, nembed: int, nhid: int, init_range: int, device_id: int
    ):
        """
        Args:
            n: The number of possible token values for the context.
            k: The number of tokens that make up a full context
            nembed: The size of the embedding layer
            nhid: The size of the hidden layer
            init_range: The range of values to initialize the parameters with
        """
        super(MlpContextEncoder, self).__init__(device_id)

        # create separate embedding for counts and values
        self.cnt_enc = nn.Embedding(n, nembed).to(device)
        self.val_enc = nn.Embedding(n, nembed).to(device)

        self.encoder = nn.Sequential(nn.Tanh(), nn.Linear(k * nembed, nhid)).to(device)

        self.cnt_enc.weight.data.uniform_(-init_range, init_range)
        self.val_enc.weight.data.uniform_(-init_range, init_range)
        init_cont(self.encoder, init_range)

    def forward(self, ctx):
        idx = np.arange(ctx.size(0) // 2)
        # extract counts and values
        cnt_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 0)))
        val_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 1)))
        cnt_idx = cnt_idx.to(device)
        val_idx = val_idx.to(device)

        cnt = ctx.index_select(0, cnt_idx)
        val = ctx.index_select(0, val_idx)

        # embed counts and values
        cnt_emb = self.cnt_enc(cnt.to(device))
        val_emb = self.val_enc(val.to(device))

        # element wise multiplication to get a hidden state
        h = torch.mul(cnt_emb, val_emb)
        # run the hidden state through the MLP
        h = h.transpose(0, 1).contiguous().view(ctx.size(1), -1)
        ctx_h = self.encoder(h.to(device)).unsqueeze(0).to(device)
        return ctx_h  # (1,1,64)
  • count 용 embedding table과 value(물건 가치) 용 embedding table 이 따로있어서 이를 통과한 뒤 element-wise 곱 후 그 값을 mlp 에 통과시키는 구조

RlAgent(utils.agent.py) 의 feed_context 이어서

def feed_context(self, ctx):
        super(RlAgent, self).feed_context(ctx)
        # save all the log probs for each generated word,
        # so we can use it later to estimate policy gradient.
        self.logprobs = []

        # generate context mask
        counts = [int(x) for x in ctx[::2]]
        vocab = self.model.word_dict.word2idx
        self.context_mask = np.zeros(len(vocab))
        for w, word in enumerate(vocab):
            if 'item' in word:
                items = word.split(" ")[1:]
                counts2 = [int(item[-1]) for item in items]
                assert len(counts) == len(counts2)
                valid = [counts[i] >= counts2[i] for i in range(len(counts))]
                if not np.all(valid):
                    self.context_mask[w] = -999.
  • self.model 은 Dialogmodel(models.dialog_model) 이고 self.model.word_dict 은 utils.data 의 Dictionary class 이다. 여기서 word2idx 는 말그대로 저장하고 있는 ordereddictionary
  • (w, word) 는 (순서, 단어)
  • 현재 돌리는 상황에서 ‘item’ 이 포함된 경우는 items 밖에 없음

이 부분 잘 몰라서 나중에 돌려봐야 할듯

utils.dialog.py 이어서

# def run() 계속
# initialize dialogue-dependent variables
        self.conv, self.agent_order = [], []
        self.metrics.reset()
        self.num_sentences = 0  # reset num_sentences
        skip = 0
  • self.metrics 는 utils/metric.py 에 있는 MetricsContainer 클래스이고 self.metrics 라는 orderedDict 에 정보들을 저장한다.
  • reset 함수를 통해 dictionary를 초기화한다

utils.dialog.py

while True:
            # produce an utterance
            out = self.write(writer, logger, forced=forced)
def write(self, writer, logger, forced=False):
        """
        Produces an utterance and saves necessary meta information
        """
        # produce an utterance
        out = writer.write()
  • writer, logger 는 self.agents element 하나씩(utils.agent 의 RlAgent) 배정한 결과

coarse_dialogue_acts.agent.py → utils.agent.py(RlAgent 가 부모)

def write(self, max_words=100) -> List[str]:
        logprobs, outs, self.lang_h, lang_hs, _, _ = self.model.write(self.lang_h, self.ctx_h,
                                                                      max_words, self.args.temperature,
                                                                      self.context_mask)
  • RlAgent 의 self.model 은 cda_rnn_model

models.cda_rnn_model.py

def write(self, lang_h: torch.Tensor, ctx_h: torch.Tensor, max_words: int, temperature: float, context_mask: List[float],
              stop_tokens: List[str] = STOP_TOKENS, resume: bool = False) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
        max_words = 1
        return super().write(lang_h, ctx_h, max_words, temperature, context_mask, stop_tokens, resume)

models.dialog_model.py (DialogModel이 부모모델)

# 들어가기 전에 
self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word).to(device)
# 말 만들어내는 generate cell 
self.writer = nn.GRUCell(
            input_size=args.nhid_ctx + args.nembed_word,
            hidden_size=args.nhid_lang,
            bias=True).to(device)
self.decoder = nn.Linear(args.nhid_lang, args.nembed_word).to(device)
def write(self, lang_h: torch.Tensor, ctx_h: torch.Tensor, max_words: int, temperature: float,
              context_mask: List[float],
              stop_tokens: List[str] = STOP_TOKENS, resume: bool = False) -> Tuple[
        List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:

				device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        outs, logprobs, lang_hs = [], [], []
        # remove batch dimension from the language and context hidden states
        lang_h = lang_h.squeeze(1)
        ctx_h = ctx_h.squeeze(1)
        inpt = None if resume else self.word2var('YOU:')
        # generate words until max_words have been generated or <selection>
        for _ in range(max_words): # max words is just 1 in our case
            if inpt:
                # add the context to the word embedding
                inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
                # update RNN state with last word
                lang_h = self.writer(inpt_emb, lang_h)
                lang_hs.append(lang_h)

            # decode words using the inverse of the word embedding matrix
            out = self.decoder(lang_h)
            scores = F.linear(out, self.word_encoder.weight).div(temperature)
            # subtract constant to avoid overflows in exponentiation
            scores = scores.add(-scores.max().item()).squeeze(0)

            # disable special tokens from being generated in a normal turns
            if not resume:
                mask = Variable(self.special_token_mask).to(device)
                scores = scores.add(mask)
                context_mask = Variable(self.to_device(torch.FloatTensor(context_mask))).to(device)
                scores = scores.add(context_mask)

            prob = F.softmax(scores, dim=0)
            logprob = F.log_softmax(scores, dim=0)

            # word = prob.multinomial(1).detach()
            word = torch.argmax(prob).unsqueeze(0).detach()
            logprob = logprob.gather(0, word)

            logprobs.append(logprob)
            outs.append(word.view(word.size()[0], 1))

            inpt = word

            # check if we generated an <eos> token
            if self.word_dict.get_word(word.item()) in stop_tokens:
                break

        # update the hidden state with the <eos> token
        inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
        lang_h = self.writer(inpt_emb, lang_h)
        lang_hs.append(lang_h)

        # add batch dimension back
        lang_h = lang_h.unsqueeze(1)

        return logprobs, torch.cat(outs), lang_h, torch.cat(lang_hs), prob, out
  • lang_h 는 word generator(GRU cell) 의 hidden state, ctx_h 는 현재 context embedding, resume 은 디폴트로 false 선언
  • max_word 는 1 로 선언되어있음
  • resume 이 false라서 inpt 는 self.word2var(’YOU:’) 이고 input_emb은 현재 you + context embedding 된 내용이 들어간다
  • GRU cell 인 self.writer에 input_emb과 hidden state(lang_h) 를 넣어 새로운 lang_h를 받는다.
  • lang_hs 는 lang_h 를 모으는 배열로 생각됨
  • self.word_encoder의 weight를 바탕으로 self.decoder 나온 out 변수에 대한 score 뽑아냄
  • 각 원소에 대해 score 최대값을 빼주어 발산을 막음
  • resume 은 default False 이므로 masking 과정 진행
  • self.special_token_mask 는 len(self.worddict) 길이의 tensor 로 '', 'YOU:', 'THEM:', '’ 중 하나에 해당하면 -999 할당
  • 위에서 나온 score 에 mask 를 더함
  • 그다음 context mask 를 더하는데 utils.agent.py 의 feed_context 함수에서 결정됨
  • score 에 softmax를 취해 prob 과 logprob를 계산하고 logprobs 에는 최대값인 logprob 를 append
  • outs에는 대응하는word(index 형태) 를 저장하고 inpt 를 word 로 다시 설정
  • 마지막 나온 단어와 다시 inpt_emb 를 만들고 GRU cell 통과시켜 lang_h hidden state 업데이트 해 append

utils.agent.py def write 이어서

def write(self, max_words=100) -> List[str]:
        logprobs, outs, self.lang_h, lang_hs, _, _ = self.model.write(self.lang_h, self.ctx_h,
                                                                      max_words, self.args.temperature,
                                                                      self.context_mask)
        # append log probs from the generated words
        self.logprobs.extend(logprobs)
        self.lang_hs.append(lang_hs)
        # first add the special 'YOU:' token
        self.words.append(self.model.word2var('YOU:').unsqueeze(0))
        # then append the utterance
        self.words.append(outs)
        assert (torch.cat(self.words).size()[0] == torch.cat(self.lang_hs).size()[0])
        return self._decode(outs, self.model.word_dict)
  • 위에서 계속 다룬 self.model.write 결과를 각자 배열 저장소에 append
  • self.words 에는 ‘YOU:’ + 생성된 단어들(index꼴)이 추가됨
  • self._decode 는 index를 영어 word로 바꿔주는 함수

다시 utils.dialog.py

def write(self, writer, logger, forced=False):
        """
        Produces an utterance and saves necessary meta information
        """
        # produce an utterance
        out = writer.write()
        if not writer.human:
            # logger.dump_sent(writer.name, out, forced=forced)
            out_with_item_names = out[0].replace("item0", "book")
            out_with_item_names = out_with_item_names.replace("item1", "hat")
            out_with_item_names = out_with_item_names.replace("item2", "ball")
  • out 으로 나온 예시는 ['propose: item0=0 item1=0 item2=3'] 이런식으로
  • writer 는 coarse_dialog_acts.agent.CdaAgent
  • write.human은 False로 되어있음
  • out_with_item_names는 propose: book=0 hat=0 ball=3 이렇게 변환됨

계속

if not self._is_selection(out):
                logger.dump_sent(writer.name, out, forced=forced)
                if writer.name == "Bob":
                    self.test_prompt += f"{writer.name}   : {out_with_item_names}\n"
                else:
                    self.test_prompt += f"{writer.name} : {out_with_item_names}\n"
  • _is _selection 함수는 out 결과가 인지 판별해주는 함수
  • test_prompt 는 이런식으로 됨
================================================================================
Alice : book=(count:1 value:0) hat=(count:1 value:1) ball=(count:3 value:3)
Bob   : book=(count:1 value:1) hat=(count:1 value:0) ball=(count:3 value:3)
--------------------------------------------------------------------------------
Bob   : propose: book=0 hat=0 ball=3
  • 나머지는 log 및 metric 저장 용
        self.conv.append(out)
        if len(self.conv) == 1 and ("propose" not in out[0] and "insist" not in out[0]):
            print("started conv with non-proposal : ", out)
            raise ValueError
        self.agent_order.append(writer.name)
        return out
  • 처음 나오는 out 일 경우 propose 나 insist 가 없으면 valueError 를 일으킨다
  • self.agent_order 에 현재 writer의 이름을 기록

utils.dialog.py 의 while 문 계속

while True:
            # produce an utterance
            out = self.write(writer, logger, forced=forced)

            # make other agent read
            self.read(reader, out)
  • reader 도 writer와 마찬가지로 coarse_dialogue_acts.agent.CdaAgent

coarse_dialogue_acts.agents.py CdaAgent → utils.agent RlAgent

RlAgent 는 LstmAgent 상속받고 있고 def read 함수가 없으므로

utils.agent LstmAgent

def read(self, inpt, prefix_token: str = 'THEM:'):
        inpt = self._encode(inpt, self.model.word_dict)
        lang_hs, self.lang_h = self.model.read(Variable(inpt).to(device), self.lang_h, self.ctx_h,
                                               prefix_token=prefix_token)
  • 여기서 inpt 은 writer 의 out

models.dialog_model.py DialogModel

def read(self, inpt: Tensor, lang_h: Tensor, ctx_h: Tensor, prefix_token: str = 'THEM:') -> Tuple[Tensor, Tensor]:
        """Reads a given utterance."""
        # inpt contains the pronounced utterance
        # add a "THEM:" token to the start of the message
        if prefix_token is not None:
            prefix = self.word2var(prefix_token).unsqueeze(0).to(device)
            inpt = inpt.to(device)
            inpt = torch.cat([prefix, inpt])

        # embed words
        inpt_emb = self.word_encoder(inpt)

        # append the context embedding to every input word embedding
        ctx_h_rep = ctx_h.expand(inpt_emb.size(0), ctx_h.size(1), ctx_h.size(2))
        inpt_emb = torch.cat([inpt_emb, ctx_h_rep], 2)

        # finally read in the words
        out, lang_h = self.reader(inpt_emb, lang_h)

        return out, lang_h
  • inpt에 THEM: 이라는 토큰을 추가
  • 각각의 inpt_emb 단어에 ctx_h 가 추가될 수 있도록 일련의 과정을 거침
  • self.reader 는 nn.GRU 모델 (self.writer가 nn.GRUcell 인 것과 차이)
  • 아마도 self.writer는 한 단어를 만들고 self.reader는 문장을 읽어내야 하기 때문으로 추정됨

utils.agent.py

def read(self, inpt, prefix_token: str = 'THEM:'):
        inpt = self._encode(inpt, self.model.word_dict)
        lang_hs, self.lang_h = self.model.read(Variable(inpt).to(device), self.lang_h, self.ctx_h,
                                               prefix_token=prefix_token)
        # append new hidden states to the current list of the hidden states
        self.lang_hs.append(lang_hs.squeeze(1))
        # first add the special 'THEM:' token
        # self.words.append(self.model.word2var('THEM:').unsqueeze(0))
        self.words.append(self.model.word2var(prefix_token).unsqueeze(0))
        # then read the utterance
        self.words.append(Variable(inpt).to(device))
        assert (torch.cat(self.words).size()[0] == torch.cat(self.lang_hs).size()[0])
  • lang_hs 와 self.lang_h 는 nn.GRU 에서 나온 final state와 hidden state
  • final state를 self.lang_hs 에 추가
  • self.words에 prefix token 인 ‘THEM:’ 과 writer가 생성한 inpt 을 append

utils.agent

while True:
            # produce an utterance
            out = self.write(writer, logger, forced=forced)

            # make other agent read
            self.read(reader, out)
						
						if self.is_end(out, writer, reader):
                break

            # swap roles
            writer, reader = reader, writer

		choices = self.generate_choices(ctxs, self.agents, logger, forced=forced)
  • is_end 는 max_sentence를 넘었는지, 으로 끝났는지를 판단

utils.agent

def generate_choices(self, ctxs, agents, logger, forced):
        """
        Generate final choices for each agent
        """
        choices = []
        # generate choices for each of the agents
        for agent in agents:
            choice = None
            if agent.name == "Alice" or agent.name == "Bob" or agent.name == "Expert":
                choice = agent.choose(self.conv, self.agent_order, ctxs)
            elif agent.name == "Human":
                choice = agent.choose()
            choices.append(choice)
        return choices
  • agent는 coarse_dialogue_acts.agent.CdaAgent

uitls.agent.py

def choose(self, conv, agent_order, ctxs, handcode='partial') -> List[str]:
				if handcode == 'full':
            return self.choose_handcoded(conv, agent_order, ctxs)
        elif handcode == 'partial':
            # Sample a choice
            choice, logprob, _ = self._choose(sample=False)
  • choose 로 들어오는 utils.dialog.py 의 self.conv 에는 out 에서 나온 내용들이 저장되어있다.

→ [['propose: item0=0 item1=0 item2=3'], ['agree'], ['']] (self.conv 의 모습)

def _choose(self, lang_hs=None, words=None, sample=False):
        # get all the possible choices
        choices = self.domain.generate_choices(self.context)
  • self.domain 은 utils.domain.ObjectDivisionDomain

utils.domain.py

def generate_choices(self, inpt: List[str]) -> List[List[str]]:
        """
        Given the number of items in a game, outputs all possible divisions of items
        Args:
            inpt: A parsed game context from the FB data set

        Returns: A list of token sequences, each representing a possible selection
        """
        cnts, _ = self.parse_context(inpt)

        def gen(cnts, idx=0, choice=[]):
            if idx >= len(cnts):
                left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)]
                right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))]
                return [left_choice + right_choice]
            choices = []
            for c in range(cnts[idx] + 1):
                choice.append(c)
                choices += gen(cnts, idx + 1, choice)
                choice.pop()
            return choices

        choices = gen(cnts)
        choices.append(['<no_agreement>'] * self.selection_length())
        #choices.append(['<disconnect>'] * self.selection_length())
        return choices
  • self.parse_context 로 item 개수만 추출 후 cnts 에 반환
  • 중첩함수 gen은 choices에 ['item0=0', 'item1=0', 'item2=0', 'item0=1', 'item1=1', 'item2=3'] 과 같은 조합의 배열을 반환
  • 끝에 한 원소와 길이가 같은 <no_agreement> 반복으로 배열 마무리 후 반환

utils.agent.py _choose 함수 계속

def _choose(self, lang_hs=None, words=None, sample=False):
        # get all the possible choices
        choices = self.domain.generate_choices(self.context)
        # concatenate the list of the hidden states into one tensor
        # print('debug', self.name, self.lang_hs, self.words)
        lang_hs = lang_hs if lang_hs is not None else torch.cat(self.lang_hs)
        # concatenate all the words into one tensor
        words = words if words is not None else torch.cat(self.words)
        # logits for each of the item
        logits = self.model.generate_choice_logits(words, lang_hs, self.ctx_h)
  • self.lang_hs : GRUcell 의 hidden state 모음
  • self.words : conversation 하면서 나눈 대화 모음 (out 들의 집합)
  • self.model 은 CdaRnnModel 인데 DialogModel 의 generate_choice_logits 함수를 호출

models.dialog_model.py

def generate_choice_logits(self, inpt, lang_h, ctx_h):
        """Similar to forward_selection, but is used while selfplaying.
        Thus it is dealing with batches of size 1.
        """
        # run a birnn over the concatenation of the input embeddings and
        # language model hidden states
        inpt_emb = self.word_encoder(inpt)
        h = torch.cat([lang_h.unsqueeze(1), inpt_emb], 2)
        h = self.dropout(h)

        # runs selection rnn over the hidden state h
        attn_h = self.zero_hid(h.size(1), self.args.nhid_attn, copies=2)
        h, _ = self.sel_rnn(h, attn_h)
        h = h.squeeze(1)

        # perform attention
        logit = self.attn(h).squeeze(1)
        prob = F.softmax(logit, dim=0).unsqueeze(1).expand_as(h)
        attn = torch.sum(torch.mul(h, prob), 0, keepdim=True)

        # concatenate attention and context hidden and pass it to the selection encoder
        ctx_h = ctx_h.squeeze(1)
        h = torch.cat([attn, ctx_h], 1)
        h = self.sel_encoder.forward(h)

        # generate logits for each item separately
        logits = [decoder.forward(h).squeeze(0) for decoder in self.sel_decoders]
        return logits
  • self.word_encoder 는 [258, 256] 크기의 embedding table
  • h 는 input embedding 과 write, read 시 생성된 GRU hidden state를 concat
  • zero_hid 함수로 self.sel_rnn (nn.GRU) 의 hidden state 담을 variable 만듦
  • self.attn 은 FC layer 두개로 이뤄짐
  • self.sel_encoder 는 1-layer 에 tanh activation 의 합
  • self.sel_decoder 는 output dimension 이 self.item_dict 길이인 1-layer FC
  • decoder 의 개수는 item 개수

utils.agent.py def _choose 계속

def _choose(self, lang_hs=None, words=None, sample=False):
				# construct probability distribution over only the valid choices
        choices_logits = []
        for i in range(self.domain.selection_length()):
            idxs = [self.model.item_dict.get_idx(c[i]) for c in choices]
            idxs = Variable(torch.from_numpy(np.array(idxs))).to(device)
            idxs = self.model.to_device(idxs)
            choices_logits.append(torch.gather(logits[i], 0, idxs).unsqueeze(1))

        choice_logit = torch.sum(torch.cat(choices_logits, 1), 1, keepdim=False)
        # subtract the max to softmax more stable
        choice_logit = choice_logit.sub(choice_logit.max().item())
        prob = F.softmax(choice_logit, dim=0)
        if sample:
            # sample a choice
            idx = prob.multinomial(1).detach()
            logprob = F.log_softmax(choice_logit).gather(0, idx)
        else:
            # take the most probably choice
            _, idx = prob.max(0, keepdim=True)
            logprob = None

        p_agree = prob[idx.item()]

        # Pick only your choice
        return choices[idx.item()][:self.domain.selection_length()], logprob, p_agree.item()
  • 아… 너무 많다 패스 (나중에 코드 짤 때 다시보기)

다시 utils.agent.py def choose 계속

# def choose 계속
						if conv[-1][0] != '<selection>':  # reached max_sentences
                return choice

            # Check case 1 (all agrees after the final proposal; choose this final proposal )
            assert conv[-1][0] == '<selection>'
            partial_convo = []
            for i, sent in reversed(list(enumerate(conv[0:-1]))):
                if 'propose' in sent[0] or 'insist' in sent[
                    0]:  # if agents disagreed after final proposal, then no agreement
                    if set(partial_convo) == {
                        'agree'}:  # if all the words after the final proposal were 'agree', then hard code selection to be final proposal
                        choice = self._get_choice(i, sent, agent_order, ctxs)
                        return choice
                    else:  # otherwise, just break
                        break
                else:
                    partial_convo.append(sent[0])
            # assert i != 0  # assert that propose or insist was in conversation

					# Check case 2 (agent propose -> other ends. Then must choose proposal agent proposed)
            end_agent = agent_order[-1]  # get person who ended
            # if ('propose' in conv[-2][0] or 'insist' in conv[-2][0]) and end_agent in ['Human', 'Bob']:
            if len(conv) > 1 and ('propose' in conv[-2][0] or 'insist' in conv[-2][0]) and end_agent != self.name:
                assert agent_order[-2] == self.name
                choice = self._get_choice(-2, conv[-2], agent_order, ctxs)
                return choice

            # Check case 3 where choice is no agreement
            if '<no_agreement>' in choice:
                return choice

            # Finally, make sure that agent selects previous selection
            for i, sent in reversed(list(enumerate(conv[0:-1]))):
                if agent_order[i] == self.name:  # find agent's last proposal
                    if 'propose' in sent[0] or 'insist' in sent[
                        0]:  # if agents disagreed after final proposal, then no agreement
                        choice = self._get_choice(i, sent, agent_order, ctxs)
                        return choice
                    elif 'agree' in sent[0]:
                        for j, sent2 in reversed(list(enumerate(conv[0:i]))):
                            if 'propose' in sent2[0] or 'insist' in sent2[
                                0]:  # if agents disagreed after final proposal, then no agreement
                                choice = self._get_choice(j, sent2, agent_order, ctxs)
                                return choice

                #    if agent_order[i] == 'Alice': # find Alice's last proposal
                #        choice = self._get_choice(i, sent, agent_order, ctxs)
                #        return choice
            # print("conv: ", conv)
            # print("agent order: ", agent_order)
            return choice
            # raise ValueError
  • 이전의 conversation을 역순환
  • 세가지 케이스와 additional 경우를 따지고 있음
  • choice 결과는 ['item0=1', 'item1=1', 'item2=0', 'item0=0', 'item1=0', 'item2=3'] 이런식으로

utils.dialog.py def generate_choices 계속

def generate_choices(self, ctxs, agents, logger, forced):
        """
        Generate final choices for each agent
        """
        choices = []
        # generate choices for each of the agents
        for agent in agents:
            choice = None
            if agent.name == "Alice" or agent.name == "Bob" or agent.name == "Expert":
                choice = agent.choose(self.conv, self.agent_order, ctxs)
            elif agent.name == "Human":
                choice = agent.choose()
            choices.append(choice)
        return choices
  • 두 agent 들에 대해 각각 choice 결과를 담은 choices 배열을 return

utils.dialog.py def run 함수

# evaluate the choices, produce agreement and a reward
        agree, rewards = self.evaluate_choices(
            choices, ctxs, update, logger, forced, training
        )

utils.dialog.py

def evaluate_choices(self, choices, ctxs, update, logger, forced, training):
        """
        Evaluate the choices, produce agreement and a reward
        :return:
        """
        # evaluate the choices, produce agreement and a reward
        agree, rewards = self.domain.score_choices(choices, ctxs)
        assert len(rewards) == 2  # Otherwise, the partner_reward will be incorrect
        logger.dump("-" * 80, forced=forced)
        # self.test_prompt += f"Alice : book={choices[0][0][-1]} hat={choices[0][1][-1]} ball={choices[0][2][-1]}\n"
        # self.test_prompt += f"Bob   : book={choices[0][3][-1]} hat={choices[0][4][-1]} ball={choices[0][5][-1]}\n"
        self.test_prompt += "-" * 80 + "\n"
        logger.dump_agreement(agree, forced=forced)
        self.test_prompt += "Agreement!\n" if agree else "Disagreement?!\n"
        for i, (agent, reward) in enumerate(zip(self.agents, rewards)):
            if agent.name == "Bob":
                self.test_prompt += f"{agent.name}   : {reward} points\n"
            else:
                self.test_prompt += f"{agent.name} : {reward} points\n"
  • score_choices 는 chocies와 context 바탕으로 각 agent 가 얻게 되는 점수를 반환
## add gpt3 style reward
        if training:
            # GPT3 REWARDS
            if self.model == "gpt3":
                style_rewards = self.gpt3_reward()
                if style_rewards == -1:
                    return -1, -1

            # GPT2 REWARDS
            elif self.model == "gpt2":
                style_rewards = self.gpt2_reward()
                if style_rewards == -1:
                    return -1, -1
  • self.model 이 gpt2 라 가정
  • style_rewards ==-1 은 parse 결과에 yes, no 포함되지 않은 경우를 의미

utils.dialog.py

def gpt2_reward(self):
        self.test_prompt += self.question
        final_prompt = self.base_prompt + self.test_prompt
        lm = pipeline("text-generation", model="gpt2", device='cuda:0')
        set_seed(10)
        full_response = lm(final_prompt, max_length=750, num_return_sequences=1)[0][
            "generated_text"
        ]
        response = full_response[len(final_prompt) :]
        parsed_response = response.lower().strip().split(" ")[0]
        if "no" in parsed_response:
            self.gpt3_answers.append((self.test_prompt, parsed_response, 0))
            return [0, 0]
        elif "yes" in parsed_response:
            self.gpt3_answers.append((self.test_prompt, parsed_response, 1))
            return [10, 10]
        else:
            print(f"cannot parse lm answer!: {response}")
            return -1
  • self.question 은 선택한 스타일에 따라 ‘Is Alice a competitive negotiator?’ 와 같이
  • final prompt 모습은 논문에 나온것과 같이 few-shot + choice 함수에 의한 결정사항과 그에 따른 reward + alice 가 [스타일]에 부합하는지 묻는 내용
  • 대답안에 yes, no 포함여부에 따라 점수부여

utils.dialog.py

# def evaluate_choices 계속
rewards[0] = style_rewards[0]
rewards[1] = style_rewards[1]

pareto = 0.0
        if agree:
            self.metrics.record("advantage", rewards[0] - rewards[1])
            pareto = self.metrics.record_pareto("pareto", ctxs, rewards)
        self.update_agents(agree, rewards, update, logger, forced=forced, pareto=pareto)
        self.metrics.record("time")
        self.metrics.record("dialog_len", len(self.conv))
        self.metrics.record("agree", int(agree))
        self.metrics.record("comb_rew", np.sum(rewards) if agree else 0)
        for agent, reward in zip(self.agents, rewards):
            self.metrics.record("%s_rew" % agent.name, reward if agree else 0)
            self.metrics.record_end_of_dialogue("%s_diversity" % agent.name)
        return agree, rewards
  • 나온 두 reward를 각각 할당
  • metric 기록 사이에 agent update 함수 끼어 있음

utils.agent.py

def update(self, agree: bool, reward: float, partner_reward: float = None, pareto=None):
        if not self.train:
            return

        self.t += 1

        if len(self.logprobs) == 0:
            return

        r = self.rewarder.calc_reward(agree, reward, partner_reward)

        # compute accumulated discounted reward
        g = Variable(torch.zeros(1, 1).fill_(r)).to(device)
        rewards = []
        for _ in self.logprobs:
            rewards.insert(0, g)
            g = g * self.args.gamma

        loss = 0
        # estimate the loss using one MonteCarlo rollout
        for lp, r in zip(self.logprobs, rewards):
            loss -= lp * r

        self.opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.args.rl_clip)

        # if self.args.visual and self.t % 10 == 0:
        # self.model_plot.update(self.t)
        # self.reward_plot.update('reward', self.t, reward)
        # self.loss_plot.update('loss', self.t, loss.item())
        # wandb.log({'rl reward': reward, 'rl loss':loss.item()})

        self.opt.step()
profile
0100101

0개의 댓글