NERF 코드 분석 4.run_nerf_helper

코드짜는침팬지·2023년 9월 17일
0

학부 연구생

목록 보기
4/10

이제 직접 nerf를 돌리는데 필요한 함수들을 차근차근 살펴보자

import torch
# torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)

Utility Functions:

img2mse: 이미지 간의 Mean Squared Error (MSE)를 계산한다.
mse2psnr: MSE 값을 Peak Signal-to-Noise Ratio (PSNR)로 변환한다.

여기서 PSNR (Peak Signal-to-Noise Ratio)은 두 이미지 간의 차이를 측정하는 데 사용되는 표준적인 메트릭인데,
특히 원본 이미지와 압축 또는 복원된 이미지 간의 차이를 측정하는 데 주로 사용된다.
PSNR은 높을수록 이미지의 품질이 더 좋다는 것을 뜻한다.

PSNR은 다음과 같은 수식으로 정의된다:

PSNR=10×log10(MAX2MSE)\text{PSNR} = 10 \times \log_{10} \left( \frac{\text{MAX}^2}{\text{MSE}} \right)

여기서:

  • MAX\text{MAX} 는 가능한 최대 픽셀 값인데 예를 들어, 8비트 이미지의 경우
    MAX=255\text{MAX} = 255 이며
  • MSE\text{MSE} 는 Mean Squared Error로, 두 이미지 간의 평균 제곱 오차다.

PSNR은 주로 dB (데시벨) 단위로 표시된다.
높은 PSNR 값은 두 이미지 간의 차이가 작다는 것을 나타내며,
낮은 PSNR 값은 두 이미지 간의 차이가 크다는 것을 나타낸다.
그러나 PSNR만으로 이미지 품질을 완전히 판단하는 것은 어려운데,

PSNR 값을 가진 이미지가 우리 눈에는 낮은 품질로 보일 수 있으니,
PSNR 외에도 다른 품질 메트릭을 함께 고려 하는것이 좋다.

to8b: 이미지를 8비트로 변환한다.

Positional Encoding (Embedder 클래스):

다음 파트는 입력 데이터에 대한 위치 인코딩을 수행한다.
위치 인코딩은 신경망이 주기적 함수 (예: sin, cos)를 사용하여 고차원 공간에서 패턴을 학습할 수 있도록 돕는다.
조금 더 자세하게 들어가보자.

# Positional encoding (section 5.1)
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
  1. Embedder 클래스:
    • 위치 인코딩을 위한 함수들을 생성하고 저장하는 class를 생성한다.
  2. __init__(self, **kwargs):
    • 초기화 함수에서는 인수로 받은 kwargs를 클래스 변수로 저장하고, create_embedding_fn 메서드를 호출하여 위치 인코딩 함수들을 생성한다.

여기서 **kwargs 를 몇몇 파이썬 코드에서 봤을 것이다.
이에 대해 좀 더 설명 해보자면

**kwargs는 Python에서 함수나 클래스의 메서드에 임의의 개수의 키워드 인자를 전달할 때 사용하는 구문이다.
"kwargs"는 "keyword arguments"의 줄임말이며, 함수 내부에서는 kwargs가 딕셔너리로 처리된다.

**kwargs는 보통 아래와 같은 상황에서 유용하게 쓰이는데:

  1. 임의의 키워드 인자 수용: 함수나 메서드가 예상하지 못한 추가적인 키워드 인자를 받아들일 수 있게 해준다.

예를들어

def example_function(**kwargs):
    print(kwargs)

example_function(a=1, b=2, c=3)  # 출력: {'a': 1, 'b': 2, 'c': 3}

다른 프로그래밍 언어가 보면 기겁 할 만한 코드를 쓸 수 있다.

  1. 딕셔너리 언패킹: 또한 딕셔너리를 **를 사용하여 함수의 인자로 언패킹할 수 있다.

    def print_data(name, age):
        print(f"Name: {name}, Age: {age}")
    
    data_dict = {"name": "Alice", "age": 30}
    print_data(**data_dict)  # 출력: Name: Alice, Age: 30
  2. 함수나 메서드의 유연성 증가: **kwargs를 사용하면 함수나 메서드가 더 유연해지는데, 즉, 나중에 인자를 추가하거나 변경해도 기존의 코드를 수정하지 않아도 됩니다.

  3. 함수 데코레이터, 프레임워크, 라이브러리에서의 활용: 이러한 강력한 기능 덕분에 **kwargs는 함수 데코레이터나 프레임워크, 라이브러리에서 자주 사용된다. 이를 통해 사용자가 제공하는 임의의 인자를 수용하거나,
    내부적으로 특정 인자를 추가/제거할 수 있다.

