이제 직접 nerf를 돌리는데 필요한 함수들을 차근차근 살펴보자
import torch
# torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
img2mse: 이미지 간의 Mean Squared Error (MSE)를 계산한다.
mse2psnr: MSE 값을 Peak Signal-to-Noise Ratio (PSNR)로 변환한다.
여기서 PSNR (Peak Signal-to-Noise Ratio)은 두 이미지 간의 차이를 측정하는 데 사용되는 표준적인 메트릭인데,
특히 원본 이미지와 압축 또는 복원된 이미지 간의 차이를 측정하는 데 주로 사용된다.
PSNR은 높을수록 이미지의 품질이 더 좋다는 것을 뜻한다.
PSNR은 다음과 같은 수식으로 정의된다:
여기서:
PSNR은 주로 dB (데시벨) 단위로 표시된다.
높은 PSNR 값은 두 이미지 간의 차이가 작다는 것을 나타내며,
낮은 PSNR 값은 두 이미지 간의 차이가 크다는 것을 나타낸다.
그러나 PSNR만으로 이미지 품질을 완전히 판단하는 것은 어려운데,
PSNR 값을 가진 이미지가 우리 눈에는 낮은 품질로 보일 수 있으니,
PSNR 외에도 다른 품질 메트릭을 함께 고려 하는것이 좋다.
to8b: 이미지를 8비트로 변환한다.
다음 파트는 입력 데이터에 대한 위치 인코딩을 수행한다.
위치 인코딩은 신경망이 주기적 함수 (예: sin, cos)를 사용하여 고차원 공간에서 패턴을 학습할 수 있도록 돕는다.
조금 더 자세하게 들어가보자.
# Positional encoding (section 5.1)
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
__init__(self, **kwargs)
:kwargs
를 클래스 변수로 저장하고, create_embedding_fn
메서드를 호출하여 위치 인코딩 함수들을 생성한다.여기서 **kwargs 를 몇몇 파이썬 코드에서 봤을 것이다.
이에 대해 좀 더 설명 해보자면
**kwargs
는 Python에서 함수나 클래스의 메서드에 임의의 개수의 키워드 인자를 전달할 때 사용하는 구문이다.
"kwargs"는 "keyword arguments"의 줄임말이며, 함수 내부에서는 kwargs
가 딕셔너리로 처리된다.
**kwargs
는 보통 아래와 같은 상황에서 유용하게 쓰이는데:
예를들어
def example_function(**kwargs):
print(kwargs)
example_function(a=1, b=2, c=3) # 출력: {'a': 1, 'b': 2, 'c': 3}
다른 프로그래밍 언어가 보면 기겁 할 만한 코드를 쓸 수 있다.
딕셔너리 언패킹: 또한 딕셔너리를 **
를 사용하여 함수의 인자로 언패킹할 수 있다.
def print_data(name, age):
print(f"Name: {name}, Age: {age}")
data_dict = {"name": "Alice", "age": 30}
print_data(**data_dict) # 출력: Name: Alice, Age: 30
함수나 메서드의 유연성 증가: **kwargs
를 사용하면 함수나 메서드가 더 유연해지는데, 즉, 나중에 인자를 추가하거나 변경해도 기존의 코드를 수정하지 않아도 됩니다.
함수 데코레이터, 프레임워크, 라이브러리에서의 활용: 이러한 강력한 기능 덕분에 **kwargs
는 함수 데코레이터나 프레임워크, 라이브러리에서 자주 사용된다. 이를 통해 사용자가 제공하는 임의의 인자를 수용하거나,
내부적으로 특정 인자를 추가/제거할 수 있다.
이 코드의 Embedder
클래스에서 **kwargs
는 생성자에 전달된 모든 키워드 인자를 self.kwargs
딕셔너리에 저장하는 역할을 해준다.
이렇게 함으로써 클래스의 다른 메서드에서 이 인자들에 접근할 수 있게 된다.
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
create_embedding_fn(self)
:
실제로 위치 인코딩 함수들을 생성
embed_fns
: 생성된 함수들을 저장할 리스트
d
: 입력 차원
include_input
: True인 경우, 입력을 그대로 반환하는 함수를 추가
max_freq
: 주파수의 최대 로그 값
N_freqs
: 사용할 주파수의 수
log_sampling
: True인 경우, 로그 스케일로 주파수 대역을 생성
각 주파수 대역에 대해, 주어진 주기 함수 (periodic_fns
, 예: sin, cos)를 사용하여 위치 인코딩 함수를 생성하고 embed_fns
에 추가한다.
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
embed(self, inputs)
:def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
get_embedder(multires, i=0)
:
주어진 해상도 (multires
)와 인덱스 i
에 대한 위치 인코딩 함수를 반환한다.
만약 i
가 -1인 경우, 아무런 인코딩 없이 입력을 그대로 반환하는 함수를 반환한다.
그렇지 않을 시, Embedder
객체를 생성하고 해당 객체의 embed
메서드를 반환한다.
즉 이 부분은 3D 좌표를 입력으로 받아, 주기 함수를 사용하여 해당 좌표를 더 높은 차원의 공간으로 매핑하는 위치 인코딩을 수행한다.
이렇게 함으로써 신경망이 공간의 복잡한 패턴과 관계를 더 잘 학습할 수 있게 된다.
숭배합니다, GOAT
자 이제 핵심 파트인 모델 분석이다.
여기선 파이토치를 처음 접한다고 생각하고 천천히 분석 해보겠다.
# Model
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
"""
"""
super(NeRF, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
### Implementation according to the paper
# self.views_linears = nn.ModuleList(
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
return outputs
이 부분에서는 입력 x 를 받아 위치 정보와 뷰 방향 정보로 분리하고
이 정보를 처리하는 레이어들을 순차적으로 통과 시킨다.
skips 에 지정된 인덱스에선 입력 위치 정보를 추가로 연결한다.
use_viewdirs 플래그에 따라 뷰 방향 정보를 처리하거나 바로 출력을 반환한다.
def load_weights_from_keras(self, weights):
assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
# Load pts_linears
for i in range(self.D):
idx_pts_linears = 2 * i
self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1]))
# Load feature_linear
idx_feature_linear = 2 * self.D
self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1]))
# Load views_linears
idx_views_linears = 2 * self.D + 2
self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1]))
# Load rgb_linear
idx_rbg_linear = 2 * self.D + 4
self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1]))
# Load alpha_linear
idx_alpha_linear = 2 * self.D + 6
self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1]))
nerf의 뼈대가 되는 pytorch 함수는 위에 적혀있는 대로이고
아래는 트레이닝을 위한 helper 함수이다.
# Ray helpers
def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
H
, W
: 이미지의 높이와 너비K
: 카메라의 intrinsic matrixc2w
: 카메라에서 월드 좌표계로의 변환 행렬rays_o
: 각 광선의 원점rays_d
: 각 광선의 방향def get_rays_np(H, W, K, c2w):
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
return rays_o, rays_d
get_rays
와 동일한 기능을 하지만, numpy를 사용하여 계산한다.def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[...,2]) / rays_d[...,2]
rays_o = rays_o + t[...,None] * rays_d
# Projection
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
o2 = 1. + 2. * near / rays_o[...,2]
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
d2 = -2. * near / rays_o[...,2]
rays_o = torch.stack([o0,o1,o2], -1)
rays_d = torch.stack([d0,d1,d2], -1)
return rays_o, rays_d
H
, W
: 이미지의 높이와 너비focal
: 초점 거리near
: 광선을 추적하기 시작하는 거리rays_o
, rays_d
: 광선의 원점과 방향# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
return samples
bins
: 샘플링할 구간weights
: 각 구간의 가중치N_samples
: 샘플의 수det
: 결정론적 샘플링을 할지 여부pytest
: 테스트를 위한 플래그