세가지 task 가 있는데 negotiation 버전에 대해
먼저 SL 베이스라인을 돌려야 함
sl.py
model_ty = get_model_type(args.model_type)
sl.py
corpus = model_ty.corpus_ty(
args.data,
freq_cutoff=args.unk_threshold,
train=args.train_file,
valid=args.val_file,
test=args.test_file,
verbose=True,
)
utils.data.py
class WordCorpus(object):
"""An utility that stores the entire dataset.
It has the train, valid and test datasets and corresponding dictionaries.
"""
def __init__(self, path, freq_cutoff=2, train='train.txt',
valid='val.txt', test='test.txt', verbose=False):
self.verbose = verbose
# only add words from the train dataset
self.word_dict, self.item_dict, self.context_dict = Dictionary.from_file(
os.path.join(path, train),
freq_cutoff=freq_cutoff)
@classmethod
def from_file(cls, file_name: str, freq_cutoff: int) -> Tuple['Dictionary', 'Dictionary', 'Dictionary']:
"""Constructs a dictionary from the given file."""
assert os.path.exists(file_name)
word_dict = cls.read_tag(file_name, 'dialogue', freq_cutoff=freq_cutoff)
item_dict = cls.read_tag(file_name, 'output', init_dict=False)
context_dict = cls.read_tag(file_name, 'input', init_dict=False)
return word_dict, item_dict, context_dict
@classmethod
def read_tag(cls, file_name: str, tag: str, freq_cutoff: int = -1, init_dict: bool = True) -> 'Dictionary':
"""
Convert the tokens between a tag in a dataset into a dictionary
Args:
file_name: Location of the txt file with dialogues
tag: The XML tag which contains the tokens that you want to parse
freq_cutoff: A minimum number of times the token must appear for it to be added to the dictionary.
By default, all tokens that are seen at least once are added
init_dict: If True, will initialize the dictionary with the
default special tokens <eos>, <unk>, <selection>, and <pad>
Returns: A Dictionary that contains all of the tokens in between the specified tag
for all examples in the training file
"""
token_freqs = OrderedDict()
with open(file_name, 'r') as f:
for line in f:
tokens = line.strip().split()
tokens = get_tag(tokens, tag)
for token in tokens:
token_freqs[token] = token_freqs.get(token, 0) + 1
dictionary = cls(init=init_dict)
token_freqs = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True)
for token, freq in token_freqs:
if freq > freq_cutoff:
dictionary.add_word(token)
return dictionary
<input> 1 4 4 1 1 2 </input> <dialogue> THEM: i would like 4 hats and you can have the rest . <eos> YOU: deal <eos> THEM: <selection> </dialogue> <output> item0=1 item1=0 item2=1 item0=0 item1=4 item2=0 </output> <partner_input> 1 0 4 2 1 2 </partner_input>
def get_tag(tokens: List[str], tag: str) -> List[str]:
"""Extracts the value inside the given tag."""
return tokens[tokens.index('<' + tag + '>') + 1:tokens.index('</' + tag + '>')]
utils.data.py
wordcorpus 클래스 init 남은거 계속
# construct all 3 datasets
self.train = self.tokenize(os.path.join(path, train)) if train else []
self.valid = self.tokenize(os.path.join(path, valid)) if valid else []
self.test = self.tokenize(os.path.join(path, test)) if test else []
# find out the output length from the train dataset
self.output_length = max([len(x[2]) for x in self.train])
def tokenize(self, file_name: str) -> List[Example]:
"""
Tokenize and numericalize the dataset found at filename.
Args:
file_name: The location of the dataset
Returns: A list of examples. Each example contains:
input_idxs: A numerical representation of the context, which includes the number of items
in the game as well as the individual utilities for each item.
word_idxs: A list of token indexes for each of the words spoken in the dialogue. This includes divider tokens
like "YOU: ", "THEM: ", "<selection>", etc.
item_idxs: An index representing the allocation given to the user at the end of the game
Example index: "item0=0 item1=1 item2=2" -> 55
"""
lines = read_lines(file_name)
random.shuffle(lines)
unk = self.word_dict.get_idx('<unk>')
dataset, total, unks = [], 0, 0
for line in lines:
tokens = line.split()
input_tokens = get_tag(tokens, 'input')
dialogue_tokens = get_tag(tokens, 'dialogue')
input_idxs = self.context_dict.w2i(input_tokens)
word_idxs = self.get_word_indices(dialogue_tokens, input_tokens)
item_idxs = self.item_dict.w2i(get_tag(tokens, 'output'))
dataset.append((input_idxs, word_idxs, item_idxs))
# compute statistics
total += len(input_idxs) + len(word_idxs) + len(item_idxs)
unks += np.count_nonzero([idx == unk for idx in word_idxs])
if self.verbose:
print('dataset %s, total %d, unks %s, ratio %0.2f%%, datapoints %d' % (
file_name, total, unks, 100. * unks / total, len(lines)))
return dataset
→ get_word_indices 는 coarse_dialogue_acts.corpus.py 에 저장된 함수를 상세히 봐야할 것이다
다시 sl.py
model = model_ty(
corpus.word_dict,
corpus.item_dict,
corpus.context_dict,
corpus.output_length,
args,
device_id,
)
sl.py
engine = Engine(model, args, device_id, verbose=True, corpus_type=args.corpus_type)
train_loss, valid_loss, select_loss = engine.train(corpus)
utils/engine.py
def train(self, corpus, N=None, callbacks: Iterable[Callable] = []):
"""Entry point."""
N = len(corpus.word_dict) if N is None else N
best_model, best_valid_select_loss = None, 1e100
lr = self.args.lr
last_decay_epoch = 0
self.t = 0
validdata = corpus.valid_dataset(self.args.bsz, device_id=self.device_id)
for epoch in range(1, self.args.max_epoch + 1):
traindata = corpus.train_dataset(self.args.bsz, device_id=self.device_id)
utils.data.py
def train_dataset(self, bsz: int, shuffle=True, device_id=None):
return self._split_into_batches(copy.copy(self.train), bsz,
shuffle=shuffle, device_id=device_id)
def _split_into_batches(self, dataset, bsz, shuffle=True, device_id=None):
"""Splits given dataset into batches."""
if shuffle:
random.shuffle(dataset)
# sort and pad
dataset.sort(key=lambda x: len(x[1]))
pad = self.word_dict.get_idx('<pad>')
batches = []
stats = {
'n': 0,
'nonpadn': 0
}
for i in range(0, len(dataset), bsz):
# groups contexes, words and items
inputs, words, items = [], [], []
for j in range(i, min(i + bsz, len(dataset))):
inputs.append(dataset[j][0])
words.append(dataset[j][1])
items.append(dataset[j][2])
# the longest dialogue in the batch
max_len = len(words[-1])
# pad all the dialogues to match the longest dialogue
for j in range(len(words)):
stats['n'] += max_len
stats['nonpadn'] += len(words[j])
words[j] += [pad] * (max_len - len(words[j]))
# construct tensor for context
ctx = torch.LongTensor(inputs).transpose(0, 1).contiguous()
data = torch.LongTensor(words).transpose(0, 1).contiguous()
# construct tensor for selection target
sel_tgt = torch.LongTensor(items).transpose(0, 1).contiguous().view(-1)
#if 16 in sel_tgt:
# print("FOUND THE FRIGGING THING")
if device_id is not None:
ctx = ctx.cuda(device_id)
data = data.cuda(device_id)
sel_tgt = sel_tgt.cuda(device_id)
# construct tensor for input and target
inpt = data.narrow(0, 0, data.size(0) - 1)
tgt = data.narrow(0, 1, data.size(0) - 1).view(-1)
batches.append((ctx, inpt, tgt, sel_tgt))
if shuffle:
random.shuffle(batches)
return batches, stats
device, seed setting 후
reward_ty = get_reward_type(args.reward_type)
rewarder = reward_ty()
alice_model = utils.load_model(args.alice_model, cuda=args.cuda)
alice_ty = get_agent_type(alice_model)
alice = alice_ty(alice_model, args, name="Alice", train=True, rewarder=rewarder)
bob_model = utils.load_model(args.bob_model, cuda=args.cuda)
bob_ty = get_agent_type(bob_model)
bob = bob_ty(bob_model, args, name="Bob", train=False, rewarder=rewarder)
args.novelty_model = utils.load_model(args.novelty_model, cuda=args.cuda)
dialog = Dialog([alice, bob], args)
logger = DialogLogger(log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)
assert alice_model.corpus_ty == bob_model.corpus_ty
corpus = alice_model.corpus_ty(
args.data,
freq_cutoff=args.unk_threshold,
train=args.train_file,
valid=args.val_file,
test=args.test_file,
verbose=True,
)
engines = [
Engine(alice_model, args, device_id, verbose=False),
Engine(bob_model, args, device_id, verbose=False),
]
reinforce = Reinforce(dialog, ctx_gen, args, engines, corpus, logger, name)
reinforce.run()
여전히 reinforce.py
def run(self):
"""Entry point of the training."""
# Assumes that both models are running on the same device_id
assert self.engines[0].device_id == self.engines[1].device_id
n = 0
for e in range(self.args.nepoch):
for ctxs in self.ctx_gen.iter(nepoch=1):
utils/utils.py
class ContextGenerator(object):
"""Dialogue context generator. Generates contexes from the file."""
def __init__(self, context_file):
self.ctxs = []
with open(context_file, "r") as f:
ctx_pair = []
for line in f:
ctx = line.strip().split()
ctx_pair.append(ctx)
if len(ctx_pair) == 2:
self.ctxs.append(ctx_pair)
ctx_pair = []
def sample(self):
return random.choice(self.ctxs)
def iter(
self, nepoch: int = None, neps: int = None, is_random=False
) -> Iterable[List[List[str]]]:
"""
Iterate through all of the contexts specified in the context_file
Args:
nepoch: The number of times to iterate through every context in the file
n_eps: The number of contexts to generate.
Note: Specify either nepoch or n_eps, but not both
Returns: A generator where each element contains a list of length 2,
each specifying the utilities and counts for each agent in the game
"""
if nepoch is not None and neps is not None:
raise ValueError("Specify either number of epochs or episodes")
if nepoch is not None:
for e in range(nepoch):
if is_random:
random.shuffle(self.ctxs)
for ctx in self.ctxs:
yield ctx
elif neps is not None:
n = 0
while n < neps:
if is_random:
random.shuffle(self.ctxs)
for ctx in self.ctxs:
yield ctx
n += 1
if n == neps:
break
else:
raise NotImplementedError
reinforce.py
for ctxs in self.ctx_gen.iter(nepoch=1):
n += 1
self.logger.dump("=" * 80)
self.dialog.test_prompt = "=" * 80 + "\n"
reinforce.py
skip = self.dialog.run(
ctxs, self.logger, update=(True, False), forced=False, training=True
)
utils.dialog.py
def run(self, ctxs, logger, update, forced=False, training=True):
"""
Run one episode of conversation
"""
assert len(self.agents) == len(ctxs)
self.ctxs = ctxs
assert len(ctxs[0]) == len(ctxs[1]) == 6
self.feed_context(ctxs, logger, forced=forced)
writer, reader, first_agent_index = self.choose_starting_order()
def feed_context(self, ctxs, logger, forced=False):
"""
Initialize agents by feeding in the contexts
and initializing other dialogue-specific variables
"""
# feed context
for agent, ctx in zip(self.agents, ctxs):
agent.feed_context(ctx)
logger.dump_ctx(agent.name, ctx, forced=forced)
s = " ".join(
[
"%s=(count:%s value:%s)"
% (logger.CODE2ITEM[i][1], ctx[2 * i], ctx[2 * i + 1])
for i in range(3)
]
)
if agent.name == "Bob":
self.test_prompt += f"{agent.name} : {s}\n"
else:
self.test_prompt += f"{agent.name} : {s}\n"
logger.dump("-" * 80, forced=forced)
self.test_prompt += "-" * 80 + "\n"
→ 이 부분이 다시 돌렸을 때 coarse_diaglogue_acts.agents.CdaAgent 임을 확인
utils/agent.py
class RlAgent(LstmAgent):
def feed_context(self, ctx):
super(RlAgent, self).feed_context(ctx)
class LstmAgent(Agent):
"""An agent that uses DialogModel as an AI."""
def feed_context(self, context):
# the hidden state of all the pronounced words
self.lang_hs = []
# all the pronounced words
self.words = []
self.context = context
# encoded context
self.ctx = self._encode(context, self.model.context_dict)
# hidded state of context
self.ctx_h = self.model.forward_context(Variable(self.ctx).to(device))
# current hidden state of the language rnn
self.lang_h = self.model.zero_hid(1)
def _encode(self, inpt, dictionary):
"""A helper function that encodes the passed in words using the dictionary.
inpt: is a list of strings.
dictionary: prebuild mapping, see Dictionary in data.py
"""
encoded = torch.LongTensor(dictionary.w2i(inpt)).unsqueeze(1)
if self.model.device_id is not None:
encoded = encoded.cuda(self.model.device_id)
return encoded
models/dialog_model
class DialogModel(modules.CudaModule):
corpus_ty = WordCorpus
def __init__(self, word_dict, item_dict, context_dict, output_length, args, device_id):
super(DialogModel, self).__init__(device_id)
domain = get_domain(args.domain)
self.word_dict = word_dict
self.item_dict = item_dict
self.context_dict = context_dict
self.args = args
# embedding for words
self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word).to(device)
# context encoder
ctx_encoder_ty = modules.RnnContextEncoder if args.rnn_ctx_encoder \
else modules.MlpContextEncoder
self.ctx_encoder = ctx_encoder_ty(len(self.context_dict), domain.input_length(),
args.nembed_ctx, args.nhid_ctx, args.init_range, device_id).to(device)
def forward_context(self, ctx):
"""Run context encoder."""
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#ctx = ctx.to(device)
return self.ctx_encoder(ctx.to(device))
class MlpContextEncoder(CudaModule):
"""Simple encoder for the dialogue context. Encoder counts and values via MLP."""
def __init__(
self, n: int, k: int, nembed: int, nhid: int, init_range: int, device_id: int
):
"""
Args:
n: The number of possible token values for the context.
k: The number of tokens that make up a full context
nembed: The size of the embedding layer
nhid: The size of the hidden layer
init_range: The range of values to initialize the parameters with
"""
super(MlpContextEncoder, self).__init__(device_id)
# create separate embedding for counts and values
self.cnt_enc = nn.Embedding(n, nembed).to(device)
self.val_enc = nn.Embedding(n, nembed).to(device)
self.encoder = nn.Sequential(nn.Tanh(), nn.Linear(k * nembed, nhid)).to(device)
self.cnt_enc.weight.data.uniform_(-init_range, init_range)
self.val_enc.weight.data.uniform_(-init_range, init_range)
init_cont(self.encoder, init_range)
def forward(self, ctx):
idx = np.arange(ctx.size(0) // 2)
# extract counts and values
cnt_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 0)))
val_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 1)))
cnt_idx = cnt_idx.to(device)
val_idx = val_idx.to(device)
cnt = ctx.index_select(0, cnt_idx)
val = ctx.index_select(0, val_idx)
# embed counts and values
cnt_emb = self.cnt_enc(cnt.to(device))
val_emb = self.val_enc(val.to(device))
# element wise multiplication to get a hidden state
h = torch.mul(cnt_emb, val_emb)
# run the hidden state through the MLP
h = h.transpose(0, 1).contiguous().view(ctx.size(1), -1)
ctx_h = self.encoder(h.to(device)).unsqueeze(0).to(device)
return ctx_h # (1,1,64)
RlAgent(utils.agent.py) 의 feed_context 이어서
def feed_context(self, ctx):
super(RlAgent, self).feed_context(ctx)
# save all the log probs for each generated word,
# so we can use it later to estimate policy gradient.
self.logprobs = []
# generate context mask
counts = [int(x) for x in ctx[::2]]
vocab = self.model.word_dict.word2idx
self.context_mask = np.zeros(len(vocab))
for w, word in enumerate(vocab):
if 'item' in word:
items = word.split(" ")[1:]
counts2 = [int(item[-1]) for item in items]
assert len(counts) == len(counts2)
valid = [counts[i] >= counts2[i] for i in range(len(counts))]
if not np.all(valid):
self.context_mask[w] = -999.
이 부분 잘 몰라서 나중에 돌려봐야 할듯
# def run() 계속
# initialize dialogue-dependent variables
self.conv, self.agent_order = [], []
self.metrics.reset()
self.num_sentences = 0 # reset num_sentences
skip = 0
utils.dialog.py
while True:
# produce an utterance
out = self.write(writer, logger, forced=forced)
def write(self, writer, logger, forced=False):
"""
Produces an utterance and saves necessary meta information
"""
# produce an utterance
out = writer.write()
coarse_dialogue_acts.agent.py → utils.agent.py(RlAgent 가 부모)
def write(self, max_words=100) -> List[str]:
logprobs, outs, self.lang_h, lang_hs, _, _ = self.model.write(self.lang_h, self.ctx_h,
max_words, self.args.temperature,
self.context_mask)
models.cda_rnn_model.py
def write(self, lang_h: torch.Tensor, ctx_h: torch.Tensor, max_words: int, temperature: float, context_mask: List[float],
stop_tokens: List[str] = STOP_TOKENS, resume: bool = False) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
max_words = 1
return super().write(lang_h, ctx_h, max_words, temperature, context_mask, stop_tokens, resume)
models.dialog_model.py (DialogModel이 부모모델)
# 들어가기 전에
self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word).to(device)
# 말 만들어내는 generate cell
self.writer = nn.GRUCell(
input_size=args.nhid_ctx + args.nembed_word,
hidden_size=args.nhid_lang,
bias=True).to(device)
self.decoder = nn.Linear(args.nhid_lang, args.nembed_word).to(device)
def write(self, lang_h: torch.Tensor, ctx_h: torch.Tensor, max_words: int, temperature: float,
context_mask: List[float],
stop_tokens: List[str] = STOP_TOKENS, resume: bool = False) -> Tuple[
List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
outs, logprobs, lang_hs = [], [], []
# remove batch dimension from the language and context hidden states
lang_h = lang_h.squeeze(1)
ctx_h = ctx_h.squeeze(1)
inpt = None if resume else self.word2var('YOU:')
# generate words until max_words have been generated or <selection>
for _ in range(max_words): # max words is just 1 in our case
if inpt:
# add the context to the word embedding
inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
# update RNN state with last word
lang_h = self.writer(inpt_emb, lang_h)
lang_hs.append(lang_h)
# decode words using the inverse of the word embedding matrix
out = self.decoder(lang_h)
scores = F.linear(out, self.word_encoder.weight).div(temperature)
# subtract constant to avoid overflows in exponentiation
scores = scores.add(-scores.max().item()).squeeze(0)
# disable special tokens from being generated in a normal turns
if not resume:
mask = Variable(self.special_token_mask).to(device)
scores = scores.add(mask)
context_mask = Variable(self.to_device(torch.FloatTensor(context_mask))).to(device)
scores = scores.add(context_mask)
prob = F.softmax(scores, dim=0)
logprob = F.log_softmax(scores, dim=0)
# word = prob.multinomial(1).detach()
word = torch.argmax(prob).unsqueeze(0).detach()
logprob = logprob.gather(0, word)
logprobs.append(logprob)
outs.append(word.view(word.size()[0], 1))
inpt = word
# check if we generated an <eos> token
if self.word_dict.get_word(word.item()) in stop_tokens:
break
# update the hidden state with the <eos> token
inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1)
lang_h = self.writer(inpt_emb, lang_h)
lang_hs.append(lang_h)
# add batch dimension back
lang_h = lang_h.unsqueeze(1)
return logprobs, torch.cat(outs), lang_h, torch.cat(lang_hs), prob, out
utils.agent.py def write 이어서
def write(self, max_words=100) -> List[str]:
logprobs, outs, self.lang_h, lang_hs, _, _ = self.model.write(self.lang_h, self.ctx_h,
max_words, self.args.temperature,
self.context_mask)
# append log probs from the generated words
self.logprobs.extend(logprobs)
self.lang_hs.append(lang_hs)
# first add the special 'YOU:' token
self.words.append(self.model.word2var('YOU:').unsqueeze(0))
# then append the utterance
self.words.append(outs)
assert (torch.cat(self.words).size()[0] == torch.cat(self.lang_hs).size()[0])
return self._decode(outs, self.model.word_dict)
다시 utils.dialog.py
def write(self, writer, logger, forced=False):
"""
Produces an utterance and saves necessary meta information
"""
# produce an utterance
out = writer.write()
if not writer.human:
# logger.dump_sent(writer.name, out, forced=forced)
out_with_item_names = out[0].replace("item0", "book")
out_with_item_names = out_with_item_names.replace("item1", "hat")
out_with_item_names = out_with_item_names.replace("item2", "ball")
계속
if not self._is_selection(out):
logger.dump_sent(writer.name, out, forced=forced)
if writer.name == "Bob":
self.test_prompt += f"{writer.name} : {out_with_item_names}\n"
else:
self.test_prompt += f"{writer.name} : {out_with_item_names}\n"
================================================================================
Alice : book=(count:1 value:0) hat=(count:1 value:1) ball=(count:3 value:3)
Bob : book=(count:1 value:1) hat=(count:1 value:0) ball=(count:3 value:3)
--------------------------------------------------------------------------------
Bob : propose: book=0 hat=0 ball=3
self.conv.append(out)
if len(self.conv) == 1 and ("propose" not in out[0] and "insist" not in out[0]):
print("started conv with non-proposal : ", out)
raise ValueError
self.agent_order.append(writer.name)
return out
utils.dialog.py 의 while 문 계속
while True:
# produce an utterance
out = self.write(writer, logger, forced=forced)
# make other agent read
self.read(reader, out)
coarse_dialogue_acts.agents.py CdaAgent → utils.agent RlAgent
RlAgent 는 LstmAgent 상속받고 있고 def read 함수가 없으므로
utils.agent LstmAgent
def read(self, inpt, prefix_token: str = 'THEM:'):
inpt = self._encode(inpt, self.model.word_dict)
lang_hs, self.lang_h = self.model.read(Variable(inpt).to(device), self.lang_h, self.ctx_h,
prefix_token=prefix_token)
models.dialog_model.py DialogModel
def read(self, inpt: Tensor, lang_h: Tensor, ctx_h: Tensor, prefix_token: str = 'THEM:') -> Tuple[Tensor, Tensor]:
"""Reads a given utterance."""
# inpt contains the pronounced utterance
# add a "THEM:" token to the start of the message
if prefix_token is not None:
prefix = self.word2var(prefix_token).unsqueeze(0).to(device)
inpt = inpt.to(device)
inpt = torch.cat([prefix, inpt])
# embed words
inpt_emb = self.word_encoder(inpt)
# append the context embedding to every input word embedding
ctx_h_rep = ctx_h.expand(inpt_emb.size(0), ctx_h.size(1), ctx_h.size(2))
inpt_emb = torch.cat([inpt_emb, ctx_h_rep], 2)
# finally read in the words
out, lang_h = self.reader(inpt_emb, lang_h)
return out, lang_h
utils.agent.py
def read(self, inpt, prefix_token: str = 'THEM:'):
inpt = self._encode(inpt, self.model.word_dict)
lang_hs, self.lang_h = self.model.read(Variable(inpt).to(device), self.lang_h, self.ctx_h,
prefix_token=prefix_token)
# append new hidden states to the current list of the hidden states
self.lang_hs.append(lang_hs.squeeze(1))
# first add the special 'THEM:' token
# self.words.append(self.model.word2var('THEM:').unsqueeze(0))
self.words.append(self.model.word2var(prefix_token).unsqueeze(0))
# then read the utterance
self.words.append(Variable(inpt).to(device))
assert (torch.cat(self.words).size()[0] == torch.cat(self.lang_hs).size()[0])
utils.agent
while True:
# produce an utterance
out = self.write(writer, logger, forced=forced)
# make other agent read
self.read(reader, out)
if self.is_end(out, writer, reader):
break
# swap roles
writer, reader = reader, writer
choices = self.generate_choices(ctxs, self.agents, logger, forced=forced)
utils.agent
def generate_choices(self, ctxs, agents, logger, forced):
"""
Generate final choices for each agent
"""
choices = []
# generate choices for each of the agents
for agent in agents:
choice = None
if agent.name == "Alice" or agent.name == "Bob" or agent.name == "Expert":
choice = agent.choose(self.conv, self.agent_order, ctxs)
elif agent.name == "Human":
choice = agent.choose()
choices.append(choice)
return choices
def choose(self, conv, agent_order, ctxs, handcode='partial') -> List[str]:
if handcode == 'full':
return self.choose_handcoded(conv, agent_order, ctxs)
elif handcode == 'partial':
# Sample a choice
choice, logprob, _ = self._choose(sample=False)
→ [['propose: item0=0 item1=0 item2=3'], ['agree'], ['']] (self.conv 의 모습)
def _choose(self, lang_hs=None, words=None, sample=False):
# get all the possible choices
choices = self.domain.generate_choices(self.context)
utils.domain.py
def generate_choices(self, inpt: List[str]) -> List[List[str]]:
"""
Given the number of items in a game, outputs all possible divisions of items
Args:
inpt: A parsed game context from the FB data set
Returns: A list of token sequences, each representing a possible selection
"""
cnts, _ = self.parse_context(inpt)
def gen(cnts, idx=0, choice=[]):
if idx >= len(cnts):
left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)]
right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))]
return [left_choice + right_choice]
choices = []
for c in range(cnts[idx] + 1):
choice.append(c)
choices += gen(cnts, idx + 1, choice)
choice.pop()
return choices
choices = gen(cnts)
choices.append(['<no_agreement>'] * self.selection_length())
#choices.append(['<disconnect>'] * self.selection_length())
return choices
utils.agent.py _choose 함수 계속
def _choose(self, lang_hs=None, words=None, sample=False):
# get all the possible choices
choices = self.domain.generate_choices(self.context)
# concatenate the list of the hidden states into one tensor
# print('debug', self.name, self.lang_hs, self.words)
lang_hs = lang_hs if lang_hs is not None else torch.cat(self.lang_hs)
# concatenate all the words into one tensor
words = words if words is not None else torch.cat(self.words)
# logits for each of the item
logits = self.model.generate_choice_logits(words, lang_hs, self.ctx_h)
models.dialog_model.py
def generate_choice_logits(self, inpt, lang_h, ctx_h):
"""Similar to forward_selection, but is used while selfplaying.
Thus it is dealing with batches of size 1.
"""
# run a birnn over the concatenation of the input embeddings and
# language model hidden states
inpt_emb = self.word_encoder(inpt)
h = torch.cat([lang_h.unsqueeze(1), inpt_emb], 2)
h = self.dropout(h)
# runs selection rnn over the hidden state h
attn_h = self.zero_hid(h.size(1), self.args.nhid_attn, copies=2)
h, _ = self.sel_rnn(h, attn_h)
h = h.squeeze(1)
# perform attention
logit = self.attn(h).squeeze(1)
prob = F.softmax(logit, dim=0).unsqueeze(1).expand_as(h)
attn = torch.sum(torch.mul(h, prob), 0, keepdim=True)
# concatenate attention and context hidden and pass it to the selection encoder
ctx_h = ctx_h.squeeze(1)
h = torch.cat([attn, ctx_h], 1)
h = self.sel_encoder.forward(h)
# generate logits for each item separately
logits = [decoder.forward(h).squeeze(0) for decoder in self.sel_decoders]
return logits
utils.agent.py def _choose 계속
def _choose(self, lang_hs=None, words=None, sample=False):
# construct probability distribution over only the valid choices
choices_logits = []
for i in range(self.domain.selection_length()):
idxs = [self.model.item_dict.get_idx(c[i]) for c in choices]
idxs = Variable(torch.from_numpy(np.array(idxs))).to(device)
idxs = self.model.to_device(idxs)
choices_logits.append(torch.gather(logits[i], 0, idxs).unsqueeze(1))
choice_logit = torch.sum(torch.cat(choices_logits, 1), 1, keepdim=False)
# subtract the max to softmax more stable
choice_logit = choice_logit.sub(choice_logit.max().item())
prob = F.softmax(choice_logit, dim=0)
if sample:
# sample a choice
idx = prob.multinomial(1).detach()
logprob = F.log_softmax(choice_logit).gather(0, idx)
else:
# take the most probably choice
_, idx = prob.max(0, keepdim=True)
logprob = None
p_agree = prob[idx.item()]
# Pick only your choice
return choices[idx.item()][:self.domain.selection_length()], logprob, p_agree.item()
다시 utils.agent.py def choose 계속
# def choose 계속
if conv[-1][0] != '<selection>': # reached max_sentences
return choice
# Check case 1 (all agrees after the final proposal; choose this final proposal )
assert conv[-1][0] == '<selection>'
partial_convo = []
for i, sent in reversed(list(enumerate(conv[0:-1]))):
if 'propose' in sent[0] or 'insist' in sent[
0]: # if agents disagreed after final proposal, then no agreement
if set(partial_convo) == {
'agree'}: # if all the words after the final proposal were 'agree', then hard code selection to be final proposal
choice = self._get_choice(i, sent, agent_order, ctxs)
return choice
else: # otherwise, just break
break
else:
partial_convo.append(sent[0])
# assert i != 0 # assert that propose or insist was in conversation
# Check case 2 (agent propose -> other ends. Then must choose proposal agent proposed)
end_agent = agent_order[-1] # get person who ended
# if ('propose' in conv[-2][0] or 'insist' in conv[-2][0]) and end_agent in ['Human', 'Bob']:
if len(conv) > 1 and ('propose' in conv[-2][0] or 'insist' in conv[-2][0]) and end_agent != self.name:
assert agent_order[-2] == self.name
choice = self._get_choice(-2, conv[-2], agent_order, ctxs)
return choice
# Check case 3 where choice is no agreement
if '<no_agreement>' in choice:
return choice
# Finally, make sure that agent selects previous selection
for i, sent in reversed(list(enumerate(conv[0:-1]))):
if agent_order[i] == self.name: # find agent's last proposal
if 'propose' in sent[0] or 'insist' in sent[
0]: # if agents disagreed after final proposal, then no agreement
choice = self._get_choice(i, sent, agent_order, ctxs)
return choice
elif 'agree' in sent[0]:
for j, sent2 in reversed(list(enumerate(conv[0:i]))):
if 'propose' in sent2[0] or 'insist' in sent2[
0]: # if agents disagreed after final proposal, then no agreement
choice = self._get_choice(j, sent2, agent_order, ctxs)
return choice
# if agent_order[i] == 'Alice': # find Alice's last proposal
# choice = self._get_choice(i, sent, agent_order, ctxs)
# return choice
# print("conv: ", conv)
# print("agent order: ", agent_order)
return choice
# raise ValueError
utils.dialog.py def generate_choices 계속
def generate_choices(self, ctxs, agents, logger, forced):
"""
Generate final choices for each agent
"""
choices = []
# generate choices for each of the agents
for agent in agents:
choice = None
if agent.name == "Alice" or agent.name == "Bob" or agent.name == "Expert":
choice = agent.choose(self.conv, self.agent_order, ctxs)
elif agent.name == "Human":
choice = agent.choose()
choices.append(choice)
return choices
utils.dialog.py def run 함수
# evaluate the choices, produce agreement and a reward
agree, rewards = self.evaluate_choices(
choices, ctxs, update, logger, forced, training
)
utils.dialog.py
def evaluate_choices(self, choices, ctxs, update, logger, forced, training):
"""
Evaluate the choices, produce agreement and a reward
:return:
"""
# evaluate the choices, produce agreement and a reward
agree, rewards = self.domain.score_choices(choices, ctxs)
assert len(rewards) == 2 # Otherwise, the partner_reward will be incorrect
logger.dump("-" * 80, forced=forced)
# self.test_prompt += f"Alice : book={choices[0][0][-1]} hat={choices[0][1][-1]} ball={choices[0][2][-1]}\n"
# self.test_prompt += f"Bob : book={choices[0][3][-1]} hat={choices[0][4][-1]} ball={choices[0][5][-1]}\n"
self.test_prompt += "-" * 80 + "\n"
logger.dump_agreement(agree, forced=forced)
self.test_prompt += "Agreement!\n" if agree else "Disagreement?!\n"
for i, (agent, reward) in enumerate(zip(self.agents, rewards)):
if agent.name == "Bob":
self.test_prompt += f"{agent.name} : {reward} points\n"
else:
self.test_prompt += f"{agent.name} : {reward} points\n"
## add gpt3 style reward
if training:
# GPT3 REWARDS
if self.model == "gpt3":
style_rewards = self.gpt3_reward()
if style_rewards == -1:
return -1, -1
# GPT2 REWARDS
elif self.model == "gpt2":
style_rewards = self.gpt2_reward()
if style_rewards == -1:
return -1, -1
utils.dialog.py
def gpt2_reward(self):
self.test_prompt += self.question
final_prompt = self.base_prompt + self.test_prompt
lm = pipeline("text-generation", model="gpt2", device='cuda:0')
set_seed(10)
full_response = lm(final_prompt, max_length=750, num_return_sequences=1)[0][
"generated_text"
]
response = full_response[len(final_prompt) :]
parsed_response = response.lower().strip().split(" ")[0]
if "no" in parsed_response:
self.gpt3_answers.append((self.test_prompt, parsed_response, 0))
return [0, 0]
elif "yes" in parsed_response:
self.gpt3_answers.append((self.test_prompt, parsed_response, 1))
return [10, 10]
else:
print(f"cannot parse lm answer!: {response}")
return -1
utils.dialog.py
# def evaluate_choices 계속
rewards[0] = style_rewards[0]
rewards[1] = style_rewards[1]
pareto = 0.0
if agree:
self.metrics.record("advantage", rewards[0] - rewards[1])
pareto = self.metrics.record_pareto("pareto", ctxs, rewards)
self.update_agents(agree, rewards, update, logger, forced=forced, pareto=pareto)
self.metrics.record("time")
self.metrics.record("dialog_len", len(self.conv))
self.metrics.record("agree", int(agree))
self.metrics.record("comb_rew", np.sum(rewards) if agree else 0)
for agent, reward in zip(self.agents, rewards):
self.metrics.record("%s_rew" % agent.name, reward if agree else 0)
self.metrics.record_end_of_dialogue("%s_diversity" % agent.name)
return agree, rewards
utils.agent.py
def update(self, agree: bool, reward: float, partner_reward: float = None, pareto=None):
if not self.train:
return
self.t += 1
if len(self.logprobs) == 0:
return
r = self.rewarder.calc_reward(agree, reward, partner_reward)
# compute accumulated discounted reward
g = Variable(torch.zeros(1, 1).fill_(r)).to(device)
rewards = []
for _ in self.logprobs:
rewards.insert(0, g)
g = g * self.args.gamma
loss = 0
# estimate the loss using one MonteCarlo rollout
for lp, r in zip(self.logprobs, rewards):
loss -= lp * r
self.opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.rl_clip)
# if self.args.visual and self.t % 10 == 0:
# self.model_plot.update(self.t)
# self.reward_plot.update('reward', self.t, reward)
# self.loss_plot.update('loss', self.t, loss.item())
# wandb.log({'rl reward': reward, 'rl loss':loss.item()})
self.opt.step()