diffusion-lm 코드리뷰

이두현·2024년 3월 17일
0

train_setup

-use_kl False --learn_sigma False

args.experiment : random

args.modality : roc

Model_FILE : diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e

folder name : diffusion_models

Model_FILE:

diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e

improved-dffusion/scripts/run_train.py

with open(Model_FILE + '.sh', 'w') as f:
        print(COMMANDLINE, file=f)

COMMANDLINE :

OPENAI_LOGDIR=diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e TOKENIZERS_PARALLELISM=false python scripts/train.py --checkpoint_path diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e --model_arch transformer --modality roc --save_interval 50000 --lr 0.0001 --batch_size 64 --diffusion_steps 2000 --noise_schedule sqrt --use_kl False --learn_sigma False --image_size 8 --num_channels 128 --seed 101 --dropout 0.1 --in_channel 128 --out_channel 128 --padding_mode pad --experiment random --lr_anneal_steps 400000 --weight_decay 0.0 --num_res_blocks 2 --predict_xstart True --training_mode e2e --vocab_size 11043 --roc_train ../datasets/ROCstory

  • improved-diffusion directory에서 위의 command 실행

improved_diffusion/improved-diffusion/script.util.py

def create_model_and_diffusion():
model = create_model(
        image_size,
        num_channels,
        num_res_blocks,
        learn_sigma=learn_sigma,
        class_cond=class_cond,
        use_checkpoint=use_checkpoint,
        attention_resolutions=attention_resolutions,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        dropout=dropout,
        model_arch=model_arch,
        in_channel=in_channel,
        out_channel=out_channel,
        training_mode=training_mode,
        vocab_size=vocab_size,
        config_name=config_name,
        experiment_mode=experiment_mode,
        logits_mode=logits_mode,
    )
def create_model():
elif model_arch == 'transformer':
        if image_size == 256:
            channel_mult = (1, 1, 2, 2, 4, 4)
        elif image_size == 64:
            channel_mult = (1, 2, 3, 4)
        elif image_size == 32:
            channel_mult = (1, 2, 2, 2)
        elif image_size == 16:  # DEBUG**
            channel_mult = (1, 2, 2, 2)
        else:
            channel_mult = (1, 2, 2, 2)

        attention_ds = []
        for res in attention_resolutions.split(","):
            attention_ds.append(image_size // int(res))

        return TransformerNetModel2(
            in_channels=in_channel,  # 3, DEBUG**
            model_channels=num_channels,
            out_channels=(out_channel if not learn_sigma else out_channel*2),  # DEBUG**  (3 if not learn_sigma else 6),
            num_res_blocks=num_res_blocks,
            attention_resolutions=tuple(attention_ds),
            dropout=dropout,
            channel_mult=channel_mult,
            num_classes=(NUM_CLASSES if class_cond else None),
            use_checkpoint=use_checkpoint,
            num_heads=num_heads,
            num_heads_upsample=num_heads_upsample,
            use_scale_shift_norm=use_scale_shift_norm,
            config_name=config_name,
            training_mode=training_mode,
            vocab_size=vocab_size,
            experiment_mode=experiment_mode,
            logits_mode=logits_mode,
        )
    else:
        raise NotImplementedError
  • image_size : 8
    • channel_mult : (1, 2, 2, 2)
  • attention_resolutions : 16,8
  • attention_ds : [0, 1]

improved-diffusion/improved-diffusion/transformer_model2.py

class TransformerNetModel2(nn.Module): # line 674
     if num_heads_upsample == -1: # True
            num_heads_upsample = num_heads
     if training_mode == 'e2e':
            self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
            if self.logits_mode == 2:
                # self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=False)
                self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=True)

            **else:
                self.lm_head = nn.Linear(self.in_channels, vocab_size)**
            with th.no_grad():
                self.lm_head.weight = self.word_embedding.weight

       if experiment_mode == 'conditional_gen':
            self.conditional_gen = True
            self.encoder_emb = nn.Embedding(vocab_size, config.hidden_size)
            self.encoder = BertEncoder(config)
            print(config, 'conditional_gen')
            config.is_decoder = True
            config.add_cross_attention = True
        **elif experiment_mode == 'lm':
            self.conditional_gen = False**
  • num_heads_upsample = num_heads = 4
  • training_mode : ‘e2e’ 이고 logits_mode = 1
  • self.word_embedding : input token 을 각각 128 짜리로 만들어주는 역할
        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            SiLU(),
            linear(time_embed_dim, config.hidden_size),
        )
	      self.input_up_proj = nn.Sequential(nn.Linear(in_channels, config.hidden_size),
                       nn.Tanh(), nn.Linear(config.hidden_size, config.hidden_size))
				if init_pretrained:
            from transformers.models.bert.modeling_bert import BertModel
            temp_bert = BertModel.from_pretrained(config_name, config=config)
            del temp_bert.embeddings
            del temp_bert.pooler
            self.input_transformers = temp_bert.encoder
            print('initializing from pretrained bert.')
        **else:
            print(config)
            self.input_transformers = BertEncoder(config)**

from transformers.models.bert.modeling_bert import BertEncoder
  • model_channels : 128, || config.hidden_size : 768
  • input_up_proj : 128 짜리로 embedding된 token을 받아서 BERT로 들어가기 전 붙이는 projection 인듯
  • BertEncoder는 input_proj - BERTEncoder - output_proj 에서 중간 담당
# config 
BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.35.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}
				self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # config2 = config
        # config2.hidden_size = 2 * config.hidden_size
        # self.output_transformers = BertEncoder(config)

        self.output_down_proj = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
                                              nn.Tanh(), nn.Linear(config.hidden_size, out_channels))
  • config.max_position_embeddings : 512
  • print(torch.arange(config.max_position_embeddings).expand((1, -1)).shape) : [1, 512]
  • config.layer_norm_eps : 1e-12
  • output_channels : 128
  • self.output_down_proj 는 transformer에서 나온 것 다시 output_channel 로 바꿔주는 layer인듯

