Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions nerfstudio/models/bakedsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,6 @@ def set_anneal(step):
)
)

if self.config.use_anneal_beta:
# anneal the beta of volsdf before each training iterations
M = self.config.beta_anneal_max_num_iters
beta_init = self.config.beta_anneal_init
beta_end = self.config.beta_anneal_end

def set_beta(step):
# bakedsdf's beta schedule
train_frac = np.clip(step / M, 0, 1)
beta = beta_init / (1 + (beta_init - beta_end) / beta_end * (train_frac**0.8))
self.field.laplace_density.beta.data[...] = beta

callbacks.append(
TrainingCallback(
where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION],
update_every_num_iters=1,
func=set_beta,
)
)

if self.config.use_anneal_eikonal_weight:
# anneal the beta of volsdf before each training iterations
K = self.config.eikonal_anneal_max_num_iters
Expand Down
56 changes: 54 additions & 2 deletions nerfstudio/models/base_surface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Type

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchtyping import TensorType
from typing_extensions import Literal

from nerfstudio.cameras.rays import RayBundle, RaySamples
from nerfstudio.engine.callbacks import (
TrainingCallback,
TrainingCallbackAttributes,
TrainingCallbackLocation,
)
from nerfstudio.field_components.encodings import NeRFEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.spatial_distortions import SceneContraction
Expand Down Expand Up @@ -119,6 +126,14 @@ class SurfaceModelConfig(ModelConfig):
"""whether to use near and far collider from command line"""
scene_contraction_norm: Literal["inf", "l2"] = "inf"
"""Which norm to use for the scene contraction."""
use_anneal_beta: bool = False
"""whether to anneal beta of neus or not similar to bakedsdf"""
beta_anneal_max_num_iters: int = 1000_000
"""max num iterations for the annealing function of beta"""
beta_anneal_init: float = 0.05
"""initial beta for annealing function"""
beta_anneal_end: float = 0.0002
"""final beta for annealing function"""


class SurfaceModel(Model):
Expand Down Expand Up @@ -225,13 +240,50 @@ def populate_modules(self):

def get_param_groups(self) -> Dict[str, List[Parameter]]:
param_groups = {}
param_groups["fields"] = list(self.field.parameters())

if self.config.use_anneal_beta:
# don't optimize beta in laplace density if use annealing beta
param_groups["fields"] = [
n_p[1] for n_p in filter(lambda n_p: "laplace_density" not in n_p[0], self.field.named_parameters())
]
else:
param_groups["fields"] = list(self.field.parameters())

if self.config.background_model != "none":
param_groups["field_background"] = list(self.field_background.parameters())
else:
param_groups["field_background"] = list(self.field_background)

return param_groups

def get_training_callbacks(
self, training_callback_attributes: TrainingCallbackAttributes
) -> List[TrainingCallback]:
callbacks = super().get_training_callbacks(training_callback_attributes)

if self.config.use_anneal_beta:
# anneal the beta of volsdf before each training iterations
M = self.config.beta_anneal_max_num_iters
beta_init = self.config.beta_anneal_init
beta_end = self.config.beta_anneal_end

def set_beta(step):
# bakedsdf's beta schedule adapted to neus
train_frac = np.clip(step / M, 0, 1)
beta = beta_init / (1 + (beta_init - beta_end) / beta_end * (train_frac**0.8))
beta = np.log(1.0 / beta) / 10.0
self.field.deviation_network.variance.data[...] = beta

callbacks.append(
TrainingCallback(
where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION],
update_every_num_iters=1,
func=set_beta,
)
)

return callbacks

@abstractmethod
def sample_and_forward_field(self, ray_bundle: RayBundle) -> Dict:
"""_summary_
Expand Down
36 changes: 3 additions & 33 deletions nerfstudio/models/neus_facto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Type, Tuple, Dict
import numpy as np
from typing import Dict, List, Tuple, Type

import numpy as np
import torch
from torch.nn import Parameter

Expand All @@ -32,10 +32,10 @@
TrainingCallbackLocation,
)
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.models.neus import NeuSModel, NeuSModelConfig
from nerfstudio.fields.density_fields import HashMLPDensityField
from nerfstudio.model_components.losses import interlevel_loss, interlevel_loss_zip
from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler
from nerfstudio.models.neus import NeuSModel, NeuSModelConfig
from nerfstudio.utils import colormaps


Expand Down Expand Up @@ -73,15 +73,6 @@ class NeuSFactoModelConfig(NeuSModelConfig):
"""Max num iterations for the annealing function."""
use_single_jitter: bool = True
"""Whether use single jitter or not for the proposal networks."""
use_anneal_beta: bool = False
"""whether to anneal beta of neus or not similar to bakedsdf"""
beta_anneal_max_num_iters: int = 1000_000
"""max num iterations for the annealing function of beta"""
beta_anneal_init: float = 0.05
"""initial beta for annealing function"""
beta_anneal_end: float = 0.0002
"""final beta for annealing function"""
# TODO move to base model config since it can be used in all models
enable_progressive_hash_encoding: bool = False
"""whether to use progressive hash encoding"""
enable_numerical_gradients_schedule: bool = False
Expand Down Expand Up @@ -182,27 +173,6 @@ def set_anneal(step):
)
)

if self.config.use_anneal_beta:
# anneal the beta of volsdf before each training iterations
M = self.config.beta_anneal_max_num_iters
beta_init = self.config.beta_anneal_init
beta_end = self.config.beta_anneal_end

def set_beta(step):
# bakedsdf's beta schedule adapted to neus
train_frac = np.clip(step / M, 0, 1)
beta = beta_init / (1 + (beta_init - beta_end) / beta_end * (train_frac**0.8))
beta = np.log(1.0 / beta) / 10.0
self.field.deviation_network.variance.data[...] = beta

callbacks.append(
TrainingCallback(
where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION],
update_every_num_iters=1,
func=set_beta,
)
)

# read the hash encoding parameters from field
level_init = self.config.level_init
# schedule the delta in numerical gradients computation
Expand Down