diff --git a/configs/fern_mf.txt b/configs/fern_mf.txt new file mode 100644 index 000000000..f02eb6964 --- /dev/null +++ b/configs/fern_mf.txt @@ -0,0 +1,16 @@ +expname = fern_test_mf +basedir = ./logs +datadir = ./data/nerf_llff_data/fern +dataset_type = llff + +factor = 8 +llffhold = 8 + +N_rand = 1024 +N_samples = 64 +N_importance = 64 + +use_viewdirs = True +raw_noise_std = 1e0 + +i_embed = 2 \ No newline at end of file diff --git a/configs/lego_mf.txt b/configs/lego_mf.txt new file mode 100644 index 000000000..f32beaf7e --- /dev/null +++ b/configs/lego_mf.txt @@ -0,0 +1,21 @@ +expname = blender_paper_lego_mf +basedir = ./logs +datadir = ./data/nerf_synthetic/lego +dataset_type = blender + +no_batching = True + +use_viewdirs = True +white_bkgd = True +lrate_decay = 500 + +N_samples = 64 +N_importance = 128 +N_rand = 1024 + +precrop_iters = 500 +precrop_frac = 0.5 + +half_res = True + +i_embed = 2 \ No newline at end of file diff --git a/run_nerf.py b/run_nerf.py index bc270be86..d10643b0c 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -18,6 +18,9 @@ from load_blender import load_blender_data from load_LINEMOD import load_LINEMOD_data +from torch.utils.tensorboard import SummaryWriter +from torchstat import stat + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") np.random.seed(0) @@ -467,7 +470,7 @@ def config_parser(): parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') parser.add_argument("--i_embed", type=int, default=0, - help='set 0 for default positional encoding, -1 for none') + help='set 0 for default positional encoding, -1 for none, 2 for multivariable fourier') parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') parser.add_argument("--multires_views", type=int, default=4, @@ -650,6 +653,10 @@ def train(): # Move testing data to GPU render_poses = torch.Tensor(render_poses).to(device) + # Check model size + # stat(render_kwargs_train['network_fn'], (65536, 2003)) + # stat(render_kwargs_train['network_fine'], (65536, 2003)) + # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') @@ -705,7 +712,7 @@ def train(): print('VAL views are', i_val) # Summary writers - # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) + writer = SummaryWriter(os.path.join(basedir, 'summaries', expname, 'base')) start = start + 1 for i in trange(start, N_iters): @@ -827,6 +834,8 @@ def train(): if i%args.i_print==0: tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") + writer.add_scalar("Train/Loss", loss.item(), i) + writer.add_scalar("Train/PSNR", psnr.item(), i) """ print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy()) print('iter time {:.05f}'.format(dt)) @@ -875,4 +884,4 @@ def train(): if __name__=='__main__': torch.set_default_tensor_type('torch.cuda.FloatTensor') - train() + train() \ No newline at end of file diff --git a/run_nerf_helpers.py b/run_nerf_helpers.py index bc6ee779d..80d44cc6b 100644 --- a/run_nerf_helpers.py +++ b/run_nerf_helpers.py @@ -3,8 +3,12 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np +import itertools +DEBUG = False +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Misc img2mse = lambda x, y : torch.mean((x - y) ** 2) mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) @@ -13,17 +17,38 @@ # Positional encoding (section 5.1) class Embedder: + def spherical_features(sqrt_dim=10, rand=True): + if rand: # Random + # np.random.seed(0) # generate consistant random feature + u, v = np.random.rand(2, sqrt_dim ** 2) + else: # Stratified + segs = np.linspace(0, 1, sqrt_dim) + u, v = np.array(list(itertools.product(segs, segs))).transpose() + # Spherical sampling + i = 2 * np.pi * u + j = np.arccos(1 - 2 * v) + x = np.sin(j) * np.cos(i) + y = np.sin(j) * np.sin(i) + z = np.cos(j) + return torch.from_numpy(np.stack((x, y, z)).transpose()).to(device) + def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] + # Multivariable Fourier Basis + embed_mffns = [] + d = self.kwargs['input_dims'] out_dim = 0 + out_mf_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x : x) + embed_mffns.append(lambda x : x) out_dim += d + out_mf_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] @@ -37,13 +62,45 @@ def create_embedding_fn(self): for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) out_dim += d - + + # # Multivariable Fourier Basis + # for freq_x in freq_bands: + # for freq_y in freq_bands: + # for freq_z in freq_bands: + # for p_fn in self.kwargs['periodic_fns']: + # embed_mffns.append(lambda x, p_fn=p_fn, freq_x=freq_x, freq_y=freq_y, + # freq_z=freq_z : p_fn(x[:, 0:1] * freq_x + x[:, 1:2] * freq_y + x[:, 2:3] * freq_z)) + # out_mf_dim += 1 + + # Spherical Fourier Basis + np.random.seed(0) # generate consistant random feature + for freq in freq_bands: + for unit_vec in Embedder.spherical_features(): + for p_fn in self.kwargs['periodic_fns']: + embed_mffns.append(lambda x, p_fn=p_fn, freq=freq, vec=unit_vec : + p_fn(freq * (vec[0] * x[:, 0:1] + vec[1] * x[:, 1:2] + vec[2] * x[:, 2:3]))) + out_mf_dim += 1 + + self.embed_fns = embed_fns + self.embed_mffns = embed_mffns self.out_dim = out_dim + self.out_mf_dim = out_mf_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + def embed_mf(self, inputs): + if DEBUG: + fns = [fn(inputs) for fn in self.embed_fns] + mffns = [fn(inputs) for fn in self.embed_mffns] + print("fns", torch.cat(fns, -1).shape) + print("out_dim", self.out_dim) + print("mffns", torch.cat(mffns, -1).shape) + print("out_mf_dim", self.out_mf_dim) + exit(-1) + return torch.cat([fn(inputs) for fn in self.embed_mffns], -1) + def get_embedder(multires, i=0): if i == -1: @@ -59,6 +116,12 @@ def get_embedder(multires, i=0): } embedder_obj = Embedder(**embed_kwargs) + + # Multivariable Fourier Basis + if i == 2: + embed = lambda x, eo=embedder_obj : eo.embed_mf(x) + return embed, embedder_obj.out_mf_dim + embed = lambda x, eo=embedder_obj : eo.embed(x) return embed, embedder_obj.out_dim