improved_diffusion/improved-diffusion/script.util.py

def create_model_and_diffusion():
diffusion = create_gaussian_diffusion(
        steps=diffusion_steps,
        learn_sigma=learn_sigma,
        sigma_small=sigma_small,
        noise_schedule=noise_schedule,
        use_kl=use_kl,
        predict_xstart=predict_xstart,
        rescale_timesteps=rescale_timesteps,
        rescale_learned_sigmas=rescale_learned_sigmas,
        timestep_respacing=timestep_respacing,
        model_arch=model_arch,
        training_mode=training_mode,
    )

improved-diffusion/improved_diffusion/script_util.py

def create_gaussian_diffusion():
betas = gd.get_named_beta_schedule(noise_schedule, steps)
  • noise_schedule : sqrt
  • betas에는 np.array가 반환됨

improved-diffusion/improved_diffusion/gaussian_diffusion.py

def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
elif schedule_name == 'sqrt':
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: 1-np.sqrt(t + 0.0001),
        )
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)
  • steps : 2000, beta 모양 : (2000,)

improved_diffusion/improved-diffusion/script.util.py

def create_model_and_diffusion():
	   betas = gd.get_named_beta_schedule(noise_schedule, steps)
	   if training_mode == 'e2e':
        # end to end training
        if use_kl:
            loss_type = gd.LossType.E2E_KL
        **else:
            loss_type = gd.LossType.E2E_MSE**
	   **if not timestep_respacing:
        timestep_respacing = [steps]**
		 
		 return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
        model_arch=model_arch,
        training_mode=training_mode,
    )
  • training_mode : e2e
  • use_kl : false
  • loss_type : LossType.E2E_MSE
  • timestep_respacing 이 none 이므로 [2000] 이 됨

improved_diffusion/improved_diffusion/respace.py

