Neural graphics primitives, parameterized by fully connected neural networks, can be costly to train and evaluate. We reduce this cost with a versatile new input encoding that permits the use of a smaller network without sacrificing quality, thus significantly reducing the number of floating point and memory access operations: a small neural network is augmented by a multiresolution hash table of trainable feature vectors whose values are optimized through stochastic gradient descent. The multiresolution structure allows the network to disambiguate hash collisions, making for a simple architecture that is trivial to parallelize on modern GPUs. We leverage this parallelism by implementing the whole system using fully-fused CUDA kernels with a focus on minimizing wasted bandwidth and compute operations. We achieve a combined speedup of several orders of magnitude, enabling training of high-quality neural graphics primitives in a matter of seconds, and rendering in tens of milliseconds at a resolution of ${1920\! imes\!1080}$.
Source: Instant Neural Graphics Primitives with a Multiresolution Hash Encoding (2022-01-16). See: paper link.
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from PIL import Image
@torch.no_grad()
def test(hn, hf, dataset, img_index, chunk_size=20, nb_bins=192, H=400, W=400):
ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]
px_values = [] # list of regenerated pixel values
for i in range(int(np.ceil(H / chunk_size))): # iterate over chunks
ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
px_values.append(render_rays(model, ray_origins_, ray_directions_,
hn=hn, hf=hf, nb_bins=nb_bins))
img = torch.cat(px_values).data.cpu().numpy().reshape(H, W, 3)
img = (img.clip(0, 1)*255.).astype(np.uint8)
img = Image.fromarray(img)
img.save(f'novel_views/img_{img_index}.png')
class NGP(torch.nn.Module):
def __init__(self, T, Nl, L, device, aabb_scale, F=2):
super(NGP, self).__init__()
self.T = T
self.Nl = Nl
self.F = F
self.L = L # For encoding directions
self.aabb_scale = aabb_scale
self.lookup_tables = torch.nn.ParameterDict(
{str(i): torch.nn.Parameter((torch.rand(
(T, 2), device=device) * 2 - 1) * 1e-4) for i in range(len(Nl))})
self.pi1, self.pi2, self.pi3 = 1, 2_654_435_761, 805_459_861
self.density_MLP = nn.Sequential(nn.Linear(self.F * len(Nl), 64),
nn.ReLU(), nn.Linear(64, 16)).to(device)
self.color_MLP = nn.Sequential(nn.Linear(27 + 16, 64), nn.ReLU(),
nn.Linear(64, 64), nn.ReLU(),
nn.Linear(64, 3), nn.Sigmoid()).to(device)
def positional_encoding(self, x):
out = [x]
for j in range(self.L):
out.append(torch.sin(2 ** j * x))
out.append(torch.cos(2 ** j * x))
return torch.cat(out, dim=1)
def forward(self, x, d):
x /= self.aabb_scale
mask = (x[:, 0].abs() < .5) & (x[:, 1].abs() < .5) & (x[:, 2].abs() < .5)
x += 0.5 # x in [0, 1]^3
color = torch.zeros((x.shape[0], 3), device=x.device)
log_sigma = torch.zeros((x.shape[0]), device=x.device) - 100000
features = torch.empty((x[mask].shape[0], self.F * len(self.Nl)), device=x.device)
for i, N in enumerate(self.Nl):
# Computing vertices, use nn.functional.grid_sample convention
floor = torch.floor(x[mask] * N)
ceil = torch.ceil(x[mask] * N)
vertices = torch.zeros((x[mask].shape[0], 8, 3), dtype=torch.int64, device=x.device)
vertices[:, 0] = floor
vertices[:, 1] = torch.cat((ceil[:, 0, None], floor[:, 1, None], floor[:, 2, None]), dim=1)
vertices[:, 2] = torch.cat((floor[:, 0, None], ceil[:, 1, None], floor[:, 2, None]), dim=1)
vertices[:, 4] = torch.cat((floor[:, 0, None], floor[:, 1, None], ceil[:, 2, None]), dim=1)
vertices[:, 6] = torch.cat((floor[:, 0, None], ceil[:, 1, None], ceil[:, 2, None]), dim=1)
vertices[:, 5] = torch.cat((ceil[:, 0, None], floor[:, 1, None], ceil[:, 2, None]), dim=1)
vertices[:, 3] = torch.cat((ceil[:, 0, None], ceil[:, 1, None], floor[:, 2, None]), dim=1)
vertices[:, 7] = ceil
# hashing
a = vertices[:, :, 0] * self.pi1
b = vertices[:, :, 1] * self.pi2
c = vertices[:, :, 2] * self.pi3
h_x = torch.remainder(torch.bitwise_xor(torch.bitwise_xor(a, b), c), self.T)
# Lookup
looked_up = self.lookup_tables[str(i)][h_x].transpose(-1, -2)
volume = looked_up.reshape((looked_up.shape[0], 2, 2, 2, 2))
features[:, i*2:(i+1)*2] = torch.nn.functional.grid_sample(
volume,
((x[mask] * N - floor) - 0.5).unsqueeze(1).unsqueeze(1).unsqueeze(1)
).squeeze(-1).squeeze(-1).squeeze(-1)
xi = self.positional_encoding(d[mask])
h = self.density_MLP(features)
log_sigma[mask] = h[:, 0]
color[mask] = self.color_MLP(torch.cat((h, xi), dim=1))
return color, torch.exp(log_sigma)
def compute_accumulated_transmittance(alphas):
accumulated_transmittance = torch.cumprod(alphas, 1)
return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
accumulated_transmittance[:, :-1]), dim=-1)
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
device = ray_origins.device
t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
# Perturb sampling along each ray.
mid = (t[:, :-1] + t[:, 1:]) / 2.
lower = torch.cat((t[:, :1], mid), -1)
upper = torch.cat((mid, t[:, -1:]), -1)
u = torch.rand(t.shape, device=device)
t = lower + (upper - lower) * u # [batch_size, nb_bins]
delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor(
[1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
# Compute the 3D points along each ray
x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)
# Expand the ray_directions tensor to match the shape of x
ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
alpha = 1 - torch.exp(-sigma.reshape(x.shape[:-1]) * delta) # [batch_size, nb_bins]
weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
c = (weights * colors.reshape(x.shape)).sum(dim=1)
weight_sum = weights.sum(-1).sum(-1) # Regularization for white background
return c + 1 - weight_sum.unsqueeze(-1)
def train(nerf_model, optimizer, data_loader, device='cpu', hn=0, hf=1, nb_epochs=10,
nb_bins=192, H=400, W=400):
for _ in range(nb_epochs):
for batch in tqdm(data_loader):
ray_origins = batch[:, :3].to(device)
ray_directions = batch[:, 3:6].to(device)
gt_px_values = batch[:, 6:].to(device)
pred_px_values = render_rays(nerf_model, ray_origins, ray_directions,
hn=hn, hf=hf, nb_bins=nb_bins)
loss = ((gt_px_values - pred_px_values) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__ == "__main__":
device = 'cuda'
training_dataset = torch.from_numpy(np.load('training_data_800x800.pkl',
allow_pickle=True))
testing_dataset = torch.from_numpy(np.load('testing_data_800x800.pkl',
allow_pickle=True))
L = 16
F = 2
T = 2**19
N_min = 16
N_max = 2048
b = np.exp((np.log(N_max) - np.log(N_min)) / (L - 1))
Nl = [int(np.floor(N_min * b**l)) for l in range(L)]
model = NGP(T, Nl, 4, device, 3)
model_optimizer = torch.optim.Adam(
[{"params": model.lookup_tables.parameters(), "lr": 1e-2, "betas": (0.9, 0.99), "eps": 1e-15, "weight_decay": 0.},
{"params": model.density_MLP.parameters(), "lr": 1e-2, "betas": (0.9, 0.99), "eps": 1e-15, "weight_decay": 10**-6},
{"params": model.color_MLP.parameters(), "lr": 1e-2, "betas": (0.9, 0.99), "eps": 1e-15, "weight_decay": 10**-6}])
data_loader = DataLoader(training_dataset, batch_size=2**14, shuffle=True)
train(model, model_optimizer, data_loader, nb_epochs=1, device=device,
hn=2, hf=6, nb_bins=192, H=800, W=800)
for img_index in range(200):
test(2, 6, testing_dataset, img_index, nb_bins=192, H=800, W=800)
python implementation Instant Neural Graphics Primitives with a Multiresolution Hash Encoding in 100 lines
2022-01-16