DDIMsampler코드
DDIM : Training 과정 중 noising - denoising 과정에서 Non-Markov chain
기존 DDPM에선 denoising에서 베이지안룰 과 마르코프 체인을 이용해서 모든 스텝t를 사용했지만 DDIM에선 t를 위해 x0과 xt-1만 사용
보통 ddim_sampler = DDIMSampler(model) 같이 model만 넘겨주는 듯
또한 보통 (B H W C)의 이미지를 (B C H W)이미지로 변환 후 모델로 넘겨야함
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
ddim_sampler 객체 생성 후
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
와 같이 바로 sample 함수로 이미지를 뽑음
역할은 샘플링 과정을 수행하기 전 입력 인자 확인 & 스케줄을 생성함
위의 sample함수에서 스케줄러를 생성하고 , 배치 크기 등 인자를 확인하고 실행하는 함수
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
#샘플링 과정에서 사용할 device정의 betas속성에 위치한 device불러옴
b = shape[0]
#shape는 size(B C H W)이므로 b는 batch
if x_T is None:
img = torch.randn(shape, device=device)
#x_T가 존재하지 않으면 정규 분포에서 무작위로 생성
else:
img = x_T
#img는 초기이미지
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
#샘플링 과정에서 사용할 timestep설정
#timestep에 따라 생성된 이미지의 특성이 달라질 수 있음
#여기선 미리 정의해둔 timestep이아니면 사용하지 않는듯
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
#학습시작
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
#mask는 igm와 img_orig는 일부분 섞어주는 역할
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
정리 : sampling이란 과정은 모델의 Inference과정이라고 할 수 있다.
따라서 이미 학습된 모델들을 사용하기만 하며 Inference과정에는 파라미터들로 적절한 timestep , iterator을 정한후 p_sampler , q_sampler를 이용함
self.p_sample_ddim() : denoising process과정으로 노이즈 이미지에서 원하는 이미지로 복원
self.q_sample(x0,ts) : (????아직은 뇌피셜) q과정이 noising process임 x0은 원본이미지 , ts변수는 noise를 추가하는 스텝의 수 이과정이 latent diffusuion의 부분인듯