def space_timesteps(num_timesteps, section_counts):
	  size_per = num_timesteps // len(section_counts)
    extra = num_timesteps % len(section_counts)
    start_idx = 0
    all_steps = []
    for i, section_count in enumerate(section_counts):
        size = size_per + (1 if i < extra else 0)
        if size < section_count:
            raise ValueError(
                f"cannot divide section of {size} steps into {section_count}"
            )
        if section_count <= 1:
            frac_stride = 1
        else:
            frac_stride = (size - 1) / (section_count - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(section_count):
            taken_steps.append(start_idx + round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        start_idx += size
    return set(all_steps)
  • num_timesteps : 2000, section_counts : [2000]
  • 현재 세팅의 경우 단순히 [0, 1, 2, …, 1999] 인 것 같음

improved_diffusion/improved-diffusion/script.util.py

def create_model_and_diffusion():
		return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else **gd.ModelMeanType.START_X**
        ),
        model_var_type=(
            (
                **gd.ModelVarType.FIXED_LARGE**
                if not sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
        model_arch=model_arch,
        training_mode=training_mode,
    )
  • use_timesteps : {0, 1, …, 1999}
  • predict_xstart : True || sigma_small : False || learn_sigma : False
  • model_mean_type : gd.ModelMeanType.START_X
  • model_var_type : gd.ModelVarType.FIXED_LARGE
  • loss_type : LossType.E2E_MSE

improved-diffusion/improved-diffusion/respace.py

class SpacedDiffusion(GaussianDiffusion):
			def __init__(self, use_timesteps, **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs["betas"])

        # print(kwargs.keys())
        base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
        last_alpha_cumprod = 1.0
        new_betas = []
  • self.original_num_steps : 2000
  • kwargs :

{'betas': array([0.01464131, 0.00888909, 0.00706818, ..., 0.35722328, 0.55561113,
0.999 ]), 'model_mean_type': <ModelMeanType.START_X: 2>, 'model_var_type': <ModelVarType.FIXED_LARGE: 3>, 'loss_type': <LossType.E2E_MSE: 6>, 'rescale_timesteps': True, 'model_arch': 'transformer', 'training_mode': 'e2e'}

improved-diffusion/improved-diffusion/gaussian_diffusion.py

class GaussianDiffusion:
	self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps
        self.model_arch=model_arch

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

				alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )
  • mean_type : ModelMeanType.START_X || var_type : ModelVarType.FIXED_LARGE
  • rescale_timesteps : True
  • num_timesteps : 2000

improved-diffusion/improved-diffusion/respace.py

class SpacedDiffusion(GaussianDiffusion):
			def __init__(self, use_timesteps, **kwargs):
        self.use_timesteps = set(use_timesteps)
        self.timestep_map = []
        self.original_num_steps = len(kwargs["betas"])

        # print(kwargs.keys())
        base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
        last_alpha_cumprod = 1.0
        new_betas = []
	      for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
                self.timestep_map.append(i)
        kwargs["betas"] = np.array(new_betas)
        super().__init__(**kwargs)
  • new_betas 길이도 2000 될 것

create model and diffusion 완료


improved-diffusion/scripts/train.py

model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
  • name : uniform
def create_named_schedule_sampler(name, diffusion):
    if name == "uniform":
        return UniformSampler(diffusion)

class UniformSampler(ScheduleSampler):
    def __init__(self, diffusion):
        self.diffusion = diffusion
        self._weights = np.ones([diffusion.num_timesteps])

    def weights(self):
        return self._weights
  • diffusion.num_timesteps : 2000

improved-diffusion/scripts/train.py

else:
        print('load data', '*'*50)
        if args.modality == 'roc-aug' or args.modality == 'commonGen-aug':
            tokenizer = load_tokenizer(args.modality, args.experiment, 'predictability/diffusion_models_v7/diff_roc_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart')
            rev_tokenizer = {v: k for k, v in tokenizer.items()}
            print(len(rev_tokenizer), 'loading from tokenizer. ')
        elif args.use_bert_tokenizer == 'yes':
            rev_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        else:
            **rev_tokenizer = None**

if args.experiment == 'random1':
            args.experiment = 'random'
            print('loading from the vocabs here.')
            assert args.in_channel == 64
            assert args.modality == 'roc'
            model22 = torch.nn.Embedding(args.vocab_size, args.in_channel)
            model22_weight = torch.load('predictability/diffusion_models_v7/diff_roc-aug_pad_rand64_'
                                        'transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e/'
                                        'ema_0.9999_200000.pt', map_location='cpu')['word_embedding.weight']
            model22.weight = model22_weight
            model22.weight.requires_grad=False
        **else:
            model22 = None**

data = load_data_text(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            image_size=args.image_size,
            class_cond=args.class_cond,
            data_args = args,
            task_mode=args.modality,
            padding_mode=args.padding_mode, #block, pad
            load_vocab=rev_tokenizer,
            model=model22,
        )

improved-diffusion/improved-diffusion/text_dataset.py

def load_data_text(
    *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None, 
        task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None,
):

**if data_args.experiment.startswith('random') and model is None:
        model = None**
    elif data_args.experiment.startswith('random') and model is not None:
        print('loading initialized random embeddings. ')

**if task_mode == 'roc' or task_mode == 'roc-aug' :
        training_data, model = get_corpus_rocstory(data_args, model, image_size,
                                            padding_mode=padding_mode, split=split,
                                            load_vocab=load_vocab)**
def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
                        split='train', load_vocab=None):
    import csv, torch, json
    from spacy.lang.en import English

		if data_args.experiment_mode == 'lm':
        if data_args.modality == 'roc':
            print('loading dataset from ROCStory')
            nlp = English()
            tokenizer = nlp.tokenizer
            sentence_lst = []
            print(f'loading from {data_args.roc_train}')
            **if split == 'train':
                print('loading form the TRAIN set')
                path = f'{data_args.roc_train}/roc_train.json'**
            elif split == 'valid':
                print('loading form the VALID set')
                path = f'{data_args.roc_train}/roc_valid.json'
            else:
                assert False, "invalid split for ROC dataset"

	      with open(path, 'r') as roc_reader:
                for row in roc_reader:
                    sentences = json.loads(row)[0].strip()
                    word_lst = [x.text for x in tokenizer(sentences)]
                    sentence_lst.append(word_lst)

				# get tokenizer.
        if load_vocab is None:
            counter = Counter()
            for input_ids in sentence_lst:
                counter.update(input_ids)

				if load_vocab is None:
        vocab_dict = {'START': 0, 'END': 1, 'UNK':2, 'PAD':3}
        for k, v in counter.items():
            if v > 10:
                vocab_dict[k] = len(vocab_dict)
        print(len(counter), len(vocab_dict))

        path_save_vocab = f'{data_args.checkpoint_path}/vocab.json'
        print(f'save the vocab to {path_save_vocab}')
        with open(path_save_vocab, 'w') as f:
            json.dump(vocab_dict, f)
  • train split에 따라 sentence_lst에 List of List 로 tokenized 된 sentence를 저장한다
  • 10 번 이상 등장하는 단어에 대해 이미 정의된 0, 1, 2, 3 순서 다음의 순서를 부여한다
    • ex) {'START': 0, 'END': 1, 'UNK': 2, 'PAD': 3, 'Brad': 4} 이렇게
if model is None and data_args.experiment == 'random':
        model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
        print('initializing the random embeddings', model)
        torch.nn.init.normal_(model.weight)
        path_save = f'{data_args.checkpoint_path}/random_emb.torch'
        print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
        torch.save(model.state_dict(), path_save)

if data_args.experiment_mode == 'lm' and data_args.modality in ['roc-aug', 'roc', 'yelp', 'commonGen', 'commonGen-aug'] \
            and data_args.cache_mode=='no':
        train_dataset = helper_tokenize_stream(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode)
        return train_dataset, model
  • model은 토큰을 128 dim vector로 바꿔주는 역할
  • train_dataset 과정은 일단 생략..

다시 load_data_text

def load_data_text(
    *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None, 
        task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None,
):

if data_args.modality in ['roc-aug', 'roc', 'book', 'yelp', 'commonGen', 'commonGen-aug'] and data_args.cache_mode=='no':
        dataset = TextDataset_NoCache(
            training_data,
            image_size,
            data_args,
            model_arch=data_args.model_arch,
            model_emb=model
        )

    if deterministic:

        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,  # 20,
            drop_last=True,
            shuffle=False,
            num_workers=1,
        )

    **else:
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,  # 20,
            drop_last=True,
            shuffle=True,
            num_workers=1,
        )**
    while True:
        yield from data_loader
  • 여기에서 딱히 embedding module이 들어가진 않음

improve_diffusion/scripts/train.py

model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
                                        args.checkpoint_path, extra_args=args)
        if args.modality == 'book' or args.use_bert_tokenizer == 'yes':
            rev_tokenizer = tokenizer # BERT tokenizer BPE.
        else:
            **rev_tokenizer = {v: k for k, v in tokenizer.items()}**
  • rev_tokenizer에 tokenizer 복사

improved_diffusion/improved_diffusion/rounding.py

def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None):

    if mode in ['random', 'random1', 'random_up_proj', 'glove']:
        if modality == 'synth':
            
        else:
            import json
            if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'):
                tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
                if 'e2e' in file and modality == 'book':
                    emb_dim = 1
	          **else:
                path_save_tokenizer = '{}/vocab.json'.format(file)
                print(f'loading from {path_save_tokenizer}')
                with open(path_save_tokenizer, 'r') as f:
                    vocab = json.load(f)
                print(len(vocab))
                tokenizer = {v: k for k, v in vocab.items()}**
            model = torch.nn.Embedding(len(tokenizer), emb_dim)
            path_save = '{}/random_emb.torch'.format(file)
            model.load_state_dict(torch.load(path_save))

    return model, tokenizer
  • 위에서 만들었던 tokenizer 저장 했던거 다시 불러오기
  • model : word → 128로 embedding 해주는 모듈

improved_diffusion/scripts/train.py

data_valid = load_data_text(
            data_dir=args.data_dir,
            batch_size=args.batch_size,
            image_size=args.image_size,
            class_cond=args.class_cond,
            data_args=args,
            task_mode=args.modality,
            padding_mode=args.padding_mode,  # block, pad
            split='valid',
            load_vocab=rev_tokenizer,
            model=model2,
        )

    # dist.barrier()
    # import time
    # while not os.path.exists(os.path.join(args.checkpoint_path, 'vocab.json')):
    #     time.sleep(1)
    def get_mapping_func(args, diffusion, data):
        model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
                                        args.checkpoint_path, extra_args=args)
        model3 = get_weights(model2, args)
        print(model3, model3.weight.requires_grad)
        mapping_func = partial(compute_logp, args, model3.cuda())
        diffusion.mapping_func = mapping_func
        return mapping_func

    get_mapping_func(args, diffusion, data)

    TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=args.batch_size,
        microbatch=args.microbatch,
        lr=args.lr,
        ema_rate=args.ema_rate,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        resume_checkpoint=args.resume_checkpoint,
        use_fp16=args.use_fp16,
        fp16_scale_growth=args.fp16_scale_growth,
        schedule_sampler=schedule_sampler,
        weight_decay=args.weight_decay,
        lr_anneal_steps=args.lr_anneal_steps,
        checkpoint_path=args.checkpoint_path,
        gradient_clipping=args.gradient_clipping,
        eval_data=data_valid,
        eval_interval=args.eval_interval
    ).run_loop()
  • get_mapping_func 은 아직 잘 모르겠음

improved_diffusion/improved_diffusion/train_util.py

class TrainLoop:
def __init__():
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.eval_data = eval_data
        self.batch_size = batch_size # 64
        self.microbatch = microbatch if microbatch > 0 else batch_size # 64
        self.lr = lr # 0.0001
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        ) # [0.9999]
        self.log_interval = log_interval
        self.eval_interval = eval_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16 # False
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or **UniformSampler(diffusion)**
        self.weight_decay = weight_decay # 0.0
        self.lr_anneal_steps = lr_anneal_steps # 400000
        self.gradient_clipping = gradient_clipping # -1.0

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size() # 64

        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self.checkpoint_path = checkpoint_path # DEBUG **
  • self.model은 word_embedding + BERT model 포함
  • dist.world_size() : 1
			  self._load_and_sync_parameters()
        if self.use_fp16:
            self._setup_fp16()

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        **else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]**

        if th.cuda.is_available(): # DEBUG **
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
	            self.ddp_model = self.model
  • ._load_and_sync_parameters() : 이전 training에서 불러오기
  • use_fp16 : false
  • master_params : self.model.parameter를 list 형태로 변환한 것
    • ema_params에도 이 모델 정보를 저장
def run_loop(self):
	while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            **self.run_step(batch, cond)**

def run_step(self, batch, cond):
        **self.forward_backward(batch, cond)**
        if self.use_fp16:
            self.optimize_fp16()
        else:
            self.optimize_normal()
        self.log_step()
  • batch : [64, 64, 128] || cond[’input_ids’] : [64, 64]

batch, cond에 들어가는 값?

improved_diffusion/improved_diffusion/text_datasets.py

def __getitem__(self, idx):

        # We are not on a new enough PIL to support the `reducing_gap`
        # argument, which uses BOX downsampling at powers of two first.
        # Thus, we do it by hand to improve downsample quality.
        with torch.no_grad():
            input_ids = self.text_datasets['train'][idx]['input_ids']
            model = self.model_emb
            **if self.data_args.experiment.startswith('random'):**
                **hidden_state = model(torch.tensor(input_ids))**
            elif self.data_args.experiment == 'gpt2_pre_compress':
                input_ids2 = torch.tensor(input_ids).to(model.device)
                input_embs = model.transformer.wte(input_ids2)  # input_embs
                hidden_state = model.down_proj(input_embs)
                hidden_state = hidden_state * data_args.emb_scale_factor

            if self.model_arch == 'conv-unet':
                arr = np.array(hidden_state,
                               dtype=np.float32).reshape(self.resolution, self.resolution, -1)
                # print(self.eigen_transform.shape)
                if self.eigen_transform is not None:
                    old_shape = arr.shape
                    arr = arr.reshape(1, -1) - self.eigen_transform['mean']
                    arr = arr @ self.eigen_transform['map']
                    arr = arr.reshape(old_shape)
                if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
                    arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)

                out_dict = {}
                out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
                # if self.local_classes is not None:
                #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
                # print(out_dict.keys())
                return np.transpose(arr, [2, 0, 1]), out_dict
            elif self.model_arch == '1d-unet':
                arr = np.array(hidden_state,
                               dtype=np.float32)  # seqlen, dim
                if self.eigen_transform is not None:
                    old_shape = arr.shape
                    arr = arr.reshape(1, -1) - self.eigen_transform['mean']
                    arr = arr @ self.eigen_transform['map']
                    arr = arr.reshape(old_shape)
                if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
                    arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
                arr = np.transpose(arr, [1, 0])
                out_dict = {}
                out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
                # out_dict['mapping_func'] = self.mapping_func
                # if self.local_classes is not None:
                #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
                # print(arr.shape)
                return arr, out_dict
            **else:
                arr = np.array(hidden_state,
                               dtype=np.float32)
                if self.eigen_transform is not None:
                    old_shape = arr.shape
                    # arr = arr.reshape(1, -1) @ self.eigen_transform
                    arr = arr.reshape(1, -1) - self.eigen_transform['mean']
                    arr = arr @ self.eigen_transform['map']
                    arr = arr.reshape(old_shape)**

                if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
                    # print(arr.dtype)
                    # print(self.data_args.noise_level, 'using the noise level.')
                    arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
                    # print(arr.dtype)

                out_dict = {}
                out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
                # out_dict['mapping_func'] = self.mapping_func
                if self.data_args.experiment_mode == 'conditional_gen':
                    out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids'])
                    out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask'])
                # if self.local_classes is not None:
                #     out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
                return arr, out_dict
  • 지금 batch로 주어지는 값은
    • word id → embedding layer통과 → eigen_transform → 반환
    • cond → input_ids 에 들어온 token id 반환

