Skip to content

Commit

Permalink
feat: Replace opencv with torch
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoycj committed Sep 27, 2024
1 parent b943d67 commit 9af46a7
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 225 deletions.
225 changes: 40 additions & 185 deletions submodules/Propagation/PatchMatch.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "PatchMatch.h"
#include <torch/extension.h>
#include <cfloat>

#include <cstdarg>

Expand Down Expand Up @@ -148,21 +149,20 @@ Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor dep
return camera;
}

void RescaleImageAndCamera(cv::Mat_<cv::Vec3b> &src, cv::Mat_<cv::Vec3b> &dst, cv::Mat_<float> &depth, Camera &camera)
void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera)
{
const int cols = depth.cols;
const int rows = depth.rows;
const int cols = depth.size(1);
const int rows = depth.size(0);

if (cols == src.cols && rows == src.rows) {
if (cols == src.size(1) && rows == src.size(0)) {
dst = src.clone();
return;
}

const float scale_x = cols / static_cast<float>(src.cols);
const float scale_y = rows / static_cast<float>(src.rows);

cv::resize(src, dst, cv::Size(cols,rows), 0, 0, cv::INTER_LINEAR);

const float scale_x = cols / static_cast<float>(src.size(1));
const float scale_y = rows / static_cast<float>(src.size(0));
dst = torch::nn::functional::interpolate(src.unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector<int64_t>({rows, cols})).mode(torch::kBilinear)).squeeze(0);

camera.K[0] *= scale_x;
camera.K[2] *= scale_x;
camera.K[4] *= scale_y;
Expand Down Expand Up @@ -209,9 +209,9 @@ void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, fl
point.y = (camera.K[3] * tmp.x + camera.K[4] * tmp.y + camera.K[5] * tmp.z) / depth;
}

float GetAngle( const cv::Vec3f &v1, const cv::Vec3f &v2 )
float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2)
{
float dot_product = v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2];
float dot_product = v1[0].item<float>() * v2[0].item<float>() + v1[1].item<float>() * v2[1].item<float>() + v1[2].item<float>() * v2[2].item<float>();
float angle = acosf(dot_product);
//if angle is not a number the dot product was 1 and thus the two vectors should be identical --> return 0
if ( angle != angle )
Expand All @@ -220,124 +220,6 @@ float GetAngle( const cv::Vec3f &v1, const cv::Vec3f &v2 )
return angle;
}

int readDepthDmb(const std::string file_path, cv::Mat_<float> &depth)
{
FILE *inimage;
inimage = fopen(file_path.c_str(), "rb");
if (!inimage){
std::cout << "Error opening file " << file_path << std::endl;
return -1;
}

int32_t type, h, w, nb;

type = -1;

fread(&type,sizeof(int32_t),1,inimage);
fread(&h,sizeof(int32_t),1,inimage);
fread(&w,sizeof(int32_t),1,inimage);
fread(&nb,sizeof(int32_t),1,inimage);

if (type != 1) {
fclose(inimage);
return -1;
}

int32_t dataSize = h*w*nb;

depth = cv::Mat::zeros(h,w,CV_32F);
fread(depth.data,sizeof(float),dataSize,inimage);

fclose(inimage);
return 0;
}

int writeDepthDmb(const std::string file_path, const cv::Mat_<float> depth)
{
FILE *outimage;
outimage = fopen(file_path.c_str(), "wb");
if (!outimage) {
std::cout << "Error opening file " << file_path << std::endl;
}

int32_t type = 1;
int32_t h = depth.rows;
int32_t w = depth.cols;
int32_t nb = 1;

fwrite(&type,sizeof(int32_t),1,outimage);
fwrite(&h,sizeof(int32_t),1,outimage);
fwrite(&w,sizeof(int32_t),1,outimage);
fwrite(&nb,sizeof(int32_t),1,outimage);

float* data = (float*)depth.data;

int32_t datasize = w*h*nb;
fwrite(data,sizeof(float),datasize,outimage);

fclose(outimage);
return 0;
}

