huggingface의 prefix_allowed_tokens_fn 코드리뷰

이두현·2024년 3월 17일
0

prefix_allowed_tokens_fn 코드리뷰


import pickle
from genre.trie import Trie
from genre.fairseq_model import GENRE

with open("data/kilt_titles_trie_dict.pkl", "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))

model = GENRE.from_pretrained("models/fairseq_entity_disambiguation_aidayago").eval()
sentences = ["Einstein was a [START_ENT] German [END_ENT] physicist."]
import ipdb
ipdb.set_trace()
samples = model.sample(
    sentences,
    prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
print(samples)
  • 이걸로 시작

GENRE/genre/fairseq_model.py

class _GENREHubInterface:
    def sample(
        self,
        sentences: List[str],
        beam: int = 5,
        verbose: bool = False,
        text_to_id=None,
        marginalize=False,
        marginalize_lenpen=0.5,
        max_len_a=1024,
        max_len_b=1024,
        **kwargs,
    ) -> List[str]:
        if isinstance(sentences, str):
            return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
        tokenized_sentences = [self.encode(sentence) for sentence in sentences]

        **batched_hypos = self.generate(
            tokenized_sentences,
            beam,
            verbose,
            max_len_a=max_len_a,
            max_len_b=max_len_b,
            **kwargs,
        )**
  • **kwargs : prefix_allowed_tokens_fn 이 전달됨
  • 먼저 tokenize 함
    • 배열에 들어있는 각각의 sentence에 대해 encode된 tensor 배열 전달
def generate(self, *args, **kwargs) -> List[List[Dict[str, torch.Tensor]]]:
        return super(BARTHubInterface, self).generate(*args, **kwargs)
  • args (tokenized_sentences,
    beam,
    verbose,
    max_len_a=max_len_a,
    max_len_b=max_len_b,),
  • kwargs : prefix_allowed_tokens_fn

이건 GENRE 아닌 fairseq 디렉토리 fairseq/fairseq/hub_utils.py

def generate(
        self,
        tokenized_sentences: List[torch.LongTensor],
        beam: int = 5,
        verbose: bool = False,
        skip_invalid_size_inputs=False,
        inference_step_args=None,
        **kwargs
    ) -> List[List[Dict[str, torch.Tensor]]]:
        if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
            return self.generate(
                tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
            )[0]

        # build generator using current args as well as any kwargs
        gen_args = copy.deepcopy(self.cfg.generation)
        with open_dict(gen_args):
            gen_args.beam = beam
            for k, v in kwargs.items():
                if k != "prefix_allowed_tokens_fn":
                    setattr(gen_args, k, v)
        **generator = self.task.build_generator(
            self.models,
            gen_args,
            prefix_allowed_tokens_fn=kwargs.get("prefix_allowed_tokens_fn", None),
        )**
        
        inference_step_args = inference_step_args or {}
        results = []
        for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
            batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
            **translations = self.task.inference_step(
                generator, self.models, batch, **inference_step_args
            )**
            for id, hypos in zip(batch["id"].tolist(), translations):
                results.append((id, hypos))

        # sort output to match input order
        outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
  • 처음 if 문이 False
  • kwargs.items()에는 max_len_a, max_len_b, prefix_allowed_fn 저장 되어있음

계속 fairseq 디렉토리 fairseq/fairseq/tasks/fairseq_task.py

def inference_step(
        self, generator, models, sample, prefix_tokens=None, constraints=None
    ):
        with torch.no_grad():
            return generator.generate(
                models, sample, prefix_tokens=prefix_tokens, constraints=constraints
            )

fairseq/fairseq/sequence_generator.py

def _generate(
        self,
        sample: Dict[str, Dict[str, Tensor]],
        prefix_tokens: Optional[Tensor] = None,
        constraints: Optional[Tensor] = None,
        bos_token: Optional[int] = None,
    ):
        incremental_states = torch.jit.annotate(
            List[Dict[str, Dict[str, Optional[Tensor]]]],
            [
                torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
                for i in range(self.model.models_size)
            ],
        )
        net_input = sample["net_input"]

        **if "src_tokens" in net_input:
            src_tokens = net_input["src_tokens"]
            # length of the source text being the character length except EndOfSentence and pad
            src_lengths = (
                (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
            )**
        elif "source" in net_input:
            src_tokens = net_input["source"]
            src_lengths = (
                net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
                if net_input["padding_mask"] is not None
                else torch.tensor(src_tokens.size(-1)).to(src_tokens)
            )
        else:
            raise Exception("expected src_tokens or source in net input")
  • jit은 왜하는지 잘 모르겠음
  • net_input 모양 :

{'src_tokens': tensor([[ 0, 717, 40335, 21, 10, 646, 4014, 11328, 1215, 5382,
742, 1859, 646, 9309, 1215, 5382, 742, 33832, 4, 2]]), 'src_lengths': tensor([20])}

  • self.eos 와 self.pad 가 아닌 공통 부분의 길이를 src_lengths로 저장
    • 여기서는 19 (마지막게 eos 토큰인듯)
        bsz, src_len = src_tokens.size()[:2]
        beam_size = self.beam_size

        if constraints is not None and not self.search.supports_constraints:
            raise NotImplementedError(
                "Target-side constraints were provided, but search method doesn't support them"
            )

        # Initialize constraints, when active
        self.search.init_constraints(constraints, beam_size)

        max_len: int = -1
        if self.match_source_len:
            max_len = src_lengths.max().item()
        else:
            **max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                self.model.max_decoder_positions() - 1,
            )**
        assert (
            self.min_len <= max_len
        ), "min_len cannot be larger than max_len, please adjust these!"
        # compute the encoder output for each beam
        encoder_outs = self.model.forward_encoder(net_input)
  • bsz : 1, src_len : 20 (padding + eos 토큰 포함인듯)
  • constraint 는 없음
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
        # ensure encoder_outs is a List.
        assert encoder_outs is not None
  • bsz : 2, beam_size :5 인데
    • new_order : [0,0,0,0,0,1,1,1,1,1]
  • print(encoder_outs[0]['encoder_out'].shape) : [22, 2, 1024]
  • reorder_encoder_out 이후 shape은 [22, 10, 1024]

내가 추가한 부분

probs = F.softmax(full_logits, dim=-1)
            for i in range(n):
                if t <= dialogue_lens[i]:
                    continue
                else:
                    # if first t after original length, cur_gen should be nothing (automatically included)
                    curr_gen = tokens[i, dialogue_lens[i]:t-1]
                    pos_token_list = self.prefix_allowed_tokens(i, curr_gen) 

                    # TODO : if no possible pos_token_list & not terminated -> force eoa token (make others zero ; zero for padding token as well)
                    if not pos_token_list and termination_mask[i] != 0:
                        mask = torch.zeros_like(probs[i, 0, :], device=device)
                        probs[i, 0, :] = probs[i, 0, :] * mask
                        probs[i, 0, tokenizer.eoa_token_id] = 1- 1e-7

                    # TODO : make mask 1 for pad token 
                    pos_token_list += [tokenizer.pad_token_id]
                    mask = torch.tensor([token in pos_token_list for token in range(len(probs[i, 0, :]))]).to(device) # we should include pad_token for available token everytime (terminated scenario should be considered)
                    probs[i, 0, :] = probs[i, 0, :] * mask
  1. 논문에서 logit 말고 prob에 대해 하는 것이 효과적이라는 설명있었음
  • curr_gen은 input으로 들어온 것이 아닌 지금까지 생성해낸 token 목록
  • pos_token_list는 이번에 가능한 token list를 반환
  • pos_token이 없을 때 (terminal이 아닌 경우만 eoa 토큰을 강제)
    • terminal 인 경우는 이미 위에서 pad 토큰을 강제해주고 있음
  • 가능한 pos_token들과 padding token의 확률 중에서 가장 큰 것을 고르도록
    • 아직 termination 아닌 것들은 pad token 못만들도록 위에서 강제해주고 있음
    • pos_token이 있는데 아직 끝나면 안되므로 eoa 토큰은 생성하지 못하게,,?
      • pos_token_list + pad 빼고는 불가능하게 되어있음
profile
0100101

0개의 댓글