diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a816307
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,148 @@
+data/*
+slurm/
+wandb/
+lightning_logs/
+runs/
+runs
+
+*.pyd
+
+# Editors
+.vscode/
+.idea/
+
+# Vagrant
+.vagrant/
+
+# Mac/OSX
+.DS_Store
+
+# Windows
+Thumbs.db
+
+# Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+*.out
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+/experiments/
+/pretrained-examples/
+
+test.ipynb
+debug/*
+checkpoints/*
+outputs/*
+.hydra/*
+*.mp4
+*.gif
+test.py
+*.slurm
+z/
\ No newline at end of file
diff --git a/README.md b/README.md
index 7f0c9a3..c2e48ec 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,94 @@
# SlotLifter
-Code for "SlotLifter: Slot-guided Feature Lifting for Learning Object-centric Radiance Fields" (ECCV 2024)
+
+
+
+
+
+
+
+
+
+
+
-# Coming Soon
+This repository contains the official implementation of the ECCV 2024 paper:
+
+[SlotLifter: Slot-guided Feature Lifting for Learning Object-centric Radiance Fields](https://arxiv.org/abs/2408.06697)
+
+[YuLiu](https://yuliu-ly.github.io)\*,[Baoxiong Jia](https://buzz-beater.github.io)\*,[Yixin Chen](https://yixchen.github.io), [Siyuan Huang](https://siyuanhuang.com)
+
+
+
+
+
+## Environment Setup
+We provide all environment configurations in ``requirements.txt``. To install all packages, you can create a conda environment and install the packages as follows:
+```bash
+conda create -n slotlifter python=3.8
+conda activate slotlifter
+conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
+pip install -r requirements.txt
+```
+In our experiments, we used NVIDIA CUDA 11.3 on Ubuntu 20.04. Similar CUDA version should also be acceptable with corresponding version control for ``torch`` and ``torchvision``.
+
+## Dataset
+### 1. ShapeStacks, ObjectsRoom, CLEVRTex, Flowers
+Download ShapeStacks, ObjectsRoom, CLEVRTex and Flowers datasets with
+```bash
+chmod +x scripts/downloads_data.sh
+./downloads_data.sh
+```
+For ObjectsRoom dataset, you need to run ``objectsroom_process.py`` to save the tfrecords dataset as a png format.
+Remember to change the ``DATA_ROOT`` in ``downloads_data.sh`` and ``objectsroom_process.py`` to your own paths.
+### 2. PTR, Birds, Dogs, Cars
+Download PTR dataset following instructions from http://ptr.csail.mit.edu. Download CUB-Birds, Stanford Dogs, and Cars datasets from [here](https://drive.google.com/drive/folders/1zEzsKV2hOlwaNRzrEXc9oGdpTBrrVIVk), provided by authors from [DRC](https://github.com/yuPeiyu98/DRC). We use the ```birds.zip```, ```cars.tar``` and ```dogs.zip``` and then uncompress them.
+
+### 4. YCB, ScanNet, COCO
+YCB, ScanNet and COCO datasets are available from [here](https://www.dropbox.com/sh/u1p1d6hysjxqauy/AACgEh0K5ANipuIeDnmaC5mQa?dl=0), provided by authors from [UnsupObjSeg](https://github.com/vLAR-group/UnsupObjSeg).
+
+### 5. Data preparation
+Please organize the data following [here](./data/README.md) before experiments.
+
+## Training
+
+To train the model from scratch we provide the following model files:
+ - ``train_trans_dec.py``: transformer-based model
+ - ``train_mixture_dec.py``: mixture-based model
+ - ``train_base_sa.py``: original slot-attention
+We provide training scripts under ``scripts/train``. Please use the following command and change ``.sh`` file to the model you want to experiment with. Take the transformer-based decoder experiment on Birds as an exmaple, you can run the following:
+```bash
+$ cd scripts
+$ cd train
+$ chmod +x trans_dec_birds.sh
+$ ./trans_dec_birds.sh
+```
+Remember to change the paths in ``path.json`` to your own paths.
+## Reloading checkpoints & Evaluation
+
+To reload checkpoints and only run inference, we provide the following model files:
+ - ``test_trans_dec.py``: transformer-based model
+ - ``test_mixture_dec.py``: mixture-based model
+ - ``test_base_sa.py``: original slot-attention
+
+Similarly, we provide testing scripts under ```scripts/test```. We provide transformer-based model for real-world datasets (Birds, Dogs, Cars, Flowers, YCB, ScanNet, COCO)
+and mixture-based model for synthetic datasets(ShapeStacks, ObjectsRoom, ClevrTex, PTR). We provide all checkpoints [here](https://drive.google.com/drive/folders/10LmK9JPWsSOcezqd6eLjuzn38VdwkBUf?usp=sharing). Please use the following command and change ``.sh`` file to the model you want to experiment with:
+```bash
+$ cd scripts
+$ cd test
+$ chmod +x trans_dec_birds.sh
+$ ./trans_dec_birds.sh
+```
+
+## Citation
+If you find our paper and/or code helpful, please consider citing:
+```
+@inproceedings{Liu2024slotlifter,
+ title={SlotLifter: Slot-guided Feature Lifting for Learning Object-centric Radiance Fields},
+ author={Liu, Yu and Jia, Baoxiong and Chen, Yixin and Huang, Siyuan},
+ booktitle={European Conference on Computer Vision (ECCV)},
+ year={2024}
+}
+```
+
+## Acknowledgement
+This code heavily used resources from [PanopticLifting](https://github.com/nihalsid/panoptic-lifting), [BO-QSA](https://github.com/YuLiu-LY/BO-QSA), [SLATE](https://github.com/singhgautam/slate), [OSRT](https://github.com/stelzner/osrt), [IBRNet](https://github.com/googleinterns/IBRNet). We thank the authors for open-sourcing their awesome projects.
diff --git a/assets/overview.png b/assets/overview.png
new file mode 100644
index 0000000..1bd364a
Binary files /dev/null and b/assets/overview.png differ
diff --git a/config/cfg/dtu.yaml b/config/cfg/dtu.yaml
new file mode 100644
index 0000000..e2aa2c0
--- /dev/null
+++ b/config/cfg/dtu.yaml
@@ -0,0 +1,107 @@
+# Wandb
+project: "dtu" # project name
+exp_name: test # experiment name
+entity: # username or teamname where you're sending runs
+group: # experiment groupname
+job_type: debug # train / test / debug ...
+tags: # tags for this run
+id: # unique Id for this run
+notes: # notes for this run
+watch_model: false # true for logging the gradient of parameters
+# Training
+lpips_net: vgg
+sample_mode: uniform
+bg_bound: 0.27
+force_bg: false
+force_bg_steps: 30000
+normalize: true
+benchmark: false
+deterministic: false
+render_src_view: false
+profiler: # use profiler to check time bottleneck
+resume: null
+ckpt_path: ''
+logger: # wandb or None
+log_path: "runs/dtu"
+chunk: 8192 # num of rays per chunk
+num_workers: 0
+seed: 42
+val_percent: 1.0 # val_batchs = val_check_percent * val_batchs if val_batchs < 1 else val_batchs
+train_percent: 1.0
+test_percent: 1.0
+val_check_interval: 4 # do validation every val_check_interval epochs. It could be less than 1
+grad_clip: 0.5
+precision: 32 # compute precision
+instance_steps: 500000
+stop_semantic_grad: false
+decay_noise: 20000
+seg_metrics:
+- ari
+- hiou
+- ari_fg
+# Optimizer
+optimizer: lion
+lr: 5e-5
+min_lr_factor: 0.02
+weight_decay: 0.001
+warmup_steps: 10000
+max_steps: 250000
+max_epochs: 10000
+decay_steps: 50000
+# Dataset
+norm_scene: false
+select_view_func: nearby # or uniform
+load_mask: false
+img_size:
+- 300
+- 400
+train_subsample_frames: 1
+val_subsample_frames: 5
+num_src_view: 4
+batch_size: 2
+ray_batchsize: 1024
+max_instances: 20
+dataset: dtu
+dataset_root: /home/yuliu/Dataset/DTU
+instance_dir: instance
+semantics_dir: semantics
+# dataset_root: data/hypersim/ai_001_008
+max_depth: 10
+visualized_indices: null
+overfit: false
+# Model
+# Multi-view enc
+feature_size: 32
+num_heads: 4
+conv_enc: false
+conv_dim: 32
+# slot_enc
+sigma_steps: 30000
+num_slots: 2
+num_iter: 3
+slot_size: 256
+drop_path: 0.2
+num_blocks: 1
+# slot_dec
+slot_density: true
+slot_dec_dim: 64
+num_dec_blocks: 4
+# NeRF
+random_proj_ratio: 1
+random_proj: true
+random_proj_steps: 30000
+n_samples: 64
+n_samples_fine: 64
+coarse_to_fine: true
+pe_view: 2
+pe_feat: 0
+grid_init: pos_enc
+monitor: psnr
+nerf_mlp_dim: 64
+suffix: ''
+scene_id: -1
+num_vis: 30
+hydra:
+ output_subdir: null # Disable saving of config files. We'll do that ourselves.
+ run:
+ dir: .
\ No newline at end of file
diff --git a/config/cfg/scannet.yaml b/config/cfg/scannet.yaml
new file mode 100644
index 0000000..36474d3
--- /dev/null
+++ b/config/cfg/scannet.yaml
@@ -0,0 +1,108 @@
+# Wandb
+project: "scannet" # project name
+exp_name: test # experiment name
+entity: # username or teamname where you're sending runs
+group: # experiment groupname
+job_type: debug # train / test / debug ...
+tags: # tags for this run
+id: # unique Id for this run
+notes: # notes for this run
+watch_model: false # true for logging the gradient of parameters
+# Training
+lpips_net: vgg
+sample_mode: uniform
+bg_bound: 0.27
+force_bg: false
+force_bg_steps: 30000
+normalize: true
+benchmark: false
+deterministic: false
+render_src_view: false
+profiler: # use profiler to check time bottleneck
+resume: null
+ckpt_path: ''
+logger: # wandb or None
+log_path: "runs/scannet"
+chunk: 8192 # num of rays per chunk
+num_workers: 0
+seed: 42
+val_percent: 1.0 # val_batchs = val_check_percent * val_batchs if val_batchs < 1 else val_batchs
+train_percent: 1.0
+test_percent: 1.0
+val_check_interval: 4 # do validation every val_check_interval epochs. It could be less than 1
+grad_clip: 0.5
+precision: 32 # compute precision
+instance_steps: 500000
+recon_rgb: true
+stop_semantic_grad: false
+decay_noise: 20000
+seg_metrics:
+- ari
+- hiou
+- ari_fg
+# Optimizer
+optimizer: lion
+lr: 5e-5
+min_lr_factor: 0.02
+weight_decay: 0.001
+warmup_steps: 10000
+max_steps: 250000
+max_epochs: 10000
+decay_steps: 50000
+# Dataset
+norm_scene: false
+select_view_func: nearby # or uniform
+load_mask: false
+img_size:
+- 480
+- 640
+train_subsample_frames: 1
+val_subsample_frames: 5
+num_src_view: 4
+batch_size: 2
+ray_batchsize: 1024
+max_instances: 20
+dataset: scannet
+dataset_root: /home/yuliu/Dataset/scannet
+instance_dir: instance
+semantics_dir: semantics
+# dataset_root: data/hypersim/ai_001_008
+max_depth: 10
+visualized_indices: null
+overfit: false
+# Model
+# Multi-view enc
+feature_size: 32
+num_heads: 4
+conv_enc: false
+conv_dim: 32
+# slot_enc
+sigma_steps: 30000
+num_slots: 8
+num_iter: 3
+slot_size: 256
+drop_path: 0.2
+num_blocks: 1
+# slot_dec
+slot_density: true
+slot_dec_dim: 64
+num_dec_blocks: 4
+# NeRF
+random_proj_ratio: 1
+random_proj: true
+random_proj_steps: 30000
+n_samples: 64
+n_samples_fine: 64
+pe_view: 2
+pe_feat: 0
+coarse_to_fine: true
+grid_init: pos_enc
+nerf_mlp_dim: 64
+monitor: psnr
+num_vis: 1
+scene_id: 0
+suffix: ''
+hydra:
+ output_subdir: null # Disable saving of config files. We'll do that ourselves.
+ run:
+ dir: .
\ No newline at end of file
diff --git a/config/cfg/uorf.yaml b/config/cfg/uorf.yaml
new file mode 100755
index 0000000..940c019
--- /dev/null
+++ b/config/cfg/uorf.yaml
@@ -0,0 +1,103 @@
+# Wandb
+project: "uorf" # project name
+exp_name: test # experiment name
+entity: # username or teamname where you're sending runs
+group: # experiment groupname
+job_type: debug # train / test / debug ...
+tags: # tags for this run
+id: # unique Id for this run
+notes: # notes for this run
+watch_model: false # true for logging the gradient of parameters
+# Training
+lpips_net: alex
+sample_mode: uniform
+bg_bound: 0.64
+force_bg: true
+force_bg_steps: 50000
+normalize: false
+benchmark: false
+deterministic: false
+render_src_view: true
+profiler: # use profiler to check time bottleneck
+resume: null
+ckpt_path: ''
+logger: # wandb or None
+log_path: "runs/uorf"
+chunk: 16384 # num of rays per chunk
+num_workers: 0
+seed: 42
+val_percent: 0.2 # val_batchs = val_check_percent * val_batchs if val_batchs < 1 else val_batchs
+train_percent: 1.0
+test_percent: 1.0
+val_check_interval: 4 # do validation every val_check_interval epochs. It could be less than 1
+grad_clip: 0.5
+precision: 32 # compute precision
+instance_steps: 500000
+stop_semantic_grad: false
+decay_noise: 20000
+seg_metrics: ['ari', 'hiou', 'ari_fg']
+# Optimizer
+optimizer: lion
+lr: 5e-5
+min_lr_factor: 0.02
+weight_decay: 0.001
+warmup_steps: 10000
+max_steps: 250000
+max_epochs: 10000
+decay_steps: 50000
+# Dataset
+subset: kitchen_matte
+norm_scene: true
+select_view_func: nearby # or uniform
+load_mask: false
+train_subsample_frames: 1
+val_subsample_frames: 5
+num_src_view: 1
+batch_size: 2
+ray_batchsize: 1024
+max_instances: 20
+img_size:
+- 128
+- 128
+n_scenes: 5
+dataset: uorf
+dataset_root: /home/yuliu/Dataset/uorf
+num_vis: 20
+max_depth: 10
+visualized_indices: null
+overfit: false
+# Model
+# Multi-view enc
+conv_enc: false
+feature_size: 64
+num_heads: 4
+conv_dim: 32
+# slot_enc
+sigma_steps: 30000
+num_slots: 5
+num_iter: 3
+slot_size: 256
+drop_path: 0.2
+num_blocks: 1
+# slot_dec
+slot_dec_dim: 64
+num_dec_blocks: 4
+slot_density: true
+# NeRF
+random_proj_ratio: 1
+random_proj: true
+random_proj_steps: 30000
+n_samples: 96
+n_samples_fine: 64
+coarse_to_fine: false
+pe_view: 10
+pe_feat: 0
+grid_init: pos_enc
+nerf_mlp_dim: 64
+scene_id: 0
+monitor: nv_ari
+suffix: ''
+hydra:
+ output_subdir: null # Disable saving of config files. We'll do that ourselves.
+ run:
+ dir: .
diff --git a/config/config.yaml b/config/config.yaml
new file mode 100644
index 0000000..f75c362
--- /dev/null
+++ b/config/config.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - cfg: scannet
+hydra:
+ output_subdir: null # Disable saving of config files. We'll do that ourselves.
+ run:
+ dir: .
\ No newline at end of file
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000..8f8a3ee
--- /dev/null
+++ b/dataset/__init__.py
@@ -0,0 +1,24 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from dataset.scannet_dataset import ScannetDataset
+from dataset.uorf_dataset import MultiscenesDataset, VisualDataset
+from dataset.dtu_dataset import DTUDataset, VisDTUDataset
+
+datasets = {
+ "scannet": ScannetDataset,
+ "uorf": MultiscenesDataset,
+ "dtu": DTUDataset,
+ "uorf_vis": VisualDataset,
+ "dtu_vis": VisDTUDataset,
+}
+
+def get_dataset(config):
+ if config.job_type == 'vis':
+ dataset = datasets[config.dataset + "_vis"]
+ else:
+ dataset = datasets[config.dataset]
+ train_set = dataset(config, "train")
+ val_set = dataset(config, "val")
+ test_set = dataset(config, "test")
+ return train_set, val_set, test_set
+
diff --git a/dataset/data_utils.py b/dataset/data_utils.py
new file mode 100644
index 0000000..6b43ff4
--- /dev/null
+++ b/dataset/data_utils.py
@@ -0,0 +1,280 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import math
+from PIL import Image
+import torchvision.transforms as transforms
+import torch
+from scipy.spatial.transform import Rotation as R
+import cv2
+
+rng = np.random.RandomState(234)
+_EPS = np.finfo(float).eps * 4.0
+TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
+
+
+def vector_norm(data, axis=None, out=None):
+ """Return length, i.e. eucledian norm, of ndarray along axis.
+ """
+ data = np.array(data, dtype=np.float64, copy=True)
+ if out is None:
+ if data.ndim == 1:
+ return math.sqrt(np.dot(data, data))
+ data *= data
+ out = np.atleast_1d(np.sum(data, axis=axis))
+ np.sqrt(out, out)
+ return out
+ else:
+ data *= data
+ np.sum(data, axis=axis, out=out)
+ np.sqrt(out, out)
+
+
+def quaternion_about_axis(angle, axis):
+ """Return quaternion for rotation about axis.
+ """
+ quaternion = np.zeros((4, ), dtype=np.float64)
+ quaternion[:3] = axis[:3]
+ qlen = vector_norm(quaternion)
+ if qlen > _EPS:
+ quaternion *= math.sin(angle/2.0) / qlen
+ quaternion[3] = math.cos(angle/2.0)
+ return quaternion
+
+
+def quaternion_matrix(quaternion):
+ """Return homogeneous rotation matrix from quaternion.
+ """
+ q = np.array(quaternion[:4], dtype=np.float64, copy=True)
+ nq = np.dot(q, q)
+ if nq < _EPS:
+ return np.identity(4)
+ q *= math.sqrt(2.0 / nq)
+ q = np.outer(q, q)
+ return np.array((
+ (1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0),
+ ( q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0),
+ ( q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0),
+ ( 0.0, 0.0, 0.0, 1.0)
+ ), dtype=np.float64)
+
+
+def rectify_inplane_rotation(src_pose, tar_pose, src_img, th=40):
+ relative = np.linalg.inv(tar_pose).dot(src_pose)
+ relative_rot = relative[:3, :3]
+ r = R.from_matrix(relative_rot)
+ euler = r.as_euler('zxy', degrees=True)
+ euler_z = euler[0]
+ if np.abs(euler_z) < th:
+ return src_pose, src_img
+
+ R_rectify = R.from_euler('z', -euler_z, degrees=True).as_matrix()
+ src_R_rectified = src_pose[:3, :3].dot(R_rectify)
+ out_pose = np.eye(4)
+ out_pose[:3, :3] = src_R_rectified
+ out_pose[:3, 3:4] = src_pose[:3, 3:4]
+ h, w = src_img.shape[:2]
+ center = ((w - 1.) / 2., (h - 1.) / 2.)
+ M = cv2.getRotationMatrix2D(center, -euler_z, 1)
+ src_img = np.clip((255*src_img).astype(np.uint8), a_max=255, a_min=0)
+ rotated = cv2.warpAffine(src_img, M, (w, h), borderValue=(255, 255, 255), flags=cv2.INTER_LANCZOS4)
+ rotated = rotated.astype(np.float32) / 255.
+ return out_pose, rotated
+
+
+def random_crop(rgb, camera, src_rgbs, src_cameras, size=(400, 600), center=None):
+ h, w = rgb.shape[:2]
+ out_h, out_w = size[0], size[1]
+ if out_w >= w or out_h >= h:
+ return rgb, camera, src_rgbs, src_cameras
+
+ if center is not None:
+ center_h, center_w = center
+ else:
+ center_h = np.random.randint(low=out_h // 2 + 1, high=h - out_h // 2 - 1)
+ center_w = np.random.randint(low=out_w // 2 + 1, high=w - out_w // 2 - 1)
+
+ rgb_out = rgb[center_h - out_h // 2:center_h + out_h // 2, center_w - out_w // 2:center_w + out_w // 2, :]
+ src_rgbs = np.array(src_rgbs)
+ src_rgbs = src_rgbs[:, center_h - out_h // 2:center_h + out_h // 2,
+ center_w - out_w // 2:center_w + out_w // 2, :]
+ camera[0] = out_h
+ camera[1] = out_w
+ camera[4] -= center_w - out_w // 2
+ camera[8] -= center_h - out_h // 2
+ src_cameras[:, 4] -= center_w - out_w // 2
+ src_cameras[:, 8] -= center_h - out_h // 2
+ src_cameras[:, 0] = out_h
+ src_cameras[:, 1] = out_w
+ return rgb_out, camera, src_rgbs, src_cameras
+
+
+def random_flip(rgb, camera, src_rgbs, src_cameras):
+ h, w = rgb.shape[:2]
+ h_r, w_r = src_rgbs.shape[1:3]
+ rgb_out = np.flip(rgb, axis=1).copy()
+ src_rgbs = np.flip(src_rgbs, axis=-2).copy()
+ camera[2] *= -1
+ camera[4] = w - 1. - camera[4]
+ src_cameras[:, 2] *= -1
+ src_cameras[:, 4] = w_r - 1. - src_cameras[:, 4]
+ return rgb_out, camera, src_rgbs, src_cameras
+
+
+def get_color_jitter_params(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
+ color_jitter = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
+ transform = transforms.ColorJitter.get_params(color_jitter.brightness,
+ color_jitter.contrast,
+ color_jitter.saturation,
+ color_jitter.hue)
+ return transform
+
+
+def color_jitter(img, transform):
+ '''
+ Args:
+ img: np.float32 [h, w, 3]
+ transform:
+ Returns: transformed np.float32
+ '''
+ img = Image.fromarray((255.*img).astype(np.uint8))
+ img_trans = transform(img)
+ img_trans = np.array(img_trans).astype(np.float32) / 255.
+ return img_trans
+
+
+def color_jitter_all_rgbs(rgb, ref_rgbs, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
+ transform = get_color_jitter_params(brightness, contrast, saturation, hue)
+ rgb_trans = color_jitter(rgb, transform)
+ ref_rgbs_trans = []
+ for ref_rgb in ref_rgbs:
+ ref_rgbs_trans.append(color_jitter(ref_rgb, transform))
+
+ ref_rgbs_trans = np.array(ref_rgbs_trans)
+ return rgb_trans, ref_rgbs_trans
+
+
+def deepvoxels_parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
+ # Get camera intrinsics
+ with open(filepath, 'r') as file:
+ f, cx, cy = list(map(float, file.readline().split()))[:3]
+ grid_barycenter = torch.Tensor(list(map(float, file.readline().split())))
+ near_plane = float(file.readline())
+ scale = float(file.readline())
+ height, width = map(float, file.readline().split())
+
+ try:
+ world2cam_poses = int(file.readline())
+ except ValueError:
+ world2cam_poses = None
+
+ if world2cam_poses is None:
+ world2cam_poses = False
+
+ world2cam_poses = bool(world2cam_poses)
+
+ cx = cx / width * trgt_sidelength
+ cy = cy / height * trgt_sidelength
+ f = trgt_sidelength / height * f
+
+ fx = f
+ if invert_y:
+ fy = -f
+ else:
+ fy = f
+
+ # Build the intrinsic matrices
+ full_intrinsic = np.array([[fx, 0., cx, 0.],
+ [0., fy, cy, 0],
+ [0., 0, 1, 0],
+ [0, 0, 0, 1]])
+
+ return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
+
+
+def angular_dist_between_2_vectors(vec1, vec2):
+ vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER)
+ vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER)
+ angular_dists = np.arccos(np.clip(np.sum(vec1_unit*vec2_unit, axis=-1), -1.0, 1.0))
+ return angular_dists
+
+
+def batched_angular_dist_rot_matrix(R1, R2):
+ '''
+ calculate the angular distance between two rotation matrices (batched)
+ :param R1: the first rotation matrix [N, 3, 3]
+ :param R2: the second rotation matrix [N, 3, 3]
+ :return: angular distance in radiance [N, ]
+ '''
+ assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3
+ return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2.,
+ a_min=-1 + TINY_NUMBER, a_max=1 - TINY_NUMBER))
+
+
+def get_nearest_pose_ids(tar_pose, ref_poses, num_select, tar_id=-1, angular_dist_method='vector',
+ scene_center=(0, 0, 0)):
+ '''
+ Args:
+ tar_pose: target pose [3, 3]
+ ref_poses: reference poses [N, 3, 3]
+ num_select: the number of nearest views to select
+ Returns: the selected indices
+ '''
+ num_cams = len(ref_poses)
+ num_select = min(num_select, num_cams-1)
+ batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0)
+
+ if angular_dist_method == 'matrix':
+ dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3])
+ elif angular_dist_method == 'vector':
+ tar_cam_locs = batched_tar_pose[:, :3, 3]
+ ref_cam_locs = ref_poses[:, :3, 3]
+ scene_center = np.array(scene_center)[None, ...]
+ tar_vectors = tar_cam_locs - scene_center
+ ref_vectors = ref_cam_locs - scene_center
+ dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors)
+ elif angular_dist_method == 'dist':
+ tar_cam_locs = batched_tar_pose[:, :3, 3]
+ ref_cam_locs = ref_poses[:, :3, 3]
+ dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1)
+ elif angular_dist_method == 'mix':
+ ang_dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3])
+ dists = 0.5 * ang_dists + 0.5 * np.linalg.norm(batched_tar_pose[:, :3, 3] - ref_poses[:, :3, 3], axis=1)
+ else:
+ raise Exception('unknown angular distance calculation method!')
+
+ sorted_ids = np.argsort(dists)
+ if tar_id >= 0:
+ sorted_ids = sorted_ids[1:]
+ selected_ids = sorted_ids[:num_select]
+ # print(angular_dists[selected_ids] * 180 / np.pi)
+ return selected_ids
+
+
+def resize(img: Image, cams, size):
+ r'''
+ Args:
+ img: PIL.Image
+ cams: [27] hw + intrinsics (3x3) + pose (4x4)
+ '''
+ W, H = img.size
+ img = img.resize(size, Image.LANCZOS)
+ w, h = size
+ cams[0] = h
+ cams[1] = w
+ cams[2:11] = np.reshape(
+ np.diag([w/W, h/H, 1]) @ np.reshape(cams[2:11], [3, 3]),
+ [9])
+ return img, cams
diff --git a/dataset/dtu_dataset.py b/dataset/dtu_dataset.py
new file mode 100644
index 0000000..0460dd9
--- /dev/null
+++ b/dataset/dtu_dataset.py
@@ -0,0 +1,435 @@
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+
+from pathlib import Path
+import torch
+from glob import glob
+from PIL import Image
+import numpy as np
+from tqdm import tqdm
+import cv2
+from util.misc import SubSampler
+from pathlib import Path
+from util.ray import get_rays
+from dataset.data_utils import get_nearest_pose_ids
+
+
+class DTUDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ cfg,
+ split="train",
+ ):
+ """
+ :param path dataset root path, contains metadata.yml
+ :param split train | val | test
+ :param list_prefix prefix for split lists: [train, val, test].lst
+ """
+ super().__init__()
+ self.data_root = cfg.dataset_root
+ self.normalize = cfg.normalize
+ self.depth_range = [0.1, 5.0]
+ self.split = split
+ self.img_root = f'image'
+ dino_patch_size = 14
+ self.dino_feats_size = cfg.dino_feats_size
+ self.feat_map_size = cfg.dino_feats_size[0] // dino_patch_size, cfg.dino_feats_size[1] // dino_patch_size
+ self.ray_batchsize = cfg.ray_batchsize
+ self.subsampler = SubSampler()
+ self.num_src_view = cfg.num_src_view
+
+ self.scene_id = cfg.scene_id
+
+ self.img_size = cfg.img_size[0], cfg.img_size[1]
+ # sub_format == "dtu":
+ self._coord_trans_world = torch.tensor(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ )
+ self._coord_trans_cam = torch.tensor(
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ )
+
+ self.setup_data()
+
+ def __len__(self):
+ return sum([len(x) for x in self.indices])
+
+ def setup_data(self):
+ path = "new_val.lst" if self.split != "train" else "new_train.lst"
+ scene_list_path = os.path.join(self.data_root, path)
+ scenes = []
+ with open(scene_list_path, "r") as f:
+ for line in f:
+ scenes.append(line.strip())
+
+ self.scenes = [os.path.join(self.data_root, scene) for scene in scenes]
+ if self.scene_id != -1:
+ self.scenes = [self.scenes[self.scene_id]]
+ self.train_indices = []
+ self.val_indices = []
+ self.all_frame_names = []
+ self.cam2normscene = []
+ self.cam2scenes = []
+ self.intrinsics = []
+ self.scene2normscene = []
+ self.normscene_scale = []
+ self.all_cams = []
+ for scene in tqdm(self.scenes, desc="Loading DTU dataset"):
+ self.setup_one_scene(scene)
+ print(
+ "Loading DTU dataset from", self.data_root,
+ 'found', len(self.scenes),
+ self.split,"scenes",
+ )
+ if self.split == 'train':
+ self.indices = self.train_indices
+ else:
+ self.indices = self.val_indices
+ # self.indices = [list(range(len(self.all_frame_names[0])))]
+ self.frame_idx2scene_idx = {}
+ self.frame_idx2sample_idx = {}
+ frame_idx = 0
+ for scene_idx, scene_indices in enumerate(self.indices):
+ for _, sample_idx in enumerate(scene_indices):
+ self.frame_idx2scene_idx[frame_idx] = scene_idx
+ self.frame_idx2sample_idx[frame_idx] = sample_idx
+ frame_idx += 1
+ print("depth range", self.depth_range)
+
+ def setup_one_scene(self, scene):
+ all_frames = sorted(glob(os.path.join(scene, self.img_root, "*")))
+ all_frames = [x.split("/")[-1].split(".")[0] for x in all_frames]
+ self.all_frame_names.append(all_frames)
+
+ fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0
+ cam2scene = []
+ dims = []
+ img_w, img_h = Image.open(os.path.join(scene, self.img_root, f'{all_frames[0]}.png')).size
+ for i in range(len(all_frames)):
+ cams = np.load(os.path.join(self.data_root, scene, "cameras.npz"))
+
+ # Decompose projection matrix
+ P = cams["world_mat_" + str(i)]
+ P = P[:3]
+ K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
+ K = K / K[2, 2]
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose()
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ scale_mtx = cams.get("scale_mat_" + str(i))
+ if scale_mtx is not None:
+ norm_trans = scale_mtx[:3, 3:]
+ norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None]
+ pose[:3, 3:] -= norm_trans
+ pose[:3, 3:] /= norm_scale
+
+ fx += torch.tensor(K[0, 0])
+ fy += torch.tensor(K[1, 1])
+ cx += torch.tensor(K[0, 2])
+ cy += torch.tensor(K[1, 2])
+
+ pose = torch.tensor(pose, dtype=torch.float32)
+ cam2scene.append(pose)
+ dims.append([img_h, img_w])
+ fx /= len(all_frames)
+ fy /= len(all_frames)
+ cx /= len(all_frames)
+ cy /= len(all_frames)
+ intrinsic = torch.tensor([
+ [fx, 0, cx, 0],
+ [0, fy, cy, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]
+ ])
+ cam2scene = torch.stack(cam2scene).float()
+ intrinsics = intrinsic.unsqueeze(0).expand(len(cam2scene), -1, -1).float()
+ cams = torch.cat([
+ torch.Tensor(self.img_size).expand(len(cam2scene), -1),
+ intrinsics.flatten(1, 2),
+ cam2scene.flatten(1, 2)], dim=1)
+ self.all_cams.append(cams)
+ self.cam2scenes.append(cam2scene)
+
+ if self.split == "train":
+ indices = list(range(len(all_frames)))
+ self.train_indices.append(indices)
+ else:
+ indices = list(range(len(all_frames)))
+ val_indices = indices[::8]
+ train_indices = [i for i in indices if i not in val_indices]
+ self.val_indices.append(val_indices)
+ self.train_indices.append(train_indices)
+
+ def load_sample(self, sample_index, scene_idx):
+ scene_dir = self.scenes[scene_idx]
+ image = Image.open(Path(scene_dir) / f"{self.img_root}" / f"{self.all_frame_names[scene_idx][sample_index]}.png")
+ image = torch.from_numpy(np.array(image) / 255).float() # [H, W, 3]
+ if self.normalize:
+ image = image * 2 - 1 # [-1, 1]
+ return image.view(-1, 3), torch.zeros(1)
+
+ def sample_support_views(self, pose, scene_idx):
+ # sample support ids
+ support_indices = self.train_indices[scene_idx]
+ if self.split == 'train':
+ subsample_factor = 1
+ nearest_pose_ids = get_nearest_pose_ids(pose.numpy(),
+ self.cam2scenes[scene_idx][support_indices].numpy(),
+ min(self.num_src_view * subsample_factor, 22),
+ tar_id=0,
+ angular_dist_method='mix')
+ nearest_pose_ids = np.random.choice(nearest_pose_ids, self.num_src_view, replace=False)
+ else:
+ nearest_pose_ids = get_nearest_pose_ids(pose.numpy(),
+ self.cam2scenes[scene_idx][support_indices].numpy(),
+ self.num_src_view,
+ tar_id=-1,
+ angular_dist_method='mix')
+ nearest_pose_ids = torch.from_numpy(nearest_pose_ids).long()
+
+ return nearest_pose_ids
+
+ def __getitem__(self, idx):
+ scene_idx = self.frame_idx2scene_idx[idx]
+ sample_idx = self.frame_idx2sample_idx[idx]
+ sample = {}
+
+ # sample support ids
+ pose = self.cam2scenes[scene_idx][sample_idx]
+ nearest_pose_ids = self.sample_support_views(pose, scene_idx)
+ support_indices = self.train_indices[scene_idx]
+ src_rgbs, src_cams, src_feats = [], [], []
+ for i in nearest_pose_ids:
+ id = support_indices[i]
+ src_cam = self.all_cams[scene_idx][id]
+ src_rgb, src_feat = self.load_sample(id, scene_idx)
+ src_rgbs.append(src_rgb)
+ src_cams.append(src_cam)
+ src_feats.append(src_feat)
+ src_rgbs = torch.stack(src_rgbs) # [N, H*W, 3]
+ src_cams = torch.stack(src_cams) # [N, 34]
+ src_feats = torch.stack(src_feats) # [N, H1, W1, D]
+
+ rgbs, feat = self.load_sample(sample_idx, scene_idx)
+ tgt_cam = self.all_cams[scene_idx][[sample_idx]]
+ tgt_rays = get_rays(tgt_cam, *self.img_size).view(-1, 6) # [HW, 3]
+ H, W = self.img_size
+ N = self.num_src_view
+ if self.split == 'train':
+ # subsample
+ Br = self.ray_batchsize
+ subsample_idx = self.subsampler.idx_subsample(self.img_size, Br)
+ tgt_rgbs = rgbs.gather(0, subsample_idx.expand(-1, 3)) # [Br, 3]
+ tgt_rays = tgt_rays.gather(0, subsample_idx.expand(-1, 6)) # [Br, 3]
+ else:
+ tgt_rgbs = rgbs[None] # [1, HW, 3], [1, HW]
+ tgt_rays = tgt_rays[None] # [1, HW, 3]
+ sample['rgbs'] = tgt_rgbs # [Br, 3] or [N1, HW, 3]
+ sample['rays'] = tgt_rays # [Br, 3] or [N1, HW, 3]
+ sample['cam'] = tgt_cam # [Br, 34] or [N1, 34]
+ sample['src_rgbs'] = src_rgbs.reshape(N, H, W, 3) # [N, HW,3]
+ sample['src_cams'] = src_cams # [N, 34] 2+9+16
+ sample['src_feats'] = src_feats # [N, H1, W1, D]
+ sample['depth_range'] = torch.tensor(self.depth_range).float()
+ return sample
+
+
+class VisDTUDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ cfg,
+ split="train",
+ ):
+ """
+ :param path dataset root path, contains metadata.yml
+ :param split train | val | test
+ :param list_prefix prefix for split lists: [train, val, test].lst
+ """
+ super().__init__()
+ self.data_root = cfg.dataset_root
+ self.normalize = cfg.normalize
+ self.depth_range = [0.1, 5.0]
+ self.img_root = f'image'
+ self.num_src_view = cfg.num_src_view
+
+ self.scene_id = cfg.scene_id
+ self.img_size = cfg.img_size[0], cfg.img_size[1]
+ # sub_format == "dtu":
+ self._coord_trans_world = torch.tensor(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ )
+ self._coord_trans_cam = torch.tensor(
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ )
+ self.num_vis = cfg.num_vis
+ self.setup_data()
+
+ def __len__(self):
+ return self.num_vis
+
+ def setup_data(self):
+ path = "new_val.lst"
+ scene_list_path = os.path.join(self.data_root, path)
+ scenes = []
+ with open(scene_list_path, "r") as f:
+ for line in f:
+ scenes.append(line.strip())
+
+ scenes = [os.path.join(self.data_root, scene) for scene in scenes]
+ self.scene = scenes[self.scene_id]
+ self.setup_one_scene(self.scene)
+
+ def setup_one_scene(self, scene):
+ all_frames = sorted(glob(os.path.join(scene, self.img_root, "*")))
+ all_frames = [x.split("/")[-1].split(".")[0] for x in all_frames]
+ self.all_frame = all_frames
+
+ fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0
+ cam2scene = []
+ dims = []
+ img_w, img_h = Image.open(os.path.join(scene, self.img_root, f'{all_frames[0]}.png')).size
+ for i in range(len(all_frames)):
+ cams = np.load(os.path.join(self.data_root, scene, "cameras.npz"))
+
+ # Decompose projection matrix
+ P = cams["world_mat_" + str(i)]
+ P = P[:3]
+ K, R, t = cv2.decomposeProjectionMatrix(P)[:3]
+ K = K / K[2, 2]
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose()
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ scale_mtx = cams.get("scale_mat_" + str(i))
+ if scale_mtx is not None:
+ norm_trans = scale_mtx[:3, 3:]
+ norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None]
+ pose[:3, 3:] -= norm_trans
+ pose[:3, 3:] /= norm_scale
+
+ fx += torch.tensor(K[0, 0])
+ fy += torch.tensor(K[1, 1])
+ cx += torch.tensor(K[0, 2])
+ cy += torch.tensor(K[1, 2])
+
+ pose = torch.tensor(pose, dtype=torch.float32)
+ cam2scene.append(pose)
+ dims.append([img_h, img_w])
+ fx /= len(all_frames)
+ fy /= len(all_frames)
+ cx /= len(all_frames)
+ cy /= len(all_frames)
+ intrinsic = torch.tensor([
+ [fx, 0, cx, 0],
+ [0, fy, cy, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]
+ ])
+ cam2scene = torch.stack(cam2scene).float()
+ intrinsics = intrinsic.unsqueeze(0).expand(len(cam2scene), -1, -1).float()
+ cams = torch.cat([
+ torch.Tensor(self.img_size).expand(len(cam2scene), -1),
+ intrinsics.flatten(1, 2),
+ cam2scene.flatten(1, 2)], dim=1)
+ self.all_cams = cams
+ self.cam2scenes = cam2scene
+ indices = list(range(len(all_frames)))
+ self.indices = indices
+
+
+ N_poses = len(cam2scene)
+ tgt_cam2scene = cam2scene
+ tgt_cams = [torch.cat([torch.Tensor(self.img_size), intrinsics.flatten(),
+ tgt_cam2scene[i].flatten()])[None] for i in range(self.num_vis)]
+ tgt_cams = torch.cat(tgt_cams, dim=0) # [N, 34]
+ tgt_rays = get_rays(tgt_cams, *self.img_size) # [N, HW, 6]
+ self.tgt_poses = tgt_cam2scene
+ self.tgt_cams = tgt_cams
+ self.tgt_rays = tgt_rays
+
+ def load_sample(self, sample_index):
+ image = Image.open(Path(self.scene) / f"{self.img_root}" / f"{self.all_frame[sample_index]}.png")
+ image = torch.from_numpy(np.array(image) / 255).float() # [H, W, 3]
+ if self.normalize:
+ image = image * 2 - 1 # [-1, 1]
+ return image.view(-1, 3), torch.zeros(1)
+
+ def sample_support_views(self, pose):
+ nearest_pose_ids = get_nearest_pose_ids(pose.numpy(),
+ self.cam2scenes.numpy(),
+ self.num_src_view,
+ tar_id=-1,
+ angular_dist_method='mix')
+ nearest_pose_ids = torch.from_numpy(nearest_pose_ids).long()
+ return nearest_pose_ids
+
+ def __getitem__(self, idx):
+ sample = {}
+
+ # sample support ids
+ pose = self.tgt_poses[idx]
+
+ nearest_pose_ids = self.sample_support_views(pose)
+ src_rgbs, src_cams, src_feats = [], [], []
+ for id in nearest_pose_ids:
+ src_cam = self.all_cams[id]
+ src_rgb, src_feat = self.load_sample(id)
+ src_rgbs.append(src_rgb)
+ src_cams.append(src_cam)
+ src_feats.append(src_feat)
+ src_rgbs = torch.stack(src_rgbs) # [N, H*W, 3]
+ src_cams = torch.stack(src_cams) # [N, 34]
+ src_feats = torch.stack(src_feats) # [N, H1, W1, D]
+
+ H, W = self.img_size
+ N = self.num_src_view
+ sample['rays'] = self.tgt_rays # [Br, 3] or [N1, HW, 3]
+ sample['cam'] = self.tgt_cams # [Br, 34] or [N1, 34]
+ sample['src_rgbs'] = src_rgbs.reshape(N, H, W, 3) # [N, HW,3]
+ sample['src_cams'] = src_cams # [N, 34] 2+9+16
+ sample['src_feats'] = src_feats # [N, H1, W1, D]
+ sample['depth_range'] = torch.tensor(self.depth_range).float()
+ return sample
+
+
+def rot_matrix_angular_dist(R1, R2):
+ assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3
+ return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2.,
+ a_min=-1 + 1e-6, a_max=1 - 1e-6))
+
+import hydra
+from torch.utils.data import DataLoader
+
+@hydra.main(config_path='../config/cfg', config_name='dtu', version_base='1.2')
+def main(config):
+ config.num_workers = 0
+ config.lambda_depth = 0.1
+
+ train_set = VisDTUDataset(config, "train")
+ val_set = VisDTUDataset(config, "val")
+ test_set = VisDTUDataset(config, "test")
+
+ train_loader = DataLoader(train_set, batch_size=config.batch_size, num_workers=config.num_workers)
+ # val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=config.num_workers)
+ test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=config.num_workers)
+ for i, batch in tqdm(enumerate(train_loader), total=len(train_set)):
+ print(batch['cam'].shape)
+ print(batch['rgbs'].shape)
+ break
+ for i, batch in tqdm(enumerate(test_loader), total=len(test_set)):
+ print(batch['cam'].shape)
+ print(batch['rgbs'].shape)
+ break
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/dataset/preprocess_scannet.py b/dataset/preprocess_scannet.py
new file mode 100644
index 0000000..e66b462
--- /dev/null
+++ b/dataset/preprocess_scannet.py
@@ -0,0 +1,145 @@
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-3])
+sys.path.append(root_path)
+
+import cv2
+import glob
+import json
+import argparse
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+from pathlib import Path
+from multiprocessing import cpu_count, Pool
+
+
+def get_keyframe_indices(filenames, window_size):
+ """
+ select non-blurry images within a moving window
+ """
+ scores = []
+ cores = cpu_count() - 1
+ # print("Using", cores, "cores")
+ with Pool(cores) as pool:
+ for result in pool.map(compute_blur_score_opencv, filenames) :
+ scores.append(result)
+ keyframes = [i + np.argmin(scores[i:i + window_size]) for i in range(0, len(scores), window_size)]
+ return keyframes, scores
+
+
+def compute_blur_score_opencv(filename):
+ """
+ Estimate the amount of blur an image has with the variance of the Laplacian.
+ Normalize by pixel number to offset the effect of image size on pixel gradients & variance
+ https://github.com/deepfakes/faceswap/blob/ac40b0f52f5a745aa058f92339302065177dd28b/tools/sort/sort.py#L626
+ """
+ image = cv2.imread(str(filename))
+ if image.ndim == 3:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+ blur_map = cv2.Laplacian(image, cv2.CV_32F)
+ score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1])
+ return 1.0 - score
+
+
+def subsample_scannet(src_folder, rate):
+ """
+ sample every nth frame from scannet
+ """
+ all_frames = sorted(list(x.stem for x in (src_folder / 'pose').iterdir()), key=lambda y: int(y) if y.isnumeric() else y)
+ total_sampled = int(len(all_frames) * rate)
+ sampled_frames = [all_frames[i * (len(all_frames) // total_sampled)] for i in range(total_sampled)]
+ unsampled_frames = [x for x in all_frames if x not in sampled_frames]
+ for frame in sampled_frames:
+ if 'inf' in Path(src_folder / "pose" / f"{frame}.txt").read_text():
+ unsampled_frames.append(frame)
+ folders = ["color", "depth", "instance", "pose", "semantics"]
+ exts = ['jpg', 'png', 'png', 'txt', 'png']
+ for folder, ext in tqdm(zip(folders, exts), desc='sampling'):
+ assert (src_folder / folder).exists(), src_folder
+ for frame in unsampled_frames:
+ if (src_folder / folder / f'{frame}.{ext}').exists():
+ os.remove(str(src_folder / folder / f'{frame}.{ext}'))
+ else:
+ print(str(src_folder / folder / f'{frame}.{ext}'), "already exists!")
+
+
+def subsample_scannet_blur_window(src_folder, min_frames):
+ """
+ sample non blurry frames from scannet
+ """
+ if os.path.exists(src_folder / f"sampled_frames.json"):
+ print("sampled_frames.json already exists, skipping")
+ return
+ scene_name = src_folder.name
+ all_frames = sorted(list(x.stem for x in (src_folder / f'pose').iterdir()), key=lambda y: int(y) if y.isnumeric() else y)
+ valid_frames = []
+ for frame in all_frames:
+ if 'inf' not in Path(src_folder / f"pose" / f"{frame}.txt").read_text():
+ valid_frames.append(frame)
+ print("Found", len(all_frames), "frames, ", len(valid_frames), "are valid")
+ valid_frame_paths = [Path(src_folder / f"color" / f"{frame}.jpg") for frame in valid_frames]
+ window_size = max(3, len(valid_frames) // min_frames)
+ frame_indices, _ = get_keyframe_indices(valid_frame_paths, window_size)
+ print("Using a window size of", window_size, "got", len(frame_indices), "frames")
+ sampled_frames = [valid_frames[i] for i in frame_indices]
+ # save as json
+ json.dump(sampled_frames, open(src_folder / f"sampled_frames.json", 'w'), indent=4)
+
+
+def resize_files(img_paths, resize_depth=False):
+ for img_path in img_paths:
+ img = Image.open(img_path) # 1296x968
+ if not os.path.exists(img_path.replace('color', 'color_512512')):
+ img1 = img.resize([512, 512], Image.Resampling.LANCZOS)
+ img1.save(img_path.replace('color', 'color_512512'))
+ if not os.path.exists(img_path.replace('color', 'color_480640')):
+ img2 = img.resize([640, 480], Image.Resampling.LANCZOS)
+ img2.save(img_path.replace('color', 'color_480640'))
+ if resize_depth and not os.path.exists(img_path.replace('color', 'depth_512512')):
+ p_depth = img_path.replace('color', 'depth').replace('.jpg', '.png')
+ depth = Image.open(p_depth)
+ depth1 = depth.resize([512, 512], Image.Resampling.NEAREST)
+ depth1.save(p_depth.replace('depth', 'depth_512512'))
+
+
+def process_one_scene(scene_path):
+ dest = scene_path
+ print('#' * 80)
+ scene_name = path.split('/')[-1]
+ print(f'subsampling from {scene_name}...')
+ subsample_scannet_blur_window(dest, min_frames=400)
+
+ print('resizing images...')
+ os.makedirs(f'{path}/color_512512', exist_ok=True)
+ os.makedirs(f'{path}/color_480640', exist_ok=True)
+ os.makedirs(f'{path}/depth_512512', exist_ok=True)
+
+ img_ids = json.load(open(f'{path}/sampled_frames.json', 'r'))
+ img_paths = [f'{path}/color/{img_id}.jpg' for img_id in img_ids]
+ resize_files(img_paths, resize_depth=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='scannet preprocessing')
+ parser.add_argument('--data_root', required=False, default='/home/yuliu/Dataset/scannet', help='file path')
+ args = parser.parse_args()
+
+ scene_paths = sorted(glob.glob(f'{args.data_root}/*'))
+ for path in scene_paths:
+ scene_root = Path(path)
+ process_one_scene(scene_root)
+
+ test_root = args.data_root.replace('scannet', 'scannet_test')
+ scene_paths = sorted(glob.glob(f'{test_root}/*'))
+ for path in scene_paths:
+ scene_name = path.split('/')[-1]
+ print(f'resizing images for {scene_name}...')
+ os.makedirs(f'{path}/color_512512', exist_ok=True)
+ os.makedirs(f'{path}/color_480640', exist_ok=True)
+ os.makedirs(f'{path}/depth_512512', exist_ok=True)
+
+ img_paths = glob.glob(f'{path}/color/*.jpg')
+ resize_files(img_paths, resize_depth=True)
+
diff --git a/dataset/scannet_dataset.py b/dataset/scannet_dataset.py
new file mode 100644
index 0000000..d082a5a
--- /dev/null
+++ b/dataset/scannet_dataset.py
@@ -0,0 +1,263 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+from pathlib import Path
+
+import torch
+import numpy as np
+from PIL import Image
+import hydra
+import json
+from tqdm import tqdm
+from glob import glob
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+from util.camera import compute_world2normscene
+from util.misc import SubSampler
+from util.ray import get_rays
+from dataset.data_utils import get_nearest_pose_ids
+
+
+def move_left_zero(x):
+ return '0' if int(x) == 0 else x.lstrip('0')
+
+
+class ScannetDataset(Dataset):
+ def __init__(self, cfg, split):
+ super().__init__()
+ self.split = 'test' if split != 'train' else 'train'
+ self.root_dir = Path(cfg.dataset_root)
+ if self.split == 'test':
+ self.root_dir = Path(cfg.dataset_root.replace('scannet', 'scannet_test'))
+ self.img_size = cfg.img_size[0], cfg.img_size[1]
+
+ self.img_root = f'color_480640'
+
+ self.train_subsample_frames = cfg.train_subsample_frames
+ self.val_subsample_frames = cfg.val_subsample_frames
+ self.num_src_view = cfg.num_src_view
+ self.world2scene = np.eye(4, dtype=np.float32)
+ self.depth_range = [0.1, 10]
+ self.max_depth = 10
+ self.normalize = cfg.normalize
+ self.ray_batchsize = cfg.ray_batchsize
+ self.norm_scene = cfg.norm_scene
+ self.subsampler = SubSampler()
+ self.setup_data()
+
+ def __len__(self):
+ return sum([len(x) for x in self.indices])
+
+ def setup_data(self):
+ scenes = sorted(glob(f'{str(self.root_dir)}/*'))
+ test_scenes = ["scene0000_01",
+ "scene0079_00",
+ "scene0158_00",
+ "scene0316_00",
+ "scene0521_00",
+ "scene0553_00",
+ "scene0616_00",
+ "scene0653_00",
+ ]
+ self.scenes = scenes
+ print(f"Using {len(self.scenes)} scenes for {self.split}")
+ scene_names = [x.split('/')[-1] for x in self.scenes]
+ print(self.split, scene_names)
+ self.train_indices = []
+ self.val_indices = []
+ self.all_frame_names = []
+ self.cam2normscene = []
+ self.cam2scenes = []
+ self.intrinsics = []
+ self.scene2normscene = []
+ self.normscene_scale = []
+ self.segmentation_data = []
+
+ for idx, scene_dir in enumerate(self.scenes):
+ self.setup_data_one_scene(idx)
+ if self.split == 'train':
+ self.indices = self.train_indices
+ else:
+ self.indices = self.val_indices
+ self.frame_idx2scene_idx = {}
+ self.frame_idx2sample_idx = {}
+ frame_idx = 0
+ for scene_idx, scene_indices in enumerate(self.indices):
+ for _, sample_idx in enumerate(scene_indices):
+ self.frame_idx2scene_idx[frame_idx] = scene_idx
+ self.frame_idx2sample_idx[frame_idx] = sample_idx
+ frame_idx += 1
+ if self.norm_scene:
+ self.depth_range = self.depth_range[0] * min(self.normscene_scale), self.depth_range[1] * max(self.normscene_scale)
+ print(f"Depth range: {self.depth_range}")
+
+ def setup_data_one_scene(self, scene_idx):
+ scene_dir = self.scenes[scene_idx]
+ # split
+ if self.split == 'train' or self.split == 'val':
+ frames = json.loads((Path(scene_dir) / "sampled_frames.json").read_text())
+ scene_frames = sorted([x.zfill(4) for x in frames])
+ scene_frames = [move_left_zero(x) for x in scene_frames]
+ sample_indices = list(range(len(scene_frames)))
+ val_indices = sample_indices[::8]
+ train_indices = [x for x in sample_indices if x not in val_indices]
+ val_indices = val_indices[::self.val_subsample_frames]
+ else:
+ train_frames = np.loadtxt(Path(scene_dir) / "train.txt", dtype=str)
+ val_frames = np.loadtxt(Path(scene_dir) / "test.txt", dtype=str)
+ scene_frames = train_frames.tolist() + val_frames.tolist()
+ scene_frames = [x.split('.')[0] for x in scene_frames]
+ train_indices = list(range(len(train_frames)))
+ val_indices = list(range(len(train_frames), len(scene_frames)))
+ sample_indices = list(range(len(scene_frames)))
+
+ # print(f"Loading {scene_name}")
+ self.train_indices.append(train_indices)
+ self.val_indices.append(val_indices)
+ self.all_frame_names.append(scene_frames)
+
+ dims, cam2scene = [], []
+ img_h, img_w = 968, 1296
+ intrinsic_color = np.array([[float(y.strip()) for y in x.strip().split()] for x in (Path(scene_dir) / f"intrinsic" / "intrinsic_color.txt").read_text().splitlines() if x != ''])
+ intrinsic_color = torch.from_numpy(intrinsic_color[:3, :3]).float()
+ scale_x, scale_y = self.img_size[1] / img_w, self.img_size[0] / img_h
+ intrinsic_normed = torch.diag(torch.Tensor([scale_x, scale_y, 1])) @ intrinsic_color
+ intrinsic = torch.eye(4)
+ intrinsic[:3, :3] = intrinsic_normed
+ self.intrinsics.append(intrinsic)
+
+ for sample_index in sample_indices:
+ cam2world = np.array([[float(y.strip()) for y in x.strip().split()] for x in (Path(scene_dir) / f"pose" / f"{scene_frames[sample_index]}.txt").read_text().splitlines() if x != ''])
+ cam2world = torch.from_numpy(self.world2scene @ cam2world).float()
+ cam2scene.append(cam2world)
+ dims.append([img_h, img_w])
+
+ cam2scene = torch.stack(cam2scene)
+ intrinsics = intrinsic_color.unsqueeze(0).expand(len(cam2scene), -1, -1)
+ scene2normscene = compute_world2normscene(
+ torch.Tensor(dims).float(),
+ intrinsics,
+ cam2scene,
+ max_depth=self.max_depth,
+ rescale_factor=1.0
+ )
+ self.scene2normscene.append(scene2normscene)
+ self.normscene_scale.append(scene2normscene[0, 0])
+ cam2normscene = []
+ for sample_index in sample_indices:
+ cam2normscene.append(scene2normscene @ cam2scene[sample_index])
+ cam2normscene = torch.stack(cam2normscene)
+
+ self.cam2normscene.append(cam2normscene)
+ self.cam2scenes.append(cam2scene)
+
+ def load_sample(self, sample_index, scene_idx):
+ scene_dir = self.scenes[scene_idx]
+ image = Image.open(Path(scene_dir) / f"{self.img_root}" / f"{self.all_frame_names[scene_idx][sample_index]}.jpg")
+ # image = image.resize(self.img_size[::-1], Image.BILINEAR)
+ image = torch.from_numpy(np.array(image) / 255).float() # [H, W, 3]
+ if self.normalize:
+ image = image * 2 - 1 # [-1, 1]
+ return image.view(-1, 3)
+
+ def sample_support_views(self, pose, scene_idx):
+ # sample support ids
+ support_indices = self.train_indices[scene_idx]
+ if self.split == 'train':
+ subsample_factor = np.random.choice(np.arange(1, 4), p=[0.2, 0.45, 0.35])
+ nearest_pose_ids = get_nearest_pose_ids(pose.numpy(),
+ self.cam2scenes[scene_idx][support_indices].numpy(),
+ min(self.num_src_view * subsample_factor, 22),
+ tar_id=0,
+ angular_dist_method='mix')
+ nearest_pose_ids = np.random.choice(nearest_pose_ids, self.num_src_view, replace=False)
+ else:
+ nearest_pose_ids = get_nearest_pose_ids(pose.numpy(),
+ self.cam2scenes[scene_idx][support_indices].numpy(),
+ self.num_src_view,
+ tar_id=-1,
+ angular_dist_method='mix')
+ nearest_pose_ids = torch.from_numpy(nearest_pose_ids).long()
+
+ return nearest_pose_ids
+
+ def __getitem__(self, idx):
+ scene_idx = self.frame_idx2scene_idx[idx]
+ sample_idx = self.frame_idx2sample_idx[idx]
+ sample = {}
+
+ # sample support ids
+ poses = self.cam2normscene[scene_idx] if self.norm_scene else self.cam2scenes[scene_idx]
+ pose = self.cam2scenes[scene_idx][sample_idx]
+ nearest_pose_ids = self.sample_support_views(pose, scene_idx)
+ support_indices = self.train_indices[scene_idx]
+ src_rgbs, src_cams = [], []
+ for i in nearest_pose_ids:
+ id = support_indices[i]
+ src_cam = torch.cat([torch.Tensor(self.img_size), self.intrinsics[scene_idx].flatten(),
+ poses[id].flatten()])
+ src_rgb = self.load_sample(id, scene_idx)
+ src_rgbs.append(src_rgb)
+ src_cams.append(src_cam)
+ src_rgbs = torch.stack(src_rgbs) # [N, H*W, 3]
+ src_cams = torch.stack(src_cams) # [N, 34]
+
+ rgbs = self.load_sample(sample_idx, scene_idx)
+ cam = torch.cat([torch.Tensor(self.img_size), self.intrinsics[scene_idx].flatten(),
+ poses[sample_idx].flatten()])
+ tgt_cam = cam[None] # [1, 34]
+ tgt_rays = get_rays(tgt_cam, *self.img_size).view(-1, 6) # [HW, 3]
+ H, W = self.img_size
+ N = self.num_src_view
+ if self.split == 'train':
+ # subsample
+ Br = self.ray_batchsize
+ subsample_idx = self.subsampler.idx_subsample(self.img_size, Br)
+ sample['idx'] = subsample_idx # [Br, 1]
+ tgt_rgbs = rgbs.gather(0, subsample_idx.expand(-1, 3)) # [Br, 3]
+ tgt_rays = tgt_rays.gather(0, subsample_idx.expand(-1, 6)) # [Br, 3]
+ else:
+ tgt_rgbs = rgbs[None]# [1, HW, 3]
+ tgt_rays = tgt_rays[None] # [1, HW, 3]
+ semantics = Image.open(Path(self.scenes[scene_idx]) / "semantics" / f"{self.all_frame_names[scene_idx][sample_idx]}.png")
+ semantics = torch.tensor(np.array(semantics.resize(self.img_size[::-1], Image.Resampling.NEAREST), np.int32)).long().reshape(-1)
+ instances = Image.open(Path(self.scenes[scene_idx]) / "instance" / f"{self.all_frame_names[scene_idx][sample_idx]}.png")
+ instances = torch.tensor(np.array(instances.resize(self.img_size[::-1], Image.Resampling.NEAREST), np.int32)).long().reshape(-1)
+ sample['semantics'] = semantics # [H*W]
+ sample['instances'] = instances # [H*W]
+ sample['rgbs'] = tgt_rgbs # [Br, 3] or [N1, HW, 3]
+ sample['rays'] = tgt_rays # [Br, 3] or [N1, HW, 3]
+ sample['cam'] = tgt_cam # [Br, 34] or [N1, 34]
+ sample['src_rgbs'] = src_rgbs.reshape(N, H, W, 3) # [N, HW,3]
+ sample['src_cams'] = src_cams # [N, 34] 2+16+16
+ sample['depth_range'] = torch.tensor(self.depth_range).float()
+ return sample
+
+
+
+## test
+from torch.utils.data import DataLoader
+
+@hydra.main(config_path='../config/cfg', config_name='scannet', version_base='1.2')
+def main(config):
+ config.img_size = [480, 640]
+ config.dino_feats_size = [210, 280]
+ config.num_workers = 0
+ train_set = ScannetDataset(config, "train")
+ test_set = ScannetDataset(config, "test")
+ train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
+ test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=config.num_workers)
+ for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
+ B = config.batch_size
+ print(batch['rgbs'].shape)
+ break
+ for i, batch in tqdm(enumerate(test_loader), total=len(test_loader)):
+ print(batch['instances'].unique())
+ break
+
+# test
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/dataset/uorf_dataset.py b/dataset/uorf_dataset.py
new file mode 100644
index 0000000..92f787e
--- /dev/null
+++ b/dataset/uorf_dataset.py
@@ -0,0 +1,457 @@
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+
+import torchvision.transforms.functional as TF
+
+from torch.utils.data import Dataset
+from PIL import Image
+import hydra
+import torch
+from glob import glob
+import numpy as np
+import random
+import csv
+from util.misc import SubSampler
+from util.camera import rotate_cam
+from util.ray import get_rays
+
+
+class MultiscenesDataset(Dataset):
+ def __init__(self, cfg, split='train'):
+ super().__init__()
+ self.split = split
+ self.root = f"{cfg.dataset_root}/{cfg.subset}/{cfg.subset}_test" if split != 'train' else f"{cfg.dataset_root}/{cfg.subset}/{cfg.subset}_train"
+ self.subset = cfg.subset
+ self.n_scenes = 5000
+ self.load_feat = False
+ self.feat_map_size = [16, 16]
+ self.ray_batchsize = cfg.ray_batchsize
+ self.img_size = cfg.img_size[0], cfg.img_size[1]
+ self.num_src_view = 1
+ self.depth_range = [6, 20]
+ self.max_depth = 20
+ self.render_src_view = cfg.render_src_view
+ self.load_mask = cfg.load_mask
+ self.norm_scene = cfg.norm_scene
+ self.subsampler = SubSampler()
+ self.setup_data()
+
+ def setup_data(self):
+ self.scenes = []
+ # for root in self.roots:
+ # self.root = root
+ file_path = os.path.join(self.root, 'files.csv')
+ if os.path.exists(file_path):
+ with open(file_path, newline='') as csvfile:
+ reader = csv.reader(csvfile)
+ for row in reader:
+ self.scenes.append(row)
+ else:
+ image_filenames = sorted(glob(os.path.join(self.root, '*.png'))) # root/00000_sc000_az00_el00.png
+ mask_filenames = sorted(glob(os.path.join(self.root, '*_mask.png')))
+ fg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_moving.png')))
+ moved_filenames = sorted(glob(os.path.join(self.root, '*_moved.png')))
+ bg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_bg.png')))
+ bg_in_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_providing_bg.png')))
+ changed_filenames = sorted(glob(os.path.join(self.root, '*_changed.png')))
+ bg_in_filenames = sorted(glob(os.path.join(self.root, '*_providing_bg.png')))
+ changed_filenames_set, bg_in_filenames_set = set(changed_filenames), set(bg_in_filenames)
+ bg_mask_filenames_set, bg_in_mask_filenames_set = set(bg_mask_filenames), set(bg_in_mask_filenames)
+ image_filenames_set, mask_filenames_set = set(image_filenames), set(mask_filenames)
+ fg_mask_filenames_set, moved_filenames_set = set(fg_mask_filenames), set(moved_filenames)
+ filenames_set = image_filenames_set - mask_filenames_set - fg_mask_filenames_set - moved_filenames_set - changed_filenames_set - bg_in_filenames_set - bg_mask_filenames_set - bg_in_mask_filenames_set
+ filenames = sorted(list(filenames_set))
+ scenes = []
+ for i in range(self.n_scenes):
+ scene_filenames = [x for x in filenames if 'sc{:04d}'.format(i) in x]
+ if len(scene_filenames) > 0:
+ scenes.append(scene_filenames)
+ self.scenes += scenes
+ with open(file_path, 'w', newline='') as csvfile:
+ writer = csv.writer(csvfile)
+ writer.writerows(scenes)
+ self.n_scenes = len(self.scenes)
+
+ self.intrinsics = []
+ self.scene2normscene = []
+ self.normscene_scale = []
+ self.cam2normscene = []
+ self.cam2scene = []
+ img_h, img_w = 256, 256
+ for i in tqdm(range(self.n_scenes), desc=f"Loading {self.split} scenes"):
+ cam2scene, dims = [], []
+ for path in self.scenes[i]:
+ pose_path = path.replace('.png', '_RT.txt')
+ pose = np.loadtxt(pose_path)
+ cam2scene.append(torch.tensor(pose, dtype=torch.float32)) # 4x4
+ dims.append([img_h, img_w])
+ cam2scene = torch.stack(cam2scene) # N, 4x4
+ self.cam2scene.append(cam2scene)
+
+ intrinsic = self.get_intrinsic(path.replace('.png', '_intrinsics.txt'))
+ intrinsic_normed = torch.diag(torch.Tensor([self.img_size[1] / img_w,
+ self.img_size[0] / img_h, 1, 1])) @ intrinsic
+ self.intrinsics.append(intrinsic_normed)
+ if self.norm_scene:
+ nss_scale = 7 # we follow the setting of uorf
+ scene2normscene = torch.Tensor([
+ [1/nss_scale, 0, 0, 0],
+ [0, 1/nss_scale, 0, 0],
+ [0, 0, 1/nss_scale, 0],
+ [0, 0, 0, 1],
+ ])
+ self.scene2normscene.append(scene2normscene)
+ self.normscene_scale.append(scene2normscene[0, 0])
+ cam2normscene = []
+ indice = list(range(len(cam2scene)))
+ for idx in indice:
+ cam2normscene.append(scene2normscene @ cam2scene[idx])
+ cam2normscene = torch.stack(cam2normscene)
+ self.cam2normscene.append(cam2normscene)
+ if self.norm_scene:
+ self.depth_range = self.depth_range[0] * min(self.normscene_scale), self.depth_range[1] * max(self.normscene_scale)
+ print(f"{self.n_scenes} scenes for {self.split}")
+ print(f"depth range: {self.depth_range}")
+
+ def _transform(self, img):
+ img = TF.resize(img, self.img_size)
+ img = TF.to_tensor(img)
+ return img
+
+ def get_intrinsic(self, path):
+ frustum_size = (256, 256)
+ if not os.path.isfile(path):
+ focal_ratio = (350. / 320., 350. / 240.)
+ focal_x = focal_ratio[0] * frustum_size[0]
+ focal_y = focal_ratio[1] * frustum_size[1]
+ bias_x = (frustum_size[0] - 1.) / 2.
+ bias_y = (frustum_size[1] - 1.) / 2.
+ else:
+ intrinsics = np.loadtxt(path)
+ focal_x = intrinsics[0, 0] * frustum_size[0]
+ focal_y = intrinsics[1, 1] * frustum_size[1]
+ bias_x = ((intrinsics[0, 2] + 1) * frustum_size[0] - 1.) / 2.
+ bias_y = ((intrinsics[1, 2] + 1) * frustum_size[1] - 1.) / 2.
+ intrinsic = torch.tensor([[focal_x, 0, bias_x, 0],
+ [0, focal_y, bias_y, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]])
+ return intrinsic.float()
+
+ def _transform_mask(self, img):
+ img = TF.resize(img, self.img_size, Image.NEAREST)
+ img = TF.to_tensor(img)
+ return img
+
+ def load_sample(self, scene_idx, sample_idx):
+ path = self.scenes[scene_idx][sample_idx]
+ img = Image.open(path).convert('RGB')
+ img = self._transform(img).permute([1, 2, 0]).view(-1, 3) # [HW, 3]
+
+ cam2scene = self.cam2normscene[scene_idx][sample_idx] if self.norm_scene else self.cam2scene[scene_idx][sample_idx]
+ cam = torch.cat([torch.Tensor(self.img_size), self.intrinsics[scene_idx].flatten(),
+ cam2scene.flatten()]) # 2+16+16=34
+
+
+ return img, cam
+
+ def __getitem__(self, index):
+ Br = self.ray_batchsize
+ sample = {}
+ if self.split == 'train':
+ scene_idx = index // 4
+ sample_idx = index % 4
+ filenames = self.scenes[scene_idx]
+ all_rgbs, all_cams = [], []
+ for i, path in enumerate(filenames):
+ img, cam = self.load_sample(scene_idx, i)
+ all_rgbs.append(img)
+ all_cams.append(cam)
+ all_rgbs = torch.stack(all_rgbs) # [N, H*W, 3]
+ all_cams = torch.stack(all_cams) # [N, 34]
+
+ tgt_cam = all_cams[0:1] # [1, 34]
+ all_rays = get_rays(all_cams, *self.img_size) # [N, HW, 6]
+ if np.random.rand() > 0.25:
+ # random delete source view
+ tgt_rgbs = all_rgbs[self.num_src_view:].view(-1, 3) # [N*Br, 3] # [N-1, HW, 3]
+ tgt_rays = all_rays[self.num_src_view:].view(-1, 6) # [N*Br, 6]
+ else:
+ tgt_rgbs = all_rgbs.view(-1, 3) # [N*Br, 3]
+ tgt_rays = all_rays.view(-1, 6) # [N*Br, 6]
+ subsample_idx = torch.randperm(tgt_rgbs.shape[0])[:Br][:, None] # [Br]
+ tgt_rgbs = tgt_rgbs.gather(0, subsample_idx.expand(-1, 3)) # [Br, 3]
+ tgt_rays = tgt_rays.gather(0, subsample_idx.expand(-1, 6)) # [Br, 6]
+
+ src_rgbs = all_rgbs[0:self.num_src_view] # [N, HW, 3]
+ src_cams = all_cams[0:self.num_src_view] # [N, 34]
+ else:
+ filenames = self.scenes[index]
+ all_rgbs, all_cams = [], []
+ for i, path in enumerate(filenames):
+ img, cam = self.load_sample(index, i)
+ all_rgbs.append(img)
+ all_cams.append(cam)
+ all_rgbs = torch.stack(all_rgbs) # [N, H*W, 3]
+ all_cams = torch.stack(all_cams) # [N, 34]
+ if 'kitchen' in self.subset: # no ground truth mask
+ masks = torch.randint(4, [all_rgbs.shape[0], self.img_size[0] * self.img_size[1]])
+ else:
+ masks = []
+ for path in filenames:
+ mask_path = path.replace('.png', '_mask.png')
+ if os.path.isfile(mask_path):
+ mask = Image.open(mask_path).convert('RGB')
+ mask_l = mask.convert('L')
+ mask = self._transform_mask(mask)
+ # ret['mask'] = mask
+ mask_l = self._transform_mask(mask_l)
+ mask_flat = mask_l.flatten(start_dim=0) # HW,
+ greyscale_dict = mask_flat.unique(sorted=True) # 8,
+ # make sure the background is always 0
+ if self.subset not in ['room_texture', 'kitchen_shiny', 'kitchen_matte']:
+ bg_color = greyscale_dict[1].clone()
+ greyscale_dict[1] = greyscale_dict[0]
+ greyscale_dict[0] = bg_color
+ onehot_labels = mask_flat[:, None] == greyscale_dict # HWx8, one-hot
+ onehot_labels = onehot_labels.type(torch.uint8)
+ mask_idx = onehot_labels.argmax(dim=1).view(-1) # HW
+ masks.append(mask_idx)
+ masks = torch.stack(masks) # [N, HW]
+ sample['instances'] = masks
+ tgt_rgbs, tgt_cam = all_rgbs, all_cams
+ tgt_rays = get_rays(tgt_cam, *self.img_size) # [N, HW, 6]
+ Nv = self.num_src_view
+ src_rgbs, src_cams = all_rgbs[:Nv], all_cams[:Nv]
+ sample['rgbs'] = tgt_rgbs # [HW, 3] or [N, HW, 3]
+ sample['rays'] = tgt_rays # [HW, 6] or [N, HW, 6]
+ sample['cam'] = tgt_cam # [1, 34] or [N, 34]
+ sample['src_rgbs'] = src_rgbs.reshape(-1, *self.img_size, 3) # [N, H, W,3] N = 1
+ sample['src_cams'] = src_cams # [N, 34] 2+16+16
+ sample['depth_range'] = torch.tensor(self.depth_range).float()
+
+ return sample
+
+ def __len__(self):
+ return self.n_scenes if self.split != 'train' else self.n_scenes * 4
+
+
+class VisualDataset(Dataset):
+ def __init__(self, cfg, split='test'):
+ super().__init__()
+ self.split = split
+ self.subset = cfg.subset
+ self.root = f"{cfg.dataset_root}/{cfg.subset}/{cfg.subset}_{split}"
+ self.n_scenes = cfg.n_scenes
+ self.feat_map_size = [16, 16]
+ self.ray_batchsize = cfg.ray_batchsize
+ self.img_size = cfg.img_size[0], cfg.img_size[1]
+ self.num_src_view = 1
+ self.depth_range = [6, 20]
+ self.max_depth = 20
+ self.normalize = cfg.normalize
+ self.load_mask = cfg.load_mask
+ self.norm_scene = cfg.norm_scene
+ self.num_vis = cfg.num_vis
+ self.setup_data()
+
+ def setup_data(self):
+ self.scenes = []
+ file_path = os.path.join(self.root, 'files.csv')
+ if os.path.exists(file_path):
+ with open(file_path, newline='') as csvfile:
+ reader = csv.reader(csvfile)
+ for row in reader:
+ self.scenes.append(row)
+ else:
+ image_filenames = sorted(glob(os.path.join(self.root, '*.png'))) # root/00000_sc000_az00_el00.png
+ mask_filenames = sorted(glob(os.path.join(self.root, '*_mask.png')))
+ fg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_moving.png')))
+ moved_filenames = sorted(glob(os.path.join(self.root, '*_moved.png')))
+ bg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_bg.png')))
+ bg_in_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_providing_bg.png')))
+ changed_filenames = sorted(glob(os.path.join(self.root, '*_changed.png')))
+ bg_in_filenames = sorted(glob(os.path.join(self.root, '*_providing_bg.png')))
+ changed_filenames_set, bg_in_filenames_set = set(changed_filenames), set(bg_in_filenames)
+ bg_mask_filenames_set, bg_in_mask_filenames_set = set(bg_mask_filenames), set(bg_in_mask_filenames)
+ image_filenames_set, mask_filenames_set = set(image_filenames), set(mask_filenames)
+ fg_mask_filenames_set, moved_filenames_set = set(fg_mask_filenames), set(moved_filenames)
+ filenames_set = image_filenames_set - mask_filenames_set - fg_mask_filenames_set - moved_filenames_set - changed_filenames_set - bg_in_filenames_set - bg_mask_filenames_set - bg_in_mask_filenames_set
+ filenames = sorted(list(filenames_set))
+ scenes = []
+ for i in range(self.n_scenes):
+ scene_filenames = [x for x in filenames if 'sc{:04d}'.format(i) in x]
+ if len(scene_filenames) > 0:
+ scenes.append(scene_filenames)
+ self.scenes = scenes
+ with open(file_path, 'w', newline='') as csvfile:
+ writer = csv.writer(csvfile)
+ writer.writerows(scenes)
+ self.n_scenes = len(self.scenes)
+
+ self.intrinsics = []
+ self.normscene_scale = []
+ self.cam2normscene = []
+ self.cam2scene = []
+ img_h, img_w = 256, 256
+ self.n_scenes = len(self.scenes)
+ for i in tqdm(range(self.n_scenes), desc=f"Loading {self.split} scenes"):
+ cam2scene, dims = [], []
+ for path in self.scenes[i]:
+ pose_path = path.replace('.png', '_RT.txt')
+ pose = np.loadtxt(pose_path)
+ cam2scene.append(torch.tensor(pose, dtype=torch.float32)) # 4x4
+ dims.append([img_h, img_w])
+ cam2scene = torch.stack(cam2scene) # N, 4x4
+ self.cam2scene.append(cam2scene)
+
+ intrinsic = self.get_intrinsic(path.replace('.png', '_intrinsics.txt'))
+ intrinsic_normed = torch.diag(torch.Tensor([self.img_size[1] / img_w,
+ self.img_size[0] / img_h, 1, 1])) @ intrinsic
+ self.intrinsics.append(intrinsic_normed)
+ if self.norm_scene:
+ nss_scale = 7 # we follow the setting of uorf
+ scene2normscene = torch.Tensor([
+ [1/nss_scale, 0, 0, 0],
+ [0, 1/nss_scale, 0, 0],
+ [0, 0, 1/nss_scale, 0],
+ [0, 0, 0, 1],
+ ])
+ self.normscene_scale.append(scene2normscene[0, 0])
+ cam2normscene = []
+ indice = list(range(len(cam2scene)))
+ for idx in indice:
+ cam2normscene.append(scene2normscene @ cam2scene[idx])
+ cam2normscene = torch.stack(cam2normscene)
+ self.cam2normscene.append(cam2normscene)
+ if self.norm_scene:
+ self.depth_range = self.depth_range[0] * min(self.normscene_scale), self.depth_range[1] * max(self.normscene_scale)
+ print(f"{self.n_scenes} scenes for visualization")
+ print(f"depth range: {self.depth_range}")
+
+ def _transform(self, img):
+ img = TF.resize(img, self.img_size)
+ img = TF.to_tensor(img)
+ if self.normalize:
+ img = TF.normalize(img, [0.5] * img.shape[0], [0.5] * img.shape[0]) # [0,1] -> [-1,1]
+ return img
+
+ def get_intrinsic(self, path):
+ frustum_size = (256, 256)
+ if not os.path.isfile(path):
+ focal_ratio = (350. / 320., 350. / 240.)
+ focal_x = focal_ratio[0] * frustum_size[0]
+ focal_y = focal_ratio[1] * frustum_size[1]
+ bias_x = (frustum_size[0] - 1.) / 2.
+ bias_y = (frustum_size[1] - 1.) / 2.
+ else:
+ intrinsics = np.loadtxt(path)
+ focal_x = intrinsics[0, 0] * frustum_size[0]
+ focal_y = intrinsics[1, 1] * frustum_size[1]
+ bias_x = ((intrinsics[0, 2] + 1) * frustum_size[0] - 1.) / 2.
+ bias_y = ((intrinsics[1, 2] + 1) * frustum_size[1] - 1.) / 2.
+ intrinsic = torch.tensor([[focal_x, 0, bias_x, 0],
+ [0, focal_y, bias_y, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]])
+ return intrinsic.float()
+
+ def _transform_mask(self, img):
+ img = TF.resize(img, self.img_size, Image.NEAREST)
+ img = TF.to_tensor(img)
+ return img
+
+ def __getitem__(self, index):
+ sample = {}
+ path = self.scenes[index][0]
+ img = Image.open(path).convert('RGB')
+ src_rgbs = self._transform(img).permute([1, 2, 0])[None] # [1, H, W, 3]
+ cam2scene = self.cam2normscene[index][0] if self.norm_scene else self.cam2scene[index][0]
+ src_cams = torch.cat([torch.Tensor(self.img_size), self.intrinsics[index].flatten(),
+ cam2scene.flatten()])[None]
+ theta_list = np.linspace(0, np.pi * 0.25, self.num_vis, endpoint=False)
+ zoom_list = np.linspace(1, 0.8, self.num_vis // 2, endpoint=False).tolist()
+ zoom_list += np.linspace(0.8, 1.05, self.num_vis // 2, endpoint=False).tolist()
+ # theta_list = np.linspace(0, np.pi * 2, endpoint=False)
+ tgt_cam2scene = [rotate_cam(cam2scene, theta, zoom) for theta, zoom in zip(theta_list, zoom_list)]
+ tgt_cams = [torch.cat([torch.Tensor(self.img_size), self.intrinsics[index].flatten(),
+ tgt_cam2scene[i].flatten()])[None] for i in range(self.num_vis)]
+ tgt_cams = torch.cat(tgt_cams, dim=0) # [N, 34]
+ tgt_rays = get_rays(tgt_cams, *self.img_size) # [N, HW, 6]
+
+ if 'kitchen' in self.subset:
+ mask_idx = torch.randint(4, [self.img_size[0] * self.img_size[1]])
+ else:
+ mask_path = path.replace('.png', '_mask.png')
+ mask = Image.open(mask_path).convert('RGB')
+ mask_l = mask.convert('L')
+ mask = self._transform_mask(mask)
+ mask_l = self._transform_mask(mask_l)
+ mask_flat = mask_l.flatten(start_dim=0) # HW,
+ greyscale_dict = mask_flat.unique(sorted=True) # 8,
+ # make sure the background is always 0
+ if self.subset not in ['room_texture', 'kitchen_shiny', 'kitchen_matte']:
+ bg_color = greyscale_dict[1].clone()
+ greyscale_dict[1] = greyscale_dict[0]
+ greyscale_dict[0] = bg_color
+ onehot_labels = mask_flat[:, None] == greyscale_dict # HWx8, one-hot
+ onehot_labels = onehot_labels.type(torch.uint8)
+ mask_idx = onehot_labels.argmax(dim=1).view(-1) # HW
+ sample['instances'] = mask_idx
+ sample['rays'] = tgt_rays # [N, HW, 6]
+ sample['cam'] = tgt_cams # [N, 34]
+ sample['src_rgbs'] = src_rgbs # [N, H, W,3]
+ sample['src_cams'] = src_cams # [N, 34]
+ sample['depth_range'] = torch.tensor(self.depth_range).float()
+
+ return sample
+
+ def __len__(self):
+ return self.n_scenes
+
+
+# test
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+@hydra.main(config_path='../config/cfg', config_name='uorf', version_base='1.2')
+def main(config):
+ config.img_size = [128, 128]
+ config.num_workers = 4
+ config.subset = 'kitchen_matte'
+ config.dataset_root = "/home/yuliu/Dataset/uorf"
+ config.num_src_view = 1
+ train_set = MultiscenesDataset(config, 'train')
+ val_set = MultiscenesDataset(config, 'val')
+ train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=0)
+ val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0)
+ # vis_set = MultiscenesVisualDataset(config)
+ # vis_loader = DataLoader(vis_set, batch_size=1, shuffle=False, num_workers=0)
+ for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
+ rgbs = batch['rgbs']
+ print(rgbs.shape)
+ cam = batch['cam']
+ src_rgbs = batch['src_rgbs']
+ src_cams = batch['src_cams']
+ if i > 7:
+ break
+ for i, batch in tqdm(enumerate(val_loader), total=len(val_loader)):
+ rgbs = batch['rgbs']
+ mask = batch['instances']
+ if i > 5:
+ break
+ # for i, batch in tqdm(enumerate(vis_loader), total=len(vis_loader)):
+ # cam = batch['cam']
+ # print(cam.shape)
+ # src_rgbs = batch['src_rgbs']
+ # src_cams = batch['src_cams']
+ # if i > 5:
+ # break
+
+
+if __name__ == '__main__':
+ main()
+
+
\ No newline at end of file
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/model/encoder.py b/model/encoder.py
new file mode 100644
index 0000000..c9ee068
--- /dev/null
+++ b/model/encoder.py
@@ -0,0 +1,333 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from util.ray import get_rays_origin_and_direction
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, num_octaves=8, start_octave=0):
+ super().__init__()
+ self.num_octaves = num_octaves
+ self.start_octave = start_octave
+
+ def forward(self, coords):
+ batch_size, num_points, dim = coords.shape
+
+ octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves)
+ octaves = octaves.float().to(coords)
+ multipliers = 2**octaves * math.pi
+ coords = coords.unsqueeze(-1)
+ while len(multipliers.shape) < len(coords.shape):
+ multipliers = multipliers.unsqueeze(0)
+
+ scaled_coords = coords * multipliers
+
+ sines = torch.sin(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves)
+ cosines = torch.cos(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves)
+
+ result = torch.cat((sines, cosines), -1)
+ return result
+
+
+class RayEncoder(nn.Module):
+ def __init__(self, pos_octaves=8, pos_start_octave=0, ray_octaves=4, ray_start_octave=0):
+ super().__init__()
+ self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves, start_octave=pos_start_octave)
+ self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves, start_octave=ray_start_octave)
+
+ def forward(self, pos, rays):
+ if len(rays.shape) == 4:
+ batchsize, height, width, dims = rays.shape
+ pos_enc = self.pos_encoding(pos.unsqueeze(1))
+ pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1)
+ pos_enc = pos_enc.repeat(1, 1, height, width)
+ rays = rays.flatten(1, 2)
+
+ ray_enc = self.ray_encoding(rays)
+ ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1])
+ ray_enc = ray_enc.permute((0, 3, 1, 2))
+ x = torch.cat((pos_enc, ray_enc), 1)
+ else:
+ pos_enc = self.pos_encoding(pos).expand(-1, rays.shape[1], -1)
+ ray_enc = self.ray_encoding(rays)
+ x = torch.cat((pos_enc, ray_enc), -1)
+
+ return x
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect')
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, padding_mode='reflect')
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes, track_running_stats=False, affine=True)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes, track_running_stats=False, affine=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width, track_running_stats=False, affine=True)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width, track_running_stats=False, affine=True)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion, track_running_stats=False, affine=True)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class conv(nn.Module):
+ def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
+ super(conv, self).__init__()
+ self.kernel_size = kernel_size
+ self.conv = nn.Conv2d(num_in_layers,
+ num_out_layers,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(self.kernel_size - 1) // 2,
+ padding_mode='reflect')
+ self.bn = nn.InstanceNorm2d(num_out_layers, track_running_stats=False, affine=True)
+
+ def forward(self, x):
+ return F.elu(self.bn(self.conv(x)), inplace=True)
+
+
+class upconv(nn.Module):
+ def __init__(self, num_in_layers, num_out_layers, kernel_size, scale):
+ super(upconv, self).__init__()
+ self.scale = scale
+ self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1)
+
+ def forward(self, x):
+ x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear')
+ return self.conv(x)
+
+
+class ResUNet(nn.Module):
+ def __init__(self,
+ in_ch=3,
+ encoder='resnet34',
+ coarse_out_ch=32,
+ fine_out_ch=32,
+ norm_layer=None,
+ coarse_only=False
+ ):
+
+ super(ResUNet, self).__init__()
+ assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "Incorrect encoder type"
+ if encoder in ['resnet18', 'resnet34']:
+ filters = [64, 128, 256, 512]
+ else:
+ filters = [256, 512, 1024, 2048]
+ self.coarse_only = coarse_only
+ if self.coarse_only:
+ fine_out_ch = 0
+ self.coarse_out_ch = coarse_out_ch
+ self.fine_out_ch = fine_out_ch
+ out_ch = coarse_out_ch + fine_out_ch
+
+ # original
+ layers = [3, 4, 6, 3]
+ if norm_layer is None:
+ # norm_layer = nn.BatchNorm2d
+ norm_layer = nn.InstanceNorm2d
+ self._norm_layer = norm_layer
+ self.dilation = 1
+ block = BasicBlock
+ replace_stride_with_dilation = [False, False, False]
+ self.inplanes = 64
+ self.groups = 1
+ self.base_width = 64
+ self.conv1 = nn.Conv2d(in_ch, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False, padding_mode='reflect')
+ self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True)
+ self.relu = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+
+ # decoder
+ self.upconv3 = upconv(filters[2], 128, 3, 2)
+ self.iconv3 = conv(filters[1] + 128, 128, 3, 1)
+ self.upconv2 = upconv(128, 64, 3, 2)
+ self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1)
+
+ # fine-level conv
+ self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion, track_running_stats=False, affine=True),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def skipconnect(self, x1, x2):
+ diffY = x2.size()[2] - x1.size()[2]
+ diffX = x2.size()[3] - x1.size()[3]
+
+ x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
+ diffY // 2, diffY - diffY // 2))
+
+ # for padding issues, see
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+
+ x = torch.cat([x2, x1], dim=1)
+ return x
+
+ def forward(self, x):
+ x = self.relu(self.bn1(self.conv1(x)))
+
+ x1 = self.layer1(x)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+
+ x = self.upconv3(x3)
+ x = self.skipconnect(x2, x)
+ x = self.iconv3(x)
+
+ x = self.upconv2(x)
+ x = self.skipconnect(x1, x)
+ x = self.iconv2(x)
+
+ x_out = self.out_conv(x)
+
+ if self.coarse_only:
+ x_coarse = x_out
+ x_fine = None
+ else:
+ x_coarse = x_out[:, :self.coarse_out_ch, :]
+ x_fine = x_out[:, -self.fine_out_ch:, :]
+ return x_coarse, x_fine
+
+
+class MultiViewResUNet(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.resunet = ResUNet(in_ch=3 + 3 + 3, coarse_out_ch=cfg.feature_size, coarse_only=True)
+ self.ray_encoder = RayEncoder(pos_octaves=10, pos_start_octave=0,
+ ray_octaves=10)
+
+ def forward(self, src_cams, images):
+ BN, C, H, W = images.shape
+ rays_o, rays_d = get_rays_origin_and_direction(src_cams.view(BN, -1), H, W) # [B*N, 1, 3], [B*N, HW, 3]
+ ray_enc = torch.cat([rays_o.repeat(1, H*W, 1), rays_d], dim=-1).view(BN, H, W, 6).permute(0, 3, 1, 2)
+ x = torch.cat([images, ray_enc], dim=1)
+ x = self.resunet(x)[0]
+ return x
+
\ No newline at end of file
diff --git a/model/nerf.py b/model/nerf.py
new file mode 100644
index 0000000..a084a5c
--- /dev/null
+++ b/model/nerf.py
@@ -0,0 +1,248 @@
+# MIT License
+#
+# Copyright (c) 2022 Anpei Chen
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import torch
+from torch import nn
+import numpy as np
+import torch.nn.functional as F
+
+from model.slot_attn import linear
+from model.projection import Projector
+
+
+@torch.jit.script
+def fused_mean_variance(x, weight):
+ mean = torch.sum(x*weight, dim=2, keepdim=True)
+ var = torch.sum(weight * (x - mean)**2, dim=2, keepdim=True)
+ return mean, var
+
+
+class NeRF(nn.Module):
+ def __init__(self, cfg=None, n_samples=64):
+ super().__init__()
+ slot_dec_dim = cfg.slot_dec_dim
+ self.nerf_mlp_dim = cfg.nerf_mlp_dim
+ self.color_mlp = RenderMLP(slot_dec_dim, 3, cfg.pe_view, cfg.pe_feat, self.nerf_mlp_dim, cfg.normalize)
+ if not cfg.slot_density:
+ self.density_proj = linear(slot_dec_dim, 1)
+
+ self.projector = Projector()
+ self.grid_init = cfg.grid_init
+ self.random_proj_ratio = cfg.random_proj_ratio
+ self.slot_dec_dim = slot_dec_dim
+ if self.random_proj_ratio > 0:
+ if cfg.num_src_view > 1:
+ self.base_fc = linear(2 * (cfg.feature_size + 3), slot_dec_dim)
+ else:
+ self.base_fc = linear(cfg.feature_size + 3, slot_dec_dim)
+
+ self.pos_encoding = self.posenc(slot_dec_dim, n_samples)
+
+ def posenc(self, d_hid, n_samples):
+
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_samples)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+ sinusoid_table = torch.from_numpy(sinusoid_table).float().unsqueeze(0)
+ return sinusoid_table
+
+ def init_one_svd(self, n_components, grid_resolution, scale):
+ plane_coef, line_coef = [], []
+ for i in range(len(self.vector_mode)):
+ vec_id = self.vector_mode[i]
+ mat_id_0, mat_id_1 = self.matrix_mode[i]
+ plane_coef.append(torch.nn.Parameter(scale * torch.randn((1, n_components[i], grid_resolution[mat_id_1], grid_resolution[mat_id_0])), requires_grad=True))
+ line_coef.append(torch.nn.Parameter(scale * torch.randn((1, n_components[i], grid_resolution[vec_id], 1)), requires_grad=True))
+ return torch.nn.ParameterList(plane_coef), torch.nn.ParameterList(line_coef)
+
+ def get_coordinate_plane_line(self, xyz_sampled):
+ coordinate_plane = torch.stack((xyz_sampled[..., self.matrix_mode[0]], xyz_sampled[..., self.matrix_mode[1]], xyz_sampled[..., self.matrix_mode[2]])).detach().view(3, -1, 1, 2)
+ coordinate_line = torch.stack((xyz_sampled[..., self.vector_mode[0]], xyz_sampled[..., self.vector_mode[1]], xyz_sampled[..., self.vector_mode[2]]))
+ coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
+ return coordinate_plane, coordinate_line
+
+ def compute_feature(self, xyz_sampled):
+ if self.grid_init == 'tensorf':
+ coordinate_plane, coordinate_line = self.get_coordinate_plane_line(xyz_sampled)
+ plane_coef_point, line_coef_point = [], []
+ for idx_plane in range(len(self.plane_grid)):
+ plane_coef_point.append(F.grid_sample(self.plane_grid[idx_plane], coordinate_plane[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1]))
+ line_coef_point.append(F.grid_sample(self.line_grid[idx_plane], coordinate_line[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1]))
+ plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
+ return self.basis_mat((plane_coef_point * line_coef_point).T)
+ elif self.grid_init == '3d':
+ points_feature = F.grid_sample(self.grid, xyz_sampled.reshape(1, -1, 1, 1, 3), align_corners=True).squeeze().permute(1, 0)
+ return self.basis_mat(points_feature)
+
+ def project_feature(self, xyz_sampled, cam, src_imgs, src_cams, src_feats, is_Train=False, proj_mask=None):
+ r"""
+ Input:
+ xyz_sampled: (B, Nr, Np, 3)
+ cam: (B, 27)
+ src_imgs: (B, Nv, 3, H, W)
+ src_cams: (B, Nv, 27)
+ src_feats: (B, Nv, D, H1, W1)
+ Output:
+ feature: (B, Nr*Np, D)
+ """
+ B, Nr, Np, _ = xyz_sampled.shape # number of points
+ xyz_sampled = xyz_sampled.view(B, -1, 3) # (B, Nr*Np, 3)
+ src_imgs = src_imgs.type_as(src_feats)
+ points_feature = xyz_sampled.new_zeros(B, Nr*Np, self.slot_dec_dim)
+ if self.random_proj_ratio > 0:
+ if self.random_proj_ratio < 1 and is_Train:
+ if proj_mask is None:
+ noise = torch.rand([1, Nr, 1, 1], device=xyz_sampled.device)
+ proj_mask = noise <= self.random_proj_ratio
+ m = proj_mask.expand(1, Nr, Np, 1).reshape(1, -1, 1)
+ else:
+ m = proj_mask.expand(1, Nr, Np, 1).reshape(1, -1, 1)
+ if torch.sum(m) > 0:
+ rgb_feat, mask, pixel_locations = self.projector.compute(xyz_sampled[m.expand(B, -1, 3)].view(B, -1, 3), cam, src_imgs, src_cams, src_feats)
+ Nv = src_imgs.shape[1]
+ if Nv == 1:
+ x = self.base_fc(rgb_feat)
+ points_feature[m.expand(B, -1, self.slot_dec_dim)] = torch.sum(x * mask, dim=2).view(-1)
+ else:
+ weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-6)
+ mean, var = fused_mean_variance(rgb_feat, weight)
+ x = torch.cat([mean, var], dim=-1).squeeze(2)
+ points_feature[m.expand(B, -1, self.slot_dec_dim)] = self.base_fc(x).view(-1)
+ else:
+ mask = torch.zeros(B, Nr*Np, 1, 1, device=xyz_sampled.device)
+ else:
+ rgb_feat, mask, pixel_locations = self.projector.compute(xyz_sampled.view(B, -1, 3), cam, src_imgs, src_cams, src_feats)
+ Nv = src_imgs.shape[1]
+ if Nv == 1:
+ points_feature = torch.sum(self.base_fc(rgb_feat) * mask, dim=2)
+ else:
+ weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-6)
+ mean, var = fused_mean_variance(rgb_feat, weight)
+ x = torch.cat([mean, var], dim=-1).squeeze(2)
+ points_feature = self.base_fc(x)
+ return points_feature, proj_mask
+
+ def compute_density(self, points_feature):
+ sigma = F.relu(self.density_proj(points_feature).squeeze(-1)) # (B, N)
+ return sigma
+
+
+class RenderMLP(nn.Module):
+
+ def __init__(self, in_channels, out_channels=3, pe_view=2, pe_feat=2, nerf_mlp_dim=128, normalize=True):
+ super().__init__()
+ self.pe_view = pe_view
+ self.pe_feat = pe_feat
+ self.output_channels = out_channels
+ self.view_independent = self.pe_view == 0 and self.pe_feat == 0
+ self.in_feat_mlp = in_channels
+
+ self.mlp = nn.Sequential(
+ linear(self.in_feat_mlp, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, out_channels)
+ )
+ self.normalize = normalize
+
+ def forward(self, rays_d, features):
+ out = self.mlp(features)
+
+ if self.normalize:
+ out = out.tanh()
+ else:
+ out = out.tanh() / 2 + 0.5
+ return out
+
+
+class SemanticMLP(nn.Module):
+ def __init__(self, in_channels, out_channels=3, pe_feat=2, nerf_mlp_dim=128):
+ super().__init__()
+ self.pe_feat = pe_feat
+ self.output_channels = out_channels
+ self.in_feat_mlp = 2 * pe_feat * in_channels + in_channels
+
+ self.mlp = nn.Sequential(
+ linear(self.in_feat_mlp, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, out_channels)
+ )
+
+ def forward(self, features):
+ indata = [features]
+ if self.pe_feat > 0:
+ indata += [positional_encoding(features, self.pe_feat)]
+ mlp_in = torch.cat(indata, dim=-1)
+ out = self.mlp(mlp_in)
+ return out
+
+
+class InstanceMLP(nn.Module):
+ def __init__(self, in_channels, out_channels=3, pe_feat=2, nerf_mlp_dim=128):
+ super().__init__()
+ self.pe_feat = pe_feat
+ self.output_channels = out_channels
+ self.in_feat_mlp = 2 * pe_feat * in_channels + in_channels
+
+ self.mlp = nn.Sequential(
+ linear(self.in_feat_mlp, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'),
+ nn.LeakyReLU(),
+ linear(nerf_mlp_dim, out_channels)
+ )
+
+ def forward(self, features):
+ indata = [features]
+ if self.pe_feat > 0:
+ indata += [positional_encoding(features, self.pe_feat)]
+ mlp_in = torch.cat(indata, dim=-1)
+ out = self.mlp(mlp_in)
+ return out
+
+
+def positional_encoding(positions, freqs):
+ freq_bands = (2 ** torch.arange(freqs)).to(positions.device)
+ pts = (positions[..., None] * freq_bands).reshape(positions.shape[:-1] + (freqs * positions.shape[-1],))
+ pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
+ return pts
+
diff --git a/model/projection.py b/model/projection.py
new file mode 100644
index 0000000..9d77448
--- /dev/null
+++ b/model/projection.py
@@ -0,0 +1,103 @@
+# This file is modified from ContraNeRF
+import torch
+import torch.nn.functional as F
+
+
+class Projector():
+ def __init__(self):
+ pass
+
+ def inbound(self, pixel_locations, h, w):
+ return (pixel_locations[..., 0] <= w - 1.) & \
+ (pixel_locations[..., 0] >= 0) & \
+ (pixel_locations[..., 1] <= h - 1.) &\
+ (pixel_locations[..., 1] >= 0)
+
+ def normalize(self, pixel_locations, h, w):
+ resize_factor = torch.tensor([w-1., h-1.]).type_as(pixel_locations).to(pixel_locations.device)[None, None, :]
+ normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1. # [n_views, n_points, 2]
+ return normalized_pixel_locations
+
+ def compute_projections(self, xyz, src_cams):
+ B, Nv = src_cams.shape[:2]
+ src_intrinsics = src_cams[..., 2:18].reshape(B*Nv, 4, 4)[:, :3, :3] # [B*n_views, 3, 3]
+ src_poses = src_cams[..., -16:].reshape(B*Nv, 4, 4)
+ xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1) # [B, n_points, 4]
+ projections = src_intrinsics.bmm(
+ torch.inverse(src_poses).bmm(xyz_h.transpose(-1, -2).repeat(Nv, 1, 1))[:, :3]
+ ) # [B*n_views, 3, n_points]
+ projections = projections.transpose(-2, -1).reshape(B, Nv, -1, 3) # [B, n_views, n_points, 3]
+ pixel_locations = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8) # [B, n_views, n_points, 2]
+ pixel_locations = torch.clamp(pixel_locations, min=-1e5, max=1e5)
+ mask = projections[..., 2] > 0 # a point is invalid if behind the camera
+ return pixel_locations, mask
+
+ def compute_angle(self, xyz, cam, src_cams):
+ B, Np, _ = xyz.shape
+ Nv = src_cams.shape[1]
+ src_poses = src_cams[..., -16:].reshape(B, Nv, 4, 4)
+ query_pose = cam[:, -16:].reshape(B, 1, 4, 4).expand(-1, Nv, -1, -1) # [B, n_views, 4, 4]
+ ray2tar_pose = (query_pose[:, :, :3, 3].unsqueeze(2) - xyz.unsqueeze(1)) # [B, n_views, n_samples, 3]
+ ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)
+ ray2src_pose = (src_poses[:, :, :3, 3].unsqueeze(2) - xyz.unsqueeze(1))
+ ray2src_pose = ray2src_pose / (torch.norm(ray2src_pose, dim=-1, keepdim=True) + 1e-6)
+ ray_diff = ray2tar_pose - ray2src_pose
+ ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
+ ray_diff_dot = torch.sum(ray2tar_pose * ray2src_pose, dim=-1, keepdim=True)
+ ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
+ ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) # [B, n_views, n_samples, 4]
+ return ray_diff
+
+ def compute(self, xyz, cam, src_imgs, src_cams, featmaps):
+ r"""
+ Input:
+ xyz: [B, n_samples, 3]
+ cam: [B, 1, 34]
+ src_imgs: [B, n_views, 3, h, w]
+ src_cams: [B, n_views, 34]
+ featmaps: [B, n_views, d, h1, w1]
+ Output:
+ rgb_feat_sampled: [B, n_samples, n_views, d+3]
+ ray_diff: [B, n_samples, n_views, 4]
+ mask: [B, n_samples, n_views, 1]
+ """
+ B, Nv, _, H, W = src_imgs.shape
+ Np = xyz.shape[1]
+ # xyz = xyz.reshape(B, 128, 128, 96, 3)[:, 20:80, 30:90].reshape(B, -1, 3)
+
+ # compute the projection of the query points to each reference image
+ pixel_locations, mask_in_front = self.compute_projections(xyz, src_cams) # [B, n_views, n_samples, 2], [B, n_views, n_samples]
+ # avoid numerical precision errors
+ pixel_locations[(pixel_locations < 0) & (pixel_locations > -0.5)] = 0
+ pixel_locations[(pixel_locations > H - 1) & (pixel_locations < H - 0.5)] = H - 1
+ pixel_locations[(pixel_locations > W - 1) & (pixel_locations < W - 0.5)] = W - 1
+
+ # # visualize for debug
+ # import matplotlib.pyplot as plt
+
+ # painted_img = src_imgs.clone().detach().cpu()[0].numpy()
+ # pixel_locations[(pixel_locations.abs() > 500).any(-1)] = -100
+ # for v in range(Nv):
+ # plt.imshow(painted_img[v, :, :, :].transpose(1, 2, 0))
+ # plt.scatter(pixel_locations[0, v, :, 0].cpu().numpy(), pixel_locations[0, v, :, 1].cpu().numpy(), c='r')
+ # plt.savefig('debug/{}.png'.format(v))
+ # plt.close()
+ normalized_pixel_locations = self.normalize(pixel_locations, H, W).reshape(B*Nv, 1, -1, 2) # [B*n_views, 1, n_samples, 2]
+
+ # rgb sampling
+ src_imgs = src_imgs.flatten(0, 1) # [B*n_views, 3, h, w]
+ rgbs_sampled = F.grid_sample(src_imgs, normalized_pixel_locations, align_corners=True).view(B, Nv, 3, Np) # [B, n_views, 3, n_samples]
+ rgbs_sampled = rgbs_sampled.permute(0, 3, 1, 2) # [B, n_samples, n_views, 3]
+
+ # deep feature sampling
+ featmaps = featmaps.flatten(0, 1) # [B*n_views, d, h1, w1]
+ feat_sampled = F.grid_sample(featmaps, normalized_pixel_locations, align_corners=True).view(B, Nv, -1, Np) # [B, n_views, d, n_samples]
+ feat_sampled = feat_sampled.permute(0, 3, 1, 2) # [B, n_samples, n_views, d]
+ rgb_feat_sampled = torch.cat([rgbs_sampled, feat_sampled], dim=-1) # [B, n_samples, n_views, d+3]
+
+ # mask
+ inbound = self.inbound(pixel_locations, H, W)
+ # ray_diff = self.compute_angle(xyz, cam, src_cams)
+ # ray_diff = ray_diff.permute(0, 2, 1, 3)
+ mask = (inbound * mask_in_front).float().transpose(1, 2)[..., None] # [B, n_samples, n_views, 1]
+ return rgb_feat_sampled, mask, pixel_locations
\ No newline at end of file
diff --git a/model/renderer.py b/model/renderer.py
new file mode 100644
index 0000000..ecf76d1
--- /dev/null
+++ b/model/renderer.py
@@ -0,0 +1,330 @@
+# MIT License
+#
+# Copyright (c) 2022 Anpei Chen
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from model.nerf import NeRF
+from model.slot_attn import JointDecoder, linear
+
+
+@torch.jit.script
+def fused_mean_variance(x, weight):
+ mean = torch.sum(x*weight, dim=2, keepdim=True)
+ var = torch.sum(weight * (x - mean)**2, dim=2, keepdim=True)
+ return mean, var
+
+def positional_encoding(positions, freqs):
+ freq_bands = (2 ** torch.arange(freqs)).to(positions.device)
+ pts = (positions[..., None] * freq_bands).reshape(positions.shape[:-1] + (freqs * positions.shape[-1],))
+ pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
+ return pts
+
+
+class NeRFRenderer(nn.Module):
+
+ def __init__(self, depth_range, cfg):
+ super().__init__()
+ self.depth_range = depth_range
+ self.distance_scale = cfg.get('distance_scale', 25)
+ self.weight_thres_color = 0.0001
+ self.feat_weight_thres = 0.01
+ self.alpha_mask_threshold = 0.0075
+ self.step_size = None
+ self.stop_semantic_grad = cfg.stop_semantic_grad
+ self.nerf_mlp_dim = cfg.nerf_mlp_dim
+ self.slot_dec_dim = cfg.slot_dec_dim
+ self.n_samples = cfg.n_samples
+ self.n_samples_fine = cfg.n_samples_fine
+ self.num_instances = cfg.max_instances
+
+ self.slot_size = cfg.slot_size
+ self.inv_uniform = cfg.get('inv_uniform', False)
+ self.normalize = cfg.normalize
+ self.pos_proj = linear(120, self.slot_dec_dim)
+
+ def compute_points_feature(self, nerf: NeRF, slot_dec: JointDecoder,
+ xyz_sampled, slots, dists,
+ cam, src_imgs, src_cams, src_feats,
+ shape, is_train, rays_d, proj_mask=None):
+ B, Nr, Np = shape
+ points_feature, proj_mask = nerf.project_feature(xyz_sampled.view(B, Nr, Np, 3), cam, src_imgs, src_cams, src_feats, is_train, proj_mask)
+
+ pos_emb = positional_encoding(xyz_sampled.view(B, Nr*Np, 3), 10)
+ view_emb = positional_encoding(rays_d.view(B, Nr, 3), 10).unsqueeze(2).expand(-1, -1, Np, -1).reshape(B, Nr*Np, -1)
+ pos_emb = self.pos_proj(torch.cat([pos_emb, view_emb], dim=-1))
+ points_feature = points_feature + pos_emb
+
+ points_coor = xyz_sampled.view(B, Nr*Np, 3)
+ ret = slot_dec(points_feature, pos_emb, slots, points_coor, Nr)
+
+ points_feature, w_slot_dec, sigma_slot = ret['x'], ret['w'], ret.get('sigma', None)
+ points_feature = points_feature + pos_emb
+ points_feature = points_feature.view(B*Nr, Np, self.slot_dec_dim)
+ if sigma_slot is not None:
+ sigma = sigma_slot.view(B*Nr, Np)
+ else:
+ sigma = nerf.compute_density(points_feature).view(B*Nr, Np) # [B*Nr, Np]
+ alpha, weight, bg_weight = self.raw_to_alpha(sigma, dists * self.distance_scale)
+
+ return points_feature, w_slot_dec, weight, ret.get('sparse_loss', torch.zeros(1, device=points_feature.device)), proj_mask, sigma
+
+ def render_color(self, points_feature, viewdirs, w, appearance_mask, shape, color_mlp, white_bg, is_train):
+ B, Nr, Np = shape
+ rgb = points_feature.new_zeros((B*Nr, Np, 3))
+ valid_rgbs = color_mlp(viewdirs[appearance_mask], points_feature[appearance_mask])
+ rgb[appearance_mask] = valid_rgbs
+ rgb_map = torch.sum(w * rgb, -2).reshape(B, Nr, 3)
+ return rgb_map
+
+ def render_instance(self, points_feature, w, appearance_mask, shape, w_slot_dec=None, instance_mlp=None):
+ B, Nr, Np = shape
+ if instance_mlp is not None:
+ instances = points_feature.new_zeros((B*Nr, Np, self.num_instances))
+ valid_instances = instance_mlp(points_feature[appearance_mask])
+ instances[appearance_mask] = valid_instances
+ instances = F.softmax(instances, dim=-1)
+ instance_map = torch.sum(w * instances, -2).reshape(B, Nr, -1) # [B, Nr, K]
+ instance_map = instance_map / (torch.sum(instance_map, -1, keepdim=True) + 1e-8)
+ else:
+ instances = w_slot_dec.view(B*Nr, Np, -1)[..., 1:]
+ instance_map = torch.sum(w * instances, -2).reshape(B, Nr, -1) # [B, Nr, K]
+ instance_map = instance_map / (torch.sum(instance_map, -1, keepdim=True) + 1e-6)
+
+ return instance_map
+
+ def render(self, nerf: NeRF, slot_dec: JointDecoder,
+ xyz_sampled, z_vals, viewdirs, slots,
+ cam, src_imgs, src_cams, src_feats,
+ white_bg, is_train,
+ render_color=False, render_ins=False, render_depth=False, rays_d=None, proj_mask=None):
+ B, Nr, Np, _ = viewdirs.shape
+ dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
+ points_feature, w_slot_dec, weight, sparse_loss, proj_mask, sigma = self.compute_points_feature(nerf, slot_dec,
+ xyz_sampled, slots, dists,
+ cam, src_imgs, src_cams, src_feats,
+ (B, Nr, Np), is_train, rays_d, proj_mask)
+ appearance_mask = weight > self.weight_thres_color # [B*Nr, Np]
+ w = weight[..., None] # [B*Nr, Np, 1]
+
+ ret = {
+ "proj_mask": proj_mask,
+ "weight": weight.clone().detach(),
+ "sparse_loss": sparse_loss,
+ "w": w_slot_dec.view(B, Nr*Np, -1)[..., 1:],
+ "pts": xyz_sampled.view(B, Nr*Np, 3),
+ }
+
+ if render_color:
+ viewdirs = viewdirs.reshape(B*Nr, Np, 3)
+ rgb_map = self.render_color(points_feature, viewdirs, w, appearance_mask, (B, Nr, Np), nerf.color_mlp, white_bg, is_train)
+ ret["rgb"] = rgb_map
+
+ if render_depth:
+ depth_map = torch.sum(weight * z_vals, -1).reshape(B, Nr)
+ opacity_map = torch.sum(w, -2).reshape(B, Nr)
+ depth_map = depth_map + (1. - opacity_map) * z_vals.max()
+ ret["depth"] = depth_map
+
+ if render_ins:
+ if self.stop_semantic_grad:
+ w = w.detach()
+ instance_map = self.render_instance(points_feature, w, appearance_mask, (B, Nr, Np), w_slot_dec)
+ ret["instance"] = instance_map
+
+ return ret
+
+ def forward(self, nerf_c: NeRF, nerf_f: NeRF, slot_dec: JointDecoder, slot_dec_fine: JointDecoder,
+ rays, depth_range, slots,
+ cam, src_imgs, src_cams, src_feats,
+ white_bg=False, is_train=False,
+ render_color=False, render_depth=False,
+ render_ins=False):
+ r"""
+ Input:
+ rays: [B, Nr, 6]
+ slots: [B, K, D]
+ white_bg: True or False
+ is_train: True or False
+ cam: [B, 27]
+ src_imgs: [B, Nv, 3, H, W]
+ src_cams: [B, Nv, 27]
+ src_feats: [B, Nv, D, H1, W1]
+ Output:
+ rgb: [B, Nr, 3]
+ instance: [B, Nr, K]
+ depth: [B, Nr]
+ feats: [B, Nr, D]
+ """
+ B, Nr, _ = rays.shape
+ # assert B == 1, "Only support batch size 1"
+ Np = self.n_samples
+ rays = rays.reshape(-1, 6)
+ rays_o, rays_d = rays[:, :3], rays[:, 3:]
+ xyz_sampled, z_vals = sample_points_in_box(rays_o, rays_d, Np,
+ self.depth_range if self.depth_range is not None else depth_range, is_train, self.inv_uniform) # [B*Nr, n_samples, 3], [B*Nr, n_samples]
+ viewdirs = rays_d.view(B, Nr, 1, 3).expand(-1, -1, Np, -1)
+
+ ret_c = self.render(nerf_c, slot_dec, xyz_sampled, z_vals, viewdirs,
+ slots, cam, src_imgs, src_cams, src_feats,
+ white_bg, is_train,
+ render_color=render_color, render_ins=render_ins,
+ render_depth=render_depth, rays_d=rays_d)
+ weight = ret_c["weight"]
+ ret = {}
+ for k, v in ret_c.items():
+ if k != "weight" and k != 'proj_mask':
+ ret[k + "_c"] = v
+
+ if nerf_f is not None:
+ Np_fine = self.n_samples_fine
+ xyz_sampled, z_vals = sample_fine_pts(rays_o, rays_d, weight, z_vals, Np_fine, Np, is_train, self.inv_uniform)
+ viewdirs = rays_d.view(B, Nr, 1, 3).expand(-1, -1, Np + Np_fine, -1)
+ ret_f = self.render(nerf_f, slot_dec_fine, xyz_sampled, z_vals, viewdirs,
+ slots, cam, src_imgs, src_cams, src_feats,
+ white_bg, is_train,
+ render_color=render_color, render_ins=render_ins,
+ render_depth=render_depth, rays_d=rays_d, proj_mask=ret_c['proj_mask'])
+ for k,v in ret_f.items():
+ if k != "weight" and k != 'proj_mask':
+ ret[k + "_f"] = v
+ return ret
+
+ @staticmethod
+ def raw_to_alpha(sigma, dist):
+ alpha = 1. - torch.exp(-sigma * dist)
+ T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
+ weights = alpha * T[:, :-1]
+ return alpha, weights, T[:, -1:]
+
+
+def sample_points_in_box(rays_o, rays_d, n_samples, depth_range, is_train, inv_uniform=False):
+ if isinstance(depth_range, tuple):
+ depth_range = torch.tensor(depth_range).float().to(rays_o.device)
+ depth_range = depth_range.expand(rays_d.shape[0], -1)
+ near_depth, far_depth = depth_range[..., 0], depth_range[..., 1]
+ if inv_uniform:
+ start = 1.0 / near_depth # [N_rays,]
+ step = (1.0 / far_depth - start) / (n_samples - 1)
+ inv_z_vals = torch.stack(
+ [start + i * step for i in range(n_samples)], dim=1
+ ) # [N_rays, n_samples]
+ z_vals = 1.0 / inv_z_vals
+ if is_train:
+ mids = 0.5 * (z_vals[:, 1:] + z_vals[:, :-1])
+ upper = torch.cat([mids, z_vals[:, -1:]], dim=-1)
+ lower = torch.cat([z_vals[:, 0:1], mids], dim=-1)
+ # uniform samples in those intervals
+ t_rand = torch.rand_like(z_vals)
+ z_vals = lower + (upper - lower) * t_rand # [N_rays, Np]
+ else:
+ step_size = (depth_range[..., 1:2] - depth_range[..., 0:1]) / (n_samples - 1)
+ rng = torch.arange(n_samples)[None].type_as(rays_o).expand(rays_o.shape[:-1] + (n_samples,))
+ if is_train:
+ rng = rng + torch.rand_like(rng[:, [0]]).type_as(rng)
+ step = step_size * rng.to(rays_o.device)
+ z_vals = (depth_range[..., 0:1] + step) # [B*Nr, n_samples]
+
+ rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None] # [B*Nr, n_samples, 3]
+ return rays_pts, z_vals
+
+
+def sample_pdf(bins, weights, Np, det=False):
+ r'''
+ :param bins: tensor of shape [N_rays, M+1], M is the number of bins
+ :param weights: tensor of shape [N_rays, M]
+ :param Np: number of samples along each ray
+ :param det: if True, will perform deterministic sampling
+ :return: [N_rays, Np]
+ '''
+
+ M = weights.shape[1]
+ weights += 1e-5
+ # Get pdf
+ pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [N_rays, M]
+ cdf = torch.cumsum(pdf, dim=-1) # [N_rays, M]
+ cdf = torch.cat([torch.zeros_like(cdf[:, 0:1]), cdf], dim=-1) # [N_rays, M+1]
+
+ # Take uniform samples
+ if det:
+ u = torch.linspace(0.0, 1.0, Np, device=bins.device)
+ u = u.unsqueeze(0).repeat(bins.shape[0], 1) # [N_rays, Np]
+ else:
+ u = torch.rand(bins.shape[0], Np, device=bins.device)
+ # Invert CDF
+ above_inds = torch.zeros_like(u, dtype=torch.long) # [N_rays, Np]
+ for i in range(M):
+ above_inds += (u >= cdf[:, i:i+1]).long()
+
+ # random sample inside each bin
+ below_inds = torch.clamp(above_inds-1, min=0)
+ inds_g = torch.stack((below_inds, above_inds), dim=2) # [N_rays, Np, 2]
+
+ cdf = cdf.unsqueeze(1).repeat(1, Np, 1) # [N_rays, Np, M+1]
+ cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [N_rays, Np, 2]
+
+ bins = bins.unsqueeze(1).repeat(1, Np, 1) # [N_rays, Np, M+1]
+ bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [N_rays, Np, 2]
+
+ # t = (u-cdf_g[:, :, 0]) / (cdf_g[:, :, 1] - cdf_g[:, :, 0] + TINY_NUMBER) # [N_rays, Np]
+ # fix numeric issue
+ denom = cdf_g[:, :, 1] - cdf_g[:, :, 0] # [N_rays, Np]
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
+ t = (u - cdf_g[:, :, 0]) / denom
+
+ samples = bins_g[:, :, 0] + t * (bins_g[:, :, 1]-bins_g[:, :, 0])
+
+ return samples
+
+def sample_fine_pts(rays_o, rays_d, weights, z_vals, N_importance, Np, det=True, inv_uniform=False):
+ if inv_uniform:
+ inv_z_vals = 1.0 / z_vals
+ inv_z_vals_mid = 0.5 * (inv_z_vals[:, 1:] + inv_z_vals[:, :-1]) # [N_rays, Np-1]
+ weights = weights[:, 1:-1] # [N_rays, Np-2]
+ inv_z_vals = sample_pdf(
+ bins=torch.flip(inv_z_vals_mid, dims=[1]),
+ weights=torch.flip(weights, dims=[1]),
+ Np=N_importance,
+ det=det,
+ ) # [N_rays, N_importance]
+ z_samples = 1.0 / inv_z_vals
+ else:
+ # take mid-points of depth samples
+ z_vals_mid = 0.5 * (z_vals[:, 1:] + z_vals[:, :-1]) # [N_rays, Np-1]
+ weights = weights[:, 1:-1] # [N_rays, Np-2]
+ z_samples = sample_pdf(
+ bins=z_vals_mid, weights=weights, Np=N_importance, det=det
+ ) # [N_rays, N_importance]
+
+ z_vals = torch.cat((z_vals, z_samples), dim=-1) # [N_rays, Np + N_importance]
+
+ # samples are sorted with increasing depth
+ z_vals, _ = torch.sort(z_vals, dim=-1)
+ N_total_samples = Np + N_importance
+
+ viewdirs = rays_d.unsqueeze(1).repeat(1, N_total_samples, 1)
+ ray_o = rays_o.unsqueeze(1).repeat(1, N_total_samples, 1)
+ pts = z_vals.unsqueeze(2) * viewdirs + ray_o # [N_rays, Np + N_importance, 3]
+ return pts, z_vals
diff --git a/model/slot_attn.py b/model/slot_attn.py
new file mode 100644
index 0000000..ba1c329
--- /dev/null
+++ b/model/slot_attn.py
@@ -0,0 +1,237 @@
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import DropPath
+
+from model.transformer import TransformerDecoder
+from model.encoder import MultiViewResUNet
+
+
+def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1., nonlinearity='relu'):
+
+ m = nn.Linear(in_features, out_features, bias)
+
+ if weight_init == 'kaiming':
+ nn.init.kaiming_uniform_(m.weight, nonlinearity=nonlinearity)
+ else:
+ nn.init.xavier_uniform_(m.weight, gain)
+
+ if bias:
+ nn.init.zeros_(m.bias)
+
+ return m
+
+
+def gru_cell(input_size, hidden_size, bias=True):
+
+ m = nn.GRUCell(input_size, hidden_size, bias)
+
+ nn.init.xavier_uniform_(m.weight_ih)
+ nn.init.orthogonal_(m.weight_hh)
+
+ if bias:
+ nn.init.zeros_(m.bias_ih)
+ nn.init.zeros_(m.bias_hh)
+
+ return m
+
+
+class SlotAttention(nn.Module):
+ def __init__(
+ self,
+ feature_size,
+ slot_size,
+ drop_path=0.2,
+ num_head=1,
+ ):
+ super().__init__()
+ self.slot_size = slot_size
+ self.epsilon = 1.0
+ self.num_head = num_head
+
+ self.norm_feature = nn.LayerNorm(feature_size)
+ self.norm_mlp = nn.LayerNorm(slot_size)
+ self.norm_slots = nn.LayerNorm(slot_size)
+
+ self.project_q = linear(slot_size, slot_size, bias=False)
+ self.project_k = linear(feature_size, slot_size, bias=False)
+ self.project_v = linear(feature_size, slot_size, bias=False)
+
+ self.gru = gru_cell(slot_size, slot_size)
+
+ self.mlp = nn.Sequential(
+ linear(slot_size, slot_size * 4, weight_init='kaiming'),
+ nn.ReLU(),
+ linear(slot_size * 4, slot_size),
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
+
+ def forward(self, features, slots_init, num_src_view, num_iter=3):
+ # features: [batch_size, num_feature, inputs_size]
+ features = self.norm_feature(features)
+ k = self.project_k(features) # Shape: [B, num_features, slot_size]
+ v = self.project_v(features)
+
+ slots = slots_init
+ # Multiple rounds of attention.
+ for i in range(num_iter - 1):
+ slots, attn = self.iter(slots, k, v, num_src_view=num_src_view)
+ slots = slots.detach() + slots_init - slots_init.detach()
+ slots, attn = self.iter(slots, k, v, num_src_view=num_src_view)
+ return slots, attn
+
+ def iter(self, slots, k, v, num_src_view):
+ B, K, D = slots.shape
+ slots_prev = slots
+ slots = self.norm_slots(slots)
+ q = self.project_q(slots)
+
+ Nh = self.num_head
+ q = q.reshape(B, K, Nh, D//Nh).transpose(1, 2) # [B, Nh, K, D//Nh]
+ k = k.reshape(B, -1, Nh, D//Nh).transpose(1, 2) # [B, Nh, Nf, D//Nh]
+
+ # Attention
+ scale = (D//Nh) ** -0.5
+ attn_logits = torch.matmul(q, k.transpose(-1, -2)) * scale # [B, Nh, K, Nf]
+ attn_logits = attn_logits.mean(1) # [B, K, Nf]
+ attn = F.softmax(attn_logits, dim=1)
+
+ # # Weighted mean
+ attn_sum = torch.sum(attn, dim=-1, keepdim=True) + self.epsilon
+ attn_wm = attn / attn_sum
+ updates = torch.einsum('bij, bjd->bid', attn_wm, v)
+
+ # Update slots
+ slots = self.gru(
+ updates.reshape(-1, D),
+ slots_prev.reshape(-1, D)
+ )
+ slots = slots.reshape(B, -1, D)
+ slots = slots + self.drop_path(self.mlp(self.norm_mlp(slots)))
+ return slots, attn
+
+
+class SlotEnc(nn.Module):
+ def __init__(
+ self, num_iter, num_slots, feature_size,
+ slot_size, drop_path=0.2, num_blocks=1):
+ super().__init__()
+
+ self.num_iter = num_iter
+ self.num_slots = num_slots
+ self.slot_size = slot_size
+ self.slot_attn = self.slot_attn = nn.ModuleList([
+ SlotAttention(feature_size, slot_size, drop_path=drop_path) for i in range(num_blocks)
+ ])
+ self.num_blocks = num_blocks
+ self.slots_init = nn.Parameter(torch.zeros(1, num_slots, slot_size))
+ nn.init.xavier_uniform_(self.slots_init)
+
+ def forward(self, f, sigma, num_src_view):
+ B, _, D = f.shape
+ # initialize slots.
+ mu = self.slots_init.expand(B, -1, -1)
+ z = torch.randn_like(mu).type_as(f)
+ slots = mu + z * sigma * mu.detach()
+ slots, attn = self.slot_attn[0](f, slots, num_iter=self.num_iter, num_src_view=num_src_view)
+ for i in range(self.num_blocks - 1):
+ slots, attn = self.slot_attn[i + 1](f, slots, num_iter=self.num_iter, num_src_view=num_src_view)
+ return slots, attn
+
+
+class Slot3D(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.feature_size = config.feature_size
+ self.multi_view_enc = MultiViewResUNet(config)
+ self.slot_enc = SlotEnc(
+ num_iter=config.num_iter,
+ num_slots=config.num_slots,
+ feature_size=config.feature_size,
+ slot_size=config.slot_size,
+ drop_path=config.drop_path,
+ )
+ self.num_slots = config.num_slots
+
+ def forward(self, src_cams=None, images=None, sigma=0):
+ """
+ Input:
+ images: [batch_size, N_views, 3, H, W]
+ src_cams: [batch_size, N_views, 25]
+ Output:
+ slots: [batch_size, num_slots, slot_size]
+ """
+ B, N_view = src_cams.shape[:2]
+ H, W = images.shape[3:]
+ features = self.multi_view_enc(src_cams, images.reshape(B*N_view, 3, H, W)) # [B*N_views, D, H1, W1]
+
+ H1, W1 = features.shape[-2:]
+ features = features.permute(0, 2, 3, 1) # [B*N_views, H1, W1, D]
+
+ feats = features.reshape(B, -1, self.feature_size)
+ attn = torch.zeros(B, self.num_slots, feats.shape[1], device=feats.device)
+ slots, attn = self.slot_enc(feats, sigma=sigma, num_src_view=N_view) # [B, K, slot_size]
+ return slots, attn, features.reshape(B, N_view, H1, W1, features.shape[-1])
+
+
+class JointDecoder(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ d_kv = cfg.slot_size
+ self.num_slots = cfg.num_slots
+ self.tf = TransformerDecoder(
+ cfg.num_dec_blocks, cfg.slot_dec_dim, d_kv, 4, cfg.drop_path, 0)
+ self.empty_slot = nn.Parameter(torch.zeros(1, 1, cfg.slot_size))
+ nn.init.xavier_uniform_(self.empty_slot)
+
+ self.Wq = linear(cfg.slot_dec_dim, cfg.slot_dec_dim)
+ self.Wk = linear(cfg.slot_size, cfg.slot_dec_dim)
+ self.out_proj = linear(cfg.slot_size, cfg.slot_dec_dim)
+ self.force_bg = False
+ self.bg_bound = cfg.bg_bound
+ self.scale = cfg.slot_dec_dim ** -0.5
+ self.slot_density = cfg.get('slot_density', False)
+ if self.slot_density:
+ self.density_scale = nn.Parameter(torch.zeros(1))
+
+
+ def forward(self, point_feats, points_emb, slots, points_coor, Nr=0):
+ r"""
+ Input:
+ point_feats: [B, N, D] N = N_ray * N_points
+ slots: [B, K, slot_size]
+ Output:
+ x: [B, N, D]
+ w: [B, N, K]
+ """
+ slots = torch.cat([self.empty_slot.expand(slots.shape[0], -1, -1), slots], dim=1) # [B, K+1, D]
+ x = self.tf(point_feats, slots, points_emb, Nr) # [B, N, D]
+
+ # point slot mapping
+ q = self.Wq(x)
+ k = self.Wk(slots)
+ logits = torch.matmul(q, k.transpose(-1, -2)) * self.scale # [B, N, K+1]
+ if self.force_bg:
+ out_idx = (points_coor.abs() > self.bg_bound).any(-1)[..., None].repeat(1, 1, self.num_slots+1) # [B, N]
+ out_idx[:, :, 0:2] = False
+ logits[out_idx] = -torch.inf
+ w = F.softmax(logits, dim=-1) # [B, N, K+1]
+ x = torch.matmul(w, slots) # [B, N, D]
+ x = self.out_proj(x)
+ if self.slot_density:
+ slot_sigma = F.relu(logits)
+ sigma = (slot_sigma[..., 1:] * w[..., 1:]).sum(-1)
+ sigma = sigma * self.density_scale.exp() # [B, N]
+ else:
+ sigma = None
+
+ return {'x': x, 'w': w, 'sigma': sigma}
+
+
diff --git a/model/transformer.py b/model/transformer.py
new file mode 100644
index 0000000..634dfcd
--- /dev/null
+++ b/model/transformer.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1., nonlinearity='relu'):
+
+ m = nn.Linear(in_features, out_features, bias)
+
+ if weight_init == 'kaiming':
+ nn.init.kaiming_uniform_(m.weight, nonlinearity=nonlinearity)
+ else:
+ nn.init.xavier_uniform_(m.weight, gain)
+
+ if bias:
+ nn.init.zeros_(m.bias)
+
+ return m
+
+
+class MultiHeadAttention(nn.Module):
+
+ def __init__(self, d_q, d_kv, d_proj, d_out, num_heads, dropout=0., gain=1., bias=True):
+ super().__init__()
+
+ assert d_proj % num_heads == 0, "d_proj must be divisible by num_heads"
+ self.num_heads = num_heads
+
+ self.attn_dropout = nn.Dropout(dropout)
+ self.output_dropout = nn.Dropout(dropout)
+
+ self.proj_q = linear(d_q, d_proj, bias=bias)
+ self.proj_k = linear(d_kv, d_proj, bias=bias)
+ self.proj_v = linear(d_kv, d_proj, bias=bias)
+ self.proj_o = linear(d_proj, d_out, bias=bias, gain=gain)
+
+
+ def forward(self, q, k, v, attn_mask=None):
+ B, T, _ = q.shape
+ _, S, _ = k.shape
+
+ q = self.proj_q(q).view(B, T, self.num_heads, -1).transpose(1, 2)
+ k = self.proj_k(k).view(B, S, self.num_heads, -1).transpose(1, 2)
+ v = self.proj_v(v).view(B, S, self.num_heads, -1).transpose(1, 2)
+
+ q = q * (q.shape[-1] ** (-0.5))
+ attn = torch.matmul(q, k.transpose(-1, -2))
+
+ if attn_mask is not None:
+ attn = attn.masked_fill(attn_mask, float('-inf'))
+
+ attn = F.softmax(attn, dim=-1)
+ attn_d = self.attn_dropout(attn)
+
+ output = torch.matmul(attn_d, v).transpose(1, 2).reshape(B, T, -1)
+ output = self.proj_o(output)
+ output = self.output_dropout(output)
+ return output
+
+
+class TransformerDecoderBlock(nn.Module):
+
+ def __init__(self, d_q, d_kv, num_heads=4, drop_path=0., dropout=0., gain=1., is_first=False):
+ super().__init__()
+
+ self.cross_attn = False
+ if d_kv > 0:
+ self.encoder_decoder_attn_layer_norm = nn.LayerNorm(d_q)
+ self.encoder_decoder_attn = MultiHeadAttention(d_q, d_kv, d_q, d_q, num_heads, dropout, gain)
+ self.cross_attn = True
+
+ self.ffn = nn.Sequential(
+ nn.LayerNorm(d_q),
+ linear(d_q, 4 * d_q, weight_init='kaiming'),
+ nn.ReLU(),
+ linear(4 * d_q, d_q, gain=gain),
+ nn.Dropout(dropout))
+
+ self.drop_path1 = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
+ self.drop_path2 = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.self_attn_layer_norm = nn.LayerNorm(d_q)
+ self.self_attn = MultiHeadAttention(d_q, d_q, d_q // 4, d_q, num_heads, dropout, gain)
+ self.conv = nn.Sequential(
+ nn.Conv1d(d_q, d_q, kernel_size=3, padding=1),
+ nn.ReLU(),
+ )
+
+ def forward(self, input, encoder_output, N):
+ """
+ input: batch_size x target_len x d_model
+ encoder_output: batch_size x source_len x d_model
+ return: batch_size x target_len x d_model
+ """
+ B, L, D = input.shape
+ if self.cross_attn:
+ x = self.encoder_decoder_attn_layer_norm(input)
+ x = self.encoder_decoder_attn(x, encoder_output, encoder_output)
+ x = self.conv(x.reshape(B*N, -1, D).transpose(1, 2)).transpose(1, 2).reshape(B, L, D)
+ input = input + self.drop_path1(x)
+ x = self.ffn(x)
+ input = input + self.drop_path2(x)
+ x = self.self_attn_layer_norm(input)
+ x = x.reshape(B*N, -1, D)
+ x = self.self_attn(x, x, x)
+ x = input + x.reshape(B, -1, D)
+ return x
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, num_blocks, d_q, d_kv, num_heads, drop_path, dropout=0.):
+ super().__init__()
+ dpr = [x.item() for x in torch.linspace(0, drop_path, num_blocks)] # stochastic depth decay rule
+ gain = (3 * num_blocks) ** (-0.5)
+ self.blocks = nn.ModuleList(
+ [TransformerDecoderBlock(d_q, d_kv, num_heads, dpr[0], dropout, gain, is_first=True)] +
+ [TransformerDecoderBlock(d_q, d_kv, num_heads, dpr[i+1], dropout, gain, is_first=False)
+ for i in range(num_blocks - 1)])
+ self.layer_norm = nn.LayerNorm(d_q)
+
+ def forward(self, input, encoder_output, pos_emb, Nr):
+ """
+ input: batch_size x target_len x d_model
+ encoder_output: batch_size x source_len x d_model
+ return: batch_size x target_len x d_model
+ """
+ for i, block in enumerate(self.blocks):
+ input = block(input, encoder_output, Nr)
+ return self.layer_norm(input)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..8dbdab0
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+tqdm
+wandb
+seaborn
+scikit-learn==1.3.0
+pytorch-lightning==1.8.6
+tabulate==0.9.0
+hydra-core==1.3.1
+opencv-python==4.6.0.66
+timm==0.6.13
+einops==0.3.0
+transforms3d==0.4.1
+piq==0.8.0
+lpips==0.1.4
diff --git a/scripts/eval_dtu.sh b/scripts/eval_dtu.sh
new file mode 100644
index 0000000..d325f75
--- /dev/null
+++ b/scripts/eval_dtu.sh
@@ -0,0 +1,84 @@
+export CUDA_VISIBLE_DEVICES=$1
+conda activate slotlifter
+dataset=dtu
+ids=(1)
+# ids=(6 7)
+# ids=(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16)
+for scene in ${ids[@]};do
+ python trainer/train_panopli_tensorf.py \
+ cfg=${dataset} \
+ cfg.job_type='test' \
+ cfg.exp_name=vis_${scene} \
+ cfg.test_percent=1.0 \
+ cfg.num_workers=16 \
+ cfg.precision=32 \
+ cfg.num_slots=8 \
+ cfg.enc_type='ibrnet' \
+ cfg.num_src_view=4 \
+ cfg.ckpt_path=/home/yuliu/Projects/OR3D/runs/dtu/f32_8s_64+64_01282054/checkpoints/best.ckpt \
+ # cfg.scene_id=${scene} \
+
+done
+ # cfg.n_samples=96 \
+ # cfg.logger=wandb \
+ # cfg.coarse_to_fine=true \
+ # cfg.resume='/home/yuliu/Projects/OR3D/runs/dtu/test_01271617/checkpoints/last.ckpt' \
+ # cfg.seed=0 \
+ # cfg.random_proj=false \
+ # cfg.coarse_to_fine=true \
+ # cfg.n_samples=32 \
+ # cfg.n_samples_fine=64 \
+ # cfg.random_proj_ratio=0 \
+ # cfg.slot_density=false \
+ # cfg.random_proj=false \
+ # cfg.norm_scene=false \
+ # cfg.distance_scale=1. \
+ # cfg.lambda_depth=0 \
+ # cfg.lambda_geo=0.1 \
+ # cfg.n_samples=32 \
+ # cfg.n_samples_fine=64 \
+ # cfg.implicit_render=true \
+ # cfg.lambda_depth=0.1 \
+ # cfg.normalize=false \
+ # cfg.random_proj=false \
+ # cfg.rel_pos=true \
+ # cfg.rel_pos=true \
+ # cfg.feature_size=32 \
+ # cfg.normalize=false \
+ # cfg.n_samples=64 \
+ # cfg.coarse_to_fine=false \
+ # cfg.pe_feat=2 \
+ # cfg.distance_scale=25 \
+ # cfg.train_dataset='llff+ibrnet_collected' \
+ # cfg.lambda_depth=0.5 \
+ # cfg.sigma_steps=30000 \
+ # cfg.resume='/home/yuliu/Projects/OR3D/runs/scannet/8s_512_3090_09042226/checkpoints/last.ckpt' \
+ # cfg.lambda_corr=0.2 \
+ # cfg.lambda_feat=0 \
+ # cfg.recon_feat=true \
+ # cfg.random_proj=false \
+ # cfg.force_bg=false \
+ # cfg.lambda_feat=0 \
+ # cfg.recon_feat=true \
+ # cfg.lambda_corr=0.02 \
+ # cfg.n_samples=32 \
+ # cfg.n_samples_fine=64 \
+ # cfg.slot_dec_dim=128 \
+ # cfg.chunk=4096 \
+ # cfg.num_src_view=4 \
+ # cfg.num_slots=8 \
+ # cfg.bg_border=0.2 \
+ # cfg.slot_dec_dim=32 \
+ # cfg.nerf_mlp_dim=32 \
+ # cfg.num_src_view=10 \
+ # cfg.chunk=4096 \
+ # cfg.feature_size=64 \
+ # cfg.recon_feat=false \
+ # cfg.lambda_depth=0 \
+ # cfg.num_slots=8 \
+ # cfg.recon_feat=true \
+ # cfg.render_src_view=false \
+ # cfg.num_src_view=4 \
+ # cfg.random_proj=false \
+ # cfg.chunk=16384 \
+
\ No newline at end of file
diff --git a/scripts/eval_scannet.sh b/scripts/eval_scannet.sh
new file mode 100644
index 0000000..96ce1e9
--- /dev/null
+++ b/scripts/eval_scannet.sh
@@ -0,0 +1,16 @@
+export CUDA_VISIBLE_DEVICES=$1
+conda activate slotlifter
+dataset=scannet
+seed=0
+python trainer/train.py \
+ cfg=${dataset} \
+ cfg.job_type='test' \
+ cfg.exp_name="eval_scannet" \
+ cfg.num_workers=8 \
+ cfg.chunk=16384 \
+ cfg.test_percent=1.0 \
+ cfg.ckpt_path=checkpoints/scannet.ckpt \
+ cfg.num_slots=8 \
+ cfg.val_subsample_frames=1 \
+ cfg.seed=${seed} \
+ # cfg.logger=wandb \
diff --git a/scripts/eval_uorf_data.sh b/scripts/eval_uorf_data.sh
new file mode 100644
index 0000000..7c90a8a
--- /dev/null
+++ b/scripts/eval_uorf_data.sh
@@ -0,0 +1,25 @@
+export CUDA_VISIBLE_DEVICES=$1
+conda activate slotlifter
+dataset=uorf
+subset=clevr_567 # change the num_slots to 8 for clevr_567
+subset=room_chair
+# subset=room_diverse
+# subset=room_texture
+# subset=kitchen_matte
+# subset=kitchen_shiny
+seed=0
+python trainer/train.py \
+ cfg=${dataset} \
+ cfg.job_type='test' \
+ cfg.exp_name="eval_${subset}_seed1" \
+ cfg.group=${subset} \
+ cfg.subset=${subset} \
+ cfg.num_workers=8 \
+ cfg.chunk=16384 \
+ cfg.test_percent=1.0 \
+ cfg.ckpt_path=checkpoints/${subset}/seed1.ckpt \
+ cfg.num_slots=5 \
+ cfg.val_subsample_frames=1 \
+ cfg.seed=${seed} \
+ cfg.logger=wandb \
+
diff --git a/scripts/train_dtu.sh b/scripts/train_dtu.sh
new file mode 100644
index 0000000..b5e5143
--- /dev/null
+++ b/scripts/train_dtu.sh
@@ -0,0 +1,17 @@
+export CUDA_VISIBLE_DEVICES=6
+conda activate slotlifter
+dataset=dtu
+seed=0
+python trainer/train.py \
+ cfg=${dataset} \
+ cfg.job_type='train' \
+ cfg.exp_name="${dataset}_${seed}" \
+ cfg.batch_size=1 \
+ cfg.ray_batchsize=1024 \
+ cfg.val_check_interval=4 \
+ cfg.num_workers=16 \
+ cfg.num_slots=8 \
+ cfg.num_src_view=1 \
+ cfg.seed=${seed} \
+ # cfg.logger=wandb \
+
diff --git a/scripts/train_scannet.sh b/scripts/train_scannet.sh
new file mode 100644
index 0000000..93f91b0
--- /dev/null
+++ b/scripts/train_scannet.sh
@@ -0,0 +1,17 @@
+export CUDA_VISIBLE_DEVICES=7
+conda activate slotlifter
+dataset=scannet
+seed=0
+python trainer/train.py \
+ cfg=${dataset} \
+ cfg.job_type='train' \
+ cfg.exp_name="${dataset}_${seed}" \
+ cfg.batch_size=1 \
+ cfg.ray_batchsize=1024 \
+ cfg.val_check_interval=3 \
+ cfg.num_workers=16 \
+ cfg.num_slots=8 \
+ cfg.num_src_view=4 \
+ cfg.seed=${seed} \
+ # cfg.logger=wandb \
+
\ No newline at end of file
diff --git a/scripts/train_uorf_data.sh b/scripts/train_uorf_data.sh
new file mode 100644
index 0000000..fa169f3
--- /dev/null
+++ b/scripts/train_uorf_data.sh
@@ -0,0 +1,24 @@
+export CUDA_VISIBLE_DEVICES=6,7
+conda activate slotlifter
+dataset=uorf
+subset=clevr_567 # change the num_slots to 8 for clevr_567
+# subset=room_chair
+# subset=room_diverse
+# subset=room_texture
+# subset=kitchen_matte
+# subset=kitchen_shiny
+seed=0
+python trainer/train.py \
+ cfg=${dataset} \
+ cfg.job_type='train' \
+ cfg.exp_name="${subset}_${seed}" \
+ cfg.group=${subset} \
+ cfg.subset=${subset} \
+ cfg.num_workers=8 \
+ cfg.batch_size=2 \
+ cfg.ray_batchsize=1024 \
+ cfg.val_check_interval=10 \
+ cfg.num_slots=8 \
+ cfg.seed=${seed} \
+ cfg.logger=wandb \
+ # cfg.monitor=psnr \ # using for kitchen_shiny and kitchen_matte because these datasets do not have ground truth masks
\ No newline at end of file
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000..14aedae
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1,147 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import os
+from pathlib import Path
+from random import randint
+import datetime
+
+import torch
+import wandb
+from pytorch_lightning import Trainer, seed_everything
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+from pytorch_lightning.loggers import WandbLogger
+
+from util.distinct_colors import DistinctColors
+from util.misc import visualize_depth, get_boundary_mask
+from util.filesystem_logger import FilesystemLogger
+import torch.nn.functional as F
+
+
+def generate_experiment_name(config):
+ if config.resume is not None and config.job_type == 'train':
+ experiment = Path(config.resume).parents[1].name
+ # experiment = f"{config.exp_name}_{datetime.datetime.now().strftime('%m%d%H%M')}"
+ os.environ['experiment'] = experiment
+ elif not os.environ.get('experiment'):
+ experiment = f"{config.exp_name}_{datetime.datetime.now().strftime('%m%d%H%M')}"
+ os.environ['experiment'] = experiment
+ else:
+ experiment = os.environ['experiment']
+ return experiment
+
+
+def create_trainer(config):
+
+ config.exp_name = generate_experiment_name(config)
+ if config.val_check_interval > 1:
+ config.val_check_interval = int(config.val_check_interval)
+ if config.seed is None:
+ config.seed = randint(0, 999)
+
+ seed_everything(config.seed, workers=True)
+
+ # save code files
+ filesystem_logger = FilesystemLogger(config)
+
+ if config.logger == 'wandb':
+ logger = WandbLogger(
+ project=config.project,
+ entity=config.entity,
+ group=config.group,
+ name=config.exp_name,
+ job_type=config.job_type,
+ tags=config.tags,
+ notes=config.notes,
+ id=config.exp_name,
+ # settings=wandb.Settings(start_method='thread'),
+ )
+ else:
+ logger = False
+
+ checkpoint_callback = ModelCheckpoint(dirpath=(Path(config.log_path) / config.exp_name / "checkpoints"),
+ monitor=f'val/{config.monitor}',
+ save_top_k=1,
+ save_last=True,
+ mode='max',
+ )
+ callbacks = [LearningRateMonitor("step"), checkpoint_callback] if logger else []
+ gpu_count = torch.cuda.device_count()
+ if config.job_type == 'debug':
+ config.train_percent = 30
+ config.val_percent = 1
+ config.test_percent = 1
+ config.val_check_interval = 1
+
+ kwargs = {
+ 'resume_from_checkpoint': config.resume,
+ 'logger': logger,
+ 'accelerator': 'gpu',
+ 'devices': gpu_count,
+ 'strategy': 'ddp' if gpu_count > 1 else None,
+ 'num_sanity_val_steps': 1,
+ 'max_steps': config.max_steps,
+ 'max_epochs': config.max_epochs,
+ 'limit_train_batches': config.train_percent,
+ 'limit_val_batches': config.val_percent,
+ 'limit_test_batches': config.test_percent,
+ 'val_check_interval': float(min(config.val_check_interval, 1)),
+ 'check_val_every_n_epoch': max(1, config.val_check_interval),
+ 'callbacks': callbacks,
+ 'gradient_clip_val': config.grad_clip if config.grad_clip > 0 else None,
+ 'precision': config.precision,
+ 'profiler': config.profiler,
+ 'benchmark': config.benchmark,
+ 'deterministic': config.deterministic,
+ }
+ trainer = Trainer(**kwargs)
+ return trainer
+
+
+def visualize_panoptic_outputs(p_rgb, p_instances, p_depth, rgb, semantics, instances, H, W, depth=None, p_semantics=None):
+ alpha = 0.65
+ distinct_colors = DistinctColors()
+ p_rgb = p_rgb.cpu()
+ img = p_rgb.view(H, W, 3).cpu().permute(2, 0, 1)
+
+ depth_scale = 1
+ if p_depth is not None:
+ p_depth = visualize_depth(p_depth.view(H, W) * depth_scale, use_global_norm=False)
+ else:
+ p_depth = torch.zeros_like(img)
+ if depth is not None:
+ depth = visualize_depth(depth.view(H, W) * depth_scale, use_global_norm=False)
+ else:
+ depth = torch.zeros_like(img)
+
+ def get_color(p_instances, idx_bg, im, rgb_):
+ colored_img_instance = distinct_colors.apply_colors_fast_torch(p_instances).float()
+ # boundaries_img_instances = get_boundary_mask(p_instances.view(H, W))
+ # colored_img_instance[p_instances == idx_bg, :] = rgb_[p_instances == idx_bg, :]
+ img_instances = colored_img_instance.view(H, W, 3).permute(2, 0, 1) * alpha + im * (1 - alpha)
+ # img_instances[:, boundaries_img_instances > 0] = 0
+ return img_instances
+
+ img_gt = rgb.view(H, W, 3).permute(2, 0, 1)
+ idx_bg = p_instances.sum(0).argmax().item()
+ p_instances = p_instances.argmax(dim=1).cpu()
+ img_instances = get_color(p_instances, idx_bg, img, p_rgb)
+
+ if semantics is not None and semantics.max() > 0:
+ img_semantics_gt = distinct_colors.apply_colors_fast_torch(semantics).view(H, W, 3).permute(2, 0, 1) * alpha + img_gt * (1 - alpha)
+ boundaries_img_semantics_gt = get_boundary_mask(semantics.view(H, W))
+ img_semantics_gt[:, boundaries_img_semantics_gt > 0] = 0
+ else:
+ img_semantics_gt = torch.zeros_like(img_gt)
+ if p_semantics is not None and p_semantics.max() > 0:
+ p_semantics = p_semantics.argmax(dim=1).cpu()
+ img_semantics = distinct_colors.apply_colors_fast_torch(p_semantics).view(H, W, 3).permute(2, 0, 1) * alpha + img * (1 - alpha)
+ boundaries_img_semantics = get_boundary_mask(p_semantics.view(H, W))
+ img_semantics[:, boundaries_img_semantics > 0] = 0
+ else:
+ img_semantics = torch.zeros_like(img_gt)
+ if instances is not None and instances.max() > 0:
+ img_instances_gt = get_color(instances.long(), 0, img_gt, rgb)
+ else:
+ img_instances_gt = torch.zeros_like(img_gt)
+ stack = torch.cat([torch.stack([img_gt, img_semantics_gt, img_instances_gt, depth]), torch.stack([img, img_semantics, img_instances, p_depth])], dim=0)
+ return stack
\ No newline at end of file
diff --git a/trainer/train.py b/trainer/train.py
new file mode 100644
index 0000000..b1fbc9e
--- /dev/null
+++ b/trainer/train.py
@@ -0,0 +1,492 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import math
+from pathlib import Path
+import torch
+import hydra
+import torch.nn.functional as F
+import pytorch_lightning as pl
+from pytorch_lightning.utilities import rank_zero_only
+from tabulate import tabulate
+from torchvision.utils import save_image, make_grid
+from torch.utils.data import DataLoader
+from dataset import get_dataset
+from model.nerf import NeRF
+from model.renderer import NeRFRenderer
+from trainer import create_trainer, visualize_panoptic_outputs
+from util.misc import visualize_depth
+from util.metrics import SegMetrics, ReconMetrics
+from model.slot_attn import Slot3D
+from torch import optim
+from util.optimizer import Lion
+from model.slot_attn import JointDecoder
+from PIL import Image
+import numpy as np
+import seaborn as sns
+
+
+def segmentation_to_rgb(seg, palette=None, num_objects=None, bg_color=(0, 0, 0)):
+ seg = seg[..., None]
+ if num_objects is None:
+ num_objects = np.max(seg) # assume consecutive numbering
+ num_objects += 1 # background
+ if palette is None:
+ # palette = [bg_color] + sns.color_palette('hls', num_objects-1)
+ palette = sns.color_palette('hls', num_objects)
+
+ seg_img = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.float32)
+ for i in range(num_objects):
+ seg_img[np.nonzero(seg[:, :, 0] == i)] = palette[i]
+ return seg_img
+
+
+def save_img_from_tensor(img, path):
+ r'''
+ img: tensor, [H, W, 3]
+ '''
+ img = img.cpu().numpy()
+ img = (img * 255).astype(np.uint8)
+ img = Image.fromarray(img)
+ img.save(path)
+
+
+def save_seg_from_tensor(seg, path):
+ r'''
+ seg: tensor, [H, W]
+ '''
+ seg = seg.cpu().numpy()
+ seg = segmentation_to_rgb(seg)
+ seg = (seg * 255).astype(np.uint8)
+ seg = Image.fromarray(seg)
+ seg.save(path)
+
+
+class TensoRFTrainer(pl.LightningModule):
+
+ def __init__(self, config):
+ super().__init__()
+ self.train_set, self.val_set, self.test_set = get_dataset(config)
+ self.train_sampler = None
+ if config.visualized_indices is None:
+ if config.job_type == 'vis':
+ config.visualized_indices = torch.arange(len(self.val_set)).tolist()
+ else:
+ config.visualized_indices = torch.randperm(len(self.val_set))[:4].tolist()
+ if config.job_type == 'debug':
+ config.visualized_indices = config.visualized_indices[0:1]
+ self.instance_steps = config.instance_steps
+ self.save_hyperparameters(config)
+
+ self.slot_3D = Slot3D(config)
+ self.slot_dec = JointDecoder(config)
+ self.slot_dec_fine = JointDecoder(config) if config.coarse_to_fine else None
+ self.nerf = NeRF(config, n_samples=config.n_samples)
+ self.nerf_fine = NeRF(config, n_samples=config.n_samples+config.n_samples_fine) if config.coarse_to_fine else None
+ depth_range = min(self.train_set.depth_range[0], self.val_set.depth_range[0], self.test_set.depth_range[0]), \
+ max(self.train_set.depth_range[1], self.val_set.depth_range[1], self.test_set.depth_range[1])
+ self.renderer = NeRFRenderer(depth_range, cfg=config)
+
+ self.loss = torch.nn.MSELoss(reduction='mean')
+
+ self.cfg = config
+ self.output_dir_result_images = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/images')
+ self.output_dir_result_images.mkdir(exist_ok=True)
+ self.output_dir_result_seg = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/seg')
+ self.output_dir_result_seg.mkdir(exist_ok=True)
+ self.output_dir_result_depth = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/depth')
+ self.output_dir_result_depth.mkdir(exist_ok=True)
+ self.output_dir_result_attn = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/attn')
+ self.output_dir_result_attn.mkdir(exist_ok=True)
+ self.sigma = 1.0
+
+ self.seg_metrics = SegMetrics(['ari', 'ari_fg'])
+ self.recon_metrics = ReconMetrics(config.lpips_net)
+ self.recon_rgb = config.get('recon_rgb', True)
+
+
+ def configure_optimizers(self):
+ warmup_steps = self.cfg.warmup_steps
+ min_lr_factor = self.cfg.min_lr_factor
+ decay_steps = self.cfg.decay_steps
+ def lr_warmup_exp_decay(step: int):
+ factor = min(1, step / (warmup_steps + 1e-6))
+ decay_factor = 0.5 ** (step / decay_steps * 1.5)
+ return factor * decay_factor * (1 - min_lr_factor) + min_lr_factor
+ def lr_exp_decay(step: int):
+ decay_factor = 0.5 ** (step / decay_steps)
+ return decay_factor * (1 - min_lr_factor) + min_lr_factor
+ params_nerf = [{'params': self.nerf.parameters(),
+ 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}]
+ params_renderer = [{'params': self.renderer.parameters(),
+ 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}]
+ params_slot_enc = [{'params': (x[1] for x in self.slot_3D.named_parameters() if 'dino' not in x[0]),
+ 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}]
+ params_slot_dec = [{'params': self.slot_dec.parameters(),
+ 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}]
+ params = params_nerf + params_renderer + params_slot_enc + params_slot_dec
+ lr_lambda_list = [lr_exp_decay, lr_exp_decay, lr_warmup_exp_decay, lr_warmup_exp_decay]
+
+ if self.cfg.grid_init == 'tensorf' or self.cfg.grid_init == '3d':
+ params_grid = [{'params': [x[1] for x in self.renderer.named_parameters() if 'grid' in x[0]],
+ 'lr': self.cfg.lr * 20, 'weight_decay': self.cfg.weight_decay}]
+ params = params + params_grid
+ lr_lambda_list = lr_lambda_list + [lr_exp_decay]
+
+ if self.cfg.coarse_to_fine:
+ params_nerf_fine = [{'params': self.nerf_fine.parameters(),
+ 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}]
+ params_slot_dec_fine = [{'params': self.slot_dec_fine.parameters(),
+ 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}]
+ params = params + params_nerf_fine + params_slot_dec_fine
+ lr_lambda_list = lr_lambda_list + [lr_exp_decay] + [lr_warmup_exp_decay]
+
+ opt = Lion(params, weight_decay=self.cfg.weight_decay)
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer=opt, lr_lambda=lr_lambda_list)
+
+ return [opt], [{"scheduler": scheduler, "interval": "step"}]
+
+ def forward(self, rays, depth_range, slots, view_feats, cam, src_rgbs, src_cams, is_train):
+ B, Nr, _ = rays.shape
+ outputs = []
+ render_depth = not is_train
+ render_ins = self.global_step >= self.instance_steps or not is_train
+ render_color = self.recon_rgb or not is_train
+ for i in range(0, Nr, self.cfg.chunk):
+ outputs.append(self.renderer(self.nerf, self.nerf_fine, self.slot_dec, self.slot_dec_fine,
+ rays[:, i: i + self.cfg.chunk], depth_range, slots,
+ cam, src_rgbs, src_cams, view_feats,
+ False, is_train,
+ render_color=render_color, render_depth=render_depth,
+ render_ins=render_ins))
+ keys = outputs[0].keys()
+ out = {}
+ for k in keys:
+ if 'dist' in k or 'loss' in k:
+ out[k] = torch.stack([o[k] for o in outputs], 0).mean()
+ else:
+ out[k] = torch.cat([o[k] for o in outputs], 1).flatten(0, 1)
+ return out
+
+ def training_step(self, batch, batch_idx):
+ self.sigma = self.cosine_anneal(self.global_step, self.cfg.sigma_steps, final_value=0)
+ if self.cfg.random_proj:
+ ratio = self.cosine_anneal(self.global_step, self.cfg.random_proj_steps, start_value=0.99, final_value=0)
+ self.nerf.random_proj_ratio = 1 - ratio
+ self.log('ratio', 1-ratio)
+ self.log('sigma', self.sigma)
+ src_rgbs, src_cams = batch['src_rgbs'], batch['src_cams'] # [B, N, H, W, 3], [B, N, 34]
+ B, N_views = src_rgbs.shape[:2]
+
+ src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W]
+ slots, attn, view_feats = self.slot_3D(sigma=self.sigma, images=src_rgbs, src_cams=src_cams)
+
+ rgbs = batch['rgbs'] # [B, Br, 3]
+ rays = batch['rays'] # [B, Br, 6]
+ depth_range = batch['depth_range'] # [B, 2]
+ view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W]
+ output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, True)
+
+ loss_rgb = self.loss(output['rgb_c'], rgbs.view(-1, 3))
+ if self.cfg.coarse_to_fine:
+ loss_rgb = (loss_rgb + self.loss(output['rgb_f'], rgbs.view(-1, 3))) / 2
+ if self.cfg.normalize:
+ loss_rgb = loss_rgb / 4
+ loss = loss_rgb
+ self.log("train/loss_rgb", loss_rgb, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
+ return loss
+
+ def on_validation_epoch_start(self):
+ torch.cuda.empty_cache()
+ self.recon_metrics.set_divice(self.device)
+
+ def validation_step(self, batch, batch_idx):
+ out_put = {}
+
+ src_rgbs, src_cams = batch['src_rgbs'], batch['src_cams'] # [B, N, H, W, 3], [B, N, 34]
+
+ B, N_views, H, W, _ = src_rgbs.shape
+ src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W]
+ slots, attn, view_feats = self.slot_3D(sigma=0, images=src_rgbs, src_cams=src_cams)
+
+ # get rays from cameras
+ N = batch['cam'].shape[1]
+ rays = batch['rays'].view(1, -1, 6) # [1, Br, 6]
+ depth_range = batch['depth_range'] # [B, 2]
+ view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W]
+ output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, False)
+
+ if self.cfg.coarse_to_fine:
+ output_rgb = output['rgb_f']
+ output_instances = output['instance_f']
+ else:
+ output_rgb = output['rgb_c']
+ output_instances = output['instance_c']
+
+ shape = (N, H, W, 3)
+ rgbs = batch['rgbs'].view(N, H*W, -1) # [1, N*H*W, 3]
+ if self.cfg.normalize:
+ output_rgb = output_rgb * 0.5 + 0.5
+ rgbs = rgbs * 0.5 + 0.5
+ if self.cfg.dataset == 'dtu' or self.cfg.dataset == 'ibrnet':
+ recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2))
+ out_put.update(recon_metrics)
+ else:
+ rs_instances = batch['instances'].view(N, -1) # [N, H*W]
+ if self.cfg.dataset != 'scannet' and self.cfg.dataset != 'oct':
+ Nv = self.cfg.num_src_view
+ recon_metrics = self.recon_metrics(output_rgb.view(shape)[Nv:].permute(0, 3, 1, 2), rgbs.view(shape)[Nv:].permute(0, 3, 1, 2))
+ seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[:Nv], rs_instances[:Nv])
+ out_put.update(recon_metrics)
+ out_put.update(seg_metrics)
+ # src_metircs = self.recon_metrics(output_rgb.view(shape)[:Nv].permute(0, 3, 1, 2), rgbs.view(shape)[:Nv].permute(0, 3, 1, 2))
+ # for key, value in src_metircs.items():
+ # out_put['src_' + key] = value
+ nv_seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[Nv:], rs_instances[Nv:])
+ for key, value in nv_seg_metrics.items():
+ out_put['nv_' + key] = value
+ else:
+ recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2))
+ K = output_instances.shape[-1]
+ seg_metrics = self.seg_metrics(output_instances.reshape(N, -1, K), rs_instances)
+ out_put.update(recon_metrics)
+ out_put.update(seg_metrics)
+
+ return out_put
+
+ def validation_epoch_end(self, outputs):
+ keys = outputs[0].keys()
+ logs = {}
+ for k in keys:
+ v = torch.stack([x[k] for x in outputs]).mean()
+ logs['val/' + k] = v
+ self.log_dict(logs, sync_dist=True)
+ table = [keys, ]
+ table.append(tuple([logs['val/' + key] for key in table[0]]))
+ print(tabulate(table, headers='firstrow', tablefmt='fancy_grid'))
+ self.visualize(self.val_dataloader())
+
+ @rank_zero_only
+ def visualize(self, dataloader):
+ (self.output_dir_result_seg / f"{self.global_step:06d}").mkdir(exist_ok=True)
+ for batch_idx, batch in enumerate(dataloader):
+ if batch_idx in self.cfg.visualized_indices:
+ cam = batch['cam'].reshape(1, -1, 34).to(self.device) # [1, N, 34]
+ rays = batch['rays'].reshape(1, -1, 6).to(self.device) # [1, NHW, 6]
+ NHW = rays.shape[1]
+ instances = batch.get('instances', torch.zeros(NHW))
+ rgbs = batch.get('rgbs', torch.zeros([NHW, 3]))
+ depth = batch.get('depth', torch.zeros_like(instances))
+ semantics = batch.get('semantics', torch.zeros_like(instances))
+ src_rgbs, src_cams = batch['src_rgbs'].to(self.device), batch['src_cams'].to(self.device)
+ B, N_views, H, W, _ = src_rgbs.shape
+ src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W]
+ slots, attn, view_feats = self.slot_3D(sigma=self.sigma, images=src_rgbs, src_cams=src_cams)
+ view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W]
+ depth_range = batch['depth_range'].to(self.device) # [B, 2]
+ output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, False)
+ if self.cfg.coarse_to_fine:
+ output_rgb = output['rgb_f']
+ output_instances = output['instance_f']
+ output_depth = output['depth_f']
+ else:
+ output_rgb = output['rgb_c']
+ output_instances = output['instance_c']
+ output_depth = output['depth_c']
+ if self.cfg.normalize:
+ output_rgb = output_rgb * 0.5 + 0.5
+ rgbs = rgbs * 0.5 + 0.5
+ src_rgbs = src_rgbs * 0.5 + 0.5
+ images = src_rgbs[0].reshape(-1, 3, H, W).cpu() # [N, 3, H, W]
+ N = cam.shape[1]
+ Nv = self.cfg.num_src_view
+ n_pad = 4 - Nv % 4 if Nv % 4 != 0 else 0
+ for frame_id in range(N):
+ stack = visualize_panoptic_outputs(output_rgb.view(N, H*W, -1)[frame_id], output_instances.view(N, H*W, -1)[frame_id],
+ output_depth.view(N, H*W)[frame_id], rgbs.view(N, H*W, -1)[frame_id],
+ semantics.view(N, H*W)[frame_id], instances.view(N, H*W)[frame_id],
+ H, W, depth.view(N, H*W)[frame_id])
+ stack = torch.cat([images, torch.zeros(n_pad, 3, H, W), stack], dim=0)
+ if self.cfg.logger == 'wandb':
+ self.logger.log_image(key=f"images/{batch_idx:04d}_{frame_id:04d}", images=[make_grid(stack, value_range=(0, 1), nrow=4, normalize=True)])
+ save_image(stack, self.output_dir_result_images / f"{self.global_step:06d}_{batch_idx:04d}_{frame_id:04d}.jpg", value_range=(0, 1), nrow=4, normalize=True)
+ if self.cfg.dataset == 'dtu':
+ H1, W1 = 76, 100
+ else:
+ H1, W1 = H // 4, W // 4
+ attn = attn.reshape(-1, src_rgbs.shape[1], H1, W1)
+ attn = F.interpolate(attn, size=(H, W), mode='nearest')
+ attn = attn.permute(1, 0, 2, 3).unsqueeze(2).repeat(1, 1, 3, 1, 1)
+ K = attn.shape[1]
+ img = torch.cat([images.unsqueeze(1).cpu(), 1 - attn.cpu()], dim=1).reshape(-1, 3, H, W)
+ if self.cfg.logger == 'wandb':
+ self.logger.log_image(key=f"attn/{batch_idx:04d}", images=[make_grid(img, value_range=(0, 1), nrow=K+1, normalize=True)])
+
+ def on_test_epoch_start(self):
+ torch.cuda.empty_cache()
+ self.recon_metrics.set_divice(self.device)
+
+ def test_step(self, batch, batch_idx):
+ out_put = {}
+
+ src_rgbs, src_cams = batch['src_rgbs'], batch['src_cams'] # [B, N, H, W, 3], [B, N, 34]
+ B, Nv, H, W, _ = src_rgbs.shape
+ src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W]
+ slots, attn, view_feats = self.slot_3D( sigma=0, images=src_rgbs, src_cams=src_cams)
+
+ images = src_rgbs[0].reshape(-1, 3, H, W).cpu() # [N_src_view, 3, H, W]
+ H1, W1 = H // 4, W // 4
+ if self.cfg.dataset == 'dtu':
+ H1 = 76
+ attn = attn.reshape(-1, src_rgbs.shape[1], H1, W1)
+ attn = F.interpolate(attn, size=(H, W), mode='nearest')
+ attn = attn.permute(1, 0, 2, 3).unsqueeze(2).repeat(1, 1, 3, 1, 1)
+ img = torch.cat([images.unsqueeze(1).cpu(), 1 - attn.cpu()], dim=1).reshape(-1, 3, H, W)
+ img = make_grid(img, nrow=Nv+1, padding=0) # [3, (N_src_view+1)*H, W]
+ save_img_from_tensor(img.permute(1, 2, 0), self.output_dir_result_attn / f"{batch_idx:04d}_attn.png")
+
+ # get rays from cameras
+ cam = batch['cam'] # [1, N, 34]
+ N = cam.shape[1]
+ rays = batch['rays'].view(1, -1, 6) # [1, Br, 6]
+ depth_range = batch['depth_range'] # [B, 2]
+
+ view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W]
+ output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, False)
+ if self.cfg.coarse_to_fine:
+ output_rgb = output['rgb_f']
+ output_instances = output['instance_f']
+ output_depth = output['depth_f']
+ else:
+ output_rgb = output['rgb_c']
+ output_instances = output['instance_c']
+ output_depth = output['depth_c']
+
+ shape = (N, H, W, 3)
+ rgbs = batch['rgbs'].view(N, H*W, -1) # [1, N*H*W, 3]
+ if self.cfg.normalize:
+ output_rgb = output_rgb * 0.5 + 0.5
+ rgbs = rgbs * 0.5 + 0.5
+ if self.cfg.dataset == 'dtu':
+ recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2))
+ out_put.update(recon_metrics)
+ else:
+ rs_instances = batch['instances'].view(N, -1) # [N, H*W]
+ if self.cfg.dataset != 'scannet':
+ Nv = self.cfg.num_src_view
+ recon_metrics = self.recon_metrics(output_rgb.view(shape)[Nv:].permute(0, 3, 1, 2), rgbs.view(shape)[Nv:].permute(0, 3, 1, 2))
+ seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[:Nv], rs_instances[:Nv])
+ out_put.update(recon_metrics)
+ out_put.update(seg_metrics)
+ src_metircs = self.recon_metrics(output_rgb.view(shape)[:Nv].permute(0, 3, 1, 2), rgbs.view(shape)[:Nv].permute(0, 3, 1, 2))
+ for key, value in src_metircs.items():
+ out_put['src_' + key] = value
+ nv_seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[Nv:], rs_instances[Nv:])
+ for key, value in nv_seg_metrics.items():
+ out_put['nv_' + key] = value
+ else:
+ recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2))
+ K = output_instances.shape[-1]
+ seg_metrics = self.seg_metrics(output_instances.reshape(N, -1, K), rs_instances)
+ out_put.update(recon_metrics)
+ out_put.update(seg_metrics)
+
+ print(f'batch_idx: {batch_idx}')
+ for k, v in recon_metrics.items():
+ print(k, ': ', v.item())
+ for k, v in seg_metrics.items():
+ print(k, ': ', v.item())
+ print('-' * 40)
+ # save img
+ imgs_gt = rgbs.view(shape)
+ imgs_pred = output_rgb.view(shape)
+ # seg_gt = rs_instances.view(N, H, W)
+ seg_pred = output_instances.argmax(-1).view(N, H, W)
+ depth_pred = output_depth.view(N, H, W)
+ for n in range(N):
+ save_img_from_tensor(imgs_gt[n], self.output_dir_result_images / f"{batch_idx:04d}_{n:02d}_rgb_gt.png")
+ save_img_from_tensor(imgs_pred[n], self.output_dir_result_images / f"{batch_idx:04d}_{n:02d}_rgb_pred.png")
+ # save_seg_from_tensor(seg_gt[n], self.output_dir_result_seg / f"{batch_idx:04d}_{n:02d}_seg_gt.png")
+ save_seg_from_tensor(seg_pred[n], self.output_dir_result_seg / f"{batch_idx:04d}_{n:02d}_seg_pred.png")
+ depth = visualize_depth(depth_pred[n], use_global_norm=False) # [3, H, W]
+ save_img_from_tensor(depth.permute(1, 2, 0), self.output_dir_result_depth / f"{batch_idx:04d}_{n:02d}_depth_pred.png")
+ return out_put
+
+ def test_epoch_end(self, outputs):
+ keys = outputs[0].keys()
+ logs = {}
+ for k in keys:
+ v = torch.stack([x[k] for x in outputs]).mean()
+ logs['test/' + k] = v
+ self.log_dict(logs, sync_dist=True)
+ table = [keys, ]
+ table.append(tuple([logs['test/' + key] for key in table[0]]))
+ print(tabulate(table, headers='firstrow', tablefmt='fancy_grid'))
+ self.visualize(self.test_dataloader())
+
+ def train_dataloader(self):
+ shuffle = False if self.cfg.job_type == 'debug' else True
+ shuffle = shuffle and self.train_sampler is None
+ persistent_workers = self.cfg.num_workers > 0
+ return DataLoader(self.train_set, self.cfg.batch_size, shuffle=shuffle, pin_memory=True, num_workers=self.cfg.num_workers, sampler=self.train_sampler, persistent_workers=persistent_workers)
+
+ def val_dataloader(self):
+ return DataLoader(self.val_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers)
+
+ def test_dataloader(self):
+ return DataLoader(self.test_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers)
+
+ def on_train_epoch_start(self):
+ torch.cuda.empty_cache()
+ self.slot_dec.force_bg = self.cfg.force_bg and self.global_step < self.cfg.force_bg_steps
+ if self.cfg.coarse_to_fine:
+ print(f'Using {self.renderer.n_samples} points for coarse rendering, {self.renderer.n_samples_fine} points for fine rendering')
+ else:
+ print(f'Using {self.renderer.n_samples} points for rendering')
+
+ def cosine_anneal(self, step, final_step, start_step=0, start_value=1.0, final_value=0.1):
+ if start_value <= final_value or start_step >= final_step:
+ return final_value
+ if step < start_step:
+ value = start_value
+ elif step >= final_step:
+ value = final_value
+ else:
+ a = 0.5 * (start_value - final_value)
+ b = 0.5 * (start_value + final_value)
+ progress = (step - start_step) / (final_step - start_step)
+ value = a * math.cos(math.pi * progress) + b
+ return value
+
+
+@hydra.main(config_path='../config', config_name='config', version_base='1.2')
+def main(config):
+ config = config.cfg
+ trainer = create_trainer(config)
+ model = TensoRFTrainer(config)
+ if trainer.logger is not None and config.watch_model:
+ trainer.logger.watch(model)
+ if config.job_type == 'test':
+ ckpt = torch.load(config.ckpt_path)
+ model.load_state_dict(ckpt['state_dict'])
+ print(f'Load from checkpoint: {config.ckpt_path}')
+ trainer.test(model)
+ elif config.job_type == 'vis':
+ ckpt = torch.load(config.ckpt_path)
+ model.load_state_dict(ckpt['state_dict'])
+ print(f'Load from checkpoint: {config.ckpt_path}')
+ model.eval()
+ with torch.no_grad():
+ model.visualize(model.val_dataloader())
+ else:
+ trainer.fit(model)
+ trainer.test(model)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/trainer/visualize.py b/trainer/visualize.py
new file mode 100644
index 0000000..e87f1bb
--- /dev/null
+++ b/trainer/visualize.py
@@ -0,0 +1,201 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from tqdm import tqdm
+from pathlib import Path
+import torch
+import hydra
+import torch.nn.functional as F
+import pytorch_lightning as pl
+from torchvision.utils import make_grid
+from torch.utils.data import DataLoader
+from dataset import get_dataset
+from model.nerf import NeRF
+from model.renderer import NeRFRenderer
+from model.slot_attn import Slot3D
+from model.slot_attn import SlotMixerDecoder
+from util.misc import visualize_depth
+from PIL import Image
+import numpy as np
+import seaborn as sns
+
+
+def segmentation_to_rgb(seg, palette=None, num_objects=None, bg_color=(0, 0, 0)):
+ seg = seg[..., None]
+ if num_objects is None:
+ num_objects = np.max(seg) # assume consecutive numbering
+ num_objects += 1 # background
+ if palette is None:
+ # palette = [bg_color] + sns.color_palette('hls', num_objects-1)
+ palette = sns.color_palette('hls', num_objects)
+
+ seg_img = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.float32)
+ for i in range(num_objects):
+ seg_img[np.nonzero(seg[:, :, 0] == i)] = palette[i]
+ return seg_img
+
+
+def save_img_from_tensor(img, path, transform=False):
+ r'''
+ img: tensor, [H, W, 3]
+ '''
+ img = img.cpu().numpy() * 255
+ img = img.astype(np.uint8)
+ # if transform:
+ # brighten the image
+ # img = cv2.convertScaleAbs(img, alpha=1.8, beta=1)
+ img = Image.fromarray(img)
+ img.save(path)
+
+
+def save_seg_from_tensor(seg, path):
+ r'''
+ seg: tensor, [H, W]
+ '''
+ seg = seg.cpu().numpy()
+ seg = segmentation_to_rgb(seg)
+ seg = (seg * 255).astype(np.uint8)
+ seg = Image.fromarray(seg)
+ seg.save(path)
+
+
+class Visualizer(pl.LightningModule):
+
+ def __init__(self, config):
+ super().__init__()
+ self.train_set, self.val_set, self.test_set = get_dataset(config)
+
+ self.slot_3D = Slot3D(config)
+ self.slot_dec = SlotMixerDecoder(config)
+ self.slot_dec_fine = SlotMixerDecoder(config) if config.coarse_to_fine else None
+ self.nerf = NeRF(config, n_samples=config.n_samples)
+ self.nerf_fine = NeRF(config, n_samples=config.n_samples+config.n_samples_fine) if config.coarse_to_fine else None
+ self.depth_range = min(self.train_set.depth_range[0], self.val_set.depth_range[0], self.test_set.depth_range[0]), \
+ max(self.train_set.depth_range[1], self.val_set.depth_range[1], self.test_set.depth_range[1])
+ self.renderer = NeRFRenderer(self.depth_range, cfg=config)
+
+ self.cfg = config
+ self.output_dir_result_images = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/images')
+ self.output_dir_result_images.mkdir(exist_ok=True)
+ self.output_dir_result_seg = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/seg')
+ self.output_dir_result_seg.mkdir(exist_ok=True)
+ self.output_dir_result_depth = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/depth')
+ self.output_dir_result_depth.mkdir(exist_ok=True)
+ self.output_dir_result_attn = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/attn')
+ self.output_dir_result_attn.mkdir(exist_ok=True)
+
+ self.scene_id = config.get('scene_id', 0)
+
+ def forward(self, rays, depth_range, slots, view_feats, cam, src_rgbs, src_cams, is_train):
+ B, Nr, _ = rays.shape
+ outputs = []
+ render_depth = True
+ render_ins = True
+ render_feat = False
+ render_color = True
+ for i in range(0, Nr, self.cfg.chunk):
+ outputs.append(self.renderer(self.nerf, self.nerf_fine, self.slot_dec, self.slot_dec_fine,
+ rays[:, i: i + self.cfg.chunk], depth_range, slots,
+ cam, src_rgbs, src_cams, view_feats,
+ False, is_train,
+ render_color=render_color, render_depth=render_depth,
+ render_sem=False, render_ins=render_ins, render_feat=render_feat))
+ keys = outputs[0].keys()
+ out = {}
+ for k in keys:
+ if 'dist' in k or 'loss' in k:
+ out[k] = torch.stack([o[k] for o in outputs], 0).mean()
+ else:
+ out[k] = torch.cat([o[k] for o in outputs], 1).flatten(0, 1)
+ return out
+
+ def visualize(self, dataloader):
+ for batch_idx, batch in tqdm(enumerate(dataloader)):
+ if batch_idx == self.scene_id or self.cfg.dataset == 'scannet' or self.cfg.dataset == 'dtu':
+ src_rgbs, src_cams = batch['src_rgbs'].to(self.device), batch['src_cams'].to(self.device)
+ B, Nv, H, W, _ = src_rgbs.shape
+ src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W]
+ slots, attn, view_feats = self.slot_3D(None, sigma=0, images=src_rgbs, src_cams=src_cams)
+ view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W]
+
+ images = src_rgbs[0].reshape(-1, 3, H, W).cpu() # [N_src_view, 3, H, W]
+ H1, W1 = H // 4, W // 4
+ if self.cfg.dataset == 'dtu':
+ H1 = 76
+ attn = attn.reshape(-1, src_rgbs.shape[1], H1, W1)
+ attn = F.interpolate(attn, size=(H, W), mode='nearest')
+ attn = attn.permute(1, 0, 2, 3).unsqueeze(2).repeat(1, 1, 3, 1, 1)
+ img = torch.cat([images.unsqueeze(1).cpu(), 1 - attn.cpu()], dim=1).reshape(-1, 3, H, W)
+ img = make_grid(img, nrow=img.shape[0], padding=0) # [3, (N_src_view+1)*H, W]
+ save_img_from_tensor(img.permute(1, 2, 0), self.output_dir_result_attn / f"{batch_idx:04d}_attn.png")
+ # input_img = images[0].permute(1, 2, 0) # [H, W, 3]
+ # save_img_from_tensor(input_img, self.output_dir_result_images / f"{batch_idx:04d}_input.png", True)
+ # instances = batch['instances'][0].view(H, W) # [H, W]
+ # save_seg_from_tensor(instances, self.output_dir_result_seg / f"{batch_idx:04d}_seg_gt.png")
+ if self.cfg.dataset == 'scannet':
+ gt_img = batch['rgbs'].view(H, W, 3) / 2 + 0.5 # [H, W, 3]
+ save_img_from_tensor(gt_img, self.output_dir_result_images / f"{batch_idx:04d}_rgb_gt.png", True)
+ gt_seg = batch['instances'].view(H, W) # [H, W]
+ save_seg_from_tensor(gt_seg, self.output_dir_result_seg / f"{batch_idx:04d}_seg_gt.png")
+
+ depth_range = batch['depth_range'].to(self.device) # [B, 2]
+ N = self.cfg.num_vis
+ all_rays = batch['rays'][0].to(self.device) # [N, HW, 6]
+ for n in tqdm(range(N)):
+ rays = all_rays[n:n+1]# [1, HW, 6]
+ output = self(rays, depth_range, slots, view_feats, batch.get('azi_rot'), src_rgbs, src_cams, False)
+
+ if self.cfg.coarse_to_fine:
+ output_rgb = output['rgb_f']
+ output_instances = output['instance_f']
+ output_depth = output['depth_f']
+ else:
+ output_rgb = output['rgb_c']
+ output_instances = output['instance_c']
+ output_depth = output['depth_c']
+
+ if self.cfg.normalize:
+ output_rgb = output_rgb * 0.5 + 0.5
+ src_rgbs = src_rgbs * 0.5 + 0.5
+
+ shape = (H, W, 3)
+ imgs_pred = output_rgb.view(shape)
+ seg_pred = output_instances.argmax(-1).view(H, W)
+ depth_pred = output_depth.view(H, W)
+ save_img_from_tensor(imgs_pred, self.output_dir_result_images / f"{batch_idx:04d}_{n:02d}_rgb_pred.png", True)
+ save_seg_from_tensor(seg_pred, self.output_dir_result_seg / f"{batch_idx:04d}_{n:02d}_seg_pred.png")
+ depth = visualize_depth(depth_pred, maxval=self.depth_range[1], use_global_norm=True) # [3, H, W]
+ save_img_from_tensor(depth.permute(1, 2, 0), self.output_dir_result_depth / f"{batch_idx:04d}_{n:02d}_depth_pred.png")
+ # break
+
+ def train_dataloader(self):
+ return DataLoader(self.train_set, batch_size=1, shuffle=True, pin_memory=True, num_workers=self.cfg.num_workers)
+
+ def val_dataloader(self):
+ return DataLoader(self.val_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers)
+
+ def test_dataloader(self):
+ return DataLoader(self.test_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers)
+
+
+@hydra.main(config_path='../config', config_name='config', version_base='1.2')
+def main(config):
+ config = config.cfg
+ result_path = Path(f'{config.log_path}/{config.exp_name}')
+ result_path.mkdir(exist_ok=True)
+ model = Visualizer(config)
+ ckpt = torch.load(config.ckpt_path)
+ model.load_state_dict(ckpt['state_dict'])
+ print(f'Load from checkpoint: {config.ckpt_path}')
+ model.cuda()
+ model.eval()
+ with torch.no_grad():
+ model.visualize(model.val_dataloader())
+
+if __name__ == '__main__':
+ main()
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/util/camera.py b/util/camera.py
new file mode 100644
index 0000000..5e9850b
--- /dev/null
+++ b/util/camera.py
@@ -0,0 +1,307 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import math
+from scipy.spatial.transform import Rotation
+import numpy as np
+import torch
+from einops import repeat
+
+from util.transforms import trs_comp, dot
+
+
+def frustum_world_bounds(dims, intrinsics, cam2worlds, max_depth, form='bbox'):
+ """Compute bounds defined by the frustum provided cameras
+ Args:
+ dims (N,2): heights,widths of cameras
+ intrinsics (N,3,3): intrinsics (unnormalized, hence HW required)
+ cam2worlds (N,4,4): camera to world transformations
+ max_depth (float): depth of all frustums
+ form (str): bbox: convex bounding box, sphere: convex bounding sphere
+ """
+ # unproject corner points
+ h_img_corners = torch.Tensor([[0, 1, 1], [1, 0, 1], [1, 1, 1]])
+ intrinsics_inv = torch.linalg.inv(intrinsics[:, [1, 0, 2]]) # K in WH -> convert to HW
+ k = len(h_img_corners)
+ n = len(dims)
+ rep_HWds = repeat(torch.cat([dims, torch.ones((n, 1))], 1), "n c -> n k c", k=k)
+ skel_pts = rep_HWds * repeat(h_img_corners, "k c -> n k c", n=n) # (N,K,(hwd))
+ corners_cam_a = torch.einsum("nkij,nkj->nki", repeat(intrinsics_inv, "n x y -> n k x y", k=k), skel_pts) * max_depth
+ corners_cam_b = torch.einsum("nkij,nkj->nki", repeat(intrinsics_inv, "n x y -> n k x y", k=k), skel_pts) * 0.01
+ # nihalsid: adding corner with max depth, and corners with min depth
+ corners_cam = torch.cat([corners_cam_a, corners_cam_b], 0)
+ corners_cam_h = torch.cat([corners_cam, torch.ones(corners_cam.shape[0], corners_cam.shape[1], 1)], -1)
+
+ corners_world_h = torch.einsum("nij,nkj->nki", cam2worlds.repeat(2, 1, 1), corners_cam_h)
+ corners_world_flat = corners_world_h.reshape(-1, 4)[:, :3]
+ if form == 'bbox':
+ bounds = torch.stack([corners_world_flat.min(0).values, corners_world_flat.max(0).values])
+ return bounds
+ elif form == 'sphere':
+ corners_world_center = torch.mean(corners_world_flat, 0)
+ sphere_radius = torch.max(torch.norm((corners_world_flat - corners_world_center), dim=1))
+
+ # todo: remove visualization
+ ##############################
+ # from util.misc import visualize_points
+ # import trimesh
+ # visualize_points(corners_world_flat, "runs/world_flat.obj")
+ # sphere = trimesh.creation.icosphere(radius=sphere_radius.numpy())
+ # sphere.apply_translation(corners_world_center)
+ # sphere.export("runs/world_bounds_sphere.obj")
+ ##############################
+
+ return corners_world_center, sphere_radius
+ else:
+ raise Exception("Not implemented yet: Ellipsoid for example")
+
+
+def compute_world2normscene(dims, intrinsics, cam2worlds, max_depth, rescale_factor=1.0):
+ """Compute transform converting world to a normalized space enclosing all
+ cameras frustums (given depth) into a unit sphere
+ Note: max_depth=0 -> camera positions only are contained (like NeRF++ does it)
+
+ Args:
+ dims (N,2): heights,widths of cameras
+ intrinsics (N,3,3): intrinsics (unnormalized, hence HW required)
+ cam2worlds (N,4,4): camera to world transformations
+ max_depth (float): depth of all frustums
+ rescale_factor (float)>=1.0: factor to scale the world space even further so no camera is too close to the unit sphere surface
+ """
+ assert rescale_factor >= 1.0, "prevent cameras outside of unit sphere"
+
+ sphere_center, sphere_radius = frustum_world_bounds(dims, intrinsics, cam2worlds, max_depth, 'sphere') # sphere containing frustums
+ world2nscene = trs_comp(-sphere_center / (rescale_factor * sphere_radius), torch.eye(3), 1 / (rescale_factor * sphere_radius))
+
+ return world2nscene
+
+
+def depth_to_distance(depth, intrinsics):
+ uv = np.stack(np.meshgrid(
+ np.arange(depth.shape[1]),
+ np.arange(depth.shape[0])
+ ), -1).reshape(-1, 2)
+ depth = depth.reshape(-1)
+ uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1)
+ return depth * np.linalg.norm((np.linalg.inv(intrinsics) @ uvh.T).T, axis=1)
+
+
+def distance_to_depth(K, dist, uv=None):
+ if uv is None and len(dist.shape) >= 2:
+ # create mesh grid according to d
+ uv = np.stack(np.meshgrid(np.arange(dist.shape[1]), np.arange(dist.shape[0])), -1)
+ uv = uv.reshape(-1, 2)
+ dist = dist.reshape(-1)
+ if not isinstance(dist, np.ndarray):
+ uv = torch.from_numpy(uv).to(dist)
+ if isinstance(dist, np.ndarray):
+ # z * np.sqrt(x_temp**2+y_temp**2+z_temp**2) = dist
+ uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1)
+ temp_point = dot(np.linalg.inv(K), uvh)
+ z = dist / np.linalg.norm(temp_point, axis=1)
+
+ else:
+ uvh = torch.cat([uv, torch.ones(len(uv), 1).to(uv)], -1)
+ temp_point = dot(torch.inverse(K), uvh)
+ z = dist / torch.linalg.norm(temp_point, dim=1)
+ return z
+
+
+def unproject_2d_3d(cam2world, intrinsics, depth, dims):
+ uv = np.stack(np.meshgrid(np.arange(dims[0]), np.arange(dims[1])), -1)
+ uv = torch.from_numpy(uv.reshape(-1, 2))
+ uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1)
+ cam_point = (torch.linalg.inv(intrinsics) @ torch.from_numpy(uvh).float().T).T * depth[:, None]
+ world_point = (cam2world[:3, :3] @ cam_point.T).T + cam2world[:3, 3]
+ return world_point
+
+
+def project_3d_2d(cam2world, K, world_point, with_dist=False, discrete=True, round=True):
+ if isinstance(world_point, np.ndarray):
+ cam_point = dot(np.linalg.inv(cam2world), world_point)
+ point_dist = np.sqrt((cam_point ** 2).sum(-1))
+ img_point = dot(K, cam_point)
+ uv_point = img_point[:, :2] / img_point[:, 2][:, None]
+ if discrete:
+ if round:
+ uv_point = np.round(uv_point)
+ uv_point = uv_point.astype(np.int)
+ if with_dist:
+ return uv_point, img_point[:, 2], point_dist
+ return uv_point
+ else:
+ cam_point = dot(torch.inverse(cam2world), world_point)
+ point_dist = (cam_point ** 2).sum(-1).sqrt()
+ img_point = dot(K, cam_point)
+ uv_point = img_point[:, :2] / img_point[:, 2][:, None]
+ if discrete:
+ if round:
+ uv_point = torch.round(uv_point)
+ uv_point = uv_point.int()
+ if with_dist:
+ return uv_point, img_point[:, 2], point_dist
+
+ return uv_point
+
+
+def auto_orient_poses(poses, method="up"):
+ """Orients and centers the poses. We provide two methods for orientation: pca and up.
+ pca: Orient the poses so that the principal component of the points is aligned with the axes.
+ This method works well when all of the cameras are in the same plane.
+ up: Orient the poses so that the average up vector is aligned with the z axis.
+ This method works well when images are not at arbitrary angles.
+ Args:
+ poses: The poses to orient.
+ method: The method to use for orientation. Either "pca" or "up".
+ Returns:
+ The oriented poses.
+ borrowed from from nerfstudio
+ """
+ translation = poses[..., :3, 3]
+
+ mean_translation = torch.mean(translation, dim=0)
+ translation = translation - mean_translation
+
+ if method == "pca":
+ _, eigvec = torch.linalg.eigh(translation.T @ translation)
+ eigvec = torch.flip(eigvec, dims=(-1,))
+
+ if torch.linalg.det(eigvec) < 0:
+ eigvec[:, 2] = -eigvec[:, 2]
+
+ transform = torch.cat([eigvec, eigvec @ -mean_translation[..., None]], dim=-1)
+ oriented_poses = transform @ poses
+
+ if oriented_poses.mean(axis=0)[2, 1] < 0:
+ oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3]
+ elif method == "up":
+ up = torch.mean(poses[:, :3, 1], dim=0)
+ up = up / torch.linalg.norm(up)
+
+ rotation = rotation_matrix(up, torch.Tensor([0, 0, 1]))
+ transform = torch.cat([rotation, rotation @ -mean_translation[..., None]], dim=-1)
+ oriented_poses = transform @ poses
+
+ return oriented_poses
+
+
+def rotation_matrix(a, b):
+ """Compute the rotation matrix that rotates vector a to vector b.
+ Args:
+ a: The vector to rotate.
+ b: The vector to rotate to.
+ Returns:
+ The rotation matrix.
+ borrowed from from nerfstudio
+ """
+ a = a / torch.linalg.norm(a)
+ b = b / torch.linalg.norm(b)
+ v = torch.cross(a, b)
+ c = torch.dot(a, b)
+ # If vectors are exactly opposite, we add a little noise to one of them
+ if c < -1 + 1e-8:
+ eps = (torch.rand(3) - 0.5) * 0.01
+ return rotation_matrix(a + eps, b)
+ s = torch.linalg.norm(v)
+ skew_sym_mat = torch.Tensor(
+ [
+ [0, -v[2], v[1]],
+ [v[2], 0, -v[0]],
+ [-v[1], v[0], 0],
+ ]
+ )
+ return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s ** 2 + 1e-8))
+
+
+def _compute_residual_and_jacobian(x, y, xd, yd, k1=0.0, k2=0.0, k3=0.0, k4=0.0, p1=0.0, p2=0.0):
+ """Auxiliary function of radial_and_tangential_undistort()."""
+ # Adapted from https://github.com/google/nerfies/blob/main/nerfies/camera.py
+ # let r(x, y) = x^2 + y^2;
+ # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +
+ # k4 * r(x, y)^4;
+ r = x * x + y * y
+ d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4)))
+
+ # The perfect projection is:
+ # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
+ # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
+ #
+ # Let's define
+ #
+ # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
+ # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
+ #
+ # We are looking for a solution that satisfies
+ # fx(x, y) = fy(x, y) = 0;
+ fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
+ fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
+
+ # Compute derivative of d over [x, y]
+ d_r = (k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4)))
+ d_x = 2.0 * x * d_r
+ d_y = 2.0 * y * d_r
+
+ # Compute derivative of fx over x and y.
+ fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
+ fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
+
+ # Compute derivative of fy over x and y.
+ fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
+ fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
+
+ return fx, fy, fx_x, fx_y, fy_x, fy_y
+
+
+def rotate(p, angle, o=[0, 0]):
+ '''
+ Rotate point p around point o with angle
+ p: (x, y)
+ o: (x, y)
+ angle: degree
+ H: height of image
+ '''
+ x = o[0] + math.cos(angle) * (p[0] - o[0]) - math.sin(angle) * (o[1] - p[1])
+ y = o[1] - math.sin(angle) * (p[0] - o[0]) - math.cos(angle) * (o[1] - p[1])
+ return x, y
+
+
+def rotate_cam(cam, theta, zoom_ratio=1., origin=[0, 0]):
+ r'''
+ Rotate camera around origin with theta
+ cam: (4, 4)
+ origin: (x, y)
+ theta: degree, 0~2pi
+ '''
+ new_cam = cam.clone().numpy()
+ R = new_cam[:3, :3] # rotation matrix
+
+ rot_z = Rotation.from_rotvec(theta * np.array([0, 0, 1])).as_matrix()
+ R = rot_z @ R
+ new_cam[:3, :3] = R
+
+ p = new_cam[:2, 3] # translation x, y
+ p = rotate(p, -theta, origin)
+ new_cam[:2, 3] = p
+ new_cam[:3, 3] = new_cam[:3, 3] * zoom_ratio
+ return torch.from_numpy(new_cam)
+
+
+def rot_matrix_angular_dist(R1, R2):
+ assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3
+ return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2.,
+ a_min=-1 + 1e-6, a_max=1 - 1e-6))
+
+def select_src_ids(tar_pose, ref_poses, num_select, dist_weight=0.5, drop_first=True):
+ num_cams = len(ref_poses)
+ num_select = min(num_select, num_cams - 1)
+ batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0)
+
+ angular_dists = rot_matrix_angular_dist(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3])
+ dists = np.linalg.norm(batched_tar_pose[:, :3, 3] - ref_poses[:, :3, 3], axis=1)
+ dists = dist_weight * dists + (1 - dist_weight) * angular_dists
+
+ sorted_ids = np.argsort(dists)
+ if drop_first:
+ selected_ids = sorted_ids[1:num_select+1]
+ else:
+ selected_ids = sorted_ids[:num_select]
+ return selected_ids
\ No newline at end of file
diff --git a/util/distinct_colors.py b/util/distinct_colors.py
new file mode 100644
index 0000000..136ffcc
--- /dev/null
+++ b/util/distinct_colors.py
@@ -0,0 +1,87 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import torch
+import numpy as np
+
+
+class DistinctColors:
+
+ def __init__(self):
+ colors = [
+ '#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f55031', '#911eb4', '#42d4f4', '#bfef45', '#fabed4', '#469990',
+ '#dcb1ff', '#404E55', '#fffac8', '#809900', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#a9a9a9', '#f032e6',
+ '#806020', '#ffffff',
+
+ "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0030ED", "#3A2465", "#34362D", "#B4A8BD", "#0086AA",
+ "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81", "#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700",
+
+ "#04F757", "#C8A1A1", "#1E6E00",
+ "#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700",
+ "#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329",
+ "#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C",
+ ]
+ self.hex_colors = colors
+ # 0 = crimson / red, 1 = green, 2 = yellow, 3 = blue
+ # 4 = orange, 5 = purple, 6 = sky blue, 7 = lime green
+ self.colors = [hex_to_rgb(c) for c in colors]
+ self.color_assignments = {}
+ self.color_ctr = 0
+ self.fast_color_index = torch.from_numpy(np.array([hex_to_rgb(colors[i % len(colors)]) for i in range(8096)] + [hex_to_rgb('#000000')]))
+
+ def get_color(self, index, override_color_0=False):
+ colors = [x for x in self.hex_colors]
+ if override_color_0:
+ colors[0] = "#3f3f3f"
+ colors = [hex_to_rgb(c) for c in colors]
+ if index not in self.color_assignments:
+ self.color_assignments[index] = colors[self.color_ctr % len(self.colors)]
+ self.color_ctr += 1
+ return self.color_assignments[index]
+
+ def get_color_fast_torch(self, index):
+ return self.fast_color_index[index]
+
+ def get_color_fast_numpy(self, index, override_color_0=False):
+ index = np.array(index).astype(np.int32)
+ if override_color_0:
+ colors = [x for x in self.hex_colors]
+ colors[0] = "#3f3f3f"
+ fast_color_index = torch.from_numpy(np.array([hex_to_rgb(colors[i % len(colors)]) for i in range(8096)] + [hex_to_rgb('#000000')]))
+ return fast_color_index[index % fast_color_index.shape[0]].numpy()
+ else:
+ return self.fast_color_index[index % self.fast_color_index.shape[0]].numpy()
+
+ def apply_colors(self, arr):
+ out_arr = torch.zeros([arr.shape[0], 3])
+
+ for i in range(arr.shape[0]):
+ out_arr[i, :] = torch.tensor(self.get_color(arr[i].item()))
+ return out_arr
+
+ def apply_colors_fast_torch(self, arr):
+ return self.fast_color_index[arr % self.fast_color_index.shape[0]]
+
+ def apply_colors_fast_numpy(self, arr):
+ return self.fast_color_index.numpy()[arr % self.fast_color_index.shape[0]]
+
+
+def hex_to_rgb(x):
+ return [int(x[i:i + 2], 16) / 255 for i in (1, 3, 5)]
+
+
+def visualize_distinct_colors(num_vis=32):
+ from PIL import Image
+ dc = DistinctColors()
+ labels = np.ones((1, 64, 64)).astype(np.int)
+ all_labels = []
+ for i in range(num_vis):
+ all_labels.append(labels * i)
+ all_labels = np.concatenate(all_labels, 0)
+ shape = all_labels.shape
+ labels_colored = dc.get_color_fast_numpy(all_labels.reshape(-1))
+ labels_colored = (labels_colored.reshape(shape[0] * shape[1], shape[2], 3) * 255).astype(np.uint8)
+ Image.fromarray(labels_colored).save("colormap.png")
+
+
+if __name__ == "__main__":
+ visualize_distinct_colors()
diff --git a/util/filesystem_logger.py b/util/filesystem_logger.py
new file mode 100644
index 0000000..adda84f
--- /dev/null
+++ b/util/filesystem_logger.py
@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import argparse
+import shutil
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from omegaconf import OmegaConf
+from pytorch_lightning.loggers.logger import Logger
+from pytorch_lightning.loggers.logger import DummyExperiment
+from pytorch_lightning.loggers.logger import rank_zero_experiment
+
+
+class FilesystemLogger(Logger):
+
+ @property
+ def version(self) -> Union[int, str]:
+ return 0
+
+ @property
+ def name(self) -> str:
+ return "fslogger"
+
+ # noinspection PyMethodOverriding
+ def log_hyperparams(self, params: argparse.Namespace):
+ pass
+
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
+ pass
+
+ def __init__(self, experiment_config, **_kwargs):
+ super().__init__()
+ self.experiment_config = experiment_config
+ self._experiment = None
+ # noinspection PyStatementEffect
+ self.experiment
+
+ @property
+ @rank_zero_experiment
+ def experiment(self):
+ if self._experiment is None:
+ self._experiment = DummyExperiment()
+ experiment_dir = Path(self.experiment_config["log_path"], self.experiment_config["exp_name"])
+ experiment_dir.mkdir(exist_ok=True, parents=True)
+
+ src_folders = ['config', 'data/splits', 'model', 'tests', 'trainer', 'util', 'data_processing', 'dataset']
+ sources = []
+ for src in src_folders:
+ sources.extend(list(Path(".").glob(f'{src}/**/*')))
+
+ files_to_copy = [x for x in sources if x.suffix in [".py", ".pyx", ".txt", ".so", ".pyd", ".h", ".cu", ".c", '.cpp', ".html"] and x.parts[0] != "runs" and x.parts[0] != "wandb"]
+
+ for f in files_to_copy:
+ Path(experiment_dir, "code", f).parents[0].mkdir(parents=True, exist_ok=True)
+ shutil.copyfile(f, Path(experiment_dir, "code", f))
+
+ Path(experiment_dir, "config.yaml").write_text(OmegaConf.to_yaml(self.experiment_config))
+
+ return self._experiment
diff --git a/util/metrics.py b/util/metrics.py
new file mode 100644
index 0000000..5cb807a
--- /dev/null
+++ b/util/metrics.py
@@ -0,0 +1,254 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import torch
+import numpy as np
+from sklearn.metrics import adjusted_rand_score
+from scipy.optimize import linear_sum_assignment
+import torch.nn.functional as F
+import torch.nn as nn
+from functools import partial
+from piq import ssim
+from piq import psnr
+import lpips
+
+
+def average_ari(masks, masks_gt, fg_only=False, reduction='mean'):
+ r'''
+ Input:
+ masks: (B, K, N)
+ masks_gt: (B, N)
+ '''
+ ari = []
+ masks = masks.argmax(dim=1)
+ B = masks.shape[0]
+ for i in range(B):
+ m = masks[i].cpu().numpy()
+ m_gt = masks_gt[i].cpu().numpy()
+ if fg_only:
+ m = m[np.where(m_gt > 0)]
+ m_gt = m_gt[np.where(m_gt > 0)]
+ score = adjusted_rand_score(m, m_gt)
+ ari.append(score)
+ if reduction == 'mean':
+ return torch.Tensor(ari).mean()
+ else:
+ return torch.Tensor(ari)
+
+
+def imask2bmask(imasks, ignore_index=None):
+ r"""Convert index mask to binary mask.
+ Args:
+ imask: index mask, shape (B, N)
+ Returns:
+ bmasks: # a list of (K, N), len = B
+ """
+ B, N = imasks.shape
+ bmasks = []
+ for i in range(B):
+ imask = imasks[i:i+1] # (1, N)
+ classes = imask.unique().tolist()
+ if ignore_index in classes:
+ classes.remove(ignore_index)
+ bmask = [imask == c for c in classes]
+ bmask = torch.cat(bmask, dim=0) # (K, N)
+ bmasks.append(bmask.float())
+ # can't use torch.stack because of different K
+ return bmasks
+
+
+def mean_best_overlap(masks, masks_gt, fg_only=False, reduction='mean'):
+ r"""Compute the best overlap between predicted and ground truth masks.
+ Args:
+ masks: predicted masks, shape (B, K, N), binary N = H*W
+ masks_gt: ground truth masks, shape (B, N), index
+ """
+ B = masks.shape[0]
+ ignore_index = None
+ if fg_only:
+ ignore_index = 0
+ bmasks_gt = imask2bmask(masks_gt, ignore_index=ignore_index) # a list of (K, N), len = B
+ mean_best_overlap = []
+ mOR = []
+ for i in range(B):
+ mask = masks[i].unsqueeze(0) > 0.5 # (1, K, N)
+ mask_gt = bmasks_gt[i].unsqueeze(1) > 0.5 # (K_gt, 1, N)
+ # Compute IOU
+ eps = 1e-8
+ intersection = (mask * mask_gt).sum(-1)
+ union = (mask + mask_gt).sum(-1)
+ iou = intersection / (union + eps) # (K_gt, K)
+ # Compute best overlap
+ best_overlap, _ = torch.max(iou, dim=1)
+ # Compute mean best overlap
+ mean_best_overlap.append(best_overlap.mean())
+ mOR.append((best_overlap > 0.5).float().mean())
+ if reduction == 'mean':
+ return torch.stack(mean_best_overlap).mean()
+ # , torch.stack(mOR).mean()
+ else:
+ return torch.stack(mean_best_overlap)
+ # , torch.stack(mOR)
+
+
+def iou_loss(pred, target):
+ """
+ Compute the iou loss: 1 - iou
+ pred: [K, N]
+ targets: [Kt, N]
+ """
+ eps = 1e-8
+ pred = pred > 0.5 # [K, N]
+ target = target > 0.5
+ intersection = (pred[:, None] & target[None]).sum(-1).float() # [K, Kt]
+ union = (pred[:, None] | target[None]).sum(-1).float() + eps # [K, Kt]
+ loss = 1 - (intersection / union) # [K, Kt]
+ return loss # [K, Kt]
+iou_loss_jit = torch.jit.script(iou_loss)
+
+
+class Matcher():
+ @torch.no_grad()
+ def forward(self, pred, target):
+ r"""
+ pred: [K, N]
+ targets: [Kt, N]
+ """
+ loss = iou_loss_jit(pred, target)
+ row_ind, col_ind = linear_sum_assignment(loss.cpu().numpy())
+ return torch.as_tensor(row_ind, dtype=torch.int64), torch.as_tensor(col_ind, dtype=torch.int64)
+
+ @torch.no_grad()
+ def batch_forward(self, pred, targets):
+ """
+ pred: [B, K, N]
+ targets: list of B x [Kt, N] Kt can be different for each target
+ """
+ indices = []
+ for i in range(pred.shape[0]):
+ indices.append(self.forward(pred[i], targets[i]))
+ return indices
+
+
+@torch.no_grad()
+def compute_iou(pred, target):
+ """
+ Input:
+ x: [K, N]
+ y: [K, N]
+ Return:
+ iou: [K, N]
+ """
+ eps = 1e-8
+ pred = pred > 0.5 # [K, N]
+ target = target > 0.5
+ intersection = (pred & target).sum(-1).float()
+ union = (pred | target).sum(-1).float() + eps # [K]
+ return (intersection / union).mean()
+compute_iou_jit = torch.jit.script(compute_iou)
+
+
+def matchedIoU(preds, targets, matcher, fg_only=False, reduction="mean"):
+ r"""
+ Input:
+ pred: [B, K, N]
+ targets: [B, N]
+ Return:
+ IoU: [1] or [B]
+ """
+ if preds.dim() == 2: # [K, N]
+ preds = preds.unsqueeze(0)
+ targets = targets.unsqueeze(0)
+
+ ious = []
+ B = preds.shape[0]
+ ignore_index = None
+ if fg_only:
+ ignore_index = 0
+ targets = imask2bmask(targets, ignore_index) # a list of [K1, N], len = B
+ for i in range(B):
+ tgt = targets[i]
+ pred = preds[i] # [K, N]
+ src_idx, tgt_idx = matcher.forward(pred, tgt)
+ src_pred = pred[src_idx] # [K1, N]
+ tgt_mask = tgt[tgt_idx] # [K1, N]
+ ious.append(compute_iou_jit(src_pred, tgt_mask))
+ ious = torch.stack(ious)
+ if reduction == "mean":
+ return ious.mean()
+ else:
+ return ious
+
+
+matcher = Matcher()
+SEGMETRICS = {
+ "hiou": partial(matchedIoU, matcher=matcher), # hungarian matched iou
+ "hiou_fg": partial(matchedIoU, fg_only=True, matcher=matcher),
+ "mbo": mean_best_overlap, # mean best overlap
+ "mbo_fg": partial(mean_best_overlap, fg_only=True),
+ "ari": average_ari,
+ "ari_fg": partial(average_ari, fg_only=True),
+}
+class SegMetrics(nn.Module):
+ def __init__(self, metrics=["hiou", "ari", "ari_fg"]):
+ super().__init__()
+ self.metrics = {}
+ for m in metrics:
+ self.metrics[m] = SEGMETRICS[m]
+
+ def forward(self, preds, targets):
+ r"""
+ Input:
+ preds: [B, N, K]
+ targets: [B, N]
+ Return:
+ metrics: dict of metrics
+ """
+ metrics = {}
+ valid = targets.sum(-1) > 0
+ preds = preds[valid]
+ targets = targets[valid]
+ preds = F.one_hot(preds.argmax(dim=-1), num_classes=preds.shape[-1]).permute(0, 2, 1).float() # [B, K, N]
+ for k, v in self.metrics.items():
+ if valid.sum() > 0:
+ metrics[k] = v(preds, targets)
+ else:
+ metrics[k] = torch.tensor(1).to(preds.device)
+ return metrics
+
+ def compute(self, preds, targets, metric='hiou'):
+ return self.metrics[metric](preds, targets)
+
+ def metrics_name(self):
+ return list(self.metrics.keys())
+
+
+class ReconMetrics(nn.Module):
+ def __init__(self, lpips_net='vgg'):
+ super().__init__()
+ self.metrics = {
+ "ssim": ssim,
+ "psnr": psnr,
+ "lpips": lpips.LPIPS(net=lpips_net),
+ }
+
+ def forward(self, preds, targets):
+ r"""
+ Input:
+ preds: [B, C, H, W]
+ targets: [B, C, H, W]
+ Return:
+ metrics: dict of metrics
+ """
+ metrics = {}
+ for k, v in self.metrics.items():
+ metrics[k] = v(preds, targets).mean()
+ return metrics
+
+ def compute(self, preds, targets, metric='psnr'):
+ return self.metrics[metric](preds, targets)
+
+ def metrics_name(self):
+ return list(self.metrics.keys())
+
+ def set_divice(self, device):
+ self.metrics["lpips"].to(device)
\ No newline at end of file
diff --git a/util/misc.py b/util/misc.py
new file mode 100644
index 0000000..6ad64b4
--- /dev/null
+++ b/util/misc.py
@@ -0,0 +1,364 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from collections import OrderedDict
+import os
+import sys
+root_path = os.path.abspath(__file__)
+root_path = '/'.join(root_path.split('/')[:-2])
+sys.path.append(root_path)
+import torch
+import math
+import numpy as np
+from pathlib import Path
+from matplotlib import cm
+from PIL import Image
+import torchvision.transforms as T
+
+
+def visualize_depth(depth, minval=0.001, maxval=1.5, use_global_norm=True):
+ x = depth
+ if isinstance(depth, torch.Tensor):
+ x = depth.cpu().numpy()
+ x = np.nan_to_num(x) # change nan to 0
+ if use_global_norm:
+ mi = minval
+ ma = maxval
+ else:
+ mi = np.min(x) # get minimum depth
+ ma = np.max(x)
+ x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
+ x_ = Image.fromarray((cm.get_cmap('jet')(x) * 255).astype(np.uint8))
+ x_ = T.ToTensor()(x_)[:3, :, :]
+ return x_
+
+
+def bounds(x):
+ lower = []
+ upper = []
+ for i in range(x.shape[1]):
+ lower.append(x[:, i].min())
+ upper.append(x[:, i].max())
+ return torch.tensor([lower, upper])
+
+
+def visualize_points(points, vis_path, colors=None):
+ if colors is None:
+ Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} 127 127 127" for p in points))
+ else:
+ Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}" for i, p in enumerate(points)))
+
+
+def visualize_points_as_pts(points, vis_path, colors=None):
+ if colors is None:
+ Path(vis_path).write_text("\n".join([f'{points.shape[0]}'] + [f"{p[0]} {p[1]} {p[2]} 255 127 127 127" for p in points]))
+ else:
+ Path(vis_path).write_text("\n".join([f'{points.shape[0]}'] + [f"{p[0]} {p[1]} {p[2]} 255 {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}" for i, p in enumerate(points)]))
+
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(EasyDict(mod=mod, outputs=outputs))
+
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen if t.requires_grad]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ self[name] = value
+
+ def __delattr__(self, name):
+ del self[name]
+
+
+def to_point_list(s):
+ return np.concatenate([c[:, np.newaxis] for c in np.where(s)], axis=1)
+
+
+def get_parameters_from_state_dict(state_dict, filter_key):
+ new_state_dict = OrderedDict()
+ for k in state_dict:
+ if k.startswith(filter_key):
+ new_state_dict[k.replace(filter_key + '.', '')] = state_dict[k]
+ return new_state_dict
+
+
+def logistic(n, zero_at):
+ return 1 - 1 / (1 + math.exp(-10 * (n / zero_at - 0.5)))
+
+
+def visualize_voxel_grid(output_path, voxel_grid, scale_to=(-1, 1)):
+ voxel_grid = ((voxel_grid - voxel_grid.min()) / (voxel_grid.max() - voxel_grid.min())).cpu()
+ rescale = lambda axis: scale_to[0] + (points[axis] / voxel_grid.shape[axis]) * (scale_to[1] - scale_to[0])
+ points = list(torch.where(voxel_grid > 0))
+ if len(points[0] > 0):
+ colors = cm.get_cmap('jet')(voxel_grid.numpy())
+ colors = colors[points[0].numpy(), points[1].numpy(), points[2].numpy(), :]
+ points[0] = rescale(0)
+ points[1] = rescale(1)
+ points[2] = rescale(2)
+ Path(output_path).write_text("\n".join([f'v {points[0][i]} {points[1][i]} {points[2][i]} {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}' for i in range(points[0].shape[0])]))
+ else:
+ Path(output_path).write_text("")
+ print("no points found..")
+
+
+def visualize_labeled_points(locations, labels, output_path):
+ from util.distinct_colors import DistinctColors
+ distinct_colors = DistinctColors()
+ if isinstance(labels, torch.Tensor):
+ colored_arr = distinct_colors.get_color_fast_torch(labels.flatten().cpu().numpy().tolist()).reshape(list(labels.shape) + [3]).numpy()
+ else:
+ colored_arr = distinct_colors.get_color_fast_numpy(labels.flatten().tolist()).reshape(list(labels.shape) + [3])
+ visualize_points(locations, output_path, colored_arr)
+
+
+def visualize_weighted_points(output_path, xyz, weights, threshold=1e-4):
+ weights = weights.view(-1)
+ weights_mask = weights > threshold
+ colors = cm.get_cmap('jet')(weights[weights_mask].numpy())
+ visualize_points(xyz[weights_mask, :].numpy(), output_path, colors=colors)
+
+
+def visualize_mask(arr, path):
+ from util.distinct_colors import DistinctColors
+ distinct_colors = DistinctColors()
+ assert len(arr.shape) == 2, "should be an HxW array"
+ boundaries = get_boundary_mask(arr)
+ if isinstance(arr, torch.Tensor):
+ colored_arr = distinct_colors.get_color_fast_torch(arr.flatten().cpu().numpy().tolist()).reshape(list(arr.shape) + [3]).numpy()
+ else:
+ colored_arr = distinct_colors.get_color_fast_numpy(arr.flatten().tolist()).reshape(list(arr.shape) + [3])
+ colored_arr = (colored_arr * 255).astype(np.uint8)
+ colored_arr[boundaries > 0, :] = 0
+ Image.fromarray(colored_arr).save(path)
+
+
+def probability_to_normalized_entropy(probabilities):
+ entropy = torch.zeros_like(probabilities[:, 0])
+ for i in range(probabilities.shape[1]):
+ entropy = entropy - probabilities[:, i] * torch.log2(probabilities[:, i] + 1e-8)
+ entropy = entropy / math.log2(probabilities.shape[1])
+ return entropy
+
+
+def get_boundary_mask(arr, dialation_size=1):
+ import cv2
+ arr_t, arr_r, arr_b, arr_l = arr[1:, :], arr[:, 1:], arr[:-1, :], arr[:, :-1]
+ arr_t_1, arr_r_1, arr_b_1, arr_l_1 = arr[2:, :], arr[:, 2:], arr[:-2, :], arr[:, :-2]
+ kernel = np.ones((dialation_size, dialation_size), 'uint8')
+ if isinstance(arr, torch.Tensor):
+ arr_t = torch.cat([arr_t, arr[-1, :].unsqueeze(0)], dim=0)
+ arr_r = torch.cat([arr_r, arr[:, -1].unsqueeze(1)], dim=1)
+ arr_b = torch.cat([arr[0, :].unsqueeze(0), arr_b], dim=0)
+ arr_l = torch.cat([arr[:, 0].unsqueeze(1), arr_l], dim=1)
+
+ arr_t_1 = torch.cat([arr_t_1, arr[-2, :].unsqueeze(0), arr[-1, :].unsqueeze(0)], dim=0)
+ arr_r_1 = torch.cat([arr_r_1, arr[:, -2].unsqueeze(1), arr[:, -1].unsqueeze(1)], dim=1)
+ arr_b_1 = torch.cat([arr[0, :].unsqueeze(0), arr[1, :].unsqueeze(0), arr_b_1], dim=0)
+ arr_l_1 = torch.cat([arr[:, 0].unsqueeze(1), arr[:, 1].unsqueeze(1), arr_l_1], dim=1)
+
+ boundaries = torch.logical_or(torch.logical_or(torch.logical_or(torch.logical_and(arr_t != arr, arr_t_1 != arr), torch.logical_and(arr_r != arr, arr_r_1 != arr)), torch.logical_and(arr_b != arr, arr_b_1 != arr)), torch.logical_and(arr_l != arr, arr_l_1 != arr))
+
+ boundaries = boundaries.cpu().numpy().astype(np.uint8)
+ boundaries = cv2.dilate(boundaries, kernel, iterations=1)
+ boundaries = torch.from_numpy(boundaries).to(arr.device)
+ else:
+ arr_t = np.concatenate([arr_t, arr[-1, :][np.newaxis, :]], axis=0)
+ arr_r = np.concatenate([arr_r, arr[:, -1][:, np.newaxis]], axis=1)
+ arr_b = np.concatenate([arr[0, :][np.newaxis, :], arr_b], axis=0)
+ arr_l = np.concatenate([arr[:, 0][:, np.newaxis], arr_l], axis=1)
+
+ arr_t_1 = np.concatenate([arr_t_1, arr[-2, :][np.newaxis, :], arr[-1, :][np.newaxis, :]], axis=0)
+ arr_r_1 = np.concatenate([arr_r_1, arr[:, -2][:, np.newaxis], arr[:, -1][:, np.newaxis]], axis=1)
+ arr_b_1 = np.concatenate([arr[0, :][np.newaxis, :], arr[1, :][np.newaxis, :], arr_b_1], axis=0)
+ arr_l_1 = np.concatenate([arr[:, 0][:, np.newaxis], arr[:, 1][:, np.newaxis], arr_l_1], axis=1)
+
+ boundaries = np.logical_or(np.logical_or(np.logical_or(np.logical_and(arr_t != arr, arr_t_1 != arr), np.logical_and(arr_r != arr, arr_r_1 != arr)), np.logical_and(arr_b != arr, arr_b_1 != arr)), np.logical_and(arr_l != arr, arr_l_1 != arr)).astype(np.uint8)
+ boundaries = cv2.dilate(boundaries, kernel, iterations=1)
+
+ return boundaries
+
+
+def pixelid2patchid(pixelid, H, W, H1, W1):
+ """
+ Input:
+ pixelid: from 0 to H*W-1
+ Return:
+ patchid: from 0 to H1*W1-1
+ """
+ x, y = pixelid % W, pixelid // W
+ x1, y1 = x // (W // W1), y // (H // H1)
+ patchid = y1 * W1 + x1
+ return patchid
+
+
+def patchify(x, p):
+ r"""
+ Input:
+ x: (H, W, D)
+ p: patch size
+ Return:
+ x: (hw, p*p, D)
+ """
+ H, W, D = x.shape
+ h, w = H // p, W // p
+ x = x.reshape(h, p, w, p, D)
+ x = torch.einsum('hpwqc->hwpqc', x)
+ x = x.reshape(h*w, p*p, D)
+ return x
+
+
+def batch_patchify(x, p):
+ r"""
+ Input:
+ x: (B, H, W, D)
+ p: patch size
+ Return:
+ x: (B, hw, p*p, D)
+ """
+ B, H, W, D = x.shape
+ h, w = H // p, W // p
+ x = x.reshape(B, h, p, w, p, D)
+ x = torch.einsum('bhpwqc->bhwpqc', x)
+ x = x.reshape(B, h*w, p*p, D)
+ return x
+
+
+class SubSampler():
+ def idx_subsample(self, img_size, Br, mode='random'):
+ r'''
+ Input:
+ img_size: H,W
+ Br: number of rays to sample
+ Return:
+ subsample_idx: [Br, 1]
+ '''
+ if mode == 'uniform':
+ b = int(Br ** 0.5)
+ assert b * b == Br, "Br must be a square number"
+ H, W = img_size
+ Nr = H * W
+ s_x, s_y = W // b, H // b
+ x = torch.randint(0, s_x, size=(1,)).item()
+ y = torch.randint(0, s_y, size=(1,)).item()
+ subsample_idx = torch.arange(Nr).reshape(H, W)[y:, x:][::s_y, ::s_x]
+ subsample_idx = subsample_idx.reshape(-1, 1)
+ Br1 = subsample_idx.shape[0]
+ if Br1 < Br: # b can't be devided by H or W
+ subsample_idx = torch.cat([subsample_idx,
+ torch.randint(0, Nr, [Br-Br1, 1])], dim=0)
+ else:
+ subsample_idx = subsample_idx[torch.randperm(Br1)[:Br]]
+ return subsample_idx
+ else:
+ H, W = img_size
+ Nr = H * W
+ subsample_idx = torch.randperm(Nr)[:Br].reshape(-1, 1)
+ return subsample_idx
+
+ def idx_subsample_patch(self, img_size, patch_size, s=1):
+ r'''
+ Input:
+ img_size: H,W
+ patch_size: P
+ s: stride for subsampling
+ Return:
+ subsample_idx: [Br, 1]
+ '''
+ H, W = img_size
+ Nr = H * W
+ P = patch_size
+ x = torch.randint(0, W//P, size=(1,)).item() * P + torch.randint(0, s, size=(1,)).item()
+ y = torch.randint(0, H//P, size=(1,)).item() * P + torch.randint(0, s, size=(1,)).item()
+ subsample_idx = torch.arange(Nr).view(H, W)[y:y+P, x:x+P][::s, ::s]
+ subsample_idx = subsample_idx.reshape(P//s*P//s, 1)
+ return subsample_idx
+
+ def subsample(self, idx, x_tuple):
+ r'''
+ Input:
+ idx: [Br, 1]
+ x_tuple: a tuple of tensors to subsample, [HW, C]
+ Return:
+ x_tuple: a tuple of tensors [Br, C]
+ '''
+ ret = []
+ for x in x_tuple:
+ ret.append(x.gather(0, idx.expand(-1, x.shape[1])))
+ return tuple(ret)
+
diff --git a/util/optimizer.py b/util/optimizer.py
new file mode 100644
index 0000000..cd6064c
--- /dev/null
+++ b/util/optimizer.py
@@ -0,0 +1,84 @@
+# Copyright 2023 Google Research. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""PyTorch implementation of the Lion optimizer."""
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Lion(Optimizer):
+ r"""Implements Lion algorithm."""
+
+ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
+ """Initialize the hyperparameters.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-4)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.99))
+ weight_decay (float, optional): weight decay coefficient (default: 0)
+ """
+
+ if not 0.0 <= lr:
+ raise ValueError('Invalid learning rate: {}'.format(lr))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+
+ Returns:
+ the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ # Perform stepweight decay
+ p.data.mul_(1 - group['lr'] * group['weight_decay'])
+
+ grad = p.grad
+ state = self.state[p]
+ # State initialization
+ if len(state) == 0:
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+
+ exp_avg = state['exp_avg']
+ beta1, beta2 = group['betas']
+
+ # Weight update
+ update = exp_avg * beta1 + grad * (1 - beta1)
+ p.add_(torch.sign(update), alpha=-group['lr'])
+ # Decay the momentum running average coefficient
+ exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
+
+ return loss
\ No newline at end of file
diff --git a/util/ray.py b/util/ray.py
new file mode 100644
index 0000000..3000388
--- /dev/null
+++ b/util/ray.py
@@ -0,0 +1,71 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import torch
+
+
+def create_grid(H, W, render_stride=1):
+ xs = torch.linspace(0, W - 1, W)[::render_stride]
+ ys = torch.linspace(0, H - 1, H)[::render_stride]
+ i, j = torch.meshgrid(xs, ys, indexing='ij')
+ i, j = i + render_stride // 2, j + render_stride // 2
+ return i.t(), j.t()
+
+
+def get_ray_directions_with_intrinsics(H, W, intrinsics, render_stride=1):
+ B = intrinsics.shape[0]
+ h, w = H // render_stride, W // render_stride
+ i, j = create_grid(H, W, render_stride=render_stride)
+ i, j = i.to(intrinsics.device), j.to(intrinsics.device)
+ i, j = i[None, ...], j[None, ...]
+ fx, fy, cx, cy = intrinsics[:, 0:1, 0:1], intrinsics[:, 1:2, 1:2], intrinsics[:, 0:1, 2:3], intrinsics[:, 1:2, 2:3]
+ directions = torch.stack([
+ (i - cx) / fx, (j - cy) / fy, torch.ones([B, h, w], device=intrinsics.device)
+ ], -1)
+ return directions
+
+
+def get_rays(cameras, H, W, render_stride=1):
+ # cameras: (B, 27) 2 + 9 + 16 HW + intrinsics + cam2world
+ h, w = H // render_stride, W // render_stride
+ rays_o, rays_d = get_rays_origin_and_direction(cameras, H, W, render_stride) # (B, 1, 3), (B, H*W, 3)
+ rays = torch.cat([
+ rays_o.expand(-1, h*w, -1), rays_d
+ ], -1) # (B, H*W, 6)
+ return rays
+
+
+def get_rays_origin_and_direction(cameras, H, W, render_stride=1):
+ B = cameras.shape[0]
+ h, w = H // render_stride, W // render_stride
+ intrinsics, cam2world = cameras[:, 2:18].reshape(-1, 4, 4)[:, :3, :3], cameras[:, -16:].reshape(-1, 4, 4)
+ directions = get_ray_directions_with_intrinsics(H, W, intrinsics, render_stride)
+ # directions: (B, H, W, 3), cam2world: (B, 4, 4)
+ rays_d = torch.matmul(directions.view(B, h*w, 3), cam2world[:, :3, :3].transpose(1, 2)) # (B, H*W, 3)
+ rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
+ rays_o = cam2world[:, :3, 3]
+
+ rays_d = rays_d.view(B, h*w, 3)
+ rays_o = rays_o.view(B, 1, 3)
+
+ return rays_o, rays_d
+
+
+def rays_intersect_sphere(rays_o, rays_d, r=1):
+ """
+ Solve for t such that a=ro+trd with ||a||=r
+ Quad -> r^2 = ||ro||^2 + 2t (ro.rd) + t^2||rd||^2
+ -> t = (-b +- sqrt(b^2 - 4ac))/(2a) with
+ a = ||rd||^2
+ b = 2(ro.rd)
+ c = ||ro||^2 - r^2
+ => (forward intersection) t= (sqrt(D) - (ro.rd))/||rd||^2
+ with D = (ro.rd)^2 - (r^2 - ||ro||^2)
+ """
+ odotd = torch.sum(rays_o * rays_d, -1)
+ d_norm_sq = torch.sum(rays_d ** 2, -1)
+ o_norm_sq = torch.sum(rays_o ** 2, -1)
+ determinant = odotd ** 2 + (r ** 2 - o_norm_sq) * d_norm_sq
+ assert torch.all(
+ determinant >= 0
+ ), "Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!"
+ return (torch.sqrt(determinant) - odotd) / d_norm_sq
diff --git a/util/transforms.py b/util/transforms.py
new file mode 100644
index 0000000..ea47bdb
--- /dev/null
+++ b/util/transforms.py
@@ -0,0 +1,220 @@
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import numpy as np
+import torch
+from transforms3d.euler import euler2mat
+from transforms3d.axangles import axangle2mat
+from transforms3d.quaternions import quat2mat
+
+
+def has_torch(*args):
+ return any([isinstance(x, torch.Tensor) for x in args])
+
+
+def dot(transform, points, coords=False):
+ if isinstance(points, torch.Tensor):
+ return dot_torch(transform, points, coords)
+ else:
+ if isinstance(transform, torch.Tensor): # points dominate
+ transform = transform.cpu().numpy()
+ if type(points) == list:
+ points = np.array(points)
+
+ if len(points.shape) == 1:
+ # single point
+ if transform.shape == (3, 3):
+ return transform @ points[:3]
+ else:
+ return (transform @ np.array([*points[:3], 1]))[:3]
+ if points.shape[1] == 3 or (coords and points.shape[1] > 3):
+ # nx[xyz,...]
+ if transform.shape == (4, 4):
+ pts = (transform[:3, :3] @ points[:, :3].T).T + transform[:3, 3]
+ elif transform.shape == (3, 3):
+ pts = (transform[:3, :3] @ points[:, :3].T).T
+ else:
+ raise RuntimeError("Format of transform not understood")
+ return np.concatenate([pts, points[:, 3:]], 1)
+ else:
+ raise RuntimeError(f"Format of points {points.shape} not understood")
+
+
+def dot_torch(transform, points, coords=False):
+ if not isinstance(transform, torch.Tensor):
+ transform = torch.from_numpy(transform).float()
+
+ transform = transform.to(points.device).float()
+ if type(points) == list:
+ points = torch.Tensor(points)
+ if len(points.shape) == 1:
+ # single point
+ if transform.shape == (3, 3):
+ return transform @ points[:3]
+ else:
+ return (transform @ torch.Tensor([*points[:3], 1]))[:3]
+ if points.shape[1] == 3 or (coords and points.shape[1] > 3):
+ # nx[xyz,...]
+ if transform.shape == (4, 4):
+ pts = (transform[:3, :3] @ points[:, :3].T).T + transform[:3, 3]
+ elif transform.shape == (3, 3):
+ pts = (transform[:3, :3] @ points[:, :3].T).T
+ else:
+ raise RuntimeError("Format of transform not understood")
+ return torch.cat([pts, points[:, 3:]], 1)
+ else:
+ raise RuntimeError(f"Format of points {points.shape} not understood")
+
+
+def dot2d(transform, points):
+ if type(points) == list:
+ points = np.array(points)
+
+ if len(points.shape) == 1:
+ # single point
+ if transform.shape == (2, 2):
+ return transform @ points[:2]
+ else:
+ return (transform @ np.array([*points[:2], 1]))[:2]
+ elif len(points.shape) == 2:
+ if points.shape[1] in [2, 3]:
+ # needs to be transposed for dot product
+ points = points.T
+ else:
+ raise RuntimeError("Format of points not understood")
+ # points in format [2/3,n]
+ if transform.shape == (3, 3):
+ return (transform[:2, :2] @ points[:2]).T + transform[:2, 2]
+ elif transform.shape == (2, 2):
+ return (transform[:2, :2] @ points[:2]).T
+ else:
+ raise RuntimeError("Format of transform not understood")
+
+
+def backproject(depth, intrinsics, cam2world=np.eye(4), color=None):
+ # in height x width (xrgb)
+ h, w = depth.shape
+ valid_px = depth > 0
+ yv, xv = np.meshgrid(range(h), range(w), indexing="ij")
+ img_coords = np.stack([yv, xv], -1)
+ img_coords = img_coords[valid_px]
+ z_coords = depth[valid_px]
+ pts = uvd_backproject(img_coords, z_coords, intrinsics, cam2world, color[valid_px] if color is not None else None)
+
+ return pts
+
+
+def uvd_backproject(uv, d, intrinsics, cam2world=np.eye(4), color=None):
+ fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2]
+ py = (uv[:, 0] - cy) * d / fy
+ px = (uv[:, 1] - cx) * d / fx
+ pts = np.stack([px, py, d])
+
+ pts = cam2world[:3, :3] @ pts + np.tile(cam2world[:3, 3], (pts.shape[1], 1)).T
+ pts = pts.T
+ if color is not None:
+ pts = np.concatenate([pts, color], 1)
+
+ return pts
+
+
+def trs_decomp(A):
+ if has_torch(A):
+ s_vec = torch.norm(A[:3, :3], dim=0)
+ else:
+ s_vec = np.linalg.norm(A[:3, :3], axis=0)
+ R = A[:3, :3] / s_vec
+ t = A[:3, 3]
+ return t, R, s_vec
+
+
+def scale_mat(s, as_torch=True):
+ if isinstance(s, np.ndarray):
+ s_mat = np.eye(4)
+ s_mat[:3, :3] *= s
+ elif has_torch(s):
+ s_mat = torch.eye(4).to(s.device)
+ s_mat[:3, :3] *= s
+ s_mat
+ else:
+ s_mat = torch.eye(4) if as_torch else np.eye(4)
+ s_mat[:3, :3] *= s
+ return s_mat
+
+
+def trans_mat(t):
+ if has_torch(t):
+ t_mat = torch.eye(4).to(t.device).float()
+ t_mat[:3, 3] = t
+ else:
+ t_mat = np.eye(4, dtype=np.float32)
+ t_mat[:3, 3] = t
+ return t_mat
+
+
+def rot_mat(axangle=None, euler=None, quat=None, as_torch=True):
+ R = np.eye(3)
+ if axangle is not None:
+ if euler is None:
+ axis, angle = axangle[0], axangle[1]
+ else:
+ axis, angle = axangle, euler
+ R = axangle2mat(axis, angle)
+ elif euler is not None:
+ R = euler2mat(*euler)
+ elif quat is not None:
+ R = quat2mat(quat)
+ if as_torch:
+ R = torch.Tensor(R)
+ return R
+
+
+def hmg(M):
+ if M.shape[0] == 3 and M.shape[1] == 3:
+ if has_torch(M):
+ hmg_M = torch.eye(4, dtype=M.dtype).to(M.device)
+ else:
+ hmg_M = np.eye(4, dtype=M.dtype)
+ hmg_M[:3, :3] = M
+ else:
+ hmg_M = M
+ return hmg_M
+
+
+def trs_comp(t, R, s_vec):
+ return trans_mat(t) @ hmg(R) @ scale_mat(s_vec)
+
+
+def tr_comp(t, R):
+ return trans_mat(t) @ hmg(R)
+
+
+def quat_from_two_vectors(v0, v1):
+ import quaternion as qt
+ v0 = v0 / np.linalg.norm(v0)
+ v1 = v1 / np.linalg.norm(v1)
+ c = v0.dot(v1)
+ if c < (-1 + 1e-8):
+ c = max(c, -1)
+ m = np.stack([v0, v1], 0)
+ _, _, vh = np.linalg.svd(m, full_matrices=True)
+ axis = vh[2]
+ w2 = (1 + c) * 0.5
+ w = np.sqrt(w2)
+ axis = axis * np.sqrt(1 - w2)
+ return qt.quaternion(w, *axis)
+
+ axis = np.cross(v0, v1)
+ s = np.sqrt((1 + c) * 2)
+ return qt.quaternion(s * 0.5, *(axis / s))
+
+
+def to4x4(pose):
+ constants = torch.zeros_like(pose[..., :1, :], device=pose.device)
+ constants[..., :, 3] = 1
+ return torch.cat([pose, constants], dim=-2)
+
+
+def normalize(poses):
+ pose_copy = torch.clone(poses)
+ pose_copy[..., :3, 3] /= torch.max(torch.abs(poses[..., :3, 3]))
+ return pose_copy