train_setup
-use_kl False --learn_sigma False
args.experiment : random
args.modality : roc
Model_FILE : diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e
folder name : diffusion_models
Model_FILE:
diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e
improved-dffusion/scripts/run_train.py
with open(Model_FILE + '.sh', 'w') as f:
print(COMMANDLINE, file=f)
COMMANDLINE :
OPENAI_LOGDIR=diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e TOKENIZERS_PARALLELISM=false python scripts/train.py --checkpoint_path diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e --model_arch transformer --modality roc --save_interval 50000 --lr 0.0001 --batch_size 64 --diffusion_steps 2000 --noise_schedule sqrt --use_kl False --learn_sigma False --image_size 8 --num_channels 128 --seed 101 --dropout 0.1 --in_channel 128 --out_channel 128 --padding_mode pad --experiment random --lr_anneal_steps 400000 --weight_decay 0.0 --num_res_blocks 2 --predict_xstart True --training_mode e2e --vocab_size 11043 --roc_train ../datasets/ROCstory
improved_diffusion/improved-diffusion/script.util.py
def create_model_and_diffusion():
model = create_model(
image_size,
num_channels,
num_res_blocks,
learn_sigma=learn_sigma,
class_cond=class_cond,
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
model_arch=model_arch,
in_channel=in_channel,
out_channel=out_channel,
training_mode=training_mode,
vocab_size=vocab_size,
config_name=config_name,
experiment_mode=experiment_mode,
logits_mode=logits_mode,
)
def create_model():
elif model_arch == 'transformer':
if image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
elif image_size == 32:
channel_mult = (1, 2, 2, 2)
elif image_size == 16: # DEBUG**
channel_mult = (1, 2, 2, 2)
else:
channel_mult = (1, 2, 2, 2)
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(image_size // int(res))
return TransformerNetModel2(
in_channels=in_channel, # 3, DEBUG**
model_channels=num_channels,
out_channels=(out_channel if not learn_sigma else out_channel*2), # DEBUG** (3 if not learn_sigma else 6),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
config_name=config_name,
training_mode=training_mode,
vocab_size=vocab_size,
experiment_mode=experiment_mode,
logits_mode=logits_mode,
)
else:
raise NotImplementedError
improved-diffusion/improved-diffusion/transformer_model2.py
class TransformerNetModel2(nn.Module): # line 674
if num_heads_upsample == -1: # True
num_heads_upsample = num_heads
if training_mode == 'e2e':
self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
if self.logits_mode == 2:
# self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=False)
self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=True)
**else:
self.lm_head = nn.Linear(self.in_channels, vocab_size)**
with th.no_grad():
self.lm_head.weight = self.word_embedding.weight
if experiment_mode == 'conditional_gen':
self.conditional_gen = True
self.encoder_emb = nn.Embedding(vocab_size, config.hidden_size)
self.encoder = BertEncoder(config)
print(config, 'conditional_gen')
config.is_decoder = True
config.add_cross_attention = True
**elif experiment_mode == 'lm':
self.conditional_gen = False**
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
SiLU(),
linear(time_embed_dim, config.hidden_size),
)
self.input_up_proj = nn.Sequential(nn.Linear(in_channels, config.hidden_size),
nn.Tanh(), nn.Linear(config.hidden_size, config.hidden_size))
if init_pretrained:
from transformers.models.bert.modeling_bert import BertModel
temp_bert = BertModel.from_pretrained(config_name, config=config)
del temp_bert.embeddings
del temp_bert.pooler
self.input_transformers = temp_bert.encoder
print('initializing from pretrained bert.')
**else:
print(config)
self.input_transformers = BertEncoder(config)**
from transformers.models.bert.modeling_bert import BertEncoder
# config
BertConfig {
"_name_or_path": "bert-base-uncased",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.35.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# config2 = config
# config2.hidden_size = 2 * config.hidden_size
# self.output_transformers = BertEncoder(config)
self.output_down_proj = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
nn.Tanh(), nn.Linear(config.hidden_size, out_channels))
improved_diffusion/improved-diffusion/script.util.py
def create_model_and_diffusion():
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
learn_sigma=learn_sigma,
sigma_small=sigma_small,
noise_schedule=noise_schedule,
use_kl=use_kl,
predict_xstart=predict_xstart,
rescale_timesteps=rescale_timesteps,
rescale_learned_sigmas=rescale_learned_sigmas,
timestep_respacing=timestep_respacing,
model_arch=model_arch,
training_mode=training_mode,
)
improved-diffusion/improved_diffusion/script_util.py
def create_gaussian_diffusion():
betas = gd.get_named_beta_schedule(noise_schedule, steps)
improved-diffusion/improved_diffusion/gaussian_diffusion.py
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
elif schedule_name == 'sqrt':
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: 1-np.sqrt(t + 0.0001),
)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
improved_diffusion/improved-diffusion/script.util.py
def create_model_and_diffusion():
betas = gd.get_named_beta_schedule(noise_schedule, steps)
if training_mode == 'e2e':
# end to end training
if use_kl:
loss_type = gd.LossType.E2E_KL
**else:
loss_type = gd.LossType.E2E_MSE**
**if not timestep_respacing:
timestep_respacing = [steps]**
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
model_arch=model_arch,
training_mode=training_mode,
)
improved_diffusion/improved_diffusion/respace.py
def space_timesteps(num_timesteps, section_counts):
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
improved_diffusion/improved-diffusion/script.util.py
def create_model_and_diffusion():
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else **gd.ModelMeanType.START_X**
),
model_var_type=(
(
**gd.ModelVarType.FIXED_LARGE**
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
model_arch=model_arch,
training_mode=training_mode,
)
improved-diffusion/improved-diffusion/respace.py
class SpacedDiffusion(GaussianDiffusion):
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
# print(kwargs.keys())
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
{'betas': array([0.01464131, 0.00888909, 0.00706818, ..., 0.35722328, 0.55561113,
0.999 ]), 'model_mean_type': <ModelMeanType.START_X: 2>, 'model_var_type': <ModelVarType.FIXED_LARGE: 3>, 'loss_type': <LossType.E2E_MSE: 6>, 'rescale_timesteps': True, 'model_arch': 'transformer', 'training_mode': 'e2e'}
improved-diffusion/improved-diffusion/gaussian_diffusion.py
class GaussianDiffusion:
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps
self.model_arch=model_arch
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
improved-diffusion/improved-diffusion/respace.py
class SpacedDiffusion(GaussianDiffusion):
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
# print(kwargs.keys())
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
create model and diffusion 완료
improved-diffusion/scripts/train.py
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
def create_named_schedule_sampler(name, diffusion):
if name == "uniform":
return UniformSampler(diffusion)
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
improved-diffusion/scripts/train.py
else:
print('load data', '*'*50)
if args.modality == 'roc-aug' or args.modality == 'commonGen-aug':
tokenizer = load_tokenizer(args.modality, args.experiment, 'predictability/diffusion_models_v7/diff_roc_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart')
rev_tokenizer = {v: k for k, v in tokenizer.items()}
print(len(rev_tokenizer), 'loading from tokenizer. ')
elif args.use_bert_tokenizer == 'yes':
rev_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
else:
**rev_tokenizer = None**
if args.experiment == 'random1':
args.experiment = 'random'
print('loading from the vocabs here.')
assert args.in_channel == 64
assert args.modality == 'roc'
model22 = torch.nn.Embedding(args.vocab_size, args.in_channel)
model22_weight = torch.load('predictability/diffusion_models_v7/diff_roc-aug_pad_rand64_'
'transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e/'
'ema_0.9999_200000.pt', map_location='cpu')['word_embedding.weight']
model22.weight = model22_weight
model22.weight.requires_grad=False
**else:
model22 = None**
data = load_data_text(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
data_args = args,
task_mode=args.modality,
padding_mode=args.padding_mode, #block, pad
load_vocab=rev_tokenizer,
model=model22,
)
improved-diffusion/improved-diffusion/text_dataset.py
def load_data_text(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None,
task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None,
):
**if data_args.experiment.startswith('random') and model is None:
model = None**
elif data_args.experiment.startswith('random') and model is not None:
print('loading initialized random embeddings. ')
**if task_mode == 'roc' or task_mode == 'roc-aug' :
training_data, model = get_corpus_rocstory(data_args, model, image_size,
padding_mode=padding_mode, split=split,
load_vocab=load_vocab)**
def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
split='train', load_vocab=None):
import csv, torch, json
from spacy.lang.en import English
if data_args.experiment_mode == 'lm':
if data_args.modality == 'roc':
print('loading dataset from ROCStory')
nlp = English()
tokenizer = nlp.tokenizer
sentence_lst = []
print(f'loading from {data_args.roc_train}')
**if split == 'train':
print('loading form the TRAIN set')
path = f'{data_args.roc_train}/roc_train.json'**
elif split == 'valid':
print('loading form the VALID set')
path = f'{data_args.roc_train}/roc_valid.json'
else:
assert False, "invalid split for ROC dataset"
with open(path, 'r') as roc_reader:
for row in roc_reader:
sentences = json.loads(row)[0].strip()
word_lst = [x.text for x in tokenizer(sentences)]
sentence_lst.append(word_lst)
# get tokenizer.
if load_vocab is None:
counter = Counter()
for input_ids in sentence_lst:
counter.update(input_ids)
if load_vocab is None:
vocab_dict = {'START': 0, 'END': 1, 'UNK':2, 'PAD':3}
for k, v in counter.items():
if v > 10:
vocab_dict[k] = len(vocab_dict)
print(len(counter), len(vocab_dict))
path_save_vocab = f'{data_args.checkpoint_path}/vocab.json'
print(f'save the vocab to {path_save_vocab}')
with open(path_save_vocab, 'w') as f:
json.dump(vocab_dict, f)
if model is None and data_args.experiment == 'random':
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
print('initializing the random embeddings', model)
torch.nn.init.normal_(model.weight)
path_save = f'{data_args.checkpoint_path}/random_emb.torch'
print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
torch.save(model.state_dict(), path_save)
if data_args.experiment_mode == 'lm' and data_args.modality in ['roc-aug', 'roc', 'yelp', 'commonGen', 'commonGen-aug'] \
and data_args.cache_mode=='no':
train_dataset = helper_tokenize_stream(sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode)
return train_dataset, model
다시 load_data_text
def load_data_text(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None,
task_mode='roc', model=None, padding_mode='block', split='train', load_vocab=None,
):
if data_args.modality in ['roc-aug', 'roc', 'book', 'yelp', 'commonGen', 'commonGen-aug'] and data_args.cache_mode=='no':
dataset = TextDataset_NoCache(
training_data,
image_size,
data_args,
model_arch=data_args.model_arch,
model_emb=model
)
if deterministic:
data_loader = DataLoader(
dataset,
batch_size=batch_size, # 20,
drop_last=True,
shuffle=False,
num_workers=1,
)
**else:
data_loader = DataLoader(
dataset,
batch_size=batch_size, # 20,
drop_last=True,
shuffle=True,
num_workers=1,
)**
while True:
yield from data_loader
improve_diffusion/scripts/train.py
model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
args.checkpoint_path, extra_args=args)
if args.modality == 'book' or args.use_bert_tokenizer == 'yes':
rev_tokenizer = tokenizer # BERT tokenizer BPE.
else:
**rev_tokenizer = {v: k for k, v in tokenizer.items()}**
improved_diffusion/improved_diffusion/rounding.py
def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None):
if mode in ['random', 'random1', 'random_up_proj', 'glove']:
if modality == 'synth':
else:
import json
if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'):
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
if 'e2e' in file and modality == 'book':
emb_dim = 1
**else:
path_save_tokenizer = '{}/vocab.json'.format(file)
print(f'loading from {path_save_tokenizer}')
with open(path_save_tokenizer, 'r') as f:
vocab = json.load(f)
print(len(vocab))
tokenizer = {v: k for k, v in vocab.items()}**
model = torch.nn.Embedding(len(tokenizer), emb_dim)
path_save = '{}/random_emb.torch'.format(file)
model.load_state_dict(torch.load(path_save))
return model, tokenizer
improved_diffusion/scripts/train.py
data_valid = load_data_text(
data_dir=args.data_dir,
batch_size=args.batch_size,
image_size=args.image_size,
class_cond=args.class_cond,
data_args=args,
task_mode=args.modality,
padding_mode=args.padding_mode, # block, pad
split='valid',
load_vocab=rev_tokenizer,
model=model2,
)
# dist.barrier()
# import time
# while not os.path.exists(os.path.join(args.checkpoint_path, 'vocab.json')):
# time.sleep(1)
def get_mapping_func(args, diffusion, data):
model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
args.checkpoint_path, extra_args=args)
model3 = get_weights(model2, args)
print(model3, model3.weight.requires_grad)
mapping_func = partial(compute_logp, args, model3.cuda())
diffusion.mapping_func = mapping_func
return mapping_func
get_mapping_func(args, diffusion, data)
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
checkpoint_path=args.checkpoint_path,
gradient_clipping=args.gradient_clipping,
eval_data=data_valid,
eval_interval=args.eval_interval
).run_loop()
improved_diffusion/improved_diffusion/train_util.py
class TrainLoop:
def __init__():
self.model = model
self.diffusion = diffusion
self.data = data
self.eval_data = eval_data
self.batch_size = batch_size # 64
self.microbatch = microbatch if microbatch > 0 else batch_size # 64
self.lr = lr # 0.0001
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")]
) # [0.9999]
self.log_interval = log_interval
self.eval_interval = eval_interval
self.save_interval = save_interval
self.resume_checkpoint = resume_checkpoint
self.use_fp16 = use_fp16 # False
self.fp16_scale_growth = fp16_scale_growth
self.schedule_sampler = schedule_sampler or **UniformSampler(diffusion)**
self.weight_decay = weight_decay # 0.0
self.lr_anneal_steps = lr_anneal_steps # 400000
self.gradient_clipping = gradient_clipping # -1.0
self.step = 0
self.resume_step = 0
self.global_batch = self.batch_size * dist.get_world_size() # 64
self.model_params = list(self.model.parameters())
self.master_params = self.model_params
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
self.sync_cuda = th.cuda.is_available()
self.checkpoint_path = checkpoint_path # DEBUG **
self._load_and_sync_parameters()
if self.use_fp16:
self._setup_fp16()
self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
if self.resume_step:
self._load_optimizer_state()
# Model was resumed, either due to a restart or a checkpoint
# being specified at the command line.
self.ema_params = [
self._load_ema_parameters(rate) for rate in self.ema_rate
]
**else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]**
if th.cuda.is_available(): # DEBUG **
self.use_ddp = True
self.ddp_model = DDP(
self.model,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
else:
if dist.get_world_size() > 1:
logger.warn(
"Distributed training requires CUDA. "
"Gradients will not be synchronized properly!"
)
self.use_ddp = False
self.ddp_model = self.model
def run_loop(self):
while (
not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps
):
batch, cond = next(self.data)
**self.run_step(batch, cond)**
def run_step(self, batch, cond):
**self.forward_backward(batch, cond)**
if self.use_fp16:
self.optimize_fp16()
else:
self.optimize_normal()
self.log_step()
batch, cond에 들어가는 값?
improved_diffusion/improved_diffusion/text_datasets.py
def __getitem__(self, idx):
# We are not on a new enough PIL to support the `reducing_gap`
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
with torch.no_grad():
input_ids = self.text_datasets['train'][idx]['input_ids']
model = self.model_emb
**if self.data_args.experiment.startswith('random'):**
**hidden_state = model(torch.tensor(input_ids))**
elif self.data_args.experiment == 'gpt2_pre_compress':
input_ids2 = torch.tensor(input_ids).to(model.device)
input_embs = model.transformer.wte(input_ids2) # input_embs
hidden_state = model.down_proj(input_embs)
hidden_state = hidden_state * data_args.emb_scale_factor
if self.model_arch == 'conv-unet':
arr = np.array(hidden_state,
dtype=np.float32).reshape(self.resolution, self.resolution, -1)
# print(self.eigen_transform.shape)
if self.eigen_transform is not None:
old_shape = arr.shape
arr = arr.reshape(1, -1) - self.eigen_transform['mean']
arr = arr @ self.eigen_transform['map']
arr = arr.reshape(old_shape)
if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
out_dict = {}
out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
# print(out_dict.keys())
return np.transpose(arr, [2, 0, 1]), out_dict
elif self.model_arch == '1d-unet':
arr = np.array(hidden_state,
dtype=np.float32) # seqlen, dim
if self.eigen_transform is not None:
old_shape = arr.shape
arr = arr.reshape(1, -1) - self.eigen_transform['mean']
arr = arr @ self.eigen_transform['map']
arr = arr.reshape(old_shape)
if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
arr = np.transpose(arr, [1, 0])
out_dict = {}
out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
# out_dict['mapping_func'] = self.mapping_func
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
# print(arr.shape)
return arr, out_dict
**else:
arr = np.array(hidden_state,
dtype=np.float32)
if self.eigen_transform is not None:
old_shape = arr.shape
# arr = arr.reshape(1, -1) @ self.eigen_transform
arr = arr.reshape(1, -1) - self.eigen_transform['mean']
arr = arr @ self.eigen_transform['map']
arr = arr.reshape(old_shape)**
if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
# print(arr.dtype)
# print(self.data_args.noise_level, 'using the noise level.')
arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
# print(arr.dtype)
out_dict = {}
out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
# out_dict['mapping_func'] = self.mapping_func
if self.data_args.experiment_mode == 'conditional_gen':
out_dict['src_ids'] = np.array(self.text_datasets['train'][idx]['src_ids'])
out_dict['src_mask'] = np.array(self.text_datasets['train'][idx]['src_mask'])
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
return arr, out_dict
improved_diffusion/improved_diffusion/train_util.py
def run_step(self, batch, cond):
**self.forward_backward(batch, cond)**
if self.use_fp16:
self.optimize_fp16()
else:
self.optimize_normal()
self.log_step()
def forward_backward(self, batch, cond):
zero_grad(self.model_params)
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
# print(micro_cond.keys())
**compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,
)**
**if last_batch or not self.use_ddp:
losses = compute_losses()**
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
improved-diffusion/improved_diffusion/respace.py
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
# print('called training_losses')
return super().training_losses(self._wrap_model(model), *args, **kwargs)
improved-diffusion/improved_diffusion/gaussian_diffusion.py
def training_losses(self, model, *args, **kwargs):
**if self.training_mode == 'e2e':
return self.training_losses_e2e(model, *args, **kwargs)**
elif self.training_mode == 'e2e-simple':
return self.training_losses_e2e_simple(model, *args, **kwargs)
else:
return self.training_losses_emb(model, *args, **kwargs)
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
assert 'input_ids' in model_kwargs
input_ids = model_kwargs.pop('input_ids').to(t.device)
x_start_mean = model.model.module.get_embeds(input_ids)
std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
th.tensor([0]).to(x_start_mean.device),
x_start_mean.shape)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
assert 'input_ids' in model_kwargs
input_ids = model_kwargs.pop('input_ids').to(t.device)
x_start_mean = model.model.module.get_embeds(input_ids)
std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
th.tensor([0]).to(x_start_mean.device),
x_start_mean.shape)
x_start_log_var = 2 * th.log(std)
**x_start = self.get_x_start(x_start_mean, std)**
def get_x_start(self, x_start_mean, std):
noise = th.randn_like(x_start_mean)
# print(std.shape, noise.shape, x_start_mean.shape)
assert noise.shape == x_start_mean.shape
# print(x_start_mean.device, noise.device)
return (
x_start_mean + std * noise
)
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
assert 'input_ids' in model_kwargs
input_ids = model_kwargs.pop('input_ids').to(t.device)
x_start_mean = model.model.module.get_embeds(input_ids)
std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
th.tensor([0]).to(x_start_mean.device),
x_start_mean.shape)
x_start_log_var = 2 * th.log(std)
x_start = self.get_x_start(x_start_mean, std)
**if noise is None:
noise = th.randn_like(x_start)**
**x_t = self.q_sample(x_start, t, noise=noise) # reparametrization trick.**
get_logits = model.model.module.get_logits
def q_sample(self, x_start, t, noise=None):
""".
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
get_logits = model.model.module.get_logits
elif self.loss_type == LossType.E2E_MSE or self.loss_type == LossType.E2E_RESCALED_MSE:
# print(x_t.shape)
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
improved-diffusion/improved_diffusion/respace.py
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
# print(ts)
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# print(new_ts)
**if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)**
# temp = self.model(x, new_ts, **kwargs)
# print(temp.shape)
# return temp
# print(new_ts)
return self.model(x, new_ts, **kwargs)
input : [batch_dim, horizon_len, vector_dim] x
[batch_dim] t (but 1/2 scaled)
output : [batch_dim, horizon_len, vector_dim] 동일
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
get_logits = model.model.module.get_logits
elif self.loss_type == LossType.E2E_MSE or self.loss_type == LossType.E2E_RESCALED_MSE:
# print(x_t.shape)
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
# print( terms["mse"])
model_out_x_start = self.x0_helper(model_output, x_t, t)['pred_xstart']
t0_mask = (t == 0)
t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
# print(terms["mse"].shape, )
terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])
이 loss로 사용하고 있음
def x0_helper(self, model_output, x, t):
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
pred_prev = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
**if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = model_output**
else:
pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
**pred_prev, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)**
else:
raise NotImplementedError(self.model_mean_type)
return {'pred_xprev':pred_prev, 'pred_xstart':pred_xstart}
종합) t==0 일 때는 token_id→ 128 embedding 된 것과 model_output 사이 mse
t ≠0 일때는 std 섞인 128 embedding과 model_output 사이 mse
def training_losses_e2e(self, model, x_start, t, model_kwargs=None, noise=None):
get_logits = model.model.module.get_logits
elif self.loss_type == LossType.E2E_MSE or self.loss_type == LossType.E2E_RESCALED_MSE:
# print(x_t.shape)
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
target = {
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0],
ModelMeanType.START_X: x_start,
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
# print( terms["mse"])
model_out_x_start = self.x0_helper(model_output, x_t, t)['pred_xstart']
t0_mask = (t == 0)
t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
# print(terms["mse"].shape, )
terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])
out_mean, _, _ = self.q_mean_variance(x_start, th.LongTensor([self.num_timesteps - 1]).to(x_start.device))
tT_loss = mean_flat(out_mean ** 2)
decoder_nll = self.token_discrete_loss(x_start, get_logits, input_ids)
# assert (model.lm_head.weight == model.word_embedding.weight).all()
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:
# KEY
terms["loss"] = terms["mse"] + (decoder_nll + tT_loss)
# terms["loss"] = terms["mse"] + (1.0/self.num_timesteps) * decoder_nll + \
# (1.0/self.num_timesteps) * tT_loss
else:
raise NotImplementedError(self.loss_type)
return terms
def token_discrete_loss(self, x_t, get_logits, input_ids):
if self.model_arch == 'conv-unet' or self.model_arch == '1d-unet':
reshaped_x_t = x_t.view(x_t.size(0), x_t.size(1), -1).permute(0, 2, 1)
**else:
# print(x_t.shape)
reshaped_x_t = x_t**
logits = get_logits(reshaped_x_t) # bsz, seqlen, vocab
# print(logits.shape)
loss_fct = th.nn.CrossEntropyLoss(reduction='none')
decoder_nll = loss_fct(logits.view(-1, logits.size(-1)), input_ids.view(-1)).view(input_ids.shape)
# print(decoder_nll.shape)
decoder_nll = decoder_nll.mean(dim=-1)
return decoder_nll
improved_diffusion/improved_diffusion/transformer_model2.py
def get_logits(self, hidden_repr):
if self.logits_mode == 1:
return self.lm_head(hidden_repr)
다시 train_util.py
compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,
)
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
**loss = (losses["loss"] * weights).mean()**
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
**else:
loss.backward()**
train_util.py
def run_step(self, batch, cond):
self.forward_backward(batch, cond)
if self.use_fp16:
self.optimize_fp16()
else:
**self.optimize_normal()**
def optimize_normal(self):
if self.gradient_clipping > 0:
self.grad_clip()
self._log_grad_norm()
**self._anneal_lr()**
self.opt.step()
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.master_params, rate=rate)
def _anneal_lr(self):
if not self.lr_anneal_steps:
return
frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
lr = self.lr * (1 - frac_done)
for param_group in self.opt.param_groups:
param_group["lr"] = lr
# in nn.py
def update_ema(target_params, source_params, rate=0.99):
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
evaluation
train_util.py
if self.eval_data is not None and self.step % self.eval_interval == 0:
batch_eval, cond_eval = next(self.eval_data)
**self.forward_only(batch, cond)**
print('eval on validation set')
logger.dumpkvs()
저장 경로 : diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e
decoder 부분에서 우리 케이스 적용..?
decoder 부분
run this
python scripts/text_sample.py --model_path /home/doolee13/Diffusion-LM/improved-diffusion/diffusion_models/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd101_xstart_e2e/ema_0.9999_001000.pt --batch_size 50 --num_samples 50 --top_p -1.0 --out_dir generation_outputs
improved_diffusion/scripts/text_sample.py
if args.experiment == 'random1': args.experiment = 'random'
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'the parameter count is {pytorch_total_params}')
# diffusion.rescale_timesteps = False # DEBUG --> REMOVE
print(diffusion.rescale_timesteps, 'a marker for whether we are in the debug mode')
model.to(dist_util.dev())
model.eval() # DEBUG
model2, tokenizer = load_models(args.modality, args.experiment, args.model_name_or_path, args.in_channel,
os.path.split(args.model_path)[0])
**if args.training_mode.startswith('e2e'):
print('e2e, load the right model embeddings', '*'*80)
model2.weight = th.nn.Parameter(model.word_embedding.weight.clone().cpu())**
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
~~if args.experiment_mode == 'conditional_gen':
batch, model_kwargs = next(data)
model_kwargs.pop('input_ids')
if args.mbr_sample > 1:
model_kwargs = {k: v.to(dist_util.dev()).repeat_interleave(args.mbr_sample, dim=0) for k, v in model_kwargs.items()}
else:
model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
print([(k, v.shape) for (k,v) in model_kwargs.items()])~~
~~if args.class_cond:
classes = th.randint(
low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
)
model_kwargs["y"] = classes~~
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
if args.model_arch == '1d-unet':
if args.mbr_sample > 1 and args.experiment_mode == 'conditional_gen':
sample_shape = (args.batch_size * args.mbr_sample, args.in_channel, args.image_size ** 2)
else:
sample_shape = (args.batch_size, args.in_channel, args.image_size ** 2)
**else:**
if args.mbr_sample > 1 and args.experiment_mode == 'conditional_gen':
sample_shape = (args.batch_size * args.mbr_sample, args.image_size ** 2, args.in_channel)
**else:
sample_shape = (args.batch_size, args.image_size ** 2, args.in_channel**)
print(sample_shape)
**sample = sample_fn(
model,
sample_shape,
clip_denoised=args.clip_denoised,
denoised_fn=partial(denoised_fn_round, args, model3.cuda()) if args.clamp == 'clamp' else None,
model_kwargs=model_kwargs,
top_p =args.top_p,
)**
if args.model_arch == '1d-unet':
print(sample.shape)
sample = sample.permute(0, 2, 1)
print(sample.shape)
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
sample 함수 실행 시 gaussian_diffusion.py 의 p_sample_loop 실행
def p_sample_loop():
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
top_p=top_p,
):
final = sample
return final["sample"]
def p_sample_loop_progressive():
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
**for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
top_p=top_p,
)
yield out
img = out["sample"]**
def p_sample():
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
def p_mean_variance():
else:
B, C = x.size(0), x.size(-1)
assert t.shape == (B,)
# print(x.shape)
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
np.append(self.posterior_variance[1], self.betas[1:]),
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
),
ModelVarType.FIXED_SMALL: (
**self.posterior_variance,
self.posterior_log_variance_clipped,**
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
**model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)**
def process_xstart(x):
if denoised_fn is not None:
# print(denoised_fn)
x = denoised_fn(x, t)
if clip_denoised:
return x.clamp(-1, 1)
return x
# in improved_diffusion/improved_diffusion/test_util.py
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance():
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
다시 p_sample()
def p_sample():
else:
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"],
'greedy_mean':out["mean"], 'out':out}
다시 p_sample_loop_progressive
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
top_p=top_p,
)
yield out
img = out["sample"]
별개로 sampling 과정 정리
이 과정이 neural network parameter를 포함하지 않은 실제 분포이므로 이를 수행하기 위해 일단 x_0 가 필요하다
x_0는 이 논문의 경우 epsilon이 아니라 모델이 x_0를 예측하기 때문에 (x_t, t) 를 주고 모델에서 뽑는다
현재 존재하는 x_t와 위에서 뽑은 x_0로 → q(x_t-1 | x_t, x_0) 수행해 mean 값을 얻는다
파라미터를 포함한 u는 q_posterior로 부터나오고 (parameter 포함 이유는 x_0 예측을 model이 했기 때문에) variance는 위의 그림처럼 t에 따라 정해진 상수로 정해짐