int readNormalDmb (const std::string file_path, cv::Mat_<cv::Vec3f> &normal)
{
FILE *inimage;
inimage = fopen(file_path.c_str(), "rb");
if (!inimage){
std::cout << "Error opening file " << file_path << std::endl;
return -1;
}

int32_t type, h, w, nb;

type = -1;

fread(&type,sizeof(int32_t),1,inimage);
fread(&h,sizeof(int32_t),1,inimage);
fread(&w,sizeof(int32_t),1,inimage);
fread(&nb,sizeof(int32_t),1,inimage);

if (type != 1) {
fclose(inimage);
return -1;
}

int32_t dataSize = h*w*nb;

normal = cv::Mat::zeros(h,w,CV_32FC3);
fread(normal.data,sizeof(float),dataSize,inimage);

fclose(inimage);
return 0;
}

int writeNormalDmb(const std::string file_path, const cv::Mat_<cv::Vec3f> normal)
{
FILE *outimage;
outimage = fopen(file_path.c_str(), "wb");
if (!outimage) {
std::cout << "Error opening file " << file_path << std::endl;
}

int32_t type = 1;
int32_t h = normal.rows;
int32_t w = normal.cols;
int32_t nb = 3;

fwrite(&type,sizeof(int32_t),1,outimage);
fwrite(&h,sizeof(int32_t),1,outimage);
fwrite(&w,sizeof(int32_t),1,outimage);
fwrite(&nb,sizeof(int32_t),1,outimage);

float* data = (float*)normal.data;

int32_t datasize = w*h*nb;
fwrite(data,sizeof(float),datasize,outimage);

fclose(outimage);
return 0;
}

