diff --git a/cache/cams/.cache b/cache/cams/.cache deleted file mode 100644 index e69de29..0000000 diff --git a/cache/depths/.cache b/cache/depths/.cache deleted file mode 100644 index e69de29..0000000 diff --git a/cache/images/.cache b/cache/images/.cache deleted file mode 100644 index e69de29..0000000 diff --git a/cache/propagated_depth/.cache b/cache/propagated_depth/.cache deleted file mode 100644 index e69de29..0000000 diff --git a/submodules/Propagation/CMakeLists.txt b/submodules/Propagation/CMakeLists.txt deleted file mode 100644 index 0b44db7..0000000 --- a/submodules/Propagation/CMakeLists.txt +++ /dev/null @@ -1,44 +0,0 @@ -cmake_minimum_required (VERSION 2.8) -project (Propagation) - -find_package(CUDA 6.0 REQUIRED ) # For Cuda Managed Memory and c++11 -find_package(OpenCV REQUIRED ) - -include_directories(${OpenCV_INCLUDE_DIRS}) -include_directories(.) - -set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-O3 --use_fast_math --maxrregcount=128 --ptxas-options=-v -std=c++11 --compiler-options -Wall -gencode arch=compute_86,code=sm_86 -gencode arch=compute_86,code=sm_86) - -if(CMAKE_COMPILER_IS_GNUCXX) - add_definitions(-std=c++11) - add_definitions(-pthread) - add_definitions(-Wall) - add_definitions(-Wextra) - add_definitions(-pedantic) - add_definitions(-Wno-unused-function) - add_definitions(-Wno-switch) - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -O3 -ffast-math -march=native") # extend release-profile with fast-math -endif() - -find_package(OpenMP) -if (OPENMP_FOUND) - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") -endif() - - -# For compilation ... -# Specify target & source files to compile it from -cuda_add_executable( - Propagation - main.h - Propagation.h - Propagation.cpp - Propagation.cu - main.cpp - ) - -# For linking ... -# Specify target & libraries to link it with -target_link_libraries(Propagation - ${OpenCV_LIBS} - ) diff --git a/submodules/Propagation/Propagation.cpp b/submodules/Propagation/PatchMatch.cpp similarity index 84% rename from submodules/Propagation/Propagation.cpp rename to submodules/Propagation/PatchMatch.cpp index f977d6e..1a2fcbb 100644 --- a/submodules/Propagation/Propagation.cpp +++ b/submodules/Propagation/PatchMatch.cpp @@ -1,4 +1,5 @@ -#include "Propagation.h" +#include "PatchMatch.h" +#include #include @@ -96,9 +97,9 @@ void CudaCheckError(const char* file, const int line) { } } -Propagation::Propagation() {} +PatchMatch::PatchMatch() {} -Propagation::~Propagation() +PatchMatch::~PatchMatch() { delete[] plane_hypotheses_host; delete[] costs_host; @@ -124,29 +125,25 @@ Propagation::~Propagation() } } -Camera ReadCamera(const std::string &cam_path) +Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval) { Camera camera; - std::ifstream file(cam_path); - - std::string line; - file >> line; for (int i = 0; i < 3; ++i) { - file >> camera.R[3 * i + 0] >> camera.R[3 * i + 1] >> camera.R[3 * i + 2] >> camera.t[i]; + camera.R[3 * i + 0] = pose[i][0].item(); + camera.R[3 * i + 1] = pose[i][1].item(); + camera.R[3 * i + 2] = pose[i][2].item(); + camera.t[i] = pose[i][3].item(); } - float tmp[4]; - file >> tmp[0] >> tmp[1] >> tmp[2] >> tmp[3]; - file >> line; - for (int i = 0; i < 3; ++i) { - file >> camera.K[3 * i + 0] >> camera.K[3 * i + 1] >> camera.K[3 * i + 2]; + camera.K[3 * i + 0] = intrinsic[i][0].item(); + camera.K[3 * i + 1] = intrinsic[i][1].item(); + camera.K[3 * i + 2] = intrinsic[i][2].item(); } - float depth_num; - float interval; - file >> camera.depth_min >> interval >> depth_num >> camera.depth_max; + camera.depth_min = depth_interval[0].item(); + camera.depth_max = depth_interval[3].item(); return camera; } @@ -406,54 +403,56 @@ static float GetDisparity(const Camera &camera, const int2 &p, const float &dept return std::sqrt(point3D[0] * point3D[0] + point3D[1] * point3D[1] + point3D[2] * point3D[2]); } -void Propagation::SetGeomConsistencyParams() +cv::Mat tensorToMat(const torch::Tensor& tensor) { + torch::Tensor tensor_contiguous = tensor.contiguous(); + torch::Tensor tensor_cpu_float = tensor_contiguous.to(torch::kCPU).to(torch::kFloat32); + + int height = tensor_cpu_float.size(0); + int width = tensor_cpu_float.size(1); + int channels = tensor_cpu_float.size(2); + + cv::Mat mat(cv::Size(width, height), CV_32FC(channels), tensor_cpu_float.data_ptr()); + + return mat.clone(); +} + +void PatchMatch::SetGeomConsistencyParams() { params.geom_consistency = true; params.max_iterations = 2; } -void Propagation::InuputInitialization(const std::string &dense_folder, const Problem &problem) +void PatchMatch::InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, + torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals) { images.clear(); cameras.clear(); - std::string image_folder = dense_folder + std::string("/images"); - std::string cam_folder = dense_folder + std::string("/cams"); - std::string depth_folder = dense_folder + std::string("/depths"); - - std::stringstream image_path; - image_path << image_folder << "/" << problem.ref_image_id << ".jpg"; - cv::Mat_ image_uint = cv::imread(image_path.str(), cv::IMREAD_GRAYSCALE); + cv::Mat image_color = tensorToMat(images_cuda[0]); cv::Mat image_float; - image_uint.convertTo(image_float, CV_32FC1); + cv::cvtColor(image_color, image_float, cv::COLOR_RGB2GRAY); + + image_float.convertTo(image_float, CV_32FC1); images.push_back(image_float); - std::stringstream cam_path; - cam_path << cam_folder << "/" << problem.ref_image_id << ".txt"; - Camera camera = ReadCamera(cam_path.str()); + Camera camera = ReadCamera(intrinsics_cuda[0], poses_cuda[0], depth_intervals[0]); camera.height = image_float.rows; camera.width = image_float.cols; cameras.push_back(camera); - std::stringstream depth_path; - depth_path << depth_folder << "/" << problem.ref_image_id << ".png"; - cv::Mat ref_depth = cv::imread(depth_path.str(), cv::IMREAD_ANYDEPTH | cv::IMREAD_GRAYSCALE); - ref_depth.convertTo(ref_depth, CV_32FC1); - //scale to metric m - ref_depth = ref_depth / 100.; + cv::Mat ref_depth = tensorToMat(depth_cuda); depths.push_back(ref_depth); - size_t num_src_images = problem.src_image_ids.size(); - for (size_t i = 0; i < num_src_images; ++i) { - std::stringstream image_path; - image_path << image_folder << "/" << problem.src_image_ids[i] << ".jpg"; - cv::Mat_ image_uint = cv::imread(image_path.str(), cv::IMREAD_GRAYSCALE); - cv::Mat image_float; - image_uint.convertTo(image_float, CV_32FC1); - images.push_back(image_float); - std::stringstream cam_path; - cam_path << cam_folder << "/" << problem.src_image_ids[i] << ".txt"; - Camera camera = ReadCamera(cam_path.str()); + int num_src_images = images_cuda.size(0); + for (int i = 1; i < num_src_images; ++i) { + cv::Mat src_image_color = tensorToMat(images_cuda[i]); + cv::Mat src_image_float; + cv::cvtColor(src_image_color, src_image_float, cv::COLOR_RGB2GRAY); + + src_image_float.convertTo(src_image_float, CV_32FC1); + images.push_back(src_image_float); + + Camera camera = ReadCamera(intrinsics_cuda[i], poses_cuda[i], depth_intervals[i]); camera.height = image_float.rows; camera.width = image_float.cols; cameras.push_back(camera); @@ -497,7 +496,7 @@ void Propagation::InuputInitialization(const std::string &dense_folder, const Pr } -void Propagation::CudaSpaceInitialization(const std::string &dense_folder, const Problem &problem) +void PatchMatch::CudaSpaceInitialization() { num_images = (int)images.size(); @@ -507,6 +506,7 @@ void Propagation::CudaSpaceInitialization(const std::string &dense_folder, const cudaChannelFormatDesc channelDesc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); cudaMallocArray(&cuArray[i], &channelDesc, cols, rows); + cudaMemcpy2DToArray (cuArray[i], 0, 0, images[i].ptr(), images[i].step[0], cols*sizeof(float), rows, cudaMemcpyHostToDevice); struct cudaResourceDesc resDesc; @@ -522,8 +522,16 @@ void Propagation::CudaSpaceInitialization(const std::string &dense_folder, const texDesc.readMode = cudaReadModeElementType; texDesc.normalizedCoords = 0; + // cudaError_t error = cudaGetLastError(); + // printf("CUDA notification0: %s\n", "test"); + // if (error != cudaSuccess) { + // printf("CUDA error step 0: %s\n", cudaGetErrorString(error)); + // // 错误处理代码 + // } + cudaCreateTextureObject(&(texture_objects_host.images[i]), &resDesc, &texDesc, NULL); } + cudaMalloc((void**)&texture_objects_cuda, sizeof(cudaTextureObjects)); cudaMemcpy(texture_objects_cuda, &texture_objects_host, sizeof(cudaTextureObjects), cudaMemcpyHostToDevice); @@ -541,39 +549,45 @@ void Propagation::CudaSpaceInitialization(const std::string &dense_folder, const cudaMalloc((void**)&depths_cuda, sizeof(float) * (cameras[0].height * cameras[0].width)); cudaMemcpy(depths_cuda, depths[0].ptr(), sizeof(float) * cameras[0].height * cameras[0].width, cudaMemcpyHostToDevice); + } -int Propagation::GetReferenceImageWidth() +int PatchMatch::GetReferenceImageWidth() { return cameras[0].width; } -int Propagation::GetReferenceImageHeight() +int PatchMatch::GetReferenceImageHeight() { return cameras[0].height; } -cv::Mat Propagation::GetReferenceImage() +cv::Mat PatchMatch::GetReferenceImage() { return images[0]; } -float4 Propagation::GetPlaneHypothesis(const int index) +float4 PatchMatch::GetPlaneHypothesis(const int index) { return plane_hypotheses_host[index]; } -float Propagation::GetCost(const int index) +float4* PatchMatch::GetPlaneHypotheses() +{ + return plane_hypotheses_host; +} + +float PatchMatch::GetCost(const int index) { return costs_host[index]; } -void Propagation::SetPatchSize(int patch_size) +void PatchMatch::SetPatchSize(int patch_size) { params.patch_size = patch_size; } -int Propagation::GetPatchSize() +int PatchMatch::GetPatchSize() { return params.patch_size; } diff --git a/submodules/Propagation/Propagation.h b/submodules/Propagation/PatchMatch.h similarity index 86% rename from submodules/Propagation/Propagation.h rename to submodules/Propagation/PatchMatch.h index 242de70..c968190 100644 --- a/submodules/Propagation/Propagation.h +++ b/submodules/Propagation/PatchMatch.h @@ -1,7 +1,8 @@ -#ifndef _Propagation_H_ -#define _Propagation_H_ +#ifndef _PatchMatch_H_ +#define _PatchMatch_H_ #include "main.h" +#include int readDepthDmb(const std::string file_path, cv::Mat_ &depth); int readNormalDmb(const std::string file_path, cv::Mat_ &normal); @@ -42,14 +43,14 @@ struct PatchMatchParams { bool geom_consistency = false; }; -class Propagation { +class PatchMatch { public: - Propagation(); - ~Propagation(); + PatchMatch(); + ~PatchMatch(); - void InuputInitialization(const std::string &dense_folder, const Problem &problem); + void InuputInitialization(torch::Tensor images_cuda, torch::Tensor intrinsics_cuda, torch::Tensor poses_cuda, torch::Tensor depth_cuda, torch::Tensor normal_cuda, torch::Tensor depth_intervals); void Colmap2MVS(const std::string &dense_folder, std::vector &problems); - void CudaSpaceInitialization(const std::string &dense_folder, const Problem &problem); + void CudaSpaceInitialization(); void RunPatchMatch(); void SetGeomConsistencyParams(); void SetPatchSize(int patch_size); @@ -59,6 +60,8 @@ class Propagation { cv::Mat GetReferenceImage(); float4 GetPlaneHypothesis(const int index); float GetCost(const int index); + float4* GetPlaneHypotheses(); + private: int num_images; std::vector images; @@ -82,4 +85,4 @@ class Propagation { float *depths_cuda; }; -#endif // _Propagation_H_ +#endif // _PatchMatch_H_ diff --git a/submodules/Propagation/Propagation.cu b/submodules/Propagation/Propagation.cu index 30d8127..b3a7573 100644 --- a/submodules/Propagation/Propagation.cu +++ b/submodules/Propagation/Propagation.cu @@ -1,4 +1,6 @@ -#include "Propagation.h" +#include "PatchMatch.h" +#include +#include __device__ void sort_small(float *d, const int n) { @@ -191,7 +193,7 @@ __device__ float4 GenerateRandomPlaneHypothesis(const Camera camera, const int2 depth = curand_uniform(rand_state) * (depth_max - depth_min) + depth_min; } // printf("initdepth: %f\n", init_depth); - // float depth = curand_uniform(rand_state) * (depth_max - depth_min) + depth_min; + float4 plane_hypothesis = GenerateRandomNormal(camera, p, rand_state, depth); plane_hypothesis.w = GetDistance2Origin(camera, p, depth, plane_hypothesis); return plane_hypothesis; @@ -1082,7 +1084,8 @@ __global__ void RedPixelFilter(const Camera *cameras, float4 *plane_hypotheses, CheckerboardFilter(cameras, plane_hypotheses, costs, p); } -void Propagation::RunPatchMatch() + +void PatchMatch::RunPatchMatch() { const int width = cameras[0].width; const int height = cameras[0].height; @@ -1110,12 +1113,14 @@ void Propagation::RunPatchMatch() int max_iterations = params.max_iterations; - RandomInitialization<<>>(texture_objects_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, depths_cuda); + RandomInitialization<<>>(texture_objects_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, depths_cuda); CUDA_SAFE_CALL(cudaDeviceSynchronize()); - + for (int i = 0; i < max_iterations; ++i) { + BlackPixelUpdate<<>>(texture_objects_cuda, texture_depths_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, i); CUDA_SAFE_CALL(cudaDeviceSynchronize()); + RedPixelUpdate<<>>(texture_objects_cuda, texture_depths_cuda, cameras_cuda, plane_hypotheses_cuda, costs_cuda, rand_states_cuda, selected_views_cuda, params, i); CUDA_SAFE_CALL(cudaDeviceSynchronize()); // printf("iteration: %d\n", i); @@ -1133,3 +1138,54 @@ void Propagation::RunPatchMatch() cudaMemcpy(costs_host, costs_cuda, sizeof(float) * width * height, cudaMemcpyDeviceToHost); CUDA_SAFE_CALL(cudaDeviceSynchronize()); } + +torch::Tensor matToTensor(cv::Mat& mat) { + cv::Mat mat_float; + if (mat.channels() == 3) { + mat.convertTo(mat_float, CV_32FC3); + } else if (mat.channels() == 1) { + mat.convertTo(mat_float, CV_32FC1); + } + + torch::Tensor tensor = torch::from_blob(mat_float.data, + {mat_float.rows, mat_float.cols, mat_float.channels()}, + torch::kFloat32).clone(); + + tensor = tensor.permute({2, 0, 1}); + + return tensor; +} + +torch::Tensor propagate_cuda(torch::Tensor images, torch::Tensor intrinsics, torch::Tensor poses, + torch::Tensor depth, torch::Tensor normal, torch::Tensor depth_intervals, int patch_size) +{ + cudaSetDevice(0); + + PatchMatch pm; + pm.SetPatchSize(patch_size); + + pm.InuputInitialization(images, intrinsics, poses, depth, normal, depth_intervals); + + pm.CudaSpaceInitialization(); + pm.RunPatchMatch(); + + const int width = pm.GetReferenceImageWidth(); + const int height = pm.GetReferenceImageHeight(); + + torch::Tensor depths = torch::zeros({height, width}, torch::kFloat); + torch::Tensor normals = torch::zeros({height, width, 3}, torch::kFloat); + + int numPixels = width * height; + + float4* plane_hypotheses = pm.GetPlaneHypotheses(); + + torch::Tensor planeHypothesisTensor = torch::from_blob(plane_hypotheses, {numPixels, 4}, torch::kFloat); + + torch::Tensor propagated_depth = planeHypothesisTensor.index({torch::indexing::Slice(), 3}).reshape({height, width}).unsqueeze(0); + + torch::Tensor propagated_normal = planeHypothesisTensor.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}).reshape({height, width, 3}).permute({2, 0, 1}); + + torch::Tensor results = torch::cat({propagated_depth, propagated_normal}, 0); + + return results; +} \ No newline at end of file diff --git a/submodules/Propagation/main.cpp b/submodules/Propagation/main.cpp deleted file mode 100644 index decab3e..0000000 --- a/submodules/Propagation/main.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include "main.h" -#include "Propagation.h" -#include -#include - -void GenerateSampleList(const std::string &dense_folder, std::vector &problems) -{ - std::string cluster_list_path = dense_folder + std::string("/pair.txt"); - - problems.clear(); - - std::ifstream file(cluster_list_path); - - int num_images; - file >> num_images; - - for (int i = 0; i < num_images; ++i) { - Problem problem; - problem.src_image_ids.clear(); - file >> problem.ref_image_id; - - int num_src_images; - file >> num_src_images; - for (int j = 0; j < num_src_images; ++j) { - int id; - float score; - file >> id >> score; - if (score <= 0.0f) { - continue; - } - problem.src_image_ids.push_back(id); - } - problems.push_back(problem); - } -} - -void ProcessProblem(const std::string &dense_folder, const Problem &problem, bool geom_consistency, int patch_size) -{ - // std::cout << "Processing image " << std::setw(8) << std::setfill('0') << problem.ref_image_id << "..." << std::endl; - cudaSetDevice(1); - std::stringstream result_path; - result_path << dense_folder << "/propagated_depth"; - std::string result_folder = result_path.str(); - mkdir(result_folder.c_str(), 0777); - // std::cout << result_folder << std::endl; - - Propagation pro; - int temp = pro.GetPatchSize(); - pro.SetPatchSize(patch_size); - temp = pro.GetPatchSize(); - if (geom_consistency) { - pro.SetGeomConsistencyParams(); - } - pro.InuputInitialization(dense_folder, problem); - - pro.CudaSpaceInitialization(dense_folder, problem); - pro.RunPatchMatch(); - - const int width = pro.GetReferenceImageWidth(); - const int height = pro.GetReferenceImageHeight(); - - cv::Mat_ depths = cv::Mat::zeros(height, width, CV_32FC1); - cv::Mat_ normals = cv::Mat::zeros(height, width, CV_32FC3); - cv::Mat_ costs = cv::Mat::zeros(height, width, CV_32FC1); - - for (int col = 0; col < width; ++col) { - for (int row = 0; row < height; ++row) { - int center = row * width + col; - float4 plane_hypothesis = pro.GetPlaneHypothesis(center); - depths(row, col) = plane_hypothesis.w; - normals(row, col) = cv::Vec3f(plane_hypothesis.x, plane_hypothesis.y, plane_hypothesis.z); - costs(row, col) = pro.GetCost(center); - } - } - - std::string suffix = "/depths.dmb"; - if (geom_consistency) { - suffix = "/depths_geom.dmb"; - } - std::string depth_path = result_folder + suffix; - std::string normal_path = result_folder + "/normals.dmb"; - std::string cost_path = result_folder + "/costs.dmb"; - writeDepthDmb(depth_path, depths); - writeNormalDmb(normal_path, normals); - writeDepthDmb(cost_path, costs); - // std::cout << "Processing image " << std::setw(8) << std::setfill('0') << problem.ref_image_id << " done!" << std::endl; -} - -int main(int argc, char** argv) -{ - if (argc < 2) { - std::cout << "USAGE: Propagation filespath ref_id src_ids patchsize" << std::endl; - return -1; - } - - std::string dense_folder = argv[1]; - int ref_id = std::stoi(argv[2]); - std::string src_ids_str = argv[3]; - int patch_size = std::stoi(argv[4]); - - Problem problem; - problem.ref_image_id = ref_id; - - std::stringstream ss(src_ids_str); - std::string token; - - while (getline(ss, token, ' ')) { - int src_id = std::stoi(token); - problem.src_image_ids.push_back(src_id); - } - - bool geom_consistency = false; - ProcessProblem(dense_folder, problem, geom_consistency, patch_size); - - return 0; -} diff --git a/submodules/Propagation/main.h b/submodules/Propagation/main.h index f2d0001..13254fe 100644 --- a/submodules/Propagation/main.h +++ b/submodules/Propagation/main.h @@ -4,6 +4,8 @@ #include "opencv2/calib3d/calib3d.hpp" #include "opencv2/imgproc/imgproc.hpp" #include "opencv2/core/core.hpp" +#include +#include "opencv2/imgcodecs.hpp" #include "opencv2/highgui/highgui.hpp" // Includes CUDA diff --git a/submodules/Propagation/pro.cpp b/submodules/Propagation/pro.cpp new file mode 100644 index 0000000..d52db3c --- /dev/null +++ b/submodules/Propagation/pro.cpp @@ -0,0 +1,29 @@ +#include +#include + +torch::Tensor propagate_cuda( + torch::Tensor images, + torch::Tensor intrinsics, + torch::Tensor poses, + torch::Tensor depth, + torch::Tensor normal, + torch::Tensor depth_intervals, + int patch_size); + +torch::Tensor propagate( + torch::Tensor images, + torch::Tensor intrinsics, + torch::Tensor poses, + torch::Tensor depth, + torch::Tensor normal, + torch::Tensor depth_intervals, + int patch_size) { + + return propagate_cuda(images, intrinsics, poses, depth, normal, depth_intervals, patch_size); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // bundle adjustment kernels + m.def("propagate", &propagate, "plane propagation"); +} \ No newline at end of file diff --git a/submodules/Propagation/setup.py b/submodules/Propagation/setup.py new file mode 100644 index 0000000..e6b6768 --- /dev/null +++ b/submodules/Propagation/setup.py @@ -0,0 +1,27 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +import os.path as osp +ROOT = osp.dirname(osp.abspath(__file__)) + +setup( + name='gaussianpro', + ext_modules=[ + CUDAExtension('gaussianpro', + include_dirs=['/data/kcheng/anaconda3/envs/procuda/include/opencv4', '/usr/local/cuda-11.7/include', '.'], + library_dirs=['/data/kcheng/anaconda3/envs/procuda/lib'], + libraries=['opencv_core', 'opencv_imgproc', 'opencv_highgui', 'opencv_imgcodecs'], + sources=[ + 'PatchMatch.cpp', + 'Propagation.cu', + 'pro.cpp' + ], + extra_compile_args={ + 'cxx': ['-O3'], + 'nvcc': ['-O3', + '-gencode=arch=compute_86,code=sm_86', + ] + }), + ], + cmdclass={ 'build_ext' : BuildExtension } +) diff --git a/train.py b/train.py index 2b2cd9b..3618aa5 100644 --- a/train.py +++ b/train.py @@ -12,10 +12,10 @@ import os import torch from random import randint -from utils.loss_utils import l1_loss, ssim +from utils.loss_utils import l1_loss, ssim, compute_scale_and_shift, ScaleAndShiftInvariantLoss from utils.general_utils import vis_depth, read_propagted_depth from gaussian_renderer import render, network_gui -from utils.graphics_utils import depth_propagation, check_geometric_consistency +from utils.graphics_utils import surface_normal_from_depth, img_warping, depth_propagation, check_geometric_consistency, generate_edge_mask import sys from scene import Scene, GaussianModel from utils.general_utils import safe_state, load_pairs_relation @@ -27,6 +27,7 @@ import imageio import numpy as np import torchvision +import cv2 try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_FOUND = True @@ -92,14 +93,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi gaussians.oneupSHdegree() # Pick a random Camera + # if not viewpoint_stack: + # viewpoint_stack = scene.getTrainCameras().copy() randidx = randint(0, len(viewpoint_stack)-1) + # if iteration > propagated_iteration_begin and iteration < propagated_iteration_after and after_propagated: + # randidx = propagated_view_index viewpoint_cam = viewpoint_stack[randidx] - # set the neighboring frames if opt.depth_loss: if opt.dataset == '360': src_idxs = pairs[randidx] else: + # intervals = [-6, -3, 3, 6] if opt.dataset == 'waymo': intervals = [-2, -1, 1, 2] elif opt.dataset == 'scannet': @@ -110,7 +115,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi #propagate the gaussians first with torch.no_grad(): - if opt.depth_loss and iteration > propagated_iteration_begin and iteration < propagated_iteration_after and (iteration % opt.propagation_interval == 0): + if opt.depth_loss and iteration > propagated_iteration_begin and iteration < propagated_iteration_after and (iteration % opt.propagation_interval == 0 and not propagation_dict[viewpoint_cam.image_name]): + # if opt.depth_loss and iteration > propagated_iteration_begin and iteration < propagated_iteration_after and (iteration % opt.propagation_interval == 0): propagation_dict[viewpoint_cam.image_name] = True render_pkg = render(viewpoint_cam, gaussians, pipe, bg, @@ -123,18 +129,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi sky_mask = viewpoint_cam.sky_mask.to(opacity_mask.device).to(torch.bool) else: sky_mask = None + torchvision.utils.save_image(viewpoint_cam.original_image, "cost/"+viewpoint_cam.image_name+"_"+str(iteration)+"gt.png") # get the propagated depth - depth_propagation(viewpoint_cam, projected_depth, viewpoint_stack, src_idxs, opt.dataset, opt.patch_size) - propagated_depth, cost, normal = read_propagted_depth('./cache/propagated_depth') - cost = torch.tensor(cost).to(projected_depth.device) - normal = torch.tensor(normal).to(projected_depth.device) + propagated_depth, normal = depth_propagation(viewpoint_cam, projected_depth, viewpoint_stack, src_idxs, opt.dataset, opt.patch_size) + + # cache the propagated_depth + viewpoint_cam.depth = propagated_depth + #transform normal to camera coordinate R_w2c = torch.tensor(viewpoint_cam.R.T).cuda().to(torch.float32) # R_w2c[:, 1:] *= -1 normal = (R_w2c @ normal.view(-1, 3).permute(1, 0)).view(3, viewpoint_cam.image_height, viewpoint_cam.image_width) - - propagated_depth = torch.tensor(propagated_depth).to(projected_depth.device) valid_mask = propagated_depth != 300 # calculate the abs rel depth error of the propagated depth and rendered depth & render color error @@ -143,15 +149,18 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi abs_rel_error_threshold = opt.depth_error_max_threshold - (opt.depth_error_max_threshold - opt.depth_error_min_threshold) * (iteration - propagated_iteration_begin) / (propagated_iteration_after - propagated_iteration_begin) # color error render_color = render_pkg['render'] + torchvision.utils.save_image(render_color, "cost/"+viewpoint_cam.image_name+"_"+str(iteration)+"color.png") color_error = torch.abs(render_color - viewpoint_cam.original_image) color_error = color_error.mean(dim=0).squeeze() - #for waymo, quantile 0.6; for free dataset, quantile 0.4 error_mask = (abs_rel_error > abs_rel_error_threshold) - - # calculate the geometric consistency + + # # calculate the photometric consistency ref_K = viewpoint_cam.K + #c2w ref_pose = viewpoint_cam.world_view_transform.transpose(0, 1).inverse() + + # calculate the geometric consistency geometric_counts = None for idx, src_idx in enumerate(src_idxs): src_viewpoint = viewpoint_stack[src_idx] @@ -159,26 +168,29 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi src_pose = src_viewpoint.world_view_transform.transpose(0, 1).inverse() src_K = src_viewpoint.K - src_render_pkg = render(src_viewpoint, gaussians, pipe, bg, - return_normal=opt.normal_loss, return_opacity=False, return_depth=opt.depth_loss or opt.depth2normal_loss) - src_projected_depth = src_render_pkg['render_depth'] + if src_viewpoint.depth is None: + src_render_pkg = render(src_viewpoint, gaussians, pipe, bg, + return_normal=opt.normal_loss, return_opacity=False, return_depth=opt.depth_loss or opt.depth2normal_loss) + src_projected_depth = src_render_pkg['render_depth'] #get the src_depth first - depth_propagation(src_viewpoint, torch.zeros_like(src_projected_depth).cuda(), viewpoint_stack, src_idxs, opt.dataset, opt.patch_size) - src_depth, cost, src_normal = read_propagted_depth('./cache/propagated_depth') - src_depth = torch.tensor(src_depth).cuda() + src_depth, src_normal = depth_propagation(src_viewpoint, src_projected_depth, viewpoint_stack, src_idxs, opt.dataset, opt.patch_size) + src_viewpoint.depth = src_depth + else: + src_depth = src_viewpoint.depth + mask, depth_reprojected, x2d_src, y2d_src, relative_depth_diff = check_geometric_consistency(propagated_depth.unsqueeze(0), ref_K.unsqueeze(0), ref_pose.unsqueeze(0), src_depth.unsqueeze(0), src_K.unsqueeze(0), src_pose.unsqueeze(0), thre1=2, thre2=0.01) + if geometric_counts is None: geometric_counts = mask.to(torch.uint8) else: geometric_counts += mask.to(torch.uint8) cost = geometric_counts.squeeze() - cost_mask = cost >= 2 + cost_mask = cost >= 2 - #set -10 as nan normal[~(cost_mask.unsqueeze(0).repeat(3, 1, 1))] = -10 viewpoint_cam.normal = normal @@ -186,6 +198,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi if sky_mask is not None: propagated_mask = propagated_mask & sky_mask + propagated_depth[~cost_mask] = 300 + # propagated_mask = propagated_mask & edge_mask + propagated_depth[~propagated_mask] = 300 + if propagated_mask.sum() > 100: gaussians.densify_from_depth_propagation(viewpoint_cam, propagated_depth, propagated_mask.to(torch.bool), gt_image) @@ -238,6 +254,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi filter_mask = viewpoint_cam.sky_mask.to(normal_gt.device).to(torch.bool) normal_gt[~(filter_mask.unsqueeze(0).repeat(3, 1, 1))] = -10 filter_mask = (normal_gt != -10)[0, :, :].to(torch.bool) + l1_normal = torch.abs(rendered_normal - normal_gt).sum(dim=0)[filter_mask].mean() cos_normal = (1. - torch.sum(rendered_normal * normal_gt, dim = 0))[filter_mask].mean() loss += opt.lambda_l1_normal * l1_normal + opt.lambda_cos_normal * cos_normal @@ -352,8 +369,8 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i parser.add_argument('--port', type=int, default=6009) parser.add_argument('--debug_from', type=int, default=-1) parser.add_argument('--detect_anomaly', action='store_true', default=False) - parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 2000, 7000, 30000]) - parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 7000, 30000]) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 2000, 7000, 20000, 50000]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 7000, 20000, 50000]) parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--start_checkpoint", type=str, default = None) @@ -366,8 +383,6 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i # Initialize system state (RNG) safe_state(args.quiet) - # Start GUI server, configure and run training - # network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py index cd0695a..3630a9b 100644 --- a/utils/graphics_utils.py +++ b/utils/graphics_utils.py @@ -15,6 +15,7 @@ from typing import NamedTuple import cv2 import os +from gaussianpro import propagate class BasicPointCloud(NamedTuple): points : np.array @@ -286,6 +287,45 @@ def bilinear_sampler(img, coords, mask=False): return img +# def sparse_depth_from_projection(gaussians, viewpoint_cam): +# pc = gaussians.get_xyz.contiguous() +# K = viewpoint_cam.K +# img_height = viewpoint_cam.image_height +# img_width = viewpoint_cam.image_width +# znear = 0.1 +# zfar = 1000 +# proj_matrix = get_proj_matrix(K, (img_width, img_height), znear, zfar) +# proj_matrix = torch.tensor(proj_matrix).cuda().to(torch.float32) +# w2c = viewpoint_cam.world_view_transform.transpose(0, 1) +# c2w = w2c.inverse() +# c2w = c2w @ torch.tensor(np.diag([1., -1., -1., 1.]).astype(np.float32)).cuda() +# w2c = c2w.inverse() +# total_m = proj_matrix @ w2c +# index_buffer, _ = pcpr.forward(pc, total_m.unsqueeze(0), img_width, img_height, 512) +# sh = index_buffer.shape +# ind = index_buffer.view(-1).long().cuda() + +# xyz = pc.unsqueeze(0).permute(2,0,1) +# xyz = xyz.view(xyz.shape[0],-1) +# proj_xyz_world = torch.index_select(xyz, 1, ind) +# Rot, Trans = w2c[:3, :3], w2c[:3, 3][..., None] + +# proj_xyz_cam = Rot @ proj_xyz_world + Trans +# proj_depth = proj_xyz_cam[2,:][None,] +# proj_depth = proj_depth.view(proj_depth.shape[0], sh[0], sh[1], sh[2]) #[1, 4, 256, 256] +# proj_depth = proj_depth.permute(1, 0, 2, 3) +# proj_depth *= -1 + +# ##mask获取 +# mask = ind.clone() +# mask[mask>0] = 1 +# mask = mask.view(1, sh[0], sh[1], sh[2]) +# mask = mask.permute(1,0,2,3) + +# proj_depth = proj_depth * mask + +# return proj_depth.squeeze() + # project the reference point cloud into the source view, then project back #extrinsics here refers c2w def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): @@ -355,12 +395,8 @@ def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth return mask, depth_reprojected, x2d_src, y2d_src, relative_depth_diff -def depth_propagation(viewpoint_cam, projected_depth, viewpoint_stack, src_idxs, dataset, patch_size): - # pass data to c++ api for mvs - cdata_image_path = './cache/images' - cdata_camera_path = './cache/cams' - cdata_depth_path = './cache/depths' - +def depth_propagation(viewpoint_cam, rendered_depth, viewpoint_stack, src_idxs, dataset, patch_size): + depth_min = 0.1 if dataset == 'waymo': depth_max = 80 @@ -369,35 +405,38 @@ def depth_propagation(viewpoint_cam, projected_depth, viewpoint_stack, src_idxs, else: depth_max = 20 - # rendered_depth[rendered_depth>120] = 1e-3 - #scale it for float type - projected_depth = projected_depth * 100 - - ref_img = viewpoint_cam.original_image - ref_img = ref_img * 255 - ref_img = ref_img.permute((1, 2, 0)).detach().cpu().numpy().astype(np.uint8) - ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB) - ref_K = viewpoint_cam.K - ref_w2c = viewpoint_cam.world_view_transform.transpose(0, 1) - cv2.imwrite(os.path.join(cdata_image_path, "0.jpg"), ref_img) - cv2.imwrite(os.path.join(cdata_depth_path, "0.png"), projected_depth.detach().cpu().numpy().astype(np.uint16)) - write_cam_txt(os.path.join(cdata_camera_path, "0.txt"), ref_K.detach().cpu().numpy(), ref_w2c.detach().cpu().numpy(), - [depth_min, (depth_max-depth_min)/192.0, 192.0, depth_max]) + images = list() + intrinsics = list() + poses = list() + depth_intervals = list() + + images.append((viewpoint_cam.original_image * 255).permute((1, 2, 0)).to(torch.uint8)) + intrinsics.append(viewpoint_cam.K) + poses.append(viewpoint_cam.world_view_transform.transpose(0, 1)) + depth_interval = torch.tensor([depth_min, (depth_max-depth_min)/192.0, 192.0, depth_max]) + depth_intervals.append(depth_interval) + + depth = rendered_depth.unsqueeze(-1) + normal = torch.zeros_like(depth) + for idx, src_idx in enumerate(src_idxs): src_viewpoint = viewpoint_stack[src_idx] - src_w2c = src_viewpoint.world_view_transform.transpose(0, 1) - src_K = src_viewpoint.K - src_img = src_viewpoint.original_image - src_img = src_img * 255 - src_img = src_img.permute((1, 2, 0)).detach().cpu().numpy().astype(np.uint8) - src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) - - cv2.imwrite(os.path.join(cdata_image_path, str(idx+1)+".jpg"), src_img) - write_cam_txt(os.path.join(cdata_camera_path, str(idx+1)+".txt"), src_K.detach().cpu().numpy(), src_w2c.detach().cpu().numpy(), - [depth_min, (depth_max-depth_min)/192.0, 192.0, depth_max]) - # c++ api for depth propagation - propagation_command = './submodules/Propagation/Propagation ./cache 0 "1 2 3 4" ' + str(patch_size) - os.system(propagation_command) + images.append((src_viewpoint.original_image * 255).permute((1, 2, 0)).to(torch.uint8)) + intrinsics.append(src_viewpoint.K) + poses.append(src_viewpoint.world_view_transform.transpose(0, 1)) + depth_intervals.append(depth_interval) + + images = torch.stack(images) + intrinsics = torch.stack(intrinsics) + poses = torch.stack(poses) + depth_intervals = torch.stack(depth_intervals) + + results = propagate(images, intrinsics, poses, depth, normal, depth_intervals, patch_size) + propagated_depth = results[0].to(rendered_depth.device) + propagated_normal = results[1:4].to(rendered_depth.device).permute(1, 2, 0) + + return propagated_depth, propagated_normal + def generate_edge_mask(propagated_depth, patch_size): # img gradient @@ -416,4 +455,4 @@ def generate_edge_mask(propagated_depth, patch_size): dilated_mask = torch.round(dilated_mask).squeeze().to(torch.bool) dilated_mask = ~dilated_mask - return dilated_mask + return dilated_mask \ No newline at end of file