diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a816307 --- /dev/null +++ b/.gitignore @@ -0,0 +1,148 @@ +data/* +slurm/ +wandb/ +lightning_logs/ +runs/ +runs + +*.pyd + +# Editors +.vscode/ +.idea/ + +# Vagrant +.vagrant/ + +# Mac/OSX +.DS_Store + +# Windows +Thumbs.db + +# Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +*.out +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json +/experiments/ +/pretrained-examples/ + +test.ipynb +debug/* +checkpoints/* +outputs/* +.hydra/* +*.mp4 +*.gif +test.py +*.slurm +z/ \ No newline at end of file diff --git a/README.md b/README.md index 7f0c9a3..c2e48ec 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,94 @@ # SlotLifter -Code for "SlotLifter: Slot-guided Feature Lifting for Learning Object-centric Radiance Fields" (ECCV 2024) +

+ + Paper arXiv + + + Paper PDF + + + Project Page + +

-# Coming Soon +This repository contains the official implementation of the ECCV 2024 paper: + +[SlotLifter: Slot-guided Feature Lifting for Learning Object-centric Radiance Fields](https://arxiv.org/abs/2408.06697) + +[YuLiu](https://yuliu-ly.github.io)\*,[Baoxiong Jia](https://buzz-beater.github.io)\*,[Yixin Chen](https://yixchen.github.io), [Siyuan Huang](https://siyuanhuang.com) +
+

+ +

+ +## Environment Setup +We provide all environment configurations in ``requirements.txt``. To install all packages, you can create a conda environment and install the packages as follows: +```bash +conda create -n slotlifter python=3.8 +conda activate slotlifter +conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia +pip install -r requirements.txt +``` +In our experiments, we used NVIDIA CUDA 11.3 on Ubuntu 20.04. Similar CUDA version should also be acceptable with corresponding version control for ``torch`` and ``torchvision``. + +## Dataset +### 1. ShapeStacks, ObjectsRoom, CLEVRTex, Flowers +Download ShapeStacks, ObjectsRoom, CLEVRTex and Flowers datasets with +```bash +chmod +x scripts/downloads_data.sh +./downloads_data.sh +``` +For ObjectsRoom dataset, you need to run ``objectsroom_process.py`` to save the tfrecords dataset as a png format. +Remember to change the ``DATA_ROOT`` in ``downloads_data.sh`` and ``objectsroom_process.py`` to your own paths. +### 2. PTR, Birds, Dogs, Cars +Download PTR dataset following instructions from http://ptr.csail.mit.edu. Download CUB-Birds, Stanford Dogs, and Cars datasets from [here](https://drive.google.com/drive/folders/1zEzsKV2hOlwaNRzrEXc9oGdpTBrrVIVk), provided by authors from [DRC](https://github.com/yuPeiyu98/DRC). We use the ```birds.zip```, ```cars.tar``` and ```dogs.zip``` and then uncompress them. + +### 4. YCB, ScanNet, COCO +YCB, ScanNet and COCO datasets are available from [here](https://www.dropbox.com/sh/u1p1d6hysjxqauy/AACgEh0K5ANipuIeDnmaC5mQa?dl=0), provided by authors from [UnsupObjSeg](https://github.com/vLAR-group/UnsupObjSeg). + +### 5. Data preparation +Please organize the data following [here](./data/README.md) before experiments. + +## Training + +To train the model from scratch we provide the following model files: + - ``train_trans_dec.py``: transformer-based model + - ``train_mixture_dec.py``: mixture-based model + - ``train_base_sa.py``: original slot-attention +We provide training scripts under ``scripts/train``. Please use the following command and change ``.sh`` file to the model you want to experiment with. Take the transformer-based decoder experiment on Birds as an exmaple, you can run the following: +```bash +$ cd scripts +$ cd train +$ chmod +x trans_dec_birds.sh +$ ./trans_dec_birds.sh +``` +Remember to change the paths in ``path.json`` to your own paths. +## Reloading checkpoints & Evaluation + +To reload checkpoints and only run inference, we provide the following model files: + - ``test_trans_dec.py``: transformer-based model + - ``test_mixture_dec.py``: mixture-based model + - ``test_base_sa.py``: original slot-attention + +Similarly, we provide testing scripts under ```scripts/test```. We provide transformer-based model for real-world datasets (Birds, Dogs, Cars, Flowers, YCB, ScanNet, COCO) +and mixture-based model for synthetic datasets(ShapeStacks, ObjectsRoom, ClevrTex, PTR). We provide all checkpoints [here](https://drive.google.com/drive/folders/10LmK9JPWsSOcezqd6eLjuzn38VdwkBUf?usp=sharing). Please use the following command and change ``.sh`` file to the model you want to experiment with: +```bash +$ cd scripts +$ cd test +$ chmod +x trans_dec_birds.sh +$ ./trans_dec_birds.sh +``` + +## Citation +If you find our paper and/or code helpful, please consider citing: +``` +@inproceedings{Liu2024slotlifter, + title={SlotLifter: Slot-guided Feature Lifting for Learning Object-centric Radiance Fields}, + author={Liu, Yu and Jia, Baoxiong and Chen, Yixin and Huang, Siyuan}, + booktitle={European Conference on Computer Vision (ECCV)}, + year={2024} +} +``` + +## Acknowledgement +This code heavily used resources from [PanopticLifting](https://github.com/nihalsid/panoptic-lifting), [BO-QSA](https://github.com/YuLiu-LY/BO-QSA), [SLATE](https://github.com/singhgautam/slate), [OSRT](https://github.com/stelzner/osrt), [IBRNet](https://github.com/googleinterns/IBRNet). We thank the authors for open-sourcing their awesome projects. diff --git a/assets/overview.png b/assets/overview.png new file mode 100644 index 0000000..1bd364a Binary files /dev/null and b/assets/overview.png differ diff --git a/config/cfg/dtu.yaml b/config/cfg/dtu.yaml new file mode 100644 index 0000000..e2aa2c0 --- /dev/null +++ b/config/cfg/dtu.yaml @@ -0,0 +1,107 @@ +# Wandb +project: "dtu" # project name +exp_name: test # experiment name +entity: # username or teamname where you're sending runs +group: # experiment groupname +job_type: debug # train / test / debug ... +tags: # tags for this run +id: # unique Id for this run +notes: # notes for this run +watch_model: false # true for logging the gradient of parameters +# Training +lpips_net: vgg +sample_mode: uniform +bg_bound: 0.27 +force_bg: false +force_bg_steps: 30000 +normalize: true +benchmark: false +deterministic: false +render_src_view: false +profiler: # use profiler to check time bottleneck +resume: null +ckpt_path: '' +logger: # wandb or None +log_path: "runs/dtu" +chunk: 8192 # num of rays per chunk +num_workers: 0 +seed: 42 +val_percent: 1.0 # val_batchs = val_check_percent * val_batchs if val_batchs < 1 else val_batchs +train_percent: 1.0 +test_percent: 1.0 +val_check_interval: 4 # do validation every val_check_interval epochs. It could be less than 1 +grad_clip: 0.5 +precision: 32 # compute precision +instance_steps: 500000 +stop_semantic_grad: false +decay_noise: 20000 +seg_metrics: +- ari +- hiou +- ari_fg +# Optimizer +optimizer: lion +lr: 5e-5 +min_lr_factor: 0.02 +weight_decay: 0.001 +warmup_steps: 10000 +max_steps: 250000 +max_epochs: 10000 +decay_steps: 50000 +# Dataset +norm_scene: false +select_view_func: nearby # or uniform +load_mask: false +img_size: +- 300 +- 400 +train_subsample_frames: 1 +val_subsample_frames: 5 +num_src_view: 4 +batch_size: 2 +ray_batchsize: 1024 +max_instances: 20 +dataset: dtu +dataset_root: /home/yuliu/Dataset/DTU +instance_dir: instance +semantics_dir: semantics +# dataset_root: data/hypersim/ai_001_008 +max_depth: 10 +visualized_indices: null +overfit: false +# Model +# Multi-view enc +feature_size: 32 +num_heads: 4 +conv_enc: false +conv_dim: 32 +# slot_enc +sigma_steps: 30000 +num_slots: 2 +num_iter: 3 +slot_size: 256 +drop_path: 0.2 +num_blocks: 1 +# slot_dec +slot_density: true +slot_dec_dim: 64 +num_dec_blocks: 4 +# NeRF +random_proj_ratio: 1 +random_proj: true +random_proj_steps: 30000 +n_samples: 64 +n_samples_fine: 64 +coarse_to_fine: true +pe_view: 2 +pe_feat: 0 +grid_init: pos_enc +monitor: psnr +nerf_mlp_dim: 64 +suffix: '' +scene_id: -1 +num_vis: 30 +hydra: + output_subdir: null # Disable saving of config files. We'll do that ourselves. + run: + dir: . \ No newline at end of file diff --git a/config/cfg/scannet.yaml b/config/cfg/scannet.yaml new file mode 100644 index 0000000..36474d3 --- /dev/null +++ b/config/cfg/scannet.yaml @@ -0,0 +1,108 @@ +# Wandb +project: "scannet" # project name +exp_name: test # experiment name +entity: # username or teamname where you're sending runs +group: # experiment groupname +job_type: debug # train / test / debug ... +tags: # tags for this run +id: # unique Id for this run +notes: # notes for this run +watch_model: false # true for logging the gradient of parameters +# Training +lpips_net: vgg +sample_mode: uniform +bg_bound: 0.27 +force_bg: false +force_bg_steps: 30000 +normalize: true +benchmark: false +deterministic: false +render_src_view: false +profiler: # use profiler to check time bottleneck +resume: null +ckpt_path: '' +logger: # wandb or None +log_path: "runs/scannet" +chunk: 8192 # num of rays per chunk +num_workers: 0 +seed: 42 +val_percent: 1.0 # val_batchs = val_check_percent * val_batchs if val_batchs < 1 else val_batchs +train_percent: 1.0 +test_percent: 1.0 +val_check_interval: 4 # do validation every val_check_interval epochs. It could be less than 1 +grad_clip: 0.5 +precision: 32 # compute precision +instance_steps: 500000 +recon_rgb: true +stop_semantic_grad: false +decay_noise: 20000 +seg_metrics: +- ari +- hiou +- ari_fg +# Optimizer +optimizer: lion +lr: 5e-5 +min_lr_factor: 0.02 +weight_decay: 0.001 +warmup_steps: 10000 +max_steps: 250000 +max_epochs: 10000 +decay_steps: 50000 +# Dataset +norm_scene: false +select_view_func: nearby # or uniform +load_mask: false +img_size: +- 480 +- 640 +train_subsample_frames: 1 +val_subsample_frames: 5 +num_src_view: 4 +batch_size: 2 +ray_batchsize: 1024 +max_instances: 20 +dataset: scannet +dataset_root: /home/yuliu/Dataset/scannet +instance_dir: instance +semantics_dir: semantics +# dataset_root: data/hypersim/ai_001_008 +max_depth: 10 +visualized_indices: null +overfit: false +# Model +# Multi-view enc +feature_size: 32 +num_heads: 4 +conv_enc: false +conv_dim: 32 +# slot_enc +sigma_steps: 30000 +num_slots: 8 +num_iter: 3 +slot_size: 256 +drop_path: 0.2 +num_blocks: 1 +# slot_dec +slot_density: true +slot_dec_dim: 64 +num_dec_blocks: 4 +# NeRF +random_proj_ratio: 1 +random_proj: true +random_proj_steps: 30000 +n_samples: 64 +n_samples_fine: 64 +pe_view: 2 +pe_feat: 0 +coarse_to_fine: true +grid_init: pos_enc +nerf_mlp_dim: 64 +monitor: psnr +num_vis: 1 +scene_id: 0 +suffix: '' +hydra: + output_subdir: null # Disable saving of config files. We'll do that ourselves. + run: + dir: . \ No newline at end of file diff --git a/config/cfg/uorf.yaml b/config/cfg/uorf.yaml new file mode 100755 index 0000000..940c019 --- /dev/null +++ b/config/cfg/uorf.yaml @@ -0,0 +1,103 @@ +# Wandb +project: "uorf" # project name +exp_name: test # experiment name +entity: # username or teamname where you're sending runs +group: # experiment groupname +job_type: debug # train / test / debug ... +tags: # tags for this run +id: # unique Id for this run +notes: # notes for this run +watch_model: false # true for logging the gradient of parameters +# Training +lpips_net: alex +sample_mode: uniform +bg_bound: 0.64 +force_bg: true +force_bg_steps: 50000 +normalize: false +benchmark: false +deterministic: false +render_src_view: true +profiler: # use profiler to check time bottleneck +resume: null +ckpt_path: '' +logger: # wandb or None +log_path: "runs/uorf" +chunk: 16384 # num of rays per chunk +num_workers: 0 +seed: 42 +val_percent: 0.2 # val_batchs = val_check_percent * val_batchs if val_batchs < 1 else val_batchs +train_percent: 1.0 +test_percent: 1.0 +val_check_interval: 4 # do validation every val_check_interval epochs. It could be less than 1 +grad_clip: 0.5 +precision: 32 # compute precision +instance_steps: 500000 +stop_semantic_grad: false +decay_noise: 20000 +seg_metrics: ['ari', 'hiou', 'ari_fg'] +# Optimizer +optimizer: lion +lr: 5e-5 +min_lr_factor: 0.02 +weight_decay: 0.001 +warmup_steps: 10000 +max_steps: 250000 +max_epochs: 10000 +decay_steps: 50000 +# Dataset +subset: kitchen_matte +norm_scene: true +select_view_func: nearby # or uniform +load_mask: false +train_subsample_frames: 1 +val_subsample_frames: 5 +num_src_view: 1 +batch_size: 2 +ray_batchsize: 1024 +max_instances: 20 +img_size: +- 128 +- 128 +n_scenes: 5 +dataset: uorf +dataset_root: /home/yuliu/Dataset/uorf +num_vis: 20 +max_depth: 10 +visualized_indices: null +overfit: false +# Model +# Multi-view enc +conv_enc: false +feature_size: 64 +num_heads: 4 +conv_dim: 32 +# slot_enc +sigma_steps: 30000 +num_slots: 5 +num_iter: 3 +slot_size: 256 +drop_path: 0.2 +num_blocks: 1 +# slot_dec +slot_dec_dim: 64 +num_dec_blocks: 4 +slot_density: true +# NeRF +random_proj_ratio: 1 +random_proj: true +random_proj_steps: 30000 +n_samples: 96 +n_samples_fine: 64 +coarse_to_fine: false +pe_view: 10 +pe_feat: 0 +grid_init: pos_enc +nerf_mlp_dim: 64 +scene_id: 0 +monitor: nv_ari +suffix: '' +hydra: + output_subdir: null # Disable saving of config files. We'll do that ourselves. + run: + dir: . diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..f75c362 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,6 @@ +defaults: + - cfg: scannet +hydra: + output_subdir: null # Disable saving of config files. We'll do that ourselves. + run: + dir: . \ No newline at end of file diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..8f8a3ee --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from dataset.scannet_dataset import ScannetDataset +from dataset.uorf_dataset import MultiscenesDataset, VisualDataset +from dataset.dtu_dataset import DTUDataset, VisDTUDataset + +datasets = { + "scannet": ScannetDataset, + "uorf": MultiscenesDataset, + "dtu": DTUDataset, + "uorf_vis": VisualDataset, + "dtu_vis": VisDTUDataset, +} + +def get_dataset(config): + if config.job_type == 'vis': + dataset = datasets[config.dataset + "_vis"] + else: + dataset = datasets[config.dataset] + train_set = dataset(config, "train") + val_set = dataset(config, "val") + test_set = dataset(config, "test") + return train_set, val_set, test_set + diff --git a/dataset/data_utils.py b/dataset/data_utils.py new file mode 100644 index 0000000..6b43ff4 --- /dev/null +++ b/dataset/data_utils.py @@ -0,0 +1,280 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +from PIL import Image +import torchvision.transforms as transforms +import torch +from scipy.spatial.transform import Rotation as R +import cv2 + +rng = np.random.RandomState(234) +_EPS = np.finfo(float).eps * 4.0 +TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision + + +def vector_norm(data, axis=None, out=None): + """Return length, i.e. eucledian norm, of ndarray along axis. + """ + data = np.array(data, dtype=np.float64, copy=True) + if out is None: + if data.ndim == 1: + return math.sqrt(np.dot(data, data)) + data *= data + out = np.atleast_1d(np.sum(data, axis=axis)) + np.sqrt(out, out) + return out + else: + data *= data + np.sum(data, axis=axis, out=out) + np.sqrt(out, out) + + +def quaternion_about_axis(angle, axis): + """Return quaternion for rotation about axis. + """ + quaternion = np.zeros((4, ), dtype=np.float64) + quaternion[:3] = axis[:3] + qlen = vector_norm(quaternion) + if qlen > _EPS: + quaternion *= math.sin(angle/2.0) / qlen + quaternion[3] = math.cos(angle/2.0) + return quaternion + + +def quaternion_matrix(quaternion): + """Return homogeneous rotation matrix from quaternion. + """ + q = np.array(quaternion[:4], dtype=np.float64, copy=True) + nq = np.dot(q, q) + if nq < _EPS: + return np.identity(4) + q *= math.sqrt(2.0 / nq) + q = np.outer(q, q) + return np.array(( + (1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], 0.0), + ( q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], 0.0), + ( q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0), + ( 0.0, 0.0, 0.0, 1.0) + ), dtype=np.float64) + + +def rectify_inplane_rotation(src_pose, tar_pose, src_img, th=40): + relative = np.linalg.inv(tar_pose).dot(src_pose) + relative_rot = relative[:3, :3] + r = R.from_matrix(relative_rot) + euler = r.as_euler('zxy', degrees=True) + euler_z = euler[0] + if np.abs(euler_z) < th: + return src_pose, src_img + + R_rectify = R.from_euler('z', -euler_z, degrees=True).as_matrix() + src_R_rectified = src_pose[:3, :3].dot(R_rectify) + out_pose = np.eye(4) + out_pose[:3, :3] = src_R_rectified + out_pose[:3, 3:4] = src_pose[:3, 3:4] + h, w = src_img.shape[:2] + center = ((w - 1.) / 2., (h - 1.) / 2.) + M = cv2.getRotationMatrix2D(center, -euler_z, 1) + src_img = np.clip((255*src_img).astype(np.uint8), a_max=255, a_min=0) + rotated = cv2.warpAffine(src_img, M, (w, h), borderValue=(255, 255, 255), flags=cv2.INTER_LANCZOS4) + rotated = rotated.astype(np.float32) / 255. + return out_pose, rotated + + +def random_crop(rgb, camera, src_rgbs, src_cameras, size=(400, 600), center=None): + h, w = rgb.shape[:2] + out_h, out_w = size[0], size[1] + if out_w >= w or out_h >= h: + return rgb, camera, src_rgbs, src_cameras + + if center is not None: + center_h, center_w = center + else: + center_h = np.random.randint(low=out_h // 2 + 1, high=h - out_h // 2 - 1) + center_w = np.random.randint(low=out_w // 2 + 1, high=w - out_w // 2 - 1) + + rgb_out = rgb[center_h - out_h // 2:center_h + out_h // 2, center_w - out_w // 2:center_w + out_w // 2, :] + src_rgbs = np.array(src_rgbs) + src_rgbs = src_rgbs[:, center_h - out_h // 2:center_h + out_h // 2, + center_w - out_w // 2:center_w + out_w // 2, :] + camera[0] = out_h + camera[1] = out_w + camera[4] -= center_w - out_w // 2 + camera[8] -= center_h - out_h // 2 + src_cameras[:, 4] -= center_w - out_w // 2 + src_cameras[:, 8] -= center_h - out_h // 2 + src_cameras[:, 0] = out_h + src_cameras[:, 1] = out_w + return rgb_out, camera, src_rgbs, src_cameras + + +def random_flip(rgb, camera, src_rgbs, src_cameras): + h, w = rgb.shape[:2] + h_r, w_r = src_rgbs.shape[1:3] + rgb_out = np.flip(rgb, axis=1).copy() + src_rgbs = np.flip(src_rgbs, axis=-2).copy() + camera[2] *= -1 + camera[4] = w - 1. - camera[4] + src_cameras[:, 2] *= -1 + src_cameras[:, 4] = w_r - 1. - src_cameras[:, 4] + return rgb_out, camera, src_rgbs, src_cameras + + +def get_color_jitter_params(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2): + color_jitter = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) + transform = transforms.ColorJitter.get_params(color_jitter.brightness, + color_jitter.contrast, + color_jitter.saturation, + color_jitter.hue) + return transform + + +def color_jitter(img, transform): + ''' + Args: + img: np.float32 [h, w, 3] + transform: + Returns: transformed np.float32 + ''' + img = Image.fromarray((255.*img).astype(np.uint8)) + img_trans = transform(img) + img_trans = np.array(img_trans).astype(np.float32) / 255. + return img_trans + + +def color_jitter_all_rgbs(rgb, ref_rgbs, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2): + transform = get_color_jitter_params(brightness, contrast, saturation, hue) + rgb_trans = color_jitter(rgb, transform) + ref_rgbs_trans = [] + for ref_rgb in ref_rgbs: + ref_rgbs_trans.append(color_jitter(ref_rgb, transform)) + + ref_rgbs_trans = np.array(ref_rgbs_trans) + return rgb_trans, ref_rgbs_trans + + +def deepvoxels_parse_intrinsics(filepath, trgt_sidelength, invert_y=False): + # Get camera intrinsics + with open(filepath, 'r') as file: + f, cx, cy = list(map(float, file.readline().split()))[:3] + grid_barycenter = torch.Tensor(list(map(float, file.readline().split()))) + near_plane = float(file.readline()) + scale = float(file.readline()) + height, width = map(float, file.readline().split()) + + try: + world2cam_poses = int(file.readline()) + except ValueError: + world2cam_poses = None + + if world2cam_poses is None: + world2cam_poses = False + + world2cam_poses = bool(world2cam_poses) + + cx = cx / width * trgt_sidelength + cy = cy / height * trgt_sidelength + f = trgt_sidelength / height * f + + fx = f + if invert_y: + fy = -f + else: + fy = f + + # Build the intrinsic matrices + full_intrinsic = np.array([[fx, 0., cx, 0.], + [0., fy, cy, 0], + [0., 0, 1, 0], + [0, 0, 0, 1]]) + + return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses + + +def angular_dist_between_2_vectors(vec1, vec2): + vec1_unit = vec1 / (np.linalg.norm(vec1, axis=1, keepdims=True) + TINY_NUMBER) + vec2_unit = vec2 / (np.linalg.norm(vec2, axis=1, keepdims=True) + TINY_NUMBER) + angular_dists = np.arccos(np.clip(np.sum(vec1_unit*vec2_unit, axis=-1), -1.0, 1.0)) + return angular_dists + + +def batched_angular_dist_rot_matrix(R1, R2): + ''' + calculate the angular distance between two rotation matrices (batched) + :param R1: the first rotation matrix [N, 3, 3] + :param R2: the second rotation matrix [N, 3, 3] + :return: angular distance in radiance [N, ] + ''' + assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3 + return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2., + a_min=-1 + TINY_NUMBER, a_max=1 - TINY_NUMBER)) + + +def get_nearest_pose_ids(tar_pose, ref_poses, num_select, tar_id=-1, angular_dist_method='vector', + scene_center=(0, 0, 0)): + ''' + Args: + tar_pose: target pose [3, 3] + ref_poses: reference poses [N, 3, 3] + num_select: the number of nearest views to select + Returns: the selected indices + ''' + num_cams = len(ref_poses) + num_select = min(num_select, num_cams-1) + batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0) + + if angular_dist_method == 'matrix': + dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]) + elif angular_dist_method == 'vector': + tar_cam_locs = batched_tar_pose[:, :3, 3] + ref_cam_locs = ref_poses[:, :3, 3] + scene_center = np.array(scene_center)[None, ...] + tar_vectors = tar_cam_locs - scene_center + ref_vectors = ref_cam_locs - scene_center + dists = angular_dist_between_2_vectors(tar_vectors, ref_vectors) + elif angular_dist_method == 'dist': + tar_cam_locs = batched_tar_pose[:, :3, 3] + ref_cam_locs = ref_poses[:, :3, 3] + dists = np.linalg.norm(tar_cam_locs - ref_cam_locs, axis=1) + elif angular_dist_method == 'mix': + ang_dists = batched_angular_dist_rot_matrix(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]) + dists = 0.5 * ang_dists + 0.5 * np.linalg.norm(batched_tar_pose[:, :3, 3] - ref_poses[:, :3, 3], axis=1) + else: + raise Exception('unknown angular distance calculation method!') + + sorted_ids = np.argsort(dists) + if tar_id >= 0: + sorted_ids = sorted_ids[1:] + selected_ids = sorted_ids[:num_select] + # print(angular_dists[selected_ids] * 180 / np.pi) + return selected_ids + + +def resize(img: Image, cams, size): + r''' + Args: + img: PIL.Image + cams: [27] hw + intrinsics (3x3) + pose (4x4) + ''' + W, H = img.size + img = img.resize(size, Image.LANCZOS) + w, h = size + cams[0] = h + cams[1] = w + cams[2:11] = np.reshape( + np.diag([w/W, h/H, 1]) @ np.reshape(cams[2:11], [3, 3]), + [9]) + return img, cams diff --git a/dataset/dtu_dataset.py b/dataset/dtu_dataset.py new file mode 100644 index 0000000..0460dd9 --- /dev/null +++ b/dataset/dtu_dataset.py @@ -0,0 +1,435 @@ +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) + +from pathlib import Path +import torch +from glob import glob +from PIL import Image +import numpy as np +from tqdm import tqdm +import cv2 +from util.misc import SubSampler +from pathlib import Path +from util.ray import get_rays +from dataset.data_utils import get_nearest_pose_ids + + +class DTUDataset(torch.utils.data.Dataset): + def __init__( + self, + cfg, + split="train", + ): + """ + :param path dataset root path, contains metadata.yml + :param split train | val | test + :param list_prefix prefix for split lists: [train, val, test].lst + """ + super().__init__() + self.data_root = cfg.dataset_root + self.normalize = cfg.normalize + self.depth_range = [0.1, 5.0] + self.split = split + self.img_root = f'image' + dino_patch_size = 14 + self.dino_feats_size = cfg.dino_feats_size + self.feat_map_size = cfg.dino_feats_size[0] // dino_patch_size, cfg.dino_feats_size[1] // dino_patch_size + self.ray_batchsize = cfg.ray_batchsize + self.subsampler = SubSampler() + self.num_src_view = cfg.num_src_view + + self.scene_id = cfg.scene_id + + self.img_size = cfg.img_size[0], cfg.img_size[1] + # sub_format == "dtu": + self._coord_trans_world = torch.tensor( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ) + self._coord_trans_cam = torch.tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ) + + self.setup_data() + + def __len__(self): + return sum([len(x) for x in self.indices]) + + def setup_data(self): + path = "new_val.lst" if self.split != "train" else "new_train.lst" + scene_list_path = os.path.join(self.data_root, path) + scenes = [] + with open(scene_list_path, "r") as f: + for line in f: + scenes.append(line.strip()) + + self.scenes = [os.path.join(self.data_root, scene) for scene in scenes] + if self.scene_id != -1: + self.scenes = [self.scenes[self.scene_id]] + self.train_indices = [] + self.val_indices = [] + self.all_frame_names = [] + self.cam2normscene = [] + self.cam2scenes = [] + self.intrinsics = [] + self.scene2normscene = [] + self.normscene_scale = [] + self.all_cams = [] + for scene in tqdm(self.scenes, desc="Loading DTU dataset"): + self.setup_one_scene(scene) + print( + "Loading DTU dataset from", self.data_root, + 'found', len(self.scenes), + self.split,"scenes", + ) + if self.split == 'train': + self.indices = self.train_indices + else: + self.indices = self.val_indices + # self.indices = [list(range(len(self.all_frame_names[0])))] + self.frame_idx2scene_idx = {} + self.frame_idx2sample_idx = {} + frame_idx = 0 + for scene_idx, scene_indices in enumerate(self.indices): + for _, sample_idx in enumerate(scene_indices): + self.frame_idx2scene_idx[frame_idx] = scene_idx + self.frame_idx2sample_idx[frame_idx] = sample_idx + frame_idx += 1 + print("depth range", self.depth_range) + + def setup_one_scene(self, scene): + all_frames = sorted(glob(os.path.join(scene, self.img_root, "*"))) + all_frames = [x.split("/")[-1].split(".")[0] for x in all_frames] + self.all_frame_names.append(all_frames) + + fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0 + cam2scene = [] + dims = [] + img_w, img_h = Image.open(os.path.join(scene, self.img_root, f'{all_frames[0]}.png')).size + for i in range(len(all_frames)): + cams = np.load(os.path.join(self.data_root, scene, "cameras.npz")) + + # Decompose projection matrix + P = cams["world_mat_" + str(i)] + P = P[:3] + K, R, t = cv2.decomposeProjectionMatrix(P)[:3] + K = K / K[2, 2] + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + scale_mtx = cams.get("scale_mat_" + str(i)) + if scale_mtx is not None: + norm_trans = scale_mtx[:3, 3:] + norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None] + pose[:3, 3:] -= norm_trans + pose[:3, 3:] /= norm_scale + + fx += torch.tensor(K[0, 0]) + fy += torch.tensor(K[1, 1]) + cx += torch.tensor(K[0, 2]) + cy += torch.tensor(K[1, 2]) + + pose = torch.tensor(pose, dtype=torch.float32) + cam2scene.append(pose) + dims.append([img_h, img_w]) + fx /= len(all_frames) + fy /= len(all_frames) + cx /= len(all_frames) + cy /= len(all_frames) + intrinsic = torch.tensor([ + [fx, 0, cx, 0], + [0, fy, cy, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + cam2scene = torch.stack(cam2scene).float() + intrinsics = intrinsic.unsqueeze(0).expand(len(cam2scene), -1, -1).float() + cams = torch.cat([ + torch.Tensor(self.img_size).expand(len(cam2scene), -1), + intrinsics.flatten(1, 2), + cam2scene.flatten(1, 2)], dim=1) + self.all_cams.append(cams) + self.cam2scenes.append(cam2scene) + + if self.split == "train": + indices = list(range(len(all_frames))) + self.train_indices.append(indices) + else: + indices = list(range(len(all_frames))) + val_indices = indices[::8] + train_indices = [i for i in indices if i not in val_indices] + self.val_indices.append(val_indices) + self.train_indices.append(train_indices) + + def load_sample(self, sample_index, scene_idx): + scene_dir = self.scenes[scene_idx] + image = Image.open(Path(scene_dir) / f"{self.img_root}" / f"{self.all_frame_names[scene_idx][sample_index]}.png") + image = torch.from_numpy(np.array(image) / 255).float() # [H, W, 3] + if self.normalize: + image = image * 2 - 1 # [-1, 1] + return image.view(-1, 3), torch.zeros(1) + + def sample_support_views(self, pose, scene_idx): + # sample support ids + support_indices = self.train_indices[scene_idx] + if self.split == 'train': + subsample_factor = 1 + nearest_pose_ids = get_nearest_pose_ids(pose.numpy(), + self.cam2scenes[scene_idx][support_indices].numpy(), + min(self.num_src_view * subsample_factor, 22), + tar_id=0, + angular_dist_method='mix') + nearest_pose_ids = np.random.choice(nearest_pose_ids, self.num_src_view, replace=False) + else: + nearest_pose_ids = get_nearest_pose_ids(pose.numpy(), + self.cam2scenes[scene_idx][support_indices].numpy(), + self.num_src_view, + tar_id=-1, + angular_dist_method='mix') + nearest_pose_ids = torch.from_numpy(nearest_pose_ids).long() + + return nearest_pose_ids + + def __getitem__(self, idx): + scene_idx = self.frame_idx2scene_idx[idx] + sample_idx = self.frame_idx2sample_idx[idx] + sample = {} + + # sample support ids + pose = self.cam2scenes[scene_idx][sample_idx] + nearest_pose_ids = self.sample_support_views(pose, scene_idx) + support_indices = self.train_indices[scene_idx] + src_rgbs, src_cams, src_feats = [], [], [] + for i in nearest_pose_ids: + id = support_indices[i] + src_cam = self.all_cams[scene_idx][id] + src_rgb, src_feat = self.load_sample(id, scene_idx) + src_rgbs.append(src_rgb) + src_cams.append(src_cam) + src_feats.append(src_feat) + src_rgbs = torch.stack(src_rgbs) # [N, H*W, 3] + src_cams = torch.stack(src_cams) # [N, 34] + src_feats = torch.stack(src_feats) # [N, H1, W1, D] + + rgbs, feat = self.load_sample(sample_idx, scene_idx) + tgt_cam = self.all_cams[scene_idx][[sample_idx]] + tgt_rays = get_rays(tgt_cam, *self.img_size).view(-1, 6) # [HW, 3] + H, W = self.img_size + N = self.num_src_view + if self.split == 'train': + # subsample + Br = self.ray_batchsize + subsample_idx = self.subsampler.idx_subsample(self.img_size, Br) + tgt_rgbs = rgbs.gather(0, subsample_idx.expand(-1, 3)) # [Br, 3] + tgt_rays = tgt_rays.gather(0, subsample_idx.expand(-1, 6)) # [Br, 3] + else: + tgt_rgbs = rgbs[None] # [1, HW, 3], [1, HW] + tgt_rays = tgt_rays[None] # [1, HW, 3] + sample['rgbs'] = tgt_rgbs # [Br, 3] or [N1, HW, 3] + sample['rays'] = tgt_rays # [Br, 3] or [N1, HW, 3] + sample['cam'] = tgt_cam # [Br, 34] or [N1, 34] + sample['src_rgbs'] = src_rgbs.reshape(N, H, W, 3) # [N, HW,3] + sample['src_cams'] = src_cams # [N, 34] 2+9+16 + sample['src_feats'] = src_feats # [N, H1, W1, D] + sample['depth_range'] = torch.tensor(self.depth_range).float() + return sample + + +class VisDTUDataset(torch.utils.data.Dataset): + def __init__( + self, + cfg, + split="train", + ): + """ + :param path dataset root path, contains metadata.yml + :param split train | val | test + :param list_prefix prefix for split lists: [train, val, test].lst + """ + super().__init__() + self.data_root = cfg.dataset_root + self.normalize = cfg.normalize + self.depth_range = [0.1, 5.0] + self.img_root = f'image' + self.num_src_view = cfg.num_src_view + + self.scene_id = cfg.scene_id + self.img_size = cfg.img_size[0], cfg.img_size[1] + # sub_format == "dtu": + self._coord_trans_world = torch.tensor( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ) + self._coord_trans_cam = torch.tensor( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ) + self.num_vis = cfg.num_vis + self.setup_data() + + def __len__(self): + return self.num_vis + + def setup_data(self): + path = "new_val.lst" + scene_list_path = os.path.join(self.data_root, path) + scenes = [] + with open(scene_list_path, "r") as f: + for line in f: + scenes.append(line.strip()) + + scenes = [os.path.join(self.data_root, scene) for scene in scenes] + self.scene = scenes[self.scene_id] + self.setup_one_scene(self.scene) + + def setup_one_scene(self, scene): + all_frames = sorted(glob(os.path.join(scene, self.img_root, "*"))) + all_frames = [x.split("/")[-1].split(".")[0] for x in all_frames] + self.all_frame = all_frames + + fx, fy, cx, cy = 0.0, 0.0, 0.0, 0.0 + cam2scene = [] + dims = [] + img_w, img_h = Image.open(os.path.join(scene, self.img_root, f'{all_frames[0]}.png')).size + for i in range(len(all_frames)): + cams = np.load(os.path.join(self.data_root, scene, "cameras.npz")) + + # Decompose projection matrix + P = cams["world_mat_" + str(i)] + P = P[:3] + K, R, t = cv2.decomposeProjectionMatrix(P)[:3] + K = K / K[2, 2] + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + scale_mtx = cams.get("scale_mat_" + str(i)) + if scale_mtx is not None: + norm_trans = scale_mtx[:3, 3:] + norm_scale = np.diagonal(scale_mtx[:3, :3])[..., None] + pose[:3, 3:] -= norm_trans + pose[:3, 3:] /= norm_scale + + fx += torch.tensor(K[0, 0]) + fy += torch.tensor(K[1, 1]) + cx += torch.tensor(K[0, 2]) + cy += torch.tensor(K[1, 2]) + + pose = torch.tensor(pose, dtype=torch.float32) + cam2scene.append(pose) + dims.append([img_h, img_w]) + fx /= len(all_frames) + fy /= len(all_frames) + cx /= len(all_frames) + cy /= len(all_frames) + intrinsic = torch.tensor([ + [fx, 0, cx, 0], + [0, fy, cy, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + cam2scene = torch.stack(cam2scene).float() + intrinsics = intrinsic.unsqueeze(0).expand(len(cam2scene), -1, -1).float() + cams = torch.cat([ + torch.Tensor(self.img_size).expand(len(cam2scene), -1), + intrinsics.flatten(1, 2), + cam2scene.flatten(1, 2)], dim=1) + self.all_cams = cams + self.cam2scenes = cam2scene + indices = list(range(len(all_frames))) + self.indices = indices + + + N_poses = len(cam2scene) + tgt_cam2scene = cam2scene + tgt_cams = [torch.cat([torch.Tensor(self.img_size), intrinsics.flatten(), + tgt_cam2scene[i].flatten()])[None] for i in range(self.num_vis)] + tgt_cams = torch.cat(tgt_cams, dim=0) # [N, 34] + tgt_rays = get_rays(tgt_cams, *self.img_size) # [N, HW, 6] + self.tgt_poses = tgt_cam2scene + self.tgt_cams = tgt_cams + self.tgt_rays = tgt_rays + + def load_sample(self, sample_index): + image = Image.open(Path(self.scene) / f"{self.img_root}" / f"{self.all_frame[sample_index]}.png") + image = torch.from_numpy(np.array(image) / 255).float() # [H, W, 3] + if self.normalize: + image = image * 2 - 1 # [-1, 1] + return image.view(-1, 3), torch.zeros(1) + + def sample_support_views(self, pose): + nearest_pose_ids = get_nearest_pose_ids(pose.numpy(), + self.cam2scenes.numpy(), + self.num_src_view, + tar_id=-1, + angular_dist_method='mix') + nearest_pose_ids = torch.from_numpy(nearest_pose_ids).long() + return nearest_pose_ids + + def __getitem__(self, idx): + sample = {} + + # sample support ids + pose = self.tgt_poses[idx] + + nearest_pose_ids = self.sample_support_views(pose) + src_rgbs, src_cams, src_feats = [], [], [] + for id in nearest_pose_ids: + src_cam = self.all_cams[id] + src_rgb, src_feat = self.load_sample(id) + src_rgbs.append(src_rgb) + src_cams.append(src_cam) + src_feats.append(src_feat) + src_rgbs = torch.stack(src_rgbs) # [N, H*W, 3] + src_cams = torch.stack(src_cams) # [N, 34] + src_feats = torch.stack(src_feats) # [N, H1, W1, D] + + H, W = self.img_size + N = self.num_src_view + sample['rays'] = self.tgt_rays # [Br, 3] or [N1, HW, 3] + sample['cam'] = self.tgt_cams # [Br, 34] or [N1, 34] + sample['src_rgbs'] = src_rgbs.reshape(N, H, W, 3) # [N, HW,3] + sample['src_cams'] = src_cams # [N, 34] 2+9+16 + sample['src_feats'] = src_feats # [N, H1, W1, D] + sample['depth_range'] = torch.tensor(self.depth_range).float() + return sample + + +def rot_matrix_angular_dist(R1, R2): + assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3 + return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2., + a_min=-1 + 1e-6, a_max=1 - 1e-6)) + +import hydra +from torch.utils.data import DataLoader + +@hydra.main(config_path='../config/cfg', config_name='dtu', version_base='1.2') +def main(config): + config.num_workers = 0 + config.lambda_depth = 0.1 + + train_set = VisDTUDataset(config, "train") + val_set = VisDTUDataset(config, "val") + test_set = VisDTUDataset(config, "test") + + train_loader = DataLoader(train_set, batch_size=config.batch_size, num_workers=config.num_workers) + # val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=config.num_workers) + test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=config.num_workers) + for i, batch in tqdm(enumerate(train_loader), total=len(train_set)): + print(batch['cam'].shape) + print(batch['rgbs'].shape) + break + for i, batch in tqdm(enumerate(test_loader), total=len(test_set)): + print(batch['cam'].shape) + print(batch['rgbs'].shape) + break + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/dataset/preprocess_scannet.py b/dataset/preprocess_scannet.py new file mode 100644 index 0000000..e66b462 --- /dev/null +++ b/dataset/preprocess_scannet.py @@ -0,0 +1,145 @@ +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-3]) +sys.path.append(root_path) + +import cv2 +import glob +import json +import argparse +import numpy as np +from tqdm import tqdm +from PIL import Image +from pathlib import Path +from multiprocessing import cpu_count, Pool + + +def get_keyframe_indices(filenames, window_size): + """ + select non-blurry images within a moving window + """ + scores = [] + cores = cpu_count() - 1 + # print("Using", cores, "cores") + with Pool(cores) as pool: + for result in pool.map(compute_blur_score_opencv, filenames) : + scores.append(result) + keyframes = [i + np.argmin(scores[i:i + window_size]) for i in range(0, len(scores), window_size)] + return keyframes, scores + + +def compute_blur_score_opencv(filename): + """ + Estimate the amount of blur an image has with the variance of the Laplacian. + Normalize by pixel number to offset the effect of image size on pixel gradients & variance + https://github.com/deepfakes/faceswap/blob/ac40b0f52f5a745aa058f92339302065177dd28b/tools/sort/sort.py#L626 + """ + image = cv2.imread(str(filename)) + if image.ndim == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + blur_map = cv2.Laplacian(image, cv2.CV_32F) + score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1]) + return 1.0 - score + + +def subsample_scannet(src_folder, rate): + """ + sample every nth frame from scannet + """ + all_frames = sorted(list(x.stem for x in (src_folder / 'pose').iterdir()), key=lambda y: int(y) if y.isnumeric() else y) + total_sampled = int(len(all_frames) * rate) + sampled_frames = [all_frames[i * (len(all_frames) // total_sampled)] for i in range(total_sampled)] + unsampled_frames = [x for x in all_frames if x not in sampled_frames] + for frame in sampled_frames: + if 'inf' in Path(src_folder / "pose" / f"{frame}.txt").read_text(): + unsampled_frames.append(frame) + folders = ["color", "depth", "instance", "pose", "semantics"] + exts = ['jpg', 'png', 'png', 'txt', 'png'] + for folder, ext in tqdm(zip(folders, exts), desc='sampling'): + assert (src_folder / folder).exists(), src_folder + for frame in unsampled_frames: + if (src_folder / folder / f'{frame}.{ext}').exists(): + os.remove(str(src_folder / folder / f'{frame}.{ext}')) + else: + print(str(src_folder / folder / f'{frame}.{ext}'), "already exists!") + + +def subsample_scannet_blur_window(src_folder, min_frames): + """ + sample non blurry frames from scannet + """ + if os.path.exists(src_folder / f"sampled_frames.json"): + print("sampled_frames.json already exists, skipping") + return + scene_name = src_folder.name + all_frames = sorted(list(x.stem for x in (src_folder / f'pose').iterdir()), key=lambda y: int(y) if y.isnumeric() else y) + valid_frames = [] + for frame in all_frames: + if 'inf' not in Path(src_folder / f"pose" / f"{frame}.txt").read_text(): + valid_frames.append(frame) + print("Found", len(all_frames), "frames, ", len(valid_frames), "are valid") + valid_frame_paths = [Path(src_folder / f"color" / f"{frame}.jpg") for frame in valid_frames] + window_size = max(3, len(valid_frames) // min_frames) + frame_indices, _ = get_keyframe_indices(valid_frame_paths, window_size) + print("Using a window size of", window_size, "got", len(frame_indices), "frames") + sampled_frames = [valid_frames[i] for i in frame_indices] + # save as json + json.dump(sampled_frames, open(src_folder / f"sampled_frames.json", 'w'), indent=4) + + +def resize_files(img_paths, resize_depth=False): + for img_path in img_paths: + img = Image.open(img_path) # 1296x968 + if not os.path.exists(img_path.replace('color', 'color_512512')): + img1 = img.resize([512, 512], Image.Resampling.LANCZOS) + img1.save(img_path.replace('color', 'color_512512')) + if not os.path.exists(img_path.replace('color', 'color_480640')): + img2 = img.resize([640, 480], Image.Resampling.LANCZOS) + img2.save(img_path.replace('color', 'color_480640')) + if resize_depth and not os.path.exists(img_path.replace('color', 'depth_512512')): + p_depth = img_path.replace('color', 'depth').replace('.jpg', '.png') + depth = Image.open(p_depth) + depth1 = depth.resize([512, 512], Image.Resampling.NEAREST) + depth1.save(p_depth.replace('depth', 'depth_512512')) + + +def process_one_scene(scene_path): + dest = scene_path + print('#' * 80) + scene_name = path.split('/')[-1] + print(f'subsampling from {scene_name}...') + subsample_scannet_blur_window(dest, min_frames=400) + + print('resizing images...') + os.makedirs(f'{path}/color_512512', exist_ok=True) + os.makedirs(f'{path}/color_480640', exist_ok=True) + os.makedirs(f'{path}/depth_512512', exist_ok=True) + + img_ids = json.load(open(f'{path}/sampled_frames.json', 'r')) + img_paths = [f'{path}/color/{img_id}.jpg' for img_id in img_ids] + resize_files(img_paths, resize_depth=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='scannet preprocessing') + parser.add_argument('--data_root', required=False, default='/home/yuliu/Dataset/scannet', help='file path') + args = parser.parse_args() + + scene_paths = sorted(glob.glob(f'{args.data_root}/*')) + for path in scene_paths: + scene_root = Path(path) + process_one_scene(scene_root) + + test_root = args.data_root.replace('scannet', 'scannet_test') + scene_paths = sorted(glob.glob(f'{test_root}/*')) + for path in scene_paths: + scene_name = path.split('/')[-1] + print(f'resizing images for {scene_name}...') + os.makedirs(f'{path}/color_512512', exist_ok=True) + os.makedirs(f'{path}/color_480640', exist_ok=True) + os.makedirs(f'{path}/depth_512512', exist_ok=True) + + img_paths = glob.glob(f'{path}/color/*.jpg') + resize_files(img_paths, resize_depth=True) + diff --git a/dataset/scannet_dataset.py b/dataset/scannet_dataset.py new file mode 100644 index 0000000..d082a5a --- /dev/null +++ b/dataset/scannet_dataset.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) +from pathlib import Path + +import torch +import numpy as np +from PIL import Image +import hydra +import json +from tqdm import tqdm +from glob import glob +import torch.nn.functional as F +from torch.utils.data import Dataset +from util.camera import compute_world2normscene +from util.misc import SubSampler +from util.ray import get_rays +from dataset.data_utils import get_nearest_pose_ids + + +def move_left_zero(x): + return '0' if int(x) == 0 else x.lstrip('0') + + +class ScannetDataset(Dataset): + def __init__(self, cfg, split): + super().__init__() + self.split = 'test' if split != 'train' else 'train' + self.root_dir = Path(cfg.dataset_root) + if self.split == 'test': + self.root_dir = Path(cfg.dataset_root.replace('scannet', 'scannet_test')) + self.img_size = cfg.img_size[0], cfg.img_size[1] + + self.img_root = f'color_480640' + + self.train_subsample_frames = cfg.train_subsample_frames + self.val_subsample_frames = cfg.val_subsample_frames + self.num_src_view = cfg.num_src_view + self.world2scene = np.eye(4, dtype=np.float32) + self.depth_range = [0.1, 10] + self.max_depth = 10 + self.normalize = cfg.normalize + self.ray_batchsize = cfg.ray_batchsize + self.norm_scene = cfg.norm_scene + self.subsampler = SubSampler() + self.setup_data() + + def __len__(self): + return sum([len(x) for x in self.indices]) + + def setup_data(self): + scenes = sorted(glob(f'{str(self.root_dir)}/*')) + test_scenes = ["scene0000_01", + "scene0079_00", + "scene0158_00", + "scene0316_00", + "scene0521_00", + "scene0553_00", + "scene0616_00", + "scene0653_00", + ] + self.scenes = scenes + print(f"Using {len(self.scenes)} scenes for {self.split}") + scene_names = [x.split('/')[-1] for x in self.scenes] + print(self.split, scene_names) + self.train_indices = [] + self.val_indices = [] + self.all_frame_names = [] + self.cam2normscene = [] + self.cam2scenes = [] + self.intrinsics = [] + self.scene2normscene = [] + self.normscene_scale = [] + self.segmentation_data = [] + + for idx, scene_dir in enumerate(self.scenes): + self.setup_data_one_scene(idx) + if self.split == 'train': + self.indices = self.train_indices + else: + self.indices = self.val_indices + self.frame_idx2scene_idx = {} + self.frame_idx2sample_idx = {} + frame_idx = 0 + for scene_idx, scene_indices in enumerate(self.indices): + for _, sample_idx in enumerate(scene_indices): + self.frame_idx2scene_idx[frame_idx] = scene_idx + self.frame_idx2sample_idx[frame_idx] = sample_idx + frame_idx += 1 + if self.norm_scene: + self.depth_range = self.depth_range[0] * min(self.normscene_scale), self.depth_range[1] * max(self.normscene_scale) + print(f"Depth range: {self.depth_range}") + + def setup_data_one_scene(self, scene_idx): + scene_dir = self.scenes[scene_idx] + # split + if self.split == 'train' or self.split == 'val': + frames = json.loads((Path(scene_dir) / "sampled_frames.json").read_text()) + scene_frames = sorted([x.zfill(4) for x in frames]) + scene_frames = [move_left_zero(x) for x in scene_frames] + sample_indices = list(range(len(scene_frames))) + val_indices = sample_indices[::8] + train_indices = [x for x in sample_indices if x not in val_indices] + val_indices = val_indices[::self.val_subsample_frames] + else: + train_frames = np.loadtxt(Path(scene_dir) / "train.txt", dtype=str) + val_frames = np.loadtxt(Path(scene_dir) / "test.txt", dtype=str) + scene_frames = train_frames.tolist() + val_frames.tolist() + scene_frames = [x.split('.')[0] for x in scene_frames] + train_indices = list(range(len(train_frames))) + val_indices = list(range(len(train_frames), len(scene_frames))) + sample_indices = list(range(len(scene_frames))) + + # print(f"Loading {scene_name}") + self.train_indices.append(train_indices) + self.val_indices.append(val_indices) + self.all_frame_names.append(scene_frames) + + dims, cam2scene = [], [] + img_h, img_w = 968, 1296 + intrinsic_color = np.array([[float(y.strip()) for y in x.strip().split()] for x in (Path(scene_dir) / f"intrinsic" / "intrinsic_color.txt").read_text().splitlines() if x != '']) + intrinsic_color = torch.from_numpy(intrinsic_color[:3, :3]).float() + scale_x, scale_y = self.img_size[1] / img_w, self.img_size[0] / img_h + intrinsic_normed = torch.diag(torch.Tensor([scale_x, scale_y, 1])) @ intrinsic_color + intrinsic = torch.eye(4) + intrinsic[:3, :3] = intrinsic_normed + self.intrinsics.append(intrinsic) + + for sample_index in sample_indices: + cam2world = np.array([[float(y.strip()) for y in x.strip().split()] for x in (Path(scene_dir) / f"pose" / f"{scene_frames[sample_index]}.txt").read_text().splitlines() if x != '']) + cam2world = torch.from_numpy(self.world2scene @ cam2world).float() + cam2scene.append(cam2world) + dims.append([img_h, img_w]) + + cam2scene = torch.stack(cam2scene) + intrinsics = intrinsic_color.unsqueeze(0).expand(len(cam2scene), -1, -1) + scene2normscene = compute_world2normscene( + torch.Tensor(dims).float(), + intrinsics, + cam2scene, + max_depth=self.max_depth, + rescale_factor=1.0 + ) + self.scene2normscene.append(scene2normscene) + self.normscene_scale.append(scene2normscene[0, 0]) + cam2normscene = [] + for sample_index in sample_indices: + cam2normscene.append(scene2normscene @ cam2scene[sample_index]) + cam2normscene = torch.stack(cam2normscene) + + self.cam2normscene.append(cam2normscene) + self.cam2scenes.append(cam2scene) + + def load_sample(self, sample_index, scene_idx): + scene_dir = self.scenes[scene_idx] + image = Image.open(Path(scene_dir) / f"{self.img_root}" / f"{self.all_frame_names[scene_idx][sample_index]}.jpg") + # image = image.resize(self.img_size[::-1], Image.BILINEAR) + image = torch.from_numpy(np.array(image) / 255).float() # [H, W, 3] + if self.normalize: + image = image * 2 - 1 # [-1, 1] + return image.view(-1, 3) + + def sample_support_views(self, pose, scene_idx): + # sample support ids + support_indices = self.train_indices[scene_idx] + if self.split == 'train': + subsample_factor = np.random.choice(np.arange(1, 4), p=[0.2, 0.45, 0.35]) + nearest_pose_ids = get_nearest_pose_ids(pose.numpy(), + self.cam2scenes[scene_idx][support_indices].numpy(), + min(self.num_src_view * subsample_factor, 22), + tar_id=0, + angular_dist_method='mix') + nearest_pose_ids = np.random.choice(nearest_pose_ids, self.num_src_view, replace=False) + else: + nearest_pose_ids = get_nearest_pose_ids(pose.numpy(), + self.cam2scenes[scene_idx][support_indices].numpy(), + self.num_src_view, + tar_id=-1, + angular_dist_method='mix') + nearest_pose_ids = torch.from_numpy(nearest_pose_ids).long() + + return nearest_pose_ids + + def __getitem__(self, idx): + scene_idx = self.frame_idx2scene_idx[idx] + sample_idx = self.frame_idx2sample_idx[idx] + sample = {} + + # sample support ids + poses = self.cam2normscene[scene_idx] if self.norm_scene else self.cam2scenes[scene_idx] + pose = self.cam2scenes[scene_idx][sample_idx] + nearest_pose_ids = self.sample_support_views(pose, scene_idx) + support_indices = self.train_indices[scene_idx] + src_rgbs, src_cams = [], [] + for i in nearest_pose_ids: + id = support_indices[i] + src_cam = torch.cat([torch.Tensor(self.img_size), self.intrinsics[scene_idx].flatten(), + poses[id].flatten()]) + src_rgb = self.load_sample(id, scene_idx) + src_rgbs.append(src_rgb) + src_cams.append(src_cam) + src_rgbs = torch.stack(src_rgbs) # [N, H*W, 3] + src_cams = torch.stack(src_cams) # [N, 34] + + rgbs = self.load_sample(sample_idx, scene_idx) + cam = torch.cat([torch.Tensor(self.img_size), self.intrinsics[scene_idx].flatten(), + poses[sample_idx].flatten()]) + tgt_cam = cam[None] # [1, 34] + tgt_rays = get_rays(tgt_cam, *self.img_size).view(-1, 6) # [HW, 3] + H, W = self.img_size + N = self.num_src_view + if self.split == 'train': + # subsample + Br = self.ray_batchsize + subsample_idx = self.subsampler.idx_subsample(self.img_size, Br) + sample['idx'] = subsample_idx # [Br, 1] + tgt_rgbs = rgbs.gather(0, subsample_idx.expand(-1, 3)) # [Br, 3] + tgt_rays = tgt_rays.gather(0, subsample_idx.expand(-1, 6)) # [Br, 3] + else: + tgt_rgbs = rgbs[None]# [1, HW, 3] + tgt_rays = tgt_rays[None] # [1, HW, 3] + semantics = Image.open(Path(self.scenes[scene_idx]) / "semantics" / f"{self.all_frame_names[scene_idx][sample_idx]}.png") + semantics = torch.tensor(np.array(semantics.resize(self.img_size[::-1], Image.Resampling.NEAREST), np.int32)).long().reshape(-1) + instances = Image.open(Path(self.scenes[scene_idx]) / "instance" / f"{self.all_frame_names[scene_idx][sample_idx]}.png") + instances = torch.tensor(np.array(instances.resize(self.img_size[::-1], Image.Resampling.NEAREST), np.int32)).long().reshape(-1) + sample['semantics'] = semantics # [H*W] + sample['instances'] = instances # [H*W] + sample['rgbs'] = tgt_rgbs # [Br, 3] or [N1, HW, 3] + sample['rays'] = tgt_rays # [Br, 3] or [N1, HW, 3] + sample['cam'] = tgt_cam # [Br, 34] or [N1, 34] + sample['src_rgbs'] = src_rgbs.reshape(N, H, W, 3) # [N, HW,3] + sample['src_cams'] = src_cams # [N, 34] 2+16+16 + sample['depth_range'] = torch.tensor(self.depth_range).float() + return sample + + + +## test +from torch.utils.data import DataLoader + +@hydra.main(config_path='../config/cfg', config_name='scannet', version_base='1.2') +def main(config): + config.img_size = [480, 640] + config.dino_feats_size = [210, 280] + config.num_workers = 0 + train_set = ScannetDataset(config, "train") + test_set = ScannetDataset(config, "test") + train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) + test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=config.num_workers) + for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)): + B = config.batch_size + print(batch['rgbs'].shape) + break + for i, batch in tqdm(enumerate(test_loader), total=len(test_loader)): + print(batch['instances'].unique()) + break + +# test +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/dataset/uorf_dataset.py b/dataset/uorf_dataset.py new file mode 100644 index 0000000..92f787e --- /dev/null +++ b/dataset/uorf_dataset.py @@ -0,0 +1,457 @@ +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) + +import torchvision.transforms.functional as TF + +from torch.utils.data import Dataset +from PIL import Image +import hydra +import torch +from glob import glob +import numpy as np +import random +import csv +from util.misc import SubSampler +from util.camera import rotate_cam +from util.ray import get_rays + + +class MultiscenesDataset(Dataset): + def __init__(self, cfg, split='train'): + super().__init__() + self.split = split + self.root = f"{cfg.dataset_root}/{cfg.subset}/{cfg.subset}_test" if split != 'train' else f"{cfg.dataset_root}/{cfg.subset}/{cfg.subset}_train" + self.subset = cfg.subset + self.n_scenes = 5000 + self.load_feat = False + self.feat_map_size = [16, 16] + self.ray_batchsize = cfg.ray_batchsize + self.img_size = cfg.img_size[0], cfg.img_size[1] + self.num_src_view = 1 + self.depth_range = [6, 20] + self.max_depth = 20 + self.render_src_view = cfg.render_src_view + self.load_mask = cfg.load_mask + self.norm_scene = cfg.norm_scene + self.subsampler = SubSampler() + self.setup_data() + + def setup_data(self): + self.scenes = [] + # for root in self.roots: + # self.root = root + file_path = os.path.join(self.root, 'files.csv') + if os.path.exists(file_path): + with open(file_path, newline='') as csvfile: + reader = csv.reader(csvfile) + for row in reader: + self.scenes.append(row) + else: + image_filenames = sorted(glob(os.path.join(self.root, '*.png'))) # root/00000_sc000_az00_el00.png + mask_filenames = sorted(glob(os.path.join(self.root, '*_mask.png'))) + fg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_moving.png'))) + moved_filenames = sorted(glob(os.path.join(self.root, '*_moved.png'))) + bg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_bg.png'))) + bg_in_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_providing_bg.png'))) + changed_filenames = sorted(glob(os.path.join(self.root, '*_changed.png'))) + bg_in_filenames = sorted(glob(os.path.join(self.root, '*_providing_bg.png'))) + changed_filenames_set, bg_in_filenames_set = set(changed_filenames), set(bg_in_filenames) + bg_mask_filenames_set, bg_in_mask_filenames_set = set(bg_mask_filenames), set(bg_in_mask_filenames) + image_filenames_set, mask_filenames_set = set(image_filenames), set(mask_filenames) + fg_mask_filenames_set, moved_filenames_set = set(fg_mask_filenames), set(moved_filenames) + filenames_set = image_filenames_set - mask_filenames_set - fg_mask_filenames_set - moved_filenames_set - changed_filenames_set - bg_in_filenames_set - bg_mask_filenames_set - bg_in_mask_filenames_set + filenames = sorted(list(filenames_set)) + scenes = [] + for i in range(self.n_scenes): + scene_filenames = [x for x in filenames if 'sc{:04d}'.format(i) in x] + if len(scene_filenames) > 0: + scenes.append(scene_filenames) + self.scenes += scenes + with open(file_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerows(scenes) + self.n_scenes = len(self.scenes) + + self.intrinsics = [] + self.scene2normscene = [] + self.normscene_scale = [] + self.cam2normscene = [] + self.cam2scene = [] + img_h, img_w = 256, 256 + for i in tqdm(range(self.n_scenes), desc=f"Loading {self.split} scenes"): + cam2scene, dims = [], [] + for path in self.scenes[i]: + pose_path = path.replace('.png', '_RT.txt') + pose = np.loadtxt(pose_path) + cam2scene.append(torch.tensor(pose, dtype=torch.float32)) # 4x4 + dims.append([img_h, img_w]) + cam2scene = torch.stack(cam2scene) # N, 4x4 + self.cam2scene.append(cam2scene) + + intrinsic = self.get_intrinsic(path.replace('.png', '_intrinsics.txt')) + intrinsic_normed = torch.diag(torch.Tensor([self.img_size[1] / img_w, + self.img_size[0] / img_h, 1, 1])) @ intrinsic + self.intrinsics.append(intrinsic_normed) + if self.norm_scene: + nss_scale = 7 # we follow the setting of uorf + scene2normscene = torch.Tensor([ + [1/nss_scale, 0, 0, 0], + [0, 1/nss_scale, 0, 0], + [0, 0, 1/nss_scale, 0], + [0, 0, 0, 1], + ]) + self.scene2normscene.append(scene2normscene) + self.normscene_scale.append(scene2normscene[0, 0]) + cam2normscene = [] + indice = list(range(len(cam2scene))) + for idx in indice: + cam2normscene.append(scene2normscene @ cam2scene[idx]) + cam2normscene = torch.stack(cam2normscene) + self.cam2normscene.append(cam2normscene) + if self.norm_scene: + self.depth_range = self.depth_range[0] * min(self.normscene_scale), self.depth_range[1] * max(self.normscene_scale) + print(f"{self.n_scenes} scenes for {self.split}") + print(f"depth range: {self.depth_range}") + + def _transform(self, img): + img = TF.resize(img, self.img_size) + img = TF.to_tensor(img) + return img + + def get_intrinsic(self, path): + frustum_size = (256, 256) + if not os.path.isfile(path): + focal_ratio = (350. / 320., 350. / 240.) + focal_x = focal_ratio[0] * frustum_size[0] + focal_y = focal_ratio[1] * frustum_size[1] + bias_x = (frustum_size[0] - 1.) / 2. + bias_y = (frustum_size[1] - 1.) / 2. + else: + intrinsics = np.loadtxt(path) + focal_x = intrinsics[0, 0] * frustum_size[0] + focal_y = intrinsics[1, 1] * frustum_size[1] + bias_x = ((intrinsics[0, 2] + 1) * frustum_size[0] - 1.) / 2. + bias_y = ((intrinsics[1, 2] + 1) * frustum_size[1] - 1.) / 2. + intrinsic = torch.tensor([[focal_x, 0, bias_x, 0], + [0, focal_y, bias_y, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + return intrinsic.float() + + def _transform_mask(self, img): + img = TF.resize(img, self.img_size, Image.NEAREST) + img = TF.to_tensor(img) + return img + + def load_sample(self, scene_idx, sample_idx): + path = self.scenes[scene_idx][sample_idx] + img = Image.open(path).convert('RGB') + img = self._transform(img).permute([1, 2, 0]).view(-1, 3) # [HW, 3] + + cam2scene = self.cam2normscene[scene_idx][sample_idx] if self.norm_scene else self.cam2scene[scene_idx][sample_idx] + cam = torch.cat([torch.Tensor(self.img_size), self.intrinsics[scene_idx].flatten(), + cam2scene.flatten()]) # 2+16+16=34 + + + return img, cam + + def __getitem__(self, index): + Br = self.ray_batchsize + sample = {} + if self.split == 'train': + scene_idx = index // 4 + sample_idx = index % 4 + filenames = self.scenes[scene_idx] + all_rgbs, all_cams = [], [] + for i, path in enumerate(filenames): + img, cam = self.load_sample(scene_idx, i) + all_rgbs.append(img) + all_cams.append(cam) + all_rgbs = torch.stack(all_rgbs) # [N, H*W, 3] + all_cams = torch.stack(all_cams) # [N, 34] + + tgt_cam = all_cams[0:1] # [1, 34] + all_rays = get_rays(all_cams, *self.img_size) # [N, HW, 6] + if np.random.rand() > 0.25: + # random delete source view + tgt_rgbs = all_rgbs[self.num_src_view:].view(-1, 3) # [N*Br, 3] # [N-1, HW, 3] + tgt_rays = all_rays[self.num_src_view:].view(-1, 6) # [N*Br, 6] + else: + tgt_rgbs = all_rgbs.view(-1, 3) # [N*Br, 3] + tgt_rays = all_rays.view(-1, 6) # [N*Br, 6] + subsample_idx = torch.randperm(tgt_rgbs.shape[0])[:Br][:, None] # [Br] + tgt_rgbs = tgt_rgbs.gather(0, subsample_idx.expand(-1, 3)) # [Br, 3] + tgt_rays = tgt_rays.gather(0, subsample_idx.expand(-1, 6)) # [Br, 6] + + src_rgbs = all_rgbs[0:self.num_src_view] # [N, HW, 3] + src_cams = all_cams[0:self.num_src_view] # [N, 34] + else: + filenames = self.scenes[index] + all_rgbs, all_cams = [], [] + for i, path in enumerate(filenames): + img, cam = self.load_sample(index, i) + all_rgbs.append(img) + all_cams.append(cam) + all_rgbs = torch.stack(all_rgbs) # [N, H*W, 3] + all_cams = torch.stack(all_cams) # [N, 34] + if 'kitchen' in self.subset: # no ground truth mask + masks = torch.randint(4, [all_rgbs.shape[0], self.img_size[0] * self.img_size[1]]) + else: + masks = [] + for path in filenames: + mask_path = path.replace('.png', '_mask.png') + if os.path.isfile(mask_path): + mask = Image.open(mask_path).convert('RGB') + mask_l = mask.convert('L') + mask = self._transform_mask(mask) + # ret['mask'] = mask + mask_l = self._transform_mask(mask_l) + mask_flat = mask_l.flatten(start_dim=0) # HW, + greyscale_dict = mask_flat.unique(sorted=True) # 8, + # make sure the background is always 0 + if self.subset not in ['room_texture', 'kitchen_shiny', 'kitchen_matte']: + bg_color = greyscale_dict[1].clone() + greyscale_dict[1] = greyscale_dict[0] + greyscale_dict[0] = bg_color + onehot_labels = mask_flat[:, None] == greyscale_dict # HWx8, one-hot + onehot_labels = onehot_labels.type(torch.uint8) + mask_idx = onehot_labels.argmax(dim=1).view(-1) # HW + masks.append(mask_idx) + masks = torch.stack(masks) # [N, HW] + sample['instances'] = masks + tgt_rgbs, tgt_cam = all_rgbs, all_cams + tgt_rays = get_rays(tgt_cam, *self.img_size) # [N, HW, 6] + Nv = self.num_src_view + src_rgbs, src_cams = all_rgbs[:Nv], all_cams[:Nv] + sample['rgbs'] = tgt_rgbs # [HW, 3] or [N, HW, 3] + sample['rays'] = tgt_rays # [HW, 6] or [N, HW, 6] + sample['cam'] = tgt_cam # [1, 34] or [N, 34] + sample['src_rgbs'] = src_rgbs.reshape(-1, *self.img_size, 3) # [N, H, W,3] N = 1 + sample['src_cams'] = src_cams # [N, 34] 2+16+16 + sample['depth_range'] = torch.tensor(self.depth_range).float() + + return sample + + def __len__(self): + return self.n_scenes if self.split != 'train' else self.n_scenes * 4 + + +class VisualDataset(Dataset): + def __init__(self, cfg, split='test'): + super().__init__() + self.split = split + self.subset = cfg.subset + self.root = f"{cfg.dataset_root}/{cfg.subset}/{cfg.subset}_{split}" + self.n_scenes = cfg.n_scenes + self.feat_map_size = [16, 16] + self.ray_batchsize = cfg.ray_batchsize + self.img_size = cfg.img_size[0], cfg.img_size[1] + self.num_src_view = 1 + self.depth_range = [6, 20] + self.max_depth = 20 + self.normalize = cfg.normalize + self.load_mask = cfg.load_mask + self.norm_scene = cfg.norm_scene + self.num_vis = cfg.num_vis + self.setup_data() + + def setup_data(self): + self.scenes = [] + file_path = os.path.join(self.root, 'files.csv') + if os.path.exists(file_path): + with open(file_path, newline='') as csvfile: + reader = csv.reader(csvfile) + for row in reader: + self.scenes.append(row) + else: + image_filenames = sorted(glob(os.path.join(self.root, '*.png'))) # root/00000_sc000_az00_el00.png + mask_filenames = sorted(glob(os.path.join(self.root, '*_mask.png'))) + fg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_moving.png'))) + moved_filenames = sorted(glob(os.path.join(self.root, '*_moved.png'))) + bg_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_bg.png'))) + bg_in_mask_filenames = sorted(glob(os.path.join(self.root, '*_mask_for_providing_bg.png'))) + changed_filenames = sorted(glob(os.path.join(self.root, '*_changed.png'))) + bg_in_filenames = sorted(glob(os.path.join(self.root, '*_providing_bg.png'))) + changed_filenames_set, bg_in_filenames_set = set(changed_filenames), set(bg_in_filenames) + bg_mask_filenames_set, bg_in_mask_filenames_set = set(bg_mask_filenames), set(bg_in_mask_filenames) + image_filenames_set, mask_filenames_set = set(image_filenames), set(mask_filenames) + fg_mask_filenames_set, moved_filenames_set = set(fg_mask_filenames), set(moved_filenames) + filenames_set = image_filenames_set - mask_filenames_set - fg_mask_filenames_set - moved_filenames_set - changed_filenames_set - bg_in_filenames_set - bg_mask_filenames_set - bg_in_mask_filenames_set + filenames = sorted(list(filenames_set)) + scenes = [] + for i in range(self.n_scenes): + scene_filenames = [x for x in filenames if 'sc{:04d}'.format(i) in x] + if len(scene_filenames) > 0: + scenes.append(scene_filenames) + self.scenes = scenes + with open(file_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerows(scenes) + self.n_scenes = len(self.scenes) + + self.intrinsics = [] + self.normscene_scale = [] + self.cam2normscene = [] + self.cam2scene = [] + img_h, img_w = 256, 256 + self.n_scenes = len(self.scenes) + for i in tqdm(range(self.n_scenes), desc=f"Loading {self.split} scenes"): + cam2scene, dims = [], [] + for path in self.scenes[i]: + pose_path = path.replace('.png', '_RT.txt') + pose = np.loadtxt(pose_path) + cam2scene.append(torch.tensor(pose, dtype=torch.float32)) # 4x4 + dims.append([img_h, img_w]) + cam2scene = torch.stack(cam2scene) # N, 4x4 + self.cam2scene.append(cam2scene) + + intrinsic = self.get_intrinsic(path.replace('.png', '_intrinsics.txt')) + intrinsic_normed = torch.diag(torch.Tensor([self.img_size[1] / img_w, + self.img_size[0] / img_h, 1, 1])) @ intrinsic + self.intrinsics.append(intrinsic_normed) + if self.norm_scene: + nss_scale = 7 # we follow the setting of uorf + scene2normscene = torch.Tensor([ + [1/nss_scale, 0, 0, 0], + [0, 1/nss_scale, 0, 0], + [0, 0, 1/nss_scale, 0], + [0, 0, 0, 1], + ]) + self.normscene_scale.append(scene2normscene[0, 0]) + cam2normscene = [] + indice = list(range(len(cam2scene))) + for idx in indice: + cam2normscene.append(scene2normscene @ cam2scene[idx]) + cam2normscene = torch.stack(cam2normscene) + self.cam2normscene.append(cam2normscene) + if self.norm_scene: + self.depth_range = self.depth_range[0] * min(self.normscene_scale), self.depth_range[1] * max(self.normscene_scale) + print(f"{self.n_scenes} scenes for visualization") + print(f"depth range: {self.depth_range}") + + def _transform(self, img): + img = TF.resize(img, self.img_size) + img = TF.to_tensor(img) + if self.normalize: + img = TF.normalize(img, [0.5] * img.shape[0], [0.5] * img.shape[0]) # [0,1] -> [-1,1] + return img + + def get_intrinsic(self, path): + frustum_size = (256, 256) + if not os.path.isfile(path): + focal_ratio = (350. / 320., 350. / 240.) + focal_x = focal_ratio[0] * frustum_size[0] + focal_y = focal_ratio[1] * frustum_size[1] + bias_x = (frustum_size[0] - 1.) / 2. + bias_y = (frustum_size[1] - 1.) / 2. + else: + intrinsics = np.loadtxt(path) + focal_x = intrinsics[0, 0] * frustum_size[0] + focal_y = intrinsics[1, 1] * frustum_size[1] + bias_x = ((intrinsics[0, 2] + 1) * frustum_size[0] - 1.) / 2. + bias_y = ((intrinsics[1, 2] + 1) * frustum_size[1] - 1.) / 2. + intrinsic = torch.tensor([[focal_x, 0, bias_x, 0], + [0, focal_y, bias_y, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + return intrinsic.float() + + def _transform_mask(self, img): + img = TF.resize(img, self.img_size, Image.NEAREST) + img = TF.to_tensor(img) + return img + + def __getitem__(self, index): + sample = {} + path = self.scenes[index][0] + img = Image.open(path).convert('RGB') + src_rgbs = self._transform(img).permute([1, 2, 0])[None] # [1, H, W, 3] + cam2scene = self.cam2normscene[index][0] if self.norm_scene else self.cam2scene[index][0] + src_cams = torch.cat([torch.Tensor(self.img_size), self.intrinsics[index].flatten(), + cam2scene.flatten()])[None] + theta_list = np.linspace(0, np.pi * 0.25, self.num_vis, endpoint=False) + zoom_list = np.linspace(1, 0.8, self.num_vis // 2, endpoint=False).tolist() + zoom_list += np.linspace(0.8, 1.05, self.num_vis // 2, endpoint=False).tolist() + # theta_list = np.linspace(0, np.pi * 2, endpoint=False) + tgt_cam2scene = [rotate_cam(cam2scene, theta, zoom) for theta, zoom in zip(theta_list, zoom_list)] + tgt_cams = [torch.cat([torch.Tensor(self.img_size), self.intrinsics[index].flatten(), + tgt_cam2scene[i].flatten()])[None] for i in range(self.num_vis)] + tgt_cams = torch.cat(tgt_cams, dim=0) # [N, 34] + tgt_rays = get_rays(tgt_cams, *self.img_size) # [N, HW, 6] + + if 'kitchen' in self.subset: + mask_idx = torch.randint(4, [self.img_size[0] * self.img_size[1]]) + else: + mask_path = path.replace('.png', '_mask.png') + mask = Image.open(mask_path).convert('RGB') + mask_l = mask.convert('L') + mask = self._transform_mask(mask) + mask_l = self._transform_mask(mask_l) + mask_flat = mask_l.flatten(start_dim=0) # HW, + greyscale_dict = mask_flat.unique(sorted=True) # 8, + # make sure the background is always 0 + if self.subset not in ['room_texture', 'kitchen_shiny', 'kitchen_matte']: + bg_color = greyscale_dict[1].clone() + greyscale_dict[1] = greyscale_dict[0] + greyscale_dict[0] = bg_color + onehot_labels = mask_flat[:, None] == greyscale_dict # HWx8, one-hot + onehot_labels = onehot_labels.type(torch.uint8) + mask_idx = onehot_labels.argmax(dim=1).view(-1) # HW + sample['instances'] = mask_idx + sample['rays'] = tgt_rays # [N, HW, 6] + sample['cam'] = tgt_cams # [N, 34] + sample['src_rgbs'] = src_rgbs # [N, H, W,3] + sample['src_cams'] = src_cams # [N, 34] + sample['depth_range'] = torch.tensor(self.depth_range).float() + + return sample + + def __len__(self): + return self.n_scenes + + +# test +from torch.utils.data import DataLoader +from tqdm import tqdm +@hydra.main(config_path='../config/cfg', config_name='uorf', version_base='1.2') +def main(config): + config.img_size = [128, 128] + config.num_workers = 4 + config.subset = 'kitchen_matte' + config.dataset_root = "/home/yuliu/Dataset/uorf" + config.num_src_view = 1 + train_set = MultiscenesDataset(config, 'train') + val_set = MultiscenesDataset(config, 'val') + train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=0) + val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0) + # vis_set = MultiscenesVisualDataset(config) + # vis_loader = DataLoader(vis_set, batch_size=1, shuffle=False, num_workers=0) + for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)): + rgbs = batch['rgbs'] + print(rgbs.shape) + cam = batch['cam'] + src_rgbs = batch['src_rgbs'] + src_cams = batch['src_cams'] + if i > 7: + break + for i, batch in tqdm(enumerate(val_loader), total=len(val_loader)): + rgbs = batch['rgbs'] + mask = batch['instances'] + if i > 5: + break + # for i, batch in tqdm(enumerate(vis_loader), total=len(vis_loader)): + # cam = batch['cam'] + # print(cam.shape) + # src_rgbs = batch['src_rgbs'] + # src_cams = batch['src_cams'] + # if i > 5: + # break + + +if __name__ == '__main__': + main() + + \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/encoder.py b/model/encoder.py new file mode 100644 index 0000000..c9ee068 --- /dev/null +++ b/model/encoder.py @@ -0,0 +1,333 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from util.ray import get_rays_origin_and_direction + + +class PositionalEncoding(nn.Module): + def __init__(self, num_octaves=8, start_octave=0): + super().__init__() + self.num_octaves = num_octaves + self.start_octave = start_octave + + def forward(self, coords): + batch_size, num_points, dim = coords.shape + + octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves) + octaves = octaves.float().to(coords) + multipliers = 2**octaves * math.pi + coords = coords.unsqueeze(-1) + while len(multipliers.shape) < len(coords.shape): + multipliers = multipliers.unsqueeze(0) + + scaled_coords = coords * multipliers + + sines = torch.sin(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves) + cosines = torch.cos(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves) + + result = torch.cat((sines, cosines), -1) + return result + + +class RayEncoder(nn.Module): + def __init__(self, pos_octaves=8, pos_start_octave=0, ray_octaves=4, ray_start_octave=0): + super().__init__() + self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves, start_octave=pos_start_octave) + self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves, start_octave=ray_start_octave) + + def forward(self, pos, rays): + if len(rays.shape) == 4: + batchsize, height, width, dims = rays.shape + pos_enc = self.pos_encoding(pos.unsqueeze(1)) + pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1) + pos_enc = pos_enc.repeat(1, 1, height, width) + rays = rays.flatten(1, 2) + + ray_enc = self.ray_encoding(rays) + ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1]) + ray_enc = ray_enc.permute((0, 3, 1, 2)) + x = torch.cat((pos_enc, ray_enc), 1) + else: + pos_enc = self.pos_encoding(pos).expand(-1, rays.shape[1], -1) + ray_enc = self.ray_encoding(rays) + x = torch.cat((pos_enc, ray_enc), -1) + + return x + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect') + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, padding_mode='reflect') + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width, track_running_stats=False, affine=True) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width, track_running_stats=False, affine=True) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion, track_running_stats=False, affine=True) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class conv(nn.Module): + def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): + super(conv, self).__init__() + self.kernel_size = kernel_size + self.conv = nn.Conv2d(num_in_layers, + num_out_layers, + kernel_size=kernel_size, + stride=stride, + padding=(self.kernel_size - 1) // 2, + padding_mode='reflect') + self.bn = nn.InstanceNorm2d(num_out_layers, track_running_stats=False, affine=True) + + def forward(self, x): + return F.elu(self.bn(self.conv(x)), inplace=True) + + +class upconv(nn.Module): + def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): + super(upconv, self).__init__() + self.scale = scale + self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) + + def forward(self, x): + x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear') + return self.conv(x) + + +class ResUNet(nn.Module): + def __init__(self, + in_ch=3, + encoder='resnet34', + coarse_out_ch=32, + fine_out_ch=32, + norm_layer=None, + coarse_only=False + ): + + super(ResUNet, self).__init__() + assert encoder in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'], "Incorrect encoder type" + if encoder in ['resnet18', 'resnet34']: + filters = [64, 128, 256, 512] + else: + filters = [256, 512, 1024, 2048] + self.coarse_only = coarse_only + if self.coarse_only: + fine_out_ch = 0 + self.coarse_out_ch = coarse_out_ch + self.fine_out_ch = fine_out_ch + out_ch = coarse_out_ch + fine_out_ch + + # original + layers = [3, 4, 6, 3] + if norm_layer is None: + # norm_layer = nn.BatchNorm2d + norm_layer = nn.InstanceNorm2d + self._norm_layer = norm_layer + self.dilation = 1 + block = BasicBlock + replace_stride_with_dilation = [False, False, False] + self.inplanes = 64 + self.groups = 1 + self.base_width = 64 + self.conv1 = nn.Conv2d(in_ch, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False, padding_mode='reflect') + self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + + # decoder + self.upconv3 = upconv(filters[2], 128, 3, 2) + self.iconv3 = conv(filters[1] + 128, 128, 3, 1) + self.upconv2 = upconv(128, 64, 3, 2) + self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) + + # fine-level conv + self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion, track_running_stats=False, affine=True), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def skipconnect(self, x1, x2): + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2)) + + # for padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + + x = torch.cat([x2, x1], dim=1) + return x + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + + x = self.upconv3(x3) + x = self.skipconnect(x2, x) + x = self.iconv3(x) + + x = self.upconv2(x) + x = self.skipconnect(x1, x) + x = self.iconv2(x) + + x_out = self.out_conv(x) + + if self.coarse_only: + x_coarse = x_out + x_fine = None + else: + x_coarse = x_out[:, :self.coarse_out_ch, :] + x_fine = x_out[:, -self.fine_out_ch:, :] + return x_coarse, x_fine + + +class MultiViewResUNet(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.resunet = ResUNet(in_ch=3 + 3 + 3, coarse_out_ch=cfg.feature_size, coarse_only=True) + self.ray_encoder = RayEncoder(pos_octaves=10, pos_start_octave=0, + ray_octaves=10) + + def forward(self, src_cams, images): + BN, C, H, W = images.shape + rays_o, rays_d = get_rays_origin_and_direction(src_cams.view(BN, -1), H, W) # [B*N, 1, 3], [B*N, HW, 3] + ray_enc = torch.cat([rays_o.repeat(1, H*W, 1), rays_d], dim=-1).view(BN, H, W, 6).permute(0, 3, 1, 2) + x = torch.cat([images, ray_enc], dim=1) + x = self.resunet(x)[0] + return x + \ No newline at end of file diff --git a/model/nerf.py b/model/nerf.py new file mode 100644 index 0000000..a084a5c --- /dev/null +++ b/model/nerf.py @@ -0,0 +1,248 @@ +# MIT License +# +# Copyright (c) 2022 Anpei Chen +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import torch +from torch import nn +import numpy as np +import torch.nn.functional as F + +from model.slot_attn import linear +from model.projection import Projector + + +@torch.jit.script +def fused_mean_variance(x, weight): + mean = torch.sum(x*weight, dim=2, keepdim=True) + var = torch.sum(weight * (x - mean)**2, dim=2, keepdim=True) + return mean, var + + +class NeRF(nn.Module): + def __init__(self, cfg=None, n_samples=64): + super().__init__() + slot_dec_dim = cfg.slot_dec_dim + self.nerf_mlp_dim = cfg.nerf_mlp_dim + self.color_mlp = RenderMLP(slot_dec_dim, 3, cfg.pe_view, cfg.pe_feat, self.nerf_mlp_dim, cfg.normalize) + if not cfg.slot_density: + self.density_proj = linear(slot_dec_dim, 1) + + self.projector = Projector() + self.grid_init = cfg.grid_init + self.random_proj_ratio = cfg.random_proj_ratio + self.slot_dec_dim = slot_dec_dim + if self.random_proj_ratio > 0: + if cfg.num_src_view > 1: + self.base_fc = linear(2 * (cfg.feature_size + 3), slot_dec_dim) + else: + self.base_fc = linear(cfg.feature_size + 3, slot_dec_dim) + + self.pos_encoding = self.posenc(slot_dec_dim, n_samples) + + def posenc(self, d_hid, n_samples): + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_samples)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table = torch.from_numpy(sinusoid_table).float().unsqueeze(0) + return sinusoid_table + + def init_one_svd(self, n_components, grid_resolution, scale): + plane_coef, line_coef = [], [] + for i in range(len(self.vector_mode)): + vec_id = self.vector_mode[i] + mat_id_0, mat_id_1 = self.matrix_mode[i] + plane_coef.append(torch.nn.Parameter(scale * torch.randn((1, n_components[i], grid_resolution[mat_id_1], grid_resolution[mat_id_0])), requires_grad=True)) + line_coef.append(torch.nn.Parameter(scale * torch.randn((1, n_components[i], grid_resolution[vec_id], 1)), requires_grad=True)) + return torch.nn.ParameterList(plane_coef), torch.nn.ParameterList(line_coef) + + def get_coordinate_plane_line(self, xyz_sampled): + coordinate_plane = torch.stack((xyz_sampled[..., self.matrix_mode[0]], xyz_sampled[..., self.matrix_mode[1]], xyz_sampled[..., self.matrix_mode[2]])).detach().view(3, -1, 1, 2) + coordinate_line = torch.stack((xyz_sampled[..., self.vector_mode[0]], xyz_sampled[..., self.vector_mode[1]], xyz_sampled[..., self.vector_mode[2]])) + coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) + return coordinate_plane, coordinate_line + + def compute_feature(self, xyz_sampled): + if self.grid_init == 'tensorf': + coordinate_plane, coordinate_line = self.get_coordinate_plane_line(xyz_sampled) + plane_coef_point, line_coef_point = [], [] + for idx_plane in range(len(self.plane_grid)): + plane_coef_point.append(F.grid_sample(self.plane_grid[idx_plane], coordinate_plane[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1])) + line_coef_point.append(F.grid_sample(self.line_grid[idx_plane], coordinate_line[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1])) + plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) + return self.basis_mat((plane_coef_point * line_coef_point).T) + elif self.grid_init == '3d': + points_feature = F.grid_sample(self.grid, xyz_sampled.reshape(1, -1, 1, 1, 3), align_corners=True).squeeze().permute(1, 0) + return self.basis_mat(points_feature) + + def project_feature(self, xyz_sampled, cam, src_imgs, src_cams, src_feats, is_Train=False, proj_mask=None): + r""" + Input: + xyz_sampled: (B, Nr, Np, 3) + cam: (B, 27) + src_imgs: (B, Nv, 3, H, W) + src_cams: (B, Nv, 27) + src_feats: (B, Nv, D, H1, W1) + Output: + feature: (B, Nr*Np, D) + """ + B, Nr, Np, _ = xyz_sampled.shape # number of points + xyz_sampled = xyz_sampled.view(B, -1, 3) # (B, Nr*Np, 3) + src_imgs = src_imgs.type_as(src_feats) + points_feature = xyz_sampled.new_zeros(B, Nr*Np, self.slot_dec_dim) + if self.random_proj_ratio > 0: + if self.random_proj_ratio < 1 and is_Train: + if proj_mask is None: + noise = torch.rand([1, Nr, 1, 1], device=xyz_sampled.device) + proj_mask = noise <= self.random_proj_ratio + m = proj_mask.expand(1, Nr, Np, 1).reshape(1, -1, 1) + else: + m = proj_mask.expand(1, Nr, Np, 1).reshape(1, -1, 1) + if torch.sum(m) > 0: + rgb_feat, mask, pixel_locations = self.projector.compute(xyz_sampled[m.expand(B, -1, 3)].view(B, -1, 3), cam, src_imgs, src_cams, src_feats) + Nv = src_imgs.shape[1] + if Nv == 1: + x = self.base_fc(rgb_feat) + points_feature[m.expand(B, -1, self.slot_dec_dim)] = torch.sum(x * mask, dim=2).view(-1) + else: + weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-6) + mean, var = fused_mean_variance(rgb_feat, weight) + x = torch.cat([mean, var], dim=-1).squeeze(2) + points_feature[m.expand(B, -1, self.slot_dec_dim)] = self.base_fc(x).view(-1) + else: + mask = torch.zeros(B, Nr*Np, 1, 1, device=xyz_sampled.device) + else: + rgb_feat, mask, pixel_locations = self.projector.compute(xyz_sampled.view(B, -1, 3), cam, src_imgs, src_cams, src_feats) + Nv = src_imgs.shape[1] + if Nv == 1: + points_feature = torch.sum(self.base_fc(rgb_feat) * mask, dim=2) + else: + weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-6) + mean, var = fused_mean_variance(rgb_feat, weight) + x = torch.cat([mean, var], dim=-1).squeeze(2) + points_feature = self.base_fc(x) + return points_feature, proj_mask + + def compute_density(self, points_feature): + sigma = F.relu(self.density_proj(points_feature).squeeze(-1)) # (B, N) + return sigma + + +class RenderMLP(nn.Module): + + def __init__(self, in_channels, out_channels=3, pe_view=2, pe_feat=2, nerf_mlp_dim=128, normalize=True): + super().__init__() + self.pe_view = pe_view + self.pe_feat = pe_feat + self.output_channels = out_channels + self.view_independent = self.pe_view == 0 and self.pe_feat == 0 + self.in_feat_mlp = in_channels + + self.mlp = nn.Sequential( + linear(self.in_feat_mlp, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, out_channels) + ) + self.normalize = normalize + + def forward(self, rays_d, features): + out = self.mlp(features) + + if self.normalize: + out = out.tanh() + else: + out = out.tanh() / 2 + 0.5 + return out + + +class SemanticMLP(nn.Module): + def __init__(self, in_channels, out_channels=3, pe_feat=2, nerf_mlp_dim=128): + super().__init__() + self.pe_feat = pe_feat + self.output_channels = out_channels + self.in_feat_mlp = 2 * pe_feat * in_channels + in_channels + + self.mlp = nn.Sequential( + linear(self.in_feat_mlp, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, out_channels) + ) + + def forward(self, features): + indata = [features] + if self.pe_feat > 0: + indata += [positional_encoding(features, self.pe_feat)] + mlp_in = torch.cat(indata, dim=-1) + out = self.mlp(mlp_in) + return out + + +class InstanceMLP(nn.Module): + def __init__(self, in_channels, out_channels=3, pe_feat=2, nerf_mlp_dim=128): + super().__init__() + self.pe_feat = pe_feat + self.output_channels = out_channels + self.in_feat_mlp = 2 * pe_feat * in_channels + in_channels + + self.mlp = nn.Sequential( + linear(self.in_feat_mlp, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, nerf_mlp_dim, weight_init='kaiming', nonlinearity='leaky_relu'), + nn.LeakyReLU(), + linear(nerf_mlp_dim, out_channels) + ) + + def forward(self, features): + indata = [features] + if self.pe_feat > 0: + indata += [positional_encoding(features, self.pe_feat)] + mlp_in = torch.cat(indata, dim=-1) + out = self.mlp(mlp_in) + return out + + +def positional_encoding(positions, freqs): + freq_bands = (2 ** torch.arange(freqs)).to(positions.device) + pts = (positions[..., None] * freq_bands).reshape(positions.shape[:-1] + (freqs * positions.shape[-1],)) + pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) + return pts + diff --git a/model/projection.py b/model/projection.py new file mode 100644 index 0000000..9d77448 --- /dev/null +++ b/model/projection.py @@ -0,0 +1,103 @@ +# This file is modified from ContraNeRF +import torch +import torch.nn.functional as F + + +class Projector(): + def __init__(self): + pass + + def inbound(self, pixel_locations, h, w): + return (pixel_locations[..., 0] <= w - 1.) & \ + (pixel_locations[..., 0] >= 0) & \ + (pixel_locations[..., 1] <= h - 1.) &\ + (pixel_locations[..., 1] >= 0) + + def normalize(self, pixel_locations, h, w): + resize_factor = torch.tensor([w-1., h-1.]).type_as(pixel_locations).to(pixel_locations.device)[None, None, :] + normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1. # [n_views, n_points, 2] + return normalized_pixel_locations + + def compute_projections(self, xyz, src_cams): + B, Nv = src_cams.shape[:2] + src_intrinsics = src_cams[..., 2:18].reshape(B*Nv, 4, 4)[:, :3, :3] # [B*n_views, 3, 3] + src_poses = src_cams[..., -16:].reshape(B*Nv, 4, 4) + xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1) # [B, n_points, 4] + projections = src_intrinsics.bmm( + torch.inverse(src_poses).bmm(xyz_h.transpose(-1, -2).repeat(Nv, 1, 1))[:, :3] + ) # [B*n_views, 3, n_points] + projections = projections.transpose(-2, -1).reshape(B, Nv, -1, 3) # [B, n_views, n_points, 3] + pixel_locations = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8) # [B, n_views, n_points, 2] + pixel_locations = torch.clamp(pixel_locations, min=-1e5, max=1e5) + mask = projections[..., 2] > 0 # a point is invalid if behind the camera + return pixel_locations, mask + + def compute_angle(self, xyz, cam, src_cams): + B, Np, _ = xyz.shape + Nv = src_cams.shape[1] + src_poses = src_cams[..., -16:].reshape(B, Nv, 4, 4) + query_pose = cam[:, -16:].reshape(B, 1, 4, 4).expand(-1, Nv, -1, -1) # [B, n_views, 4, 4] + ray2tar_pose = (query_pose[:, :, :3, 3].unsqueeze(2) - xyz.unsqueeze(1)) # [B, n_views, n_samples, 3] + ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) + ray2src_pose = (src_poses[:, :, :3, 3].unsqueeze(2) - xyz.unsqueeze(1)) + ray2src_pose = ray2src_pose / (torch.norm(ray2src_pose, dim=-1, keepdim=True) + 1e-6) + ray_diff = ray2tar_pose - ray2src_pose + ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) + ray_diff_dot = torch.sum(ray2tar_pose * ray2src_pose, dim=-1, keepdim=True) + ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) + ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) # [B, n_views, n_samples, 4] + return ray_diff + + def compute(self, xyz, cam, src_imgs, src_cams, featmaps): + r""" + Input: + xyz: [B, n_samples, 3] + cam: [B, 1, 34] + src_imgs: [B, n_views, 3, h, w] + src_cams: [B, n_views, 34] + featmaps: [B, n_views, d, h1, w1] + Output: + rgb_feat_sampled: [B, n_samples, n_views, d+3] + ray_diff: [B, n_samples, n_views, 4] + mask: [B, n_samples, n_views, 1] + """ + B, Nv, _, H, W = src_imgs.shape + Np = xyz.shape[1] + # xyz = xyz.reshape(B, 128, 128, 96, 3)[:, 20:80, 30:90].reshape(B, -1, 3) + + # compute the projection of the query points to each reference image + pixel_locations, mask_in_front = self.compute_projections(xyz, src_cams) # [B, n_views, n_samples, 2], [B, n_views, n_samples] + # avoid numerical precision errors + pixel_locations[(pixel_locations < 0) & (pixel_locations > -0.5)] = 0 + pixel_locations[(pixel_locations > H - 1) & (pixel_locations < H - 0.5)] = H - 1 + pixel_locations[(pixel_locations > W - 1) & (pixel_locations < W - 0.5)] = W - 1 + + # # visualize for debug + # import matplotlib.pyplot as plt + + # painted_img = src_imgs.clone().detach().cpu()[0].numpy() + # pixel_locations[(pixel_locations.abs() > 500).any(-1)] = -100 + # for v in range(Nv): + # plt.imshow(painted_img[v, :, :, :].transpose(1, 2, 0)) + # plt.scatter(pixel_locations[0, v, :, 0].cpu().numpy(), pixel_locations[0, v, :, 1].cpu().numpy(), c='r') + # plt.savefig('debug/{}.png'.format(v)) + # plt.close() + normalized_pixel_locations = self.normalize(pixel_locations, H, W).reshape(B*Nv, 1, -1, 2) # [B*n_views, 1, n_samples, 2] + + # rgb sampling + src_imgs = src_imgs.flatten(0, 1) # [B*n_views, 3, h, w] + rgbs_sampled = F.grid_sample(src_imgs, normalized_pixel_locations, align_corners=True).view(B, Nv, 3, Np) # [B, n_views, 3, n_samples] + rgbs_sampled = rgbs_sampled.permute(0, 3, 1, 2) # [B, n_samples, n_views, 3] + + # deep feature sampling + featmaps = featmaps.flatten(0, 1) # [B*n_views, d, h1, w1] + feat_sampled = F.grid_sample(featmaps, normalized_pixel_locations, align_corners=True).view(B, Nv, -1, Np) # [B, n_views, d, n_samples] + feat_sampled = feat_sampled.permute(0, 3, 1, 2) # [B, n_samples, n_views, d] + rgb_feat_sampled = torch.cat([rgbs_sampled, feat_sampled], dim=-1) # [B, n_samples, n_views, d+3] + + # mask + inbound = self.inbound(pixel_locations, H, W) + # ray_diff = self.compute_angle(xyz, cam, src_cams) + # ray_diff = ray_diff.permute(0, 2, 1, 3) + mask = (inbound * mask_in_front).float().transpose(1, 2)[..., None] # [B, n_samples, n_views, 1] + return rgb_feat_sampled, mask, pixel_locations \ No newline at end of file diff --git a/model/renderer.py b/model/renderer.py new file mode 100644 index 0000000..ecf76d1 --- /dev/null +++ b/model/renderer.py @@ -0,0 +1,330 @@ +# MIT License +# +# Copyright (c) 2022 Anpei Chen +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import torch +from torch import nn +import torch.nn.functional as F + +from model.nerf import NeRF +from model.slot_attn import JointDecoder, linear + + +@torch.jit.script +def fused_mean_variance(x, weight): + mean = torch.sum(x*weight, dim=2, keepdim=True) + var = torch.sum(weight * (x - mean)**2, dim=2, keepdim=True) + return mean, var + +def positional_encoding(positions, freqs): + freq_bands = (2 ** torch.arange(freqs)).to(positions.device) + pts = (positions[..., None] * freq_bands).reshape(positions.shape[:-1] + (freqs * positions.shape[-1],)) + pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) + return pts + + +class NeRFRenderer(nn.Module): + + def __init__(self, depth_range, cfg): + super().__init__() + self.depth_range = depth_range + self.distance_scale = cfg.get('distance_scale', 25) + self.weight_thres_color = 0.0001 + self.feat_weight_thres = 0.01 + self.alpha_mask_threshold = 0.0075 + self.step_size = None + self.stop_semantic_grad = cfg.stop_semantic_grad + self.nerf_mlp_dim = cfg.nerf_mlp_dim + self.slot_dec_dim = cfg.slot_dec_dim + self.n_samples = cfg.n_samples + self.n_samples_fine = cfg.n_samples_fine + self.num_instances = cfg.max_instances + + self.slot_size = cfg.slot_size + self.inv_uniform = cfg.get('inv_uniform', False) + self.normalize = cfg.normalize + self.pos_proj = linear(120, self.slot_dec_dim) + + def compute_points_feature(self, nerf: NeRF, slot_dec: JointDecoder, + xyz_sampled, slots, dists, + cam, src_imgs, src_cams, src_feats, + shape, is_train, rays_d, proj_mask=None): + B, Nr, Np = shape + points_feature, proj_mask = nerf.project_feature(xyz_sampled.view(B, Nr, Np, 3), cam, src_imgs, src_cams, src_feats, is_train, proj_mask) + + pos_emb = positional_encoding(xyz_sampled.view(B, Nr*Np, 3), 10) + view_emb = positional_encoding(rays_d.view(B, Nr, 3), 10).unsqueeze(2).expand(-1, -1, Np, -1).reshape(B, Nr*Np, -1) + pos_emb = self.pos_proj(torch.cat([pos_emb, view_emb], dim=-1)) + points_feature = points_feature + pos_emb + + points_coor = xyz_sampled.view(B, Nr*Np, 3) + ret = slot_dec(points_feature, pos_emb, slots, points_coor, Nr) + + points_feature, w_slot_dec, sigma_slot = ret['x'], ret['w'], ret.get('sigma', None) + points_feature = points_feature + pos_emb + points_feature = points_feature.view(B*Nr, Np, self.slot_dec_dim) + if sigma_slot is not None: + sigma = sigma_slot.view(B*Nr, Np) + else: + sigma = nerf.compute_density(points_feature).view(B*Nr, Np) # [B*Nr, Np] + alpha, weight, bg_weight = self.raw_to_alpha(sigma, dists * self.distance_scale) + + return points_feature, w_slot_dec, weight, ret.get('sparse_loss', torch.zeros(1, device=points_feature.device)), proj_mask, sigma + + def render_color(self, points_feature, viewdirs, w, appearance_mask, shape, color_mlp, white_bg, is_train): + B, Nr, Np = shape + rgb = points_feature.new_zeros((B*Nr, Np, 3)) + valid_rgbs = color_mlp(viewdirs[appearance_mask], points_feature[appearance_mask]) + rgb[appearance_mask] = valid_rgbs + rgb_map = torch.sum(w * rgb, -2).reshape(B, Nr, 3) + return rgb_map + + def render_instance(self, points_feature, w, appearance_mask, shape, w_slot_dec=None, instance_mlp=None): + B, Nr, Np = shape + if instance_mlp is not None: + instances = points_feature.new_zeros((B*Nr, Np, self.num_instances)) + valid_instances = instance_mlp(points_feature[appearance_mask]) + instances[appearance_mask] = valid_instances + instances = F.softmax(instances, dim=-1) + instance_map = torch.sum(w * instances, -2).reshape(B, Nr, -1) # [B, Nr, K] + instance_map = instance_map / (torch.sum(instance_map, -1, keepdim=True) + 1e-8) + else: + instances = w_slot_dec.view(B*Nr, Np, -1)[..., 1:] + instance_map = torch.sum(w * instances, -2).reshape(B, Nr, -1) # [B, Nr, K] + instance_map = instance_map / (torch.sum(instance_map, -1, keepdim=True) + 1e-6) + + return instance_map + + def render(self, nerf: NeRF, slot_dec: JointDecoder, + xyz_sampled, z_vals, viewdirs, slots, + cam, src_imgs, src_cams, src_feats, + white_bg, is_train, + render_color=False, render_ins=False, render_depth=False, rays_d=None, proj_mask=None): + B, Nr, Np, _ = viewdirs.shape + dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) + points_feature, w_slot_dec, weight, sparse_loss, proj_mask, sigma = self.compute_points_feature(nerf, slot_dec, + xyz_sampled, slots, dists, + cam, src_imgs, src_cams, src_feats, + (B, Nr, Np), is_train, rays_d, proj_mask) + appearance_mask = weight > self.weight_thres_color # [B*Nr, Np] + w = weight[..., None] # [B*Nr, Np, 1] + + ret = { + "proj_mask": proj_mask, + "weight": weight.clone().detach(), + "sparse_loss": sparse_loss, + "w": w_slot_dec.view(B, Nr*Np, -1)[..., 1:], + "pts": xyz_sampled.view(B, Nr*Np, 3), + } + + if render_color: + viewdirs = viewdirs.reshape(B*Nr, Np, 3) + rgb_map = self.render_color(points_feature, viewdirs, w, appearance_mask, (B, Nr, Np), nerf.color_mlp, white_bg, is_train) + ret["rgb"] = rgb_map + + if render_depth: + depth_map = torch.sum(weight * z_vals, -1).reshape(B, Nr) + opacity_map = torch.sum(w, -2).reshape(B, Nr) + depth_map = depth_map + (1. - opacity_map) * z_vals.max() + ret["depth"] = depth_map + + if render_ins: + if self.stop_semantic_grad: + w = w.detach() + instance_map = self.render_instance(points_feature, w, appearance_mask, (B, Nr, Np), w_slot_dec) + ret["instance"] = instance_map + + return ret + + def forward(self, nerf_c: NeRF, nerf_f: NeRF, slot_dec: JointDecoder, slot_dec_fine: JointDecoder, + rays, depth_range, slots, + cam, src_imgs, src_cams, src_feats, + white_bg=False, is_train=False, + render_color=False, render_depth=False, + render_ins=False): + r""" + Input: + rays: [B, Nr, 6] + slots: [B, K, D] + white_bg: True or False + is_train: True or False + cam: [B, 27] + src_imgs: [B, Nv, 3, H, W] + src_cams: [B, Nv, 27] + src_feats: [B, Nv, D, H1, W1] + Output: + rgb: [B, Nr, 3] + instance: [B, Nr, K] + depth: [B, Nr] + feats: [B, Nr, D] + """ + B, Nr, _ = rays.shape + # assert B == 1, "Only support batch size 1" + Np = self.n_samples + rays = rays.reshape(-1, 6) + rays_o, rays_d = rays[:, :3], rays[:, 3:] + xyz_sampled, z_vals = sample_points_in_box(rays_o, rays_d, Np, + self.depth_range if self.depth_range is not None else depth_range, is_train, self.inv_uniform) # [B*Nr, n_samples, 3], [B*Nr, n_samples] + viewdirs = rays_d.view(B, Nr, 1, 3).expand(-1, -1, Np, -1) + + ret_c = self.render(nerf_c, slot_dec, xyz_sampled, z_vals, viewdirs, + slots, cam, src_imgs, src_cams, src_feats, + white_bg, is_train, + render_color=render_color, render_ins=render_ins, + render_depth=render_depth, rays_d=rays_d) + weight = ret_c["weight"] + ret = {} + for k, v in ret_c.items(): + if k != "weight" and k != 'proj_mask': + ret[k + "_c"] = v + + if nerf_f is not None: + Np_fine = self.n_samples_fine + xyz_sampled, z_vals = sample_fine_pts(rays_o, rays_d, weight, z_vals, Np_fine, Np, is_train, self.inv_uniform) + viewdirs = rays_d.view(B, Nr, 1, 3).expand(-1, -1, Np + Np_fine, -1) + ret_f = self.render(nerf_f, slot_dec_fine, xyz_sampled, z_vals, viewdirs, + slots, cam, src_imgs, src_cams, src_feats, + white_bg, is_train, + render_color=render_color, render_ins=render_ins, + render_depth=render_depth, rays_d=rays_d, proj_mask=ret_c['proj_mask']) + for k,v in ret_f.items(): + if k != "weight" and k != 'proj_mask': + ret[k + "_f"] = v + return ret + + @staticmethod + def raw_to_alpha(sigma, dist): + alpha = 1. - torch.exp(-sigma * dist) + T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) + weights = alpha * T[:, :-1] + return alpha, weights, T[:, -1:] + + +def sample_points_in_box(rays_o, rays_d, n_samples, depth_range, is_train, inv_uniform=False): + if isinstance(depth_range, tuple): + depth_range = torch.tensor(depth_range).float().to(rays_o.device) + depth_range = depth_range.expand(rays_d.shape[0], -1) + near_depth, far_depth = depth_range[..., 0], depth_range[..., 1] + if inv_uniform: + start = 1.0 / near_depth # [N_rays,] + step = (1.0 / far_depth - start) / (n_samples - 1) + inv_z_vals = torch.stack( + [start + i * step for i in range(n_samples)], dim=1 + ) # [N_rays, n_samples] + z_vals = 1.0 / inv_z_vals + if is_train: + mids = 0.5 * (z_vals[:, 1:] + z_vals[:, :-1]) + upper = torch.cat([mids, z_vals[:, -1:]], dim=-1) + lower = torch.cat([z_vals[:, 0:1], mids], dim=-1) + # uniform samples in those intervals + t_rand = torch.rand_like(z_vals) + z_vals = lower + (upper - lower) * t_rand # [N_rays, Np] + else: + step_size = (depth_range[..., 1:2] - depth_range[..., 0:1]) / (n_samples - 1) + rng = torch.arange(n_samples)[None].type_as(rays_o).expand(rays_o.shape[:-1] + (n_samples,)) + if is_train: + rng = rng + torch.rand_like(rng[:, [0]]).type_as(rng) + step = step_size * rng.to(rays_o.device) + z_vals = (depth_range[..., 0:1] + step) # [B*Nr, n_samples] + + rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None] # [B*Nr, n_samples, 3] + return rays_pts, z_vals + + +def sample_pdf(bins, weights, Np, det=False): + r''' + :param bins: tensor of shape [N_rays, M+1], M is the number of bins + :param weights: tensor of shape [N_rays, M] + :param Np: number of samples along each ray + :param det: if True, will perform deterministic sampling + :return: [N_rays, Np] + ''' + + M = weights.shape[1] + weights += 1e-5 + # Get pdf + pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [N_rays, M] + cdf = torch.cumsum(pdf, dim=-1) # [N_rays, M] + cdf = torch.cat([torch.zeros_like(cdf[:, 0:1]), cdf], dim=-1) # [N_rays, M+1] + + # Take uniform samples + if det: + u = torch.linspace(0.0, 1.0, Np, device=bins.device) + u = u.unsqueeze(0).repeat(bins.shape[0], 1) # [N_rays, Np] + else: + u = torch.rand(bins.shape[0], Np, device=bins.device) + # Invert CDF + above_inds = torch.zeros_like(u, dtype=torch.long) # [N_rays, Np] + for i in range(M): + above_inds += (u >= cdf[:, i:i+1]).long() + + # random sample inside each bin + below_inds = torch.clamp(above_inds-1, min=0) + inds_g = torch.stack((below_inds, above_inds), dim=2) # [N_rays, Np, 2] + + cdf = cdf.unsqueeze(1).repeat(1, Np, 1) # [N_rays, Np, M+1] + cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) # [N_rays, Np, 2] + + bins = bins.unsqueeze(1).repeat(1, Np, 1) # [N_rays, Np, M+1] + bins_g = torch.gather(input=bins, dim=-1, index=inds_g) # [N_rays, Np, 2] + + # t = (u-cdf_g[:, :, 0]) / (cdf_g[:, :, 1] - cdf_g[:, :, 0] + TINY_NUMBER) # [N_rays, Np] + # fix numeric issue + denom = cdf_g[:, :, 1] - cdf_g[:, :, 0] # [N_rays, Np] + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[:, :, 0]) / denom + + samples = bins_g[:, :, 0] + t * (bins_g[:, :, 1]-bins_g[:, :, 0]) + + return samples + +def sample_fine_pts(rays_o, rays_d, weights, z_vals, N_importance, Np, det=True, inv_uniform=False): + if inv_uniform: + inv_z_vals = 1.0 / z_vals + inv_z_vals_mid = 0.5 * (inv_z_vals[:, 1:] + inv_z_vals[:, :-1]) # [N_rays, Np-1] + weights = weights[:, 1:-1] # [N_rays, Np-2] + inv_z_vals = sample_pdf( + bins=torch.flip(inv_z_vals_mid, dims=[1]), + weights=torch.flip(weights, dims=[1]), + Np=N_importance, + det=det, + ) # [N_rays, N_importance] + z_samples = 1.0 / inv_z_vals + else: + # take mid-points of depth samples + z_vals_mid = 0.5 * (z_vals[:, 1:] + z_vals[:, :-1]) # [N_rays, Np-1] + weights = weights[:, 1:-1] # [N_rays, Np-2] + z_samples = sample_pdf( + bins=z_vals_mid, weights=weights, Np=N_importance, det=det + ) # [N_rays, N_importance] + + z_vals = torch.cat((z_vals, z_samples), dim=-1) # [N_rays, Np + N_importance] + + # samples are sorted with increasing depth + z_vals, _ = torch.sort(z_vals, dim=-1) + N_total_samples = Np + N_importance + + viewdirs = rays_d.unsqueeze(1).repeat(1, N_total_samples, 1) + ray_o = rays_o.unsqueeze(1).repeat(1, N_total_samples, 1) + pts = z_vals.unsqueeze(2) * viewdirs + ray_o # [N_rays, Np + N_importance, 3] + return pts, z_vals diff --git a/model/slot_attn.py b/model/slot_attn.py new file mode 100644 index 0000000..ba1c329 --- /dev/null +++ b/model/slot_attn.py @@ -0,0 +1,237 @@ +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath + +from model.transformer import TransformerDecoder +from model.encoder import MultiViewResUNet + + +def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1., nonlinearity='relu'): + + m = nn.Linear(in_features, out_features, bias) + + if weight_init == 'kaiming': + nn.init.kaiming_uniform_(m.weight, nonlinearity=nonlinearity) + else: + nn.init.xavier_uniform_(m.weight, gain) + + if bias: + nn.init.zeros_(m.bias) + + return m + + +def gru_cell(input_size, hidden_size, bias=True): + + m = nn.GRUCell(input_size, hidden_size, bias) + + nn.init.xavier_uniform_(m.weight_ih) + nn.init.orthogonal_(m.weight_hh) + + if bias: + nn.init.zeros_(m.bias_ih) + nn.init.zeros_(m.bias_hh) + + return m + + +class SlotAttention(nn.Module): + def __init__( + self, + feature_size, + slot_size, + drop_path=0.2, + num_head=1, + ): + super().__init__() + self.slot_size = slot_size + self.epsilon = 1.0 + self.num_head = num_head + + self.norm_feature = nn.LayerNorm(feature_size) + self.norm_mlp = nn.LayerNorm(slot_size) + self.norm_slots = nn.LayerNorm(slot_size) + + self.project_q = linear(slot_size, slot_size, bias=False) + self.project_k = linear(feature_size, slot_size, bias=False) + self.project_v = linear(feature_size, slot_size, bias=False) + + self.gru = gru_cell(slot_size, slot_size) + + self.mlp = nn.Sequential( + linear(slot_size, slot_size * 4, weight_init='kaiming'), + nn.ReLU(), + linear(slot_size * 4, slot_size), + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + def forward(self, features, slots_init, num_src_view, num_iter=3): + # features: [batch_size, num_feature, inputs_size] + features = self.norm_feature(features) + k = self.project_k(features) # Shape: [B, num_features, slot_size] + v = self.project_v(features) + + slots = slots_init + # Multiple rounds of attention. + for i in range(num_iter - 1): + slots, attn = self.iter(slots, k, v, num_src_view=num_src_view) + slots = slots.detach() + slots_init - slots_init.detach() + slots, attn = self.iter(slots, k, v, num_src_view=num_src_view) + return slots, attn + + def iter(self, slots, k, v, num_src_view): + B, K, D = slots.shape + slots_prev = slots + slots = self.norm_slots(slots) + q = self.project_q(slots) + + Nh = self.num_head + q = q.reshape(B, K, Nh, D//Nh).transpose(1, 2) # [B, Nh, K, D//Nh] + k = k.reshape(B, -1, Nh, D//Nh).transpose(1, 2) # [B, Nh, Nf, D//Nh] + + # Attention + scale = (D//Nh) ** -0.5 + attn_logits = torch.matmul(q, k.transpose(-1, -2)) * scale # [B, Nh, K, Nf] + attn_logits = attn_logits.mean(1) # [B, K, Nf] + attn = F.softmax(attn_logits, dim=1) + + # # Weighted mean + attn_sum = torch.sum(attn, dim=-1, keepdim=True) + self.epsilon + attn_wm = attn / attn_sum + updates = torch.einsum('bij, bjd->bid', attn_wm, v) + + # Update slots + slots = self.gru( + updates.reshape(-1, D), + slots_prev.reshape(-1, D) + ) + slots = slots.reshape(B, -1, D) + slots = slots + self.drop_path(self.mlp(self.norm_mlp(slots))) + return slots, attn + + +class SlotEnc(nn.Module): + def __init__( + self, num_iter, num_slots, feature_size, + slot_size, drop_path=0.2, num_blocks=1): + super().__init__() + + self.num_iter = num_iter + self.num_slots = num_slots + self.slot_size = slot_size + self.slot_attn = self.slot_attn = nn.ModuleList([ + SlotAttention(feature_size, slot_size, drop_path=drop_path) for i in range(num_blocks) + ]) + self.num_blocks = num_blocks + self.slots_init = nn.Parameter(torch.zeros(1, num_slots, slot_size)) + nn.init.xavier_uniform_(self.slots_init) + + def forward(self, f, sigma, num_src_view): + B, _, D = f.shape + # initialize slots. + mu = self.slots_init.expand(B, -1, -1) + z = torch.randn_like(mu).type_as(f) + slots = mu + z * sigma * mu.detach() + slots, attn = self.slot_attn[0](f, slots, num_iter=self.num_iter, num_src_view=num_src_view) + for i in range(self.num_blocks - 1): + slots, attn = self.slot_attn[i + 1](f, slots, num_iter=self.num_iter, num_src_view=num_src_view) + return slots, attn + + +class Slot3D(nn.Module): + def __init__(self, config): + super().__init__() + self.feature_size = config.feature_size + self.multi_view_enc = MultiViewResUNet(config) + self.slot_enc = SlotEnc( + num_iter=config.num_iter, + num_slots=config.num_slots, + feature_size=config.feature_size, + slot_size=config.slot_size, + drop_path=config.drop_path, + ) + self.num_slots = config.num_slots + + def forward(self, src_cams=None, images=None, sigma=0): + """ + Input: + images: [batch_size, N_views, 3, H, W] + src_cams: [batch_size, N_views, 25] + Output: + slots: [batch_size, num_slots, slot_size] + """ + B, N_view = src_cams.shape[:2] + H, W = images.shape[3:] + features = self.multi_view_enc(src_cams, images.reshape(B*N_view, 3, H, W)) # [B*N_views, D, H1, W1] + + H1, W1 = features.shape[-2:] + features = features.permute(0, 2, 3, 1) # [B*N_views, H1, W1, D] + + feats = features.reshape(B, -1, self.feature_size) + attn = torch.zeros(B, self.num_slots, feats.shape[1], device=feats.device) + slots, attn = self.slot_enc(feats, sigma=sigma, num_src_view=N_view) # [B, K, slot_size] + return slots, attn, features.reshape(B, N_view, H1, W1, features.shape[-1]) + + +class JointDecoder(nn.Module): + def __init__(self, cfg): + super().__init__() + d_kv = cfg.slot_size + self.num_slots = cfg.num_slots + self.tf = TransformerDecoder( + cfg.num_dec_blocks, cfg.slot_dec_dim, d_kv, 4, cfg.drop_path, 0) + self.empty_slot = nn.Parameter(torch.zeros(1, 1, cfg.slot_size)) + nn.init.xavier_uniform_(self.empty_slot) + + self.Wq = linear(cfg.slot_dec_dim, cfg.slot_dec_dim) + self.Wk = linear(cfg.slot_size, cfg.slot_dec_dim) + self.out_proj = linear(cfg.slot_size, cfg.slot_dec_dim) + self.force_bg = False + self.bg_bound = cfg.bg_bound + self.scale = cfg.slot_dec_dim ** -0.5 + self.slot_density = cfg.get('slot_density', False) + if self.slot_density: + self.density_scale = nn.Parameter(torch.zeros(1)) + + + def forward(self, point_feats, points_emb, slots, points_coor, Nr=0): + r""" + Input: + point_feats: [B, N, D] N = N_ray * N_points + slots: [B, K, slot_size] + Output: + x: [B, N, D] + w: [B, N, K] + """ + slots = torch.cat([self.empty_slot.expand(slots.shape[0], -1, -1), slots], dim=1) # [B, K+1, D] + x = self.tf(point_feats, slots, points_emb, Nr) # [B, N, D] + + # point slot mapping + q = self.Wq(x) + k = self.Wk(slots) + logits = torch.matmul(q, k.transpose(-1, -2)) * self.scale # [B, N, K+1] + if self.force_bg: + out_idx = (points_coor.abs() > self.bg_bound).any(-1)[..., None].repeat(1, 1, self.num_slots+1) # [B, N] + out_idx[:, :, 0:2] = False + logits[out_idx] = -torch.inf + w = F.softmax(logits, dim=-1) # [B, N, K+1] + x = torch.matmul(w, slots) # [B, N, D] + x = self.out_proj(x) + if self.slot_density: + slot_sigma = F.relu(logits) + sigma = (slot_sigma[..., 1:] * w[..., 1:]).sum(-1) + sigma = sigma * self.density_scale.exp() # [B, N] + else: + sigma = None + + return {'x': x, 'w': w, 'sigma': sigma} + + diff --git a/model/transformer.py b/model/transformer.py new file mode 100644 index 0000000..634dfcd --- /dev/null +++ b/model/transformer.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1., nonlinearity='relu'): + + m = nn.Linear(in_features, out_features, bias) + + if weight_init == 'kaiming': + nn.init.kaiming_uniform_(m.weight, nonlinearity=nonlinearity) + else: + nn.init.xavier_uniform_(m.weight, gain) + + if bias: + nn.init.zeros_(m.bias) + + return m + + +class MultiHeadAttention(nn.Module): + + def __init__(self, d_q, d_kv, d_proj, d_out, num_heads, dropout=0., gain=1., bias=True): + super().__init__() + + assert d_proj % num_heads == 0, "d_proj must be divisible by num_heads" + self.num_heads = num_heads + + self.attn_dropout = nn.Dropout(dropout) + self.output_dropout = nn.Dropout(dropout) + + self.proj_q = linear(d_q, d_proj, bias=bias) + self.proj_k = linear(d_kv, d_proj, bias=bias) + self.proj_v = linear(d_kv, d_proj, bias=bias) + self.proj_o = linear(d_proj, d_out, bias=bias, gain=gain) + + + def forward(self, q, k, v, attn_mask=None): + B, T, _ = q.shape + _, S, _ = k.shape + + q = self.proj_q(q).view(B, T, self.num_heads, -1).transpose(1, 2) + k = self.proj_k(k).view(B, S, self.num_heads, -1).transpose(1, 2) + v = self.proj_v(v).view(B, S, self.num_heads, -1).transpose(1, 2) + + q = q * (q.shape[-1] ** (-0.5)) + attn = torch.matmul(q, k.transpose(-1, -2)) + + if attn_mask is not None: + attn = attn.masked_fill(attn_mask, float('-inf')) + + attn = F.softmax(attn, dim=-1) + attn_d = self.attn_dropout(attn) + + output = torch.matmul(attn_d, v).transpose(1, 2).reshape(B, T, -1) + output = self.proj_o(output) + output = self.output_dropout(output) + return output + + +class TransformerDecoderBlock(nn.Module): + + def __init__(self, d_q, d_kv, num_heads=4, drop_path=0., dropout=0., gain=1., is_first=False): + super().__init__() + + self.cross_attn = False + if d_kv > 0: + self.encoder_decoder_attn_layer_norm = nn.LayerNorm(d_q) + self.encoder_decoder_attn = MultiHeadAttention(d_q, d_kv, d_q, d_q, num_heads, dropout, gain) + self.cross_attn = True + + self.ffn = nn.Sequential( + nn.LayerNorm(d_q), + linear(d_q, 4 * d_q, weight_init='kaiming'), + nn.ReLU(), + linear(4 * d_q, d_q, gain=gain), + nn.Dropout(dropout)) + + self.drop_path1 = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path2 = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity() + + self.self_attn_layer_norm = nn.LayerNorm(d_q) + self.self_attn = MultiHeadAttention(d_q, d_q, d_q // 4, d_q, num_heads, dropout, gain) + self.conv = nn.Sequential( + nn.Conv1d(d_q, d_q, kernel_size=3, padding=1), + nn.ReLU(), + ) + + def forward(self, input, encoder_output, N): + """ + input: batch_size x target_len x d_model + encoder_output: batch_size x source_len x d_model + return: batch_size x target_len x d_model + """ + B, L, D = input.shape + if self.cross_attn: + x = self.encoder_decoder_attn_layer_norm(input) + x = self.encoder_decoder_attn(x, encoder_output, encoder_output) + x = self.conv(x.reshape(B*N, -1, D).transpose(1, 2)).transpose(1, 2).reshape(B, L, D) + input = input + self.drop_path1(x) + x = self.ffn(x) + input = input + self.drop_path2(x) + x = self.self_attn_layer_norm(input) + x = x.reshape(B*N, -1, D) + x = self.self_attn(x, x, x) + x = input + x.reshape(B, -1, D) + return x + + +class TransformerDecoder(nn.Module): + + def __init__(self, num_blocks, d_q, d_kv, num_heads, drop_path, dropout=0.): + super().__init__() + dpr = [x.item() for x in torch.linspace(0, drop_path, num_blocks)] # stochastic depth decay rule + gain = (3 * num_blocks) ** (-0.5) + self.blocks = nn.ModuleList( + [TransformerDecoderBlock(d_q, d_kv, num_heads, dpr[0], dropout, gain, is_first=True)] + + [TransformerDecoderBlock(d_q, d_kv, num_heads, dpr[i+1], dropout, gain, is_first=False) + for i in range(num_blocks - 1)]) + self.layer_norm = nn.LayerNorm(d_q) + + def forward(self, input, encoder_output, pos_emb, Nr): + """ + input: batch_size x target_len x d_model + encoder_output: batch_size x source_len x d_model + return: batch_size x target_len x d_model + """ + for i, block in enumerate(self.blocks): + input = block(input, encoder_output, Nr) + return self.layer_norm(input) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8dbdab0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +tqdm +wandb +seaborn +scikit-learn==1.3.0 +pytorch-lightning==1.8.6 +tabulate==0.9.0 +hydra-core==1.3.1 +opencv-python==4.6.0.66 +timm==0.6.13 +einops==0.3.0 +transforms3d==0.4.1 +piq==0.8.0 +lpips==0.1.4 diff --git a/scripts/eval_dtu.sh b/scripts/eval_dtu.sh new file mode 100644 index 0000000..d325f75 --- /dev/null +++ b/scripts/eval_dtu.sh @@ -0,0 +1,84 @@ +export CUDA_VISIBLE_DEVICES=$1 +conda activate slotlifter +dataset=dtu +ids=(1) +# ids=(6 7) +# ids=(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16) +for scene in ${ids[@]};do + python trainer/train_panopli_tensorf.py \ + cfg=${dataset} \ + cfg.job_type='test' \ + cfg.exp_name=vis_${scene} \ + cfg.test_percent=1.0 \ + cfg.num_workers=16 \ + cfg.precision=32 \ + cfg.num_slots=8 \ + cfg.enc_type='ibrnet' \ + cfg.num_src_view=4 \ + cfg.ckpt_path=/home/yuliu/Projects/OR3D/runs/dtu/f32_8s_64+64_01282054/checkpoints/best.ckpt \ + # cfg.scene_id=${scene} \ + +done + # cfg.n_samples=96 \ + # cfg.logger=wandb \ + # cfg.coarse_to_fine=true \ + # cfg.resume='/home/yuliu/Projects/OR3D/runs/dtu/test_01271617/checkpoints/last.ckpt' \ + # cfg.seed=0 \ + # cfg.random_proj=false \ + # cfg.coarse_to_fine=true \ + # cfg.n_samples=32 \ + # cfg.n_samples_fine=64 \ + # cfg.random_proj_ratio=0 \ + # cfg.slot_density=false \ + # cfg.random_proj=false \ + # cfg.norm_scene=false \ + # cfg.distance_scale=1. \ + # cfg.lambda_depth=0 \ + # cfg.lambda_geo=0.1 \ + # cfg.n_samples=32 \ + # cfg.n_samples_fine=64 \ + # cfg.implicit_render=true \ + # cfg.lambda_depth=0.1 \ + # cfg.normalize=false \ + # cfg.random_proj=false \ + # cfg.rel_pos=true \ + # cfg.rel_pos=true \ + # cfg.feature_size=32 \ + # cfg.normalize=false \ + # cfg.n_samples=64 \ + # cfg.coarse_to_fine=false \ + # cfg.pe_feat=2 \ + # cfg.distance_scale=25 \ + # cfg.train_dataset='llff+ibrnet_collected' \ + # cfg.lambda_depth=0.5 \ + # cfg.sigma_steps=30000 \ + # cfg.resume='/home/yuliu/Projects/OR3D/runs/scannet/8s_512_3090_09042226/checkpoints/last.ckpt' \ + # cfg.lambda_corr=0.2 \ + # cfg.lambda_feat=0 \ + # cfg.recon_feat=true \ + # cfg.random_proj=false \ + # cfg.force_bg=false \ + # cfg.lambda_feat=0 \ + # cfg.recon_feat=true \ + # cfg.lambda_corr=0.02 \ + # cfg.n_samples=32 \ + # cfg.n_samples_fine=64 \ + # cfg.slot_dec_dim=128 \ + # cfg.chunk=4096 \ + # cfg.num_src_view=4 \ + # cfg.num_slots=8 \ + # cfg.bg_border=0.2 \ + # cfg.slot_dec_dim=32 \ + # cfg.nerf_mlp_dim=32 \ + # cfg.num_src_view=10 \ + # cfg.chunk=4096 \ + # cfg.feature_size=64 \ + # cfg.recon_feat=false \ + # cfg.lambda_depth=0 \ + # cfg.num_slots=8 \ + # cfg.recon_feat=true \ + # cfg.render_src_view=false \ + # cfg.num_src_view=4 \ + # cfg.random_proj=false \ + # cfg.chunk=16384 \ + \ No newline at end of file diff --git a/scripts/eval_scannet.sh b/scripts/eval_scannet.sh new file mode 100644 index 0000000..96ce1e9 --- /dev/null +++ b/scripts/eval_scannet.sh @@ -0,0 +1,16 @@ +export CUDA_VISIBLE_DEVICES=$1 +conda activate slotlifter +dataset=scannet +seed=0 +python trainer/train.py \ + cfg=${dataset} \ + cfg.job_type='test' \ + cfg.exp_name="eval_scannet" \ + cfg.num_workers=8 \ + cfg.chunk=16384 \ + cfg.test_percent=1.0 \ + cfg.ckpt_path=checkpoints/scannet.ckpt \ + cfg.num_slots=8 \ + cfg.val_subsample_frames=1 \ + cfg.seed=${seed} \ + # cfg.logger=wandb \ diff --git a/scripts/eval_uorf_data.sh b/scripts/eval_uorf_data.sh new file mode 100644 index 0000000..7c90a8a --- /dev/null +++ b/scripts/eval_uorf_data.sh @@ -0,0 +1,25 @@ +export CUDA_VISIBLE_DEVICES=$1 +conda activate slotlifter +dataset=uorf +subset=clevr_567 # change the num_slots to 8 for clevr_567 +subset=room_chair +# subset=room_diverse +# subset=room_texture +# subset=kitchen_matte +# subset=kitchen_shiny +seed=0 +python trainer/train.py \ + cfg=${dataset} \ + cfg.job_type='test' \ + cfg.exp_name="eval_${subset}_seed1" \ + cfg.group=${subset} \ + cfg.subset=${subset} \ + cfg.num_workers=8 \ + cfg.chunk=16384 \ + cfg.test_percent=1.0 \ + cfg.ckpt_path=checkpoints/${subset}/seed1.ckpt \ + cfg.num_slots=5 \ + cfg.val_subsample_frames=1 \ + cfg.seed=${seed} \ + cfg.logger=wandb \ + diff --git a/scripts/train_dtu.sh b/scripts/train_dtu.sh new file mode 100644 index 0000000..b5e5143 --- /dev/null +++ b/scripts/train_dtu.sh @@ -0,0 +1,17 @@ +export CUDA_VISIBLE_DEVICES=6 +conda activate slotlifter +dataset=dtu +seed=0 +python trainer/train.py \ + cfg=${dataset} \ + cfg.job_type='train' \ + cfg.exp_name="${dataset}_${seed}" \ + cfg.batch_size=1 \ + cfg.ray_batchsize=1024 \ + cfg.val_check_interval=4 \ + cfg.num_workers=16 \ + cfg.num_slots=8 \ + cfg.num_src_view=1 \ + cfg.seed=${seed} \ + # cfg.logger=wandb \ + diff --git a/scripts/train_scannet.sh b/scripts/train_scannet.sh new file mode 100644 index 0000000..93f91b0 --- /dev/null +++ b/scripts/train_scannet.sh @@ -0,0 +1,17 @@ +export CUDA_VISIBLE_DEVICES=7 +conda activate slotlifter +dataset=scannet +seed=0 +python trainer/train.py \ + cfg=${dataset} \ + cfg.job_type='train' \ + cfg.exp_name="${dataset}_${seed}" \ + cfg.batch_size=1 \ + cfg.ray_batchsize=1024 \ + cfg.val_check_interval=3 \ + cfg.num_workers=16 \ + cfg.num_slots=8 \ + cfg.num_src_view=4 \ + cfg.seed=${seed} \ + # cfg.logger=wandb \ + \ No newline at end of file diff --git a/scripts/train_uorf_data.sh b/scripts/train_uorf_data.sh new file mode 100644 index 0000000..fa169f3 --- /dev/null +++ b/scripts/train_uorf_data.sh @@ -0,0 +1,24 @@ +export CUDA_VISIBLE_DEVICES=6,7 +conda activate slotlifter +dataset=uorf +subset=clevr_567 # change the num_slots to 8 for clevr_567 +# subset=room_chair +# subset=room_diverse +# subset=room_texture +# subset=kitchen_matte +# subset=kitchen_shiny +seed=0 +python trainer/train.py \ + cfg=${dataset} \ + cfg.job_type='train' \ + cfg.exp_name="${subset}_${seed}" \ + cfg.group=${subset} \ + cfg.subset=${subset} \ + cfg.num_workers=8 \ + cfg.batch_size=2 \ + cfg.ray_batchsize=1024 \ + cfg.val_check_interval=10 \ + cfg.num_slots=8 \ + cfg.seed=${seed} \ + cfg.logger=wandb \ + # cfg.monitor=psnr \ # using for kitchen_shiny and kitchen_matte because these datasets do not have ground truth masks \ No newline at end of file diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..14aedae --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import os +from pathlib import Path +from random import randint +import datetime + +import torch +import wandb +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +from pytorch_lightning.loggers import WandbLogger + +from util.distinct_colors import DistinctColors +from util.misc import visualize_depth, get_boundary_mask +from util.filesystem_logger import FilesystemLogger +import torch.nn.functional as F + + +def generate_experiment_name(config): + if config.resume is not None and config.job_type == 'train': + experiment = Path(config.resume).parents[1].name + # experiment = f"{config.exp_name}_{datetime.datetime.now().strftime('%m%d%H%M')}" + os.environ['experiment'] = experiment + elif not os.environ.get('experiment'): + experiment = f"{config.exp_name}_{datetime.datetime.now().strftime('%m%d%H%M')}" + os.environ['experiment'] = experiment + else: + experiment = os.environ['experiment'] + return experiment + + +def create_trainer(config): + + config.exp_name = generate_experiment_name(config) + if config.val_check_interval > 1: + config.val_check_interval = int(config.val_check_interval) + if config.seed is None: + config.seed = randint(0, 999) + + seed_everything(config.seed, workers=True) + + # save code files + filesystem_logger = FilesystemLogger(config) + + if config.logger == 'wandb': + logger = WandbLogger( + project=config.project, + entity=config.entity, + group=config.group, + name=config.exp_name, + job_type=config.job_type, + tags=config.tags, + notes=config.notes, + id=config.exp_name, + # settings=wandb.Settings(start_method='thread'), + ) + else: + logger = False + + checkpoint_callback = ModelCheckpoint(dirpath=(Path(config.log_path) / config.exp_name / "checkpoints"), + monitor=f'val/{config.monitor}', + save_top_k=1, + save_last=True, + mode='max', + ) + callbacks = [LearningRateMonitor("step"), checkpoint_callback] if logger else [] + gpu_count = torch.cuda.device_count() + if config.job_type == 'debug': + config.train_percent = 30 + config.val_percent = 1 + config.test_percent = 1 + config.val_check_interval = 1 + + kwargs = { + 'resume_from_checkpoint': config.resume, + 'logger': logger, + 'accelerator': 'gpu', + 'devices': gpu_count, + 'strategy': 'ddp' if gpu_count > 1 else None, + 'num_sanity_val_steps': 1, + 'max_steps': config.max_steps, + 'max_epochs': config.max_epochs, + 'limit_train_batches': config.train_percent, + 'limit_val_batches': config.val_percent, + 'limit_test_batches': config.test_percent, + 'val_check_interval': float(min(config.val_check_interval, 1)), + 'check_val_every_n_epoch': max(1, config.val_check_interval), + 'callbacks': callbacks, + 'gradient_clip_val': config.grad_clip if config.grad_clip > 0 else None, + 'precision': config.precision, + 'profiler': config.profiler, + 'benchmark': config.benchmark, + 'deterministic': config.deterministic, + } + trainer = Trainer(**kwargs) + return trainer + + +def visualize_panoptic_outputs(p_rgb, p_instances, p_depth, rgb, semantics, instances, H, W, depth=None, p_semantics=None): + alpha = 0.65 + distinct_colors = DistinctColors() + p_rgb = p_rgb.cpu() + img = p_rgb.view(H, W, 3).cpu().permute(2, 0, 1) + + depth_scale = 1 + if p_depth is not None: + p_depth = visualize_depth(p_depth.view(H, W) * depth_scale, use_global_norm=False) + else: + p_depth = torch.zeros_like(img) + if depth is not None: + depth = visualize_depth(depth.view(H, W) * depth_scale, use_global_norm=False) + else: + depth = torch.zeros_like(img) + + def get_color(p_instances, idx_bg, im, rgb_): + colored_img_instance = distinct_colors.apply_colors_fast_torch(p_instances).float() + # boundaries_img_instances = get_boundary_mask(p_instances.view(H, W)) + # colored_img_instance[p_instances == idx_bg, :] = rgb_[p_instances == idx_bg, :] + img_instances = colored_img_instance.view(H, W, 3).permute(2, 0, 1) * alpha + im * (1 - alpha) + # img_instances[:, boundaries_img_instances > 0] = 0 + return img_instances + + img_gt = rgb.view(H, W, 3).permute(2, 0, 1) + idx_bg = p_instances.sum(0).argmax().item() + p_instances = p_instances.argmax(dim=1).cpu() + img_instances = get_color(p_instances, idx_bg, img, p_rgb) + + if semantics is not None and semantics.max() > 0: + img_semantics_gt = distinct_colors.apply_colors_fast_torch(semantics).view(H, W, 3).permute(2, 0, 1) * alpha + img_gt * (1 - alpha) + boundaries_img_semantics_gt = get_boundary_mask(semantics.view(H, W)) + img_semantics_gt[:, boundaries_img_semantics_gt > 0] = 0 + else: + img_semantics_gt = torch.zeros_like(img_gt) + if p_semantics is not None and p_semantics.max() > 0: + p_semantics = p_semantics.argmax(dim=1).cpu() + img_semantics = distinct_colors.apply_colors_fast_torch(p_semantics).view(H, W, 3).permute(2, 0, 1) * alpha + img * (1 - alpha) + boundaries_img_semantics = get_boundary_mask(p_semantics.view(H, W)) + img_semantics[:, boundaries_img_semantics > 0] = 0 + else: + img_semantics = torch.zeros_like(img_gt) + if instances is not None and instances.max() > 0: + img_instances_gt = get_color(instances.long(), 0, img_gt, rgb) + else: + img_instances_gt = torch.zeros_like(img_gt) + stack = torch.cat([torch.stack([img_gt, img_semantics_gt, img_instances_gt, depth]), torch.stack([img, img_semantics, img_instances, p_depth])], dim=0) + return stack \ No newline at end of file diff --git a/trainer/train.py b/trainer/train.py new file mode 100644 index 0000000..b1fbc9e --- /dev/null +++ b/trainer/train.py @@ -0,0 +1,492 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +import math +from pathlib import Path +import torch +import hydra +import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from tabulate import tabulate +from torchvision.utils import save_image, make_grid +from torch.utils.data import DataLoader +from dataset import get_dataset +from model.nerf import NeRF +from model.renderer import NeRFRenderer +from trainer import create_trainer, visualize_panoptic_outputs +from util.misc import visualize_depth +from util.metrics import SegMetrics, ReconMetrics +from model.slot_attn import Slot3D +from torch import optim +from util.optimizer import Lion +from model.slot_attn import JointDecoder +from PIL import Image +import numpy as np +import seaborn as sns + + +def segmentation_to_rgb(seg, palette=None, num_objects=None, bg_color=(0, 0, 0)): + seg = seg[..., None] + if num_objects is None: + num_objects = np.max(seg) # assume consecutive numbering + num_objects += 1 # background + if palette is None: + # palette = [bg_color] + sns.color_palette('hls', num_objects-1) + palette = sns.color_palette('hls', num_objects) + + seg_img = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.float32) + for i in range(num_objects): + seg_img[np.nonzero(seg[:, :, 0] == i)] = palette[i] + return seg_img + + +def save_img_from_tensor(img, path): + r''' + img: tensor, [H, W, 3] + ''' + img = img.cpu().numpy() + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + img.save(path) + + +def save_seg_from_tensor(seg, path): + r''' + seg: tensor, [H, W] + ''' + seg = seg.cpu().numpy() + seg = segmentation_to_rgb(seg) + seg = (seg * 255).astype(np.uint8) + seg = Image.fromarray(seg) + seg.save(path) + + +class TensoRFTrainer(pl.LightningModule): + + def __init__(self, config): + super().__init__() + self.train_set, self.val_set, self.test_set = get_dataset(config) + self.train_sampler = None + if config.visualized_indices is None: + if config.job_type == 'vis': + config.visualized_indices = torch.arange(len(self.val_set)).tolist() + else: + config.visualized_indices = torch.randperm(len(self.val_set))[:4].tolist() + if config.job_type == 'debug': + config.visualized_indices = config.visualized_indices[0:1] + self.instance_steps = config.instance_steps + self.save_hyperparameters(config) + + self.slot_3D = Slot3D(config) + self.slot_dec = JointDecoder(config) + self.slot_dec_fine = JointDecoder(config) if config.coarse_to_fine else None + self.nerf = NeRF(config, n_samples=config.n_samples) + self.nerf_fine = NeRF(config, n_samples=config.n_samples+config.n_samples_fine) if config.coarse_to_fine else None + depth_range = min(self.train_set.depth_range[0], self.val_set.depth_range[0], self.test_set.depth_range[0]), \ + max(self.train_set.depth_range[1], self.val_set.depth_range[1], self.test_set.depth_range[1]) + self.renderer = NeRFRenderer(depth_range, cfg=config) + + self.loss = torch.nn.MSELoss(reduction='mean') + + self.cfg = config + self.output_dir_result_images = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/images') + self.output_dir_result_images.mkdir(exist_ok=True) + self.output_dir_result_seg = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/seg') + self.output_dir_result_seg.mkdir(exist_ok=True) + self.output_dir_result_depth = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/depth') + self.output_dir_result_depth.mkdir(exist_ok=True) + self.output_dir_result_attn = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/attn') + self.output_dir_result_attn.mkdir(exist_ok=True) + self.sigma = 1.0 + + self.seg_metrics = SegMetrics(['ari', 'ari_fg']) + self.recon_metrics = ReconMetrics(config.lpips_net) + self.recon_rgb = config.get('recon_rgb', True) + + + def configure_optimizers(self): + warmup_steps = self.cfg.warmup_steps + min_lr_factor = self.cfg.min_lr_factor + decay_steps = self.cfg.decay_steps + def lr_warmup_exp_decay(step: int): + factor = min(1, step / (warmup_steps + 1e-6)) + decay_factor = 0.5 ** (step / decay_steps * 1.5) + return factor * decay_factor * (1 - min_lr_factor) + min_lr_factor + def lr_exp_decay(step: int): + decay_factor = 0.5 ** (step / decay_steps) + return decay_factor * (1 - min_lr_factor) + min_lr_factor + params_nerf = [{'params': self.nerf.parameters(), + 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}] + params_renderer = [{'params': self.renderer.parameters(), + 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}] + params_slot_enc = [{'params': (x[1] for x in self.slot_3D.named_parameters() if 'dino' not in x[0]), + 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}] + params_slot_dec = [{'params': self.slot_dec.parameters(), + 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}] + params = params_nerf + params_renderer + params_slot_enc + params_slot_dec + lr_lambda_list = [lr_exp_decay, lr_exp_decay, lr_warmup_exp_decay, lr_warmup_exp_decay] + + if self.cfg.grid_init == 'tensorf' or self.cfg.grid_init == '3d': + params_grid = [{'params': [x[1] for x in self.renderer.named_parameters() if 'grid' in x[0]], + 'lr': self.cfg.lr * 20, 'weight_decay': self.cfg.weight_decay}] + params = params + params_grid + lr_lambda_list = lr_lambda_list + [lr_exp_decay] + + if self.cfg.coarse_to_fine: + params_nerf_fine = [{'params': self.nerf_fine.parameters(), + 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}] + params_slot_dec_fine = [{'params': self.slot_dec_fine.parameters(), + 'lr': self.cfg.lr, 'weight_decay': self.cfg.weight_decay}] + params = params + params_nerf_fine + params_slot_dec_fine + lr_lambda_list = lr_lambda_list + [lr_exp_decay] + [lr_warmup_exp_decay] + + opt = Lion(params, weight_decay=self.cfg.weight_decay) + scheduler = optim.lr_scheduler.LambdaLR(optimizer=opt, lr_lambda=lr_lambda_list) + + return [opt], [{"scheduler": scheduler, "interval": "step"}] + + def forward(self, rays, depth_range, slots, view_feats, cam, src_rgbs, src_cams, is_train): + B, Nr, _ = rays.shape + outputs = [] + render_depth = not is_train + render_ins = self.global_step >= self.instance_steps or not is_train + render_color = self.recon_rgb or not is_train + for i in range(0, Nr, self.cfg.chunk): + outputs.append(self.renderer(self.nerf, self.nerf_fine, self.slot_dec, self.slot_dec_fine, + rays[:, i: i + self.cfg.chunk], depth_range, slots, + cam, src_rgbs, src_cams, view_feats, + False, is_train, + render_color=render_color, render_depth=render_depth, + render_ins=render_ins)) + keys = outputs[0].keys() + out = {} + for k in keys: + if 'dist' in k or 'loss' in k: + out[k] = torch.stack([o[k] for o in outputs], 0).mean() + else: + out[k] = torch.cat([o[k] for o in outputs], 1).flatten(0, 1) + return out + + def training_step(self, batch, batch_idx): + self.sigma = self.cosine_anneal(self.global_step, self.cfg.sigma_steps, final_value=0) + if self.cfg.random_proj: + ratio = self.cosine_anneal(self.global_step, self.cfg.random_proj_steps, start_value=0.99, final_value=0) + self.nerf.random_proj_ratio = 1 - ratio + self.log('ratio', 1-ratio) + self.log('sigma', self.sigma) + src_rgbs, src_cams = batch['src_rgbs'], batch['src_cams'] # [B, N, H, W, 3], [B, N, 34] + B, N_views = src_rgbs.shape[:2] + + src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W] + slots, attn, view_feats = self.slot_3D(sigma=self.sigma, images=src_rgbs, src_cams=src_cams) + + rgbs = batch['rgbs'] # [B, Br, 3] + rays = batch['rays'] # [B, Br, 6] + depth_range = batch['depth_range'] # [B, 2] + view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W] + output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, True) + + loss_rgb = self.loss(output['rgb_c'], rgbs.view(-1, 3)) + if self.cfg.coarse_to_fine: + loss_rgb = (loss_rgb + self.loss(output['rgb_f'], rgbs.view(-1, 3))) / 2 + if self.cfg.normalize: + loss_rgb = loss_rgb / 4 + loss = loss_rgb + self.log("train/loss_rgb", loss_rgb, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True) + return loss + + def on_validation_epoch_start(self): + torch.cuda.empty_cache() + self.recon_metrics.set_divice(self.device) + + def validation_step(self, batch, batch_idx): + out_put = {} + + src_rgbs, src_cams = batch['src_rgbs'], batch['src_cams'] # [B, N, H, W, 3], [B, N, 34] + + B, N_views, H, W, _ = src_rgbs.shape + src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W] + slots, attn, view_feats = self.slot_3D(sigma=0, images=src_rgbs, src_cams=src_cams) + + # get rays from cameras + N = batch['cam'].shape[1] + rays = batch['rays'].view(1, -1, 6) # [1, Br, 6] + depth_range = batch['depth_range'] # [B, 2] + view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W] + output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, False) + + if self.cfg.coarse_to_fine: + output_rgb = output['rgb_f'] + output_instances = output['instance_f'] + else: + output_rgb = output['rgb_c'] + output_instances = output['instance_c'] + + shape = (N, H, W, 3) + rgbs = batch['rgbs'].view(N, H*W, -1) # [1, N*H*W, 3] + if self.cfg.normalize: + output_rgb = output_rgb * 0.5 + 0.5 + rgbs = rgbs * 0.5 + 0.5 + if self.cfg.dataset == 'dtu' or self.cfg.dataset == 'ibrnet': + recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2)) + out_put.update(recon_metrics) + else: + rs_instances = batch['instances'].view(N, -1) # [N, H*W] + if self.cfg.dataset != 'scannet' and self.cfg.dataset != 'oct': + Nv = self.cfg.num_src_view + recon_metrics = self.recon_metrics(output_rgb.view(shape)[Nv:].permute(0, 3, 1, 2), rgbs.view(shape)[Nv:].permute(0, 3, 1, 2)) + seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[:Nv], rs_instances[:Nv]) + out_put.update(recon_metrics) + out_put.update(seg_metrics) + # src_metircs = self.recon_metrics(output_rgb.view(shape)[:Nv].permute(0, 3, 1, 2), rgbs.view(shape)[:Nv].permute(0, 3, 1, 2)) + # for key, value in src_metircs.items(): + # out_put['src_' + key] = value + nv_seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[Nv:], rs_instances[Nv:]) + for key, value in nv_seg_metrics.items(): + out_put['nv_' + key] = value + else: + recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2)) + K = output_instances.shape[-1] + seg_metrics = self.seg_metrics(output_instances.reshape(N, -1, K), rs_instances) + out_put.update(recon_metrics) + out_put.update(seg_metrics) + + return out_put + + def validation_epoch_end(self, outputs): + keys = outputs[0].keys() + logs = {} + for k in keys: + v = torch.stack([x[k] for x in outputs]).mean() + logs['val/' + k] = v + self.log_dict(logs, sync_dist=True) + table = [keys, ] + table.append(tuple([logs['val/' + key] for key in table[0]])) + print(tabulate(table, headers='firstrow', tablefmt='fancy_grid')) + self.visualize(self.val_dataloader()) + + @rank_zero_only + def visualize(self, dataloader): + (self.output_dir_result_seg / f"{self.global_step:06d}").mkdir(exist_ok=True) + for batch_idx, batch in enumerate(dataloader): + if batch_idx in self.cfg.visualized_indices: + cam = batch['cam'].reshape(1, -1, 34).to(self.device) # [1, N, 34] + rays = batch['rays'].reshape(1, -1, 6).to(self.device) # [1, NHW, 6] + NHW = rays.shape[1] + instances = batch.get('instances', torch.zeros(NHW)) + rgbs = batch.get('rgbs', torch.zeros([NHW, 3])) + depth = batch.get('depth', torch.zeros_like(instances)) + semantics = batch.get('semantics', torch.zeros_like(instances)) + src_rgbs, src_cams = batch['src_rgbs'].to(self.device), batch['src_cams'].to(self.device) + B, N_views, H, W, _ = src_rgbs.shape + src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W] + slots, attn, view_feats = self.slot_3D(sigma=self.sigma, images=src_rgbs, src_cams=src_cams) + view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W] + depth_range = batch['depth_range'].to(self.device) # [B, 2] + output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, False) + if self.cfg.coarse_to_fine: + output_rgb = output['rgb_f'] + output_instances = output['instance_f'] + output_depth = output['depth_f'] + else: + output_rgb = output['rgb_c'] + output_instances = output['instance_c'] + output_depth = output['depth_c'] + if self.cfg.normalize: + output_rgb = output_rgb * 0.5 + 0.5 + rgbs = rgbs * 0.5 + 0.5 + src_rgbs = src_rgbs * 0.5 + 0.5 + images = src_rgbs[0].reshape(-1, 3, H, W).cpu() # [N, 3, H, W] + N = cam.shape[1] + Nv = self.cfg.num_src_view + n_pad = 4 - Nv % 4 if Nv % 4 != 0 else 0 + for frame_id in range(N): + stack = visualize_panoptic_outputs(output_rgb.view(N, H*W, -1)[frame_id], output_instances.view(N, H*W, -1)[frame_id], + output_depth.view(N, H*W)[frame_id], rgbs.view(N, H*W, -1)[frame_id], + semantics.view(N, H*W)[frame_id], instances.view(N, H*W)[frame_id], + H, W, depth.view(N, H*W)[frame_id]) + stack = torch.cat([images, torch.zeros(n_pad, 3, H, W), stack], dim=0) + if self.cfg.logger == 'wandb': + self.logger.log_image(key=f"images/{batch_idx:04d}_{frame_id:04d}", images=[make_grid(stack, value_range=(0, 1), nrow=4, normalize=True)]) + save_image(stack, self.output_dir_result_images / f"{self.global_step:06d}_{batch_idx:04d}_{frame_id:04d}.jpg", value_range=(0, 1), nrow=4, normalize=True) + if self.cfg.dataset == 'dtu': + H1, W1 = 76, 100 + else: + H1, W1 = H // 4, W // 4 + attn = attn.reshape(-1, src_rgbs.shape[1], H1, W1) + attn = F.interpolate(attn, size=(H, W), mode='nearest') + attn = attn.permute(1, 0, 2, 3).unsqueeze(2).repeat(1, 1, 3, 1, 1) + K = attn.shape[1] + img = torch.cat([images.unsqueeze(1).cpu(), 1 - attn.cpu()], dim=1).reshape(-1, 3, H, W) + if self.cfg.logger == 'wandb': + self.logger.log_image(key=f"attn/{batch_idx:04d}", images=[make_grid(img, value_range=(0, 1), nrow=K+1, normalize=True)]) + + def on_test_epoch_start(self): + torch.cuda.empty_cache() + self.recon_metrics.set_divice(self.device) + + def test_step(self, batch, batch_idx): + out_put = {} + + src_rgbs, src_cams = batch['src_rgbs'], batch['src_cams'] # [B, N, H, W, 3], [B, N, 34] + B, Nv, H, W, _ = src_rgbs.shape + src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W] + slots, attn, view_feats = self.slot_3D( sigma=0, images=src_rgbs, src_cams=src_cams) + + images = src_rgbs[0].reshape(-1, 3, H, W).cpu() # [N_src_view, 3, H, W] + H1, W1 = H // 4, W // 4 + if self.cfg.dataset == 'dtu': + H1 = 76 + attn = attn.reshape(-1, src_rgbs.shape[1], H1, W1) + attn = F.interpolate(attn, size=(H, W), mode='nearest') + attn = attn.permute(1, 0, 2, 3).unsqueeze(2).repeat(1, 1, 3, 1, 1) + img = torch.cat([images.unsqueeze(1).cpu(), 1 - attn.cpu()], dim=1).reshape(-1, 3, H, W) + img = make_grid(img, nrow=Nv+1, padding=0) # [3, (N_src_view+1)*H, W] + save_img_from_tensor(img.permute(1, 2, 0), self.output_dir_result_attn / f"{batch_idx:04d}_attn.png") + + # get rays from cameras + cam = batch['cam'] # [1, N, 34] + N = cam.shape[1] + rays = batch['rays'].view(1, -1, 6) # [1, Br, 6] + depth_range = batch['depth_range'] # [B, 2] + + view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W] + output = self(rays, depth_range, slots, view_feats, None, src_rgbs, src_cams, False) + if self.cfg.coarse_to_fine: + output_rgb = output['rgb_f'] + output_instances = output['instance_f'] + output_depth = output['depth_f'] + else: + output_rgb = output['rgb_c'] + output_instances = output['instance_c'] + output_depth = output['depth_c'] + + shape = (N, H, W, 3) + rgbs = batch['rgbs'].view(N, H*W, -1) # [1, N*H*W, 3] + if self.cfg.normalize: + output_rgb = output_rgb * 0.5 + 0.5 + rgbs = rgbs * 0.5 + 0.5 + if self.cfg.dataset == 'dtu': + recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2)) + out_put.update(recon_metrics) + else: + rs_instances = batch['instances'].view(N, -1) # [N, H*W] + if self.cfg.dataset != 'scannet': + Nv = self.cfg.num_src_view + recon_metrics = self.recon_metrics(output_rgb.view(shape)[Nv:].permute(0, 3, 1, 2), rgbs.view(shape)[Nv:].permute(0, 3, 1, 2)) + seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[:Nv], rs_instances[:Nv]) + out_put.update(recon_metrics) + out_put.update(seg_metrics) + src_metircs = self.recon_metrics(output_rgb.view(shape)[:Nv].permute(0, 3, 1, 2), rgbs.view(shape)[:Nv].permute(0, 3, 1, 2)) + for key, value in src_metircs.items(): + out_put['src_' + key] = value + nv_seg_metrics = self.seg_metrics(output_instances.view(N, H*W, -1)[Nv:], rs_instances[Nv:]) + for key, value in nv_seg_metrics.items(): + out_put['nv_' + key] = value + else: + recon_metrics = self.recon_metrics(output_rgb.view(shape).permute(0, 3, 1, 2), rgbs.view(shape).permute(0, 3, 1, 2)) + K = output_instances.shape[-1] + seg_metrics = self.seg_metrics(output_instances.reshape(N, -1, K), rs_instances) + out_put.update(recon_metrics) + out_put.update(seg_metrics) + + print(f'batch_idx: {batch_idx}') + for k, v in recon_metrics.items(): + print(k, ': ', v.item()) + for k, v in seg_metrics.items(): + print(k, ': ', v.item()) + print('-' * 40) + # save img + imgs_gt = rgbs.view(shape) + imgs_pred = output_rgb.view(shape) + # seg_gt = rs_instances.view(N, H, W) + seg_pred = output_instances.argmax(-1).view(N, H, W) + depth_pred = output_depth.view(N, H, W) + for n in range(N): + save_img_from_tensor(imgs_gt[n], self.output_dir_result_images / f"{batch_idx:04d}_{n:02d}_rgb_gt.png") + save_img_from_tensor(imgs_pred[n], self.output_dir_result_images / f"{batch_idx:04d}_{n:02d}_rgb_pred.png") + # save_seg_from_tensor(seg_gt[n], self.output_dir_result_seg / f"{batch_idx:04d}_{n:02d}_seg_gt.png") + save_seg_from_tensor(seg_pred[n], self.output_dir_result_seg / f"{batch_idx:04d}_{n:02d}_seg_pred.png") + depth = visualize_depth(depth_pred[n], use_global_norm=False) # [3, H, W] + save_img_from_tensor(depth.permute(1, 2, 0), self.output_dir_result_depth / f"{batch_idx:04d}_{n:02d}_depth_pred.png") + return out_put + + def test_epoch_end(self, outputs): + keys = outputs[0].keys() + logs = {} + for k in keys: + v = torch.stack([x[k] for x in outputs]).mean() + logs['test/' + k] = v + self.log_dict(logs, sync_dist=True) + table = [keys, ] + table.append(tuple([logs['test/' + key] for key in table[0]])) + print(tabulate(table, headers='firstrow', tablefmt='fancy_grid')) + self.visualize(self.test_dataloader()) + + def train_dataloader(self): + shuffle = False if self.cfg.job_type == 'debug' else True + shuffle = shuffle and self.train_sampler is None + persistent_workers = self.cfg.num_workers > 0 + return DataLoader(self.train_set, self.cfg.batch_size, shuffle=shuffle, pin_memory=True, num_workers=self.cfg.num_workers, sampler=self.train_sampler, persistent_workers=persistent_workers) + + def val_dataloader(self): + return DataLoader(self.val_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers) + + def test_dataloader(self): + return DataLoader(self.test_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers) + + def on_train_epoch_start(self): + torch.cuda.empty_cache() + self.slot_dec.force_bg = self.cfg.force_bg and self.global_step < self.cfg.force_bg_steps + if self.cfg.coarse_to_fine: + print(f'Using {self.renderer.n_samples} points for coarse rendering, {self.renderer.n_samples_fine} points for fine rendering') + else: + print(f'Using {self.renderer.n_samples} points for rendering') + + def cosine_anneal(self, step, final_step, start_step=0, start_value=1.0, final_value=0.1): + if start_value <= final_value or start_step >= final_step: + return final_value + if step < start_step: + value = start_value + elif step >= final_step: + value = final_value + else: + a = 0.5 * (start_value - final_value) + b = 0.5 * (start_value + final_value) + progress = (step - start_step) / (final_step - start_step) + value = a * math.cos(math.pi * progress) + b + return value + + +@hydra.main(config_path='../config', config_name='config', version_base='1.2') +def main(config): + config = config.cfg + trainer = create_trainer(config) + model = TensoRFTrainer(config) + if trainer.logger is not None and config.watch_model: + trainer.logger.watch(model) + if config.job_type == 'test': + ckpt = torch.load(config.ckpt_path) + model.load_state_dict(ckpt['state_dict']) + print(f'Load from checkpoint: {config.ckpt_path}') + trainer.test(model) + elif config.job_type == 'vis': + ckpt = torch.load(config.ckpt_path) + model.load_state_dict(ckpt['state_dict']) + print(f'Load from checkpoint: {config.ckpt_path}') + model.eval() + with torch.no_grad(): + model.visualize(model.val_dataloader()) + else: + trainer.fit(model) + trainer.test(model) + + +if __name__ == '__main__': + main() diff --git a/trainer/visualize.py b/trainer/visualize.py new file mode 100644 index 0000000..e87f1bb --- /dev/null +++ b/trainer/visualize.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +from tqdm import tqdm +from pathlib import Path +import torch +import hydra +import torch.nn.functional as F +import pytorch_lightning as pl +from torchvision.utils import make_grid +from torch.utils.data import DataLoader +from dataset import get_dataset +from model.nerf import NeRF +from model.renderer import NeRFRenderer +from model.slot_attn import Slot3D +from model.slot_attn import SlotMixerDecoder +from util.misc import visualize_depth +from PIL import Image +import numpy as np +import seaborn as sns + + +def segmentation_to_rgb(seg, palette=None, num_objects=None, bg_color=(0, 0, 0)): + seg = seg[..., None] + if num_objects is None: + num_objects = np.max(seg) # assume consecutive numbering + num_objects += 1 # background + if palette is None: + # palette = [bg_color] + sns.color_palette('hls', num_objects-1) + palette = sns.color_palette('hls', num_objects) + + seg_img = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.float32) + for i in range(num_objects): + seg_img[np.nonzero(seg[:, :, 0] == i)] = palette[i] + return seg_img + + +def save_img_from_tensor(img, path, transform=False): + r''' + img: tensor, [H, W, 3] + ''' + img = img.cpu().numpy() * 255 + img = img.astype(np.uint8) + # if transform: + # brighten the image + # img = cv2.convertScaleAbs(img, alpha=1.8, beta=1) + img = Image.fromarray(img) + img.save(path) + + +def save_seg_from_tensor(seg, path): + r''' + seg: tensor, [H, W] + ''' + seg = seg.cpu().numpy() + seg = segmentation_to_rgb(seg) + seg = (seg * 255).astype(np.uint8) + seg = Image.fromarray(seg) + seg.save(path) + + +class Visualizer(pl.LightningModule): + + def __init__(self, config): + super().__init__() + self.train_set, self.val_set, self.test_set = get_dataset(config) + + self.slot_3D = Slot3D(config) + self.slot_dec = SlotMixerDecoder(config) + self.slot_dec_fine = SlotMixerDecoder(config) if config.coarse_to_fine else None + self.nerf = NeRF(config, n_samples=config.n_samples) + self.nerf_fine = NeRF(config, n_samples=config.n_samples+config.n_samples_fine) if config.coarse_to_fine else None + self.depth_range = min(self.train_set.depth_range[0], self.val_set.depth_range[0], self.test_set.depth_range[0]), \ + max(self.train_set.depth_range[1], self.val_set.depth_range[1], self.test_set.depth_range[1]) + self.renderer = NeRFRenderer(self.depth_range, cfg=config) + + self.cfg = config + self.output_dir_result_images = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/images') + self.output_dir_result_images.mkdir(exist_ok=True) + self.output_dir_result_seg = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/seg') + self.output_dir_result_seg.mkdir(exist_ok=True) + self.output_dir_result_depth = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/depth') + self.output_dir_result_depth.mkdir(exist_ok=True) + self.output_dir_result_attn = Path(f'{self.cfg.log_path}/{self.cfg.exp_name}/attn') + self.output_dir_result_attn.mkdir(exist_ok=True) + + self.scene_id = config.get('scene_id', 0) + + def forward(self, rays, depth_range, slots, view_feats, cam, src_rgbs, src_cams, is_train): + B, Nr, _ = rays.shape + outputs = [] + render_depth = True + render_ins = True + render_feat = False + render_color = True + for i in range(0, Nr, self.cfg.chunk): + outputs.append(self.renderer(self.nerf, self.nerf_fine, self.slot_dec, self.slot_dec_fine, + rays[:, i: i + self.cfg.chunk], depth_range, slots, + cam, src_rgbs, src_cams, view_feats, + False, is_train, + render_color=render_color, render_depth=render_depth, + render_sem=False, render_ins=render_ins, render_feat=render_feat)) + keys = outputs[0].keys() + out = {} + for k in keys: + if 'dist' in k or 'loss' in k: + out[k] = torch.stack([o[k] for o in outputs], 0).mean() + else: + out[k] = torch.cat([o[k] for o in outputs], 1).flatten(0, 1) + return out + + def visualize(self, dataloader): + for batch_idx, batch in tqdm(enumerate(dataloader)): + if batch_idx == self.scene_id or self.cfg.dataset == 'scannet' or self.cfg.dataset == 'dtu': + src_rgbs, src_cams = batch['src_rgbs'].to(self.device), batch['src_cams'].to(self.device) + B, Nv, H, W, _ = src_rgbs.shape + src_rgbs = src_rgbs.permute(0, 1, 4, 2, 3) # [B, Nv, 3, H, W] + slots, attn, view_feats = self.slot_3D(None, sigma=0, images=src_rgbs, src_cams=src_cams) + view_feats = view_feats.permute(0, 1, 4, 2, 3) # [B, Nv, D, H, W] + + images = src_rgbs[0].reshape(-1, 3, H, W).cpu() # [N_src_view, 3, H, W] + H1, W1 = H // 4, W // 4 + if self.cfg.dataset == 'dtu': + H1 = 76 + attn = attn.reshape(-1, src_rgbs.shape[1], H1, W1) + attn = F.interpolate(attn, size=(H, W), mode='nearest') + attn = attn.permute(1, 0, 2, 3).unsqueeze(2).repeat(1, 1, 3, 1, 1) + img = torch.cat([images.unsqueeze(1).cpu(), 1 - attn.cpu()], dim=1).reshape(-1, 3, H, W) + img = make_grid(img, nrow=img.shape[0], padding=0) # [3, (N_src_view+1)*H, W] + save_img_from_tensor(img.permute(1, 2, 0), self.output_dir_result_attn / f"{batch_idx:04d}_attn.png") + # input_img = images[0].permute(1, 2, 0) # [H, W, 3] + # save_img_from_tensor(input_img, self.output_dir_result_images / f"{batch_idx:04d}_input.png", True) + # instances = batch['instances'][0].view(H, W) # [H, W] + # save_seg_from_tensor(instances, self.output_dir_result_seg / f"{batch_idx:04d}_seg_gt.png") + if self.cfg.dataset == 'scannet': + gt_img = batch['rgbs'].view(H, W, 3) / 2 + 0.5 # [H, W, 3] + save_img_from_tensor(gt_img, self.output_dir_result_images / f"{batch_idx:04d}_rgb_gt.png", True) + gt_seg = batch['instances'].view(H, W) # [H, W] + save_seg_from_tensor(gt_seg, self.output_dir_result_seg / f"{batch_idx:04d}_seg_gt.png") + + depth_range = batch['depth_range'].to(self.device) # [B, 2] + N = self.cfg.num_vis + all_rays = batch['rays'][0].to(self.device) # [N, HW, 6] + for n in tqdm(range(N)): + rays = all_rays[n:n+1]# [1, HW, 6] + output = self(rays, depth_range, slots, view_feats, batch.get('azi_rot'), src_rgbs, src_cams, False) + + if self.cfg.coarse_to_fine: + output_rgb = output['rgb_f'] + output_instances = output['instance_f'] + output_depth = output['depth_f'] + else: + output_rgb = output['rgb_c'] + output_instances = output['instance_c'] + output_depth = output['depth_c'] + + if self.cfg.normalize: + output_rgb = output_rgb * 0.5 + 0.5 + src_rgbs = src_rgbs * 0.5 + 0.5 + + shape = (H, W, 3) + imgs_pred = output_rgb.view(shape) + seg_pred = output_instances.argmax(-1).view(H, W) + depth_pred = output_depth.view(H, W) + save_img_from_tensor(imgs_pred, self.output_dir_result_images / f"{batch_idx:04d}_{n:02d}_rgb_pred.png", True) + save_seg_from_tensor(seg_pred, self.output_dir_result_seg / f"{batch_idx:04d}_{n:02d}_seg_pred.png") + depth = visualize_depth(depth_pred, maxval=self.depth_range[1], use_global_norm=True) # [3, H, W] + save_img_from_tensor(depth.permute(1, 2, 0), self.output_dir_result_depth / f"{batch_idx:04d}_{n:02d}_depth_pred.png") + # break + + def train_dataloader(self): + return DataLoader(self.train_set, batch_size=1, shuffle=True, pin_memory=True, num_workers=self.cfg.num_workers) + + def val_dataloader(self): + return DataLoader(self.val_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers) + + def test_dataloader(self): + return DataLoader(self.test_set, batch_size=1, shuffle=False, pin_memory=True, num_workers=self.cfg.num_workers) + + +@hydra.main(config_path='../config', config_name='config', version_base='1.2') +def main(config): + config = config.cfg + result_path = Path(f'{config.log_path}/{config.exp_name}') + result_path.mkdir(exist_ok=True) + model = Visualizer(config) + ckpt = torch.load(config.ckpt_path) + model.load_state_dict(ckpt['state_dict']) + print(f'Load from checkpoint: {config.ckpt_path}') + model.cuda() + model.eval() + with torch.no_grad(): + model.visualize(model.val_dataloader()) + +if __name__ == '__main__': + main() diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/camera.py b/util/camera.py new file mode 100644 index 0000000..5e9850b --- /dev/null +++ b/util/camera.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +from scipy.spatial.transform import Rotation +import numpy as np +import torch +from einops import repeat + +from util.transforms import trs_comp, dot + + +def frustum_world_bounds(dims, intrinsics, cam2worlds, max_depth, form='bbox'): + """Compute bounds defined by the frustum provided cameras + Args: + dims (N,2): heights,widths of cameras + intrinsics (N,3,3): intrinsics (unnormalized, hence HW required) + cam2worlds (N,4,4): camera to world transformations + max_depth (float): depth of all frustums + form (str): bbox: convex bounding box, sphere: convex bounding sphere + """ + # unproject corner points + h_img_corners = torch.Tensor([[0, 1, 1], [1, 0, 1], [1, 1, 1]]) + intrinsics_inv = torch.linalg.inv(intrinsics[:, [1, 0, 2]]) # K in WH -> convert to HW + k = len(h_img_corners) + n = len(dims) + rep_HWds = repeat(torch.cat([dims, torch.ones((n, 1))], 1), "n c -> n k c", k=k) + skel_pts = rep_HWds * repeat(h_img_corners, "k c -> n k c", n=n) # (N,K,(hwd)) + corners_cam_a = torch.einsum("nkij,nkj->nki", repeat(intrinsics_inv, "n x y -> n k x y", k=k), skel_pts) * max_depth + corners_cam_b = torch.einsum("nkij,nkj->nki", repeat(intrinsics_inv, "n x y -> n k x y", k=k), skel_pts) * 0.01 + # nihalsid: adding corner with max depth, and corners with min depth + corners_cam = torch.cat([corners_cam_a, corners_cam_b], 0) + corners_cam_h = torch.cat([corners_cam, torch.ones(corners_cam.shape[0], corners_cam.shape[1], 1)], -1) + + corners_world_h = torch.einsum("nij,nkj->nki", cam2worlds.repeat(2, 1, 1), corners_cam_h) + corners_world_flat = corners_world_h.reshape(-1, 4)[:, :3] + if form == 'bbox': + bounds = torch.stack([corners_world_flat.min(0).values, corners_world_flat.max(0).values]) + return bounds + elif form == 'sphere': + corners_world_center = torch.mean(corners_world_flat, 0) + sphere_radius = torch.max(torch.norm((corners_world_flat - corners_world_center), dim=1)) + + # todo: remove visualization + ############################## + # from util.misc import visualize_points + # import trimesh + # visualize_points(corners_world_flat, "runs/world_flat.obj") + # sphere = trimesh.creation.icosphere(radius=sphere_radius.numpy()) + # sphere.apply_translation(corners_world_center) + # sphere.export("runs/world_bounds_sphere.obj") + ############################## + + return corners_world_center, sphere_radius + else: + raise Exception("Not implemented yet: Ellipsoid for example") + + +def compute_world2normscene(dims, intrinsics, cam2worlds, max_depth, rescale_factor=1.0): + """Compute transform converting world to a normalized space enclosing all + cameras frustums (given depth) into a unit sphere + Note: max_depth=0 -> camera positions only are contained (like NeRF++ does it) + + Args: + dims (N,2): heights,widths of cameras + intrinsics (N,3,3): intrinsics (unnormalized, hence HW required) + cam2worlds (N,4,4): camera to world transformations + max_depth (float): depth of all frustums + rescale_factor (float)>=1.0: factor to scale the world space even further so no camera is too close to the unit sphere surface + """ + assert rescale_factor >= 1.0, "prevent cameras outside of unit sphere" + + sphere_center, sphere_radius = frustum_world_bounds(dims, intrinsics, cam2worlds, max_depth, 'sphere') # sphere containing frustums + world2nscene = trs_comp(-sphere_center / (rescale_factor * sphere_radius), torch.eye(3), 1 / (rescale_factor * sphere_radius)) + + return world2nscene + + +def depth_to_distance(depth, intrinsics): + uv = np.stack(np.meshgrid( + np.arange(depth.shape[1]), + np.arange(depth.shape[0]) + ), -1).reshape(-1, 2) + depth = depth.reshape(-1) + uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1) + return depth * np.linalg.norm((np.linalg.inv(intrinsics) @ uvh.T).T, axis=1) + + +def distance_to_depth(K, dist, uv=None): + if uv is None and len(dist.shape) >= 2: + # create mesh grid according to d + uv = np.stack(np.meshgrid(np.arange(dist.shape[1]), np.arange(dist.shape[0])), -1) + uv = uv.reshape(-1, 2) + dist = dist.reshape(-1) + if not isinstance(dist, np.ndarray): + uv = torch.from_numpy(uv).to(dist) + if isinstance(dist, np.ndarray): + # z * np.sqrt(x_temp**2+y_temp**2+z_temp**2) = dist + uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1) + temp_point = dot(np.linalg.inv(K), uvh) + z = dist / np.linalg.norm(temp_point, axis=1) + + else: + uvh = torch.cat([uv, torch.ones(len(uv), 1).to(uv)], -1) + temp_point = dot(torch.inverse(K), uvh) + z = dist / torch.linalg.norm(temp_point, dim=1) + return z + + +def unproject_2d_3d(cam2world, intrinsics, depth, dims): + uv = np.stack(np.meshgrid(np.arange(dims[0]), np.arange(dims[1])), -1) + uv = torch.from_numpy(uv.reshape(-1, 2)) + uvh = np.concatenate([uv, np.ones((len(uv), 1))], -1) + cam_point = (torch.linalg.inv(intrinsics) @ torch.from_numpy(uvh).float().T).T * depth[:, None] + world_point = (cam2world[:3, :3] @ cam_point.T).T + cam2world[:3, 3] + return world_point + + +def project_3d_2d(cam2world, K, world_point, with_dist=False, discrete=True, round=True): + if isinstance(world_point, np.ndarray): + cam_point = dot(np.linalg.inv(cam2world), world_point) + point_dist = np.sqrt((cam_point ** 2).sum(-1)) + img_point = dot(K, cam_point) + uv_point = img_point[:, :2] / img_point[:, 2][:, None] + if discrete: + if round: + uv_point = np.round(uv_point) + uv_point = uv_point.astype(np.int) + if with_dist: + return uv_point, img_point[:, 2], point_dist + return uv_point + else: + cam_point = dot(torch.inverse(cam2world), world_point) + point_dist = (cam_point ** 2).sum(-1).sqrt() + img_point = dot(K, cam_point) + uv_point = img_point[:, :2] / img_point[:, 2][:, None] + if discrete: + if round: + uv_point = torch.round(uv_point) + uv_point = uv_point.int() + if with_dist: + return uv_point, img_point[:, 2], point_dist + + return uv_point + + +def auto_orient_poses(poses, method="up"): + """Orients and centers the poses. We provide two methods for orientation: pca and up. + pca: Orient the poses so that the principal component of the points is aligned with the axes. + This method works well when all of the cameras are in the same plane. + up: Orient the poses so that the average up vector is aligned with the z axis. + This method works well when images are not at arbitrary angles. + Args: + poses: The poses to orient. + method: The method to use for orientation. Either "pca" or "up". + Returns: + The oriented poses. + borrowed from from nerfstudio + """ + translation = poses[..., :3, 3] + + mean_translation = torch.mean(translation, dim=0) + translation = translation - mean_translation + + if method == "pca": + _, eigvec = torch.linalg.eigh(translation.T @ translation) + eigvec = torch.flip(eigvec, dims=(-1,)) + + if torch.linalg.det(eigvec) < 0: + eigvec[:, 2] = -eigvec[:, 2] + + transform = torch.cat([eigvec, eigvec @ -mean_translation[..., None]], dim=-1) + oriented_poses = transform @ poses + + if oriented_poses.mean(axis=0)[2, 1] < 0: + oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3] + elif method == "up": + up = torch.mean(poses[:, :3, 1], dim=0) + up = up / torch.linalg.norm(up) + + rotation = rotation_matrix(up, torch.Tensor([0, 0, 1])) + transform = torch.cat([rotation, rotation @ -mean_translation[..., None]], dim=-1) + oriented_poses = transform @ poses + + return oriented_poses + + +def rotation_matrix(a, b): + """Compute the rotation matrix that rotates vector a to vector b. + Args: + a: The vector to rotate. + b: The vector to rotate to. + Returns: + The rotation matrix. + borrowed from from nerfstudio + """ + a = a / torch.linalg.norm(a) + b = b / torch.linalg.norm(b) + v = torch.cross(a, b) + c = torch.dot(a, b) + # If vectors are exactly opposite, we add a little noise to one of them + if c < -1 + 1e-8: + eps = (torch.rand(3) - 0.5) * 0.01 + return rotation_matrix(a + eps, b) + s = torch.linalg.norm(v) + skew_sym_mat = torch.Tensor( + [ + [0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0], + ] + ) + return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s ** 2 + 1e-8)) + + +def _compute_residual_and_jacobian(x, y, xd, yd, k1=0.0, k2=0.0, k3=0.0, k4=0.0, p1=0.0, p2=0.0): + """Auxiliary function of radial_and_tangential_undistort().""" + # Adapted from https://github.com/google/nerfies/blob/main/nerfies/camera.py + # let r(x, y) = x^2 + y^2; + # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 + + # k4 * r(x, y)^4; + r = x * x + y * y + d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4))) + + # The perfect projection is: + # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2); + # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2); + # + # Let's define + # + # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd; + # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd; + # + # We are looking for a solution that satisfies + # fx(x, y) = fy(x, y) = 0; + fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd + fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd + + # Compute derivative of d over [x, y] + d_r = (k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4))) + d_x = 2.0 * x * d_r + d_y = 2.0 * y * d_r + + # Compute derivative of fx over x and y. + fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x + fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y + + # Compute derivative of fy over x and y. + fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x + fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y + + return fx, fy, fx_x, fx_y, fy_x, fy_y + + +def rotate(p, angle, o=[0, 0]): + ''' + Rotate point p around point o with angle + p: (x, y) + o: (x, y) + angle: degree + H: height of image + ''' + x = o[0] + math.cos(angle) * (p[0] - o[0]) - math.sin(angle) * (o[1] - p[1]) + y = o[1] - math.sin(angle) * (p[0] - o[0]) - math.cos(angle) * (o[1] - p[1]) + return x, y + + +def rotate_cam(cam, theta, zoom_ratio=1., origin=[0, 0]): + r''' + Rotate camera around origin with theta + cam: (4, 4) + origin: (x, y) + theta: degree, 0~2pi + ''' + new_cam = cam.clone().numpy() + R = new_cam[:3, :3] # rotation matrix + + rot_z = Rotation.from_rotvec(theta * np.array([0, 0, 1])).as_matrix() + R = rot_z @ R + new_cam[:3, :3] = R + + p = new_cam[:2, 3] # translation x, y + p = rotate(p, -theta, origin) + new_cam[:2, 3] = p + new_cam[:3, 3] = new_cam[:3, 3] * zoom_ratio + return torch.from_numpy(new_cam) + + +def rot_matrix_angular_dist(R1, R2): + assert R1.shape[-1] == 3 and R2.shape[-1] == 3 and R1.shape[-2] == 3 and R2.shape[-2] == 3 + return np.arccos(np.clip((np.trace(np.matmul(R2.transpose(0, 2, 1), R1), axis1=1, axis2=2) - 1) / 2., + a_min=-1 + 1e-6, a_max=1 - 1e-6)) + +def select_src_ids(tar_pose, ref_poses, num_select, dist_weight=0.5, drop_first=True): + num_cams = len(ref_poses) + num_select = min(num_select, num_cams - 1) + batched_tar_pose = tar_pose[None, ...].repeat(num_cams, 0) + + angular_dists = rot_matrix_angular_dist(batched_tar_pose[:, :3, :3], ref_poses[:, :3, :3]) + dists = np.linalg.norm(batched_tar_pose[:, :3, 3] - ref_poses[:, :3, 3], axis=1) + dists = dist_weight * dists + (1 - dist_weight) * angular_dists + + sorted_ids = np.argsort(dists) + if drop_first: + selected_ids = sorted_ids[1:num_select+1] + else: + selected_ids = sorted_ids[:num_select] + return selected_ids \ No newline at end of file diff --git a/util/distinct_colors.py b/util/distinct_colors.py new file mode 100644 index 0000000..136ffcc --- /dev/null +++ b/util/distinct_colors.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import torch +import numpy as np + + +class DistinctColors: + + def __init__(self): + colors = [ + '#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f55031', '#911eb4', '#42d4f4', '#bfef45', '#fabed4', '#469990', + '#dcb1ff', '#404E55', '#fffac8', '#809900', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#a9a9a9', '#f032e6', + '#806020', '#ffffff', + + "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0030ED", "#3A2465", "#34362D", "#B4A8BD", "#0086AA", + "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81", "#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700", + + "#04F757", "#C8A1A1", "#1E6E00", + "#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700", + "#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329", + "#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C", + ] + self.hex_colors = colors + # 0 = crimson / red, 1 = green, 2 = yellow, 3 = blue + # 4 = orange, 5 = purple, 6 = sky blue, 7 = lime green + self.colors = [hex_to_rgb(c) for c in colors] + self.color_assignments = {} + self.color_ctr = 0 + self.fast_color_index = torch.from_numpy(np.array([hex_to_rgb(colors[i % len(colors)]) for i in range(8096)] + [hex_to_rgb('#000000')])) + + def get_color(self, index, override_color_0=False): + colors = [x for x in self.hex_colors] + if override_color_0: + colors[0] = "#3f3f3f" + colors = [hex_to_rgb(c) for c in colors] + if index not in self.color_assignments: + self.color_assignments[index] = colors[self.color_ctr % len(self.colors)] + self.color_ctr += 1 + return self.color_assignments[index] + + def get_color_fast_torch(self, index): + return self.fast_color_index[index] + + def get_color_fast_numpy(self, index, override_color_0=False): + index = np.array(index).astype(np.int32) + if override_color_0: + colors = [x for x in self.hex_colors] + colors[0] = "#3f3f3f" + fast_color_index = torch.from_numpy(np.array([hex_to_rgb(colors[i % len(colors)]) for i in range(8096)] + [hex_to_rgb('#000000')])) + return fast_color_index[index % fast_color_index.shape[0]].numpy() + else: + return self.fast_color_index[index % self.fast_color_index.shape[0]].numpy() + + def apply_colors(self, arr): + out_arr = torch.zeros([arr.shape[0], 3]) + + for i in range(arr.shape[0]): + out_arr[i, :] = torch.tensor(self.get_color(arr[i].item())) + return out_arr + + def apply_colors_fast_torch(self, arr): + return self.fast_color_index[arr % self.fast_color_index.shape[0]] + + def apply_colors_fast_numpy(self, arr): + return self.fast_color_index.numpy()[arr % self.fast_color_index.shape[0]] + + +def hex_to_rgb(x): + return [int(x[i:i + 2], 16) / 255 for i in (1, 3, 5)] + + +def visualize_distinct_colors(num_vis=32): + from PIL import Image + dc = DistinctColors() + labels = np.ones((1, 64, 64)).astype(np.int) + all_labels = [] + for i in range(num_vis): + all_labels.append(labels * i) + all_labels = np.concatenate(all_labels, 0) + shape = all_labels.shape + labels_colored = dc.get_color_fast_numpy(all_labels.reshape(-1)) + labels_colored = (labels_colored.reshape(shape[0] * shape[1], shape[2], 3) * 255).astype(np.uint8) + Image.fromarray(labels_colored).save("colormap.png") + + +if __name__ == "__main__": + visualize_distinct_colors() diff --git a/util/filesystem_logger.py b/util/filesystem_logger.py new file mode 100644 index 0000000..adda84f --- /dev/null +++ b/util/filesystem_logger.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import argparse +import shutil +from pathlib import Path +from typing import Dict, Optional, Union + +from omegaconf import OmegaConf +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.loggers.logger import DummyExperiment +from pytorch_lightning.loggers.logger import rank_zero_experiment + + +class FilesystemLogger(Logger): + + @property + def version(self) -> Union[int, str]: + return 0 + + @property + def name(self) -> str: + return "fslogger" + + # noinspection PyMethodOverriding + def log_hyperparams(self, params: argparse.Namespace): + pass + + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + pass + + def __init__(self, experiment_config, **_kwargs): + super().__init__() + self.experiment_config = experiment_config + self._experiment = None + # noinspection PyStatementEffect + self.experiment + + @property + @rank_zero_experiment + def experiment(self): + if self._experiment is None: + self._experiment = DummyExperiment() + experiment_dir = Path(self.experiment_config["log_path"], self.experiment_config["exp_name"]) + experiment_dir.mkdir(exist_ok=True, parents=True) + + src_folders = ['config', 'data/splits', 'model', 'tests', 'trainer', 'util', 'data_processing', 'dataset'] + sources = [] + for src in src_folders: + sources.extend(list(Path(".").glob(f'{src}/**/*'))) + + files_to_copy = [x for x in sources if x.suffix in [".py", ".pyx", ".txt", ".so", ".pyd", ".h", ".cu", ".c", '.cpp', ".html"] and x.parts[0] != "runs" and x.parts[0] != "wandb"] + + for f in files_to_copy: + Path(experiment_dir, "code", f).parents[0].mkdir(parents=True, exist_ok=True) + shutil.copyfile(f, Path(experiment_dir, "code", f)) + + Path(experiment_dir, "config.yaml").write_text(OmegaConf.to_yaml(self.experiment_config)) + + return self._experiment diff --git a/util/metrics.py b/util/metrics.py new file mode 100644 index 0000000..5cb807a --- /dev/null +++ b/util/metrics.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import torch +import numpy as np +from sklearn.metrics import adjusted_rand_score +from scipy.optimize import linear_sum_assignment +import torch.nn.functional as F +import torch.nn as nn +from functools import partial +from piq import ssim +from piq import psnr +import lpips + + +def average_ari(masks, masks_gt, fg_only=False, reduction='mean'): + r''' + Input: + masks: (B, K, N) + masks_gt: (B, N) + ''' + ari = [] + masks = masks.argmax(dim=1) + B = masks.shape[0] + for i in range(B): + m = masks[i].cpu().numpy() + m_gt = masks_gt[i].cpu().numpy() + if fg_only: + m = m[np.where(m_gt > 0)] + m_gt = m_gt[np.where(m_gt > 0)] + score = adjusted_rand_score(m, m_gt) + ari.append(score) + if reduction == 'mean': + return torch.Tensor(ari).mean() + else: + return torch.Tensor(ari) + + +def imask2bmask(imasks, ignore_index=None): + r"""Convert index mask to binary mask. + Args: + imask: index mask, shape (B, N) + Returns: + bmasks: # a list of (K, N), len = B + """ + B, N = imasks.shape + bmasks = [] + for i in range(B): + imask = imasks[i:i+1] # (1, N) + classes = imask.unique().tolist() + if ignore_index in classes: + classes.remove(ignore_index) + bmask = [imask == c for c in classes] + bmask = torch.cat(bmask, dim=0) # (K, N) + bmasks.append(bmask.float()) + # can't use torch.stack because of different K + return bmasks + + +def mean_best_overlap(masks, masks_gt, fg_only=False, reduction='mean'): + r"""Compute the best overlap between predicted and ground truth masks. + Args: + masks: predicted masks, shape (B, K, N), binary N = H*W + masks_gt: ground truth masks, shape (B, N), index + """ + B = masks.shape[0] + ignore_index = None + if fg_only: + ignore_index = 0 + bmasks_gt = imask2bmask(masks_gt, ignore_index=ignore_index) # a list of (K, N), len = B + mean_best_overlap = [] + mOR = [] + for i in range(B): + mask = masks[i].unsqueeze(0) > 0.5 # (1, K, N) + mask_gt = bmasks_gt[i].unsqueeze(1) > 0.5 # (K_gt, 1, N) + # Compute IOU + eps = 1e-8 + intersection = (mask * mask_gt).sum(-1) + union = (mask + mask_gt).sum(-1) + iou = intersection / (union + eps) # (K_gt, K) + # Compute best overlap + best_overlap, _ = torch.max(iou, dim=1) + # Compute mean best overlap + mean_best_overlap.append(best_overlap.mean()) + mOR.append((best_overlap > 0.5).float().mean()) + if reduction == 'mean': + return torch.stack(mean_best_overlap).mean() + # , torch.stack(mOR).mean() + else: + return torch.stack(mean_best_overlap) + # , torch.stack(mOR) + + +def iou_loss(pred, target): + """ + Compute the iou loss: 1 - iou + pred: [K, N] + targets: [Kt, N] + """ + eps = 1e-8 + pred = pred > 0.5 # [K, N] + target = target > 0.5 + intersection = (pred[:, None] & target[None]).sum(-1).float() # [K, Kt] + union = (pred[:, None] | target[None]).sum(-1).float() + eps # [K, Kt] + loss = 1 - (intersection / union) # [K, Kt] + return loss # [K, Kt] +iou_loss_jit = torch.jit.script(iou_loss) + + +class Matcher(): + @torch.no_grad() + def forward(self, pred, target): + r""" + pred: [K, N] + targets: [Kt, N] + """ + loss = iou_loss_jit(pred, target) + row_ind, col_ind = linear_sum_assignment(loss.cpu().numpy()) + return torch.as_tensor(row_ind, dtype=torch.int64), torch.as_tensor(col_ind, dtype=torch.int64) + + @torch.no_grad() + def batch_forward(self, pred, targets): + """ + pred: [B, K, N] + targets: list of B x [Kt, N] Kt can be different for each target + """ + indices = [] + for i in range(pred.shape[0]): + indices.append(self.forward(pred[i], targets[i])) + return indices + + +@torch.no_grad() +def compute_iou(pred, target): + """ + Input: + x: [K, N] + y: [K, N] + Return: + iou: [K, N] + """ + eps = 1e-8 + pred = pred > 0.5 # [K, N] + target = target > 0.5 + intersection = (pred & target).sum(-1).float() + union = (pred | target).sum(-1).float() + eps # [K] + return (intersection / union).mean() +compute_iou_jit = torch.jit.script(compute_iou) + + +def matchedIoU(preds, targets, matcher, fg_only=False, reduction="mean"): + r""" + Input: + pred: [B, K, N] + targets: [B, N] + Return: + IoU: [1] or [B] + """ + if preds.dim() == 2: # [K, N] + preds = preds.unsqueeze(0) + targets = targets.unsqueeze(0) + + ious = [] + B = preds.shape[0] + ignore_index = None + if fg_only: + ignore_index = 0 + targets = imask2bmask(targets, ignore_index) # a list of [K1, N], len = B + for i in range(B): + tgt = targets[i] + pred = preds[i] # [K, N] + src_idx, tgt_idx = matcher.forward(pred, tgt) + src_pred = pred[src_idx] # [K1, N] + tgt_mask = tgt[tgt_idx] # [K1, N] + ious.append(compute_iou_jit(src_pred, tgt_mask)) + ious = torch.stack(ious) + if reduction == "mean": + return ious.mean() + else: + return ious + + +matcher = Matcher() +SEGMETRICS = { + "hiou": partial(matchedIoU, matcher=matcher), # hungarian matched iou + "hiou_fg": partial(matchedIoU, fg_only=True, matcher=matcher), + "mbo": mean_best_overlap, # mean best overlap + "mbo_fg": partial(mean_best_overlap, fg_only=True), + "ari": average_ari, + "ari_fg": partial(average_ari, fg_only=True), +} +class SegMetrics(nn.Module): + def __init__(self, metrics=["hiou", "ari", "ari_fg"]): + super().__init__() + self.metrics = {} + for m in metrics: + self.metrics[m] = SEGMETRICS[m] + + def forward(self, preds, targets): + r""" + Input: + preds: [B, N, K] + targets: [B, N] + Return: + metrics: dict of metrics + """ + metrics = {} + valid = targets.sum(-1) > 0 + preds = preds[valid] + targets = targets[valid] + preds = F.one_hot(preds.argmax(dim=-1), num_classes=preds.shape[-1]).permute(0, 2, 1).float() # [B, K, N] + for k, v in self.metrics.items(): + if valid.sum() > 0: + metrics[k] = v(preds, targets) + else: + metrics[k] = torch.tensor(1).to(preds.device) + return metrics + + def compute(self, preds, targets, metric='hiou'): + return self.metrics[metric](preds, targets) + + def metrics_name(self): + return list(self.metrics.keys()) + + +class ReconMetrics(nn.Module): + def __init__(self, lpips_net='vgg'): + super().__init__() + self.metrics = { + "ssim": ssim, + "psnr": psnr, + "lpips": lpips.LPIPS(net=lpips_net), + } + + def forward(self, preds, targets): + r""" + Input: + preds: [B, C, H, W] + targets: [B, C, H, W] + Return: + metrics: dict of metrics + """ + metrics = {} + for k, v in self.metrics.items(): + metrics[k] = v(preds, targets).mean() + return metrics + + def compute(self, preds, targets, metric='psnr'): + return self.metrics[metric](preds, targets) + + def metrics_name(self): + return list(self.metrics.keys()) + + def set_divice(self, device): + self.metrics["lpips"].to(device) \ No newline at end of file diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000..6ad64b4 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,364 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from collections import OrderedDict +import os +import sys +root_path = os.path.abspath(__file__) +root_path = '/'.join(root_path.split('/')[:-2]) +sys.path.append(root_path) +import torch +import math +import numpy as np +from pathlib import Path +from matplotlib import cm +from PIL import Image +import torchvision.transforms as T + + +def visualize_depth(depth, minval=0.001, maxval=1.5, use_global_norm=True): + x = depth + if isinstance(depth, torch.Tensor): + x = depth.cpu().numpy() + x = np.nan_to_num(x) # change nan to 0 + if use_global_norm: + mi = minval + ma = maxval + else: + mi = np.min(x) # get minimum depth + ma = np.max(x) + x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 + x_ = Image.fromarray((cm.get_cmap('jet')(x) * 255).astype(np.uint8)) + x_ = T.ToTensor()(x_)[:3, :, :] + return x_ + + +def bounds(x): + lower = [] + upper = [] + for i in range(x.shape[1]): + lower.append(x[:, i].min()) + upper.append(x[:, i].max()) + return torch.tensor([lower, upper]) + + +def visualize_points(points, vis_path, colors=None): + if colors is None: + Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} 127 127 127" for p in points)) + else: + Path(vis_path).write_text("\n".join(f"v {p[0]} {p[1]} {p[2]} {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}" for i, p in enumerate(points))) + + +def visualize_points_as_pts(points, vis_path, colors=None): + if colors is None: + Path(vis_path).write_text("\n".join([f'{points.shape[0]}'] + [f"{p[0]} {p[1]} {p[2]} 255 127 127 127" for p in points])) + else: + Path(vis_path).write_text("\n".join([f'{points.shape[0]}'] + [f"{p[0]} {p[1]} {p[2]} 255 {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}" for i, p in enumerate(points)])) + + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + + def pre_hook(_mod, _inputs): + nesting[0] += 1 + + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen if t.requires_grad] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + self[name] = value + + def __delattr__(self, name): + del self[name] + + +def to_point_list(s): + return np.concatenate([c[:, np.newaxis] for c in np.where(s)], axis=1) + + +def get_parameters_from_state_dict(state_dict, filter_key): + new_state_dict = OrderedDict() + for k in state_dict: + if k.startswith(filter_key): + new_state_dict[k.replace(filter_key + '.', '')] = state_dict[k] + return new_state_dict + + +def logistic(n, zero_at): + return 1 - 1 / (1 + math.exp(-10 * (n / zero_at - 0.5))) + + +def visualize_voxel_grid(output_path, voxel_grid, scale_to=(-1, 1)): + voxel_grid = ((voxel_grid - voxel_grid.min()) / (voxel_grid.max() - voxel_grid.min())).cpu() + rescale = lambda axis: scale_to[0] + (points[axis] / voxel_grid.shape[axis]) * (scale_to[1] - scale_to[0]) + points = list(torch.where(voxel_grid > 0)) + if len(points[0] > 0): + colors = cm.get_cmap('jet')(voxel_grid.numpy()) + colors = colors[points[0].numpy(), points[1].numpy(), points[2].numpy(), :] + points[0] = rescale(0) + points[1] = rescale(1) + points[2] = rescale(2) + Path(output_path).write_text("\n".join([f'v {points[0][i]} {points[1][i]} {points[2][i]} {colors[i, 0]} {colors[i, 1]} {colors[i, 2]}' for i in range(points[0].shape[0])])) + else: + Path(output_path).write_text("") + print("no points found..") + + +def visualize_labeled_points(locations, labels, output_path): + from util.distinct_colors import DistinctColors + distinct_colors = DistinctColors() + if isinstance(labels, torch.Tensor): + colored_arr = distinct_colors.get_color_fast_torch(labels.flatten().cpu().numpy().tolist()).reshape(list(labels.shape) + [3]).numpy() + else: + colored_arr = distinct_colors.get_color_fast_numpy(labels.flatten().tolist()).reshape(list(labels.shape) + [3]) + visualize_points(locations, output_path, colored_arr) + + +def visualize_weighted_points(output_path, xyz, weights, threshold=1e-4): + weights = weights.view(-1) + weights_mask = weights > threshold + colors = cm.get_cmap('jet')(weights[weights_mask].numpy()) + visualize_points(xyz[weights_mask, :].numpy(), output_path, colors=colors) + + +def visualize_mask(arr, path): + from util.distinct_colors import DistinctColors + distinct_colors = DistinctColors() + assert len(arr.shape) == 2, "should be an HxW array" + boundaries = get_boundary_mask(arr) + if isinstance(arr, torch.Tensor): + colored_arr = distinct_colors.get_color_fast_torch(arr.flatten().cpu().numpy().tolist()).reshape(list(arr.shape) + [3]).numpy() + else: + colored_arr = distinct_colors.get_color_fast_numpy(arr.flatten().tolist()).reshape(list(arr.shape) + [3]) + colored_arr = (colored_arr * 255).astype(np.uint8) + colored_arr[boundaries > 0, :] = 0 + Image.fromarray(colored_arr).save(path) + + +def probability_to_normalized_entropy(probabilities): + entropy = torch.zeros_like(probabilities[:, 0]) + for i in range(probabilities.shape[1]): + entropy = entropy - probabilities[:, i] * torch.log2(probabilities[:, i] + 1e-8) + entropy = entropy / math.log2(probabilities.shape[1]) + return entropy + + +def get_boundary_mask(arr, dialation_size=1): + import cv2 + arr_t, arr_r, arr_b, arr_l = arr[1:, :], arr[:, 1:], arr[:-1, :], arr[:, :-1] + arr_t_1, arr_r_1, arr_b_1, arr_l_1 = arr[2:, :], arr[:, 2:], arr[:-2, :], arr[:, :-2] + kernel = np.ones((dialation_size, dialation_size), 'uint8') + if isinstance(arr, torch.Tensor): + arr_t = torch.cat([arr_t, arr[-1, :].unsqueeze(0)], dim=0) + arr_r = torch.cat([arr_r, arr[:, -1].unsqueeze(1)], dim=1) + arr_b = torch.cat([arr[0, :].unsqueeze(0), arr_b], dim=0) + arr_l = torch.cat([arr[:, 0].unsqueeze(1), arr_l], dim=1) + + arr_t_1 = torch.cat([arr_t_1, arr[-2, :].unsqueeze(0), arr[-1, :].unsqueeze(0)], dim=0) + arr_r_1 = torch.cat([arr_r_1, arr[:, -2].unsqueeze(1), arr[:, -1].unsqueeze(1)], dim=1) + arr_b_1 = torch.cat([arr[0, :].unsqueeze(0), arr[1, :].unsqueeze(0), arr_b_1], dim=0) + arr_l_1 = torch.cat([arr[:, 0].unsqueeze(1), arr[:, 1].unsqueeze(1), arr_l_1], dim=1) + + boundaries = torch.logical_or(torch.logical_or(torch.logical_or(torch.logical_and(arr_t != arr, arr_t_1 != arr), torch.logical_and(arr_r != arr, arr_r_1 != arr)), torch.logical_and(arr_b != arr, arr_b_1 != arr)), torch.logical_and(arr_l != arr, arr_l_1 != arr)) + + boundaries = boundaries.cpu().numpy().astype(np.uint8) + boundaries = cv2.dilate(boundaries, kernel, iterations=1) + boundaries = torch.from_numpy(boundaries).to(arr.device) + else: + arr_t = np.concatenate([arr_t, arr[-1, :][np.newaxis, :]], axis=0) + arr_r = np.concatenate([arr_r, arr[:, -1][:, np.newaxis]], axis=1) + arr_b = np.concatenate([arr[0, :][np.newaxis, :], arr_b], axis=0) + arr_l = np.concatenate([arr[:, 0][:, np.newaxis], arr_l], axis=1) + + arr_t_1 = np.concatenate([arr_t_1, arr[-2, :][np.newaxis, :], arr[-1, :][np.newaxis, :]], axis=0) + arr_r_1 = np.concatenate([arr_r_1, arr[:, -2][:, np.newaxis], arr[:, -1][:, np.newaxis]], axis=1) + arr_b_1 = np.concatenate([arr[0, :][np.newaxis, :], arr[1, :][np.newaxis, :], arr_b_1], axis=0) + arr_l_1 = np.concatenate([arr[:, 0][:, np.newaxis], arr[:, 1][:, np.newaxis], arr_l_1], axis=1) + + boundaries = np.logical_or(np.logical_or(np.logical_or(np.logical_and(arr_t != arr, arr_t_1 != arr), np.logical_and(arr_r != arr, arr_r_1 != arr)), np.logical_and(arr_b != arr, arr_b_1 != arr)), np.logical_and(arr_l != arr, arr_l_1 != arr)).astype(np.uint8) + boundaries = cv2.dilate(boundaries, kernel, iterations=1) + + return boundaries + + +def pixelid2patchid(pixelid, H, W, H1, W1): + """ + Input: + pixelid: from 0 to H*W-1 + Return: + patchid: from 0 to H1*W1-1 + """ + x, y = pixelid % W, pixelid // W + x1, y1 = x // (W // W1), y // (H // H1) + patchid = y1 * W1 + x1 + return patchid + + +def patchify(x, p): + r""" + Input: + x: (H, W, D) + p: patch size + Return: + x: (hw, p*p, D) + """ + H, W, D = x.shape + h, w = H // p, W // p + x = x.reshape(h, p, w, p, D) + x = torch.einsum('hpwqc->hwpqc', x) + x = x.reshape(h*w, p*p, D) + return x + + +def batch_patchify(x, p): + r""" + Input: + x: (B, H, W, D) + p: patch size + Return: + x: (B, hw, p*p, D) + """ + B, H, W, D = x.shape + h, w = H // p, W // p + x = x.reshape(B, h, p, w, p, D) + x = torch.einsum('bhpwqc->bhwpqc', x) + x = x.reshape(B, h*w, p*p, D) + return x + + +class SubSampler(): + def idx_subsample(self, img_size, Br, mode='random'): + r''' + Input: + img_size: H,W + Br: number of rays to sample + Return: + subsample_idx: [Br, 1] + ''' + if mode == 'uniform': + b = int(Br ** 0.5) + assert b * b == Br, "Br must be a square number" + H, W = img_size + Nr = H * W + s_x, s_y = W // b, H // b + x = torch.randint(0, s_x, size=(1,)).item() + y = torch.randint(0, s_y, size=(1,)).item() + subsample_idx = torch.arange(Nr).reshape(H, W)[y:, x:][::s_y, ::s_x] + subsample_idx = subsample_idx.reshape(-1, 1) + Br1 = subsample_idx.shape[0] + if Br1 < Br: # b can't be devided by H or W + subsample_idx = torch.cat([subsample_idx, + torch.randint(0, Nr, [Br-Br1, 1])], dim=0) + else: + subsample_idx = subsample_idx[torch.randperm(Br1)[:Br]] + return subsample_idx + else: + H, W = img_size + Nr = H * W + subsample_idx = torch.randperm(Nr)[:Br].reshape(-1, 1) + return subsample_idx + + def idx_subsample_patch(self, img_size, patch_size, s=1): + r''' + Input: + img_size: H,W + patch_size: P + s: stride for subsampling + Return: + subsample_idx: [Br, 1] + ''' + H, W = img_size + Nr = H * W + P = patch_size + x = torch.randint(0, W//P, size=(1,)).item() * P + torch.randint(0, s, size=(1,)).item() + y = torch.randint(0, H//P, size=(1,)).item() * P + torch.randint(0, s, size=(1,)).item() + subsample_idx = torch.arange(Nr).view(H, W)[y:y+P, x:x+P][::s, ::s] + subsample_idx = subsample_idx.reshape(P//s*P//s, 1) + return subsample_idx + + def subsample(self, idx, x_tuple): + r''' + Input: + idx: [Br, 1] + x_tuple: a tuple of tensors to subsample, [HW, C] + Return: + x_tuple: a tuple of tensors [Br, C] + ''' + ret = [] + for x in x_tuple: + ret.append(x.gather(0, idx.expand(-1, x.shape[1]))) + return tuple(ret) + diff --git a/util/optimizer.py b/util/optimizer.py new file mode 100644 index 0000000..cd6064c --- /dev/null +++ b/util/optimizer.py @@ -0,0 +1,84 @@ +# Copyright 2023 Google Research. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""PyTorch implementation of the Lion optimizer.""" +import torch +from torch.optim.optimizer import Optimizer + + +class Lion(Optimizer): + r"""Implements Lion algorithm.""" + + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): + """Initialize the hyperparameters. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-4) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float, optional): weight decay coefficient (default: 0) + """ + + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + beta1, beta2 = group['betas'] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + p.add_(torch.sign(update), alpha=-group['lr']) + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss \ No newline at end of file diff --git a/util/ray.py b/util/ray.py new file mode 100644 index 0000000..3000388 --- /dev/null +++ b/util/ray.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import torch + + +def create_grid(H, W, render_stride=1): + xs = torch.linspace(0, W - 1, W)[::render_stride] + ys = torch.linspace(0, H - 1, H)[::render_stride] + i, j = torch.meshgrid(xs, ys, indexing='ij') + i, j = i + render_stride // 2, j + render_stride // 2 + return i.t(), j.t() + + +def get_ray_directions_with_intrinsics(H, W, intrinsics, render_stride=1): + B = intrinsics.shape[0] + h, w = H // render_stride, W // render_stride + i, j = create_grid(H, W, render_stride=render_stride) + i, j = i.to(intrinsics.device), j.to(intrinsics.device) + i, j = i[None, ...], j[None, ...] + fx, fy, cx, cy = intrinsics[:, 0:1, 0:1], intrinsics[:, 1:2, 1:2], intrinsics[:, 0:1, 2:3], intrinsics[:, 1:2, 2:3] + directions = torch.stack([ + (i - cx) / fx, (j - cy) / fy, torch.ones([B, h, w], device=intrinsics.device) + ], -1) + return directions + + +def get_rays(cameras, H, W, render_stride=1): + # cameras: (B, 27) 2 + 9 + 16 HW + intrinsics + cam2world + h, w = H // render_stride, W // render_stride + rays_o, rays_d = get_rays_origin_and_direction(cameras, H, W, render_stride) # (B, 1, 3), (B, H*W, 3) + rays = torch.cat([ + rays_o.expand(-1, h*w, -1), rays_d + ], -1) # (B, H*W, 6) + return rays + + +def get_rays_origin_and_direction(cameras, H, W, render_stride=1): + B = cameras.shape[0] + h, w = H // render_stride, W // render_stride + intrinsics, cam2world = cameras[:, 2:18].reshape(-1, 4, 4)[:, :3, :3], cameras[:, -16:].reshape(-1, 4, 4) + directions = get_ray_directions_with_intrinsics(H, W, intrinsics, render_stride) + # directions: (B, H, W, 3), cam2world: (B, 4, 4) + rays_d = torch.matmul(directions.view(B, h*w, 3), cam2world[:, :3, :3].transpose(1, 2)) # (B, H*W, 3) + rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) + rays_o = cam2world[:, :3, 3] + + rays_d = rays_d.view(B, h*w, 3) + rays_o = rays_o.view(B, 1, 3) + + return rays_o, rays_d + + +def rays_intersect_sphere(rays_o, rays_d, r=1): + """ + Solve for t such that a=ro+trd with ||a||=r + Quad -> r^2 = ||ro||^2 + 2t (ro.rd) + t^2||rd||^2 + -> t = (-b +- sqrt(b^2 - 4ac))/(2a) with + a = ||rd||^2 + b = 2(ro.rd) + c = ||ro||^2 - r^2 + => (forward intersection) t= (sqrt(D) - (ro.rd))/||rd||^2 + with D = (ro.rd)^2 - (r^2 - ||ro||^2) + """ + odotd = torch.sum(rays_o * rays_d, -1) + d_norm_sq = torch.sum(rays_d ** 2, -1) + o_norm_sq = torch.sum(rays_o ** 2, -1) + determinant = odotd ** 2 + (r ** 2 - o_norm_sq) * d_norm_sq + assert torch.all( + determinant >= 0 + ), "Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!" + return (torch.sqrt(determinant) - odotd) / d_norm_sq diff --git a/util/transforms.py b/util/transforms.py new file mode 100644 index 0000000..ea47bdb --- /dev/null +++ b/util/transforms.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np +import torch +from transforms3d.euler import euler2mat +from transforms3d.axangles import axangle2mat +from transforms3d.quaternions import quat2mat + + +def has_torch(*args): + return any([isinstance(x, torch.Tensor) for x in args]) + + +def dot(transform, points, coords=False): + if isinstance(points, torch.Tensor): + return dot_torch(transform, points, coords) + else: + if isinstance(transform, torch.Tensor): # points dominate + transform = transform.cpu().numpy() + if type(points) == list: + points = np.array(points) + + if len(points.shape) == 1: + # single point + if transform.shape == (3, 3): + return transform @ points[:3] + else: + return (transform @ np.array([*points[:3], 1]))[:3] + if points.shape[1] == 3 or (coords and points.shape[1] > 3): + # nx[xyz,...] + if transform.shape == (4, 4): + pts = (transform[:3, :3] @ points[:, :3].T).T + transform[:3, 3] + elif transform.shape == (3, 3): + pts = (transform[:3, :3] @ points[:, :3].T).T + else: + raise RuntimeError("Format of transform not understood") + return np.concatenate([pts, points[:, 3:]], 1) + else: + raise RuntimeError(f"Format of points {points.shape} not understood") + + +def dot_torch(transform, points, coords=False): + if not isinstance(transform, torch.Tensor): + transform = torch.from_numpy(transform).float() + + transform = transform.to(points.device).float() + if type(points) == list: + points = torch.Tensor(points) + if len(points.shape) == 1: + # single point + if transform.shape == (3, 3): + return transform @ points[:3] + else: + return (transform @ torch.Tensor([*points[:3], 1]))[:3] + if points.shape[1] == 3 or (coords and points.shape[1] > 3): + # nx[xyz,...] + if transform.shape == (4, 4): + pts = (transform[:3, :3] @ points[:, :3].T).T + transform[:3, 3] + elif transform.shape == (3, 3): + pts = (transform[:3, :3] @ points[:, :3].T).T + else: + raise RuntimeError("Format of transform not understood") + return torch.cat([pts, points[:, 3:]], 1) + else: + raise RuntimeError(f"Format of points {points.shape} not understood") + + +def dot2d(transform, points): + if type(points) == list: + points = np.array(points) + + if len(points.shape) == 1: + # single point + if transform.shape == (2, 2): + return transform @ points[:2] + else: + return (transform @ np.array([*points[:2], 1]))[:2] + elif len(points.shape) == 2: + if points.shape[1] in [2, 3]: + # needs to be transposed for dot product + points = points.T + else: + raise RuntimeError("Format of points not understood") + # points in format [2/3,n] + if transform.shape == (3, 3): + return (transform[:2, :2] @ points[:2]).T + transform[:2, 2] + elif transform.shape == (2, 2): + return (transform[:2, :2] @ points[:2]).T + else: + raise RuntimeError("Format of transform not understood") + + +def backproject(depth, intrinsics, cam2world=np.eye(4), color=None): + # in height x width (xrgb) + h, w = depth.shape + valid_px = depth > 0 + yv, xv = np.meshgrid(range(h), range(w), indexing="ij") + img_coords = np.stack([yv, xv], -1) + img_coords = img_coords[valid_px] + z_coords = depth[valid_px] + pts = uvd_backproject(img_coords, z_coords, intrinsics, cam2world, color[valid_px] if color is not None else None) + + return pts + + +def uvd_backproject(uv, d, intrinsics, cam2world=np.eye(4), color=None): + fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2] + py = (uv[:, 0] - cy) * d / fy + px = (uv[:, 1] - cx) * d / fx + pts = np.stack([px, py, d]) + + pts = cam2world[:3, :3] @ pts + np.tile(cam2world[:3, 3], (pts.shape[1], 1)).T + pts = pts.T + if color is not None: + pts = np.concatenate([pts, color], 1) + + return pts + + +def trs_decomp(A): + if has_torch(A): + s_vec = torch.norm(A[:3, :3], dim=0) + else: + s_vec = np.linalg.norm(A[:3, :3], axis=0) + R = A[:3, :3] / s_vec + t = A[:3, 3] + return t, R, s_vec + + +def scale_mat(s, as_torch=True): + if isinstance(s, np.ndarray): + s_mat = np.eye(4) + s_mat[:3, :3] *= s + elif has_torch(s): + s_mat = torch.eye(4).to(s.device) + s_mat[:3, :3] *= s + s_mat + else: + s_mat = torch.eye(4) if as_torch else np.eye(4) + s_mat[:3, :3] *= s + return s_mat + + +def trans_mat(t): + if has_torch(t): + t_mat = torch.eye(4).to(t.device).float() + t_mat[:3, 3] = t + else: + t_mat = np.eye(4, dtype=np.float32) + t_mat[:3, 3] = t + return t_mat + + +def rot_mat(axangle=None, euler=None, quat=None, as_torch=True): + R = np.eye(3) + if axangle is not None: + if euler is None: + axis, angle = axangle[0], axangle[1] + else: + axis, angle = axangle, euler + R = axangle2mat(axis, angle) + elif euler is not None: + R = euler2mat(*euler) + elif quat is not None: + R = quat2mat(quat) + if as_torch: + R = torch.Tensor(R) + return R + + +def hmg(M): + if M.shape[0] == 3 and M.shape[1] == 3: + if has_torch(M): + hmg_M = torch.eye(4, dtype=M.dtype).to(M.device) + else: + hmg_M = np.eye(4, dtype=M.dtype) + hmg_M[:3, :3] = M + else: + hmg_M = M + return hmg_M + + +def trs_comp(t, R, s_vec): + return trans_mat(t) @ hmg(R) @ scale_mat(s_vec) + + +def tr_comp(t, R): + return trans_mat(t) @ hmg(R) + + +def quat_from_two_vectors(v0, v1): + import quaternion as qt + v0 = v0 / np.linalg.norm(v0) + v1 = v1 / np.linalg.norm(v1) + c = v0.dot(v1) + if c < (-1 + 1e-8): + c = max(c, -1) + m = np.stack([v0, v1], 0) + _, _, vh = np.linalg.svd(m, full_matrices=True) + axis = vh[2] + w2 = (1 + c) * 0.5 + w = np.sqrt(w2) + axis = axis * np.sqrt(1 - w2) + return qt.quaternion(w, *axis) + + axis = np.cross(v0, v1) + s = np.sqrt((1 + c) * 2) + return qt.quaternion(s * 0.5, *(axis / s)) + + +def to4x4(pose): + constants = torch.zeros_like(pose[..., :1, :], device=pose.device) + constants[..., :, 3] = 1 + return torch.cat([pose, constants], dim=-2) + + +def normalize(poses): + pose_copy = torch.clone(poses) + pose_copy[..., :3, 3] /= torch.max(torch.abs(poses[..., :3, 3])) + return pose_copy