void StoreColorPlyFileBinaryPointCloud (const std::string &plyFilePath, const std::vector<PointList> &pc)
{
std::cout << "store 3D points to ply file" << std::endl;
Expand Down Expand Up @@ -403,19 +285,6 @@ static float GetDisparity(const Camera &camera, const int2 &p, const float &dept
return std::sqrt(point3D[0] * point3D[0] + point3D[1] * point3D[1] + point3D[2] * point3D[2]);
}

cv::Mat tensorToMat(const torch::Tensor& tensor) {
torch::Tensor tensor_contiguous = tensor.contiguous();
torch::Tensor tensor_cpu_float = tensor_contiguous.to(torch::kCPU).to(torch::kFloat32);

int height = tensor_cpu_float.size(0);
int width = tensor_cpu_float.size(1);
int channels = tensor_cpu_float.size(2);

cv::Mat mat(cv::Size(width, height), CV_32FC(channels), tensor_cpu_float.data_ptr<float>());

return mat.clone();
}

void PatchMatch::SetGeomConsistencyParams()
{
params.geom_consistency = true;
Expand All @@ -428,69 +297,62 @@ void PatchMatch::InuputInitialization(torch::Tensor images_cuda, torch::Tensor i
images.clear();
cameras.clear();

cv::Mat image_color = tensorToMat(images_cuda[0]);
cv::Mat image_float;
cv::cvtColor(image_color, image_float, cv::COLOR_RGB2GRAY);

image_float.convertTo(image_float, CV_32FC1);
torch::Tensor image_color = images_cuda[0];
torch::Tensor image_float = torch::mean(image_color, /*dim=*/2, /*keepdim=*/true).squeeze();
image_float = image_float.to(torch::kFloat32);
images.push_back(image_float);

Camera camera = ReadCamera(intrinsics_cuda[0], poses_cuda[0], depth_intervals[0]);
camera.height = image_float.rows;
camera.width = image_float.cols;
camera.height = image_float.size(0);
camera.width = image_float.size(1);
cameras.push_back(camera);

cv::Mat ref_depth = tensorToMat(depth_cuda);
torch::Tensor ref_depth = depth_cuda;
depths.push_back(ref_depth);

int num_src_images = images_cuda.size(0);
for (int i = 1; i < num_src_images; ++i) {
cv::Mat src_image_color = tensorToMat(images_cuda[i]);
cv::Mat src_image_float;
cv::cvtColor(src_image_color, src_image_float, cv::COLOR_RGB2GRAY);

src_image_float.convertTo(src_image_float, CV_32FC1);
torch::Tensor src_image_color = images_cuda[i];
torch::Tensor src_image_float = torch::mean(src_image_color, /*dim=*/2, /*keepdim=*/true).squeeze();
src_image_float = src_image_float.to(torch::kFloat32);
images.push_back(src_image_float);

Camera camera = ReadCamera(intrinsics_cuda[i], poses_cuda[i], depth_intervals[i]);
camera.height = image_float.rows;
camera.width = image_float.cols;
camera.height = src_image_float.size(0);
camera.width = src_image_float.size(1);
cameras.push_back(camera);
}

// Scale cameras and images
for (size_t i = 0; i < images.size(); ++i) {
if (images[i].cols <= params.max_image_size && images[i].rows <= params.max_image_size) {
if (images[i].size(1) <= params.max_image_size && images[i].size(0) <= params.max_image_size) {
continue;
}

const float factor_x = static_cast<float>(params.max_image_size) / images[i].cols;
const float factor_y = static_cast<float>(params.max_image_size) / images[i].rows;
const float factor_x = static_cast<float>(params.max_image_size) / images[i].size(1);
const float factor_y = static_cast<float>(params.max_image_size) / images[i].size(0);
const float factor = std::min(factor_x, factor_y);

const int new_cols = std::round(images[i].cols * factor);
const int new_rows = std::round(images[i].rows * factor);
const int new_cols = std::round(images[i].size(1) * factor);
const int new_rows = std::round(images[i].size(0) * factor);

const float scale_x = new_cols / static_cast<float>(images[i].cols);
const float scale_y = new_rows / static_cast<float>(images[i].rows);
const float scale_x = new_cols / static_cast<float>(images[i].size(1));
const float scale_y = new_rows / static_cast<float>(images[i].size(0));

cv::Mat_<float> scaled_image_float;
cv::resize(images[i], scaled_image_float, cv::Size(new_cols,new_rows), 0, 0, cv::INTER_LINEAR);
torch::Tensor scaled_image_float = torch::nn::functional::interpolate(images[i].unsqueeze(0), torch::nn::functional::InterpolateFuncOptions().size(std::vector<int64_t>({new_rows, new_cols})).mode(torch::kBilinear)).squeeze(0);
images[i] = scaled_image_float.clone();

cameras[i].K[0] *= scale_x;
cameras[i].K[2] *= scale_x;
cameras[i].K[4] *= scale_y;
cameras[i].K[5] *= scale_y;
cameras[i].height = scaled_image_float.rows;
cameras[i].width = scaled_image_float.cols;
cameras[i].height = scaled_image_float.size(0);
cameras[i].width = scaled_image_float.size(1);
}

params.depth_min = cameras[0].depth_min * 0.6f;
params.depth_max = cameras[0].depth_max * 1.2f;
// std::cout << "depth range: " << params.depth_min << " " << params.depth_max << std::endl;
params.num_images = (int)images.size();
// std::cout << "num images: " << params.num_images << std::endl;
params.disparity_min = cameras[0].K[0] * params.baseline / params.depth_max;
params.disparity_max = cameras[0].K[0] * params.baseline / params.depth_min;

Expand All @@ -501,13 +363,13 @@ void PatchMatch::CudaSpaceInitialization()
num_images = (int)images.size();

for (int i = 0; i < num_images; ++i) {
int rows = images[i].rows;
int cols = images[i].cols;
int rows = images[i].size(0);
int cols = images[i].size(1);

cudaChannelFormatDesc channelDesc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat);
cudaMallocArray(&cuArray[i], &channelDesc, cols, rows);

cudaMemcpy2DToArray (cuArray[i], 0, 0, images[i].ptr<float>(), images[i].step[0], cols*sizeof(float), rows, cudaMemcpyHostToDevice);
cudaMemcpy2DToArray(cuArray[i], 0, 0, images[i].data_ptr<float>(), images[i].stride(0) * sizeof(float), cols * sizeof(float), rows, cudaMemcpyHostToDevice);

struct cudaResourceDesc resDesc;
memset(&resDesc, 0, sizeof(cudaResourceDesc));
Expand All @@ -522,13 +384,6 @@ void PatchMatch::CudaSpaceInitialization()
texDesc.readMode = cudaReadModeElementType;
texDesc.normalizedCoords = 0;

// cudaError_t error = cudaGetLastError();
// printf("CUDA notification0: %s\n", "test");
// if (error != cudaSuccess) {
// printf("CUDA error step 0: %s\n", cudaGetErrorString(error));
// // 错误处理代码
// }

cudaCreateTextureObject(&(texture_objects_host.images[i]), &resDesc, &texDesc, NULL);
}

Expand All @@ -538,8 +393,9 @@ void PatchMatch::CudaSpaceInitialization()
cudaMalloc((void**)&cameras_cuda, sizeof(Camera) * (num_images));
cudaMemcpy(cameras_cuda, &cameras[0], sizeof(Camera) * (num_images), cudaMemcpyHostToDevice);

plane_hypotheses_host = new float4[cameras[0].height * cameras[0].width];
cudaMalloc((void**)&plane_hypotheses_cuda, sizeof(float4) * (cameras[0].height * cameras[0].width));
int total_pixels = cameras[0].height * cameras[0].width;
plane_hypotheses_host = new float4[total_pixels];
cudaMalloc((void**)&plane_hypotheses_cuda, sizeof(float4) * total_pixels);

costs_host = new float[cameras[0].height * cameras[0].width];
cudaMalloc((void**)&costs_cuda, sizeof(float) * (cameras[0].height * cameras[0].width));
Expand All @@ -548,8 +404,7 @@ void PatchMatch::CudaSpaceInitialization()
cudaMalloc((void**)&selected_views_cuda, sizeof(unsigned int) * (cameras[0].height * cameras[0].width));

cudaMalloc((void**)&depths_cuda, sizeof(float) * (cameras[0].height * cameras[0].width));
cudaMemcpy(depths_cuda, depths[0].ptr<float>(), sizeof(float) * cameras[0].height * cameras[0].width, cudaMemcpyHostToDevice);

cudaMemcpy(depths_cuda, depths[0].data_ptr<float>(), sizeof(float) * cameras[0].height * cameras[0].width, cudaMemcpyHostToDevice);
}

int PatchMatch::GetReferenceImageWidth()
Expand All @@ -562,7 +417,7 @@ int PatchMatch::GetReferenceImageHeight()
return cameras[0].height;
}

cv::Mat PatchMatch::GetReferenceImage()
torch::Tensor PatchMatch::GetReferenceImage()
{
return images[0];
}
Expand Down
19 changes: 7 additions & 12 deletions submodules/Propagation/PatchMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@
#include "main.h"
#include <torch/extension.h>

int readDepthDmb(const std::string file_path, cv::Mat_<float> &depth);
int readNormalDmb(const std::string file_path, cv::Mat_<cv::Vec3f> &normal);
int writeDepthDmb(const std::string file_path, const cv::Mat_<float> depth);
int writeNormalDmb(const std::string file_path, const cv::Mat_<cv::Vec3f> normal);

Camera ReadCamera(const std::string &cam_path);
void RescaleImageAndCamera(cv::Mat_<cv::Vec3b> &src, cv::Mat_<cv::Vec3b> &dst, cv::Mat_<float> &depth, Camera &camera);
Camera ReadCamera(torch::Tensor intrinsic, torch::Tensor pose, torch::Tensor depth_interval);
void RescaleImageAndCamera(torch::Tensor &src, torch::Tensor &dst, torch::Tensor &depth, Camera &camera);
float3 Get3DPointonWorld(const int x, const int y, const float depth, const Camera camera);
void ProjectonCamera(const float3 PointX, const Camera camera, float2 &point, float &depth);
float GetAngle(const cv::Vec3f &v1, const cv::Vec3f &v2);
void StoreColorPlyFileBinaryPointCloud (const std::string &plyFilePath, const std::vector<PointList> &pc);
float GetAngle(const torch::Tensor &v1, const torch::Tensor &v2);
void StoreColorPlyFileBinaryPointCloud(const std::string &plyFilePath, const std::vector<PointList> &pc);

#define CUDA_SAFE_CALL(error) CudaSafeCall(error, __FILE__, __LINE__)
#define CUDA_CHECK_ERROR() CudaCheckError(__FILE__, __LINE__)
Expand Down Expand Up @@ -57,15 +52,15 @@ class PatchMatch {
int GetPatchSize();
int GetReferenceImageWidth();
int GetReferenceImageHeight();
cv::Mat GetReferenceImage();
torch::Tensor GetReferenceImage();
float4 GetPlaneHypothesis(const int index);
float GetCost(const int index);
float4* GetPlaneHypotheses();

private:
int num_images;
std::vector<cv::Mat> images;
std::vector<cv::Mat> depths;
std::vector<torch::Tensor> images;
std::vector<torch::Tensor> depths;
std::vector<Camera> cameras;
cudaTextureObjects texture_objects_host;
cudaTextureObjects texture_depths_host;
Expand Down
Loading

0 comments on commit 9af46a7

Please sign in to comment.