이 코드의 Embedder 클래스에서 **kwargs는 생성자에 전달된 모든 키워드 인자를 self.kwargs 딕셔너리에 저장하는 역할을 해준다.
이렇게 함으로써 클래스의 다른 메서드에서 이 인자들에 접근할 수 있게 된다.

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
  1. create_embedding_fn(self):

    • 실제로 위치 인코딩 함수들을 생성

    • embed_fns: 생성된 함수들을 저장할 리스트

    • d: 입력 차원

    • include_input: True인 경우, 입력을 그대로 반환하는 함수를 추가

    • max_freq: 주파수의 최대 로그 값

    • N_freqs: 사용할 주파수의 수

    • log_sampling: True인 경우, 로그 스케일로 주파수 대역을 생성

    • 각 주파수 대역에 대해, 주어진 주기 함수 (periodic_fns, 예: sin, cos)를 사용하여 위치 인코딩 함수를 생성하고 embed_fns에 추가한다.

        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

  1. embed(self, inputs):
    • 주어진 입력에 대해 모든 위치 인코딩 함수를 적용하고,
      결과를 연결하여 반환한다.
def get_embedder(multires, i=0):
    if i == -1:
        return nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 3,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim
  1. get_embedder(multires, i=0):

    • 주어진 해상도 (multires)와 인덱스 i에 대한 위치 인코딩 함수를 반환한다.

    • 만약 i가 -1인 경우, 아무런 인코딩 없이 입력을 그대로 반환하는 함수를 반환한다.

    • 그렇지 않을 시, Embedder 객체를 생성하고 해당 객체의 embed 메서드를 반환한다.

즉 이 부분은 3D 좌표를 입력으로 받아, 주기 함수를 사용하여 해당 좌표를 더 높은 차원의 공간으로 매핑하는 위치 인코딩을 수행한다.
이렇게 함으로써 신경망이 공간의 복잡한 패턴과 관계를 더 잘 학습할 수 있게 된다.


숭배합니다, GOAT

모델 분석

자 이제 핵심 파트인 모델 분석이다.

여기선 파이토치를 처음 접한다고 생각하고 천천히 분석 해보겠다.


# Model
class NeRF(nn.Module):
  • NeRF라는 클래스를 정의하며, PyTorch의 기본 모듈인 nn.Module을 상속받는다.
    이를 통해 PyTorch의 다양한 기능(예: 자동 미분)을 사용할 수 있게 된다.
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
        """ 
        """
        super(NeRF, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs
  • 클래스의 생성자를 정의, 부모 클래스인 nn.Module의 생성자를 호출한다.
        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
        
        ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
  • pts_linears는 입력 데이터를 처리하는 선형 레이어의 리스트로,
    skips에 지정된 인덱스에서는 입력 데이터를 추가로 받는다.
        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
  • views_linears는 뷰 방향 정보를 처리하는 선형 레이어의 리스트다.
        ### Implementation according to the paper
        # self.views_linears = nn.ModuleList(
        #     [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
        
        if use_viewdirs:
            self.feature_linear = nn.Linear(W, W)
            self.alpha_linear = nn.Linear(W, 1)
            self.rgb_linear = nn.Linear(W//2, 3)
        else:
            self.output_linear = nn.Linear(W, output_ch)
  • use_viewdirs 플래그에 따라 추가적인 레이어들을 정의한다.
    use_viewdirs가 참이면 뷰 방향 정보를 사용하여 RGB와 투명도를 계산한다.
    def forward(self, x):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = torch.cat([feature, input_views], -1)
        
            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = F.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        return outputs    
  • 이 부분에서는 입력 x 를 받아 위치 정보와 뷰 방향 정보로 분리하고
    이 정보를 처리하는 레이어들을 순차적으로 통과 시킨다.
    skips 에 지정된 인덱스에선 입력 위치 정보를 추가로 연결한다.

  • use_viewdirs 플래그에 따라 뷰 방향 정보를 처리하거나 바로 출력을 반환한다.

    def load_weights_from_keras(self, weights):
        assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
        
        # Load pts_linears
        for i in range(self.D):
            idx_pts_linears = 2 * i
            self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))    
            self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
        
        # Load feature_linear
        idx_feature_linear = 2 * self.D
        self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
        self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))

        # Load views_linears
        idx_views_linears = 2 * self.D + 2
        self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
        self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))

        # Load rgb_linear
        idx_rbg_linear = 2 * self.D + 4
        self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
        self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))

        # Load alpha_linear
        idx_alpha_linear = 2 * self.D + 6
        self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
        self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
  • 이후 Keras에서 학습된 가중치를 로드하는데,
    Keras 모델에서 저장된 가중치를 PyTorch 모델에 로드하는 과정을 나타낸다.
    각 레이어의 가중치와 편향을 순차적으로 로드한다.

