DDIM_sampler 코드분석

안민기·2023년 3월 7일
0

ControNet

목록 보기
2/2

DDIMsampler코드
DDIM : Training 과정 중 noising - denoising 과정에서 Non-Markov chain
기존 DDPM에선 denoising에서 베이지안룰 과 마르코프 체인을 이용해서 모든 스텝t를 사용했지만 DDIM에선 t를 위해 x0과 xt-1만 사용

클래스 선언

보통 ddim_sampler = DDIMSampler(model) 같이 model만 넘겨주는 듯
또한 보통 (B H W C)의 이미지를 (B C H W)이미지로 변환 후 모델로 넘겨야함

class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

주요 함수들

sample()

ddim_sampler 객체 생성 후

samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                 shape, cond, verbose=False, eta=eta,
                                                 unconditional_guidance_scale=scale,
                                                 unconditional_conditioning=un_cond)

와 같이 바로 sample 함수로 이미지를 뽑음

역할은 샘플링 과정을 수행하기 전 입력 인자 확인 & 스케줄을 생성함

ddim_sampling()

위의 sample함수에서 스케줄러를 생성하고 , 배치 크기 등 인자를 확인하고 실행하는 함수

def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
                      ucg_schedule=None):
                      
        device = self.model.betas.device
        	#샘플링 과정에서 사용할 device정의 betas속성에 위치한 device불러옴
        b = shape[0]
        	#shape는 size(B C H W)이므로 b는 batch
            
        if x_T is None:
            img = torch.randn(shape, device=device)
            	#x_T가 존재하지 않으면 정규 분포에서 무작위로 생성
        else:
            img = x_T
        #img는 초기이미지

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]
            #샘플링 과정에서 사용할 timestep설정
            #timestep에 따라 생성된 이미지의 특성이 달라질 수 있음
            #여기선 미리 정의해둔 timestep이아니면 사용하지 않는듯

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
		
        #학습시작
        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img
            #mask는 igm와 img_orig는 일부분 섞어주는 역할
          

            if ucg_schedule is not None:
                assert len(ucg_schedule) == len(time_range)
                unconditional_guidance_scale = ucg_schedule[i]

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning,
                                      dynamic_threshold=dynamic_threshold)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

정리 : sampling이란 과정은 모델의 Inference과정이라고 할 수 있다.
따라서 이미 학습된 모델들을 사용하기만 하며 Inference과정에는 파라미터들로 적절한 timestep , iterator을 정한후 p_sampler , q_sampler를 이용함

self.p_sample_ddim() : denoising process과정으로 노이즈 이미지에서 원하는 이미지로 복원
self.q_sample(x0,ts) : (????아직은 뇌피셜) q과정이 noising process임 x0은 원본이미지 , ts변수는 noise를 추가하는 스텝의 수 이과정이 latent diffusuion의 부분인듯

profile
Trendy AI Developer

0개의 댓글