improved_diffusion/improved_diffusion/train_util.py

def run_step(self, batch, cond):
        **self.forward_backward(batch, cond)**
        if self.use_fp16:
            self.optimize_fp16()
        else:
            self.optimize_normal()
        self.log_step()

def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
            # print(micro_cond.keys())
            **compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )**

            **if last_batch or not self.use_ddp:
                losses = compute_losses()**
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            if self.use_fp16:
                loss_scale = 2 ** self.lg_loss_scale
                (loss * loss_scale).backward()
            else:
                loss.backward()
  • micro 는 64라 원래 배치를 그대로 써버림
  • cond, micro_cond에는 input_id (token index)만 들어가 있음
  • sampler는 uniform sampler (improved_diffusion/improved_diffusion/resample.py) 이므로 indicies는 (batch,) 크기의 diffusion timestep sample 이 들어가있고 weight는 같은 크기의 1.0 으로 채워져 있음

improved-diffusion/improved_diffusion/respace.py

def training_losses(
        self, model, *args, **kwargs
    ):  # pylint: disable=signature-differs
        # print('called training_losses')
        return super().training_losses(self._wrap_model(model), *args, **kwargs)

improved-diffusion/improved_diffusion/gaussian_diffusion.py

def training_losses(self, model, *args, **kwargs):
        **if self.training_mode == 'e2e':
            return self.training_losses_e2e(model, *args, **kwargs)**
        elif self.training_mode == 'e2e-simple':
            return self.training_losses_e2e_simple(model, *args, **kwargs)
        else:
            return self.training_losses_emb(model, *args, **kwargs)

def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
	   assert 'input_ids' in model_kwargs
        input_ids = model_kwargs.pop('input_ids').to(t.device)
        x_start_mean = model.model.module.get_embeds(input_ids)
		    std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
                                   th.tensor([0]).to(x_start_mean.device),
                                   x_start_mean.shape)
  • input_ids shape : [64,64]
  • improved-diffusion/improved_diffusion/transformer_model2.py에 get_embeds 함수에서
    • self.embedding : nn.Embedding(vocab_size, self.in_channels) 통과한 값이 나옴
  • x_start_mean.shape : [64, 64, 128] (batch, horizon, emb dimension)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)
  • res는 self.sqrt_one_minus_alphas_cumprod 에서 timestep에 대응하는 값 하나로 시작
  • timesteps 로는 시작단계이므로 device위에서 올라간 0을 전달
  • 그 하나의 값으로 채워진 x_mean(arr)과 같은 크기의 tensor 반환
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
	   assert 'input_ids' in model_kwargs
        input_ids = model_kwargs.pop('input_ids').to(t.device)
        x_start_mean = model.model.module.get_embeds(input_ids)
		    std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
                                   th.tensor([0]).to(x_start_mean.device),
                                   x_start_mean.shape)

				x_start_log_var = 2 * th.log(std)
        **x_start = self.get_x_start(x_start_mean, std)**
def get_x_start(self, x_start_mean, std):
        noise = th.randn_like(x_start_mean)
        # print(std.shape, noise.shape, x_start_mean.shape)
        assert noise.shape == x_start_mean.shape
        # print(x_start_mean.device, noise.device)
        return (
             x_start_mean + std * noise
        )
  • 위에서 구한 std에 같은 모양의 noise 만들어주고 std 곱해서 원래 값에 noise 줌
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
	   assert 'input_ids' in model_kwargs
        input_ids = model_kwargs.pop('input_ids').to(t.device)
        x_start_mean = model.model.module.get_embeds(input_ids)
		    std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
                                   th.tensor([0]).to(x_start_mean.device),
                                   x_start_mean.shape)

				x_start_log_var = 2 * th.log(std)
        x_start = self.get_x_start(x_start_mean, std)
				**if noise is None:
            noise = th.randn_like(x_start)**
        **x_t = self.q_sample(x_start, t, noise=noise) # reparametrization trick.**
        get_logits = model.model.module.get_logits