helper 함수

nerf의 뼈대가 되는 pytorch 함수는 위에 적혀있는 대로이고
아래는 트레이닝을 위한 helper 함수이다.

# Ray helpers
def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    return rays_o, rays_d
  1. get_rays(H, W, K, c2w)
    • 입력:
      • H, W: 이미지의 높이와 너비
      • K: 카메라의 intrinsic matrix
      • c2w: 카메라에서 월드 좌표계로의 변환 행렬
    • 역할:
      • 이미지의 각 픽셀에 대한 광선(ray)의 원점과 방향 계산
    • 출력:
      • rays_o: 각 광선의 원점
      • rays_d: 각 광선의 방향
def get_rays_np(H, W, K, c2w):
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
    return rays_o, rays_d
  1. get_rays_np(H, W, K, c2w)
    • get_rays와 동일한 기능을 하지만, numpy를 사용하여 계산한다.
def ndc_rays(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = -(near + rays_o[...,2]) / rays_d[...,2]
    rays_o = rays_o + t[...,None] * rays_d
    
    # Projection
    o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
    o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
    o2 = 1. + 2. * near / rays_o[...,2]

    d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
    d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
    d2 = -2. * near / rays_o[...,2]
    
    rays_o = torch.stack([o0,o1,o2], -1)
    rays_d = torch.stack([d0,d1,d2], -1)
    
    return rays_o, rays_d
  1. ndc_rays(H, W, focal, near, rays_o, rays_d)
    • 입력:
      • H, W: 이미지의 높이와 너비
      • focal: 초점 거리
      • near: 광선을 추적하기 시작하는 거리
      • rays_o, rays_d: 광선의 원점과 방향
    • 역할:
      • 광선을 NDC (Normalized Device Coordinates)로 변환한다.
        이는 광선 추적을 위한 좌표계 변환이다.
    • 출력:
      • 변환된 광선의 원점과 방향
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # Get pdf
    weights = weights + 1e-5 # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])

    return samples
  1. sample_pdf(bins, weights, N_samples, det=False, pytest=False)
    • 입력:
      • bins: 샘플링할 구간
      • weights: 각 구간의 가중치
      • N_samples: 샘플의 수
      • det: 결정론적 샘플링을 할지 여부
      • pytest: 테스트를 위한 플래그
    • 역할:
      • 주어진 가중치를 기반으로 PDF (Probability Density Function)를 생성하고,
        이를 사용하여 CDF (Cumulative Distribution Function)를 생성한다.
      • CDF를 기반으로 샘플을 추출합니다.
        이는 광선 추적에서 광선을 따라 특정 거리에서 샘플링을 수행할 때 사용된다.
    • 출력:
      • 샘플링된 값들
  • 이 부분은 NeRF의 핵심 부분 중 하나인 광선 추적과 관련된 도우미 함수이다.
    광선 추적은 3D 환경에서 광선의 경로를 시뮬레이션하여 이미지를 생성하는 기술이며
    NeRF는 이 광선 추적 기술을 사용하여 3D 환경을 묘사하는 신경망을 학습한다.
profile
학과 꼴찌 공대 호소인

0개의 댓글