From eac4d202491b621d2105a99ddc62c34710589236 Mon Sep 17 00:00:00 2001 From: Haoxuan Li Date: Sat, 26 Aug 2023 15:03:43 +0200 Subject: [PATCH] add anneal-beta option for base surface model --- nerfstudio/models/bakedsdf.py | 20 --------- nerfstudio/models/base_surface_model.py | 56 ++++++++++++++++++++++++- nerfstudio/models/neus_facto.py | 36 ++-------------- 3 files changed, 57 insertions(+), 55 deletions(-) diff --git a/nerfstudio/models/bakedsdf.py b/nerfstudio/models/bakedsdf.py index 823ace56..cd27ff2e 100644 --- a/nerfstudio/models/bakedsdf.py +++ b/nerfstudio/models/bakedsdf.py @@ -190,26 +190,6 @@ def set_anneal(step): ) ) - if self.config.use_anneal_beta: - # anneal the beta of volsdf before each training iterations - M = self.config.beta_anneal_max_num_iters - beta_init = self.config.beta_anneal_init - beta_end = self.config.beta_anneal_end - - def set_beta(step): - # bakedsdf's beta schedule - train_frac = np.clip(step / M, 0, 1) - beta = beta_init / (1 + (beta_init - beta_end) / beta_end * (train_frac**0.8)) - self.field.laplace_density.beta.data[...] = beta - - callbacks.append( - TrainingCallback( - where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], - update_every_num_iters=1, - func=set_beta, - ) - ) - if self.config.use_anneal_eikonal_weight: # anneal the beta of volsdf before each training iterations K = self.config.eikonal_anneal_max_num_iters diff --git a/nerfstudio/models/base_surface_model.py b/nerfstudio/models/base_surface_model.py index 2f228154..a7f72367 100644 --- a/nerfstudio/models/base_surface_model.py +++ b/nerfstudio/models/base_surface_model.py @@ -22,15 +22,22 @@ from dataclasses import dataclass, field from typing import Dict, List, Tuple, Type +import numpy as np import torch import torch.nn.functional as F from torch.nn import Parameter -from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.functional import structural_similarity_index_measure +from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchtyping import TensorType from typing_extensions import Literal + from nerfstudio.cameras.rays import RayBundle, RaySamples +from nerfstudio.engine.callbacks import ( + TrainingCallback, + TrainingCallbackAttributes, + TrainingCallbackLocation, +) from nerfstudio.field_components.encodings import NeRFEncoding from nerfstudio.field_components.field_heads import FieldHeadNames from nerfstudio.field_components.spatial_distortions import SceneContraction @@ -119,6 +126,14 @@ class SurfaceModelConfig(ModelConfig): """whether to use near and far collider from command line""" scene_contraction_norm: Literal["inf", "l2"] = "inf" """Which norm to use for the scene contraction.""" + use_anneal_beta: bool = False + """whether to anneal beta of neus or not similar to bakedsdf""" + beta_anneal_max_num_iters: int = 1000_000 + """max num iterations for the annealing function of beta""" + beta_anneal_init: float = 0.05 + """initial beta for annealing function""" + beta_anneal_end: float = 0.0002 + """final beta for annealing function""" class SurfaceModel(Model): @@ -225,13 +240,50 @@ def populate_modules(self): def get_param_groups(self) -> Dict[str, List[Parameter]]: param_groups = {} - param_groups["fields"] = list(self.field.parameters()) + + if self.config.use_anneal_beta: + # don't optimize beta in laplace density if use annealing beta + param_groups["fields"] = [ + n_p[1] for n_p in filter(lambda n_p: "laplace_density" not in n_p[0], self.field.named_parameters()) + ] + else: + param_groups["fields"] = list(self.field.parameters()) + if self.config.background_model != "none": param_groups["field_background"] = list(self.field_background.parameters()) else: param_groups["field_background"] = list(self.field_background) + return param_groups + def get_training_callbacks( + self, training_callback_attributes: TrainingCallbackAttributes + ) -> List[TrainingCallback]: + callbacks = super().get_training_callbacks(training_callback_attributes) + + if self.config.use_anneal_beta: + # anneal the beta of volsdf before each training iterations + M = self.config.beta_anneal_max_num_iters + beta_init = self.config.beta_anneal_init + beta_end = self.config.beta_anneal_end + + def set_beta(step): + # bakedsdf's beta schedule adapted to neus + train_frac = np.clip(step / M, 0, 1) + beta = beta_init / (1 + (beta_init - beta_end) / beta_end * (train_frac**0.8)) + beta = np.log(1.0 / beta) / 10.0 + self.field.deviation_network.variance.data[...] = beta + + callbacks.append( + TrainingCallback( + where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], + update_every_num_iters=1, + func=set_beta, + ) + ) + + return callbacks + @abstractmethod def sample_and_forward_field(self, ray_bundle: RayBundle) -> Dict: """_summary_ diff --git a/nerfstudio/models/neus_facto.py b/nerfstudio/models/neus_facto.py index 8be4c2b8..20207b47 100644 --- a/nerfstudio/models/neus_facto.py +++ b/nerfstudio/models/neus_facto.py @@ -19,9 +19,9 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Type, Tuple, Dict -import numpy as np +from typing import Dict, List, Tuple, Type +import numpy as np import torch from torch.nn import Parameter @@ -32,10 +32,10 @@ TrainingCallbackLocation, ) from nerfstudio.field_components.field_heads import FieldHeadNames -from nerfstudio.models.neus import NeuSModel, NeuSModelConfig from nerfstudio.fields.density_fields import HashMLPDensityField from nerfstudio.model_components.losses import interlevel_loss, interlevel_loss_zip from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler +from nerfstudio.models.neus import NeuSModel, NeuSModelConfig from nerfstudio.utils import colormaps @@ -73,15 +73,6 @@ class NeuSFactoModelConfig(NeuSModelConfig): """Max num iterations for the annealing function.""" use_single_jitter: bool = True """Whether use single jitter or not for the proposal networks.""" - use_anneal_beta: bool = False - """whether to anneal beta of neus or not similar to bakedsdf""" - beta_anneal_max_num_iters: int = 1000_000 - """max num iterations for the annealing function of beta""" - beta_anneal_init: float = 0.05 - """initial beta for annealing function""" - beta_anneal_end: float = 0.0002 - """final beta for annealing function""" - # TODO move to base model config since it can be used in all models enable_progressive_hash_encoding: bool = False """whether to use progressive hash encoding""" enable_numerical_gradients_schedule: bool = False @@ -182,27 +173,6 @@ def set_anneal(step): ) ) - if self.config.use_anneal_beta: - # anneal the beta of volsdf before each training iterations - M = self.config.beta_anneal_max_num_iters - beta_init = self.config.beta_anneal_init - beta_end = self.config.beta_anneal_end - - def set_beta(step): - # bakedsdf's beta schedule adapted to neus - train_frac = np.clip(step / M, 0, 1) - beta = beta_init / (1 + (beta_init - beta_end) / beta_end * (train_frac**0.8)) - beta = np.log(1.0 / beta) / 10.0 - self.field.deviation_network.variance.data[...] = beta - - callbacks.append( - TrainingCallback( - where_to_run=[TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], - update_every_num_iters=1, - func=set_beta, - ) - ) - # read the hash encoding parameters from field level_init = self.config.level_init # schedule the delta in numerical gradients computation