diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..114376106 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*.pyc +.vscode +output +build +diff_rasterization/diff_rast.egg-info +diff_rasterization/dist +tensorboard_3d +screenshots \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..fdce420e1 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "submodules/diff-gaussian-rasterization"] + path = submodules/diff-gaussian-rasterization + url = https://gitlab.inria.fr/bkerbl/diff-gaussian-rasterization.git +[submodule "submodules/simple-knn"] + path = submodules/simple-knn + url = https://gitlab.inria.fr/bkerbl/simple-knn.git +[submodule "SIBR_viewers_windows"] + path = SIBR_viewers_windows + url = https://gitlab.inria.fr/sibr/sibr_core.git +[submodule "SIBR_viewers_linux"] + path = SIBR_viewers_linux + url = https://gitlab.inria.fr/sibr/sibr_core.git diff --git a/README.md b/README.md new file mode 100644 index 000000000..6e9cd8d36 --- /dev/null +++ b/README.md @@ -0,0 +1,203 @@ +# 3D Gaussian Splatting for Real-Time Radiance Field Rendering +Bernhard Kerbl*, Georgios Kopanas*, Thomas Leimkühler, George Drettakis (* indicates equal contribution)
+| [Webpage](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) | [Full Paper](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf) | [Datasets (TODO)](TODO) | [Video](https://youtu.be/T_kXY43VZnk) | [Other GRAPHDECO Publications](http://www-sop.inria.fr/reves/publis/gdindex.php) | [FUNGRAPH project page](https://fungraph.inria.fr) |
+![Teaser image](assets/teaser.png) + +This repository contains the code associated with the paper "3D Gaussian Splatting for Real-Time Radiance Field Rendering", which can be found [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/). We further provide the reference images used to create the error metrics reported in the paper, as well as recently created, pre-trained models. + + + + + + +Abstract: *Radiance Field methods have recently revolutionized novel-view synthesis of scenes captured with multiple photos or videos. However, achieving high visual quality still requires neural networks that are costly to train and render, while recent faster methods inevitably trade off speed for quality. For unbounded and complete scenes (rather than isolated objects) and 1080p resolution rendering, no current method can achieve real-time display rates. We introduce three key elements that allow us to achieve state-of-the-art visual quality while maintaining competitive training times and importantly allow high-quality real-time (≥ 30 fps) novel-view synthesis at 1080p resolution. First, starting from sparse points produced during camera calibration, we represent the scene with 3D Gaussians that preserve desirable properties of continuous volumetric radiance fields for scene optimization while avoiding unnecessary computation in empty space; Second, we perform interleaved optimization/density control of the 3D Gaussians, notably optimizing anisotropic covariance to achieve an accurate representation of the scene; Third, we develop a fast visibility-aware rendering algorithm that supports anisotropic splatting and both accelerates training and allows realtime rendering. We demonstrate state-of-the-art visual quality and real-time rendering on several established datasets.* + +
+
+

BibTeX

+
@Article{kerbl3Dgaussians,
+      author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
+      title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
+      journal      = {ACM Transactions on Graphics},
+      number       = {4},
+      volume       = {42},
+      month        = {July},
+      year         = {2023},
+      url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
+}
+
+
+ + +## Funding and Acknowledgments + +This research was funded by the ERC Advanced grant FUNGRAPH No 788065. The authors are grateful to Adobe for generous donations, the OPAL infrastructure from Université Côte d’Azur and for the HPC resources from GENCI–IDRIS (Grant 2022-AD011013409). The authors thank the anonymous reviewers for their valuable feedback, P. Hedman and A. Tewari for proofreading earlier drafts also T. Müller, A. Yu and S. Fridovich-Keil for helping with the comparisons. + +## Cloning the Repository +TODO TODO Replace link + +The repository contains submodules, thus please check it out with +``` +git clone --recursive +``` + +## Overview + +The codebase has 4 main components: +- A PyTorch-based optimizer to produce a 3D Gaussian model from SfM inputs +- A network viewer that allows to connect to and visualize the optimization process +- An OpenGL-based real-time viewer to render trained models in real-time. +- A script to help you turn your own images into optimization-ready SfM data sets + +The components have different requirements w.r.t. both hardware and software. They have been tested on Windows 10 and Linux Ubuntu 22. Instructions for setting up and running each of them are found in the sections below. + +## Optimizer + +The optimizer uses PyTorch and CUDA extensions in a Python environment to produce trained models. + +### Hardware Requirements + +- CUDA-ready GPU with Compute Capability 7.0+ +- 24 GB VRAM to train the largest scenes in our test suite + +### Setup + +Our provided install method is based on Conda package and environment management. We suggest 3 options, depending on your available disk space. + +#### Option 1 (Plenty of space on system drive) + +To produce our exact evaluation environment on a freshly set up machine should be straightforward once Conda is installed (at the expense of considerable disk space): +```shell +conda env create --file environment_full.yml # This will take some time +conda activate gaussian_splatting +``` +#### Option 2 (Little space on system drive) + +If you already have a recent C++ compiler and a version of the CUDA **development** kit 11 installed, you can opt to use the lighter-weight environment install instead. +```shell +conda env create --file environment_light.yml # This will take less time +conda activate gaussian_splatting +``` +#### Option 3 (Even less space on system drive) +Note that even with the light version, downloading packages and creating a new environment with Conda can require a significant amount of disk space. By default, Conda will use the main system hard drive. You can avoid this by specifying a different package download location and an environment on a different drive: + +```shell +conda config --add pkgs_dirs / +conda env create --file environment_light.yml --prefix //gaussian_splatting +conda activate //gaussian_splatting +``` + +#### Custom Install + +If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make changes, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 11.8). + +### Running + +To run the optimizer, simply use + +```shell +python train.py -s +``` + +TODO update link + +You can find our SfM data sets for Tanks&Temples and Deep Blending here. If you do not provide an output model directory (```-m```), trained models are written to folders with randomized unique names inside the ```output``` directory. At this point, the trained models may be viewed with the real-time viewer (see further below). + +### Evaluation +By default, the trained models use all available images in the dataset. To train them while withholding a test set for evaluation, use the ```--eval``` flag. This way, you can render training/test sets and produce error metrics as follows: +```shell +python train.py -s --eval # Train with train/test split +python render.py -m # Generate renderings +python metrics.py -m # Compute error metrics on renderings +``` + +We further provide the ```full_eval.py``` script. This script specifies the routine used in our evaluation and demonstrates the use of some additional parameters, e.g., ```--images (-i)``` to define alternative image directories within COLMAP data sets. If you have downloaded and extracted all the training data, you can run it like this: +``` +python full_eval.py --m360 --tat --db +``` +In the current version, this process takes about 7h on our reference machine containing an A6000. + +## Network Viewer + +The Network Viewer can be used to observe the training process and watch the model as it forms. It is not required for the basic workflow, but it is automatically set up when preparing SIBR for the Real-Time Viewer. + +### Hardware Requirements + +- OpenGL 4.5-ready GPU +- 8 GB VRAM + +### Setup + +If you cloned with submodules (e.g., using ```--recursive```), the source code for the viewers is found in ```SIBR_viewers_(windows|linux)``` (choose whichever fits your OS). The network viewer runs within the SIBR framework for Image-based Rendering applications. For setup, you will need the CUDA 11 **development** kit, a C++ compiler (use Visual Studio **2019** on Windows) and **CMake**, then follow the steps corresponding to your operating system. + +#### Windows +On Windows, CMake should take care of your dependencies +```shell +cd SIBR_viewers_windows +cmake -Bbuild . +cmake --build build --target install --config RelWithDebInfo +``` +You may specify a different configuration, e.g. ```Debug``` if you need more control during development. + +#### Ubuntu +For Ubuntu, you will need to install a few dependencies before running the project setup. +```shell +# Dependencies +sudo apt install -y libglew-dev libassimp-dev libboost-all-dev libgtk-3-dev libopencv-dev libglfw3-dev libavdevice-dev libavcodec-dev libeigen3-dev libxxf86vm-dev libembree-dev +# Project setup +cd SIBR_viewers_linux +cmake -Bbuild . +cmake --build build --target install +``` +If you receive a build error related to ```libglfw```, locate the library directory and set up a symbolic link there ```libglfw3.so``` → ``````. + +### Running +You may run the compiled ```SIBR_remoteGaussian_app_``` either by opening the build in your C++ development IDE or by running the installed app in ```install/bin```, e.g.: +```shell +./SIBR_viewers_windows/install/bin/SIBR_remoteGaussian_app_rwdi.exe +``` + +The network viewer allows you to connect to a running training process on the same or a different machine. If you are training on the same machine and OS, no command line parameters should be required: the optimizer communicates the location of the training data to the network viewer. By default, optimizer and network viewer will try to establish a connection on **localhost** on port **6009**. You can change this behavior by providing matching ```--ip``` and ```--port``` parameters to both the optimizer and the network viewer. If for some reason the path used by the optimizer to find the training data is not reachable by the network viewer (e.g., due to them running on different (virtual) machines), you may specify an override location to the viewer by using ```--path ```. + +### Navigation + +The SIBR interface provides several methods of navigating the scene. By default, you will be started with an FPS navigator, which you can control with ```W, A, S, D``` for camera translation and ```Q, E, I, K, J, L``` for rotation. Alternatively, you may want to use a Trackball-style navigator (select from the floating menu). You can also snap to a camera from the data set with the ```Snap to``` button or find the closest camera with ```Snap to closest```. The floating menues also allow you to change the navigation speed. You can use the ```Scaling Modifier``` to control the size of the displayed Gaussians, or show the initial point cloud. + +## Real-Time Viewer + +The Real-Time Viewer can be used to render trained models with real-time frame rates. + +### Hardware Requirements + +- CUDA-ready GPU with Compute Capability 7.0+ +- OpenGL 4.5-ready GPU +- 8 GB VRAM + +### Setup + +The setup is the same as for the remote viewer. + +### Running +You may run the compiled ```SIBR_gaussianViewer_app_``` either by opening the build in your C++ development IDE or by running the installed app in ```install/bin```, e.g.: +```shell +./SIBR_viewers_windows/install/bin/SIBR_gaussianViewer_app_rwdi.exe --model-path +``` + +It should suffice to provide the ```--model-path``` parameter pointing to a trained model directory. Alternatively, you can specify an override location for training input data using ```--path```. To use a specific resolution other than the auto-chosen one, specify ```--rendering-size ```. To unlock the full frame rate, please disable V-Sync on your machine and enter full-screen mode (Menu → Display). + +### Navigation + +Navigation works exactly as it does in the network viewer. However, you also have the option to visualize the Gaussians by rendering them as ellipsoids from the floating menu. + +## Converting your own Scenes + +We provide a converter script ```convert.py```, which uses COLMAP to extract SfM information. Optionally, you can use ImageMagick to resize the input images. To use them, please first install a recent version of COLMAP (ideally CUDA-powered) and ImageMagick. Put the images you want to use in a directory ```/input```. If you have COLMAP and ImageMagick on your system path, you can simply run +```shell +python convert.py -s [--resize] #If not resizing, ImageMagick is not needed +``` +Alternatively, you can use the optional parameters ```--colmap_executable``` and ```--magick_executable``` to point to the respective paths. Please not that on Windows, the executable should point to the COLMAP ```.bat``` file that takes care of setting the execution environment. Once done, `````` will contain the expected COLMAP data set structure with undistorted, differently sized input images, in addition to your original images and temporary data in the directory ```distorted```. +## FAQ +- *Where do I get data sets, e.g., those referenced in ```full_eval.py```?* The MipNeRF360 data set is provided by the authors of the original paper on the project site. Note that two of the data sets cannot be openly shared and require you to consult the authors directly. For Tanks&Temples and Deep Blending, please use the download links provided above. + +- *24 GB of VRAM for training is a lot! Can't we do it with less?* Yes, most likely. By our calculations it should be possible with **way** less memory (~8GB). If we can find the time we will try to achieve this. If some PyTorch veteran out there wants to tackle this, we look forward to your pull request! + diff --git a/arguments/__init__.py b/arguments/__init__.py new file mode 100644 index 000000000..424d38193 --- /dev/null +++ b/arguments/__init__.py @@ -0,0 +1,93 @@ +from argparse import ArgumentParser, Namespace +import sys +import os + +class GroupParams: + pass + +class ParamGroup: + def __init__(self, parser: ArgumentParser, name : str, fill_none = False): + group = parser.add_argument_group(name) + for key, value in vars(self).items(): + shorthand = False + if key.startswith("_"): + shorthand = True + key = key[1:] + t = type(value) + value = value if not fill_none else None + if shorthand: + if t == bool: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") + else: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) + else: + if t == bool: + group.add_argument("--" + key, default=value, action="store_true") + else: + group.add_argument("--" + key, default=value, type=t) + + def extract(self, args): + group = GroupParams() + for arg in vars(args).items(): + if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): + setattr(group, arg[0], arg[1]) + return group + +class ModelParams(ParamGroup): + def __init__(self, parser, sentinel=False): + self.sh_degree = 3 + self._source_path = "" + self._model_path = "" + self._images = "images" + self._resolution = 1 + self._white_background = False + self.eval = False + super().__init__(parser, "Loading Parameters", sentinel) + +class PipelineParams(ParamGroup): + def __init__(self, parser): + self.convert_SHs_python = False + self.compute_cov3D_python = False + super().__init__(parser, "Pipeline Parameters") + +class OptimizationParams(ParamGroup): + def __init__(self, parser): + self.iterations = 30_000 + self.position_lr_init = 0.00016 + self.position_lr_final = 0.0000016 + self.position_lr_delay_mult = 0.01 + self.posititon_lr_max_steps = 30_000 + self.feature_lr = 0.0025 + self.opacity_lr = 0.05 + self.scaling_lr = 0.001 + self.rotation_lr = 0.001 + self.percent_dense = 0.01 + self.lambda_dssim = 0.2 + self.densification_interval = 100 + self.opacity_reset_interval = 3000 + self.densify_from_iter = 500 + self.densify_until_iter = 15_000 + self.densify_grad_threshold = 0.0002 + super().__init__(parser, "Optimization Parameters") + +def get_combined_args(parser : ArgumentParser): + cmdlne_string = sys.argv[1:] + cfgfile_string = "Namespace()" + args_cmdline = parser.parse_args(cmdlne_string) + + try: + cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") + print("Looking for config file in", cfgfilepath) + with open(cfgfilepath) as cfg_file: + print("Config file found: {}".format(cfgfilepath)) + cfgfile_string = cfg_file.read() + except TypeError: + print("Config file not found at") + pass + args_cfgfile = eval(cfgfile_string) + + merged_dict = vars(args_cfgfile).copy() + for k,v in vars(args_cmdline).items(): + if v != None: + merged_dict[k] = v + return Namespace(**merged_dict) diff --git a/assets/logo_graphdeco.png b/assets/logo_graphdeco.png new file mode 100644 index 000000000..4818ac47b Binary files /dev/null and b/assets/logo_graphdeco.png differ diff --git a/assets/logo_inria.png b/assets/logo_inria.png new file mode 100644 index 000000000..f395b7a72 Binary files /dev/null and b/assets/logo_inria.png differ diff --git a/assets/logo_mpi.png b/assets/logo_mpi.png new file mode 100644 index 000000000..2282e7f18 Binary files /dev/null and b/assets/logo_mpi.png differ diff --git a/assets/logo_mpi.svg b/assets/logo_mpi.svg new file mode 100644 index 000000000..6cb3a0069 --- /dev/null +++ b/assets/logo_mpi.svgo newline at end of file diff --git a/assets/logo_uca.png b/assets/logo_uca.png new file mode 100644 index 000000000..e7f1a6f0a Binary files /dev/null and b/assets/logo_uca.png differ diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 000000000..98b8166be Binary files /dev/null and b/assets/teaser.png differ diff --git a/convert.py b/convert.py new file mode 100644 index 000000000..a9d827cfd --- /dev/null +++ b/convert.py @@ -0,0 +1,85 @@ +import os +from argparse import ArgumentParser +import shutil + +# This Python script is based on the shell converter script provided in the MipNerF 360 repository. +parser = ArgumentParser("Colmap converter") +parser.add_argument("--no_gpu", action='store_true') +parser.add_argument("--source_path", "-s", required=True, type=str) +parser.add_argument("--camera", default="OPENCV", type=str) +parser.add_argument("--colmap_executable", default="", type=str) +parser.add_argument("--resize", action="store_true") +parser.add_argument("--magick_executable", default="", type=str) +args = parser.parse_args() +colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" +magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" +use_gpu = 1 if not args.no_gpu else 0 + +os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) + +## Feature extraction +os.system(colmap_command + " feature_extractor "\ + "--database_path " + args.source_path + "/distorted/database.db \ + --image_path " + args.source_path + "/input \ + --ImageReader.single_camera 1 \ + --ImageReader.camera_model " + args.camera + " \ + --SiftExtraction.use_gpu " + str(use_gpu)) + +## Feature matching +os.system(colmap_command + " exhaustive_matcher \ + --database_path " + args.source_path + "/distorted/database.db \ + --SiftMatching.use_gpu " + str(use_gpu)) + +### Bundle adjustment +# The default Mapper tolerance is unnecessarily large, +# decreasing it speeds up bundle adjustment steps. +os.system(colmap_command + " mapper \ + --database_path " + args.source_path + "/distorted/database.db \ + --image_path " + args.source_path + "/input \ + --output_path " + args.source_path + "/distorted/sparse \ + --Mapper.ba_global_function_tolerance=0.000001") + +### Image undistortion +## We need to undistort our images into ideal pinhole intrinsics. +os.system(colmap_command + " image_undistorter \ + --image_path " + args.source_path + "/input \ + --input_path " + args.source_path + "/distorted/sparse/0 \ + --output_path " + args.source_path + "\ + --output_type COLMAP") + +files = os.listdir(args.source_path + "/sparse") +os.makedirs(args.source_path + "/sparse/0", exist_ok=True) +# Copy each file from the source directory to the destination directory +for file in files: + if file == '0': + continue + source_file = os.path.join(args.source_path, "sparse", file) + destination_file = os.path.join(args.source_path, "sparse", "0", file) + shutil.move(source_file, destination_file) + +if(args.resize): + print("Copying and resizing...") + + # Resize images. + os.makedirs(args.source_path + "/images_2", exist_ok=True) + os.makedirs(args.source_path + "/images_4", exist_ok=True) + os.makedirs(args.source_path + "/images_8", exist_ok=True) + # Get the list of files in the source directory + files = os.listdir(args.source_path + "/images") + # Copy each file from the source directory to the destination directory + for file in files: + source_file = os.path.join(args.source_path, "images", file) + + destination_file = os.path.join(args.source_path, "images_2", file) + shutil.copy2(source_file, destination_file) + os.system(magick_command + " mogrify -resize 50% " + destination_file) + + destination_file = os.path.join(args.source_path, "images_4", file) + shutil.copy2(source_file, destination_file) + os.system(magick_command + " mogrify -resize 25% " + destination_file) + + destination_file = os.path.join(args.source_path, "images_8", file) + shutil.copy2(source_file, destination_file) + os.system(magick_command + " mogrify -resize 12.5% " + destination_file) + +print("Done.") \ No newline at end of file diff --git a/environment_full.yml b/environment_full.yml new file mode 100644 index 000000000..87e8dd802 --- /dev/null +++ b/environment_full.yml @@ -0,0 +1,19 @@ +name: gaussian_splatting +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cudatoolkit=11.6 + - cudatoolkit-dev=11.6 + - cxx-compiler=1.3.0 + - plyfile=0.8.1 + - python=3.7.13 + - pip=22.3.1 + - pytorch=1.12.1 + - torchaudio=0.12.1 + - torchvision=0.13.1 + - tqdm + - pip: + - submodules/diff-gaussian-rasterization + - submodules/simple-knn \ No newline at end of file diff --git a/environment_light.yml b/environment_light.yml new file mode 100644 index 000000000..b17a50f65 --- /dev/null +++ b/environment_light.yml @@ -0,0 +1,17 @@ +name: gaussian_splatting +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cudatoolkit=11.6 + - plyfile=0.8.1 + - python=3.7.13 + - pip=22.3.1 + - pytorch=1.12.1 + - torchaudio=0.12.1 + - torchvision=0.13.1 + - tqdm + - pip: + - submodules/diff-gaussian-rasterization + - submodules/simple-knn \ No newline at end of file diff --git a/full_eval.py b/full_eval.py new file mode 100644 index 000000000..01be10608 --- /dev/null +++ b/full_eval.py @@ -0,0 +1,52 @@ +import os +from argparse import ArgumentParser + +mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] +mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] +tanks_and_temples_scenes = ["truck", "train"] +deep_blending_scenes = ["drjohnson", "playroom"] + +parser = ArgumentParser(description="Full evaluation script parameters") +parser.add_argument("--skip_training", action="store_true") +parser.add_argument("--skip_rendering", action="store_true") +parser.add_argument("--skip_metrics", action="store_true") +args, _ = parser.parse_known_args() + +if not args.skip_training: + parser.add_argument('--mipnerf360', "-m360", required=True, type=str) + parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) + parser.add_argument("--deepblending", "-db", required=True, type=str) + args = parser.parse_args() + + common_args = " --quiet --eval --test_iterations -1" + for scene in tanks_and_temples_scenes: + source = args.tanksandtemples + "/" + scene + os.system("python train.py -s " + source + " -m ./eval/" + scene + common_args) + for scene in deep_blending_scenes: + source = args.deepblending + "/" + scene + os.system("python train.py -s " + source + " -m ./eval/" + scene + common_args) + for scene in mipnerf360_outdoor_scenes: + source = args.mipnerf360 + "/" + scene + os.system("python train.py -s " + source + " -i images_4 -m ./eval/" + scene + common_args) + for scene in mipnerf360_indoor_scenes: + source = args.mipnerf360 + "/" + scene + os.system("python train.py -s " + source + " -i images_2 -m ./eval/" + scene + common_args) + +all_scenes = [] +all_scenes.extend(mipnerf360_outdoor_scenes) +all_scenes.extend(mipnerf360_indoor_scenes) +all_scenes.extend(tanks_and_temples_scenes) +all_scenes.extend(deep_blending_scenes) + +if not args.skip_rendering: + for scene in all_scenes: + os.system("python render.py --quiet --skip_train --eval --iteration 7000 -m ./eval/" + scene) + for scene in all_scenes: + os.system("python render.py --quiet --skip_train --eval --iteration 30000 -m ./eval/" + scene) + +if not args.skip_metrics: + scenes_string = "" + for scene in all_scenes: + scenes_string += "\"" + "./eval/" + scene + "\" " + + os.system("python metrics.py -m " + scenes_string) \ No newline at end of file diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py new file mode 100644 index 000000000..492318c87 --- /dev/null +++ b/gaussian_renderer/__init__.py @@ -0,0 +1,88 @@ +import torch +import math +from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer +from scene.gaussian_model import GaussianModel +from utils.sh_utils import eval_sh + +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if colors_precomp is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return {"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii} diff --git a/gaussian_renderer/network_gui.py b/gaussian_renderer/network_gui.py new file mode 100644 index 000000000..a136f93e8 --- /dev/null +++ b/gaussian_renderer/network_gui.py @@ -0,0 +1,75 @@ +import torch +import traceback +import socket +import json +from scene.cameras import MiniCam + +host = "127.0.0.1" +port = 6009 + +conn = None +addr = None + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +def init(wish_host, wish_port): + global host, port, listener + host = wish_host + port = wish_port + listener.bind((host, port)) + listener.listen() + listener.settimeout(0) + +def try_connect(): + global conn, addr, listener + try: + conn, addr = listener.accept() + print(f"\nConnected by {addr}") + conn.settimeout(None) + except Exception as inst: + pass + +def read(): + global conn + messageLength = conn.recv(4) + messageLength = int.from_bytes(messageLength, 'little') + message = conn.recv(messageLength) + return json.loads(message.decode("utf-8")) + +def send(message_bytes, verify): + global conn + if message_bytes != None: + conn.sendall(message_bytes) + conn.sendall(len(verify).to_bytes(4, 'little')) + conn.sendall(bytes(verify, 'ascii')) + +def receive(): + message = read() + + width = message["resolution_x"] + height = message["resolution_y"] + + if width != 0 and height != 0: + try: + do_training = bool(message["train"]) + fovy = message["fov_y"] + fovx = message["fov_x"] + znear = message["z_near"] + zfar = message["z_far"] + do_shs_python = bool(message["shs_python"]) + do_rot_scale_python = bool(message["rot_scale_python"]) + keep_alive = bool(message["keep_alive"]) + scaling_modifier = message["scaling_modifier"] + world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() + world_view_transform[:,1] = -world_view_transform[:,1] + world_view_transform[:,2] = -world_view_transform[:,2] + full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() + full_proj_transform[:,1] = -full_proj_transform[:,1] + custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) + except Exception as e: + print("") + traceback.print_exc() + raise e + return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier + else: + return None, None, None, None, None, None \ No newline at end of file diff --git a/lpipsPyTorch/__init__.py b/lpipsPyTorch/__init__.py new file mode 100644 index 000000000..2a6297daa --- /dev/null +++ b/lpipsPyTorch/__init__.py @@ -0,0 +1,21 @@ +import torch + +from .modules.lpips import LPIPS + + +def lpips(x: torch.Tensor, + y: torch.Tensor, + net_type: str = 'alex', + version: str = '0.1'): + r"""Function that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + x, y (torch.Tensor): the input tensors to compare. + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + device = x.device + criterion = LPIPS(net_type, version).to(device) + return criterion(x, y) diff --git a/lpipsPyTorch/modules/lpips.py b/lpipsPyTorch/modules/lpips.py new file mode 100644 index 000000000..9cd001d1e --- /dev/null +++ b/lpipsPyTorch/modules/lpips.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + +from .networks import get_network, LinLayers +from .utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type) + + # linear layers + self.lin = LinLayers(self.net.n_channels_list) + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0), 0, True) diff --git a/lpipsPyTorch/modules/networks.py b/lpipsPyTorch/modules/networks.py new file mode 100644 index 000000000..d36c6a561 --- /dev/null +++ b/lpipsPyTorch/modules/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from .utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) diff --git a/lpipsPyTorch/modules/utils.py b/lpipsPyTorch/modules/utils.py new file mode 100644 index 000000000..3d15a0983 --- /dev/null +++ b/lpipsPyTorch/modules/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/metrics.py b/metrics.py new file mode 100644 index 000000000..469a415f5 --- /dev/null +++ b/metrics.py @@ -0,0 +1,87 @@ +from pathlib import Path +import os +from PIL import Image +import torch +import torchvision.transforms.functional as tf +from utils.loss_utils import ssim +from lpipsPyTorch import lpips +import json +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser + +def readImages(renders_dir, gt_dir): + renders = [] + gts = [] + image_names = [] + for fname in os.listdir(renders_dir): + render = Image.open(renders_dir / fname) + gt = Image.open(gt_dir / fname) + renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) + gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) + image_names.append(fname) + return renders, gts, image_names + +def evaluate(model_paths): + + full_dict = {} + per_view_dict = {} + full_dict_polytopeonly = {} + per_view_dict_polytopeonly = {} + + for scene_dir in model_paths: + print("Scene:", scene_dir) + full_dict[scene_dir] = {} + per_view_dict[scene_dir] = {} + full_dict_polytopeonly[scene_dir] = {} + per_view_dict_polytopeonly[scene_dir] = {} + + test_dir = Path(scene_dir) / "test" + + for method in os.listdir(test_dir): + print("Method:", method) + + full_dict[scene_dir][method] = {} + per_view_dict[scene_dir][method] = {} + full_dict_polytopeonly[scene_dir][method] = {} + per_view_dict_polytopeonly[scene_dir][method] = {} + + method_dir = test_dir / method + gt_dir = method_dir/ "gt" + renders_dir = method_dir / "renders" + renders, gts, image_names = readImages(renders_dir, gt_dir) + + ssims = [] + psnrs = [] + lpipss = [] + + for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): + ssims.append(ssim(renders[idx], gts[idx])) + psnrs.append(psnr(renders[idx], gts[idx])) + lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) + + print("SSIM: {}".format(torch.tensor(ssims).mean())) + print("PSNR: {}".format(torch.tensor(psnrs).mean())) + print("LPIPS: {}".format(torch.tensor(lpipss).mean())) + + full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), + "PSNR": torch.tensor(psnrs).mean().item(), + "LPIPS": torch.tensor(lpipss).mean().item()}) + per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, + "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, + "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) + + with open(scene_dir + "/results.json", 'w') as fp: + json.dump(full_dict[scene_dir], fp, indent=True) + with open(scene_dir + "/per_view.json", 'w') as fp: + json.dump(per_view_dict[scene_dir], fp, indent=True) + +if __name__ == "__main__": + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) + args = parser.parse_args() + evaluate(args.model_paths) diff --git a/render.py b/render.py new file mode 100644 index 000000000..29b8a71b1 --- /dev/null +++ b/render.py @@ -0,0 +1,55 @@ +import torch +from scene import Scene +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args +from gaussian_renderer import GaussianModel + +def render_set(model_path, name, iteration, views, gaussians, pipeline, background): + render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") + gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + + makedirs(render_path, exist_ok=True) + makedirs(gts_path, exist_ok=True) + + for idx, view in enumerate(tqdm(views, desc="Rendering progress")): + rendering = render(view, gaussians, pipeline, background)["render"] + gt = view.original_image[0:3, :, :] + torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) + torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) + +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): + with torch.no_grad(): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) + + bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + if not skip_train: + render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) + + if not skip_test: + render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Testing script parameters") + model = ModelParams(parser, sentinel=True) + pipeline = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = get_combined_args(parser) + print("Rendering " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file diff --git a/scene/__init__.py b/scene/__init__.py new file mode 100644 index 000000000..335d497c8 --- /dev/null +++ b/scene/__init__.py @@ -0,0 +1,83 @@ +import os +import random +import json +from utils.system_utils import searchForMaxIteration +from scene.dataset_readers import sceneLoadTypeCallbacks +from scene.gaussian_model import GaussianModel +from arguments import ModelParams +from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON + +class Scene: + + gaussians : GaussianModel + + def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): + """b + :param path: Path to colmap scene main folder. + """ + self.model_path = args.model_path + self.loaded_iter = None + self.gaussians = gaussians + + if load_iteration: + if load_iteration == -1: + self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) + else: + self.loaded_iter = load_iteration + print("Loading trained model at iteration {}".format(self.loaded_iter)) + + self.train_cameras = {} + self.test_cameras = {} + + if os.path.exists(os.path.join(args.source_path, "sparse")): + scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) + elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): + print("Found transforms_train.json file, assuming Blender data set!") + scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) + else: + assert False, "Could not recognize scene type!" + + if not self.loaded_iter: + with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: + dest_file.write(src_file.read()) + json_cams = [] + camlist = [] + if scene_info.test_cameras: + camlist.extend(scene_info.test_cameras) + if scene_info.train_cameras: + camlist.extend(scene_info.train_cameras) + for id, cam in enumerate(camlist): + json_cams.append(camera_to_JSON(id, cam)) + with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: + json.dump(json_cams, file) + + if shuffle: + random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling + random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling + + self.cameras_extent = scene_info.nerf_normalization["radius"] + + for resolution_scale in resolution_scales: + print("Loading Training Cameras") + self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) + print("Loading Test Cameras") + self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) + + if self.loaded_iter: + self.gaussians.load_ply(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "point_cloud.ply"), + og_number_points=len(scene_info.point_cloud.points)) + else: + self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) + + def save(self, iteration): + point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + + def getTrainCameras(self, scale=1.0): + return self.train_cameras[scale] + + def getTestCameras(self, scale=1.0): + return self.test_cameras[scale] \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py new file mode 100644 index 000000000..c3dcc3241 --- /dev/null +++ b/scene/cameras.py @@ -0,0 +1,53 @@ +import torch +from torch import nn +import numpy as np +from utils.graphics_utils import getWorld2View2, getProjectionMatrix + +class Camera(nn.Module): + def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, + image_name, uid, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0 + ): + super(Camera, self).__init__() + + self.uid = uid + self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.image_name = image_name + + self.original_image = image.clamp(0.0, 1.0).cuda() + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + if gt_alpha_mask is not None: + self.original_image *= gt_alpha_mask.cuda() + else: + self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda") + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + +class MiniCam: + def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + self.world_view_transform = world_view_transform + self.full_proj_transform = full_proj_transform + view_inv = torch.inverse(self.world_view_transform) + self.camera_center = view_inv[3][:3] + diff --git a/scene/colmap_loader.py b/scene/colmap_loader.py new file mode 100644 index 000000000..ebdb14e3d --- /dev/null +++ b/scene/colmap_loader.py @@ -0,0 +1,271 @@ +import numpy as np +import collections +import struct + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + xyzs = None + rgbs = None + errors = None + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = np.array(float(elems[7])) + if xyzs is None: + xyzs = xyz[None, ...] + rgbs = rgb[None, ...] + errors = error[None, ...] + else: + xyzs = np.append(xyzs, xyz[None, ...], axis=0) + rgbs = np.append(rgbs, rgb[None, ...], axis=0) + errors = np.append(errors, error[None, ...], axis=0) + return xyzs, rgbs, errors + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + + + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + + for p_id in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + xyzs[p_id] = xyz + rgbs[p_id] = rgb + errors[p_id] = error + return xyzs, rgbs, errors + +def read_intrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + +def read_extrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_intrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_extrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_colmap_bin_array(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py + + :param path: path to the colmap binary file. + :return: nd array with the floating point values in the value + """ + with open(path, "rb") as fid: + width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, + usecols=(0, 1, 2), dtype=int) + fid.seek(0) + num_delimiter = 0 + byte = fid.read(1) + while True: + if byte == b"&": + num_delimiter += 1 + if num_delimiter >= 3: + break + byte = fid.read(1) + array = np.fromfile(fid, np.float32) + array = array.reshape((width, height, channels), order="F") + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py new file mode 100644 index 000000000..893904696 --- /dev/null +++ b/scene/dataset_readers.py @@ -0,0 +1,244 @@ +import os +import sys +from PIL import Image +from typing import NamedTuple +from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ + read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +import numpy as np +import json +from pathlib import Path +from plyfile import PlyData, PlyElement +from utils.sh_utils import SH2RGB +from scene.gaussian_model import BasicPointCloud + +class CameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + image: np.array + image_path: str + image_name: str + width: int + height: int + +class SceneInfo(NamedTuple): + point_cloud: BasicPointCloud + train_cameras: list + test_cameras: list + nerf_normalization: dict + ply_path: str + +def getNerfppNorm(cam_info): + def get_center_and_diag(cam_centers): + cam_centers = np.hstack(cam_centers) + avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) + center = avg_cam_center + dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) + diagonal = np.max(dist) + return center.flatten(), diagonal + + cam_centers = [] + + for cam in cam_info: + W2C = getWorld2View2(cam.R, cam.T) + C2W = np.linalg.inv(W2C) + cam_centers.append(C2W[:3, 3:4]) + + center, diagonal = get_center_and_diag(cam_centers) + radius = diagonal * 1.1 + + translate = -center + + return {"translate": translate, "radius": radius} + +def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): + cam_infos = [] + for idx, key in enumerate(cam_extrinsics): + sys.stdout.write('\r') + # the exact output you're looking for: + sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) + sys.stdout.flush() + + extr = cam_extrinsics[key] + intr = cam_intrinsics[extr.camera_id] + height = intr.height + width = intr.width + + uid = intr.id + R = np.transpose(qvec2rotmat(extr.qvec)) + T = np.array(extr.tvec) + + if intr.model=="SIMPLE_PINHOLE": + focal_length_x = intr.params[0] + FovY = focal2fov(focal_length_x, height) + FovX = focal2fov(focal_length_x, width) + elif intr.model=="PINHOLE": + focal_length_x = intr.params[0] + focal_length_y = intr.params[1] + FovY = focal2fov(focal_length_y, height) + FovX = focal2fov(focal_length_x, width) + else: + assert False, "Colmap camera model not handled!" + + image_path = os.path.join(images_folder, os.path.basename(extr.name)) + image_name = os.path.basename(image_path).split(".")[0] + image = Image.open(image_path) + + cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=width, height=height) + cam_infos.append(cam_info) + sys.stdout.write('\n') + return cam_infos + +def fetchPly(path): + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + return BasicPointCloud(points=positions, colors=colors, normals=normals) + +def storePly(path, xyz, rgb): + # Define the dtype for the structured array + dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), + ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), + ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] + + normals = np.zeros_like(xyz) + + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb), axis=1) + elements[:] = list(map(tuple, attributes)) + + # Create the PlyData object and write to file + vertex_element = PlyElement.describe(elements, 'vertex') + ply_data = PlyData([vertex_element]) + ply_data.write(path) + +def readColmapSceneInfo(path, images, eval, llffhold=8): + try: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") + cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) + except: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") + cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) + + reading_dir = "images" if images == None else images + cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) + cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) + + if eval: + train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] + test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] + else: + train_cam_infos = cam_infos + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "sparse/0/points3d.ply") + bin_path = os.path.join(path, "sparse/0/points3d.bin") + txt_path = os.path.join(path, "sparse/0/points3d.txt") + if not os.path.exists(ply_path): + print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") + try: + xyz, rgb, _ = read_points3D_binary(bin_path) + except: + xyz, rgb, _ = read_points3D_text(txt_path) + storePly(ply_path, xyz, rgb) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path) + return scene_info + +def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): + cam_infos = [] + + with open(os.path.join(path, transformsfile)) as json_file: + contents = json.load(json_file) + fovx = contents["camera_angle_x"] + + frames = contents["frames"] + for idx, frame in enumerate(frames): + cam_name = os.path.join(path, frame["file_path"] + extension) + + matrix = np.linalg.inv(np.array(frame["transform_matrix"])) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + + image_path = os.path.join(path, cam_name) + image_name = Path(cam_name).stem + image = Image.open(image_path) + + im_data = np.array(image.convert("RGBA")) + + bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) + + norm_data = im_data / 255.0 + arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) + image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + + fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) + FovY = fovx + FovX = fovy + + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) + + return cam_infos + +def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): + print("Reading Training Transforms") + train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) + print("Reading Test Transforms") + test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) + + if not eval: + train_cam_infos.extend(test_cam_infos) + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "points3d.ply") + if not os.path.exists(ply_path): + # Since this data set has no colmap data, we start with random points + num_pts = 100_000 + print(f"Generating random point cloud ({num_pts})...") + + # We create random points inside the bounds of the synthetic Blender scenes + xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) + + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path) + return scene_info + +sceneLoadTypeCallbacks = { + "Colmap": readColmapSceneInfo, + "Blender" : readNerfSyntheticInfo +} \ No newline at end of file diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py new file mode 100644 index 000000000..6b2a249b2 --- /dev/null +++ b/scene/gaussian_model.py @@ -0,0 +1,356 @@ +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from utils.sh_utils import RGB2SH +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation + +class GaussianModel: + def __init__(self, sh_degree : int): + + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + + self._xyz = torch.empty(0) + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + + self.optimizer = None + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + @property + def get_scaling(self): + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) + scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid(0.5 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init*self.spatial_lr_scale, "name": "xyz"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr*self.spatial_lr_scale, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, + lr_final=training_args.position_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.posititon_lr_max_steps) + + def update_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def reset_opacity(self): + opacities_new = inverse_sigmoid(torch.ones_like(self.get_opacity)*0.01) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path, og_number_points=-1): + self.og_number_points = og_number_points + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + + self.active_sh_degree = self.max_sh_degree + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] == name: + stored_state = self.optimizer.state.get(group['params'][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): + d = {"xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling" : new_scaling, + "rotation" : new_rotation} + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[:grads.shape[0]] = grads.squeeze() + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + + stds = self.get_scaling[selected_pts_mask].repeat(N,1) + means =torch.zeros((stds.size(0), 3),device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N,1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) + new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) + + prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) + + new_xyz = self._xyz[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) + + def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) + self.denom[update_filter] += 1 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 000000000..838a4eb87 --- /dev/null +++ b/train.py @@ -0,0 +1,194 @@ +import os +import torch +from random import randint +from utils.loss_utils import l1_loss, ssim +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +def training(dataset, opt, pipe, testing_iterations, saving_iterations): + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(opt.iterations), desc="Training progress") + for iteration in range(1, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, pipe.do_shs_python, pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive() + if custom_cam != None: + net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + # Render + render_pkg = render(viewpoint_cam, gaussians, pipe, background) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + + # Log and save + training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) + + if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + gaussians.update_learning_rate(iteration) + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) + tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('iter_time', elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, + {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) + + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + images = torch.tensor([], device="cuda") + gts = torch.tensor([], device="cuda") + for idx, viewpoint in enumerate(config['cameras']): + image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) + gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) + images = torch.cat((images, image.unsqueeze(0)), dim=0) + gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) + if tb_writer and (idx < 5): + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration) + if iteration == testing_iterations[0]: + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration) + + l1_test = l1_loss(images, gts) + psnr_test = psnr(images, gts).mean() + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + + if tb_writer: + tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + torch.cuda.empty_cache() + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) + parser.add_argument("--quiet", action="store_true") + args = parser.parse_args(sys.argv[1:]) + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) + + # All done + print("\nTraining complete.") \ No newline at end of file diff --git a/utils/camera_utils.py b/utils/camera_utils.py new file mode 100644 index 000000000..c7789807e --- /dev/null +++ b/utils/camera_utils.py @@ -0,0 +1,57 @@ +from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +def loadCam(args, id, cam_info, resolution_scale): + orig_w, orig_h = cam_info.image.size + + if args.resolution in [1, 2, 4, 8]: + resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) + else: # should be a type that converts to float + global_down = orig_w/args.resolution + scale = float(global_down) * float(resolution_scale) + resolution = (int(orig_w / scale), int(orig_h / scale)) + + resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + gt_image = resized_image_rgb[:3, ...] + loaded_mask = None + + if resized_image_rgb.shape[1] == 4: + loaded_mask = resized_image_rgb[3:4, ...] + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=gt_image, gt_alpha_mask=loaded_mask, + image_name=cam_info.image_name, uid=id) + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + +def camera_to_JSON(id, camera : Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + 'id' : id, + 'img_name' : camera.image_name, + 'width' : camera.width, + 'height' : camera.height, + 'position': pos.tolist(), + 'rotation': serializable_array_2d, + 'fy' : fov2focal(camera.FovY, camera.height), + 'fx' : fov2focal(camera.FovX, camera.width) + } + return camera_entry diff --git a/utils/general_utils.py b/utils/general_utils.py new file mode 100644 index 000000000..0dc50dc71 --- /dev/null +++ b/utils/general_utils.py @@ -0,0 +1,122 @@ +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py new file mode 100644 index 000000000..545fc4190 --- /dev/null +++ b/utils/graphics_utils.py @@ -0,0 +1,66 @@ +import torch +import math +import numpy as np +from typing import NamedTuple + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) \ No newline at end of file diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 000000000..23e1cb779 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,8 @@ +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/utils/loss_utils.py b/utils/loss_utils.py new file mode 100644 index 000000000..08c98a69c --- /dev/null +++ b/utils/loss_utils.py @@ -0,0 +1,53 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + diff --git a/utils/sh_utils.py b/utils/sh_utils.py new file mode 100644 index 000000000..bbca7d192 --- /dev/null +++ b/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/utils/system_utils.py b/utils/system_utils.py new file mode 100644 index 000000000..133553857 --- /dev/null +++ b/utils/system_utils.py @@ -0,0 +1,17 @@ +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters)