import pickle
from genre.trie import Trie
from genre.fairseq_model import GENRE
with open("data/kilt_titles_trie_dict.pkl", "rb") as f:
trie = Trie.load_from_dict(pickle.load(f))
model = GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago").eval()
sentences = ["Einstein was a [START_ENT] German [END_ENT] physicist."]
import ipdb
ipdb.set_trace()
samples = model.sample(
sentences,
prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
print(samples)
GENRE/genre/fairseq_model.py
class _GENREHubInterface:
def sample(
self,
sentences: List[str],
beam: int = 5,
verbose: bool = False,
text_to_id=None,
marginalize=False,
marginalize_lenpen=0.5,
max_len_a=1024,
max_len_b=1024,
**kwargs,
) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
**batched_hypos = self.generate(
tokenized_sentences,
beam,
verbose,
max_len_a=max_len_a,
max_len_b=max_len_b,
**kwargs,
)**
def generate(self, *args, **kwargs) -> List[List[Dict[str, torch.Tensor]]]:
return super(BARTHubInterface, self).generate(*args, **kwargs)
이건 GENRE 아닌 fairseq 디렉토리 fairseq/fairseq/hub_utils.py
def generate(
self,
tokenized_sentences: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
return self.generate(
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
)[0]
# build generator using current args as well as any kwargs
gen_args = copy.deepcopy(self.cfg.generation)
with open_dict(gen_args):
gen_args.beam = beam
for k, v in kwargs.items():
if k != "prefix_allowed_tokens_fn":
setattr(gen_args, k, v)
**generator = self.task.build_generator(
self.models,
gen_args,
prefix_allowed_tokens_fn=kwargs.get("prefix_allowed_tokens_fn", None),
)**
inference_step_args = inference_step_args or {}
results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
**translations = self.task.inference_step(
generator, self.models, batch, **inference_step_args
)**
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))
# sort output to match input order
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
계속 fairseq 디렉토리 fairseq/fairseq/tasks/fairseq_task.py
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, constraints=constraints
)
fairseq/fairseq/sequence_generator.py
def _generate(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(self.model.models_size)
],
)
net_input = sample["net_input"]
**if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
src_lengths = (
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
)**
elif "source" in net_input:
src_tokens = net_input["source"]
src_lengths = (
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
raise Exception("expected src_tokens or source in net input")
{'src_tokens': tensor([[ 0, 717, 40335, 21, 10, 646, 4014, 11328, 1215, 5382,
742, 1859, 646, 9309, 1215, 5382, 742, 33832, 4, 2]]), 'src_lengths': tensor([20])}
bsz, src_len = src_tokens.size()[:2]
beam_size = self.beam_size
if constraints is not None and not self.search.supports_constraints:
raise NotImplementedError(
"Target-side constraints were provided, but search method doesn't support them"
)
# Initialize constraints, when active
self.search.init_constraints(constraints, beam_size)
max_len: int = -1
if self.match_source_len:
max_len = src_lengths.max().item()
else:
**max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
self.model.max_decoder_positions() - 1,
)**
assert (
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
# compute the encoder output for each beam
encoder_outs = self.model.forward_encoder(net_input)
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
# ensure encoder_outs is a List.
assert encoder_outs is not None
내가 추가한 부분
probs = F.softmax(full_logits, dim=-1)
for i in range(n):
if t <= dialogue_lens[i]:
continue
else:
# if first t after original length, cur_gen should be nothing (automatically included)
curr_gen = tokens[i, dialogue_lens[i]:t-1]
pos_token_list = self.prefix_allowed_tokens(i, curr_gen)
# TODO : if no possible pos_token_list & not terminated -> force eoa token (make others zero ; zero for padding token as well)
if not pos_token_list and termination_mask[i] != 0:
mask = torch.zeros_like(probs[i, 0, :], device=device)
probs[i, 0, :] = probs[i, 0, :] * mask
probs[i, 0, tokenizer.eoa_token_id] = 1- 1e-7
# TODO : make mask 1 for pad token
pos_token_list += [tokenizer.pad_token_id]
mask = torch.tensor([token in pos_token_list for token in range(len(probs[i, 0, :]))]).to(device) # we should include pad_token for available token everytime (terminated scenario should be considered)
probs[i, 0, :] = probs[i, 0, :] * mask