How to implement Dynamic Masking(feat. RoBERTa)

0

TIL

목록 보기
13/16

RoBERTa 논문에서는 기존에 MLM과는 다른 masking인 dynamic masking을 사용한다고 말한다. 기존 MLM은 계속 동일한 단어를 epoch마다 예측하기에 의미 없는 단어를 계속 masking하고 있을 수 있으며 overfitting이 발생할 수도 있다고 말한다.

그렇기에 RoBERTa 연구진은 데이터가 모델에 들어갈 때마다 매번 masking을 새롭게 해주어 이러한 단점을 보완하고자 했다. RoBERTa의 가장 주요한 논문 내용도 이것이라 볼 수 있다.

HuggingFace가 기본적으로 Dynamic Masking을 사용하기에 편하게 per-train을 시켜줄 수 있지만 어떻게 진행되는지 코드를 뜯어보고자 RoBERTa 깃헙을 뜯어보았다. 설명이 된 것을 찾을 수 없어 본인 방식으로 정리했으므로 틀린 부분에 대한 지적은 언제나 환영합니다.

💡 내용

class MaskTokensDataset(BaseWrapperDataset):
    """
    A wrapper Dataset for masked language modeling.
    Input items are masked according to the specified masking probability.
    Args:
        dataset: Dataset to wrap.
        sizes: Sentence lengths
        vocab: Dictionary with the vocabulary and special tokens.
        pad_idx: Id of pad token in vocab
        mask_idx: Id of mask token in vocab
        return_masked_tokens: controls whether to return the non-masked tokens
            (the default) or to return a tensor with the original masked token
            IDs (and *pad_idx* elsewhere). The latter is useful as targets for
            masked LM training.
        seed: Seed for random number generator for reproducibility.
        mask_prob: probability of replacing a token with *mask_idx*.
        leave_unmasked_prob: probability that a masked token is unmasked.
        random_token_prob: probability of replacing a masked token with a
            random token from the vocabulary.
        freq_weighted_replacement: sample random replacement words based on
            word frequencies in the vocab.
        mask_whole_words: only mask whole words. This should be a byte mask
            over vocab indices, indicating whether it is the beginning of a
            word. We will extend any mask to encompass the whole word.
        bpe: BPE to use for whole-word masking.
        mask_multiple_length : repeat each mask index multiple times. Default
            value is 1.
        mask_stdev : standard deviation of masks distribution in case of
            multiple masking. Default value is 0.
    """ 

코드의 내용에 대한 간략한 정리가 나와있다.
그 중 특이했던 점은 mask_multiple_length, mask_stdev 이었다.

  • mask_multiple_length : 해당 index를 몇 번이나 반복할 것인지. 여기서 해당 단어가 여러 번 반복되는 것이 아니라 dynamic masking의 기준이 되는 index가 몇 번이나 반복되냐는 것이다.
    • dynamic masking은 그 특성상 한 번의 epoch마다도 masking이 계속 다른 것에 적용이 된다.
  • mask_stdev : mask가 반복되는 횟수에 대해서 얼마나 옮겨갈 것인지에 대한 표준편차. stdev가 설정이 되지 않았다면 계단식으로, stdev가 설정되었다면 반복횟수를 평균으로 mask를 옮기게 된다.
def __getitem_cached__(self, seed: int, epoch: int, index: int):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            item = self.dataset[index]
            sz = len(item)

            assert (
                self.mask_idx not in item
            ), "Dataset contains mask_idx (={}), this is not expected!".format(
                self.mask_idx,
            )

            if self.mask_whole_words is not None:
                word_begins_mask = self.mask_whole_words.gather(0, item)
                word_begins_idx = word_begins_mask.nonzero().view(-1)
                sz = len(word_begins_idx)
                words = np.split(word_begins_mask, word_begins_idx)[1:]
                assert len(words) == sz
                word_lens = list(map(len, words))

가장 먼저 item은 이번에 반복할 index의 data이다. 즉, 문서 하나로 생각하면 된다. sz는 이것의 길이이다(왜 도대체 sz라고 하는 거지...)

이때에 mask_whole_words라는 것이 등장하는데 이는 BPE로 encoding이 되어 있는 단어들을 BPE 단위가 아니라 단어 하나씩만을 mask시키고 싶은 단어들의 집합이다.

  • 여기에는 New York 등이 들어갈 수 있을 것이다.

이러한 친구들은 특별하게 단어들을 찾아 sz의 길이를 변형시켜준다. 그렇게 그리고 이들을 반영해 masking을 해주도록 설정해준다.

            # decide elements to mask
            mask = np.full(sz, False)
            num_mask = int(
                # add a random number for probabilistic rounding
                self.mask_prob * sz / float(self.mask_multiple_length)
                + np.random.rand()
            )

            # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
            mask_idc = np.random.choice(sz, num_mask, replace=False)
            if self.mask_stdev > 0.0:
                lengths = np.random.normal(
                    self.mask_multiple_length, self.mask_stdev, size=num_mask
                )
                lengths = [max(0, int(round(x))) for x in lengths]
                mask_idc = np.asarray(
                    [
                        mask_idc[j] + offset
                        for j in range(len(mask_idc))
                        for offset in range(lengths[j])
                    ],
                    dtype=np.int64,
                )
            else:
                mask_idc = np.concatenate(
                    [mask_idc + i for i in range(self.mask_multiple_length)]
                )
            mask_idc = mask_idc[mask_idc < len(mask)]
            try:
                mask[mask_idc] = True
            except:  # something wrong
                print(
                    "Assigning mask indexes {} to mask {} failed!".format(
                        mask_idc, mask
                    )
                )
                raise

이 부분이 참 특이했다. 먼저 mask 여부를 저장할 numpy를 만들어준다. 그리고 내가 정해둔 비율에 따라 마스크의 개수를 설정해준다.

그후에 stdev가 존재한다면 mask_multiple_length을 평균으로 mask_stdev을 표준편차로 만들어서 idx에 offset으로 더해준다.

이런 방식을 사용하면 mask_multiple_length을 중심으로 숫자들이 등장하게 되고 mask가 index를 중심으로 움직이게 된다.
만약 존재하지 않는다면 내가 반복할 횟수까지를 for문에 넣어 mask를 계단식으로 옮겨준다.

마지막으로 이들을 반영해 index numpy를 만들고 이들에 따라 masking을 해주는 데이터셋을 만들어준다. 이는 모델에 데이터가 들어갈 때마다 반복적으로 시행이 된다.

✏️ 후기

생각보다 코드가 너무 불친절했다. 또한 왜 mask_multiple_length가 평균으로 설정이 되어 index를 옮겨주고 있는지도 이해가 잘 가지 않았다. random성을 주기 위한 중심을 설정해준 느낌이었는데 그것이 반복되는 횟수와 무슨 관련이 있는지 모르겠다. 이에 대한 notation도 잘 되어 있지 않아 이해를 하는데 한참 걸렸다.

이런 코드를 보니 나도 모델링할 때 이런 방식으로 하면 가독성이 많이 떨어지는구나를 느꼈다. 적어도 data에 readme를 간단하게나마 추가하는 것이 좋을 것 같다는 생각이 들었다.

참고문헌 : https://hryang06.github.io/nlp/BERT/

https://github.com/facebookresearch/fairseq/tree/main/fairseq

profile
프리미어와 IDE만 있다면 무엇이든 만들 수 있어

0개의 댓글