def q_sample(self, x_start, t, noise=None):
        """.
        In other words, sample from q(x_t | x_0).
        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert noise.shape == x_start.shape
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )
  • 함수설명 ) x_start : x_o를 의미 || t : q(x_t | x_0) 에서 t 의미 ||
    • 반환값 : x_start의 noise 첨가된 버전
  • extract_into_tensor는 아까 다룬 함수로 numpy array로 저장된 self.sqrt_alphas_cumprod 같은 상수들을 x_start shape 에 맞춰 element 단위로 계산이 이뤄지게 끔

Untitled

Untitled

def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
	    get_logits = model.model.module.get_logits

			elif self.loss_type == LossType.E2E_MSE or self.loss_type == LossType.E2E_RESCALED_MSE:
            # print(x_t.shape)
            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
  • get_logits는 transformer_model2.py 의 get_logits 함수를 가져오는 것으로 생각됨
  • model forward 함수가 바로 나오는게 아니라 respace.py 의 call 함수가 호출됨

improved-diffusion/improved_diffusion/respace.py

class _WrappedModel:
    def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
        self.model = model
        self.timestep_map = timestep_map
        self.rescale_timesteps = rescale_timesteps
        self.original_num_steps = original_num_steps

    def __call__(self, x, ts, **kwargs):
        # print(ts)
        map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
        new_ts = map_tensor[ts]
        # print(new_ts)
        **if self.rescale_timesteps:
            new_ts = new_ts.float() * (1000.0 / self.original_num_steps)**
        # temp = self.model(x, new_ts, **kwargs)
        # print(temp.shape)
        # return temp
        # print(new_ts)
        return self.model(x, new_ts, **kwargs)
  • new_ts가 map_tensor를 거치는 과정은 의미가 없음 (map_tensor가 [0, … , 1999] 를 포함하고 있기 때문에)
  • rescale_timesteps는 참인데 self.original_num_steps = 2000 이므로 원래 원소들에 1/2 를 한다
    • why …?
  • self.model transformer2 model인데 이에 대해 forward 실행

input : [batch_dim, horizon_len, vector_dim] x

[batch_dim] t (but 1/2 scaled)

output : [batch_dim, horizon_len, vector_dim] 동일

def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
	    get_logits = model.model.module.get_logits

			elif self.loss_type == LossType.E2E_MSE or self.loss_type == LossType.E2E_RESCALED_MSE:
            # print(x_t.shape)
            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)

						target = {
                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
                    x_start=x_start, x_t=x_t, t=t
                )[0],
                ModelMeanType.START_X: x_start,
                ModelMeanType.EPSILON: noise,
            }[self.model_mean_type]
            assert model_output.shape == target.shape == x_start.shape
            terms["mse"] = mean_flat((target - model_output) ** 2)
            # print( terms["mse"])
            model_out_x_start = self.x0_helper(model_output, x_t, t)['pred_xstart']
            t0_mask = (t == 0)
            t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
            # print(terms["mse"].shape, )
            terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])
  • 모델 output은 아래에 대한 예측값 , 혹은 x 그 자체

Untitled

Untitled

  • model_mean_type 이 x_start 이므로 target은 x_start
  • model 이 x_0 를 직접적으로 predict 하고 있다고 했고 위의 간단화된 loss를 해당 논문에서는

Untitled

이 loss로 사용하고 있음

  • mean_flat은 batch 제외한 dimension 에 대해 mean을 취함
    • terms[’mse’] : [batch_size,] 반환
    • improved_diffusion/improved_diffusion/nn.py
  • sampling 한 timestep t가 0 일 때는 t0_loss를 주게 되고 아닌 경우에는 위에서 구한 loss를 주게 되는데
    • 일단 x_start_mean 는 std 가 섞이지 않은 token의 순수 embedding
    • x0_helper 거친 model_out_x_start는 그냥 model_output
def x0_helper(self, model_output, x, t):
        if self.model_mean_type == ModelMeanType.PREVIOUS_X:
            pred_xstart =  self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
            pred_prev = model_output

        elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
            **if self.model_mean_type == ModelMeanType.START_X:
                pred_xstart = model_output**
            else:
                pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
            **pred_prev, _, _ = self.q_posterior_mean_variance(
                x_start=pred_xstart, x_t=x, t=t
            )**

        else:
            raise NotImplementedError(self.model_mean_type)
        return {'pred_xprev':pred_prev, 'pred_xstart':pred_xstart}
  • 여기서 나온 pred_xstart를 사용할 건데 model_output 그냥 그대로 나옴

종합) t==0 일 때는 token_id→ 128 embedding 된 것과 model_output 사이 mse

t ≠0 일때는 std 섞인 128 embedding과 model_output 사이 mse

def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
	    get_logits = model.model.module.get_logits

			elif self.loss_type == LossType.E2E_MSE or self.loss_type == LossType.E2E_RESCALED_MSE:
            # print(x_t.shape)
            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)

						target = {
                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
                    x_start=x_start, x_t=x_t, t=t
                )[0],
                ModelMeanType.START_X: x_start,
                ModelMeanType.EPSILON: noise,
            }[self.model_mean_type]
            assert model_output.shape == target.shape == x_start.shape
            terms["mse"] = mean_flat((target - model_output) ** 2)
            # print( terms["mse"])
            model_out_x_start = self.x0_helper(model_output, x_t, t)['pred_xstart']
            t0_mask = (t == 0)
            t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
            # print(terms["mse"].shape, )
            terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])
						
						out_mean, _, _ = self.q_mean_variance(x_start, th.LongTensor([self.num_timesteps - 1]).to(x_start.device))
            tT_loss =  mean_flat(out_mean ** 2)

            decoder_nll = self.token_discrete_loss(x_start, get_logits, input_ids)

            # assert (model.lm_head.weight == model.word_embedding.weight).all()

            if "vb" in terms:
                terms["loss"] = terms["mse"] + terms["vb"]
            else:
                # KEY
                terms["loss"] = terms["mse"] + (decoder_nll + tT_loss)
                # terms["loss"] = terms["mse"] + (1.0/self.num_timesteps) * decoder_nll + \
                #                 (1.0/self.num_timesteps) * tT_loss
        else:
            raise NotImplementedError(self.loss_type)

        return terms
  • 여기서는 q_mean_variance 함수에서 mean, var를 만들지만

Untitled

  • 앞서 사용했던 q_sample 함수에서는 다른 코드와는 다르게 mean, var를 사용하지 않고 바로 뽑아버린다
  • x_T 를 예측한 tT_loss는 결국 gaussian noise 형태를 취해야 하는 out_mean이 최소를 갖게 하는 방향으로 loss 에 포함한 것으로 생각됨
def token_discrete_loss(self, x_t, get_logits, input_ids):
        if self.model_arch == 'conv-unet' or  self.model_arch == '1d-unet':
            reshaped_x_t = x_t.view(x_t.size(0), x_t.size(1), -1).permute(0, 2, 1)
        **else:
            # print(x_t.shape)
            reshaped_x_t = x_t**
        logits = get_logits(reshaped_x_t)  # bsz, seqlen, vocab
        # print(logits.shape)
        loss_fct = th.nn.CrossEntropyLoss(reduction='none')
        decoder_nll = loss_fct(logits.view(-1, logits.size(-1)), input_ids.view(-1)).view(input_ids.shape)
        # print(decoder_nll.shape)
        decoder_nll = decoder_nll.mean(dim=-1)
        return decoder_nll
  • reshaped_x_t는 결국 std가 섞인 x_start를 의미하며 이게 decoder를 통과해서 나온 vocab과 실제 input vocab 사이 crossentropy loss를 구해 decoder_nll 텀으로 추가
  • logits : [64, 64, vocab] → [4096, vocab]
  • input_ids : [64, 64] → [4096]

improved_diffusion/improved_diffusion/transformer_model2.py

def get_logits(self, hidden_repr):
        if self.logits_mode == 1:
            return self.lm_head(hidden_repr)
  • lm_head는 (128, vocab_size) 의 decoder 모듈

다시 train_util.py

compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            **loss = (losses["loss"] * weights).mean()**
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            if self.use_fp16:
                loss_scale = 2 ** self.lg_loss_scale
                (loss * loss_scale).backward()
            **else:
                loss.backward()**
  • losses["loss"], weights 둘다 [batch_size] dim 이고 weights 는 모두 1

train_util.py

def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        if self.use_fp16:
            self.optimize_fp16()
        else:
            **self.optimize_normal()**

def optimize_normal(self):
        if self.gradient_clipping > 0:
            self.grad_clip()
        self._log_grad_norm()
        **self._anneal_lr()**
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)

def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

# in nn.py
def update_ema(target_params, source_params, rate=0.99):
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)

evaluation

train_util.py

if self.eval_data is not None and self.step % self.eval_interval == 0:
                batch_eval, cond_eval = next(self.eval_data)
                **self.forward_only(batch, cond)**
                print('eval on validation set')
                logger.dumpkvs()
  • torch.no_grad() 로 쌓인 것 빼고 거의 동일

저장 경로 : diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e



decoder 부분에서 우리 케이스 적용..?

  • 현 코드 실행 흐름
    • token id → vector (128). x_start_mean
    • add std → x_start
    • x_start → decoder, token_id 사이 cross entropy
  • 우리 케이스 ?
    • state_horizon 이미 존재
    • add_std
    • std 가 섞인 두 state에 대해 action decoder model ..?


decoder 부분

run this

python scripts/text_sample.py --model_path /home/doolee13/Diffusion-LM/improved-diffusion/diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e/ema_0.9999_001000.pt --batch_size 50 --num_samples 50 --top_p -1.0 --out_dir generation_outputs

improved_diffusion/scripts/text_sample.py

	  if args.experiment == 'random1': args.experiment = 'random'
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    logger.log(f'the parameter count is {pytorch_total_params}')

    # diffusion.rescale_timesteps = False  # DEBUG --> REMOVE
    print(diffusion.rescale_timesteps, 'a marker for whether we are in the debug mode')
    model.to(dist_util.dev())
    model.eval() # DEBUG

		model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
                                    os.path.split(args.model_path)[0])

		**if args.training_mode.startswith('e2e'):
        print('e2e, load the right model embeddings', '*'*80)
        model2.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu())**
  • create model and diffusion은 위에서와 동일
  • model2, tokenizer도동일
    • model2 는 nn.Embedding(vocab_size, 128)
while len(all_images) * args.batch_size < args.num_samples:
        model_kwargs = {}
        ~~if args.experiment_mode == 'conditional_gen':
            batch, model_kwargs = next(data)
            model_kwargs.pop('input_ids')
            if args.mbr_sample > 1:
                model_kwargs = {k: v.to(dist_util.dev()).repeat_interleave(args.mbr_sample, dim=0) for k, v in model_kwargs.items()}
            else:
                model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
            print([(k, v.shape) for (k,v) in model_kwargs.items()])~~
        ~~if args.class_cond:
            classes = th.randint(
                low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
            )
            model_kwargs["y"] = classes~~
        sample_fn = (
            diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
        )
        if args.model_arch == '1d-unet':
            if args.mbr_sample > 1 and args.experiment_mode == 'conditional_gen':
                sample_shape = (args.batch_size * args.mbr_sample, args.in_channel, args.image_size ** 2)
            else:
                sample_shape = (args.batch_size,  args.in_channel, args.image_size ** 2)
        **else:**
            if args.mbr_sample > 1 and args.experiment_mode == 'conditional_gen':
                sample_shape = (args.batch_size * args.mbr_sample, args.image_size ** 2, args.in_channel)
            **else:
                sample_shape = (args.batch_size, args.image_size ** 2, args.in_channel**)
        print(sample_shape)
        **sample = sample_fn(
            model,
            sample_shape,
            clip_denoised=args.clip_denoised,
            denoised_fn=partial(denoised_fn_round, args, model3.cuda()) if args.clamp == 'clamp' else None,
            model_kwargs=model_kwargs,
            top_p =args.top_p,
        )**

        if args.model_arch == '1d-unet':
            print(sample.shape)
            sample = sample.permute(0, 2, 1)
        print(sample.shape)

        gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
        if args.class_cond:
            gathered_labels = [
                th.zeros_like(classes) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(gathered_labels, classes)
            all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
        logger.log(f"created {len(all_images) * args.batch_size} samples")
  • args.use_ddim 이 false 이므로 sample_fn 은 diffusion.p_sample_loop
  • sample_shape : [bsize, 64, 128]

sample 함수 실행 시 gaussian_diffusion.py 의 p_sample_loop 실행

def p_sample_loop():
				final = None
        for sample in self.p_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            top_p=top_p,
        ):
            final = sample
        return final["sample"]

def p_sample_loop_progressive():
				if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            img = noise
        else:
            img = th.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1]

				**for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                out = self.p_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    model_kwargs=model_kwargs,
                    top_p=top_p,
                )
                yield out
                img = out["sample"]**
  • device : ‘cuda’,
  • img : random noise
  • indices : [1999 ~ 0]
  • t : [bsize, ]
def p_sample():
				out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )

def p_mean_variance():
				else:
            B, C = x.size(0), x.size(-1)
        assert t.shape == (B,)
        # print(x.shape)
        model_output = model(x, self._scale_timesteps(t), **model_kwargs)

				else:
            model_variance, model_log_variance = {
                # for fixedlarge, we set the initial (log-)variance like so
                # to get a better decoder log likelihood.
                ModelVarType.FIXED_LARGE: (
                    np.append(self.posterior_variance[1], self.betas[1:]),
                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),
                ),
                ModelVarType.FIXED_SMALL: (
                    **self.posterior_variance,
                    self.posterior_log_variance_clipped,**
                ),
            }[self.model_var_type]
            model_variance = _extract_into_tensor(model_variance, t, x.shape)
            model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)

			elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
            if self.model_mean_type == ModelMeanType.START_X:
                pred_xstart = process_xstart(model_output)

						**model_mean, _, _ = self.q_posterior_mean_variance(
                x_start=pred_xstart, x_t=x, t=t
            )**

def process_xstart(x):
            if denoised_fn is not None:
                # print(denoised_fn)
                x = denoised_fn(x, t)
            if clip_denoised:
                return x.clamp(-1, 1)
            return x

# in improved_diffusion/improved_diffusion/test_util.py
  • x_t 가 주어졌을 때 x_t-1 을 sample 하는 함수
  • 이러기 위해서는 p(x_t-1 | x_t) 분포를 알아야 함
  • B, C :50, 128
  • model_output : [50, 64, 128]
  • model_variance, model_log_variance에는 각각 bold 체 할당 후 현재 timestep 에 맞는 것 반환
    • shape : [50, 64, 128]

Untitled

Untitled

  • 이렇게 p(x_t-1 | x_t) 에 대응하는 sigma(var) 준비 완료 (상수)
  • process_xstart는 각각이 실제 단어 emb와 가장 가까이 갈 수 있도록 조정해주는 과정
    • 일단 우리 경우와는 다르므로 생략
    • p_mean_var에서 q_posterior_mean_variance 예측하기 전에
def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
  • q(x_t-1 | x_t, x_0) 이용해서 x_t-1 mean을 구한다
  • 여기서 계산해서 나온 model mean만 사용
	def p_mean_variance():
		assert (
            model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        )
        return {
            "mean": model_mean,
            "variance": model_variance,
            "log_variance": model_log_variance,
            "pred_xstart": pred_xstart,
        }
  • model_mean : q(x_t-1 | x_t, x_0) 에서 나온 값
  • variance, log_variance : 상수로 정해진 값 사용 (t에 depend)
  • pred_xstart : x_t, t 주어졌을 때 x_0 예측한 값(모델을 통해 )을 후처리해서 얻은 값

다시 p_sample()

def p_sample():
		else:
            noise = th.randn_like(x)
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"],
                'greedy_mean':out["mean"], 'out':out}
  • noise , out[’mean’] shape : [50, 64, 128]
  • nonzero_mask shape : [50, 1,1]
  • sample : t ≠ 0 일 때 noise 가 섞인 x_t-1 예측 결과
  • pred_xstart : t 시점에서 예측한 x_0
  • greedy_mean : noise 가 없는 x_t-1 예측 결과

다시 p_sample_loop_progressive

for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                out = self.p_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    model_kwargs=model_kwargs,
                    top_p=top_p,
                )
                yield out
                img = out["sample"]
  • t 시점에서 예측한 x_t-1 인 (noise 가 섞인) img로 계속 sampling 과정 진행

별개로 sampling 과정 정리

Untitled

이 과정이 neural network parameter를 포함하지 않은 실제 분포이므로 이를 수행하기 위해 일단 x_0 가 필요하다

x_0는 이 논문의 경우 epsilon이 아니라 모델이 x_0를 예측하기 때문에 (x_t, t) 를 주고 모델에서 뽑는다

현재 존재하는 x_t와 위에서 뽑은 x_0로 → q(x_t-1 | x_t, x_0) 수행해 mean 값을 얻는다

Untitled

파라미터를 포함한 u는 q_posterior로 부터나오고 (parameter 포함 이유는 x_0 예측을 model이 했기 때문에) variance는 위의 그림처럼 t에 따라 정해진 상수로 정해짐

profile
0100101

0개의 댓글