Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 219 additions & 129 deletions FastGeodis/__init__.py

Large diffs are not rendered by default.

155 changes: 155 additions & 0 deletions FastGeodis/fastgeodis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <vector>
#include "fastgeodis.h"
#include "common.h"
#include "geodis_pba.h"

#ifdef _OPENMP
#include <omp.h>
Expand Down Expand Up @@ -420,6 +421,154 @@ torch::Tensor GSF3d_fastmarch(const torch::Tensor &image, const torch::Tensor &m
return Dd_Md + De_Me;
}

torch::Tensor exact_euclidean2d(const torch::Tensor &mask, const std::vector<float> &spacing)
{
// Check input dimensions - expect BCHW format
const int num_dims = mask.dim();
if (num_dims != 4)
{
throw std::invalid_argument(
"exact_euclidean2d only supports 4D inputs (BCHW), received " + std::to_string(num_dims) + "D");
}

// Note: batch and channel dimensions are now supported

if (spacing.size() != 2)
{
throw std::invalid_argument(
"exact_euclidean2d requires 2D spacing, received " + std::to_string(spacing.size()));
}

if (mask.is_cuda())
{
#ifdef WITH_CUDA
if (!torch::cuda::is_available())
{
throw std::runtime_error(
"cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support");
}
return exact_euclidean2d_cuda(mask, spacing);
#else
AT_ERROR("exact_euclidean2d is only available with CUDA support. Not compiled with CUDA.");
#endif
}
else
{
AT_ERROR("exact_euclidean2d is only available on CUDA devices. Please move tensor to GPU.");
}
}

torch::Tensor exact_euclidean3d(const torch::Tensor &mask, const std::vector<float> &spacing)
{
// Check input dimensions - expect BCDHW format
const int num_dims = mask.dim();
if (num_dims != 5)
{
throw std::invalid_argument(
"exact_euclidean3d only supports 5D inputs (BCDHW), received " + std::to_string(num_dims) + "D");
}

// Note: batch and channel dimensions are now supported

if (spacing.size() != 3)
{
throw std::invalid_argument(
"exact_euclidean3d requires 3D spacing, received " + std::to_string(spacing.size()));
}

if (mask.is_cuda())
{
#ifdef WITH_CUDA
if (!torch::cuda::is_available())
{
throw std::runtime_error(
"cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support");
}
return exact_euclidean3d_cuda(mask, spacing);
#else
AT_ERROR("exact_euclidean3d is only available with CUDA support. Not compiled with CUDA.");
#endif
}
else
{
AT_ERROR("exact_euclidean3d is only available on CUDA devices. Please move tensor to GPU.");
}
}

torch::Tensor signed_exact_euclidean2d(const torch::Tensor &mask, const std::vector<float> &spacing)
{
// Check input dimensions - expect BCHW format
const int num_dims = mask.dim();
if (num_dims != 4)
{
throw std::invalid_argument(
"signed_exact_euclidean2d only supports 4D inputs (BCHW), received " + std::to_string(num_dims) + "D");
}

// Note: batch and channel dimensions are now supported

if (spacing.size() != 2)
{
throw std::invalid_argument(
"signed_exact_euclidean2d requires 2D spacing, received " + std::to_string(spacing.size()));
}

if (mask.is_cuda())
{
#ifdef WITH_CUDA
if (!torch::cuda::is_available())
{
throw std::runtime_error(
"cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support");
}
return signed_exact_euclidean2d_cuda(mask, spacing);
#else
AT_ERROR("signed_exact_euclidean2d is only available with CUDA support. Not compiled with CUDA.");
#endif
}
else
{
AT_ERROR("signed_exact_euclidean2d is only available on CUDA devices. Please move tensor to GPU.");
}
}

torch::Tensor signed_exact_euclidean3d(const torch::Tensor &mask, const std::vector<float> &spacing)
{
// Check input dimensions - expect BCDHW format
const int num_dims = mask.dim();
if (num_dims != 5)
{
throw std::invalid_argument(
"signed_exact_euclidean3d only supports 5D inputs (BCDHW), received " + std::to_string(num_dims) + "D");
}

// Note: batch and channel dimensions are now supported

if (spacing.size() != 3)
{
throw std::invalid_argument(
"signed_exact_euclidean3d requires 3D spacing, received " + std::to_string(spacing.size()));
}

if (mask.is_cuda())
{
#ifdef WITH_CUDA
if (!torch::cuda::is_available())
{
throw std::runtime_error(
"cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support");
}
return signed_exact_euclidean3d_cuda(mask, spacing);
#else
AT_ERROR("signed_exact_euclidean3d is only available with CUDA support. Not compiled with CUDA.");
#endif
}
else
{
AT_ERROR("signed_exact_euclidean3d is only available on CUDA devices. Please move tensor to GPU.");
}
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("generalised_geodesic2d", &generalised_geodesic2d, "Generalised Geodesic distance 2d");
Expand Down Expand Up @@ -449,4 +598,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("GSF2d_fastmarch", &GSF2d_fastmarch, "Geodesic Symmetric Filtering 2d using Fast Marching method");
m.def("GSF3d_fastmarch", &GSF3d_fastmarch, "Geodesic Symmetric Filtering 3d using Fast Marching method");

// Exact Euclidean Distance Transform using PBA+ algorithm
m.def("exact_euclidean2d", &exact_euclidean2d, "Exact Euclidean Distance Transform 2D using PBA+ algorithm");
m.def("exact_euclidean3d", &exact_euclidean3d, "Exact Euclidean Distance Transform 3D using PBA+ algorithm");
m.def("signed_exact_euclidean2d", &signed_exact_euclidean2d, "Signed Exact Euclidean Distance Transform 2D using PBA+ algorithm");
m.def("signed_exact_euclidean3d", &signed_exact_euclidean3d, "Signed Exact Euclidean Distance Transform 3D using PBA+ algorithm");

}
20 changes: 19 additions & 1 deletion FastGeodis/fastgeodis.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <torch/extension.h>
#include <vector>
#include "common.h"
#include "geodis_pba.h"

#ifdef WITH_CUDA
torch::Tensor generalised_geodesic2d_cuda(
Expand Down Expand Up @@ -293,4 +294,21 @@ torch::Tensor GSF3d_fastmarch(
const torch::Tensor &mask,
const float &theta,
const std::vector<float> &spacing,
const float &lambda);
const float &lambda);

// Exact Euclidean Distance Transform using PBA+ algorithm
torch::Tensor exact_euclidean2d(
const torch::Tensor &mask,
const std::vector<float> &spacing);

torch::Tensor exact_euclidean3d(
const torch::Tensor &mask,
const std::vector<float> &spacing);

torch::Tensor signed_exact_euclidean2d(
const torch::Tensor &mask,
const std::vector<float> &spacing);

torch::Tensor signed_exact_euclidean3d(
const torch::Tensor &mask,
const std::vector<float> &spacing);
Loading