From 01023200e511fb4cde35b02637ac24c77a14b35a Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Wed, 14 Jan 2026 05:24:47 +0000 Subject: [PATCH 01/12] Add magic_enum dependency --- CMakeLists.txt | 1 + cmake/Dependencies.cmake | 16 ++++++++++++++-- cpp/core/CMakeLists.txt | 1 + 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e4d378862..99ae458264 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ option(MRTRIX_USE_SYSTEM_NIFTI "Use system-installed NIfTI C headers" OFF) option(MRTRIX_USE_SYSTEM_GTEST "Use system-installed Google Test library" OFF) option(MRTRIX_USE_SYSTEM_DAWN "Use system-installed Dawn library" OFF) option(MRTRIX_USE_SYSTEM_SLANG "Use system-installed Slang library" OFF) +option(MRTRIX_USE_SYSTEM_MAGIC_ENUM "Use system-installed Magic Enum library" OFF) option(MRTRIX_USE_SYSTEM_TCB_SPAN "Use system-installed TCB Span library" OFF) if(MRTRIX_BUILD_TESTS) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 710e8d4b7f..e13077200b 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -42,11 +42,11 @@ endif() # Nifti headers if(MRTRIX_USE_SYSTEM_NIFTI) - find_path(NIFTI1_INCLUDE_DIR + find_path(NIFTI1_INCLUDE_DIR NAMES nifti1.h PATHS ${NIFTI_DIR} /usr/include /usr/local/include ) - find_path(NIFTI2_INCLUDE_DIR + find_path(NIFTI2_INCLUDE_DIR NAMES nifti2.h PATHS ${NIFTI_DIR} /usr/include /usr/local/include ) @@ -236,3 +236,15 @@ else() add_library(tcb::span ALIAS tcb_span) endif() +# magic_enum +if(MRTRIX_USE_SYSTEM_MAGIC_ENUM) + find_package(magic_enum 0.9.7 CONFIG REQUIRED) +else() + set(magic_enum_url "https://github.com/Neargye/magic_enum/archive/refs/tags/v0.9.7.tar.gz") + FetchContent_Declare( + magic_enum + DOWNLOAD_EXTRACT_TIMESTAMP ON + URL ${magic_enum_url} + ) + FetchContent_MakeAvailable(magic_enum) +endif() diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt index 28a6951372..1c776027fa 100644 --- a/cpp/core/CMakeLists.txt +++ b/cpp/core/CMakeLists.txt @@ -119,6 +119,7 @@ target_link_libraries(mrtrix-core PUBLIC nlohmann_json::nlohmann_json nifti::nifti tcb::span + magic_enum::magic_enum ) if(MRTRIX_ENABLE_GPU) From 73e11892637fe5ab00175d2d0290a3e564de9856 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Wed, 14 Jan 2026 08:48:00 +0000 Subject: [PATCH 02/12] First draft of affine registration on GPU --- cpp/cmd/mrreggpu.cpp | 486 ++++++++++++++++++ cpp/core/gpu/registration/adabelief.cpp | 83 +++ cpp/core/gpu/registration/adabelief.h | 56 ++ .../gpu/registration/calculatorinterface.h | 53 ++ cpp/core/gpu/registration/calculatoroutput.h | 25 + .../gpu/registration/convergencechecker.cpp | 76 +++ .../gpu/registration/convergencechecker.h | 50 ++ cpp/core/gpu/registration/eigenhelpers.cpp | 88 ++++ cpp/core/gpu/registration/eigenhelpers.h | 45 ++ .../gpu/registration/globalregistration.cpp | 478 +++++++++++++++++ .../gpu/registration/globalregistration.h | 24 + cpp/core/gpu/registration/imageoperations.cpp | 215 ++++++++ cpp/core/gpu/registration/imageoperations.h | 68 +++ cpp/core/gpu/registration/initialisation.cpp | 292 +++++++++++ cpp/core/gpu/registration/initialisation.h | 37 ++ .../initialisation_rotation_search.cpp | 105 ++++ .../initialisation_rotation_search.h | 51 ++ cpp/core/gpu/registration/ncccalculator.cpp | 275 ++++++++++ cpp/core/gpu/registration/ncccalculator.h | 78 +++ cpp/core/gpu/registration/nmicalculator.cpp | 366 +++++++++++++ cpp/core/gpu/registration/nmicalculator.h | 97 ++++ .../gpu/registration/registrationtypes.cpp | 293 +++++++++++ cpp/core/gpu/registration/registrationtypes.h | 154 ++++++ cpp/core/gpu/registration/ssdcalculator.cpp | 163 ++++++ cpp/core/gpu/registration/ssdcalculator.h | 65 +++ cpp/core/gpu/registration/utils.cpp | 144 ++++++ cpp/core/gpu/registration/utils.h | 63 +++ .../gpu/registration/voxelscannermatrices.h | 57 ++ cpp/core/gpu/shaders/atomic_utils.slang | 17 + cpp/core/gpu/shaders/center_of_mass.slang | 63 +++ cpp/core/gpu/shaders/downsample_image.slang | 54 ++ cpp/core/gpu/shaders/parzen_binner.slang | 88 ++++ cpp/core/gpu/shaders/reduction_image.slang | 214 ++++++++ cpp/core/gpu/shaders/reduction_utils.slang | 243 +++++++++ .../registration/coordinate_mapper.slang | 87 ++++ .../shaders/registration/cubic_bspline.slang | 22 + .../registration/global_transformation.slang | 359 +++++++++++++ .../registration/joint_histogram.slang | 206 ++++++++ .../gpu/shaders/registration/moments.slang | 59 +++ cpp/core/gpu/shaders/registration/ncc.slang | 336 ++++++++++++ cpp/core/gpu/shaders/registration/nmi.slang | 375 ++++++++++++++ cpp/core/gpu/shaders/registration/ssd.slang | 137 +++++ .../registration/voxelscannermatrices.slang | 7 + cpp/core/gpu/shaders/texture_utils.slang | 135 +++++ 44 files changed, 6389 insertions(+) create mode 100644 cpp/cmd/mrreggpu.cpp create mode 100644 cpp/core/gpu/registration/adabelief.cpp create mode 100644 cpp/core/gpu/registration/adabelief.h create mode 100644 cpp/core/gpu/registration/calculatorinterface.h create mode 100644 cpp/core/gpu/registration/calculatoroutput.h create mode 100644 cpp/core/gpu/registration/convergencechecker.cpp create mode 100644 cpp/core/gpu/registration/convergencechecker.h create mode 100644 cpp/core/gpu/registration/eigenhelpers.cpp create mode 100644 cpp/core/gpu/registration/eigenhelpers.h create mode 100644 cpp/core/gpu/registration/globalregistration.cpp create mode 100644 cpp/core/gpu/registration/globalregistration.h create mode 100644 cpp/core/gpu/registration/imageoperations.cpp create mode 100644 cpp/core/gpu/registration/imageoperations.h create mode 100644 cpp/core/gpu/registration/initialisation.cpp create mode 100644 cpp/core/gpu/registration/initialisation.h create mode 100644 cpp/core/gpu/registration/initialisation_rotation_search.cpp create mode 100644 cpp/core/gpu/registration/initialisation_rotation_search.h create mode 100644 cpp/core/gpu/registration/ncccalculator.cpp create mode 100644 cpp/core/gpu/registration/ncccalculator.h create mode 100644 cpp/core/gpu/registration/nmicalculator.cpp create mode 100644 cpp/core/gpu/registration/nmicalculator.h create mode 100644 cpp/core/gpu/registration/registrationtypes.cpp create mode 100644 cpp/core/gpu/registration/registrationtypes.h create mode 100644 cpp/core/gpu/registration/ssdcalculator.cpp create mode 100644 cpp/core/gpu/registration/ssdcalculator.h create mode 100644 cpp/core/gpu/registration/utils.cpp create mode 100644 cpp/core/gpu/registration/utils.h create mode 100644 cpp/core/gpu/registration/voxelscannermatrices.h create mode 100644 cpp/core/gpu/shaders/atomic_utils.slang create mode 100644 cpp/core/gpu/shaders/center_of_mass.slang create mode 100644 cpp/core/gpu/shaders/downsample_image.slang create mode 100644 cpp/core/gpu/shaders/parzen_binner.slang create mode 100644 cpp/core/gpu/shaders/reduction_image.slang create mode 100644 cpp/core/gpu/shaders/reduction_utils.slang create mode 100644 cpp/core/gpu/shaders/registration/coordinate_mapper.slang create mode 100644 cpp/core/gpu/shaders/registration/cubic_bspline.slang create mode 100644 cpp/core/gpu/shaders/registration/global_transformation.slang create mode 100644 cpp/core/gpu/shaders/registration/joint_histogram.slang create mode 100644 cpp/core/gpu/shaders/registration/moments.slang create mode 100644 cpp/core/gpu/shaders/registration/ncc.slang create mode 100644 cpp/core/gpu/shaders/registration/nmi.slang create mode 100644 cpp/core/gpu/shaders/registration/ssd.slang create mode 100644 cpp/core/gpu/shaders/registration/voxelscannermatrices.slang create mode 100644 cpp/core/gpu/shaders/texture_utils.slang diff --git a/cpp/cmd/mrreggpu.cpp b/cpp/cmd/mrreggpu.cpp new file mode 100644 index 0000000000..3b2fbbdcaa --- /dev/null +++ b/cpp/cmd/mrreggpu.cpp @@ -0,0 +1,486 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + + #include "adapter/reslice.h" + #include "app.h" + #include "cmdline_option.h" + #include "command.h" // IWYU pragma: keep + #include "datatype.h" + #include "exception.h" + #include "file/matrix.h" + #include "filter/reslice.h" + #include "gpu/gpu.h" + #include "gpu/registration/eigenhelpers.h" + #include "gpu/registration/globalregistration.h" + #include "gpu/registration/registrationtypes.h" + #include "gpu/registration/imageoperations.h" + #include "gpu/registration/imageoperations.h" + #include "header.h" + #include "image.h" + #include "image_helpers.h" + #include "interp/cubic.h" + #include "magic_enum/magic_enum.hpp" + #include "math/average_space.h" + #include "mrtrix.h" + #include "types.h" + + #include + #include + + #include + #include + #include + #include + #include + #include + #include + #include + #include + + using namespace MR; + using namespace App; + + namespace { + + template std::vector lowercase_enum_names() { + static constexpr auto names_view = magic_enum::enum_names(); + std::vector result; + result.reserve(names_view.size()); + for (const auto &s : names_view) { + result.push_back(MR::lowercase(std::string(s))); + } + return result; + } + template std::string enum_name_lowercase(Enum e) { + auto name = magic_enum::enum_name(e); + return MR::lowercase(std::string(name)); + } + + template Enum from_name(std::string_view name) { + auto e = magic_enum::enum_cast(name, magic_enum::case_insensitive); + if (!e.has_value()) { + std::string error = "Unsupported value '" + std::string(name) + "'. Supported values are: "; + const auto names = lowercase_enum_names(); + for (const auto &n : names) { + error += n + ", "; + } + throw Exception(error); + } + return e.value(); + } + + constexpr float default_max_search_angle = 45.0F; + constexpr TransformationType default_transformation_type = TransformationType::Affine; + constexpr MetricType default_metric_type = MetricType::NMI; + constexpr uint32_t default_ncc_window_radius = 0U; + constexpr uint32_t default_max_iterations = 500; + const std::vector supported_metric_types = lowercase_enum_names(); + const std::vector supported_transform_types = lowercase_enum_names(); + const std::vector supported_init_translations = lowercase_enum_names(); + const std::vector supported_init_rotations = lowercase_enum_names(); + + struct HalfwayTransforms { + transform_type half; + transform_type half_inverse; + Eigen::Matrix4d half_matrix; + Eigen::Matrix4d half_inverse_matrix; + }; + + HalfwayTransforms compute_halfway_transforms(const transform_type &scanner_transform) { + const Eigen::Matrix4d matrix = EigenHelpers::to_homogeneous_mat4d(scanner_transform); + const double det = matrix.block<3, 3>(0, 0).determinant(); + if (!std::isfinite(det) || det <= 0.0) { + throw Exception("Cannot compute halfway transform: non-invertible or reflected transform."); + } + const Eigen::Matrix4d half_matrix = matrix.sqrt(); + const Eigen::Matrix4d half_inverse_matrix = half_matrix.inverse(); + return HalfwayTransforms{ + .half = EigenHelpers::from_homogeneous_mat4d(half_matrix), + .half_inverse = EigenHelpers::from_homogeneous_mat4d(half_inverse_matrix), + .half_matrix = half_matrix, + .half_inverse_matrix = half_inverse_matrix, + }; + } + + } // namespace + + // clang-format off + // NOLINTBEGIN(readability-implicit-bool-conversion) + void usage() { + AUTHOR = "Daljit Singh", + SYNOPSIS = "Affine image registration on the GPU."; + + ARGUMENTS + + Argument ("image1 image2", "input image 1 ('moving') and input image 2 ('template')").type_image_in() + + Argument ("contrast1 contrast2", + "optional list of additional input images used as additional contrasts." + " Can be used multiple times." + " contrastX and imageX must share the same coordinate system.").type_image_in().optional().allow_multiple(); + + + OPTIONS + + Option ("transformed", "image1 transformed to image2 space after registration." + " Note that -transformed needs to be repeated for each contrast.") + .allow_multiple() + + Argument("image").type_image_out().optional() + + + Option ("transformed_midway", "image1 and image2 after registration transformed and regridded to the midway space." + " Note that -transformed_midway needs to be repeated for each contrast.") + .allow_multiple() + + Argument("image1_transformed").type_image_out() + + Argument("image2_transformed").type_image_out() + + + Option ("matrix", "write the transformation matrix used for reslicing image1 into image2 space.") + + Argument("filename").type_file_out() + + + Option ("matrix_1tomidway", "write the transformation matrix used for reslicing image1 into midway space.") + + Argument("filename").type_file_out() + + + Option ("matrix_2tomidway", "write the transformation matrix used for reslicing image2 into midway space.") + + Argument("filename").type_file_out() + + + Option ("type", "type of transform (rigid, affine)") + + Argument("name").type_choice(supported_transform_types) + + + Option ("metric", "similarity metric to use (nmi, ssd, ncc)") + + Argument("name").type_choice(supported_metric_types) + + // TODO: Should we mention that using a large window radius (> 3) is not recommended + // as it's computationally expensive and usually does not improve results? + + Option("ncc_radius", + "window radius (in voxels) for the NCC metric; set to 0 for global NCC (default: " + + std::to_string(default_ncc_window_radius) + ").") + + Argument("radius").type_integer(0, 15) + + + Option("mask1", "a mask to define the region of image1 to use for optimisation.") + + Argument("filename").type_image_in() + + + Option("mask2", "a mask to define the region of image2 to use for optimisation.") + + Argument("filename").type_image_in() + + + Option("max_iter", "maximum number of iterations (default: " + std::to_string(default_max_iterations) + ")") + + Argument("number").type_integer(10, 1000) + + + Option("init_translation", + "initialise the translation and centre of rotation;" + " Valid choices are:" + " mass (aligns the centers of mass of both images, default);" + " geometric (aligns geometric image centres);" + " none.") + + Argument("type").type_choice(supported_init_translations) + + + Option("init_rotation", + "Method to use to initialise the rotation." + " Valid choices are:" + " search (search for the best rotation using the selected metric);" + " moments (rotation based on directions of intensity variance with respect to centre of mass);" + " none (default).") + + Argument("type").type_choice(supported_init_rotations) + + + Option("init_rotation_max_angle", + "Maximum rotation angle (degrees) to sample when init_rotation=search (default: " + std::to_string(default_max_search_angle) + + " Use a larger value only when images may be grossly misaligned.") + + Argument("degrees").type_float(0.0, 180.0) + + + Option("init_matrix", + "initialise either the registration with the supplied transformation matrix " + "(as a 4x4 matrix in scanner coordinates). " + "Note that this overrides init_translation and init_rotation initialisation") + + Argument("filename").type_file_in() + + Option ("mc_weights", "relative weight of images used for multi-contrast registration. Default: 1.0 (equal weighting)") + + Argument ("weights").type_sequence_float (); + } + // NOLINTEND(readability-implicit-bool-conversion) + + // clang-format on + + struct HeaderPair { + Header header1; + Header header2; + }; + + struct ImagePair { + Image image1; + Image image2; + }; + + void run() { + auto gpu_context_request = GPU::ComputeContext::request_async(); + std::vector header_pairs; + const size_t arg_size = argument.size(); + if (arg_size % 2 != 0 || arg_size < 2) { + const auto error = MR::join(argument, " "); + throw Exception("Unexpected number of input images, arguments: " + error); + } + + for (size_t i = 0; i < arg_size; i += 2) { + header_pairs.push_back({Header::open(argument[i]), Header::open(argument[i + 1])}); + } + + for (const auto &[header1, header2] : header_pairs) { + if (header1.ndim() != header2.ndim()) { + throw Exception("Input images " + header1.name() + " and " + header2.name() + + " have different number of dimensions: " + std::to_string(header1.ndim()) + " and " + + std::to_string(header2.ndim())); + } + check_3D_nonunity(header1); + check_3D_nonunity(header2); + } + + const TransformationType transform_type = from_name( + get_option_value("type", enum_name_lowercase(default_transformation_type))); + + const MetricType metric_type = + from_name(get_option_value("metric", enum_name_lowercase(default_metric_type))); + + const uint32_t ncc_window_radius = get_option_value("ncc_radius", default_ncc_window_radius); + + std::optional> mask1; + std::optional> mask2; + const auto mask1_option = get_options("mask1"); + const auto mask2_option = get_options("mask2"); + if (!mask1_option.empty()) { + mask1 = Image::open(mask1_option[0][0]); + } + if (!mask2_option.empty()) { + mask2 = Image::open(mask2_option[0][0]); + } + if (mask1) { + if (mask1->ndim() != 3) { + throw Exception("mask1 must be a 3D image."); + } + check_dimensions(*mask1, header_pairs.front().header1, 0, 3); + } + if (mask2) { + if (mask2->ndim() != 3) { + throw Exception("mask2 must be a 3D image."); + } + check_dimensions(*mask2, header_pairs.front().header2, 0, 3); + } + + const uint32_t max_iterations = get_option_value("max_iter", default_max_iterations); + + const auto init_matrix_option = get_options("init_matrix"); + + Eigen::Vector3d centre; + InitialGuess initial_guess; + if (!init_matrix_option.empty()) { + // TODO: compute centre from images. Also check what's the correct thing to do in this case. + initial_guess = File::Matrix::load_transform(init_matrix_option[0][0], centre); + } else { + const InitTranslationChoice init_translation = + from_name(get_option_value("init_translation", "mass")); + const InitRotationChoice init_rotation = + from_name(get_option_value("init_rotation", "none")); + const float init_rotation_max_angle = get_option_value("init_rotation_max_angle", default_max_search_angle); + Metric init_metric; + switch (metric_type) { + case MetricType::NMI: + init_metric = NMIMetric{}; + break; + case MetricType::SSD: + init_metric = SSDMetric{}; + break; + case MetricType::NCC: + init_metric = NCCMetric{.window_radius = ncc_window_radius}; + break; + default: + throw Exception("Unsupported metric type"); + } + + initial_guess = InitialisationOptions{ + .translation_choice = init_translation, + .rotation_choice = init_rotation, + .cost_metric = init_metric, + .max_search_angle_degrees = init_rotation_max_angle, + }; + } + + // TODO: we only support 3D images for now. We'll need to extend this to + // support 4D images later. + for (const auto &[header1, header2] : header_pairs) { + check_dimensions(header1, header_pairs.front().header1, 0, 3); + check_dimensions(header2, header_pairs.front().header2, 0, 3); + + if (header1.ndim() != 3 || header2.ndim() != 3) { + throw Exception("Input images with dimensionality other than 3 are not supported."); + } + } + + const auto weight_options = get_options("mc_weights"); + std::vector mc_weights; + if (!weight_options.empty()) { + mc_weights = parse_floats(weight_options[0][0]); + if (mc_weights.size() == 1) { + mc_weights.resize(header_pairs.size(), mc_weights[0]); + } else if (mc_weights.size() != header_pairs.size()) { + throw Exception("number of mc_weights does not match number of contrasts"); + } + const bool weights_positive = std::all_of(mc_weights.begin(), mc_weights.end(), [](auto w) { return w >= 0.0; }); + if (!weights_positive) { + throw Exception("mc_weights must be non-negative"); + } + } + + std::vector channels; + size_t index = 0; + for (const auto &[header1, header2] : header_pairs) { + const ChannelConfig channel{ + .image1 = Image::open(header1.name()).with_direct_io(), + .image2 = Image::open(header2.name()).with_direct_io(), + .image1Mask = mask1, + .image2Mask = mask2, + .weight = static_cast(mc_weights.empty() ? 1.0F : mc_weights[index]), + }; + + channels.push_back(channel); + ++index; + } + + Metric metric; + switch (metric_type) { + case MetricType::NMI: + metric = NMIMetric{}; + break; + case MetricType::SSD: + metric = SSDMetric{}; + break; + case MetricType::NCC: + metric = NCCMetric{.window_radius = ncc_window_radius}; + break; + default: + throw Exception("Unsupported metric type"); + } + + const RegistrationConfig registration_config{ + .channels = channels, + .transformation_type = transform_type, + .initial_guess = initial_guess, + .metric = metric, + .max_iterations = max_iterations, + }; + + auto gpu_compute_context = gpu_context_request.get(); + const RegistrationResult registration_result = GPU::run_registration(registration_config, gpu_compute_context); + + const std::string matrix_filename = get_option_value("matrix", ""); + const std::string matrix_1tomid_filename = get_option_value("matrix_1tomidway", ""); + const std::string matrix_2tomid_filename = get_option_value("matrix_2tomidway", ""); + + const auto transformed_option = get_options("transformed"); + std::vector transformed_filenames; + if (!transformed_option.empty()) { + if (transformed_option.size() > header_pairs.size()) { + throw Exception("Number of -transformed images exceeds number of contrasts"); + } + if (transformed_option.size() < header_pairs.size()) { + WARN("Number of -transformed images is less than number of contrasts."); + } + for (size_t i = 0; i < transformed_option.size(); ++i) { + const std::filesystem::path output_path(transformed_option[i][0]); + transformed_filenames.push_back(output_path); + const auto input1_path = std::filesystem::path(header_pairs[i].header1.name()); + INFO(input1_path.filename().string() + ", transformed to space of image2, will be saved to " + + output_path.string()); + } + } + + const auto transformed_midway_option = get_options("transformed_midway"); + std::vector transformed_midway1_filenames; + std::vector transformed_midway2_filenames; + if (!transformed_midway_option.empty()) { + if (transformed_midway_option.size() > header_pairs.size()) { + throw Exception("Number of -transformed_midway images exceeds number of contrasts"); + } + if (transformed_midway_option.size() < header_pairs.size()) { + WARN("Number of -transformed_midway images is less than number of contrasts."); + } + for (size_t i = 0; i < transformed_midway_option.size(); ++i) { + if (transformed_midway_option[i].args.size() != 2U) { + throw Exception("Each -transformed_midway option requires two output images."); + } + const std::filesystem::path output1_path(transformed_midway_option[i][0]); + const std::filesystem::path output2_path(transformed_midway_option[i][1]); + transformed_midway1_filenames.push_back(output1_path); + transformed_midway2_filenames.push_back(output2_path); + const auto input1_path = std::filesystem::path(header_pairs[i].header1.name()); + const auto input2_path = std::filesystem::path(header_pairs[i].header2.name()); + INFO(input1_path.filename().string() + ", transformed to midway space, will be saved to " + + output1_path.string()); + INFO(input2_path.filename().string() + ", transformed to midway space, will be saved to " + + output2_path.string()); + } + } + + const bool needs_halfway_transforms = + !transformed_midway1_filenames.empty() || !matrix_1tomid_filename.empty() || !matrix_2tomid_filename.empty(); + std::optional halfway_transforms; + if (needs_halfway_transforms) { + halfway_transforms = compute_halfway_transforms(registration_result.transformation); + } + + if (!matrix_filename.empty() || !matrix_1tomid_filename.empty() || !matrix_2tomid_filename.empty()) { + const Eigen::Vector3d centre = image_centre_scanner_space(header_pairs.front().header1); + if (!matrix_filename.empty()) { + File::Matrix::save_transform(registration_result.transformation, centre, matrix_filename); + } + if (!matrix_1tomid_filename.empty()) { + File::Matrix::save_transform(halfway_transforms->half, centre, matrix_1tomid_filename); + } + if (!matrix_2tomid_filename.empty()) { + File::Matrix::save_transform(halfway_transforms->half_inverse, centre, matrix_2tomid_filename); + } + } + + if (!transformed_filenames.empty()) { + const size_t transforms_to_write = std::min(transformed_filenames.size(), header_pairs.size()); + for (size_t idx = 0; idx < transforms_to_write; ++idx) { + const auto &[header1, header2] = header_pairs[idx]; + + Image input_image = Image::open(header1.name()); + Header output_header(header2); + output_header.datatype() = DataType::from(); + auto output_image = Image::create(transformed_filenames[idx].string(), output_header).with_direct_io(); + + Filter::reslice( + input_image, output_image, registration_result.transformation, Adapter::AutoOverSample, 0.0F); + } + } + + if (!transformed_midway1_filenames.empty()) { + // Compute midpioint transforms in scanner space and then build a midway output header that can hold both images + using ProjectiveTransform = Eigen::Transform; + const ProjectiveTransform half_projective(halfway_transforms->half_matrix); + const ProjectiveTransform half_inverse_projective(halfway_transforms->half_inverse_matrix); + + const size_t transforms_to_write = std::min(transformed_midway1_filenames.size(), header_pairs.size()); + for (size_t idx = 0; idx < transforms_to_write; ++idx) { + const auto &[header1, header2] = header_pairs[idx]; + Header output_header = compute_minimum_average_header(header1, header2, half_inverse_projective, half_projective); + output_header.datatype() = DataType::from(); + + Image input_image1 = Image::open(header1.name()); + auto output_image1 = Image::create(transformed_midway1_filenames[idx].string(), output_header).with_direct_io(); + Filter::reslice( + input_image1, output_image1, halfway_transforms->half, Adapter::AutoOverSample, 0.0F); + + Image input_image2 = Image::open(header2.name()); + auto output_image2 = Image::create(transformed_midway2_filenames[idx].string(), output_header).with_direct_io(); + Filter::reslice( + input_image2, output_image2, halfway_transforms->half_inverse, Adapter::AutoOverSample, 0.0F); + } + } + } diff --git a/cpp/core/gpu/registration/adabelief.cpp b/cpp/core/gpu/registration/adabelief.cpp new file mode 100644 index 0000000000..b3c2783adf --- /dev/null +++ b/cpp/core/gpu/registration/adabelief.cpp @@ -0,0 +1,83 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/adabelief.h" +#include +#include +#include +#include +#include +#include +#include + +AdaBelief::AdaBelief(const std::vector ¶meters, float beta1, float beta2, float epsilon) + : m_parameters(parameters), + m_beta1(beta1), + m_beta2(beta2), + m_epsilon(epsilon), + m_timeStep(0), + m_firstMoments(parameters.size(), 0.0F), + m_secondMoments(parameters.size(), 0.0F), + m_mask(parameters.size(), 0), + m_updates(parameters.size(), 0.0F) {} + +std::vector AdaBelief::parameterValues() const { + std::vector values(m_parameters.size()); + std::transform( + m_parameters.begin(), m_parameters.end(), values.begin(), [](const Parameter ¶m) { return param.value; }); + return values; +} + +void AdaBelief::setParameterValues(tcb::span values) { + for (size_t i = 0; i < m_parameters.size() && i < values.size(); ++i) { + m_parameters[i].value = values[i]; + } +} + +void AdaBelief::reset() { + m_timeStep = 0; + m_firstMoments.assign(m_parameters.size(), 0.0F); + m_secondMoments.assign(m_parameters.size(), 0.0F); + m_mask.assign(m_parameters.size(), 0); + m_updates.assign(m_parameters.size(), 0.0F); +} + +void AdaBelief::step(const std::vector &gradients) { + assert(gradients.size() == m_parameters.size()); + ++m_timeStep; + + // First pass: compute moment estimates, the update direction, and binary masks + for (size_t i = 0; i < m_parameters.size(); ++i) { + m_firstMoments[i] = m_beta1 * m_firstMoments[i] + (1.0F - m_beta1) * gradients[i]; + const float diff = gradients[i] - m_firstMoments[i]; + m_secondMoments[i] = m_beta2 * m_secondMoments[i] + (1.0F - m_beta2) * diff * diff; + const float firstMomentCorrected = m_firstMoments[i] / (1.0F - std::pow(m_beta1, m_timeStep)); + const float secondMomentCorrected = m_secondMoments[i] / (1.0F - std::pow(m_beta2, m_timeStep)); + m_updates[i] = firstMomentCorrected / (std::sqrt(secondMomentCorrected) + m_epsilon); + // Create mask: update only if update and gradient are aligned (i.e. product > 0) + m_mask[i] = (m_updates[i] * gradients[i] > 0.0F) ? 1 : 0; + } + + // Compute the average mask value across all parameters + const float maskSum = std::accumulate(m_mask.begin(), m_mask.end(), 0); + const float maskMean = maskSum / static_cast(m_mask.size()); + + // Second pass: apply cautious update to each parameter + for (size_t i = 0; i < m_parameters.size(); ++i) { + // If the mask is 0, no update is applied; otherwise, scale the update to compensate for the overall sparsity. + m_parameters[i].value -= m_parameters[i].learning_rate * (m_updates[i] * m_mask[i] / (maskMean + m_epsilon)); + } +} diff --git a/cpp/core/gpu/registration/adabelief.h b/cpp/core/gpu/registration/adabelief.h new file mode 100644 index 0000000000..160b77bd05 --- /dev/null +++ b/cpp/core/gpu/registration/adabelief.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include +#include +#include + +// AdaBelief is an improved version of Adam that takes into account the curvature of the loss function. +// See here https://arxiv.org/abs/2010.07468 +// This version is further enhanced by the idea in this paper https://arxiv.org/abs/2411.16085 +// which consists in performing an element‐wise mask to the update such that only the components +// where the proposed update direction and the current gradient are aligned +// (i.e., have the same sign) are applied. This ensures that every step reliably reduces the loss +// and avoids potential overshooting or oscillations. +class AdaBelief { +public: + struct Parameter { + float value; + float learning_rate; + }; + + AdaBelief(const std::vector ¶meters, float beta1 = 0.7F, float beta2 = 0.9999F, float epsilon = 1e-6F); + + void step(const std::vector &gradients); + std::vector parameterValues() const; + void setParameterValues(tcb::span values); + + // Resets the optimizer internal state (moments and timestep) but keeps the parameters unchanged + void reset(); + +private: + std::vector m_parameters; + float m_beta1; + float m_beta2; + float m_epsilon; + int m_timeStep; + std::vector m_firstMoments; // Exponential moving average of gradients (m_t) + std::vector m_secondMoments; // Exponential moving average of squared deviations ((g_t - m_t)^2) + std::vector m_mask; + std::vector m_updates; +}; diff --git a/cpp/core/gpu/registration/calculatorinterface.h b/cpp/core/gpu/registration/calculatorinterface.h new file mode 100644 index 0000000000..028f336617 --- /dev/null +++ b/cpp/core/gpu/registration/calculatorinterface.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "match_variant.h" +#include "gpu/registration/ncccalculator.h" +#include "gpu/registration/nmicalculator.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/ssdcalculator.h" + +#include +#include + +namespace MR::GPU { + +class Calculator final : public std::variant { +public: + // Expose std::variant constructors publicly + using std::variant::variant; + + struct Config { + Texture fixed_texture; + Texture moving_texture; + MR::Transform fixed_transform; + MR::Transform moving_transform; + float downscale_factor; + Metric metric; + }; + + void update(const GlobalTransform &transformation) { + MR::match_v(*this, [&](auto &&arg) { arg.update(transformation); }); + } + + IterationResult get_result() const { + return MR::match_v(*this, [&](auto &&arg) { return arg.get_result(); }); + } +}; + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/calculatoroutput.h b/cpp/core/gpu/registration/calculatoroutput.h new file mode 100644 index 0000000000..f7b75a6586 --- /dev/null +++ b/cpp/core/gpu/registration/calculatoroutput.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include + +namespace MR::GPU { + +enum class CalculatorOutput : uint8_t { Cost, CostAndGradients }; + +} diff --git a/cpp/core/gpu/registration/convergencechecker.cpp b/cpp/core/gpu/registration/convergencechecker.cpp new file mode 100644 index 0000000000..0efed148b9 --- /dev/null +++ b/cpp/core/gpu/registration/convergencechecker.cpp @@ -0,0 +1,76 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/convergencechecker.h" +#include "exception.h" +#include +#include +#include +#include +#include +#include +#include + +namespace MR { + +ConvergenceChecker::ConvergenceChecker(const Config &checkerConfiguration) : m_configuration(checkerConfiguration) { + assert(!m_configuration.param_thresholds.empty()); + assert(m_configuration.patienceLimit > 0U); +} + +bool ConvergenceChecker::has_converged(tcb::span current_params, float current_cost) { + if (m_configuration.param_thresholds.size() != current_params.size()) { + throw std::invalid_argument("ConvergenceChecker::has_converged: parameter threshold configuration mismatch."); + } + + if (!m_initialized) { + DEBUG("ConvergenceChecker: Initializing with first parameters and cost."); + m_minimum_cost = current_cost; + m_best_params.assign(current_params.begin(), current_params.end()); + m_initialized = true; + return false; + } + + const bool has_better_cost = current_cost < m_minimum_cost; + + const auto significant_param_improvement = [&]() { + for (size_t idx = 0; idx < current_params.size(); ++idx) { + const float param_diff = std::fabs(m_best_params[idx] - current_params[idx]); + if (param_diff >= m_configuration.param_thresholds[idx]) { + return true; + } + } + return false; + }(); + + if (has_better_cost) { + m_minimum_cost = current_cost; + m_best_params.assign(current_params.begin(), current_params.end()); + // For the patience counter, only significant parameter improvements count + m_patience_counter = significant_param_improvement ? 0U : m_patience_counter + 1U; + DEBUG("ConvergenceChecker: Better cost found. Resetting patience counter to " + std::to_string(m_patience_counter) + + "."); + } else { + DEBUG("ConvergenceChecker: No better cost found. Incrementing patience counter."); + ++m_patience_counter; + } + + return m_patience_counter >= m_configuration.patienceLimit; +} + +void ConvergenceChecker::reset_patience() { m_patience_counter = 0; } + +} // namespace MR diff --git a/cpp/core/gpu/registration/convergencechecker.h b/cpp/core/gpu/registration/convergencechecker.h new file mode 100644 index 0000000000..8616c2e317 --- /dev/null +++ b/cpp/core/gpu/registration/convergencechecker.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include +#include +#include +#include + +namespace MR { + +struct ConvergenceChecker { + + struct Config { + // Minimum required cost improvement to reset patience counter + uint32_t patienceLimit = 10; + // Absolute thresholds for each parameter + // NOTE: the order must match the order of parameters in the optimization + std::vector param_thresholds; + }; + + explicit ConvergenceChecker(const Config &checkerConfiguration); + + bool has_converged(tcb::span current_transform_params, float current_cost); + + void reset_patience(); + +private: + float m_minimum_cost = std::numeric_limits::max(); + uint32_t m_patience_counter = 0; + bool m_initialized = false; + Config m_configuration; + std::vector m_best_params; +}; + +} // namespace MR diff --git a/cpp/core/gpu/registration/eigenhelpers.cpp b/cpp/core/gpu/registration/eigenhelpers.cpp new file mode 100644 index 0000000000..b00c028b98 --- /dev/null +++ b/cpp/core/gpu/registration/eigenhelpers.cpp @@ -0,0 +1,88 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/eigenhelpers.h" +#include "types.h" + +#include +#include +#include + +namespace MR::EigenHelpers { + +namespace { +template Eigen::Matrix to_homogeneous_mat4(const transform_type &source_transform) { + Eigen::Matrix matrix = Eigen::Matrix::Identity(); + matrix.template block<3, 4>(0, 0) = source_transform.matrix().template cast(); + return matrix; +} + +template transform_type from_homogeneous_mat4(const Eigen::Matrix &source_matrix) { + transform_type result; + result.linear() = source_matrix.template block<3, 3>(0, 0).template cast(); + result.translation() = source_matrix.template block<3, 1>(0, 3).template cast(); + return result; +} +} // namespace + +// Eigen::Transform +Eigen::Matrix4f to_homogeneous_mat4f(const transform_type &source_transform) { + return to_homogeneous_mat4(source_transform); +} + +Eigen::Matrix4d to_homogeneous_mat4d(const transform_type &source_transform) { + return to_homogeneous_mat4(source_transform); +} + +transform_type from_homogeneous_mat4f(const Eigen::Matrix4f &source_matrix) { + return from_homogeneous_mat4(source_matrix); +} + +transform_type from_homogeneous_mat4d(const Eigen::Matrix4d &source_matrix) { + return from_homogeneous_mat4(source_matrix); +} + +std::array to_array(const Eigen::Matrix4f &matrix) { + std::array array{}; + Eigen::Map(array.data()) = matrix; + return array; +} + +std::array to_array(const transform_type &transform) { + return to_array(to_homogeneous_mat4f(transform)); +} + +Eigen::Vector3f to_vector3f(const std::array &array) { + return Eigen::Vector3f(array[0], array[1], array[2]); +} + +std::array to_array(const Eigen::Vector3f &vector) { return {vector.x(), vector.y(), vector.z()}; } + +Eigen::Matrix4f make_scaling_mat4f(float scale_factor) { + Eigen::Matrix4f m = Eigen::Matrix4f::Identity(); + m(0, 0) = scale_factor; + m(1, 1) = scale_factor; + m(2, 2) = scale_factor; + return m; +} + +std::string to_string(const Eigen::Matrix4f &matrix) { + std::stringstream ss; + ss << matrix; + return ss.str(); +} + +} // namespace MR::EigenHelpers diff --git a/cpp/core/gpu/registration/eigenhelpers.h b/cpp/core/gpu/registration/eigenhelpers.h new file mode 100644 index 0000000000..4758319e01 --- /dev/null +++ b/cpp/core/gpu/registration/eigenhelpers.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "types.h" + +#include + +#include +#include + +namespace MR::EigenHelpers { + +Eigen::Matrix4f to_homogeneous_mat4f(const transform_type &source_transform); +Eigen::Matrix4d to_homogeneous_mat4d(const transform_type &source_transform); + +transform_type from_homogeneous_mat4f(const Eigen::Matrix4f &source_matrix); +transform_type from_homogeneous_mat4d(const Eigen::Matrix4d &source_matrix); + +std::array to_array(const Eigen::Matrix4f &matrix); +std::array to_array(const transform_type &transform); +std::array to_array(const Eigen::Vector3f &vector); + +Eigen::Vector3f to_vector3f(const std::array &array); + +/// Returns a 4x4 homogeneous scaling matrix for the given scale factor. +Eigen::Matrix4f make_scaling_mat4f(float scale_factor); + +std::string to_string(const Eigen::Matrix4f &matrix); + +} // namespace MR::EigenHelpers diff --git a/cpp/core/gpu/registration/globalregistration.cpp b/cpp/core/gpu/registration/globalregistration.cpp new file mode 100644 index 0000000000..b7fce61c83 --- /dev/null +++ b/cpp/core/gpu/registration/globalregistration.cpp @@ -0,0 +1,478 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/globalregistration.h" +#include "gpu/registration/adabelief.h" +#include "gpu/registration/calculatorinterface.h" +#include "gpu/registration/convergencechecker.h" +#include "gpu/registration/eigenhelpers.h" +#include "exception.h" +#include "gpu/gpu.h" +#include "header.h" +#include "image.h" +#include "gpu/registration/imageoperations.h" +#include "gpu/registration/initialisation.h" +#include "match_variant.h" +#include "gpu/registration/ncccalculator.h" +#include "gpu/registration/nmicalculator.h" +#include "gpu/registration/registrationtypes.h" +#include +#include "gpu/registration/ssdcalculator.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace MR; +using namespace MR::GPU; + +constexpr float base_learning_rate = 0.1F; +// Threshold for considering translation parameters to have changed significantly (in mm) +constexpr float translation_significant_threshold = 1e-2F; +// Patience limits for convergence checking +// We want a higher patience on the coarsest level to get a good initial alignment +constexpr uint32_t coarsest_level_patience = 10U; +constexpr uint32_t finer_levels_patience = 5U; + +namespace { +AdaBelief create_optimiser(tcb::span initial_params, float translation_learning_rate) { + std::vector optimization_parameters(initial_params.size()); + for (size_t i = 0; i < optimization_parameters.size(); ++i) { + auto ¶m = optimization_parameters[i]; + param.value = initial_params[i]; + // Non-translation parameters have smaller learning rate to account for their larger impact on the transformation + param.learning_rate = i < 3 ? translation_learning_rate : translation_learning_rate * 1e-3F; + } + return AdaBelief(optimization_parameters); +} + +std::vector make_convergence_thresholds(size_t param_count) { + std::vector thresholds(param_count); + for (size_t i = 0; i < param_count; ++i) { + thresholds[i] = (i < 3) ? translation_significant_threshold : translation_significant_threshold * 1e-2F; + } + return thresholds; +} +} // namespace + +namespace MR::GPU { + +struct WeightedGradients { + explicit WeightedGradients(size_t N) : m_gradients(N, 0.0F) {} + void add(const std::vector &gradients, float weight) { + if (gradients.size() != m_gradients.size()) { + throw std::logic_error("WeightedGradients::add: gradient size mismatch"); + } + for (size_t i = 0; i < gradients.size(); ++i) { + m_gradients[i] += weight * gradients[i]; + } + } + + const std::vector &get() const { return m_gradients; } + +private: + std::vector m_gradients; +}; + +struct LevelData { + Texture texture1; + Texture texture2; + std::optional moving_mask; + std::optional fixed_mask; + Calculator calculator; + std::optional reverse_calculator; +}; + +struct ChannelData { + std::vector levels; + float weight = 1.0F; +}; + +RegistrationResult run_registration(const RegistrationConfig &config, const GPU::ComputeContext &context) { + constexpr uint32_t num_levels = 3U; + const bool is_affine = config.transformation_type == TransformationType::Affine; + const uint32_t degrees_of_freedom = is_affine ? 12U : 6U; + + std::vector channels_data; + for (const auto &channel_config : config.channels) { + const auto &image1 = channel_config.image1; + const auto &image2 = channel_config.image2; + + const Texture texture1 = context.new_texture_from_host_image(image1); + const Texture texture2 = context.new_texture_from_host_image(image2); + std::optional texture1_mask; + if (channel_config.image1Mask) { + texture1_mask = context.new_texture_from_host_image(*channel_config.image1Mask); + } + std::optional texture2_mask; + if (channel_config.image2Mask) { + texture2_mask = context.new_texture_from_host_image(*channel_config.image2Mask); + } + + const std::vector pyramid1 = createDownsampledPyramid(texture1, num_levels, context); + const std::vector pyramid2 = createDownsampledPyramid(texture2, num_levels, context); + std::vector pyramid1_mask; + if (texture1_mask) { + pyramid1_mask = createDownsampledPyramid(*texture1_mask, num_levels, context); + } + std::vector pyramid2_mask; + if (texture2_mask) { + pyramid2_mask = createDownsampledPyramid(*texture2_mask, num_levels, context); + } + + std::vector levels; + for (size_t level = 0; level < num_levels; ++level) { + // The pyramid is arranged so index 0 is the lowest resolution and + // index (num_levels-1) is full resolution. The transform downscale is + // how much the texture is downsampled relative to the original image. + const float level_downscale = std::exp2f(static_cast(num_levels - 1U - level)); + + std::optional level_moving_mask; + if (!pyramid1_mask.empty()) { + level_moving_mask = pyramid1_mask[level]; + } + std::optional level_fixed_mask; + if (!pyramid2_mask.empty()) { + level_fixed_mask = pyramid2_mask[level]; + } + + auto calculator = MR::match_v( + config.metric, + [&](const NMIMetric &nmi_metric) -> Calculator { + const NMICalculator::Config nmi_config{ + .transformation_type = config.transformation_type, + .fixed = pyramid2[level], + .moving = pyramid1[level], + .fixed_mask = level_fixed_mask, + .moving_mask = level_moving_mask, + .voxel_scanner_matrices = VoxelScannerMatrices::from_image_pair(image1, image2, level_downscale), + .num_bins = nmi_metric.num_bins, + .output = CalculatorOutput::CostAndGradients, + .context = &context, + }; + return NMICalculator(nmi_config); + }, + [&](const SSDMetric &) -> Calculator { + const SSDCalculator::Config ssd_config{ + .transformation_type = config.transformation_type, + .fixed = pyramid2[level], + .moving = pyramid1[level], + .fixed_mask = level_fixed_mask, + .moving_mask = level_moving_mask, + .voxel_scanner_matrices = VoxelScannerMatrices::from_image_pair(image1, image2, level_downscale), + .output = CalculatorOutput::CostAndGradients, + .context = &context, + }; + return SSDCalculator(ssd_config); + }, + [&](const NCCMetric &ncc_metric) -> Calculator { + const NCCCalculator::Config ncc_config{ + .transformation_type = config.transformation_type, + .fixed = pyramid2[level], + .moving = pyramid1[level], + .fixed_mask = level_fixed_mask, + .moving_mask = level_moving_mask, + .voxel_scanner_matrices = VoxelScannerMatrices::from_image_pair(image1, image2, level_downscale), + .window_radius = ncc_metric.window_radius, + .output = CalculatorOutput::CostAndGradients, + .context = &context, + }; + return NCCCalculator(ncc_config); + }); + + std::optional reverse_calculator; + if (level == (num_levels - 1U)) { + reverse_calculator = MR::match_v( + config.metric, + [&](const NMIMetric &nmi_metric) -> Calculator { + const NMICalculator::Config nmi_config{ + .transformation_type = config.transformation_type, + .fixed = pyramid1[level], + .moving = pyramid2[level], + .fixed_mask = level_moving_mask, + .moving_mask = level_fixed_mask, + .voxel_scanner_matrices = + VoxelScannerMatrices::from_image_pair(image2, image1, level_downscale), + .num_bins = nmi_metric.num_bins, + .output = CalculatorOutput::CostAndGradients, + .context = &context, + }; + return NMICalculator(nmi_config); + }, + [&](const SSDMetric &) -> Calculator { + const SSDCalculator::Config ssd_config{ + .transformation_type = config.transformation_type, + .fixed = pyramid1[level], + .moving = pyramid2[level], + .fixed_mask = level_moving_mask, + .moving_mask = level_fixed_mask, + .voxel_scanner_matrices = + VoxelScannerMatrices::from_image_pair(image2, image1, level_downscale), + .output = CalculatorOutput::CostAndGradients, + .context = &context, + }; + return SSDCalculator(ssd_config); + }, + [&](const NCCMetric &ncc_metric) -> Calculator { + const NCCCalculator::Config ncc_config{ + .transformation_type = config.transformation_type, + .fixed = pyramid1[level], + .moving = pyramid2[level], + .fixed_mask = level_moving_mask, + .moving_mask = level_fixed_mask, + .voxel_scanner_matrices = + VoxelScannerMatrices::from_image_pair(image2, image1, level_downscale), + .window_radius = ncc_metric.window_radius, + .output = CalculatorOutput::CostAndGradients, + .context = &context, + }; + return NCCCalculator(ncc_config); + }); + } + levels.emplace_back( + LevelData{pyramid1[level], + pyramid2[level], + level_moving_mask, + level_fixed_mask, + std::move(calculator), + std::move(reverse_calculator)}); + } + channels_data.emplace_back(ChannelData{.levels = levels, .weight = channel_config.weight}); + } + if (channels_data.empty()) { + throw MR::Exception("No channels provided for registration"); + } + + const GlobalTransform initial_transform = [&]() { + return MR::match_v( + config.initial_guess, + [&](const transform_type &t) { + return GlobalTransform::from_affine_compact( + t, + image_centre_scanner_space, float>(config.channels.front().image1), + config.transformation_type); + }, + [&](const InitialisationOptions &init_options) { + // Use the lowest resolution level for initialisation from the first channel only + const auto &first_level = channels_data.front().levels.front(); + const float init_transform_downscale = std::exp2f(static_cast(num_levels - 1U)); + const auto voxel_scanner = VoxelScannerMatrices::from_image_pair( + config.channels.front().image1, config.channels.front().image2, init_transform_downscale); + + const InitialisationConfig init_config{ + .moving_texture = first_level.texture1, + .target_texture = first_level.texture2, + .moving_mask = first_level.moving_mask, + .target_mask = first_level.fixed_mask, + .voxel_scanner_matrices = voxel_scanner, + .options = init_options, + }; + + const auto rigid = initialise_transformation(init_config, context); + + return config.transformation_type == TransformationType::Rigid ? rigid.as_rigid() : rigid.as_affine(); + }); + }(); + + const auto convergence_thresholds = make_convergence_thresholds(degrees_of_freedom); + + GlobalTransform current_transform = initial_transform; + GlobalTransform best_transform = initial_transform; + + // To make registration symmetric we run the registration in both directions + // and take the average of the resulting transforms at each iteration using Lie algebra averaging. + // This avoids the need of defining an average middle space which can introduce sampling bias. + // See https://doi.org/10.1117/1.jmi.1.2.024003 by Modat et al. + // This only needs to be be done for the final level since lower levels are just approximations. + // TODO: verify the symmetricity of the registration process. + for (size_t level = 0; level < num_levels; ++level) { + GlobalTransform best_transform_level = current_transform; + float best_cost = std::numeric_limits::infinity(); + ConvergenceChecker convergence_checker(ConvergenceChecker::Config{ + .patienceLimit = (level == 0U) ? coarsest_level_patience : finer_levels_patience, + .param_thresholds = convergence_thresholds, + }); + const float learning_rate = base_learning_rate / std::pow(2.0F, static_cast(level)); + + const bool symmetric_level = (level == (num_levels - 1U)); + if (!symmetric_level) { + AdaBelief adabelief = create_optimiser(current_transform.parameters(), learning_rate); + + for (size_t iter = 0; iter < config.max_iterations; ++iter) { + // Dispatch gradient calculations for all channels + for (auto &channel_data : channels_data) { + auto &channel_calculator = channel_data.levels[level].calculator; + channel_calculator.update(current_transform); + } + + WeightedGradients channel_gradients(degrees_of_freedom); + float total_cost = 0.0F; + // Gather results for each channel accumulating gradients and cost + for (const auto &channel_data : channels_data) { + const auto &channel_calculator = channel_data.levels[level].calculator; + const auto channel_result = channel_calculator.get_result(); + channel_gradients.add(channel_result.gradients, channel_data.weight); + total_cost += channel_result.cost * channel_data.weight; + } + + if (total_cost < best_cost) { + best_cost = total_cost; + best_transform_level = current_transform; + } + + INFO("Current transformation matrix at level " + std::to_string(level) + " iteration " + + std::to_string(iter) + ":\n" + EigenHelpers::to_string(current_transform.to_matrix4f())); + + INFO("Level " + std::to_string(level) + ", Iteration " + std::to_string(iter) + + ", Cost: " + std::to_string(total_cost)); + + if (convergence_checker.has_converged(current_transform.parameters(), total_cost)) { + CONSOLE("Convergence reached at level " + std::to_string(level) + " after " + std::to_string(iter) + + " iterations."); + break; + } + adabelief.step(channel_gradients.get()); + const auto updated_params = adabelief.parameterValues(); + current_transform.set_params(updated_params); + } + current_transform = best_transform_level; + best_transform = best_transform_level; + continue; + } + + const auto pivot_moving = + image_centre_scanner_space, float>(config.channels.front().image1).template cast(); + const auto pivot_fixed = + image_centre_scanner_space, float>(config.channels.front().image2).template cast(); + + // Re-parameterise the current transform around the fixed pivot for the forward direction. + GlobalTransform current_transform_fwd = GlobalTransform::from_affine_compact( + current_transform.to_affine_compact(), pivot_fixed, config.transformation_type); + GlobalTransform current_transform_bwd = GlobalTransform::from_affine_compact( + current_transform.to_affine_compact().inverse(), pivot_moving, config.transformation_type); + + AdaBelief adabelief_fwd = create_optimiser(current_transform_fwd.parameters(), learning_rate); + AdaBelief adabelief_bwd = create_optimiser(current_transform_bwd.parameters(), learning_rate); + + for (size_t iter = 0; iter < config.max_iterations; ++iter) { + // Dispatch gradient calculations for all channels in both directions. + for (auto &channel_data : channels_data) { + auto &channel_calculator_fwd = channel_data.levels[level].calculator; + channel_calculator_fwd.update(current_transform_fwd); + + auto &reverse_calculator = channel_data.levels[level].reverse_calculator; + reverse_calculator->update(current_transform_bwd); + } + + WeightedGradients channel_gradients_fwd(degrees_of_freedom); + WeightedGradients channel_gradients_bwd(degrees_of_freedom); + float total_cost_fwd = 0.0F; + float total_cost_bwd = 0.0F; + + // Gather results for each channel accumulating gradients and cost. + for (const auto &channel_data : channels_data) { + const auto &channel_calculator_fwd = channel_data.levels[level].calculator; + const auto channel_result_fwd = channel_calculator_fwd.get_result(); + channel_gradients_fwd.add(channel_result_fwd.gradients, channel_data.weight); + total_cost_fwd += channel_result_fwd.cost * channel_data.weight; + + const auto &reverse_calculator = channel_data.levels[level].reverse_calculator; + const auto channel_result_bwd = reverse_calculator->get_result(); + channel_gradients_bwd.add(channel_result_bwd.gradients, channel_data.weight); + total_cost_bwd += channel_result_bwd.cost * channel_data.weight; + } + + const float total_cost = total_cost_fwd + total_cost_bwd; + if (total_cost < best_cost) { + best_cost = total_cost; + best_transform_level = current_transform_fwd; + } + + INFO("Current transformation matrix (fwd) at level " + std::to_string(level) + " iteration " + + std::to_string(iter) + ":\n" + EigenHelpers::to_string(current_transform_fwd.to_matrix4f())); + INFO("Current transformation matrix (bwd) at level " + std::to_string(level) + " iteration " + + std::to_string(iter) + ":\n" + EigenHelpers::to_string(current_transform_bwd.to_matrix4f())); + + INFO("Level " + std::to_string(level) + ", Iteration " + std::to_string(iter) + + ", Cost (fwd+bwd): " + std::to_string(total_cost_fwd) + "+" + std::to_string(total_cost_bwd) + + " = " + std::to_string(total_cost)); + + if (convergence_checker.has_converged(current_transform_fwd.parameters(), total_cost)) { + CONSOLE("Convergence reached at level " + std::to_string(level) + " after " + + std::to_string(iter) + " iterations."); + break; + } + + adabelief_fwd.step(channel_gradients_fwd.get()); + adabelief_bwd.step(channel_gradients_bwd.get()); + + const auto updated_params_fwd = adabelief_fwd.parameterValues(); + const auto updated_params_bwd = adabelief_bwd.parameterValues(); + current_transform_fwd.set_params(updated_params_fwd); + current_transform_bwd.set_params(updated_params_bwd); + + // Lie algebra averaging to enforce symmetry. + const Eigen::Matrix4f t_fwd = current_transform_fwd.to_matrix4f(); + const Eigen::Matrix4f t_bwd = current_transform_bwd.to_matrix4f(); + const Eigen::Matrix4f mean_log = 0.5F * (t_fwd.log() + t_bwd.inverse().log()); + const Eigen::Matrix4f avg_mat = mean_log.exp(); + const auto avg_tform = EigenHelpers::from_homogeneous_mat4f(avg_mat); + + current_transform_fwd = GlobalTransform::from_affine_compact(avg_tform, pivot_fixed, config.transformation_type); + current_transform_bwd = + GlobalTransform::from_affine_compact(avg_tform.inverse(), pivot_moving, config.transformation_type); + + const auto params_fwd = current_transform_fwd.parameters(); + adabelief_fwd.setParameterValues(params_fwd); + const auto params_bwd = current_transform_bwd.parameters(); + adabelief_bwd.setParameterValues(params_bwd); + } + + current_transform = best_transform_level; + best_transform = best_transform_level; + } + + INFO("Final transformation matrix:\n" + EigenHelpers::to_string(best_transform.to_matrix4f())); + return RegistrationResult{.transformation = best_transform.to_affine_compact()}; +} + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/globalregistration.h b/cpp/core/gpu/registration/globalregistration.h new file mode 100644 index 0000000000..8c11df651f --- /dev/null +++ b/cpp/core/gpu/registration/globalregistration.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" + +namespace MR::GPU { +RegistrationResult run_registration(const RegistrationConfig &config, const GPU::ComputeContext &context); +} diff --git a/cpp/core/gpu/registration/imageoperations.cpp b/cpp/core/gpu/registration/imageoperations.cpp new file mode 100644 index 0000000000..032c0266c2 --- /dev/null +++ b/cpp/core/gpu/registration/imageoperations.cpp @@ -0,0 +1,215 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/imageoperations.h" +#include "gpu/registration/eigenhelpers.h" +#include "gpu/gpu.h" +#include +#include "types.h" +#include "gpu/registration/utils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { +using Coordinate3D = std::array; + +struct MomentUniforms { + alignas(16) std::array centre{}; +}; +static_assert(sizeof(MomentUniforms) % 16 == 0, "MomentUniforms must be 16-byte aligned"); + +Coordinate3D centerOfMass(const Texture &texture, + const ComputeContext &context, + const transform_type &imageTransform, + std::optional mask) { + const WorkgroupSize workgroupSize{.x = 8, .y = 8, .z = 4}; + + const Buffer weightedPositionBuffer = context.new_empty_buffer(3); + const Buffer totalWeightBuffer = context.new_empty_buffer(1); + + const KernelSpec centerOfMassKernelSpec{ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/center_of_mass.slang"}, + .workgroup_size = workgroupSize, + .constants = {{"kUseMask", static_cast(mask.has_value())}}, + }, + .bindings_map = {{"weightedPositions", weightedPositionBuffer}, + {"totalIntensity", totalWeightBuffer}, + {"image", texture}, + {"mask", mask ? *mask : texture}}, + }; + + const Kernel centerOfMassKernel = context.new_kernel(centerOfMassKernelSpec); + const DispatchGrid dispatch_grid = DispatchGrid::element_wise_texture(texture, workgroupSize); + context.dispatch_kernel(centerOfMassKernel, dispatch_grid); + + std::array weightedPositionValues{}; + uint32_t totalWeightValue; + + context.download_buffer(weightedPositionBuffer, weightedPositionValues); + context.download_buffer(totalWeightBuffer, &totalWeightValue, sizeof(float)); + + // Now reinterpret the downloaded data as float + std::array weightedPosition; + weightedPosition[0] = *reinterpret_cast(&weightedPositionValues[0]); + weightedPosition[1] = *reinterpret_cast(&weightedPositionValues[1]); + weightedPosition[2] = *reinterpret_cast(&weightedPositionValues[2]); + + const float totalWeight = *reinterpret_cast(&totalWeightValue); + + const auto center = Eigen::Vector4f( + weightedPosition[0] / totalWeight, weightedPosition[1] / totalWeight, weightedPosition[2] / totalWeight, 1.0F); + + assert(center.x() >= 0 && center.x() <= texture.spec.width && center.y() >= 0 && center.y() <= texture.spec.height && + center.z() >= 0 && center.z() <= texture.spec.depth && "Center of mass is out of the bounds of the image"); + + const auto centerScanner = imageTransform.matrix().cast() * center; + + return {centerScanner.x(), centerScanner.y(), centerScanner.z()}; +} + +Eigen::Matrix3f computeScannerMoments(const Texture &texture, + const ComputeContext &context, + const Eigen::Matrix4f &voxelToScanner, + const Eigen::Vector3f ¢reScanner, + std::optional mask) { + constexpr size_t kMomentCount = 6; + const WorkgroupSize workgroupSize{.x = 8, .y = 8, .z = 4}; + + const std::array matrixData = EigenHelpers::to_array(voxelToScanner); + const Buffer matrixBuffer = context.new_buffer_from_host_memory(matrixData); + + const MomentUniforms uniforms{ + .centre = {centreScanner.x(), centreScanner.y(), centreScanner.z(), 0.0f}, + }; + const Buffer centreBuffer = + context.new_buffer_from_host_memory(&uniforms, sizeof(uniforms), BufferType::UniformBuffer); + + Buffer momentBuffer = context.new_empty_buffer(kMomentCount); + context.clear_buffer(momentBuffer); + + const KernelSpec kernelSpec{ + .compute_shader = {.shader_source = ShaderFile{"shaders/registration/moments.slang"}, + .workgroup_size = workgroupSize, + .constants = {{"kUseMask", static_cast(mask.has_value())}}}, + .bindings_map = {{"momentBuffer", momentBuffer}, + {"voxelToScanner", matrixBuffer}, + {"centreScanner", centreBuffer}, + {"image", texture}, + {"mask", mask ? *mask : texture}}, + }; + + const Kernel kernel = context.new_kernel(kernelSpec); + const DispatchGrid dispatch_grid = DispatchGrid::element_wise_texture(texture, workgroupSize); + + context.dispatch_kernel(kernel, dispatch_grid); + + std::array momentBits{}; + context.download_buffer(momentBuffer, momentBits); + + std::array momentValues{}; + for (size_t i = 0; i < kMomentCount; ++i) { + std::memcpy(&momentValues[i], &momentBits[i], sizeof(float)); + } + + Eigen::Matrix3f moments; + moments << momentValues[0], momentValues[3], momentValues[4], momentValues[3], momentValues[1], momentValues[5], + momentValues[4], momentValues[5], momentValues[2]; + return moments; +} + +Texture transformTexture(const Texture &texture, + const ComputeContext &context, + tcb::span transformationMatrixData) { + const WorkgroupSize workgroupSize{.x = 8, .y = 8, .z = 4}; + + const Buffer transformationMatrixBuffer = context.new_buffer_from_host_memory(transformationMatrixData); + + const TextureSpec outputTextureSpec{.width = texture.spec.width, + .height = texture.spec.height, + .depth = texture.spec.depth, + .format = texture.spec.format, + .usage = {.storage_binding = true, .render_target = false}}; + const Texture outputTexture = context.new_empty_texture(outputTextureSpec); + + const KernelSpec transformKernelSpec{ + .compute_shader = {.shader_source = ShaderFile{"shaders/transform_image.slang"}, + .workgroup_size = workgroupSize}, + .bindings_map = {{"transformationMatrix", transformationMatrixBuffer}, + {"inputImage", texture}, + {"outputImage", outputTexture}, + {"linearSampler", context.new_linear_sampler()}}}; + + const Kernel transformKernel = context.new_kernel(transformKernelSpec); + const DispatchGrid dispatch_grid{ + .x = Utils::nextMultipleOf(texture.spec.width / workgroupSize.x, workgroupSize.x), + .y = Utils::nextMultipleOf(texture.spec.height / workgroupSize.y, workgroupSize.y), + .z = Utils::nextMultipleOf(texture.spec.depth / workgroupSize.z, workgroupSize.z), + }; + + context.dispatch_kernel(transformKernel, dispatch_grid); + + return outputTexture; +} + +Texture downsampleTexture(const Texture &texture, const ComputeContext &context) { + const WorkgroupSize workgroupSize{.x = 8, .y = 8, .z = 4}; + + const TextureSpec outputTextureSpec{.width = texture.spec.width / 2, + .height = texture.spec.height / 2, + .depth = texture.spec.depth / 2, + .format = texture.spec.format, + .usage = {.storage_binding = true, .render_target = false}}; + const Texture outputTexture = context.new_empty_texture(outputTextureSpec); + + const KernelSpec transformKernelSpec{ + .compute_shader = {.shader_source = ShaderFile{"shaders/downsample_image.slang"}, + .workgroup_size = workgroupSize}, + .bindings_map = {{"inputTexture", texture}, {"outputTexture", outputTexture}}}; + const Kernel transformKernel = context.new_kernel(transformKernelSpec); + const DispatchGrid dispatch_grid{ + .x = Utils::nextMultipleOf(outputTextureSpec.width / workgroupSize.x, workgroupSize.x), + .y = Utils::nextMultipleOf(outputTextureSpec.height / workgroupSize.y, workgroupSize.y), + .z = Utils::nextMultipleOf(outputTextureSpec.depth / workgroupSize.z, workgroupSize.z), + }; + + context.dispatch_kernel(transformKernel, dispatch_grid); + + return outputTexture; +} + +std::vector +createDownsampledPyramid(const Texture &fullResTexture, int32_t numLevels, const ComputeContext &context) { + if (numLevels == 0) + return {}; + + std::vector pyramid(numLevels); + pyramid[numLevels - 1] = fullResTexture; + + for (int level = static_cast(numLevels) - 2; level >= 0; --level) { + pyramid[level] = downsampleTexture(pyramid[level + 1], context); + } + return pyramid; +} + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/imageoperations.h b/cpp/core/gpu/registration/imageoperations.h new file mode 100644 index 0000000000..5bba2697ca --- /dev/null +++ b/cpp/core/gpu/registration/imageoperations.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/gpu.h" +#include +#include "transform.h" +#include "types.h" + +#include +#include +#include +#include + +namespace MR { +template +Eigen::Matrix image_centre_scanner_space(const ImageType &image) { + const ValueType half = static_cast(0.5); + const ValueType one = static_cast(1.0); + Eigen::Matrix centre_voxel; + centre_voxel[0] = static_cast(image.size(0)) * half - one; + centre_voxel[1] = static_cast(image.size(1)) * half - one; + centre_voxel[2] = static_cast(image.size(2)) * half - one; + const Transform transform(image); + return transform.voxel2scanner.template cast() * centre_voxel; +} +} // namespace MR + +namespace MR::GPU { +// Compute center of mass of a given image the image +std::array centerOfMass(const GPU::Texture &texture, + const GPU::ComputeContext &context, + const transform_type &imageTransform = transform_type::Identity(), + std::optional mask = std::nullopt); + +Eigen::Matrix3f computeScannerMoments(const GPU::Texture &texture, + const GPU::ComputeContext &context, + const Eigen::Matrix4f &voxelToScanner, + const Eigen::Vector3f ¢reScanner, + std::optional mask = std::nullopt); + +// Transform the image using the given transformation +// If you want to transform an image in scanner coordinates +// then this transformation must be equal = +// scanner to voxel mat * transformation * voxel to scanner mat +GPU::Texture transformTexture(const GPU::Texture &texture, + const GPU::ComputeContext &context, + tcb::span transformationMatrixData); + +GPU::Texture downsampleTexture(const GPU::Texture &texture, const GPU::ComputeContext &context); + +std::vector +createDownsampledPyramid(const GPU::Texture &fullResTexture, int numLevels, const GPU::ComputeContext &context); +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/initialisation.cpp b/cpp/core/gpu/registration/initialisation.cpp new file mode 100644 index 0000000000..289c708ecb --- /dev/null +++ b/cpp/core/gpu/registration/initialisation.cpp @@ -0,0 +1,292 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/initialisation.h" +#include "gpu/registration/calculatorinterface.h" +#include "gpu/registration/eigenhelpers.h" +#include "gpu/gpu.h" +#include "gpu/registration/imageoperations.h" +#include "gpu/registration/initialisation_rotation_search.h" +#include "match_variant.h" +#include "math/math.h" +#include "mrtrix.h" +#include "gpu/registration/ncccalculator.h" +#include "gpu/registration/nmicalculator.h" +#include "gpu/registration/registrationtypes.h" +#include +#include "gpu/registration/ssdcalculator.h" +#include "types.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using vec3 = std::array; + +namespace { + +// Returns `num_samples` axis-angle vectors stored as {x, y, z} where the vector direction +// is the rotation axis (unit length) and the vector magnitude is the rotation angle theta. +// Angles are in radians. +// See https://stackoverflow.com/questions/9600801/evenly-distributing-n-points-on-a-sphere +std::vector fibonacci_sphere_samples(int32_t num_samples, float min_angle, float max_angle) { + std::vector out; + if (num_samples <= 0) { + throw std::invalid_argument("num_samples must be positive"); + } + + if (min_angle > max_angle) + std::swap(min_angle, max_angle); + + const int32_t n = num_samples; + + const double pi = std::acos(-1.0); + const double golden_angle = pi * (3.0 - std::sqrt(5.0)); // ~= 2.399963229728653 + + out.reserve(n); + + for (int i = 0; i < n; ++i) { + // y in [-1, 1]. If n == 1, place at north pole (y = 1). + const double y = (n == 1) ? 1.0 : 1.0 - ((2.0 * i) / static_cast(n - 1)); + const double radius = std::sqrt(std::max(0.0, 1.0 - (y * y))); + const double phi = i * golden_angle; + + const double x = std::cos(phi) * radius; + const double z = std::sin(phi) * radius; + + // compute angle for this sample using lerp in [min_angle, max_angle] + const double t = (n == 1) ? 0.5 : (static_cast(i) / static_cast(n - 1)); + const double angle = + static_cast(min_angle) + (t * (static_cast(max_angle) - static_cast(min_angle))); + + // axis-angle vector = unit_axis * angle + const std::array axis = { + static_cast(x * angle), static_cast(y * angle), static_cast(z * angle)}; + out.push_back(axis); + } + + return out; +} + +bool compute_sorted_eigenvectors(const Eigen::Matrix3f &matrix, + Eigen::Matrix3f &eigenvectors, + Eigen::Vector3f &eigenvalues) { + if (!matrix.allFinite()) { + return false; + } + + const Eigen::SelfAdjointEigenSolver solver(matrix); + if (solver.info() != Eigen::Success) { + return false; + } + + const Eigen::Vector3f values = solver.eigenvalues(); + const Eigen::Matrix3f vectors = solver.eigenvectors(); + + std::array indices = {0, 1, 2}; + std::sort(indices.begin(), indices.end(), [&](int a, int b) { return values[a] > values[b]; }); + + for (size_t i = 0; i < indices.size(); ++i) { + eigenvalues[static_cast(i)] = values[indices[i]]; + eigenvectors.col(static_cast(i)) = vectors.col(indices[i]); + } + return eigenvectors.allFinite() && eigenvalues.allFinite(); +} +} // namespace + +namespace MR::GPU { +GlobalTransform initialise_transformation(const InitialisationConfig &config, const ComputeContext &context) { + const Texture &moving_texture = config.moving_texture; + const Texture &target_texture = config.target_texture; + const auto &voxel_scanner_matrices = config.voxel_scanner_matrices; + const auto &moving_mask = config.moving_mask; + const auto &target_mask = config.target_mask; + const InitialisationOptions &options = config.options; + + const auto com_target = EigenHelpers::to_vector3f( + centerOfMass(target_texture, context, transform_type::Identity(), target_mask)); + const Eigen::Map voxel_to_scanner_fixed(voxel_scanner_matrices.voxel_to_scanner_fixed.data()); + const Eigen::Vector4f com_target_scanner = voxel_to_scanner_fixed * com_target.homogeneous(); + + const std::array rigid_identity = {0.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0F}; + GlobalTransform initial_transform(rigid_identity, TransformationType::Rigid, com_target_scanner.head<3>()); + + switch (options.translation_choice) { + case InitTranslationChoice::None: + break; + case MR::InitTranslationChoice::Mass: { + INFO("Computing initial translation using center of mass."); + const auto com_moving = EigenHelpers::to_vector3f( + centerOfMass(moving_texture, context, transform_type::Identity(), moving_mask)); + const Eigen::Map voxel_to_scanner_moving( + voxel_scanner_matrices.voxel_to_scanner_moving.data()); + const Eigen::Vector4f com_moving_scanner = voxel_to_scanner_moving * com_moving.homogeneous(); + + initial_transform.set_translation(com_moving_scanner.head<3>() - com_target_scanner.head<3>()); + break; + } + case InitTranslationChoice::Geometric: { + INFO("Computing initial translation using geometric center."); + const Eigen::Vector4f geom_moving_voxel((static_cast(moving_texture.spec.width) - 1.0F) * 0.5F, + (static_cast(moving_texture.spec.height) - 1.0F) * 0.5F, + (static_cast(moving_texture.spec.depth) - 1.0F) * 0.5F, + 1.0F); + const Eigen::Vector4f geom_target_voxel((static_cast(target_texture.spec.width) - 1.0F) * 0.5F, + (static_cast(target_texture.spec.height) - 1.0F) * 0.5F, + (static_cast(target_texture.spec.depth) - 1.0F) * 0.5F, + 1.0F); + + const Eigen::Map voxel_scanner_fixed(voxel_scanner_matrices.voxel_to_scanner_fixed.data()); + const Eigen::Map voxel_scanner_moving(voxel_scanner_matrices.voxel_to_scanner_moving.data()); + const Eigen::Vector4f geom_moving_scanner = voxel_scanner_moving * geom_moving_voxel; + const Eigen::Vector4f geom_target_scanner = voxel_scanner_fixed * geom_target_voxel; + + initial_transform.set_translation(geom_moving_scanner.head<3>() - geom_target_scanner.head<3>()); + initial_transform.set_pivot(Eigen::Vector3f(geom_target_scanner.head<3>())); + break; + } + } + + switch (options.rotation_choice) { + case MR::InitRotationChoice::None: + break; + case MR::InitRotationChoice::Search: { + INFO("Computing initial rotation using spherical sampling."); + const auto make_calculator = [&]() -> Calculator { + return MR::match_v( + options.cost_metric, + [&](const NMIMetric &nmi_metric) -> Calculator { + const std::optional fixed_mask = target_mask; + const std::optional moving_mask_opt = moving_mask; + const NMICalculator::Config nmi_config{ + .transformation_type = TransformationType::Rigid, + .fixed = target_texture, + .moving = moving_texture, + .fixed_mask = fixed_mask, + .moving_mask = moving_mask_opt, + .voxel_scanner_matrices = voxel_scanner_matrices, + .num_bins = nmi_metric.num_bins, + .output = CalculatorOutput::Cost, + .context = &context, + }; + return NMICalculator(nmi_config); + }, + [&](const SSDMetric &) -> Calculator { + const std::optional fixed_mask = target_mask; + const std::optional moving_mask_opt = moving_mask; + const SSDCalculator::Config ssd_config{ + .transformation_type = TransformationType::Rigid, + .fixed = target_texture, + .moving = moving_texture, + .fixed_mask = fixed_mask, + .moving_mask = moving_mask_opt, + .voxel_scanner_matrices = voxel_scanner_matrices, + .output = CalculatorOutput::Cost, + .context = &context, + }; + return SSDCalculator(ssd_config); + }, + [&](const NCCMetric &ncc_metric) -> Calculator { + const std::optional fixed_mask = target_mask; + const std::optional moving_mask_opt = moving_mask; + const NCCCalculator::Config ncc_config{ + .transformation_type = TransformationType::Rigid, + .fixed = target_texture, + .moving = moving_texture, + .fixed_mask = fixed_mask, + .moving_mask = moving_mask_opt, + .voxel_scanner_matrices = voxel_scanner_matrices, + .window_radius = ncc_metric.window_radius, + .output = CalculatorOutput::Cost, + .context = &context, + }; + return NCCCalculator(ncc_config); + }); + }; + + constexpr float pi = MR::Math::pi; + const float max_angle_rad = std::clamp(options.max_search_angle_degrees, 0.0F, 180.0F) * (pi / 180.0F); + const auto samples = fibonacci_sphere_samples(500, 0.0F, max_angle_rad); + + const auto rotation_angle = [](const std::array &axis) { + return std::sqrt(std::inner_product(axis.cbegin(), axis.cend(), axis.cbegin(), 0.0F)); + }; + + INFO("max_search_angle_degrees=" + std::to_string(options.max_search_angle_degrees) + + " max_angle_rad=" + std::to_string(max_angle_rad)); + INFO("sample[0] norm=" + std::to_string(rotation_angle(samples[0])) + + " sample[last] norm=" + std::to_string(rotation_angle(samples.back()))); + + const auto make_rotation_calculator = [&]() -> RotationSearchCalculator { + auto calculator = std::make_shared(make_calculator()); + return RotationSearchCalculator{ + .update = [calculator](const GlobalTransform &transform) { calculator->update(transform); }, + .get_result = [calculator]() { return calculator->get_result(); }, + }; + }; + + const RotationSearchParams search_params{ + .parallel_calculators = 8, + .min_improvement = 1e-6F, + .tie_cost_eps = 1e-6F, + }; + const tcb::span> sample_span(samples.data(), samples.size()); + const auto best_rotation = + search_best_rotation(initial_transform, + sample_span, + make_rotation_calculator, + search_params, + [&](float best_cost, const std::array &best_rotation) { + INFO("New best initial rotation found with cost " + std::to_string(best_cost) + + " at axis-angle {" + std::to_string(best_rotation[0]) + ", " + + std::to_string(best_rotation[1]) + ", " + std::to_string(best_rotation[2]) + "}"); + }); + + std::array params{}; + const auto current_params = initial_transform.parameters(); + std::copy(current_params.begin(), current_params.end(), params.begin()); + params[3] = best_rotation[0]; + params[4] = best_rotation[1]; + params[5] = best_rotation[2]; + initial_transform = GlobalTransform(tcb::span(params.data(), initial_transform.param_count()), + initial_transform.type(), + initial_transform.pivot()); + break; + } + case InitRotationChoice::Moments: { + // TODO: implement moment-based initial rotation + throw std::logic_error("Moment-based initial rotation is not yet implemented."); + } + } + + INFO("Initial transformation matrix:\n" + EigenHelpers::to_string(initial_transform.to_matrix4f())); + + return initial_transform.as_affine(); +} +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/initialisation.h b/cpp/core/gpu/registration/initialisation.h new file mode 100644 index 0000000000..7f3f94247e --- /dev/null +++ b/cpp/core/gpu/registration/initialisation.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include + +namespace MR::GPU { + +struct InitialisationConfig { + Texture moving_texture; + Texture target_texture; + std::optional moving_mask; + std::optional target_mask; + VoxelScannerMatrices voxel_scanner_matrices; + InitialisationOptions options; +}; + +GlobalTransform initialise_transformation(const InitialisationConfig &config, const ComputeContext &context); +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/initialisation_rotation_search.cpp b/cpp/core/gpu/registration/initialisation_rotation_search.cpp new file mode 100644 index 0000000000..0440f26412 --- /dev/null +++ b/cpp/core/gpu/registration/initialisation_rotation_search.cpp @@ -0,0 +1,105 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/initialisation_rotation_search.h" +#include "gpu/registration/registrationtypes.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { + +Eigen::Vector3f search_best_rotation(const GlobalTransform &initial_transform, + tcb::span> samples, + const std::function &make_calculator, + const RotationSearchParams ¶ms, + const std::function &)> &on_update) { + if (params.parallel_calculators == 0 || samples.empty()) { + return Eigen::Vector3f::Zero(); + } + + const auto rotation_angle = [](const std::array &axis) { + return std::sqrt(std::inner_product(axis.cbegin(), axis.cend(), axis.cbegin(), 0.0F)); + }; + + float best_cost = std::numeric_limits::infinity(); + std::array best_rotation = {0.0F, 0.0F, 0.0F}; + + const auto consider_candidate = [&](float candidate_cost, const std::array &sample) { + const float cost_delta = candidate_cost - best_cost; + const bool better_cost = cost_delta < -params.min_improvement; + const bool tie_with_smaller_angle = + std::abs(cost_delta) <= params.tie_cost_eps && rotation_angle(sample) < rotation_angle(best_rotation); + + if (better_cost || tie_with_smaller_angle) { + best_cost = candidate_cost; + best_rotation = sample; + if (on_update) { + on_update(best_cost, best_rotation); + } + } + }; + + std::vector calculators; + calculators.reserve(params.parallel_calculators); + for (size_t i = 0; i < params.parallel_calculators; ++i) { + calculators.push_back(make_calculator()); + } + + const size_t param_count = initial_transform.param_count(); + std::array base_params{}; + const auto initial_params = initial_transform.parameters(); + std::copy(initial_params.begin(), initial_params.end(), base_params.begin()); + + for (size_t chunk_start = 0; chunk_start < samples.size(); chunk_start += calculators.size()) { + const size_t chunk_size = std::min(calculators.size(), samples.size() - chunk_start); + + for (size_t local_index = 0; local_index < chunk_size; ++local_index) { + const auto &sample = samples[chunk_start + local_index]; + auto params = base_params; + params[3] = sample[0]; + params[4] = sample[1]; + params[5] = sample[2]; + const GlobalTransform candidate_transform( + tcb::span(params.data(), param_count), initial_transform.type(), initial_transform.pivot()); + calculators[local_index].update(candidate_transform); + } + + for (size_t local_index = 0; local_index < chunk_size; ++local_index) { + const auto result = calculators[local_index].get_result(); + consider_candidate(result.cost, samples[chunk_start + local_index]); + } + } + + return Eigen::Vector3f(best_rotation[0], best_rotation[1], best_rotation[2]); +} + +Eigen::Vector3f search_best_rotation(const GlobalTransform &initial_transform, + tcb::span> samples, + const std::function &make_calculator, + const RotationSearchParams ¶ms) { + return search_best_rotation( + initial_transform, samples, make_calculator, params, std::function &)>()); +} + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/initialisation_rotation_search.h b/cpp/core/gpu/registration/initialisation_rotation_search.h new file mode 100644 index 0000000000..20d1d67489 --- /dev/null +++ b/cpp/core/gpu/registration/initialisation_rotation_search.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/registration/registrationtypes.h" +#include + +#include +#include +#include +#include + +namespace MR::GPU { + +struct RotationSearchParams { + size_t parallel_calculators = 8; + float min_improvement = 1e-6F; + float tie_cost_eps = 1e-6F; +}; + +struct RotationSearchCalculator { + std::function update; + std::function get_result; +}; + +Eigen::Vector3f search_best_rotation(const GlobalTransform &initial_transform, + tcb::span> samples, + const std::function &make_calculator, + const RotationSearchParams ¶ms, + const std::function &)> &on_update); + +Eigen::Vector3f search_best_rotation(const GlobalTransform &initial_transform, + tcb::span> samples, + const std::function &make_calculator, + const RotationSearchParams ¶ms); + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/ncccalculator.cpp b/cpp/core/gpu/registration/ncccalculator.cpp new file mode 100644 index 0000000000..39e2eeafac --- /dev/null +++ b/cpp/core/gpu/registration/ncccalculator.cpp @@ -0,0 +1,275 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/ncccalculator.h" + +#include "gpu/registration/calculatoroutput.h" +#include "gpu/registration/eigenhelpers.h" +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { +namespace { + +constexpr WorkgroupSize ncc_workgroup_size{8U, 4U, 4U}; +constexpr double kVarianceEps = 1e-8; +constexpr double kDenominatorEps = 1e-8; + +template struct alignas(16) NCCUniforms { + alignas(16) DispatchGrid dispatch_grid{}; + alignas(16) std::array transformation_pivot{}; + alignas(16) std::array current_transform{}; + alignas(16) VoxelScannerMatrices voxel_scanner_matrices{}; +}; + +using RigidNCCUniforms = NCCUniforms<6>; +using AffineNCCUniforms = NCCUniforms<12>; +static_assert(sizeof(RigidNCCUniforms) % 16 == 0, "RigidNCCUniforms must be 16-byte aligned"); +static_assert(sizeof(AffineNCCUniforms) % 16 == 0, "AffineNCCUniforms must be 16-byte aligned"); + +template +void upload_uniforms(const ComputeContext &context, + const Buffer &buffer, + const DispatchGrid &dispatch_grid, + const GlobalTransform &transform, + const VoxelScannerMatrices &matrices) { + NCCUniforms uniforms{}; + uniforms.dispatch_grid = dispatch_grid; + uniforms.transformation_pivot = EigenHelpers::to_array(transform.pivot()); + const auto params = transform.parameters(); + std::copy_n(params.begin(), N, uniforms.current_transform.begin()); + uniforms.voxel_scanner_matrices = matrices; + context.write_to_buffer(buffer, &uniforms, sizeof(uniforms)); +} + +} // namespace + +NCCCalculator::NCCCalculator(const Config &config) + : m_output(config.output), + m_compute_context(config.context), + m_use_local_window(config.window_radius > 0U), + m_window_radius(config.window_radius), + m_voxel_scanner_matrices(config.voxel_scanner_matrices), + m_fixed(config.fixed), + m_moving(config.moving), + m_fixed_mask(config.fixed_mask.value_or(config.fixed)), + m_moving_mask(config.moving_mask.value_or(config.moving)), + m_use_fixed_mask(config.fixed_mask.has_value()), + m_use_moving_mask(config.moving_mask.has_value()) { + assert(m_compute_context != nullptr); + const bool is_rigid = config.transformation_type == TransformationType::Rigid; + m_degrees_of_freedom = is_rigid ? 6U : 12U; + m_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, ncc_workgroup_size); + m_terms_per_workgroup = 1U + m_degrees_of_freedom; + m_global_terms_per_workgroup = 5U + 3U * m_degrees_of_freedom; + + const size_t uniformsSize = is_rigid ? sizeof(RigidNCCUniforms) : sizeof(AffineNCCUniforms); + m_uniforms_buffer = m_compute_context->new_empty_buffer(uniformsSize, BufferType::UniformBuffer); + m_num_contributing_voxels_buffer = m_compute_context->new_empty_buffer(1); + + if (m_use_local_window) { + m_lncc_partials_buffer = + m_compute_context->new_empty_buffer(m_terms_per_workgroup * m_dispatch_grid.workgroup_count()); + m_lncc_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/ncc.slang"}, + .entryPoint = "lncc_main", + .workgroup_size = ncc_workgroup_size, + .constants = {{"kUseSourceMask", static_cast(m_use_moving_mask)}, + {"kUseTargetMask", static_cast(m_use_fixed_mask)}, + {"kComputeGradients", + static_cast(m_output == CalculatorOutput::CostAndGradients)}, + {"kWindowRadius", m_window_radius}}, + .entry_point_args = {is_rigid ? "RigidTransformation" : "AffineTransformation"}, + }, + .bindings_map = {{"uniforms", m_uniforms_buffer}, + {"sourceImage", m_moving}, + {"targetImage", m_fixed}, + {"sourceMask", m_moving_mask}, + {"targetMask", m_fixed_mask}, + {"linearSampler", m_compute_context->new_linear_sampler()}, + {"lnccPartials", m_lncc_partials_buffer}, + {"numContributingVoxels", m_num_contributing_voxels_buffer}}, + }); + } else { + m_global_partials_buffer = + m_compute_context->new_empty_buffer(m_global_terms_per_workgroup * m_dispatch_grid.workgroup_count()); + m_global_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/ncc.slang"}, + .entryPoint = "global_ncc_main", + .workgroup_size = ncc_workgroup_size, + .constants = {{"kUseSourceMask", static_cast(m_use_moving_mask)}, + {"kUseTargetMask", static_cast(m_use_fixed_mask)}, + {"kComputeGradients", + static_cast(m_output == CalculatorOutput::CostAndGradients)}, + {"kWindowRadius", m_window_radius}}, + .entry_point_args = {is_rigid ? "RigidTransformation" : "AffineTransformation"}, + }, + .bindings_map = {{"uniforms", m_uniforms_buffer}, + {"sourceImage", m_moving}, + {"targetImage", m_fixed}, + {"sourceMask", m_moving_mask}, + {"targetMask", m_fixed_mask}, + {"linearSampler", m_compute_context->new_linear_sampler()}, + {"globalPartials", m_global_partials_buffer}, + {"numContributingVoxels", m_num_contributing_voxels_buffer}}, + }); + } +} + +void NCCCalculator::update(const GlobalTransform &transformation) { + assert(transformation.param_count() == m_degrees_of_freedom); + if (m_use_local_window) { + assert(m_window_radius > 0U); + } + + if (transformation.is_affine()) { + upload_uniforms<12>(*m_compute_context, m_uniforms_buffer, m_dispatch_grid, transformation, m_voxel_scanner_matrices); + } else { + upload_uniforms<6>(*m_compute_context, m_uniforms_buffer, m_dispatch_grid, transformation, m_voxel_scanner_matrices); + } + + m_compute_context->clear_buffer(m_num_contributing_voxels_buffer); + + if (m_use_local_window) { + m_compute_context->dispatch_kernel(m_lncc_kernel, m_dispatch_grid); + } else { + m_compute_context->dispatch_kernel(m_global_kernel, m_dispatch_grid); + } +} + +IterationResult NCCCalculator::get_result() const { + if (m_use_local_window) { + const auto partials = m_compute_context->download_buffer_as_vector(m_lncc_partials_buffer); + const auto contributing = m_compute_context->download_buffer_as_vector(m_num_contributing_voxels_buffer); + const uint32_t validCount = contributing.empty() ? 0U : contributing[0]; + const size_t workgroups = m_dispatch_grid.workgroup_count(); + + double totalCost = 0.0; + std::vector gradients; + if (m_output == CalculatorOutput::CostAndGradients) { + gradients.assign(m_degrees_of_freedom, 0.0); + } + + for (size_t wg = 0; wg < workgroups; ++wg) { + const size_t base = wg * m_terms_per_workgroup; + totalCost += partials[base]; + if (m_output == CalculatorOutput::CostAndGradients) { + for (uint32_t i = 0; i < m_degrees_of_freedom; ++i) { + gradients[i] += partials[base + 1 + i]; + } + } + } + + const float invCount = validCount > 0U ? 1.0F / static_cast(validCount) : 0.0F; + const float cost = static_cast(totalCost) * invCount; + + if (m_output == CalculatorOutput::Cost) { + return IterationResult{cost, {}}; + } + + std::vector gradientsF(m_degrees_of_freedom, 0.0F); + for (uint32_t i = 0; i < m_degrees_of_freedom; ++i) { + gradientsF[i] = static_cast(gradients[i]) * invCount; + } + return IterationResult{cost, std::move(gradientsF)}; + } + + const auto partials = m_compute_context->download_buffer_as_vector(m_global_partials_buffer); + const auto contributing = m_compute_context->download_buffer_as_vector(m_num_contributing_voxels_buffer); + const double validCount = contributing.empty() ? 0.0 : static_cast(contributing[0]); + if (validCount == 0.0) { + return m_output == CalculatorOutput::CostAndGradients + ? IterationResult{0.0F, std::vector(m_degrees_of_freedom, 0.0F)} + : IterationResult{0.0F, {}}; + } + + const size_t workgroups = m_dispatch_grid.workgroup_count(); + double sumTarget = 0.0; + double sumMoving = 0.0; + double sumTargetSquared = 0.0; + double sumMovingSquared = 0.0; + double sumTargetMoving = 0.0; + std::vector sumTargetMovingPrime(m_degrees_of_freedom, 0.0); + std::vector sumMovingPrime(m_degrees_of_freedom, 0.0); + std::vector sumMovingSquaredPrime(m_degrees_of_freedom, 0.0); + + for (size_t wg = 0; wg < workgroups; ++wg) { + const size_t base = wg * m_global_terms_per_workgroup; + size_t offset = base; + sumTarget += partials[offset++]; + sumMoving += partials[offset++]; + sumTargetSquared += partials[offset++]; + sumMovingSquared += partials[offset++]; + sumTargetMoving += partials[offset++]; + + for (uint32_t i = 0; i < m_degrees_of_freedom; ++i) { + sumTargetMovingPrime[i] += partials[offset++]; + } + for (uint32_t i = 0; i < m_degrees_of_freedom; ++i) { + sumMovingPrime[i] += partials[offset++]; + } + for (uint32_t i = 0; i < m_degrees_of_freedom; ++i) { + sumMovingSquaredPrime[i] += partials[offset++]; + } + } + + const double invCount = 1.0 / validCount; + const double meanTarget = sumTarget * invCount; + const double meanMoving = sumMoving * invCount; + const double varianceTarget = std::max(0.0, sumTargetSquared * invCount - meanTarget * meanTarget); + const double varianceMoving = std::max(0.0, sumMovingSquared * invCount - meanMoving * meanMoving); + if (varianceTarget < kVarianceEps || varianceMoving < kVarianceEps) { + return m_output == CalculatorOutput::CostAndGradients + ? IterationResult{0.0F, std::vector(m_degrees_of_freedom, 0.0F)} + : IterationResult{0.0F, {}}; + } + + const double covariance = sumTargetMoving * invCount - meanTarget * meanMoving; + const double denom = std::sqrt(std::max(varianceTarget * varianceMoving, kVarianceEps)); + const float cost = static_cast(-covariance / denom); + + if (m_output == CalculatorOutput::Cost) { + return IterationResult{cost, {}}; + } + + const double denomGradBase = std::max(varianceMoving * denom, kDenominatorEps); + std::vector gradients(m_degrees_of_freedom, 0.0F); + for (uint32_t i = 0; i < m_degrees_of_freedom; ++i) { + const double cPrime = (sumTargetMovingPrime[i] * invCount) - (meanTarget * sumMovingPrime[i] * invCount); + const double varMovingPrime = + 2.0 * (sumMovingSquaredPrime[i] * invCount - meanMoving * sumMovingPrime[i] * invCount); + const double gradient = (cPrime * varianceMoving - 0.5 * covariance * varMovingPrime) / denomGradBase; + gradients[i] = static_cast(-gradient); + } + + return IterationResult{cost, std::move(gradients)}; +} + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/ncccalculator.h b/cpp/core/gpu/registration/ncccalculator.h new file mode 100644 index 0000000000..6c52482bc8 --- /dev/null +++ b/cpp/core/gpu/registration/ncccalculator.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/registration/calculatoroutput.h" +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include +#include + +namespace MR::GPU { + +class NCCCalculator { +public: + struct Config { + TransformationType transformation_type = TransformationType::Affine; + Texture fixed; + Texture moving; + std::optional fixed_mask; + std::optional moving_mask; + VoxelScannerMatrices voxel_scanner_matrices{}; + uint32_t window_radius = 0U; + CalculatorOutput output = CalculatorOutput::CostAndGradients; + const ComputeContext *context = nullptr; + }; + + explicit NCCCalculator(const Config &config); + + void update(const GlobalTransform &transformation); + IterationResult get_result() const; + +private: + CalculatorOutput m_output = CalculatorOutput::CostAndGradients; + const ComputeContext *m_compute_context = nullptr; + bool m_use_local_window = false; + uint32_t m_window_radius = 0U; + uint32_t m_degrees_of_freedom = 0U; + + DispatchGrid m_dispatch_grid{}; + VoxelScannerMatrices m_voxel_scanner_matrices{}; + + Texture m_fixed{}; + Texture m_moving{}; + Texture m_fixed_mask{}; + Texture m_moving_mask{}; + bool m_use_fixed_mask = false; + bool m_use_moving_mask = false; + + Buffer m_uniforms_buffer{}; + Buffer m_lncc_partials_buffer{}; + Buffer m_global_partials_buffer{}; + Buffer m_num_contributing_voxels_buffer{}; + + Kernel m_lncc_kernel{}; + Kernel m_global_kernel{}; + + size_t m_terms_per_workgroup = 0U; + size_t m_global_terms_per_workgroup = 0U; +}; + +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/nmicalculator.cpp b/cpp/core/gpu/registration/nmicalculator.cpp new file mode 100644 index 0000000000..87b3bfaeea --- /dev/null +++ b/cpp/core/gpu/registration/nmicalculator.cpp @@ -0,0 +1,366 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/nmicalculator.h" + +#include "gpu/registration/calculatoroutput.h" +#include "gpu/registration/eigenhelpers.h" +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/utils.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { +namespace { +uint32_t float_to_ordered_uint(float v) { + uint32_t bits; + // Use std::bit_cast when we switch to C++20 + std::memcpy(&bits, &v, sizeof(bits)); + return (bits & 0x80000000u) ? ~bits : (bits ^ 0x80000000u); +} + +float ordered_uint_to_float(uint32_t v) { + const uint32_t bits = (v & 0x80000000u) ? (v ^ 0x80000000u) : ~v; + float out; + // Use std::bit_cast when we switch to C++20 + std::memcpy(&out, &bits, sizeof(out)); + return out; +} +} // namespace + +struct MinMaxUniforms { + alignas(16) DispatchGrid dispatch_grid{}; +}; +static_assert(sizeof(MinMaxUniforms) % 16 == 0, "MinMaxUniforms must be 16-byte aligned"); + +struct JointHistogramUniforms { + alignas(16) DispatchGrid dispatch_grid{}; + alignas(16) Intensities intensities{}; + alignas(16) std::array transformation_matrix{}; +}; +static_assert(sizeof(JointHistogramUniforms) % 16 == 0, "JointHistogramUniforms must be 16-byte aligned"); + +struct PrecomputeUniforms { + alignas(16) DispatchGrid dispatch_grid{}; +}; +static_assert(sizeof(PrecomputeUniforms) % 16 == 0, "PrecomputeUniforms must be 16-byte aligned"); + +template struct GradientsUniforms { + alignas(16) DispatchGrid dispatch_grid{}; + alignas(16) std::array transformation_pivot{}; + alignas(16) Intensities intensities{}; + alignas(16) std::array current_transform{}; + alignas(16) VoxelScannerMatrices voxel_scanner_matrices{}; +}; + +using RigidGradientsUniforms = GradientsUniforms<6>; +using AffineGradientsUniforms = GradientsUniforms<12>; + +constexpr WorkgroupSize gradientsWorkgroupSize{16, 8, 8}; +const std::array initialMinMax{ + float_to_ordered_uint(std::numeric_limits::max()), + float_to_ordered_uint(-std::numeric_limits::max()), +}; + +// Order of operations to drive GPU computation: +// 1. Find the min/max intensities of the fixed image and moving image (with current transformation applied) +// 2. Compute the joint histogram of the fixed and moving images +// 3. Precompute a coefficients table from the joint histogram to avoid redundant computations in the next stage and +// compute the mutual information cost. +// 4. Compute the gradients of the mutual information cost function with respect to the transformation parameters. +NMICalculator::NMICalculator(const Config &config) + : m_output(config.output), + m_compute_context(config.context), + m_fixed(config.fixed), + m_moving(config.moving), + m_fixed_mask(config.fixed_mask.value_or(config.fixed)), + m_moving_mask(config.moving_mask.value_or(config.moving)), + m_use_fixed_mask(config.fixed_mask.has_value()), + m_use_moving_mask(config.moving_mask.has_value()), + m_voxel_scanner_matrices(config.voxel_scanner_matrices), + m_num_bins(config.num_bins) + +{ + assert(m_compute_context != nullptr); + const bool is_rigid = config.transformation_type == TransformationType::Rigid; + const bool is_affine = config.transformation_type == TransformationType::Affine; + m_degrees_of_freedom = is_rigid ? 6 : 12; + + // The min/max reduction runs on encoded uint32_t values. We map floats to an order-preserving + // uint representation (flip sign bit for positives, bitwise-not for negatives), so unsigned + // comparisons match float ordering and atomics work for negative intensities too. + // TODO: Should we use shared memory reduction instead of atomics for better performance? + m_min_max_uniforms_buffer = + m_compute_context->new_empty_buffer(sizeof(MinMaxUniforms), BufferType::UniformBuffer); + m_min_max_intensity_fixed_buffer = + m_compute_context->new_buffer_from_host_memory(initialMinMax.data(), sizeof(initialMinMax)); + m_min_max_intensity_moving_buffer = + m_compute_context->new_buffer_from_host_memory(initialMinMax.data(), sizeof(initialMinMax)); + m_raw_joint_histogram_buffer = m_compute_context->new_empty_buffer(m_num_bins * m_num_bins); + m_smoothed_joint_histogram_buffer = m_compute_context->new_empty_buffer(m_num_bins * m_num_bins); + m_joint_histogram_mass_buffer = m_compute_context->new_empty_buffer(1); + m_joint_histogram_uniforms_buffer = + m_compute_context->new_empty_buffer(sizeof(JointHistogramUniforms), BufferType::UniformBuffer); + m_precomputed_coefficients_buffer = m_compute_context->new_empty_buffer(m_num_bins * m_num_bins); + m_mutual_information_buffer = m_compute_context->new_empty_buffer(1); + + if (m_output == CalculatorOutput::CostAndGradients) { + m_gradients_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, gradientsWorkgroupSize); + const uint32_t gradients_uniform_size = + is_affine ? sizeof(AffineGradientsUniforms) : sizeof(RigidGradientsUniforms); + m_gradients_uniforms_buffer = + m_compute_context->new_empty_buffer(gradients_uniform_size, BufferType::UniformBuffer); + m_gradients_buffer = + m_compute_context->new_empty_buffer(m_degrees_of_freedom * m_gradients_dispatch_grid.workgroup_count()); + } + + const KernelSpec min_max_fixed_kernel_spec{ + .compute_shader = {.shader_source = ShaderFile{"shaders/reduction_image.slang"}, + .entryPoint = "minMaxAtomic"}, + .bindings_map = {{"uniforms", m_min_max_uniforms_buffer}, + {"inputTexture", m_fixed}, + {"outputBuffer", m_min_max_intensity_fixed_buffer}, + {"sampler", m_compute_context->new_linear_sampler()}}, + }; + const auto min_max_fixed_kernel = m_compute_context->new_kernel(min_max_fixed_kernel_spec); + const DispatchGrid fixed_dispatch_grid = + DispatchGrid::element_wise_texture(m_fixed, min_max_fixed_kernel.workgroup_size); + const MinMaxUniforms min_max_fixed_uniforms{ + .dispatch_grid = fixed_dispatch_grid, + }; + m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, &min_max_fixed_uniforms, sizeof(min_max_fixed_uniforms)); + m_compute_context->dispatch_kernel(min_max_fixed_kernel, fixed_dispatch_grid); + + const KernelSpec min_max_moving_kernel_spec{ + .compute_shader = {.shader_source = ShaderFile{"shaders/reduction_image.slang"}, + .entryPoint = "minMaxAtomic"}, + .bindings_map = { + {"uniforms", m_min_max_uniforms_buffer}, + {"inputTexture", m_moving}, + {"outputBuffer", m_min_max_intensity_moving_buffer}, + {"sampler", m_compute_context->new_linear_sampler()}, + }}; + + m_min_max_moving_kernel = m_compute_context->new_kernel(min_max_moving_kernel_spec); + const DispatchGrid moving_dispatch_grid = + DispatchGrid::element_wise_texture(m_moving, m_min_max_moving_kernel.workgroup_size); + const MinMaxUniforms min_max_moving_uniforms{ + .dispatch_grid = moving_dispatch_grid, + }; + m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, &min_max_moving_uniforms, sizeof(MinMaxUniforms)); + m_compute_context->dispatch_kernel(m_min_max_moving_kernel, moving_dispatch_grid); + + const std::vector min_max_fixed_bits = + m_compute_context->download_buffer_as_vector(m_min_max_intensity_fixed_buffer); + const std::vector min_max_moving_bits = + m_compute_context->download_buffer_as_vector(m_min_max_intensity_moving_buffer); + + m_intensities = {ordered_uint_to_float(min_max_moving_bits[0]), + ordered_uint_to_float(min_max_moving_bits[1]), + ordered_uint_to_float(min_max_fixed_bits[0]), + ordered_uint_to_float(min_max_fixed_bits[1])}; + + const WorkgroupSize joint_histogram_wg_size = {8, 8, 4}; + + m_joint_histogram_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, joint_histogram_wg_size); + const JointHistogramUniforms joint_histogram_uniforms{ + .dispatch_grid = m_joint_histogram_dispatch_grid, + .intensities = m_intensities, + .transformation_matrix = {}, + }; + m_compute_context->write_to_buffer( + m_joint_histogram_uniforms_buffer, &joint_histogram_uniforms, sizeof(JointHistogramUniforms)); + const uint32_t jointHistogramPartialsSize = (m_num_bins * m_num_bins) * m_joint_histogram_dispatch_grid.workgroup_count(); + m_joint_histogram_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/joint_histogram.slang"}, + .entryPoint = "rawHistogram", + .workgroup_size = joint_histogram_wg_size, + .constants = {{"kNumBins", m_num_bins}, + {"kUseFixedMask", static_cast(m_use_fixed_mask)}, + {"kUseMovingMask", static_cast(m_use_moving_mask)}}, + }, + .bindings_map = {{"uniforms", m_joint_histogram_uniforms_buffer}, + {"fixedTexture", m_fixed}, + {"movingTexture", m_moving}, + {"fixedMaskTexture", m_fixed_mask}, + {"movingMaskTexture", m_moving_mask}, + {"jointHistogram", m_raw_joint_histogram_buffer}, + {"sampler", m_compute_context->new_linear_sampler()}}, + }); + + m_compute_total_mass_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/joint_histogram.slang"}, + .entryPoint = "computeTotalMass", + .constants = {{"kNumBins", m_num_bins}, + {"kUseFixedMask", static_cast(m_use_fixed_mask)}, + {"kUseMovingMask", static_cast(m_use_moving_mask)}}, + }, + .bindings_map = {{"jointHistogramSmoothed", m_smoothed_joint_histogram_buffer}, + {"jointHistogramMass", m_joint_histogram_mass_buffer}}, + }); + + m_joint_histogram_smooth_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/joint_histogram.slang"}, + .entryPoint = "smoothHistogram", + .workgroup_size = WorkgroupSize{8, 8, 1}, + .constants = {{"kNumBins", m_num_bins}, + {"kUseFixedMask", static_cast(m_use_fixed_mask)}, + {"kUseMovingMask", static_cast(m_use_moving_mask)}}, + }, + .bindings_map = {{"uniforms", m_joint_histogram_uniforms_buffer}, + {"jointHistogram", m_raw_joint_histogram_buffer}, + {"jointHistogramSmoothed", m_smoothed_joint_histogram_buffer}}, + }); + + m_precompute_kernel = m_compute_context->new_kernel({ + .compute_shader = {.shader_source = ShaderFile{"shaders/registration/nmi.slang"}, + .entryPoint = "precompute", + .constants = {{"kNumBins", m_num_bins}, + {"kUseTargetMask", static_cast(m_use_fixed_mask)}, + {"kUseSourceMask", static_cast(m_use_moving_mask)}}}, + .bindings_map = {{"jointHistogram", m_smoothed_joint_histogram_buffer}, + {"jointHistogramMass", m_joint_histogram_mass_buffer}, + {"coefficientsTable", m_precomputed_coefficients_buffer}, + {"mutualInformation", m_mutual_information_buffer}}, + }); + + if (m_output == CalculatorOutput::CostAndGradients) { + m_gradients_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/nmi.slang"}, + .entryPoint = "main", + .workgroup_size = gradientsWorkgroupSize, + .constants = {{"kNumBins", m_num_bins}, + {"kUseTargetMask", static_cast(m_use_fixed_mask)}, + {"kUseSourceMask", static_cast(m_use_moving_mask)}}, + .entry_point_args = {is_affine ? "AffineTransformation" : "RigidTransformation"}, + }, + .bindings_map = {{"uniforms", m_gradients_uniforms_buffer}, + {"targetTexture", m_fixed}, + {"sourceTexture", m_moving}, + {"targetMaskTexture", m_fixed_mask}, + {"sourceMaskTexture", m_moving_mask}, + {"coefficientsTable", m_precomputed_coefficients_buffer}, + {"partialSumsGradients", m_gradients_buffer}, + {"sampler", m_compute_context->new_linear_sampler()}}, + }); + } +} + +void NMICalculator::update(const GlobalTransform &transformation) { + m_compute_context->clear_buffer(m_raw_joint_histogram_buffer); + m_compute_context->clear_buffer(m_joint_histogram_mass_buffer); + + assert(transformation.param_count() == m_degrees_of_freedom); + const auto moving_dispatch_grid = DispatchGrid::element_wise_texture(m_moving, m_min_max_moving_kernel.workgroup_size); + const auto fixed_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, m_joint_histogram_kernel.workgroup_size); + + const auto transformation_matrix = transformation.to_matrix4f(); + + const Eigen::Matrix4f transformation_matrix_voxel_space = + Eigen::Map(m_voxel_scanner_matrices.scanner_to_voxel_moving.data()) * + Eigen::Map(transformation_matrix.data()) * + Eigen::Map(m_voxel_scanner_matrices.voxel_to_scanner_fixed.data()); + + const JointHistogramUniforms joint_histogram_uniforms{ + .dispatch_grid = fixed_dispatch_grid, + .intensities = m_intensities, + .transformation_matrix = EigenHelpers::to_array(transformation_matrix_voxel_space), + }; + m_compute_context->write_to_buffer( + m_joint_histogram_uniforms_buffer, &joint_histogram_uniforms, sizeof(joint_histogram_uniforms)); + m_compute_context->dispatch_kernel(m_joint_histogram_kernel, m_joint_histogram_dispatch_grid); + + const WorkgroupSize smoothWGSize{8, 8, 1}; + const DispatchGrid smooth_grid = + DispatchGrid::element_wise({size_t(m_num_bins), size_t(m_num_bins), size_t(1)}, smoothWGSize); + m_compute_context->dispatch_kernel(m_joint_histogram_smooth_kernel, smooth_grid); + + const uint32_t histogramSize = m_num_bins * m_num_bins; + const DispatchGrid merge_grid{.x = histogramSize}; + m_compute_context->dispatch_kernel(m_compute_total_mass_kernel, DispatchGrid{1, 1, 1}); + + // Precompute coefficients and mutual information from the smoothed histogram + m_compute_context->dispatch_kernel(m_precompute_kernel, DispatchGrid{1, 1, 1}); + + const std::array pivot_array = EigenHelpers::to_array(transformation.pivot()); + + if (m_output == CalculatorOutput::CostAndGradients) { + if (transformation.is_affine()) { + std::array params; + const auto current = transformation.parameters(); + std::copy_n(current.begin(), 12, params.begin()); + const AffineGradientsUniforms gradients_uniforms{ + .dispatch_grid = m_gradients_dispatch_grid, + .transformation_pivot = pivot_array, + .intensities = m_intensities, + .current_transform = params, + .voxel_scanner_matrices = m_voxel_scanner_matrices, + }; + + m_compute_context->write_to_buffer(m_gradients_uniforms_buffer, &gradients_uniforms, sizeof(AffineGradientsUniforms)); + } else { + std::array params; + const auto current = transformation.parameters(); + std::copy_n(current.begin(), 6, params.begin()); + const RigidGradientsUniforms gradients_uniforms{ + .dispatch_grid = m_gradients_dispatch_grid, + .transformation_pivot = pivot_array, + .intensities = m_intensities, + .current_transform = params, + .voxel_scanner_matrices = m_voxel_scanner_matrices, + }; + m_compute_context->write_to_buffer(m_gradients_uniforms_buffer, &gradients_uniforms, sizeof(RigidGradientsUniforms)); + } + + m_compute_context->dispatch_kernel(m_gradients_kernel, m_gradients_dispatch_grid); + } +} + +IterationResult NMICalculator::get_result() const { + // Negate the cost and gradients since the class is used to maximize the mutual information + // while the optimisation framework minimises the cost function. + const auto mi_cost = m_compute_context->download_buffer_as_vector(m_mutual_information_buffer); + if (m_output == CalculatorOutput::Cost) { + return IterationResult{-mi_cost[0], {}}; + } + + const auto gradients_partials_f = m_compute_context->download_buffer_as_vector(m_gradients_buffer); + const std::vector gradients_partials(gradients_partials_f.begin(), gradients_partials_f.end()); + auto gradients = Utils::chunkReduce(gradients_partials, m_degrees_of_freedom, std::plus<>{}); + std::transform(gradients.begin(), gradients.end(), gradients.begin(), std::negate<>{}); + + return IterationResult{-mi_cost[0], std::vector(gradients.begin(), gradients.end())}; +} +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/nmicalculator.h b/cpp/core/gpu/registration/nmicalculator.h new file mode 100644 index 0000000000..0e6bff73d1 --- /dev/null +++ b/cpp/core/gpu/registration/nmicalculator.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/registration/calculatoroutput.h" +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include +#include + +namespace MR::GPU { + +// Match shader order: source (moving) then target (fixed) +struct Intensities { + float min_moving; + float max_moving; + float min_fixed; + float max_fixed; +}; + +class NMICalculator { +public: + struct Config { + TransformationType transformation_type = TransformationType::Affine; + Texture fixed; + Texture moving; + std::optional fixed_mask; + std::optional moving_mask; + VoxelScannerMatrices voxel_scanner_matrices{}; + uint32_t num_bins = 32; + CalculatorOutput output = CalculatorOutput::CostAndGradients; + const ComputeContext *context = nullptr; + }; + explicit NMICalculator(const Config &config); + + void update(const GlobalTransform &transformation); + IterationResult get_result() const; + +private: + CalculatorOutput m_output = CalculatorOutput::CostAndGradients; + const ComputeContext *m_compute_context = nullptr; + + Buffer m_raw_joint_histogram_buffer; + Buffer m_smoothed_joint_histogram_buffer; + Buffer m_joint_histogram_mass_buffer; + Buffer m_joint_histogram_uniforms_buffer; + Buffer m_min_max_uniforms_buffer; + Buffer m_min_max_intensity_fixed_buffer; + Buffer m_min_max_intensity_moving_buffer; + Buffer m_precomputed_coefficients_buffer; + Buffer m_mutual_information_buffer; + Buffer m_gradients_uniforms_buffer; + Buffer m_gradients_buffer; + + Kernel m_min_max_moving_kernel; + Kernel m_joint_histogram_kernel; + Kernel m_joint_histogram_smooth_kernel; + Kernel m_compute_total_mass_kernel; + Kernel m_precompute_kernel; + Kernel m_gradients_kernel; + + Texture m_fixed; + Texture m_moving; + Texture m_fixed_mask; + Texture m_moving_mask; + bool m_use_fixed_mask = false; + bool m_use_moving_mask = false; + + VoxelScannerMatrices m_voxel_scanner_matrices; + Eigen::Vector3f m_centre_scanner_fixed; + Eigen::Vector3f m_centre_scanner_moving; + + DispatchGrid m_joint_histogram_dispatch_grid; + DispatchGrid m_gradients_dispatch_grid; + + uint32_t m_num_bins = 32; + Intensities m_intensities; + uint32_t m_degrees_of_freedom; +}; +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/registrationtypes.cpp b/cpp/core/gpu/registration/registrationtypes.cpp new file mode 100644 index 0000000000..f65aa6b954 --- /dev/null +++ b/cpp/core/gpu/registration/registrationtypes.cpp @@ -0,0 +1,293 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/registrationtypes.h" +#include +#include "types.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace MR { +// using transform_type = Eigen::Transform; +namespace { +constexpr size_t param_count_for_type(TransformationType type) { return type == TransformationType::Rigid ? 6U : 12U; } +} // namespace + +TransformationType GlobalTransform::type() const { return m_type; } + +tcb::span GlobalTransform::parameters() const { + return tcb::span(m_params.data(), m_param_count); +} + +void GlobalTransform::set_params(tcb::span params) { + const size_t expected = param_count_for_type(m_type); + if (params.size() != expected) { + throw std::invalid_argument("Parameter count does not match transformation type."); + } + std::copy(params.begin(), params.end(), m_params.begin()); + m_param_count = expected; +} + +Eigen::Vector3f GlobalTransform::pivot() const { return m_pivot; } + +void GlobalTransform::set_pivot(const Eigen::Vector3f &pivot) { m_pivot = pivot; } + +transform_type GlobalTransform::to_affine_compact() const { + const Eigen::Translation3f translateToPivot(-m_pivot.x(), -m_pivot.y(), -m_pivot.z()); + const Eigen::Translation3f translateFromPivot(m_pivot.x(), m_pivot.y(), m_pivot.z()); + + Eigen::Transform shear = Eigen::Transform::Identity(); + if (is_affine()) { + // The linear part of the shear matrix: + // [ 1 sh_xy sh_xz ] + // [ 0 1 sh_yz ] + // [ 0 0 1 ] + shear.matrix()(0, 1) = m_params[9]; + shear.matrix()(0, 2) = m_params[10]; + shear.matrix()(1, 2) = m_params[11]; + } + + Eigen::Transform scale = Eigen::Transform::Identity(); + if (is_affine()) { + scale = Eigen::Scaling(m_params[6], m_params[7], m_params[8]); + } + + const Eigen::Vector3f rotationAxisAngleVec(m_params[3], m_params[4], m_params[5]); + const float angle = rotationAxisAngleVec.norm(); + Eigen::AngleAxisf rotation = Eigen::AngleAxisf::Identity(); + if (angle != 0.0f) { + rotation = Eigen::AngleAxisf(angle, rotationAxisAngleVec / angle); + } + + const Eigen::Translation3f globalTranslation(m_params[0], m_params[1], m_params[2]); + + // Combine transformations in the correct order: + // M_final = M6 * M5 * M4 * M3 * M2 * M1 + // (Applied to a point P as M_final * P) + const Eigen::Transform final_affine_transform = + globalTranslation * translateFromPivot * rotation * scale * shear * translateToPivot; + + return transform_type(final_affine_transform); +} + +GlobalTransform::GlobalTransform(tcb::span params, TransformationType type, const Eigen::Vector3f &pivot) + : m_type(type), m_pivot(pivot) { + const size_t expected = param_count_for_type(type); + if (params.size() != expected) { + throw std::invalid_argument("Parameter count does not match transformation type."); + } + m_params.fill(0.0F); + std::copy(params.begin(), params.end(), m_params.begin()); + m_param_count = expected; +} + +Eigen::Matrix4f GlobalTransform::to_matrix4f() const { + transform_type transform = to_affine_compact(); + Eigen::Matrix4f matrix = Eigen::Matrix4f::Identity(); + matrix.block<3, 4>(0, 0) = transform.matrix().template cast(); + return matrix; +} + +GlobalTransform GlobalTransform::inverse() const { + const auto eigenTransform = to_affine_compact(); + const auto inverseEigenTransform = eigenTransform.inverse(); + return GlobalTransform::from_affine_compact(inverseEigenTransform, m_pivot, m_type); +} + +GlobalTransform GlobalTransform::with_pivot(const Eigen::Vector3f &pivot) const { + return GlobalTransform(parameters(), m_type, pivot); +} +// Keeps translation and axis-angle rotation while dropping scale/shear. +GlobalTransform GlobalTransform::as_rigid() const { + if (is_rigid()) { + return *this; + } + const std::array rigid_params{m_params[0], m_params[1], m_params[2], m_params[3], m_params[4], m_params[5]}; + return GlobalTransform(rigid_params, TransformationType::Rigid, m_pivot); +} + +// Ensures scale defaults to 1 and shear to 0 when promoting a rigid transform to affine. +GlobalTransform GlobalTransform::as_affine() const { + if (is_affine()) { + return *this; + } + std::array affine_params{}; + const auto current = parameters(); + std::copy(current.begin(), current.end(), affine_params.begin()); + affine_params[6] = 1.0F; + affine_params[7] = 1.0F; + affine_params[8] = 1.0F; + return GlobalTransform(affine_params, TransformationType::Affine, m_pivot); +} + +bool GlobalTransform::is_rigid() const { return m_type == TransformationType::Rigid; } + +bool GlobalTransform::is_affine() const { return m_type == TransformationType::Affine; } + +size_t GlobalTransform::param_count() const { return m_param_count; } + +void GlobalTransform::set_translation(const Eigen::Vector3f &translation) { + m_params[0] = translation.x(); + m_params[1] = translation.y(); + m_params[2] = translation.z(); +} + +Eigen::Vector3f GlobalTransform::translation() const { return Eigen::Vector3f(m_params[0], m_params[1], m_params[2]); } + +void GlobalTransform::set_rotation(const Eigen::Vector3f &rotation_axis_angle) { + m_params[3] = rotation_axis_angle.x(); + m_params[4] = rotation_axis_angle.y(); + m_params[5] = rotation_axis_angle.z(); +} + +Eigen::Vector3f GlobalTransform::rotation() const { return Eigen::Vector3f(m_params[3], m_params[4], m_params[5]); } + +void GlobalTransform::set_scale(const Eigen::Vector3f &scale) { + if (is_rigid()) { + throw std::logic_error("Scale is only valid for affine transforms."); + } + m_params[6] = scale.x(); + m_params[7] = scale.y(); + m_params[8] = scale.z(); +} + +Eigen::Vector3f GlobalTransform::scale() const { + if (is_rigid()) { + return Eigen::Vector3f(1.0F, 1.0F, 1.0F); + } + return Eigen::Vector3f(m_params[6], m_params[7], m_params[8]); +} + +void GlobalTransform::set_shear(const Eigen::Vector3f &shear) { + if (is_rigid()) { + throw std::logic_error("Shear is only valid for affine transforms."); + } + m_params[9] = shear.x(); + m_params[10] = shear.y(); + m_params[11] = shear.z(); +} + +Eigen::Vector3f GlobalTransform::shear() const { + if (is_rigid()) { + return Eigen::Vector3f::Zero(); + } + return Eigen::Vector3f(m_params[9], m_params[10], m_params[11]); +} + +// Decomposes a 4x4 affine transformation matrix back into its constituent parameters +// (translation, rotation, scale, shear) defined relative to a pivot point. +// +// The forward transformation is composed as: p' = T_global * T_to_pivot * R * S * Sh * T_from_pivot * p +// This can be expressed as a standard affine matrix: p' = (LinearPart * p) + TranslationPart + +// The full translation vector is derived from: T_full = T_global - LinearPart*pivot + pivot +// Rearranging gives: T_global = T_full - pivot + LinearPart*pivot +// +// Linear Part Decomposition (to find R, S, Sh): +// The linear part is a product: LinearPart = R * (S * Sh). The (Scale * Shear) term +// forms an upper-triangular matrix U. This means LinearPart = R * U. +// A QR decomposition splits a matrix into an +// orthogonal matrix Q (our rotation R) and an upper-triangular matrix R_qr (U). +// For Affine (N=12): We perform the QR decomposition. Scale values are on the diagonal +// of R_qr, and shear values are the normalized off-diagonals. +// For Rigid (N=6): S and Sh are identity matrices, so the LinearPart is already the +// pure rotation matrix R. No decomposition is needed. + +// TODO: we need to write unit tests for this function. +// TODO: also should we use Eigen::ColPivHouseholderQR instead? +GlobalTransform GlobalTransform::from_affine_compact(const transform_type &transform, + const Eigen::Vector3f &pivot, + TransformationType type) { + using Scalar = typename transform_type::Scalar; + using Vector3 = Eigen::Matrix; + using Matrix3 = Eigen::Matrix; + + const Vector3 pivotVector = pivot.template cast(); + + const Matrix3 linearPart = transform.linear(); + const Vector3 translationPart = transform.translation(); + const Vector3 globalTranslation = translationPart - pivotVector + linearPart * pivotVector; + + std::vector parameters; + const size_t N = (type == TransformationType::Rigid) ? 6 : 12; + parameters.resize(N, 0.0F); + + parameters[0] = static_cast(globalTranslation.x()); + parameters[1] = static_cast(globalTranslation.y()); + parameters[2] = static_cast(globalTranslation.z()); + + Matrix3 rotationMatrix; + + if (N == 12) { + // Decompose the linear part using QR decomposition + const Eigen::HouseholderQR qr(linearPart); + rotationMatrix = qr.householderQ(); + Matrix3 upperTriangularPart = qr.matrixQR().template triangularView(); + + // Ensure the result is a proper rotation matrix (determinant = +1), not a reflection. + if (rotationMatrix.determinant() < Scalar(0)) { + rotationMatrix.col(0) *= -1; + upperTriangularPart.row(0) *= -1; + } + + // Force positive diagonal on R as we don't want negative scales. + for (int i = 0; i < 3; ++i) { + if (upperTriangularPart(i, i) < Scalar(0)) { + rotationMatrix.col(i) *= -1; + upperTriangularPart.row(i) *= -1; + } + } + + // Ensure Q is proper after diagonal fix + if (rotationMatrix.determinant() < Scalar(0)) { + rotationMatrix.col(2) *= -1; // flip one column (e.g. the last) + upperTriangularPart.row(2) *= -1; // and the matching row in R + } + + // Extract scale and shear from the upper triangular matrix + const Scalar scaleX = upperTriangularPart(0, 0); + const Scalar scaleY = upperTriangularPart(1, 1); + const Scalar scaleZ = upperTriangularPart(2, 2); + parameters[6] = static_cast(scaleX); + parameters[7] = static_cast(scaleY); + parameters[8] = static_cast(scaleZ); + + parameters[9] = (scaleX != 0) ? static_cast(upperTriangularPart(0, 1) / scaleX) : 0.0F; + parameters[10] = (scaleX != 0) ? static_cast(upperTriangularPart(0, 2) / scaleX) : 0.0F; + parameters[11] = (scaleY != 0) ? static_cast(upperTriangularPart(1, 2) / scaleY) : 0.0F; + + } else { + // For the rigid case the linear part is the rotation matrix. No decomposition is needed. + rotationMatrix = linearPart; + } + + Eigen::AngleAxis angleAxis(rotationMatrix); + const Vector3 axisAngleVector = angleAxis.axis() * angleAxis.angle(); + parameters[3] = static_cast(axisAngleVector.x()); + parameters[4] = static_cast(axisAngleVector.y()); + parameters[5] = static_cast(axisAngleVector.z()); + + return GlobalTransform(parameters, type, pivot); +} +} // namespace MR diff --git a/cpp/core/gpu/registration/registrationtypes.h b/cpp/core/gpu/registration/registrationtypes.h new file mode 100644 index 0000000000..06c276db01 --- /dev/null +++ b/cpp/core/gpu/registration/registrationtypes.h @@ -0,0 +1,154 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "image.h" +#include +#include "types.h" + +#include +#include +#include +#include +#include +#include + +namespace MR { + +enum class TransformationType : uint8_t { Rigid, Affine }; + +// The parameters in order: +// - 3 translations +// - 3 rotations (axis-angle representation) +// - 3 scaling factors +// - 3 shearing factors +// Order of application: shear, scale, rotate, translates. +// All operations are assumed to be applied by taking the pivot +// point as the centre of the transformation. +struct GlobalTransform { + + // Throws if params.size() does not match type + explicit GlobalTransform(tcb::span params, + TransformationType type, + const Eigen::Vector3f &pivot = Eigen::Vector3f::Zero()); + + GlobalTransform inverse() const; + // Obtain a copy with a different pivot + GlobalTransform with_pivot(const Eigen::Vector3f &pivot) const; + + // Create GlobalTransform from an Eigen transform and pivot. + // If type is specified as rigid, scale and shear components are ignored. + static GlobalTransform from_affine_compact(const transform_type &tform, + const Eigen::Vector3f &pivot, + TransformationType type = TransformationType::Affine); + + // Returns a copy that keeps translation and axis-angle rotation, dropping any scale/shear terms. + GlobalTransform as_rigid() const; + // Returns a copy that includes all affine params with rigid inputs being identity scale and + // zero shear appended. + GlobalTransform as_affine() const; + + TransformationType type() const; + + tcb::span parameters() const; + void set_params(tcb::span params); + + Eigen::Vector3f pivot() const; + void set_pivot(const Eigen::Vector3f &pivot); + + // Obtain a 3x4 Eigen affine-compact transform + transform_type to_affine_compact() const; + Eigen::Matrix4f to_matrix4f() const; + + bool is_rigid() const; + bool is_affine() const; + size_t param_count() const; + + Eigen::Vector3f translation() const; + void set_translation(const Eigen::Vector3f &translation); + + void set_rotation(const Eigen::Vector3f &rotation_axis_angle); + Eigen::Vector3f rotation() const; + + // For rigid case, scale defaults to (1,1,1) and shears to (0,0,0). + // Setters throw if called on rigid transforms. + Eigen::Vector3f scale() const; + void set_scale(const Eigen::Vector3f &scale); + + Eigen::Vector3f shear() const; + void set_shear(const Eigen::Vector3f &shear); + +private: + TransformationType m_type; + // We allocate space for the maximum number of parameters. + std::array m_params{}; + size_t m_param_count = 0U; + Eigen::Vector3f m_pivot; +}; + +struct IterationResult { + float cost; + std::vector gradients; +}; + +struct NMIMetric { + uint32_t num_bins = 32; +}; + +struct SSDMetric {}; + +struct NCCMetric { + uint32_t window_radius = 0U; +}; + +using Metric = std::variant; + +enum class MetricType : uint8_t { NMI, SSD, NCC }; +enum class InitTranslationChoice : uint8_t { None, Mass, Geometric }; +enum class InitRotationChoice : uint8_t { None, Search, Moments }; + +struct InitialisationOptions { + InitTranslationChoice translation_choice = InitTranslationChoice::Mass; + InitRotationChoice rotation_choice = InitRotationChoice::None; + Metric cost_metric = NMIMetric{}; + // Limits the maximum sampled rotation angle (degrees) for search-based initialisation. + float max_search_angle_degrees = 90.0F; +}; + +using InitialGuess = std::variant; + +struct ChannelConfig { + Image image1; + Image image2; + std::optional> image1Mask; + std::optional> image2Mask; + float weight = 1.0F; +}; + +struct RegistrationConfig { + std::vector channels; + TransformationType transformation_type; + InitialGuess initial_guess; + Metric metric; + uint32_t max_iterations = 500; +}; + +struct RegistrationResult { + transform_type transformation; +}; + +} // namespace MR diff --git a/cpp/core/gpu/registration/ssdcalculator.cpp b/cpp/core/gpu/registration/ssdcalculator.cpp new file mode 100644 index 0000000000..1ce5b03f52 --- /dev/null +++ b/cpp/core/gpu/registration/ssdcalculator.cpp @@ -0,0 +1,163 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/ssdcalculator.h" + +#include "gpu/registration/calculatoroutput.h" +#include "gpu/registration/eigenhelpers.h" +#include "exception.h" +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { + +template struct SSDUniforms { + alignas(16) DispatchGrid dispatch_grid{}; + alignas(16) std::array transformationPivot{}; + alignas(16) std::array currentTransform{}; + alignas(16) VoxelScannerMatrices voxelScannerMatrices{}; +}; + +using RigidSSDUniforms = SSDUniforms<6>; +using AffineSSDUniforms = SSDUniforms<12>; +static_assert(sizeof(RigidSSDUniforms) % 16 == 0, "RigidSSDUniforms must be 16-byte aligned"); +static_assert(sizeof(AffineSSDUniforms) % 16 == 0, "AffineSSDUniforms must be 16-byte aligned"); + +constexpr WorkgroupSize ssd_workgroup_size{8, 8, 4}; + +SSDCalculator::SSDCalculator(const Config &config) + : m_output(config.output), + m_compute_context(config.context), + m_fixed(config.fixed), + m_moving(config.moving), + m_fixed_mask(config.fixed_mask.value_or(config.fixed)), + m_moving_mask(config.moving_mask.value_or(config.moving)), + m_use_fixed_mask(config.fixed_mask.has_value()), + m_use_moving_mask(config.moving_mask.has_value()), + m_voxel_scanner_matrices(config.voxel_scanner_matrices) { + assert(m_compute_context != nullptr); + const bool is_rigid = config.transformation_type == TransformationType::Rigid; + m_degrees_of_freedom = is_rigid ? 6U : 12U; + + m_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, ssd_workgroup_size); + + const uint32_t uniforms_size = is_rigid ? sizeof(RigidSSDUniforms) : sizeof(AffineSSDUniforms); + m_uniforms_buffer = m_compute_context->new_empty_buffer(uniforms_size, BufferType::UniformBuffer); + + const size_t params_per_workgroup = 1U + m_degrees_of_freedom; + m_partials_buffer = m_compute_context->new_empty_buffer(params_per_workgroup * m_dispatch_grid.workgroup_count()); + m_num_contributing_voxels_buffer = m_compute_context->new_empty_buffer(1); + + m_kernel = m_compute_context->new_kernel({ + .compute_shader = + { + .shader_source = ShaderFile{"shaders/registration/ssd.slang"}, + .entryPoint = "main", + .workgroup_size = ssd_workgroup_size, + .constants = {{"kUseSourceMask", static_cast(m_use_moving_mask)}, + {"kUseTargetMask", static_cast(m_use_fixed_mask)}, + {"kComputeGradients", + static_cast(m_output == CalculatorOutput::CostAndGradients)}}, + .entry_point_args = {is_rigid ? "RigidTransformation" : "AffineTransformation"}, + }, + .bindings_map = {{"uniforms", m_uniforms_buffer}, + {"sourceImage", m_moving}, + {"targetImage", m_fixed}, + {"sourceMask", m_moving_mask}, + {"targetMask", m_fixed_mask}, + {"linearSampler", m_compute_context->new_linear_sampler()}, + {"ssdAndGradientsPartials", m_partials_buffer}, + {"numContributingVoxels", m_num_contributing_voxels_buffer}}, + }); +} +void SSDCalculator::update(const GlobalTransform &transformation) { + assert(transformation.param_count() == m_degrees_of_freedom); + m_compute_context->clear_buffer(m_num_contributing_voxels_buffer); + + const std::array pivotArray = EigenHelpers::to_array(transformation.pivot()); + if (transformation.is_affine()) { + std::array params; + const auto current = transformation.parameters(); + std::copy_n(current.begin(), 12, params.begin()); + const AffineSSDUniforms uniforms{ + .dispatch_grid = m_dispatch_grid, + .transformationPivot = pivotArray, + .currentTransform = params, + .voxelScannerMatrices = m_voxel_scanner_matrices, + }; + m_compute_context->write_to_buffer(m_uniforms_buffer, &uniforms, sizeof(AffineSSDUniforms)); + } else { + std::array params; + const auto current = transformation.parameters(); + std::copy_n(current.begin(), 6, params.begin()); + const RigidSSDUniforms uniforms{ + .dispatch_grid = m_dispatch_grid, + .transformationPivot = pivotArray, + .currentTransform = params, + .voxelScannerMatrices = m_voxel_scanner_matrices, + }; + m_compute_context->write_to_buffer(m_uniforms_buffer, &uniforms, sizeof(RigidSSDUniforms)); + } + + m_compute_context->dispatch_kernel(m_kernel, m_dispatch_grid); +} + +IterationResult SSDCalculator::get_result() const { + const auto partials = m_compute_context->download_buffer_as_vector(m_partials_buffer); + const size_t paramsPerWorkgroup = 1U + m_degrees_of_freedom; + const size_t workgroups = m_dispatch_grid.workgroup_count(); + if (partials.size() < paramsPerWorkgroup * workgroups) { + throw MR::Exception("SSDCalculator: partials buffer size mismatch."); + } + + double cost = 0.0; + const bool computeGradients = m_output == CalculatorOutput::CostAndGradients; + std::vector gradients; + if (computeGradients) { + gradients.assign(m_degrees_of_freedom, 0.0); + } + for (size_t wg = 0; wg < workgroups; ++wg) { + const size_t base = wg * paramsPerWorkgroup; + cost += partials[base]; + if (computeGradients) { + for (size_t i = 0; i < m_degrees_of_freedom; ++i) { + gradients[i] += partials[base + 1 + i]; + } + } + } + + if (!computeGradients) { + return IterationResult{static_cast(cost), {}}; + } + + std::vector gradientsF; + gradientsF.reserve(m_degrees_of_freedom); + for (double value : gradients) { + gradientsF.push_back(static_cast(value)); + } + + return IterationResult{static_cast(cost), std::move(gradientsF)}; +} +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/ssdcalculator.h b/cpp/core/gpu/registration/ssdcalculator.h new file mode 100644 index 0000000000..078aa4a97f --- /dev/null +++ b/cpp/core/gpu/registration/ssdcalculator.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/registration/calculatoroutput.h" +#include "gpu/gpu.h" +#include "gpu/registration/registrationtypes.h" +#include "gpu/registration/voxelscannermatrices.h" + +#include +#include + +namespace MR::GPU { +class SSDCalculator { +public: + struct Config { + TransformationType transformation_type = TransformationType::Affine; + Texture fixed; + Texture moving; + std::optional fixed_mask; + std::optional moving_mask; + VoxelScannerMatrices voxel_scanner_matrices{}; + CalculatorOutput output = CalculatorOutput::CostAndGradients; + const ComputeContext *context = nullptr; + }; + explicit SSDCalculator(const Config &config); + + void update(const GlobalTransform &transformation); + IterationResult get_result() const; + +private: + CalculatorOutput m_output = CalculatorOutput::CostAndGradients; + const ComputeContext *m_compute_context = nullptr; + Buffer m_uniforms_buffer; + Buffer m_partials_buffer; + Buffer m_num_contributing_voxels_buffer; + + Kernel m_kernel; + + Texture m_fixed; + Texture m_moving; + Texture m_fixed_mask; + Texture m_moving_mask; + bool m_use_fixed_mask = false; + bool m_use_moving_mask = false; + + DispatchGrid m_dispatch_grid; + VoxelScannerMatrices m_voxel_scanner_matrices; + uint32_t m_degrees_of_freedom = 0; +}; +} // namespace MR::GPU diff --git a/cpp/core/gpu/registration/utils.cpp b/cpp/core/gpu/registration/utils.cpp new file mode 100644 index 0000000000..474b8319a8 --- /dev/null +++ b/cpp/core/gpu/registration/utils.cpp @@ -0,0 +1,144 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#include "gpu/registration/utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +#include +#elif defined(__APPLE__) +#include +#elif defined(__linux__) +#endif + +using namespace std::string_literals; + +uint32_t Utils::nextMultipleOf(const uint32_t value, const uint32_t multiple) { + if (value > std::numeric_limits::max() - multiple) { + return std::numeric_limits::max(); + } + return (value + multiple - 1) / multiple * multiple; +} + +std::string Utils::readFile(const std::filesystem::path &filePath, ReadFileMode mode) { + if (!std::filesystem::exists(filePath)) { + throw std::runtime_error("File not found: "s + filePath.string()); + } + + const auto openMode = (mode == ReadFileMode::Binary) ? std::ios::in | std::ios::binary : std::ios::in; + std::ifstream f(filePath, std::ios::in | openMode); + const auto fileSize64 = std::filesystem::file_size(filePath); + if (fileSize64 > static_cast(std::numeric_limits::max())) { + throw std::runtime_error("File too large to read into memory: "s + filePath.string()); + } + const std::streamsize fileSize = static_cast(fileSize64); + std::string result(static_cast(fileSize), '\0'); + f.read(result.data(), fileSize); + + return result; +} + +std::filesystem::path Utils::getExecutablePath() { +#if defined(_WIN32) + wchar_t buffer[MAX_PATH]; + const DWORD len = GetModuleFileNameW(NULL, buffer, MAX_PATH); + if (len == 0) { + throw std::runtime_error("GetModuleFileNameW failed. Error: " + std::to_string(GetLastError())); + } + if (len == MAX_PATH) { + throw std::runtime_error("GetModuleFileNameW failed: Buffer too small, path truncated."); + } + return std::filesystem::path(buffer); // buffer content is copied + +#elif defined(__APPLE__) + uint32_t size = 0; + if (_NSGetExecutablePath(nullptr, &size) != -1) { + throw std::runtime_error("_NSGetExecutablePath: Unexpected behavior when querying buffer size."); + } + std::vector bufferVec(size); + if (_NSGetExecutablePath(bufferVec.data(), &size) != 0) { + throw std::runtime_error("_NSGetExecutablePath: Failed to retrieve executable path."); + } + const std::filesystem::path initialPath(bufferVec.data()); + std::error_code ec; + const std::filesystem::path canonicalPath = std::filesystem::canonical(initialPath, ec); + if (ec) { + throw std::runtime_error("Failed to get canonical path for '" + initialPath.string() + "': " + ec.message()); + } + return canonicalPath; + +#elif defined(__linux__) + const std::string linkPathStr = "/proc/self/exe"; // const as it's not modified + std::error_code ec; + + const std::filesystem::path symlinkPath(linkPathStr); + const std::filesystem::path p = std::filesystem::read_symlink(symlinkPath, ec); + + if (ec) { + throw std::runtime_error("read_symlink(\"" + linkPathStr + "\") failed: " + ec.message()); + } + const std::filesystem::path canonicalP = std::filesystem::canonical(p, ec); + if (ec) { + throw std::runtime_error("canonical(\"" + p.string() + "\") failed: " + ec.message()); + } + return canonicalP; + +#else +#error Unsupported platform +#endif +} + +std::string Utils::randomString(size_t length) { + static const std::string characterSet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + + // 1. A static random engine that is seeded only once. + static std::mt19937 generator = []() { + std::random_device rd; + return std::mt19937(rd()); + }(); + + // 2. A distribution that maps to the indices of the character set. + // We create it once and reuse it. + static std::uniform_int_distribution distribution(0, characterSet.size() - 1); + + std::string result; + result.reserve(length); + + // 3. Fill the string with random characters. + for (size_t i = 0; i < length; ++i) { + result += characterSet[distribution(generator)]; + } + + return result; +} + +std::string Utils::hash_string(const std::string &input) { + const std::hash hasher; + const size_t hashValue = hasher(input); + return std::to_string(hashValue); +} diff --git a/cpp/core/gpu/registration/utils.h b/cpp/core/gpu/registration/utils.h new file mode 100644 index 0000000000..a4cc7de19e --- /dev/null +++ b/cpp/core/gpu/registration/utils.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include +#include +#include +#include + +namespace Utils { + +// Divides the input vector into equal-sized rows (each row having "chunkSize" elements) +// and then performs a column-wise accumulation using the provided binary operator. +// e.g. { 1, 2, 3, 4, 5, 6 } with chunkSize = 2, we form the "matrix" +// [1, 3, 5] +// [2, 4, 6] +// and then perform the operation on each column. +template +std::vector chunkReduce(const std::vector &data, size_t chunkSize, BinaryOp op) { + if (chunkSize == 0) { + throw std::invalid_argument("chunkSize cannot be zero."); + } + if (data.size() % chunkSize != 0) { + throw std::invalid_argument("vector size must be a multiple of chunkSize."); + } + + const size_t numRows = data.size() / chunkSize; + std::vector result(chunkSize, T{}); + + for (size_t row = 0; row < numRows; ++row) { + for (size_t col = 0; col < chunkSize; ++col) { + result[col] = op(result[col], data[row * chunkSize + col]); + } + } + return result; +} + +// Returns the smallest multiple of `multiple` that is greater or equal to `value`. +uint32_t nextMultipleOf(const uint32_t value, const uint32_t multiple); + +enum ReadFileMode { Text, Binary }; +std::string readFile(const std::filesystem::path &filePath, ReadFileMode mode = ReadFileMode::Text); + +std::filesystem::path getExecutablePath(); + +std::string randomString(size_t length); + +std::string hash_string(const std::string &input); +} // namespace Utils diff --git a/cpp/core/gpu/registration/voxelscannermatrices.h b/cpp/core/gpu/registration/voxelscannermatrices.h new file mode 100644 index 0000000000..a1d8ce20b6 --- /dev/null +++ b/cpp/core/gpu/registration/voxelscannermatrices.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2008-2025 the MRtrix3 contributors. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * Covered Software is provided under this License on an "as is" + * basis, without warranty of any kind, either expressed, implied, or + * statutory, including, without limitation, warranties that the + * Covered Software is free of defects, merchantable, fit for a + * particular purpose or non-infringing. + * See the Mozilla Public License v. 2.0 for more details. + * + * For more details, see http://www.mrtrix.org/. + */ + +#pragma once + +#include "gpu/registration/eigenhelpers.h" +#include "image.h" +#include "transform.h" +#include + +namespace MR::GPU { + +// This struct provides 4x4 matrices for converting between voxel and scanner spaces +// for both moving and fixed images. It's designed for use in GPU buffers. + +struct alignas(16) VoxelScannerMatrices { + std::array voxel_to_scanner_moving; + std::array voxel_to_scanner_fixed; + std::array scanner_to_voxel_moving; + std::array scanner_to_voxel_fixed; + + static VoxelScannerMatrices from_image_pair(const Image &moving, const Image &fixed, + float scale_factor = 1.0F) { + const Eigen::Matrix4f scale_matrix = EigenHelpers::make_scaling_mat4f(scale_factor); + + const auto moving_transform = MR::Transform(moving); + const auto fixed_transform = MR::Transform(fixed); + + const Eigen::Matrix4f voxel_to_scanner_moving_mat = + EigenHelpers::to_homogeneous_mat4f(moving_transform.voxel2scanner) * scale_matrix; + const Eigen::Matrix4f voxel_to_scanner_fixed_mat = + EigenHelpers::to_homogeneous_mat4f(fixed_transform.voxel2scanner) * scale_matrix; + + const Eigen::Matrix4f scanner_to_voxel_moving_mat = voxel_to_scanner_moving_mat.inverse(); + const Eigen::Matrix4f scanner_to_voxel_fixed_mat = voxel_to_scanner_fixed_mat.inverse(); + + return VoxelScannerMatrices{.voxel_to_scanner_moving = EigenHelpers::to_array(voxel_to_scanner_moving_mat), + .voxel_to_scanner_fixed = EigenHelpers::to_array(voxel_to_scanner_fixed_mat), + .scanner_to_voxel_moving = EigenHelpers::to_array(scanner_to_voxel_moving_mat), + .scanner_to_voxel_fixed = EigenHelpers::to_array(scanner_to_voxel_fixed_mat)}; + } +}; + +} // namespace MR::GPU diff --git a/cpp/core/gpu/shaders/atomic_utils.slang b/cpp/core/gpu/shaders/atomic_utils.slang new file mode 100644 index 0000000000..2f31462277 --- /dev/null +++ b/cpp/core/gpu/shaders/atomic_utils.slang @@ -0,0 +1,17 @@ +module atomic_utils; + +// The use of __ref is necessary to pass a reference to an Atomic object. +// This is undocumented behaviour. +// See // https://github.com/shader-slang/slang/issues/5941#issuecomment-2564397693 +public float atomicAddF32InMemory(__ref Atomic sum, float value) +{ + for(uint oldBits = sum.load(); ; ) + { + uint newBits = asuint(asfloat(oldBits) + value); + uint prevBits = sum.compareExchange(oldBits, newBits); + if (prevBits == oldBits) { + return asfloat(newBits); + } + oldBits = prevBits; + } +} diff --git a/cpp/core/gpu/shaders/center_of_mass.slang b/cpp/core/gpu/shaders/center_of_mass.slang new file mode 100644 index 0000000000..59da3055c7 --- /dev/null +++ b/cpp/core/gpu/shaders/center_of_mass.slang @@ -0,0 +1,63 @@ +import atomic_utils; + +extern static const uint32_t kWorkgroupSizeX = 8; +extern static const uint32_t kWorkgroupSizeY = 8; +extern static const uint32_t kWorkgroupSizeZ = 4; +// TODO: Use extern static const bool once the Slang compiler bug is fixed. +extern static const uint32_t kUseMask; + +static const uint32_t kWorkgroupInvocations = kWorkgroupSizeX * kWorkgroupSizeY * kWorkgroupSizeZ; + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t localIndex: SV_GroupIndex, + RWStructuredBuffer> totalIntensity, + RWStructuredBuffer> weightedPositions, + Texture3D image, + Texture3D mask) +{ + static groupshared float localTotalIntensity[kWorkgroupInvocations]; + static groupshared float3 localWeightedPositions[kWorkgroupInvocations]; + + var imageDim : uint32_t3; + image.GetDimensions(imageDim.x, imageDim.y, imageDim.z); + + localWeightedPositions[localIndex] = float3(0.0f, 0.0f, 0.0f); + localTotalIntensity[localIndex] = 0.0f; + + if (all(globalId < imageDim)) { + var includeVoxel = true; + if (kUseMask != 0U) { + let maskValue = mask.Load(int4(globalId, 0)).r; + if (maskValue < 0.5f) { + includeVoxel = false; + } + } + + if (includeVoxel) { + let voxelValue = image.Load(int4(globalId, 0)).r; + let voxelCoord = float3(globalId); + localWeightedPositions[localIndex] = voxelValue * voxelCoord; + localTotalIntensity[localIndex] = voxelValue; + } + } + + GroupMemoryBarrierWithGroupSync(); + + for (uint offset = kWorkgroupInvocations / 2; offset > 0; offset /= 2) { + if (localIndex < offset) { + localWeightedPositions[localIndex] += localWeightedPositions[localIndex + offset]; + localTotalIntensity[localIndex] += localTotalIntensity[localIndex + offset]; + } + GroupMemoryBarrierWithGroupSync(); + } + + if (localIndex == 0) { + atomicAddF32InMemory(weightedPositions[0], localWeightedPositions[0].x); + atomicAddF32InMemory(weightedPositions[1], localWeightedPositions[0].y); + atomicAddF32InMemory(weightedPositions[2], localWeightedPositions[0].z); + atomicAddF32InMemory(totalIntensity[0], localTotalIntensity[0]); + } +} diff --git a/cpp/core/gpu/shaders/downsample_image.slang b/cpp/core/gpu/shaders/downsample_image.slang new file mode 100644 index 0000000000..42bddbb664 --- /dev/null +++ b/cpp/core/gpu/shaders/downsample_image.slang @@ -0,0 +1,54 @@ +// This shader downsamples by a factor of 2 in each dimension using a 4-tap binomial filter. +// TODO: Consider using shared memory for better performance. +// TODO: Handle anisotropic downsampling factors. +extern static const uint kWorkgroupSizeX = 8; +extern static const uint kWorkgroupSizeY = 8; +extern static const uint kWorkgroupSizeZ = 4; + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main( + uint3 id: SV_DispatchThreadID, + Texture3D inputTexture, + WTexture3D outputTexture) +{ + uint3 inputDims; + inputTexture.GetDimensions(inputDims.x, inputDims.y, inputDims.z); + uint3 outputDims; + outputTexture.GetDimensions(outputDims.x, outputDims.y, outputDims.z); + + if (any(id >= outputDims)) + { + return; + } + + let baseCoord = 2 * id; + + // Binomial coefficients for a 4-tap filter + let weightsX = float4(1.0, 3.0, 3.0, 1.0); + let weightsY = float4(1.0, 3.0, 3.0, 1.0); + let weightsZ = float4(1.0, 3.0, 3.0, 1.0); + + var accum = 0.0; + var totalWeight = 0.0; + for(var dz = -1; dz <= 2; dz++) { + let wz = weightsZ[dz + 1]; + let z = clamp(int32_t(baseCoord.z) + dz, 0, int32_t(inputDims.z) - 1); + for(var dy = -1; dy <= 2; dy++) { + let wy = weightsY[dy + 1]; + let y = clamp(int32_t(baseCoord.y) + dy, 0, int32_t(inputDims.y) - 1); + for(var dx = -1; dx <= 2; dx++) { + let wx = weightsX[dx + 1]; + let x = clamp(int32_t(baseCoord.x) + dx, 0, int32_t(inputDims.x) - 1); + + let weight = wx * wy * wz; + accum += weight * inputTexture[uint3(x, y, z)]; + totalWeight += weight; + } + } + } + + let result = accum / totalWeight; + + outputTexture.Store(id, result); +} diff --git a/cpp/core/gpu/shaders/parzen_binner.slang b/cpp/core/gpu/shaders/parzen_binner.slang new file mode 100644 index 0000000000..8253c91a9f --- /dev/null +++ b/cpp/core/gpu/shaders/parzen_binner.slang @@ -0,0 +1,88 @@ +import cubic_bspline; + +// Parzen binner helper as a templated struct. +// Provides intensity -> bin mapping, neighbourhood computation and +// precomputation of cubic B-spline weights and derivatives. + +struct ParzenBinner +{ + // Small epsilon for numeric stability + static const float epsilon = 1e-9F; + + // Effective number of bins excluding padding on both ends + static float effectiveRangeBins() + { + return float((NumBins - 1) - 2 * Padding); + } + + // Map an intensity into bin-space (clamped) + static float mapIntensityToBin(float intensity, float minVal, float maxVal) + { + let range = max(maxVal - minVal, epsilon); + let eff = effectiveRangeBins(); + var bin = (intensity - minVal) / range * eff + float(Padding); + return clamp(bin, float(Padding), float(NumBins - 1 - Padding)); + } + + // Compute start/end bin indices (4-wide neighbourhood for cubic B-spline) + static void computeBinNeighbourhood(float bin, out uint32_t start, out uint32_t end) + { + let centre = int32_t(floor(bin)); + start = uint32_t(max(centre - 1, 0)); + end = uint32_t(min(centre + 2, int32_t(NumBins - 1))); + } + + // Precompute B-spline weights for bins in [start..end] + static void computeWeights(float bin, uint32_t start, uint32_t end, + out float4 weights, out uint32_t count) + { + count = 0U; + for (uint32_t i = start; i <= end; ++i) { + weights[count] = cubicBSpline(bin - float(i)); + ++count; + } + + // Initialize remaining weights to zero + for (uint32_t i = count; i < 4; ++i) { + weights[i] = 0.0F; + } + } + + // Precompute weights and their derivatives for bins in [start..end] + static void computeWeightsAndDerivatives(float bin, uint32_t start, uint32_t end, + out float4 weights, + out float4 derivatives, + out uint32_t count) + { + count = 0U; + for (uint32_t j = start; j <= end; ++j) { + let rel = bin - float(j); + weights[count] = cubicBSpline(rel); + derivatives[count] = cubicBSplineDerivative(rel); + ++count; + } + + // Initialize remaining weights and derivatives to zero + for (uint32_t i = count; i < 4; ++i) { + weights[i] = 0.0F; + derivatives[i] = 0.0F; + } + } + + static void computeDerivatives(float bin, uint32_t start, uint32_t end, + out float4 derivatives, + out uint32_t count) + { + count = 0U; + for (uint32_t j = start; j <= end; ++j) { + let rel = bin - float(j); + derivatives[count] = cubicBSplineDerivative(rel); + ++count; + } + + // Initialize remaining weights and derivatives to zero + for (uint32_t i = count; i < 4; ++i) { + derivatives[i] = 0.0F; + } + } +}; \ No newline at end of file diff --git a/cpp/core/gpu/shaders/reduction_image.slang b/cpp/core/gpu/shaders/reduction_image.slang new file mode 100644 index 0000000000..d2370ecd2c --- /dev/null +++ b/cpp/core/gpu/shaders/reduction_image.slang @@ -0,0 +1,214 @@ +// Compute shader performing per-workgroup reduction on a 3D texture across multiple operations. +// Threads load voxel intensities into shared memory, then a parallel reduction is applied for +// each IReduceOp in the generic pack. + +import reduction_utils; +import texture_utils; + +extern static const uint kWorkgroupSizeX = 8; +extern static const uint kWorkgroupSizeY = 8; +extern static const uint kWorkgroupSizeZ = 4; +static const uint kWorkgroupTotalSize = kWorkgroupSizeX * kWorkgroupSizeY * kWorkgroupSizeZ; + +typealias ToFloat = float; + + +// A "meta" reduction operation that applies a pack of operations +// to a tuple of values. +struct CompositeReduceOp> : +IReduceOp>> +{ + typealias ValueTuple = Tuple>; + static const int ValueTupleSize = countof(expand each Operation); + + static ValueTuple identityElement() + { + return makeTuple(expand(each Operation).identityElement()); + } + + static ValueTuple reduce(ValueTuple a, ValueTuple b) + { + return makeTuple(expand(each Operation).reduce(each a, each b)); + } +}; + +struct Uniforms +{ + uint32_t3 dispatchGrid; +}; + +struct UniformsWithTransform { + uint32_t3 dispatchGrid; + float4x4 imageTransform; +}; + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main_sum( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + uint32_t3 workgroupId: SV_GroupID, + ConstantBuffer uniforms, + Texture3D inputImage, + RWStructuredBuffer output, + SamplerState sampler) +{ + reduce(globalId, localId, workgroupId, uniforms.dispatchGrid, inputImage, output, none, sampler); +} + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main_sum_with_transform( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + uint32_t3 workgroupId: SV_GroupID, + ConstantBuffer uniforms, + Texture3D inputImage, + RWStructuredBuffer output, + SamplerState sampler) +{ + reduce(globalId, localId, workgroupId, uniforms.dispatchGrid, inputImage, output, uniforms.imageTransform, sampler); +} + +void reduce>( + uint32_t3 globalId, + uint32_t3 localId, + uint32_t3 workgroupId, + uint32_t3 dispatchGrid, + Texture3D inputImage, + RWStructuredBuffer output, + Optional transformation, + SamplerState sampler +) +{ + typealias ValueTuple = CompositeReduceOp.ValueTuple; + typealias CompositeOperation = CompositeReduceOp; + static const int operationsCount = countof(expand each Operation); + static groupshared Array localData; + + uint3 texDim; + inputImage.GetDimensions(texDim.x, texDim.y, texDim.z); + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + + // Each thread stores its corresponding voxel intensity in the local data array. + ValueTuple threadValue; + if (all(globalId < texDim)) { + if (transformation.hasValue) { + let transformedVoxel4 : float4 = mul(transformation.value, float4(globalId, 1.0f)); + let transformedVoxel = transformedVoxel4.xyz / transformedVoxel4.w; + float voxelIntensity = 0.0F; + if (all(transformedVoxel >= 0.0F) && all(transformedVoxel <= float3(texDim - 1u))) { + voxelIntensity = inputImage.SampleLevel(sampler, (transformedVoxel.xyz + 0.5F) / float3(texDim), 0.0).r; + } + threadValue = makeTuple(expand ToFloat(voxelIntensity)); + } + else { + let voxelIntensity = inputImage.Load(int4(globalId, 0)); + threadValue = makeTuple(expand ToFloat(voxelIntensity)); + } + } + else { + threadValue = CompositeOperation.identityElement(); + } + localData[localIndex] = threadValue; + GroupMemoryBarrierWithGroupSync(); + + let resultValue = workgroupReduce(localData, localIndex); + + // The first thread in the workgroup writes the final result for the group. + if (localIndex == 0) + { + let wgIndex = workgroupId.x + + workgroupId.y * dispatchGrid[0] + + workgroupId.z * dispatchGrid[0] * dispatchGrid[1]; + int i = 0; + expand output[wgIndex * operationsCount + i++] = each resultValue; + } +} + + +void reduceAtomic>( + uint32_t3 globalId, + uint32_t3 localId, + uint32_t3 workgroupId, + uint32_t3 dispatchGrid, + Texture3D inputImage, + RWStructuredBuffer> output, // in global memory + Optional transformation, + SamplerState sampler) +{ + static const int operationsCount = countof(expand each Operation); + static_assert(operationsCount > 0, "At least one operation must be provided"); + + static groupshared Array, operationsCount> localData; + + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + if(localIndex == 0) { + int i = 0; + expand localData[i++].store((each Operation).identityElement()); + } + + GroupMemoryBarrierWithGroupSync(); + + uint32_t3 textureDims = textureSize(inputImage); + + if(all(globalId < textureDims)) { + float value; + if (transformation.hasValue) + { + let samplingField = VoxelSamplingField(inputImage, sampler); + let transformedVoxel = mul(transformation.value, float4(globalId, 1.0f)).xyz; + float voxelIntensity = 0.0F; + if (all(transformedVoxel >= 0.0F) && all(transformedVoxel <= float3(textureDims - 1u))) { + voxelIntensity = samplingField.sample(transformedVoxel); + } + value = voxelIntensity; + } + else + { + value = inputImage.Load(int4(globalId, 0)).r; + } + + int i = 0; + expand (each Operation).atomicReduce(localData[i++], value); + } + + GroupMemoryBarrierWithGroupSync(); + + // Global reduction + if(localIndex == 0) { + int j = 0; + expand (each Operation).atomicReduceEncoded(output[j], localData[j++].load()); + } +} + +// Min/Max reduction with atomic operations using ordered-float encoding to handle negatives +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void minMaxAtomic( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + uint32_t3 workgroupId: SV_GroupID, + ConstantBuffer uniforms, + Texture3D inputTexture, + RWStructuredBuffer> outputBuffer, + SamplerState sampler) +{ + reduceAtomic(globalId, localId, workgroupId, uniforms.dispatchGrid, inputTexture, outputBuffer, none, sampler); +} + +// Min/Max reduction with atomic operations and a given transformation from the target coordinates +// to the source coordinates. +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void minMaxAtomicWithTransform( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + uint32_t3 workgroupId: SV_GroupID, + ConstantBuffer uniforms, + Texture3D inputTexture, + RWStructuredBuffer> outputBuffer, + SamplerState sampler) +{ + reduceAtomic(globalId, localId, workgroupId, uniforms.dispatchGrid, inputTexture, outputBuffer, uniforms.imageTransform, sampler); +} diff --git a/cpp/core/gpu/shaders/reduction_utils.slang b/cpp/core/gpu/shaders/reduction_utils.slang new file mode 100644 index 0000000000..9746ad2bb2 --- /dev/null +++ b/cpp/core/gpu/shaders/reduction_utils.slang @@ -0,0 +1,243 @@ +module ReductionUtils; +import atomic_utils; + +public interface IReduceOp +{ + static T identityElement(); + static T reduce(T a, T b); +}; + +// Value: the logical value type (float | uint32_t | int32_t) +// Storage: the type used in the atomic variable (uint32_t for float; same as Value for ints) +public interface IAtomicReduceOp +{ + static Storage identityElement(); + static void atomicReduce(__ref Atomic dest, Value value); + static void atomicReduceEncoded(__ref Atomic dest, Storage value); +} + +public struct SumOp : IReduceOp where T : IArithmetic +{ + public static T identityElement() { return T(0); } + public static T reduce(T a, T b) { return a + b; } +}; + +public struct SumArrayOp : IReduceOp> where T : IArithmetic +{ + public static Array identityElement() { Array v; for (uint i = 0; i < N; i++) v[i] = T(0); return v; } + public static Array reduce(Array a, Array b) + { + Array r; + for (uint i = 0; i < N; i++) + r[i] = a[i] + b[i]; + return r; + } +}; + +public struct SumFloatOp : IReduceOp +{ + public static float identityElement() { return 0.0F; } + public static float reduce(float a, float b) { return a + b; } +}; + +public struct MinFloatOp : IReduceOp +{ + public static float identityElement() { return float.maxValue; } + public static float reduce(float a, float b) { return min(a, b); + } +}; + +public struct MaxFloatOp : IReduceOp +{ + public static float identityElement() { return float.minValue; } + public static float reduce(float a, float b) { return max(a, b); } +}; + +// NOTE: These operations only works for positive floats +// We exploit the fact given two positive float a,b with a > b, we have +// asuint(a) > asuint(b). +public struct AtomicMinPositiveFloatOp : IAtomicReduceOp +{ + public static uint32_t identityElement() { return 0xFFFFFFFFu; } // +INF + public static void atomicReduce(__ref Atomic dst, float v) + { + dst.min(asuint(max(v, 0.0F))); + } + public static void atomicReduceEncoded(__ref Atomic dst, uint32_t v) + { + dst.min(v); + } +} + +public struct AtomicMaxPositiveFloatOp : IAtomicReduceOp +{ + public static uint32_t identityElement() { return 0; } // -INF + public static void atomicReduce(__ref Atomic dst, float v) + { + dst.max(asuint(v)); + } + public static void atomicReduceEncoded(__ref Atomic dst, uint32_t v) + { + dst.max(v); + } +} + +// Encode floats into an order-preserving uint so unsigned comparisons match float ordering +// (including negatives). Mapping: for positive values flip the sign bit; for negative values +// bitwise-not the whole word. +// See https://stackoverflow.com/a/72461459 +public uint32_t floatToOrderedUint(float v) +{ + let bits = asuint(v); + let mask = (bits & 0x80000000u) != 0u ? 0xFFFFFFFFu : 0x80000000u; + return bits ^ mask; +} + +public float orderedUintToFloat(uint32_t v) +{ + let mask = (v & 0x80000000u) != 0u ? 0x80000000u : 0xFFFFFFFFu; + return asfloat(v ^ mask); +} + +public struct AtomicMinFloatOp : IAtomicReduceOp +{ + public static uint32_t identityElement() { return floatToOrderedUint(float.maxValue); } + public static void atomicReduce(__ref Atomic dst, float v) + { + dst.min(floatToOrderedUint(v)); + } + public static void atomicReduceEncoded(__ref Atomic dst, uint32_t v) + { + dst.min(v); + } +} + +public struct AtomicMaxFloatOp : IAtomicReduceOp +{ + public static uint32_t identityElement() { return floatToOrderedUint(-float.maxValue); } + public static void atomicReduce(__ref Atomic dst, float v) + { + dst.max(floatToOrderedUint(v)); + } + public static void atomicReduceEncoded(__ref Atomic dst, uint32_t v) + { + dst.max(v); + } +} + +// A function to perform a parallel reduction using the provided operation within a workgroup. +// Note: The size of the workgroup data must be a power of two. +public DataType workgroupReduce, uint32_t workgroupDataSize> + (__ref Array localData,uint32_t localIndex) +{ + static_assert(workgroupDataSize <= 1024, "workgroupDataSize must be <= 1024"); + static_assert((workgroupDataSize & (workgroupDataSize - 1)) == 0, "workgroupDataSize must be a power of two"); + + for (uint offset = workgroupDataSize / 2; offset > 0; offset /= 2) + { + if (localIndex < offset) + { + localData[localIndex] = Operation.reduce(localData[localIndex], localData[localIndex + offset]); + } + GroupMemoryBarrierWithGroupSync(); + } + return localData[0]; +} + + +interface IReduceWaveOp +{ + inline static T identityElement(); + inline static T waveReduce(T value); +} + +public struct ISumArrayWaveOp : IReduceWaveOp> where T : __BuiltinArithmeticType { + inline static Array identityElement() { return Array(); } + inline static Array waveReduce(Array array) { + Array reduced; + for (var i = 0; i < N; ++i) { reduced[i] = WaveActiveSum(array[i]); } + return reduced; + } + inline static Array reduce(Array a, Array b) { + Array reduced; + for (var i = 0; i < N; ++i) { reduced[i] = a[i] + b[i]; } + return reduced; + } +} + +uint32_t wgslWorkgroupUniformLoad(uint32_t value) +{ + // See https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin + __intrinsic_asm "workgroupUniformLoad(&$0)"; +} + +[ForceInline] +public DataType workgroupReduceWithWaves> +(__ref DataType threadValue, __ref DataType[] wavePartials, uint32_t numWaves, uint32_t waveIndex, uint32_t localIndex) +{ + // TODO: if this function is called multiple times in a kernel, then we are unnecessarily reading/writing from/to + // shared memory. Should we take this parameters as function arguments instead? + static groupshared uint32_t wgUniformNumWaves; + static groupshared uint32_t wgUniformWaveSize; + + let waveReduced = Operation.waveReduce(threadValue); + let laneIndex = WaveGetLaneIndex(); + let waveSize = WaveGetLaneCount(); + + if (laneIndex == 0U) { + wavePartials[waveIndex] = waveReduced; + } + + if (localIndex == 0) { + wgUniformNumWaves = numWaves; + wgUniformWaveSize = waveSize; + } + + GroupMemoryBarrierWithGroupSync(); + + // We need this workaround because otherwise we break WGSL's subgroup uniformity analysis + // See https://github.com/shader-slang/slang/issues/8774 + let uniformNumWaves = wgslWorkgroupUniformLoad(wgUniformNumWaves); + let uniformWaveSize = wgslWorkgroupUniformLoad(wgUniformWaveSize); + // We now each a single partial item for each wave, so we iterate in a loop + // to combine them into one, each round shrinking the problem by waveSize. + // For example, assume waveSize=4 and numWaves=10, then: + // Indices: 0 1 2 3 4 5 6 7 8 9 + // Partials: [0 1 2 3] [4 5 6 7] [8 9] + // ^block 0 ^block 1 ^block 2 + // After the first iteration, we will have 3 remaining items. + + if (uniformNumWaves > uniformWaveSize) { + var remaining = uniformNumWaves; + while (remaining > 1U) { + let baseIndex = waveIndex * uniformWaveSize; + // Each wave will reduced up to waveSize items, but the last wave may have fewer and + // some waves will have none. + let blockCount = (baseIndex < remaining) ? min(uniformWaveSize, remaining - baseIndex) : 0U; + let laneValue = laneIndex < blockCount ? wavePartials[baseIndex + laneIndex] : Operation.identityElement(); + GroupMemoryBarrierWithGroupSync(); + + let blockReduced = Operation.waveReduce(laneValue); + let newRemaining = (remaining + uniformWaveSize - 1) / uniformWaveSize; + + if (laneIndex == 0U && waveIndex < newRemaining) { + wavePartials[waveIndex] = blockReduced; + } + + GroupMemoryBarrierWithGroupSync(); + remaining = newRemaining; + } + } + + // If the number of remaning elements can fit in the subgroup size we can just need an extra wave reduction + else { + let laneValue = (waveIndex == 0 && laneIndex < uniformNumWaves) ? wavePartials[laneIndex] : Operation.identityElement(); + let finalValue = Operation.waveReduce(laneValue); + if (waveIndex == 0) { + wavePartials[0] = finalValue; + } + GroupMemoryBarrierWithGroupSync(); + } + + return wavePartials[0]; +} diff --git a/cpp/core/gpu/shaders/registration/coordinate_mapper.slang b/cpp/core/gpu/shaders/registration/coordinate_mapper.slang new file mode 100644 index 0000000000..72408347cb --- /dev/null +++ b/cpp/core/gpu/shaders/registration/coordinate_mapper.slang @@ -0,0 +1,87 @@ +import global_transformation; +import voxelscannermatrices; + +// A helper class to map coordinates between source and target spaces. +// It uses the voxel to scanner space and inverse matrices to perform the mapping. +struct CoordinateMapper { + uint3 _sourceDim; + uint3 _targetDim; + VoxelScannerMatrices _vsm; + float3 _sourceCentreVoxelCoord; + float3 _sourceCentreScanner; + + __init(uint3 sourceDim, uint3 targetDim, in VoxelScannerMatrices vsm) + { + this._sourceDim = sourceDim; + this._targetDim = targetDim; + this._vsm = vsm; + this._sourceCentreVoxelCoord = float3(sourceDim) * 0.5F; + this._sourceCentreScanner = mul(vsm.sourceVoxelToScanner, float4(this._sourceCentreVoxelCoord, 1.0F)).xyz; + } + + // Maps a voxel coordinate in the source space to target space + __generic + float3 mapSourceVoxelToTarget(float3 sourceVoxelCoord, Transformation transformation) + { + float4 sourceScannerCoord = mul(_vsm.sourceVoxelToScanner, float4(sourceVoxelCoord, 1.0F)); + float3 targetScannerCoord = transformation.apply(sourceScannerCoord.xyz); + return mul(_vsm.targetScannerToVoxel, float4(targetScannerCoord, 1.0F)).xyz; + } + + // Maps a voxel coordinate in the target space to source space + __generic + float3 mapTargetVoxelToSource(float3 targetVoxelCoord, Transformation transformation) + { + float4 targetScannerCoord = mul(_vsm.targetVoxelToScanner, float4(targetVoxelCoord, 1.0F)); + float3 sourceScannerCoord = transformation.apply(targetScannerCoord.xyz); + return mul(_vsm.sourceScannerToVoxel, float4(sourceScannerCoord, 1.0F)).xyz; + } + + + // Maps a voxel coordinate in the target space to scanner space + float3 mapTargetVoxelToScanner(float3 targetVoxelCoord) + { + return mul(_vsm.targetVoxelToScanner, float4(targetVoxelCoord, 1.0F)).xyz; + } + + // Maps a voxel gradient in source space to scanner space + // NOTE: the gradient vector is covariant. So if coordinates transform as v' = M * v, + // the gradient transforms as g' = g * inverse(M). Since sourceScannerToVoxel is the inverse + // of sourceVoxelToScanner, this function maps a voxel gradient in source space to scanner space. + // Note that in Slang the notation mul(vec, matrix) is equivalent to matrix^T * vec. + float3 mapVoxelGradientToScanner(float3 voxelGradient) + { + return mul(float4(voxelGradient, 0.0F), _vsm.sourceScannerToVoxel).xyz; + } + + + // Returns the scanner Jacobian at a given scanner coordinate + float3 sourceCentreScanner() + { + return _sourceCentreScanner; + } + + // Checks if a coordinate is within the source dimensions + bool inSourceInt(int32_t3 coord) + { + return all(coord >= int32_t3(0, 0, 0)) && all(coord < int32_t3(_sourceDim)); + } + + // Checks if a coordinate is within the target dimensions + bool inTargetInt(int32_t3 coord) + { + return all(coord >= int32_t3(0, 0, 0)) && all(coord < int32_t3(_targetDim)); + } + + // Checks if a coordinate is within the source dimensions + bool inSource(float3 coord) + { + return all(coord >= float3(0.0F, 0.0F, 0.0F)) && all(coord < float3(_sourceDim)); + } + + // Checks if a coordinate is within the target dimensions + bool inTarget(float3 coord) + { + return all(coord >= float3(0.0F, 0.0F, 0.0F)) && all(coord < float3(_targetDim)); + } +}; diff --git a/cpp/core/gpu/shaders/registration/cubic_bspline.slang b/cpp/core/gpu/shaders/registration/cubic_bspline.slang new file mode 100644 index 0000000000..9c79b5ffb7 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/cubic_bspline.slang @@ -0,0 +1,22 @@ +// Cubic B–spline kernel with compact support = 2 +float cubicBSpline(float x) +{ + let ax = abs(x); + let u = max(0.0, 2.0 - ax); + let v = max(0.0, 1.0 - ax); + let u3 = u * u * u; + let v3 = v * v * v; + return (u3 - 4.0 * v3) * (1.0 / 6.0); +} + +float cubicBSplineDerivative(float x) +{ + let ax = abs(x); + let u = max(0.0, 2.0 - ax); + let v = max(0.0, 1.0 - ax); + let du_dx = (ax < 2.0) ? -sign(x) : 0.0; + let dv_dx = (ax < 1.0) ? -sign(x) : 0.0; + let u2 = u * u; + let v2 = v * v; + return (3.0 * u2 * du_dx - 12.0 * v2 * dv_dx) * (1.0 / 6.0); +} diff --git a/cpp/core/gpu/shaders/registration/global_transformation.slang b/cpp/core/gpu/shaders/registration/global_transformation.slang new file mode 100644 index 0000000000..df26d25fe5 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/global_transformation.slang @@ -0,0 +1,359 @@ +module GlobalTransformation; + +public struct RigidTransformationParameters +{ + public float tx, ty, tz; // Translation + public float rx, ry, rz; // Axis angle parameterisation +} + +public struct AffineTransformationParameters +{ + public float tx, ty, tz; // Translation + public float rx, ry, rz; // Axis angle parameterisation + public float sx, sy, sz; // Scaling + public float shx, shy, shz; // Shearing +} + +public interface ITransformation +{ + associatedtype TParams; + associatedtype Jacobian : IArray; + static const uint32_t kParamCount; + + __init(TParams params, float3 pivot); + float3 apply(float3 coord); + float4x4 matrix(); + Jacobian jacobian(float3 coord); + // Returns the i-th column of the Jacobian matrix at the given coordinate + // Using this function can avoid constructing the full Jacobian array which + // sometimes lead to register spilling in a thread hurting performance. + float3 jacobianVector(uint32_t index, float3 coord); +} + +static const float kEps = 1e-5F; + +// Computes the rotation matrix from the exponential map using Rodrigues’ formula. +// See https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula +// Let v = (rx,ry,rz), θ = |v|. Then +// R = I + A * K + B * K^2, +// where +// A = sin(θ)/θ, B = (1-cos(θ))/(θ^2), +// and K is the skew-symmetric matrix of v. +// Note that the original Rodrigues’ formula is given as R = I + sin(θ)*K + (1-cos(θ))*K^2. +// but given v = (rx, ry, rz), θ = |v|, the unit axis is u = v/θ and K(u) = K(v/θ) = K(v)/θ. +public float3x3 rotationMatrix(float3 rotationVector) +{ + let theta = length(rotationVector); + float A, B; + if (theta < kEps) + { + A = 1.0f - theta * theta / 6.0f; + B = 0.5f - theta * theta / 24.0f; + } + else + { + A = sin(theta) / theta; + B = (1.0f - cos(theta)) / (theta * theta); + } + + let K = float3x3(0.0F, -rotationVector.z, rotationVector.y, + rotationVector.z, 0.0F, -rotationVector.x, + -rotationVector.y, rotationVector.x, 0.0F); + + let K2 = mul(K, K); + let I = float3x3(1.0F, 0.0F, 0.0F, + 0.0F, 1.0F, 0.0F, + 0.0F, 0.0F, 1.0F); + return I + A * K + B * K2; +} + +// Computes the derivative of the rotation matrix with respect to one of the three +// exponential map parameters. That is, for a given component (0 for rx, 1 for ry, 2 for rz), +// it returns dR/dv_i where R = I + A*K + B*K^2 and +// A = sin(θ)/θ, B = (1-cos(θ))/(θ^2), +// and K is the skew-symmetric matrix of v = (rx,ry,rz) with θ = |v|. +// dR/dv_i = dA/dv_i * K + A * dK/dv_i + dB/dv_i * K^2 + B * d(K^2)/dv_i. +public float3x3 rotationMatrixDerivative(float3 rotationVector, uint32_t component) +{ + let theta = length(rotationVector); + float A, B, dA_dtheta, dB_dtheta; + if (theta < kEps) + { + A = 1.0F - theta * theta / 6.0F; + B = 0.5F - theta * theta / 24.0F; + dA_dtheta = -theta / 3.0F; + dB_dtheta = -theta / 12.0F; + } + else + { + A = sin(theta) / theta; + B = (1.0f - cos(theta)) / (theta * theta); + dA_dtheta = (theta * cos(theta) - sin(theta)) / (theta * theta); + dB_dtheta = (theta * sin(theta) - 2.0f * (1.0f - cos(theta))) / (theta * theta * theta); + } + let K = float3x3(0.0F, -rotationVector.z, rotationVector.y, + rotationVector.z, 0.0F, -rotationVector.x, + -rotationVector.y, rotationVector.x, 0.0F); + let K2 = mul(K, K); + float3x3 dK = float3x3(0); + if (component == 0) + dK = float3x3(0.0F, 0.0F, 0.0F, 0.0F, 0.0F, -1.0F, 0.0F, 1.0F, 0.0f); + else if (component == 1) + dK = float3x3(0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, -1.0F, 0.0F, 0.0f); + else if (component == 2) + dK = float3x3(0.0F, -1.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 0.0F, 0.0f); + let dK2 = mul(dK, K) + mul(K, dK); + let vi = rotationVector[component]; + let dA_dvi = (theta < kEps) ? -vi / 3.0f : dA_dtheta * (vi / theta); + let dB_dvi = (theta < kEps) ? -vi / 12.0f : dB_dtheta * (vi / theta); + return dA_dvi * K + A * dK + dB_dvi * K2 + B * dK2; +} + +// Matrix layout: +// R R R Tx +// R R R Ty +// R R R Tz +// 0 0 0 1 +// The matrix is constructed assuming rotation is applied before translation +public float4x4 buildRigidMatrix(RigidTransformationParameters params) +{ + let R = rotationMatrix(float3(params.rx, params.ry, params.rz)); + return float4x4( + R[0][0], R[0][1], R[0][2], params.tx, + R[1][0], R[1][1], R[1][2], params.ty, + R[2][0], R[2][1], R[2][2], params.tz, + 0.0F, 0.0F, 0.0F, 1.0F); +} + +// Matrix layout: +// M M M Tx +// M M M Ty +// M M M Tz +// 0 0 0 1 +// M = R * S * Sh is the 3x3 linear transformation matrix (Rotation * Scale * Shear) +public float4x4 buildAffineMatrix(AffineTransformationParameters params, Optional precomputedRotation = none) +{ + let rotationVector = float3(params.rx, params.ry, params.rz); + let R = precomputedRotation.hasValue ? precomputedRotation.value : rotationMatrix(rotationVector); + let S = float3x3( + params.sx, 0.0F, 0.0F, + 0.0F, params.sy, 0.0F, + 0.0F, 0.0F, params.sz); + + let Sh = float3x3( + 1.0F, params.shx, params.shy, + 0.0F, 1.0F, params.shz, + 0.0F, 0.0F, 1.0F); + + // NOTE: in HLSL(and Slang), the * operator is for component-wise multiplication, + // so we need to use mul() to multiply matrices. + let M = mul(R, mul(S, Sh)); + return float4x4( + M[0][0], M[0][1], M[0][2], params.tx, + M[1][0], M[1][1], M[1][2], params.ty, + M[2][0], M[2][1], M[2][2], params.tz, + 0.0F, 0.0F, 0.0F, 1.0F); +} + +// T(v;α) = Rotation matrix + Translation +public float3 applyRigidTransform3D(float3 coord, RigidTransformationParameters params) +{ + let coord4 = float4(coord, 1.0F); + let transformed = mul(buildRigidMatrix(params), coord4); + return transformed.xyz; +} + +public float3 applyRigidTransform3DAboutPivot(float3 coord, + float3 pivot, + RigidTransformationParameters params) +{ + let local = coord - pivot; + let transformed4 = mul(buildRigidMatrix(params), float4(local, 1.0F)); + return transformed4.xyz + pivot; +} + +// T(v;α) = Rotation Matrix * Scale Matrix * Shear Matrix + Translation +public float3 applyAffineTransform3D(float3 coord, AffineTransformationParameters params) +{ + let coord4 = float4(coord, 1.0f); + let transformed = mul(buildAffineMatrix(params), coord4); + return transformed.xyz; +} + +public float3 applyAffineTransform3DAboutPivot(float3 coord, float3 pivot, AffineTransformationParameters params) +{ + let local = coord - pivot; + let transformed4 = mul(buildAffineMatrix(params), float4(local, 1.0f)); + return transformed4.xyz + pivot; +} + +public float3 rigidJacobianVector(uint32_t index, float3 coord, RigidTransformationParameters params) +{ + if(index == 0) return float3(1.0F, 0.0F, 0.0F); + if(index == 1) return float3(0.0F, 1.0F, 0.0F); + if(index == 2) return float3(0.0F, 0.0F, 1.0F); + + let rotationVector = float3(params.rx, params.ry, params.rz); + if(index == 3) return mul(rotationMatrixDerivative(rotationVector, 0), coord); + if(index == 4) return mul(rotationMatrixDerivative(rotationVector, 1), coord); + if(index == 5) return mul(rotationMatrixDerivative(rotationVector, 2), coord); + + return float3(0.0F, 0.0F, 0.0F); // Invalid index +} + +public float3[6] rigidJacobian(float3 coord, RigidTransformationParameters params) +{ + let dTdtx = rigidJacobianVector(0U, coord, params); + let dTdty = rigidJacobianVector(1U, coord, params); + let dTdtz = rigidJacobianVector(2U, coord, params); + let dRdVx = rigidJacobianVector(3U, coord, params); + let dRdVy = rigidJacobianVector(4U, coord, params); + let dRdVz = rigidJacobianVector(5U, coord, params); + + return float3[6]( + dTdtx, dTdty, dTdtz, + dRdVx, dRdVy, dRdVz); +} + + + +public float3 affineJacobianVector(uint32_t index, float3 coord, AffineTransformationParameters params) +{ +} + +public float3[12] affineJacobian(float3 coord, AffineTransformationParameters params) +{ + let dTdtx = affineJacobianVector(0U, coord, params); + let dTdty = affineJacobianVector(1U, coord, params); + let dTdtz = affineJacobianVector(2U, coord, params); + let dTdRx = affineJacobianVector(3U, coord, params); + let dTdRy = affineJacobianVector(4U, coord, params); + let dTdRz = affineJacobianVector(5U, coord, params); + let dTdSx = affineJacobianVector(6U, coord, params); + let dTdSy = affineJacobianVector(7U, coord, params); + let dTdSz = affineJacobianVector(8U, coord, params); + let dTdShx = affineJacobianVector(9U, coord, params); + let dTdShy = affineJacobianVector(10U, coord, params); + let dTdShz = affineJacobianVector(11U, coord, params); + + return float3[12]( + dTdtx, dTdty, dTdtz, + dTdRx, dTdRy, dTdRz, + dTdSx, dTdSy, dTdSz, + dTdShx, dTdShy, dTdShz); +} + +public struct RigidTransformation : ITransformation +{ + public typedef RigidTransformationParameters TParams; + public typedef float3[6] Jacobian; + public static const uint kParamCount = 6; + public TParams params; + public float3 pivot; + + __init(TParams params, float3 pivot) + { + this.params = params; + this.pivot = pivot; + } + public float3 apply(float3 coord) + { + return applyRigidTransform3DAboutPivot(coord, pivot, params); + } + + public float4x4 matrix() + { + return buildRigidMatrix(params); + } + + public Jacobian jacobian(float3 coord) + { + let evalCoord = coord - pivot; + return rigidJacobian(evalCoord, params); + } + + public float3 jacobianVector(uint32_t index, float3 coord) + { + let evalCoord = coord - pivot; + return rigidJacobianVector(index, evalCoord, params); + } +} + +public struct AffineTransformation : ITransformation +{ + public typedef AffineTransformationParameters TParams; + public typedef float3[12] Jacobian; + public static const uint kParamCount = 12; + public TParams params; + public float3 pivot; + float3x3 rotationMat; + + __init(TParams params, float3 pivot) + { + this.params = params; + this.pivot = pivot; + this.rotationMat = rotationMatrix(float3(params.rx, params.ry, params.rz)); + } + + public float3 apply(float3 coord) + { + let local = coord - pivot; + let transformed4 = mul(buildAffineMatrix(params, rotationMat), float4(local, 1.0f)); + return transformed4.xyz + pivot; + } + + float4x4 matrix() + { + return buildAffineMatrix(params, rotationMat); + } + + public Jacobian jacobian(float3 coord) + { + let evalCoord = coord - pivot; + return affineJacobian(evalCoord, params); + } + + public float3 jacobianVector(uint32_t index, float3 coord) + { + let evalCoord = coord - pivot; + if(index == 0U) return float3(1.0F, 0.0F, 0.0F); + if(index == 1U) return float3(0.0F, 1.0F, 0.0F); + if(index == 2U) return float3(0.0F, 0.0F, 1.0F); + + let scaleMatrix = float3x3( + params.sx, 0.0F, 0.0F, + 0.0F, params.sy, 0.0F, + 0.0F, 0.0F, params.sz); + + let shearMatrix = float3x3( + 1.0F, params.shx, params.shy, + 0.0F, 1.0F, params.shz, + 0.0F, 0.0F, 1.0F); + + + // Rotation derivatives : (∂T/∂v_i) = ∂R/∂v_i * S * Sh * x + if(index >= 3U && index <= 5U) { + let rotationVector = float3(params.rx, params.ry, params.rz); + let shearedThenScaledCoord = mul(scaleMatrix, mul(shearMatrix, evalCoord)); + return mul(rotationMatrixDerivative(rotationVector, index - 3U), shearedThenScaledCoord); + } + + // Scale derivatives : (∂T/∂scale_i) = (R * ∂S/∂s_i * Sh) * x + if(index >= 6U && index <= 8U) { + let coordSheared = mul(shearMatrix, evalCoord); + if(index == 6U) return mul(rotationMat, float3(coordSheared.x, 0.0, 0.0)); + if(index == 7U) return mul(rotationMat, float3(0.0, coordSheared.y, 0.0)); + return mul(rotationMat, float3(0.0, 0.0, coordSheared.z)); + } + + // Shear derivatives : (∂T/∂shear_i) = (R * S * ∂Sh/∂sh_i) * x + if(index >= 9U && index <= 11U) { + if(index == 9U) return mul(rotationMat, mul(scaleMatrix, float3(evalCoord.y, 0.0, 0.0))); + if(index == 10U) return mul(rotationMat, mul(scaleMatrix, float3(evalCoord.z, 0.0, 0.0))); + return mul(rotationMat, mul(scaleMatrix, float3(0.0, evalCoord.z, 0.0))); + } + + return float3(0.0F, 0.0F, 0.0F); // Invalid index + + } +} \ No newline at end of file diff --git a/cpp/core/gpu/shaders/registration/joint_histogram.slang b/cpp/core/gpu/shaders/registration/joint_histogram.slang new file mode 100644 index 0000000000..5664fe2dc9 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/joint_histogram.slang @@ -0,0 +1,206 @@ +// This shader computes a joint histogram of intensities from a given source and target volume. +import atomic_utils; +import cubic_bspline; +import texture_utils; +import reduction_utils; + +extern static const uint32_t kWorkgroupSizeX = 8; +extern static const uint32_t kWorkgroupSizeY = 8; +extern static const uint32_t kWorkgroupSizeZ = 4; + + +extern static const int32_t kNumBins; +// TODO: Use extern static const bool once the Slang compiler bug is fixed. +extern static const uint32_t kUseFixedMask; +extern static const uint32_t kUseMovingMask; +static const int32_t kPadding = 2; // reserve bins at each end so center never touches histogram edge +static const int32_t kWorkgroupInvocations = kWorkgroupSizeX * kWorkgroupSizeY * kWorkgroupSizeZ; + +// A small epsilon to avoid division by zero errors. +static const float epsilon = 1e-9; + +struct Intensities +{ + float sourceMin; + float sourceMax; + float targetMin; + float targetMax; +} + +struct Uniforms +{ + uint32_t3 dispatchGrid; + Intensities intensities; + // Optional linear transformation matrix from target image coordinates to source image coordinates. + float4x4 transformationMatrix; +} + +// Map intensity to nearest bin while preserving padding at edges +uint32_t mapIntensityToNearestBin(float intensity, float inMin, float inMax) +{ + let range = inMax - inMin; + let denom = (range <= epsilon) ? epsilon : range; + let normalized = clamp((intensity - inMin) / denom, 0.0F, 1.0F); + let usableBins = float(kNumBins - 2 * kPadding); + let idxF = normalized * (usableBins - 1.0F) + 0.5F; + let idxClamped = clamp(idxF, 0.0F, usableBins - 1.0F); + return uint32_t(kPadding) + uint32_t(idxClamped); +} + + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void rawHistogram( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + uint32_t3 workgroupId: SV_GroupID, + ConstantBuffer uniforms, + Texture3D movingTexture, + Texture3D fixedTexture, + Texture3D movingMaskTexture, + Texture3D fixedMaskTexture, + RWStructuredBuffer> jointHistogram, + SamplerState sampler) +{ + static groupshared Array, kNumBins * kNumBins> localHistogram; + static_assert(kNumBins > 0, "kNumBins must be greater than 0"); + + let sourceDimensions = textureSize(movingTexture); + let targetDimensions = textureSize(fixedTexture); + let sourceDimensionsF = float3(sourceDimensions); + + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + let totalBins = kNumBins * kNumBins; + + // Initialise workgroup-shared histogram and total mass to zero + for (var i = localIndex; i < totalBins; i += kWorkgroupInvocations) { + localHistogram[i].store(0U); + } + GroupMemoryBarrierWithGroupSync(); + let movingSamplingField = VoxelSamplingField(movingTexture, sampler); + let movingMaskField = VoxelSamplingField(movingMaskTexture, sampler); + let fixedMaskField = VoxelSamplingField(fixedMaskTexture, sampler); + let useFixedMask = kUseFixedMask != 0U; + let useMovingMask = kUseMovingMask != 0U; + + // Nearest-bin deposition: one atomic add per voxel into the workgroup-local histogram + bool includeVoxel = all(globalId < targetDimensions); + float3 movingVoxelCoord = float3(0.0F, 0.0F, 0.0F); + if (includeVoxel) { + let targetVoxelCoord = float3(globalId); + if (useFixedMask) { + let fixedMaskValue = fixedMaskField.sample(targetVoxelCoord); + if (fixedMaskValue < 0.5F) { + includeVoxel = false; + } + } + if (includeVoxel) { + movingVoxelCoord = mul(uniforms.transformationMatrix, float4(targetVoxelCoord, 1.0F)).xyz; + let isMovingVoxelInSource = all(movingVoxelCoord >= 0.0F) && all(movingVoxelCoord < sourceDimensionsF); + if (!isMovingVoxelInSource) { + includeVoxel = false; + } + } + if (includeVoxel && useMovingMask) { + let movingMaskValue = movingMaskField.sample(movingVoxelCoord); + if (movingMaskValue < 0.5F) { + includeVoxel = false; + } + } + } + + if (includeVoxel) { + let sourceIntensity = movingSamplingField.sample(movingVoxelCoord); + let targetIntensity = fixedTexture.Load(int4(globalId, 0)).r; + + let targetBin = mapIntensityToNearestBin(targetIntensity, uniforms.intensities.targetMin, uniforms.intensities.targetMax); + let sourceBin = mapIntensityToNearestBin(sourceIntensity, uniforms.intensities.sourceMin, uniforms.intensities.sourceMax); + + let binIndex = targetBin * kNumBins + sourceBin; + localHistogram[binIndex].increment(); + } + + GroupMemoryBarrierWithGroupSync(); + + let dispatchGrid = uniforms.dispatchGrid; + let wgIndex = workgroupId.x + workgroupId.y * dispatchGrid.x + workgroupId.z * dispatchGrid.x * dispatchGrid.y; + let numHistograms = dispatchGrid.x * dispatchGrid.y * dispatchGrid.z; + + // Merge workgroup-local histograms into global histogram + for (var i = localIndex; i < totalBins; i += kWorkgroupInvocations) { + let localCount = localHistogram[i].load(); + if (localCount > 0U) { + jointHistogram[i].add(localCount); + } + } +} + + +// Smooth the merged histogram by convolving with a cubic B-spline kernel +// NOTE: This only yields an approximation of the Parzen-windowed histogram, +// but it's a good trade-off between accuracy and computational efficiency. +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, 1)] +void smoothHistogram( + uint32_t3 globalId: SV_DispatchThreadID, + ConstantBuffer uniforms, + StructuredBuffer jointHistogram, + RWStructuredBuffer jointHistogramSmoothed +) +{ + let targetBin = globalId.x; + let sourceBin = globalId.y; + if (targetBin >= uint32_t(kNumBins) || sourceBin >= uint32_t(kNumBins)) { + return; + } + + var accumulated = 0.0F; + // cubic B-spline has compact support [-2, 2] + for (int dTarget = -2; dTarget <= 2; ++dTarget) { + for (int dSource = -2; dSource <= 2; ++dSource) { + let neighbourTargetBinInt = int(targetBin) + dTarget; + let neighbourSourceBinInt = int(sourceBin) + dSource; + let neighbourTargetBin = uint32_t(clamp(neighbourTargetBinInt, 0, kNumBins - 1)); + let neighbourSourceBin = uint32_t(clamp(neighbourSourceBinInt, 0, kNumBins - 1)); + let weight = cubicBSpline(float(dTarget)) * cubicBSpline(float(dSource)); + let histIndex = neighbourTargetBin * kNumBins + neighbourSourceBin; + accumulated += weight * float(jointHistogram[histIndex]); + } + } + + let outIndex = targetBin * kNumBins + sourceBin; + jointHistogramSmoothed[outIndex] = accumulated; +} + + +// This entry point is dispatched with a single workgroup +// to compute the total mass of the joint histogram +static const uint32_t kTotalMassWorkgroupSize = 1024; +[shader("compute")] +[numthreads(kTotalMassWorkgroupSize, 1, 1)] +void computeTotalMass( + uint32_t3 localId: SV_GroupThreadID, + StructuredBuffer jointHistogramSmoothed, + RWStructuredBuffer jointHistogramMass +) +{ + static groupshared Array sharedSums; // one slot per thread + static const uint32_t totalBins = kNumBins * kNumBins; + + let localIndex = localId.x; + + // Each thread sums a strided range of bins + var threadSum = 0.0F; + for (var i = localIndex; i < totalBins; i += kTotalMassWorkgroupSize) { + threadSum += jointHistogramSmoothed[i]; + } + + sharedSums[localIndex] = threadSum; + GroupMemoryBarrierWithGroupSync(); + + let totalMass = workgroupReduce(sharedSums, localIndex); + + if (localIndex == 0u) { + jointHistogramMass[0] = totalMass; + } +} diff --git a/cpp/core/gpu/shaders/registration/moments.slang b/cpp/core/gpu/shaders/registration/moments.slang new file mode 100644 index 0000000000..8950693322 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/moments.slang @@ -0,0 +1,59 @@ +import atomic_utils; + +extern static const uint32_t kWorkgroupSizeX = 8; +extern static const uint32_t kWorkgroupSizeY = 8; +extern static const uint32_t kWorkgroupSizeZ = 4; +static const uint32_t kMomentCount = 6; +// TODO: Use extern static const bool once the Slang compiler bug is fixed. +extern static const uint32_t kUseMask; + +struct MomentUniforms { + float4 centre; +}; + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main( + uint32_t3 globalId : SV_DispatchThreadID, + RWStructuredBuffer> momentBuffer, + StructuredBuffer voxelToScanner, + ConstantBuffer centreScanner, + Texture3D image, + Texture3D mask) +{ + var imageDimensions : uint32_t3; + image.GetDimensions(imageDimensions.x, imageDimensions.y, imageDimensions.z); + + if (!all(globalId < imageDimensions)) { + return; + } + + if (kUseMask != 0U) { + let maskValue = mask.Load(int32_t4(globalId, 0)).r; + if (maskValue < 0.5F) { + return; + } + } + + let voxelValue = image.Load(int32_t4(globalId, 0)).r; + if (!isfinite(voxelValue)) { + return; + } + + let scannerPos4 = mul(voxelToScanner[0], float4(float3(globalId), 1.0F)); + let scannerPos = scannerPos4.xyz / scannerPos4.w; + let centered = scannerPos - centreScanner.centre.xyz; + + let contributions = Array( + centered.x * centered.x * voxelValue, + centered.y * centered.y * voxelValue, + centered.z * centered.z * voxelValue, + centered.x * centered.y * voxelValue, + centered.x * centered.z * voxelValue, + centered.y * centered.z * voxelValue + ); + + for (var i = 0U; i < kMomentCount; ++i) { + atomicAddF32InMemory(momentBuffer[i], contributions[i]); + } +} diff --git a/cpp/core/gpu/shaders/registration/ncc.slang b/cpp/core/gpu/shaders/registration/ncc.slang new file mode 100644 index 0000000000..54356cf005 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/ncc.slang @@ -0,0 +1,336 @@ +// NCC = Normalized Cross Correlation +// LNCC = Local Normalized Cross Correlation + +// Definition (for a neighborhood N around voxel x): +// LNCC(x) = Cov(I_t, I_m) / sqrt[ Var_t * Var_m ] +// +// where: +// - I_t are target intensities in neighborhood N +// - I_m are moving image intensities in neighborhood N (after transformation) +// - mu_t, mu_m are the means of I_t and I_m in N +// - Cov(I_t, I_m) = (1/|N|) * Σ_p (I_t(p) - mu_t)(I_m(p) - mu_m) +// - Var_t = (1/|N|) * Σ_p (I_t(p) - mu_t)^2 +// - Var_m = (1/|N|) * Σ_p (I_m(p) - mu_m)^2 +// dLNCC/dalpha = (C' * Var_m - 0.5 * C * Var_m') / (Var_m * sqrt(Var_t * Var_m)) +// Global NCC: same equations, but the sums are accumulated over the whole valid domain. +// TODO: write derivation of gradients. + +import global_transformation; +import reduction_utils; +import texture_utils; +import coordinate_mapper; +import voxelscannermatrices; + +extern static const uint32_t kWorkgroupSizeX = 8; +extern static const uint32_t kWorkgroupSizeY = 4; +extern static const uint32_t kWorkgroupSizeZ = 4; +// TODO: Use extern static const bool once the Slang compiler bug is fixed. +extern static const uint32_t kUseSourceMask; +extern static const uint32_t kUseTargetMask; +extern static const uint32_t kComputeGradients; +extern static const uint32_t kWindowRadius = 0; +static const uint32_t kWorkgroupInvocations = kWorkgroupSizeX * kWorkgroupSizeY * kWorkgroupSizeZ; +// For numerical stability in variance and denominator calculations +static const float kEpsVar = 1e-6F; +static const float kEpsDenom = 1e-6F; + +struct LNCCParameters { + float cost; + Array gradients; +}; + +struct LNCCReductionOP : IReduceOp> { + static LNCCParameters identityElement() { + var params : LNCCParameters; + params.cost = 0.0F; + for (uint32_t i = 0U; i < N; ++i) { + params.gradients[i] = 0.0F; + } + return params; + } + + static LNCCParameters reduce(LNCCParameters a, LNCCParameters b) { + var params : LNCCParameters; + params.cost = a.cost + b.cost; + for (uint32_t i = 0U; i < N; ++i) { + params.gradients[i] = a.gradients[i] + b.gradients[i]; + } + return params; + } +}; + +struct GlobalNCCParameters { + float sumTarget; + float sumMoving; + float sumTargetSquared; + float sumMovingSquared; + float sumTargetMoving; + Array sumTargetMovingPrime; + Array sumMovingPrime; + Array sumMovingSquaredPrime; +}; + +struct GlobalNCCReductionOP : IReduceOp> { + static GlobalNCCParameters identityElement() { + var zeros : Array; + for (uint32_t i = 0U; i < N; ++i) { zeros[i] = 0.0F; } + return GlobalNCCParameters(0.0F, 0.0F, 0.0F, 0.0F, 0.0F, zeros, zeros, zeros); + } + + static GlobalNCCParameters reduce(GlobalNCCParameters a, GlobalNCCParameters b) { + var tmPrime : Array; + var mPrime : Array; + var mmPrime : Array; + for (uint32_t i = 0U; i < N; ++i) { + tmPrime[i] = a.sumTargetMovingPrime[i] + b.sumTargetMovingPrime[i]; + mPrime[i] = a.sumMovingPrime[i] + b.sumMovingPrime[i]; + mmPrime[i] = a.sumMovingSquaredPrime[i] + b.sumMovingSquaredPrime[i]; + } + return GlobalNCCParameters(a.sumTarget + b.sumTarget, + a.sumMoving + b.sumMoving, + a.sumTargetSquared + b.sumTargetSquared, + a.sumMovingSquared + b.sumMovingSquared, + a.sumTargetMoving + b.sumTargetMoving, + tmPrime, + mPrime, + mmPrime); + } +}; + +struct NCCUniforms where Transformation : ITransformation { + uint32_t3 dispatchGrid; + float3 transformationPivot; + Transformation.TParams currentTransform; + VoxelScannerMatrices voxelScannerMatrices; +}; + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void lncc_main( + uint32_t3 globalId : SV_DispatchThreadID, + uint32_t3 localId : SV_GroupThreadID, + uint32_t3 workgroupId : SV_GroupID, + ConstantBuffer> uniforms, + Texture3D sourceImage, + Texture3D targetImage, + Texture3D sourceMask, + Texture3D targetMask, + SamplerState linearSampler, + RWStructuredBuffer lnccPartials, + RWStructuredBuffer> numContributingVoxels) + where Transformation : ITransformation { + static const uint32_t paramsCount = Transformation.kParamCount; + typedef LNCCParameters CurrentParams; + static groupshared Array workgroupPartials; + static groupshared Atomic localValidCount; + + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + if (localIndex == 0U) { localValidCount.store(0U); } + GroupMemoryBarrierWithGroupSync(); + + let radius = int32_t(kWindowRadius); + var threadParams = LNCCReductionOP::identityElement(); + if (radius > 0) { + let sourceDim = textureSize(sourceImage); + let targetDim = textureSize(targetImage); + let coordMapper = CoordinateMapper(sourceDim, targetDim, uniforms.voxelScannerMatrices); + let transformation = Transformation(uniforms.currentTransform, uniforms.transformationPivot); + let sourceField = VoxelSamplingField(sourceImage, linearSampler); + let targetField = VoxelSamplingField(targetImage, linearSampler); + let sourceMaskField = VoxelSamplingField(sourceMask, linearSampler); + let targetMaskField = VoxelSamplingField(targetMask, linearSampler); + + let center = int3(globalId); + if (coordMapper.inTargetInt(center) && targetMaskField.maskAccepts(float3(center), kUseTargetMask != 0U)) { + var sumTarget : float = 0.0F; + var sumMoving : float = 0.0F; + var sumTargetSquared : float = 0.0F; + var sumMovingSquared : float = 0.0F; + var sumTargetMoving : float = 0.0F; + // NOTE: We treat valid samples as fixed, but the true derivative should include + // boundary/mask discontinuities. + var validSamples : float = 0.0F; + var sumTargetMovingPrime : Array; + var sumMovingPrime : Array; + var sumMovingSquaredPrime : Array; + for (uint32_t i = 0U; i < paramsCount; ++i) { + sumTargetMovingPrime[i] = 0.0F; + sumMovingPrime[i] = 0.0F; + sumMovingSquaredPrime[i] = 0.0F; + } + + for (int32_t dz = -radius; dz <= radius; ++dz) { + for (int32_t dy = -radius; dy <= radius; ++dy) { + for (int32_t dx = -radius; dx <= radius; ++dx) { + let neighbor = center + int3(dx, dy, dz); + if (!coordMapper.inTargetInt(neighbor)) { + continue; + } + let neighborFloat = float3(neighbor); + if (!targetMaskField.maskAccepts(neighborFloat, kUseTargetMask != 0U)) { + continue; + } + + let sourceCoord = coordMapper.mapTargetVoxelToSource(neighborFloat, transformation); + if (!coordMapper.inSource(sourceCoord)) { + continue; + } + if (!sourceMaskField.maskAccepts(sourceCoord, kUseSourceMask != 0U)) { + continue; + } + let targetVal = targetField.sample(neighborFloat); + let movingVal = sourceField.sample(sourceCoord); + + validSamples += 1.0F; + sumTarget += targetVal; + sumMoving += movingVal; + sumTargetSquared += targetVal * targetVal; + sumMovingSquared += movingVal * movingVal; + sumTargetMoving += targetVal * movingVal; + + if (kComputeGradients != 0U) { + let gradVoxel = sourceField.spatialGradient(sourceCoord); + let gradScanner = coordMapper.mapVoxelGradientToScanner(gradVoxel); + let scannerCoord = coordMapper.mapTargetVoxelToScanner(neighborFloat); + for (uint32_t p = 0U; p < paramsCount; ++p) { + let imPrime = dot(gradScanner, transformation.jacobianVector(p, scannerCoord)); + sumTargetMovingPrime[p] += targetVal * imPrime; + sumMovingPrime[p] += imPrime; + sumMovingSquaredPrime[p] += movingVal * imPrime; + } + } + } + } + } + + if (validSamples > 0.0F) { + let invN = 1.0F / validSamples; + let meanTarget = sumTarget * invN; + let meanMoving = sumMoving * invN; + let varianceTarget = max(0.0F, sumTargetSquared * invN - meanTarget * meanTarget); + let varianceMoving = max(0.0F, sumMovingSquared * invN - meanMoving * meanMoving); + if (varianceTarget > kEpsVar && varianceMoving > kEpsVar) { + let covariance = sumTargetMoving * invN - meanTarget * meanMoving; + let costDenom = sqrt(max(varianceTarget * varianceMoving, kEpsVar)); + threadParams.cost = -covariance / costDenom; + + if (kComputeGradients != 0U) { + let denomGradBase = max(varianceMoving * sqrt(max(varianceTarget * varianceMoving, kEpsVar)), kEpsDenom); + for (uint32_t p = 0U; p < paramsCount; ++p) { + let cPrime = (sumTargetMovingPrime[p] * invN) - (meanTarget * sumMovingPrime[p] * invN); + let varMovingPrime = 2.0F * (sumMovingSquaredPrime[p] * invN - meanMoving * sumMovingPrime[p] * invN); + let gradValue = (cPrime * varianceMoving - 0.5F * covariance * varMovingPrime) / denomGradBase; + threadParams.gradients[p] = -gradValue; + } + } + + localValidCount.increment(); + } + } + } + } + + workgroupPartials[localIndex] = threadParams; + GroupMemoryBarrierWithGroupSync(); + let reduced = workgroupReduce, kWorkgroupInvocations>(workgroupPartials, localIndex); + + if (localIndex == 0U) { + let wgIndex = workgroupId.x + + workgroupId.y * uniforms.dispatchGrid[0] + + workgroupId.z * uniforms.dispatchGrid[0] * uniforms.dispatchGrid[1]; + let base = wgIndex * (1 + paramsCount); + lnccPartials[base] = reduced.cost; + for (uint32_t i = 0U; i < paramsCount; ++i) { + lnccPartials[base + 1 + i] = reduced.gradients[i]; + } + numContributingVoxels[0].add(localValidCount.load()); + } +} + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void global_ncc_main( + uint32_t3 globalId : SV_DispatchThreadID, + uint32_t3 localId : SV_GroupThreadID, + uint32_t3 workgroupId : SV_GroupID, + ConstantBuffer> uniforms, + Texture3D sourceImage, + Texture3D targetImage, + Texture3D sourceMask, + Texture3D targetMask, + SamplerState linearSampler, + RWStructuredBuffer globalPartials, + RWStructuredBuffer> numContributingVoxels) + where Transformation : ITransformation { + static const uint32_t paramsCount = Transformation.kParamCount; + typedef GlobalNCCParameters CurrentParams; + static groupshared Array workgroupPartials; + static groupshared Atomic localValidCount; + + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + if (localIndex == 0U) { localValidCount.store(0U); } + GroupMemoryBarrierWithGroupSync(); + + let sourceDim = textureSize(sourceImage); + let targetDim = textureSize(targetImage); + let coordMapper = CoordinateMapper(sourceDim, targetDim, uniforms.voxelScannerMatrices); + let transformation = Transformation(uniforms.currentTransform, uniforms.transformationPivot); + let sourceField = VoxelSamplingField(sourceImage, linearSampler); + let targetField = VoxelSamplingField(targetImage, linearSampler); + let sourceMaskField = VoxelSamplingField(sourceMask, linearSampler); + let targetMaskField = VoxelSamplingField(targetMask, linearSampler); + + CurrentParams threadParams = GlobalNCCReductionOP::identityElement(); + + if (coordMapper.inTargetInt(int3(globalId)) && targetMaskField.maskAccepts(float3(globalId), kUseTargetMask != 0U)) { + let sourceCoord = coordMapper.mapTargetVoxelToSource(float3(globalId), transformation); + if (coordMapper.inSource(sourceCoord) && sourceMaskField.maskAccepts(sourceCoord, kUseSourceMask != 0U)) { + let targetVal = targetField.sample(float3(globalId)); + let movingVal = sourceField.sample(sourceCoord); + threadParams.sumTarget = targetVal; + threadParams.sumMoving = movingVal; + threadParams.sumTargetSquared = targetVal * targetVal; + threadParams.sumMovingSquared = movingVal * movingVal; + threadParams.sumTargetMoving = targetVal * movingVal; + + if (kComputeGradients != 0U) { + let gradVoxel = sourceField.spatialGradient(sourceCoord); + let gradScanner = coordMapper.mapVoxelGradientToScanner(gradVoxel); + let scannerCoord = coordMapper.mapTargetVoxelToScanner(float3(globalId)); + for (uint32_t p = 0U; p < paramsCount; ++p) { + let imPrime = dot(gradScanner, transformation.jacobianVector(p, scannerCoord)); + threadParams.sumTargetMovingPrime[p] = targetVal * imPrime; + threadParams.sumMovingPrime[p] = imPrime; + threadParams.sumMovingSquaredPrime[p] = movingVal * imPrime; + } + } + + localValidCount.increment(); + } + } + + workgroupPartials[localIndex] = threadParams; + GroupMemoryBarrierWithGroupSync(); + let reduced = workgroupReduce, kWorkgroupInvocations>(workgroupPartials, localIndex); + + if (localIndex == 0U) { + let wgIndex = workgroupId.x + + workgroupId.y * uniforms.dispatchGrid[0] + + workgroupId.z * uniforms.dispatchGrid[0] * uniforms.dispatchGrid[1]; + let base = wgIndex * (5 + 3 * paramsCount); + globalPartials[base + 0] = reduced.sumTarget; + globalPartials[base + 1] = reduced.sumMoving; + globalPartials[base + 2] = reduced.sumTargetSquared; + globalPartials[base + 3] = reduced.sumMovingSquared; + globalPartials[base + 4] = reduced.sumTargetMoving; + + var offset : uint32_t = 5U; + for (uint32_t i = 0U; i < paramsCount; ++i) { globalPartials[base + offset + i] = reduced.sumTargetMovingPrime[i]; } + offset += paramsCount; + for (uint32_t i = 0U; i < paramsCount; ++i) { globalPartials[base + offset + i] = reduced.sumMovingPrime[i]; } + offset += paramsCount; + for (uint32_t i = 0U; i < paramsCount; ++i) { globalPartials[base + offset + i] = reduced.sumMovingSquaredPrime[i]; } + + numContributingVoxels[0].add(localValidCount.load()); + } +} diff --git a/cpp/core/gpu/shaders/registration/nmi.slang b/cpp/core/gpu/shaders/registration/nmi.slang new file mode 100644 index 0000000000..d8e19c6c40 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/nmi.slang @@ -0,0 +1,375 @@ +// Normalised Mutual Information (NMI) = (H_t + H_m) / H_j +// where +// - H_j = -Σ_ij p_ij * log(p_ij) is the joint entropy. +// - H_t = -Σ_i p_i * log(p_i) is the target marginal entropy. +// - H_m = -Σ_j p_j * log(p_j) is the moving marginal entropy. +// +// Joint / marginal probabilities and normalisation: +// - H_ij = Σ_v h_ij(v) is joint histogram counts +// - Z = Σ_ab H_{ab} is total mass +// - p_ij = H_ij / Z +// - p_i = Σ_j p_ij, p_j = Σ_i p_ij +// where v is the voxel index. +// +// Using the quotient rule for NMI = (H_t + H_m) / H_j, after some tedious algebra (using dp_i/dq = Σ_ij dp_ij/dq), +// you can show that: +// dNMI/dq = 1/(H_j*Z) Σ_ij (dH_ij/dq * [NMI * log p_ij - log p_i - log p_j]) +// The fact that the term in square brackets is independent of the voxel index v is used in this +// shader to precompute it once per-bin to avoid redundant calculations. +// +// To calculate dH_ij/dq, we first note that the parzen-window per-voxel contribution: +// h_ij(v) = B(b_t(v) - i) * B(b_m(v) - j) +// where +// - b_t(v) maps the target intensity at voxel v into bin-space (target bins indexed by i) +// - b_m(v) maps the (transformed) moving intensity at voxel v into bin-space (moving/source bins indexed by j) +// - B(x) is the cubic B-spline kernel and B'(x) = dB/dx. +// +// Only the moving image depends on the transform parameters q. Therefore: +// dH_ij/dq = Σ_v d h_ij(v)/dq +// and +// d h_ij(v)/dq = B( b_t(v) - i ) * B'( b_m(v) - j ) * d b_m(v)/dq +// Since b_m depends linearly on the moving image intensity I_m(x'): +// d b_m / dq = (d b_m / d I_m) * d I_m(x')/dq +// and by the chain rule +// d I_m(x')/dq = grad I_m(x') * d x'/dq +// where grad I_m is the image gradient in scanner coordinates at x'. + +import coordinate_mapper; +import cubic_bspline; +import global_transformation; +import reduction_utils; +import texture_utils; +import parzen_binner; +import voxelscannermatrices; + +extern static const int32_t kNumBins; +// TODO: Use extern static const bool once the Slang compiler bug is fixed. +extern static const uint32_t kUseTargetMask; +extern static const uint32_t kUseSourceMask; + +static const float divisionEpsilon = 1e-7F; // To avoid divisions by zero +static const float logEpsilon = 1e-9F; // To avoid log(0) + +static const uint32_t kPrecomputeWorkgroupSizeX = 16; +static const uint32_t kPrecomputeWorkgroupSizeY = 16; +static const uint32_t kPrecomputeWorkgroupInvocations = kPrecomputeWorkgroupSizeX * kPrecomputeWorkgroupSizeY; + +// Computes marginals p_i, p_j, entropies H_t, H_m, H_j, the scalar NMI +// value, and a per-bin coefficient table coefficientsTable[i*kNumBins + j] = dNMI/dp_ij. +// NOTE: this entry point MUST be dispatched with a single workgroup. +[shader("compute")] +[numthreads(kPrecomputeWorkgroupSizeX, kPrecomputeWorkgroupSizeY, 1)] +void precompute( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + // Raw (unnormalised) joint histogram H and scalar total mass Z + StructuredBuffer jointHistogram, + StructuredBuffer jointHistogramMass, + // coefficientsTable[i*kNumBins + j] = ( nmi*log p_ij - log p_i - log p_j ) / (H_j * Z) + RWStructuredBuffer coefficientsTable, + RWStructuredBuffer mutualInformation) +{ + static groupshared Array pTarget; + static groupshared Array pMoving; + static groupshared Array logP_row; // log p_i (row / target marginal) + static groupshared Array logP_col; // log p_j (column / moving marginal) + static groupshared Array workgroupEntropies; // H_j, H_t, H_m + + let localIndex = localId.y * kPrecomputeWorkgroupSizeX + localId.x; + // We iterate over the histogram in a strided manner so that each thread + // is responsible for accumulating a single row/column of the joint histogram. + // Each thread will process kNumBins/kPrecomputeWorkgroupSizeY (for p_i) rows or + // kNumBins/kPrecomputeWorkgroupSizeX (for p_j) columns. + // p_i = Σ_j p_ij + // Each thread sums its own row of the joint histogram. + if (localId.x == 0) { + for (uint32_t i = localId.y; i < kNumBins; i += kPrecomputeWorkgroupSizeY) { + // guard in case kNumBins < workgroup size or not divisible by stride + if (i >= kNumBins) break; + var rowSum = 0.0F; + for (uint32_t j = 0; j < kNumBins; ++j) { + rowSum += jointHistogram[i * kNumBins + j]; + } + pTarget[i] = rowSum; // H_i + } + } + + // p_j = Σ_i p_ij + // Each thread sums its own column of the joint histogram. + if(localId.y == 0) { + for(uint32_t j = localId.x; j < kNumBins; j += kPrecomputeWorkgroupSizeX) { + // guard in case kNumBins < workgroup size or not divisible by stride + if (j >= kNumBins) break; + var colSum = 0.0F; + for(uint32_t i = 0; i < kNumBins; ++i) { + colSum += jointHistogram[i*kNumBins + j]; + } + pMoving[j] = colSum; // H_j + } + } + + GroupMemoryBarrierWithGroupSync(); + + + var entropiesVec = float3(0.0F, 0.0F, 0.0F); // H_j, H_t, H_m + var localHJ = 0.0F; + var localHT = 0.0F; + var localHM = 0.0F; + let Z = max(jointHistogramMass[0], 0.0F); + + // Accumulate joint entropy H_j and H_t contributions + for (uint32_t i = localId.y; i < kNumBins; i += kPrecomputeWorkgroupSizeY) { + if (i >= kNumBins) break; + if (localId.x == 0) { + let Hi = max(pTarget[i], 0.0F); + let pi = Hi / max(Z, divisionEpsilon); + entropiesVec.y += -pi * log(max(pi, logEpsilon)); + } + for(uint32_t j = localId.x; j < kNumBins; j += kPrecomputeWorkgroupSizeX) { + if (j >= kNumBins) break; + let index = i * kNumBins + j; + let Hij = max(jointHistogram[index], 0.0F); + let p_ij = Hij / max(Z, divisionEpsilon); + entropiesVec.x += -p_ij * log(max(p_ij, logEpsilon)); + } + } + + // Accumulate H_m contributions (strided over columns) + if (localId.y == 0) { + for (uint32_t j = localId.x; j < kNumBins; j += kPrecomputeWorkgroupSizeX) { + if (j >= kNumBins) break; + let Hj = max(pMoving[j], 0.0F); + let pj = Hj / max(Z, divisionEpsilon); + entropiesVec.z += -pj * log(max(pj, logEpsilon)); + } + } + + workgroupEntropies[localIndex] = entropiesVec; + // Reduce H_j, H_t, H_m across the workgroup + GroupMemoryBarrierWithGroupSync(); + + let entropies = workgroupReduce, kPrecomputeWorkgroupInvocations>(workgroupEntropies, localIndex); + let H_j = entropies.x; + let H_t = entropies.y; + let H_m = entropies.z; + + if(localId.x == 0 && localId.y == 0) { + let denom = max(H_j, divisionEpsilon); + mutualInformation[0] = (H_t + H_m) / denom; + } + + GroupMemoryBarrierWithGroupSync(); + + // Compute per-bin coefficient table coefficientsTable[index] = (nmi*log p_ij - log p_i - log p_j) / (H_j * Z) + let nmi = (H_t + H_m) / max(H_j, divisionEpsilon); + let Zsafe = max(Z, divisionEpsilon); + let denom = H_j * Zsafe; + + var invNorm = denom > divisionEpsilon ? 1.0F / denom : 0.0F; + + // Precompute log p_i and log p_j into shared memory to avoid recomputing inside inner loop + // Each thread along the x-axis of the workgroup computes log p_j + // Each thread along the y-axis of the workgroup computes log p_i + + if(localId.x == 0) { + for (uint32_t i = localId.y; i < kNumBins; i += kPrecomputeWorkgroupSizeY) { + if (i >= kNumBins) break; + let pi = max(pTarget[i] / Zsafe, 0.0F); + logP_row[i] = log(max(pi, logEpsilon)); + } + } + + if(localId.y == 0) { + for (uint32_t j = localId.x; j < kNumBins; j += kPrecomputeWorkgroupSizeX) { + if (j >= kNumBins) break; + let pj = max(pMoving[j] / Zsafe, 0.0F); + logP_col[j] = log(max(pj, logEpsilon)); + } + } + + GroupMemoryBarrierWithGroupSync(); + + for(uint32_t i = localId.y; i < kNumBins; i += kPrecomputeWorkgroupSizeY) { + if (i >= kNumBins) break; + for(uint32_t j = localId.x; j < kNumBins; j += kPrecomputeWorkgroupSizeX) { + if (j >= kNumBins) break; + let index = i * kNumBins + j; + let Hij = max(jointHistogram[index], 0.0F); + let p_ij = max(Hij / Zsafe, 0.0F); + let logHij = log(max(p_ij, logEpsilon)); + let val = nmi * logHij - logP_row[i] - logP_col[j]; + coefficientsTable[index] = invNorm * val; + } + } +} + +extern static const uint32_t kWorkgroupSizeX = 8; +extern static const uint32_t kWorkgroupSizeY = 8; +extern static const uint32_t kWorkgroupSizeZ = 8; +static const uint32_t kWorkgroupInvocations = kWorkgroupSizeX * kWorkgroupSizeY * kWorkgroupSizeZ; +static const uint32_t kMinSubgroupSize = 4; + +typealias MIGradients = Array; + +struct Intensities +{ + float sourceMin; + float sourceMax; + float targetMin; + float targetMax; +} + +struct Uniforms where Transformation : ITransformation{ + uint32_t3 dispatchGrid; + float3 transformationPivot; + Intensities intensities; + Transformation.TParams currentTransform; + VoxelScannerMatrices voxelScannerMatrices; +}; + + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main( + uint32_t3 globalId: SV_DispatchThreadID, + uint32_t3 localId: SV_GroupThreadID, + uint32_t3 workgroupId: SV_GroupID, + ConstantBuffer> uniforms, + Texture3D sourceTexture, + Texture3D targetTexture, + Texture3D sourceMaskTexture, + Texture3D targetMaskTexture, + SamplerState sampler, + StructuredBuffer coefficientsTable, + RWStructuredBuffer> partialSumsGradients) + where Transformation : ITransformation +{ + typealias Gradients = MIGradients; + // Wave-partial gradient sums: one Gradients vector per possible wave in the WG. + static const uint32_t kWaveMax = kWorkgroupInvocations / kMinSubgroupSize; + static groupshared Array wavePartials; + typealias Binner = ParzenBinner; + + let sourceDimensions = textureSize(sourceTexture); + let targetDimensions = textureSize(targetTexture); + let coordMapper = CoordinateMapper( + uint3(sourceDimensions), uint3(targetDimensions), uniforms.voxelScannerMatrices + ); + + let transformation = Transformation(uniforms.currentTransform, uniforms.transformationPivot); + let sourceTextureField = VoxelSamplingField(sourceTexture, sampler); + let sourceMaskField = VoxelSamplingField(sourceMaskTexture, sampler); + let targetMaskField = VoxelSamplingField(targetMaskTexture, sampler); + let effectiveRangeBins = Binner.effectiveRangeBins(); + + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + + // Thread-local gradient accumulator kept in registers; must be defined for all threads + var gradients: Gradients; + for (uint32_t q = 0U; q < Transformation.kParamCount; ++q) { gradients[q] = 0.0F; } + + bool includeVoxel = coordMapper.inTargetInt(int3(globalId)); + float3 targetVoxelCoord = float3(globalId); + float3 movingVoxelCoord = 0.0F; + + if (includeVoxel && targetMaskField.maskAccepts(targetVoxelCoord, kUseTargetMask != 0U)) { + movingVoxelCoord = coordMapper.mapTargetVoxelToSource(targetVoxelCoord, transformation); + if (!coordMapper.inSource(movingVoxelCoord) || !sourceMaskField.maskAccepts(movingVoxelCoord, kUseSourceMask != 0U)) { + includeVoxel = false; + } + } else { + includeVoxel = false; + } + + if (includeVoxel) { + let targetIntensity = targetTexture.Load(int4(globalId, 0)).r; + let gradMovingVoxel = sourceTextureField.spatialGradient(movingVoxelCoord); + let gradMovingScanner = coordMapper.mapVoxelGradientToScanner(gradMovingVoxel); + let targetScannerCoord = coordMapper.mapTargetVoxelToScanner(targetVoxelCoord); + + let targetMin = uniforms.intensities.targetMin; + let targetMax = uniforms.intensities.targetMax; + let sourceMin = uniforms.intensities.sourceMin; + let sourceMax = uniforms.intensities.sourceMax; + + let movingIntensity = sourceTextureField.sample(movingVoxelCoord); + + let targetBin = Binner.mapIntensityToBin(targetIntensity, targetMin, targetMax); + let sourceBin = Binner.mapIntensityToBin(movingIntensity, sourceMin, sourceMax); + + uint32_t targetBinStart; uint32_t targetBinEnd; + Binner.computeBinNeighbourhood(targetBin, targetBinStart, targetBinEnd); + uint32_t sourceBinStart; uint32_t sourceBinEnd; + Binner.computeBinNeighbourhood(sourceBin, sourceBinStart, sourceBinEnd); + + // Precompute B-spline weights and derivatives for available bins + var targetWeights: float4; + var sourceDerivWeights: float4; + var targetCount: uint32_t; + var sourceCount: uint32_t; + + Binner.computeWeights(targetBin, targetBinStart, targetBinEnd, targetWeights, targetCount); + Binner.computeDerivatives(sourceBin, sourceBinStart, sourceBinEnd, sourceDerivWeights, sourceCount); + + // Derivative: dNMI/dq = Σ_ij C_ij * B_t(i) * B'_s(j) * (d b_m / d I_m) * (d I_m / d q) + // Since b_m(I_m) = padding + effectiveRangeBins * (I_m - sourceMin) / (sourceMax - sourceMin) + // d b_m / d I_m = effectiveRangeBins / (sourceMax - sourceMin) + let dBin_dIm = effectiveRangeBins / max(sourceMax - sourceMin, divisionEpsilon); + var accumulatedContribution = 0.0F; // C_ij * B_t(i) * B'_s(j) + + for (uint32_t tIdx = 0U; tIdx < targetCount; ++tIdx) { + let rowIndex = (targetBinStart + tIdx) * kNumBins; + let targetWeight = targetWeights[tIdx]; + let baseIndex = rowIndex + sourceBinStart; + + var coefficientsVector : float4 = 0.0F; + [unroll] + for (uint32_t off = 0U; off < 4U; ++off) { + uint32_t col = sourceBinStart + off; + if (col < kNumBins) { + coefficientsVector[off] = coefficientsTable[rowIndex + col]; + } + } + + accumulatedContribution += targetWeight * dot(coefficientsVector, sourceDerivWeights); + } + + let factor : float = accumulatedContribution * dBin_dIm; + // Early-out if no contribution + if (abs(factor) > 0.0F) { + // Compute per-parameter image intensity Jacobian dIm/dq: reuse jacobian vectors + // to avoid recomputing inside both gradient and Hessian paths. + var dImdq: Gradients; + [ForceUnroll] + for (uint32_t q = 0U; q < Transformation.kParamCount; ++q) { + let jvec = transformation.jacobianVector(q, targetScannerCoord); + dImdq[q] = dot(gradMovingScanner, jvec); + } + + for (uint32_t q = 0U; q < Transformation.kParamCount; ++q) { gradients[q] = factor * dImdq[q]; } + } + } + + // Wave-level reduction for gradients + let waveSize = WaveGetLaneCount(); + let lane = WaveGetLaneIndex(); + [unroll] + for (uint32_t q = 0U; q < Transformation.kParamCount; ++q) { + let sum = WaveActiveSum(gradients[q]); + if (lane == 0) { + let waveIndex = localIndex / waveSize; + if (waveIndex < kWaveMax) { wavePartials[waveIndex][q] += sum; } + } + } + + GroupMemoryBarrierWithGroupSync(); + + let workgroupGradients = workgroupReduce, kWaveMax>(wavePartials, localIndex); + + if (localIndex == 0U) { + let wgIndex = workgroupId.x + + workgroupId.y * uniforms.dispatchGrid[0] + + workgroupId.z * uniforms.dispatchGrid[0] * uniforms.dispatchGrid[1]; + partialSumsGradients[wgIndex] = workgroupGradients; + } +} diff --git a/cpp/core/gpu/shaders/registration/ssd.slang b/cpp/core/gpu/shaders/registration/ssd.slang new file mode 100644 index 0000000000..afbe5698c6 --- /dev/null +++ b/cpp/core/gpu/shaders/registration/ssd.slang @@ -0,0 +1,137 @@ +// TODO: Write math and derivation of gradients +import global_transformation; +import reduction_utils; +import texture_utils; +import coordinate_mapper; +import voxelscannermatrices; + +struct SSDParameters +{ + float ssd; + Array gradients; +}; + +struct SSDReductionOP : IReduceOp> +{ + static SSDParameters identityElement() { + Array gradients; + for (uint32_t i = 0U; i < N; ++i) + { + gradients[i] = 0.0F; + } + return SSDParameters(0.0F, gradients); + } + + static SSDParameters reduce(SSDParameters a, SSDParameters b) { + let ssd = a.ssd + b.ssd; + Array gradients; + for (uint32_t i = 0U; i < N; ++i) + { + gradients[i] = a.gradients[i] + b.gradients[i]; + } + return SSDParameters(ssd, gradients); + } +}; + + +extern static const uint32_t kWorkgroupSizeX = 8; +extern static const uint32_t kWorkgroupSizeY = 8; +extern static const uint32_t kWorkgroupSizeZ = 4; +// TODO: Use extern static const bool once the Slang compiler bug is fixed. +extern static const uint32_t kUseSourceMask; +extern static const uint32_t kUseTargetMask; + +extern static const uint32_t kComputeGradients; +static const uint32_t workgroupInvocations = kWorkgroupSizeX * kWorkgroupSizeY * kWorkgroupSizeZ; + + +struct Uniforms where Transformation : ITransformation{ + uint3 dispatchGrid; + float3 transformationPivot; + Transformation.TParams currentTransform; + VoxelScannerMatrices voxelScannerMatrices; +}; + +[shader("compute")] +[numthreads(kWorkgroupSizeX, kWorkgroupSizeY, kWorkgroupSizeZ)] +void main( + uint3 globalId: SV_DispatchThreadID, + uint3 localId: SV_GroupThreadID, + uint3 workgroupId: SV_GroupID, + ConstantBuffer> uniforms, + Texture3D sourceImage, + Texture3D targetImage, + Texture3D sourceMask, + Texture3D targetMask, + SamplerState linearSampler, + RWStructuredBuffer> ssdAndGradientsPartials, + RWStructuredBuffer> numContributingVoxels +) where Transformation : ITransformation +{ + static const uint32_t paramsCount = Transformation.kParamCount; + typedef SSDParameters CurrentSSDParameters; + static groupshared Array localParameters; + static groupshared Atomic localNumContributingVoxels; + + var ssdParams = CurrentSSDParameters(0.0F, Array()); + let localIndex = localId.x + localId.y * kWorkgroupSizeX + localId.z * kWorkgroupSizeX * kWorkgroupSizeY; + let sourceDim = textureSize(sourceImage); + let targetDim = textureSize(targetImage); + + let coordMapper = CoordinateMapper( + sourceDim, targetDim, uniforms.voxelScannerMatrices + ); + + let transformation = Transformation(uniforms.currentTransform, uniforms.transformationPivot); + let sourceTextureField = VoxelSamplingField(sourceImage, linearSampler); + let sourceMaskField = VoxelSamplingField(sourceMask, linearSampler); + let targetMaskField = VoxelSamplingField(targetMask, linearSampler); + + if (localIndex == 0U) { + localNumContributingVoxels.store(0U); + } + GroupMemoryBarrierWithGroupSync(); + + bool includeVoxel = coordMapper.inTargetInt(int3(globalId)); + let targetVoxelCoord = float3(globalId); + + if (includeVoxel && targetMaskField.maskAccepts(targetVoxelCoord, kUseTargetMask != 0U)) { + let movingVoxelCoord = coordMapper.mapTargetVoxelToSource(targetVoxelCoord, transformation); + if (coordMapper.inSource(movingVoxelCoord) && sourceMaskField.maskAccepts(movingVoxelCoord, kUseSourceMask != 0U)) { + let targetIntensity = targetImage.Load(int4(globalId, 0)).r; + let movingIntensity = sourceTextureField.sample(movingVoxelCoord); + let error = movingIntensity - targetIntensity; + let gradMovingVoxel = sourceTextureField.spatialGradient(movingVoxelCoord); + let gradMovingScanner = coordMapper.mapVoxelGradientToScanner(gradMovingVoxel); + let targetScannerCoord = coordMapper.mapTargetVoxelToScanner(targetVoxelCoord); + + var voxelGradients : Array; + if (kComputeGradients != 0U) { + for (uint32_t i = 0U; i < paramsCount; ++i) { + voxelGradients[i] = error * dot(gradMovingScanner, transformation.jacobianVector(i, targetScannerCoord)); + } + } + else { + for (uint32_t i = 0U; i < paramsCount; ++i) { + voxelGradients[i] = 0.0F; + } + } + ssdParams = CurrentSSDParameters(0.5F * error * error, voxelGradients); + localNumContributingVoxels.increment(); + } + } + localParameters[localIndex] = ssdParams; + + GroupMemoryBarrierWithGroupSync(); + + let finalValue = workgroupReduce< CurrentSSDParameters, SSDReductionOP, workgroupInvocations>( + localParameters, localIndex); + + if (localIndex == 0U) { + let wgIndex = workgroupId.x + + workgroupId.y * uniforms.dispatchGrid[0] + + workgroupId.z * uniforms.dispatchGrid[0] * uniforms.dispatchGrid[1]; + ssdAndGradientsPartials[wgIndex] = finalValue; + numContributingVoxels[0].add(localNumContributingVoxels.load()); + } +} diff --git a/cpp/core/gpu/shaders/registration/voxelscannermatrices.slang b/cpp/core/gpu/shaders/registration/voxelscannermatrices.slang new file mode 100644 index 0000000000..1770cd6c4b --- /dev/null +++ b/cpp/core/gpu/shaders/registration/voxelscannermatrices.slang @@ -0,0 +1,7 @@ +struct VoxelScannerMatrices +{ + float4x4 sourceVoxelToScanner; + float4x4 targetVoxelToScanner; + float4x4 sourceScannerToVoxel; + float4x4 targetScannerToVoxel; +} \ No newline at end of file diff --git a/cpp/core/gpu/shaders/texture_utils.slang b/cpp/core/gpu/shaders/texture_utils.slang new file mode 100644 index 0000000000..7926458c9e --- /dev/null +++ b/cpp/core/gpu/shaders/texture_utils.slang @@ -0,0 +1,135 @@ +module texture_utils; + +public uint3 textureSize(Texture3D tex) where T : ITexelElement { + uint3 size; + tex.GetDimensions(size.x, size.y, size.z); + return size; +} + +public uint3 textureSize(RWTexture3D tex) where T : ITexelElement { + uint3 size; + tex.GetDimensions(size.x, size.y, size.z); + return size; +} + +public uint3 textureSize(WTexture3D tex) where T : ITexelElement { + uint3 size; + tex.GetDimensions(size.x, size.y, size.z); + return size; +} + + +// Defines different coordinate systems for 3D volume sampling. +// The underlying texture sampler always expects normalized coordinates [0,1], +// so each coordinate system applies the appropriate transformation before sampling. +public enum SamplingCoordinateSystem { + // Voxel-centered coordinate system. + // In this system, voxel centers are located at integer positions (i, j, k). + // When sampling, coordinates are shifted by +0.5 before normalization. + // GPU samplers expect coordinates where the center of a texel + // is at (i+0.5, j+0.5), so adding 0.5 aligns it correctly. + Voxel, + // Texel-centered coordinate system. + // In this system, voxel corners are at integer positions, meaning voxel centers + // are at half-integer positions (i+0.5, j+0.5, k+0.5). + // Coordinates are used directly without shifting during sampling. + Texel +} + +public typealias TexelSamplingField = VolumeSamplingField; +public typealias VoxelSamplingField = VolumeSamplingField; + + +// A convenient class for sampling and gradient computation on 3D images. +public struct VolumeSamplingField { + Texture3D image; + uint3 dimensions; + SamplerState sampler; + + public __init(Texture3D img, SamplerState smpl) { + image = img; + sampler = smpl; + image.GetDimensions(dimensions.x, dimensions.y, dimensions.z); + } + + public float sample(float3 coord) { + if (coordSys == SamplingCoordinateSystem.Voxel) { + return image.SampleLevel(sampler, (coord + 0.5F) / float3(dimensions), 0.0F); + } else if (coordSys == SamplingCoordinateSystem.Texel) { + return image.SampleLevel(sampler, coord / float3(dimensions), 0.0F); + } + return 0.0F; + } + + public bool maskAccepts(float3 coord, bool useMask) { + if (!useMask) { return true; } + return sample(coord) >= 0.5F; + } + + public float3 spatialGradient(float3 voxelCoord) { + let offsetX = float3(1.0F, 0.0F, 0.0F); + let offsetY = float3(0.0F, 1.0F, 0.0F); + let offsetZ = float3(0.0F, 0.0F, 1.0F); + + let gradX = (sample(voxelCoord + offsetX) - sample(voxelCoord - offsetX)) * 0.5f; + let gradY = (sample(voxelCoord + offsetY) - sample(voxelCoord - offsetY)) * 0.5f; + let gradZ = (sample(voxelCoord + offsetZ) - sample(voxelCoord - offsetZ)) * 0.5f; + + return float3(gradX, gradY, gradZ); + } + + public float3 spatialGradientTrilinearAnalytic(float3 coord) { + float3 localCoord = coord; + if (coordSys == SamplingCoordinateSystem.Texel) { + localCoord = coord - 0.5F; + } + + // Clamp localCoord so i0 >= 0 and i0+1 stays in-bounds + int3 i0 = int3(floor(localCoord)); + i0 = max(i0, int3(0, 0, 0)); + let i1 : int3 = min(i0 + 1, int3(dimensions) - 1); + let t : float3 = localCoord - float3(i0); + + // Fetch the 8 cell corners + let v000 = image.Load(int4(i0, 0)).r; + let v100 = image.Load(int4(int3(i1.x, i0.y, i0.z),0)).r; + let v010 = image.Load(int4(int3(i0.x, i1.y, i0.z),0)).r; + let v110 = image.Load(int4(int3(i1.x, i1.y, i0.z),0)).r; + let v001 = image.Load(int4(int3(i0.x, i0.y, i1.z),0)).r; + let v101 = image.Load(int4(int3(i1.x, i0.y, i1.z),0)).r; + let v011 = image.Load(int4(int3(i0.x, i1.y, i1.z),0)).r; + let v111 = image.Load(int4(int3(i1.x, i1.y, i1.z),0)).r; + + // Analytic partials of the trilinear form + let ddx0 = lerp(v100 - v000, v110 - v010, t.y); + let ddx1 = lerp(v101 - v001, v111 - v011, t.y); + let ddx = lerp(ddx0, ddx1, t.z); + + let ddy0 = lerp(v010 - v000, v110 - v100, t.x); + let ddy1 = lerp(v011 - v001, v111 - v101, t.x); + let ddy = lerp(ddy0, ddy1, t.z); + + let ddz0 = lerp(v001 - v000, v101 - v100, t.x); + let ddz1 = lerp(v011 - v010, v111 - v110, t.x); + let ddz = lerp(ddz0, ddz1, t.y); + + return float3(ddx, ddy, ddz); + } +} + + +public struct GradientSamplingField { + Texture3D gradientImage; + uint3 dimensions; + SamplerState sampler; + + public __init(Texture3D img, SamplerState smpl) { + gradientImage = img; + sampler = smpl; + gradientImage.GetDimensions(dimensions.x, dimensions.y, dimensions.z); + } + + public float3 sample(float3 coord) { + return gradientImage.SampleLevel(sampler, coord / float3(dimensions), 0.0F).rgb; + } +} From 022149fc039fe751fa95d03baa1a92d433870354 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Wed, 14 Jan 2026 08:48:50 +0000 Subject: [PATCH 03/12] Symlink shader code + hack to find registration shaders --- cpp/core/CMakeLists.txt | 14 ++++++++++++++ cpp/core/gpu/gpu.cpp | 8 ++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt index 1c776027fa..a9b9b2efe9 100644 --- a/cpp/core/CMakeLists.txt +++ b/cpp/core/CMakeLists.txt @@ -142,6 +142,20 @@ if(WIN32) endif() +# TODO: Add install rules for shaders +set(GPU_SHADER_SRC "${CMAKE_CURRENT_SOURCE_DIR}/gpu/shaders") +set(SHADER_BIN_DEST "${PROJECT_BINARY_DIR}/bin/shaders") + +# Custom target to symlink shaders to the build/bin/shaders directory +add_custom_target(copy-gpu-shaders ALL + COMMAND ${CMAKE_COMMAND} -E create_symlink ${GPU_SHADER_SRC} ${SHADER_BIN_DEST} + COMMENT "Symlinking GPU shaders to build directory" + VERBATIM +) + +# Ensure the core library depends on this target +add_dependencies(mrtrix-core copy-gpu-shaders) + install(TARGETS mrtrix-core RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/cpp/core/gpu/gpu.cpp b/cpp/core/gpu/gpu.cpp index fa52ff17c5..ef7a8071bd 100644 --- a/cpp/core/gpu/gpu.cpp +++ b/cpp/core/gpu/gpu.cpp @@ -246,7 +246,11 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique search_paths = {executable_dir_cstr, registration_dir_cstr}; std::vector slang_compiler_options; { @@ -261,8 +265,8 @@ ComputeContext::ComputeContext() : m_slang_session_info(std::make_unique Date: Wed, 14 Jan 2026 08:49:29 +0000 Subject: [PATCH 04/12] Use exec path for finding shaders --- cpp/core/gpu/slangcodegen.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/cpp/core/gpu/slangcodegen.cpp b/cpp/core/gpu/slangcodegen.cpp index 5680f4c92e..4e83b59069 100644 --- a/cpp/core/gpu/slangcodegen.cpp +++ b/cpp/core/gpu/slangcodegen.cpp @@ -20,6 +20,7 @@ #include "gpu/gpu.h" #include "match_variant.h" #include "shadercache.h" +#include "platform.h" #include #include @@ -49,15 +50,27 @@ namespace { enum ReadFileMode : uint8_t { Text, Binary }; std::string read_file(const std::filesystem::path &filePath, ReadFileMode mode = ReadFileMode::Text) { using namespace std::string_literals; - if (!std::filesystem::exists(filePath)) { + + std::filesystem::path path_to_open = filePath; + if (!std::filesystem::exists(path_to_open)) { + // Try to find the file relative to the executable path + const auto exe_path = MR::Platform::get_executable_path(); + const auto exe_dir = exe_path.parent_path(); + const auto relative_path = exe_dir / filePath; + if (std::filesystem::exists(relative_path)) { + path_to_open = relative_path; + } + } + + if (!std::filesystem::exists(path_to_open)) { throw std::runtime_error("File not found: "s + filePath.string()); } const auto openMode = (mode == ReadFileMode::Binary) ? std::ios::in | std::ios::binary : std::ios::in; - std::ifstream f(filePath, std::ios::in | openMode); - const auto fileSize64 = std::filesystem::file_size(filePath); + std::ifstream f(path_to_open, std::ios::in | openMode); + const auto fileSize64 = std::filesystem::file_size(path_to_open); if (fileSize64 > static_cast(std::numeric_limits::max())) { - throw std::runtime_error("File too large to read into memory: "s + filePath.string()); + throw std::runtime_error("File too large to read into memory: "s + path_to_open.string()); } const std::streamsize fileSize = static_cast(fileSize64); std::string result(static_cast(fileSize), '\0'); From 0f76ded1fadf48f08ceb1681341660c26bcc8136 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Wed, 14 Jan 2026 08:50:31 +0000 Subject: [PATCH 05/12] Export boolean link-time constants --- cpp/core/gpu/slangcodegen.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/core/gpu/slangcodegen.cpp b/cpp/core/gpu/slangcodegen.cpp index 4e83b59069..02d3e8d93e 100644 --- a/cpp/core/gpu/slangcodegen.cpp +++ b/cpp/core/gpu/slangcodegen.cpp @@ -340,6 +340,7 @@ CompiledKernelWGSL compile_kernel_code_to_wgsl(const MR::GPU::KernelSpec &kernel for (const auto &[name, value] : kernel_spec.compute_shader.constants) { MR::match_v( value, + [&oss, name = name](bool v) { oss << "export static const bool " << name << " = " << v << ";\n"; }, [&oss, name = name](int32_t v) { oss << "export static const int32_t " << name << " = " << v << ";\n"; }, [&oss, name = name](uint32_t v) { oss << "export static const uint32_t " << name << " = " << v << ";\n"; }, [&oss, name = name](float v) { oss << "export static const float " << name << " = " << v << ";\n"; }); From 8866f5f6995e09b8ba16074fc6c54ddb8b7f04e6 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Mon, 19 Jan 2026 15:43:31 +0000 Subject: [PATCH 06/12] Update registration uniforms for new GPU API --- cpp/core/gpu/registration/imageoperations.cpp | 2 +- cpp/core/gpu/registration/ncccalculator.cpp | 9 +++-- cpp/core/gpu/registration/nmicalculator.cpp | 39 ++++++++++++------- cpp/core/gpu/registration/ssdcalculator.cpp | 13 +++++-- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/cpp/core/gpu/registration/imageoperations.cpp b/cpp/core/gpu/registration/imageoperations.cpp index 032c0266c2..40da106596 100644 --- a/cpp/core/gpu/registration/imageoperations.cpp +++ b/cpp/core/gpu/registration/imageoperations.cpp @@ -103,7 +103,7 @@ Eigen::Matrix3f computeScannerMoments(const Texture &texture, .centre = {centreScanner.x(), centreScanner.y(), centreScanner.z(), 0.0f}, }; const Buffer centreBuffer = - context.new_buffer_from_host_memory(&uniforms, sizeof(uniforms), BufferType::UniformBuffer); + context.new_buffer_from_host_object(uniforms, BufferType::UniformBuffer); Buffer momentBuffer = context.new_empty_buffer(kMomentCount); context.clear_buffer(momentBuffer); diff --git a/cpp/core/gpu/registration/ncccalculator.cpp b/cpp/core/gpu/registration/ncccalculator.cpp index 39e2eeafac..2bd414762f 100644 --- a/cpp/core/gpu/registration/ncccalculator.cpp +++ b/cpp/core/gpu/registration/ncccalculator.cpp @@ -62,7 +62,7 @@ void upload_uniforms(const ComputeContext &context, const auto params = transform.parameters(); std::copy_n(params.begin(), N, uniforms.current_transform.begin()); uniforms.voxel_scanner_matrices = matrices; - context.write_to_buffer(buffer, &uniforms, sizeof(uniforms)); + context.write_to_buffer(buffer, tcb::as_bytes(tcb::span>(&uniforms, 1))); } } // namespace @@ -86,8 +86,11 @@ NCCCalculator::NCCCalculator(const Config &config) m_terms_per_workgroup = 1U + m_degrees_of_freedom; m_global_terms_per_workgroup = 5U + 3U * m_degrees_of_freedom; - const size_t uniformsSize = is_rigid ? sizeof(RigidNCCUniforms) : sizeof(AffineNCCUniforms); - m_uniforms_buffer = m_compute_context->new_empty_buffer(uniformsSize, BufferType::UniformBuffer); + if (is_rigid) { + m_uniforms_buffer = m_compute_context->new_buffer_from_host_object(RigidNCCUniforms{}, BufferType::UniformBuffer); + } else { + m_uniforms_buffer = m_compute_context->new_buffer_from_host_object(AffineNCCUniforms{}, BufferType::UniformBuffer); + } m_num_contributing_voxels_buffer = m_compute_context->new_empty_buffer(1); if (m_use_local_window) { diff --git a/cpp/core/gpu/registration/nmicalculator.cpp b/cpp/core/gpu/registration/nmicalculator.cpp index 87b3bfaeea..05c59ad662 100644 --- a/cpp/core/gpu/registration/nmicalculator.cpp +++ b/cpp/core/gpu/registration/nmicalculator.cpp @@ -115,25 +115,28 @@ NMICalculator::NMICalculator(const Config &config) // comparisons match float ordering and atomics work for negative intensities too. // TODO: Should we use shared memory reduction instead of atomics for better performance? m_min_max_uniforms_buffer = - m_compute_context->new_empty_buffer(sizeof(MinMaxUniforms), BufferType::UniformBuffer); + m_compute_context->new_buffer_from_host_object(MinMaxUniforms{}, BufferType::UniformBuffer); m_min_max_intensity_fixed_buffer = - m_compute_context->new_buffer_from_host_memory(initialMinMax.data(), sizeof(initialMinMax)); + m_compute_context->new_buffer_from_host_memory(tcb::span(initialMinMax)); m_min_max_intensity_moving_buffer = - m_compute_context->new_buffer_from_host_memory(initialMinMax.data(), sizeof(initialMinMax)); + m_compute_context->new_buffer_from_host_memory(tcb::span(initialMinMax)); m_raw_joint_histogram_buffer = m_compute_context->new_empty_buffer(m_num_bins * m_num_bins); m_smoothed_joint_histogram_buffer = m_compute_context->new_empty_buffer(m_num_bins * m_num_bins); m_joint_histogram_mass_buffer = m_compute_context->new_empty_buffer(1); m_joint_histogram_uniforms_buffer = - m_compute_context->new_empty_buffer(sizeof(JointHistogramUniforms), BufferType::UniformBuffer); + m_compute_context->new_buffer_from_host_object(JointHistogramUniforms{}, BufferType::UniformBuffer); m_precomputed_coefficients_buffer = m_compute_context->new_empty_buffer(m_num_bins * m_num_bins); m_mutual_information_buffer = m_compute_context->new_empty_buffer(1); if (m_output == CalculatorOutput::CostAndGradients) { m_gradients_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, gradientsWorkgroupSize); - const uint32_t gradients_uniform_size = - is_affine ? sizeof(AffineGradientsUniforms) : sizeof(RigidGradientsUniforms); - m_gradients_uniforms_buffer = - m_compute_context->new_empty_buffer(gradients_uniform_size, BufferType::UniformBuffer); + if (is_affine) { + m_gradients_uniforms_buffer = + m_compute_context->new_buffer_from_host_object(AffineGradientsUniforms{}, BufferType::UniformBuffer); + } else { + m_gradients_uniforms_buffer = + m_compute_context->new_buffer_from_host_object(RigidGradientsUniforms{}, BufferType::UniformBuffer); + } m_gradients_buffer = m_compute_context->new_empty_buffer(m_degrees_of_freedom * m_gradients_dispatch_grid.workgroup_count()); } @@ -152,7 +155,8 @@ NMICalculator::NMICalculator(const Config &config) const MinMaxUniforms min_max_fixed_uniforms{ .dispatch_grid = fixed_dispatch_grid, }; - m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, &min_max_fixed_uniforms, sizeof(min_max_fixed_uniforms)); + m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, + tcb::as_bytes(tcb::span(&min_max_fixed_uniforms, 1))); m_compute_context->dispatch_kernel(min_max_fixed_kernel, fixed_dispatch_grid); const KernelSpec min_max_moving_kernel_spec{ @@ -171,7 +175,8 @@ NMICalculator::NMICalculator(const Config &config) const MinMaxUniforms min_max_moving_uniforms{ .dispatch_grid = moving_dispatch_grid, }; - m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, &min_max_moving_uniforms, sizeof(MinMaxUniforms)); + m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, + tcb::as_bytes(tcb::span(&min_max_moving_uniforms, 1))); m_compute_context->dispatch_kernel(m_min_max_moving_kernel, moving_dispatch_grid); const std::vector min_max_fixed_bits = @@ -193,7 +198,8 @@ NMICalculator::NMICalculator(const Config &config) .transformation_matrix = {}, }; m_compute_context->write_to_buffer( - m_joint_histogram_uniforms_buffer, &joint_histogram_uniforms, sizeof(JointHistogramUniforms)); + m_joint_histogram_uniforms_buffer, + tcb::as_bytes(tcb::span(&joint_histogram_uniforms, 1))); const uint32_t jointHistogramPartialsSize = (m_num_bins * m_num_bins) * m_joint_histogram_dispatch_grid.workgroup_count(); m_joint_histogram_kernel = m_compute_context->new_kernel({ .compute_shader = @@ -299,7 +305,8 @@ void NMICalculator::update(const GlobalTransform &transformation) { .transformation_matrix = EigenHelpers::to_array(transformation_matrix_voxel_space), }; m_compute_context->write_to_buffer( - m_joint_histogram_uniforms_buffer, &joint_histogram_uniforms, sizeof(joint_histogram_uniforms)); + m_joint_histogram_uniforms_buffer, + tcb::as_bytes(tcb::span(&joint_histogram_uniforms, 1))); m_compute_context->dispatch_kernel(m_joint_histogram_kernel, m_joint_histogram_dispatch_grid); const WorkgroupSize smoothWGSize{8, 8, 1}; @@ -329,7 +336,9 @@ void NMICalculator::update(const GlobalTransform &transformation) { .voxel_scanner_matrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer(m_gradients_uniforms_buffer, &gradients_uniforms, sizeof(AffineGradientsUniforms)); + m_compute_context->write_to_buffer( + m_gradients_uniforms_buffer, + tcb::as_bytes(tcb::span(&gradients_uniforms, 1))); } else { std::array params; const auto current = transformation.parameters(); @@ -341,7 +350,9 @@ void NMICalculator::update(const GlobalTransform &transformation) { .current_transform = params, .voxel_scanner_matrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer(m_gradients_uniforms_buffer, &gradients_uniforms, sizeof(RigidGradientsUniforms)); + m_compute_context->write_to_buffer( + m_gradients_uniforms_buffer, + tcb::as_bytes(tcb::span(&gradients_uniforms, 1))); } m_compute_context->dispatch_kernel(m_gradients_kernel, m_gradients_dispatch_grid); diff --git a/cpp/core/gpu/registration/ssdcalculator.cpp b/cpp/core/gpu/registration/ssdcalculator.cpp index 1ce5b03f52..e08cc43520 100644 --- a/cpp/core/gpu/registration/ssdcalculator.cpp +++ b/cpp/core/gpu/registration/ssdcalculator.cpp @@ -63,8 +63,11 @@ SSDCalculator::SSDCalculator(const Config &config) m_dispatch_grid = DispatchGrid::element_wise_texture(m_fixed, ssd_workgroup_size); - const uint32_t uniforms_size = is_rigid ? sizeof(RigidSSDUniforms) : sizeof(AffineSSDUniforms); - m_uniforms_buffer = m_compute_context->new_empty_buffer(uniforms_size, BufferType::UniformBuffer); + if (is_rigid) { + m_uniforms_buffer = m_compute_context->new_buffer_from_host_object(RigidSSDUniforms{}, BufferType::UniformBuffer); + } else { + m_uniforms_buffer = m_compute_context->new_buffer_from_host_object(AffineSSDUniforms{}, BufferType::UniformBuffer); + } const size_t params_per_workgroup = 1U + m_degrees_of_freedom; m_partials_buffer = m_compute_context->new_empty_buffer(params_per_workgroup * m_dispatch_grid.workgroup_count()); @@ -107,7 +110,8 @@ void SSDCalculator::update(const GlobalTransform &transformation) { .currentTransform = params, .voxelScannerMatrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer(m_uniforms_buffer, &uniforms, sizeof(AffineSSDUniforms)); + m_compute_context->write_to_buffer(m_uniforms_buffer, + tcb::as_bytes(tcb::span(&uniforms, 1))); } else { std::array params; const auto current = transformation.parameters(); @@ -118,7 +122,8 @@ void SSDCalculator::update(const GlobalTransform &transformation) { .currentTransform = params, .voxelScannerMatrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer(m_uniforms_buffer, &uniforms, sizeof(RigidSSDUniforms)); + m_compute_context->write_to_buffer(m_uniforms_buffer, + tcb::as_bytes(tcb::span(&uniforms, 1))); } m_compute_context->dispatch_kernel(m_kernel, m_dispatch_grid); From cff396bca48ca64dcc79d5d456a2b3d52138cf69 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Mon, 19 Jan 2026 15:46:17 +0000 Subject: [PATCH 07/12] Add write_object_to_buffer helper --- cpp/core/gpu/gpu.h | 9 +++++++++ cpp/core/gpu/registration/ncccalculator.cpp | 2 +- cpp/core/gpu/registration/nmicalculator.cpp | 22 ++++++--------------- cpp/core/gpu/registration/ssdcalculator.cpp | 6 ++---- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/cpp/core/gpu/gpu.h b/cpp/core/gpu/gpu.h index de3ca91252..bd5e3f917f 100644 --- a/cpp/core/gpu/gpu.h +++ b/cpp/core/gpu/gpu.h @@ -241,6 +241,15 @@ struct ComputeContext { return Buffer{bufferType, std::move(buffer)}; } + // Writes a POD-like object into a byte buffer (e.g. uniform buffers). + template + void write_object_to_buffer(const Buffer &buffer, const Object &object, uint64_t offset_bytes = 0) const { + static_assert(std::is_trivially_copyable_v, "Object must be trivially copyable"); + static_assert(std::is_standard_layout_v, "Object must be standard layout"); + const auto bytes = tcb::as_bytes(tcb::span(&object, 1)); + write_to_buffer(buffer, bytes, offset_bytes); + } + // This function blocks until the download is complete. template [[nodiscard]] std::vector download_buffer_as_vector(const Buffer &buffer) const { std::vector result(buffer.wgpu_handle.GetSize() / sizeof(T)); diff --git a/cpp/core/gpu/registration/ncccalculator.cpp b/cpp/core/gpu/registration/ncccalculator.cpp index 2bd414762f..9be8983c64 100644 --- a/cpp/core/gpu/registration/ncccalculator.cpp +++ b/cpp/core/gpu/registration/ncccalculator.cpp @@ -62,7 +62,7 @@ void upload_uniforms(const ComputeContext &context, const auto params = transform.parameters(); std::copy_n(params.begin(), N, uniforms.current_transform.begin()); uniforms.voxel_scanner_matrices = matrices; - context.write_to_buffer(buffer, tcb::as_bytes(tcb::span>(&uniforms, 1))); + context.write_object_to_buffer(buffer, uniforms); } } // namespace diff --git a/cpp/core/gpu/registration/nmicalculator.cpp b/cpp/core/gpu/registration/nmicalculator.cpp index 05c59ad662..280069f236 100644 --- a/cpp/core/gpu/registration/nmicalculator.cpp +++ b/cpp/core/gpu/registration/nmicalculator.cpp @@ -155,8 +155,7 @@ NMICalculator::NMICalculator(const Config &config) const MinMaxUniforms min_max_fixed_uniforms{ .dispatch_grid = fixed_dispatch_grid, }; - m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, - tcb::as_bytes(tcb::span(&min_max_fixed_uniforms, 1))); + m_compute_context->write_object_to_buffer(m_min_max_uniforms_buffer, min_max_fixed_uniforms); m_compute_context->dispatch_kernel(min_max_fixed_kernel, fixed_dispatch_grid); const KernelSpec min_max_moving_kernel_spec{ @@ -175,8 +174,7 @@ NMICalculator::NMICalculator(const Config &config) const MinMaxUniforms min_max_moving_uniforms{ .dispatch_grid = moving_dispatch_grid, }; - m_compute_context->write_to_buffer(m_min_max_uniforms_buffer, - tcb::as_bytes(tcb::span(&min_max_moving_uniforms, 1))); + m_compute_context->write_object_to_buffer(m_min_max_uniforms_buffer, min_max_moving_uniforms); m_compute_context->dispatch_kernel(m_min_max_moving_kernel, moving_dispatch_grid); const std::vector min_max_fixed_bits = @@ -197,9 +195,7 @@ NMICalculator::NMICalculator(const Config &config) .intensities = m_intensities, .transformation_matrix = {}, }; - m_compute_context->write_to_buffer( - m_joint_histogram_uniforms_buffer, - tcb::as_bytes(tcb::span(&joint_histogram_uniforms, 1))); + m_compute_context->write_object_to_buffer(m_joint_histogram_uniforms_buffer, joint_histogram_uniforms); const uint32_t jointHistogramPartialsSize = (m_num_bins * m_num_bins) * m_joint_histogram_dispatch_grid.workgroup_count(); m_joint_histogram_kernel = m_compute_context->new_kernel({ .compute_shader = @@ -304,9 +300,7 @@ void NMICalculator::update(const GlobalTransform &transformation) { .intensities = m_intensities, .transformation_matrix = EigenHelpers::to_array(transformation_matrix_voxel_space), }; - m_compute_context->write_to_buffer( - m_joint_histogram_uniforms_buffer, - tcb::as_bytes(tcb::span(&joint_histogram_uniforms, 1))); + m_compute_context->write_object_to_buffer(m_joint_histogram_uniforms_buffer, joint_histogram_uniforms); m_compute_context->dispatch_kernel(m_joint_histogram_kernel, m_joint_histogram_dispatch_grid); const WorkgroupSize smoothWGSize{8, 8, 1}; @@ -336,9 +330,7 @@ void NMICalculator::update(const GlobalTransform &transformation) { .voxel_scanner_matrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer( - m_gradients_uniforms_buffer, - tcb::as_bytes(tcb::span(&gradients_uniforms, 1))); + m_compute_context->write_object_to_buffer(m_gradients_uniforms_buffer, gradients_uniforms); } else { std::array params; const auto current = transformation.parameters(); @@ -350,9 +342,7 @@ void NMICalculator::update(const GlobalTransform &transformation) { .current_transform = params, .voxel_scanner_matrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer( - m_gradients_uniforms_buffer, - tcb::as_bytes(tcb::span(&gradients_uniforms, 1))); + m_compute_context->write_object_to_buffer(m_gradients_uniforms_buffer, gradients_uniforms); } m_compute_context->dispatch_kernel(m_gradients_kernel, m_gradients_dispatch_grid); diff --git a/cpp/core/gpu/registration/ssdcalculator.cpp b/cpp/core/gpu/registration/ssdcalculator.cpp index e08cc43520..eddf007d2e 100644 --- a/cpp/core/gpu/registration/ssdcalculator.cpp +++ b/cpp/core/gpu/registration/ssdcalculator.cpp @@ -110,8 +110,7 @@ void SSDCalculator::update(const GlobalTransform &transformation) { .currentTransform = params, .voxelScannerMatrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer(m_uniforms_buffer, - tcb::as_bytes(tcb::span(&uniforms, 1))); + m_compute_context->write_object_to_buffer(m_uniforms_buffer, uniforms); } else { std::array params; const auto current = transformation.parameters(); @@ -122,8 +121,7 @@ void SSDCalculator::update(const GlobalTransform &transformation) { .currentTransform = params, .voxelScannerMatrices = m_voxel_scanner_matrices, }; - m_compute_context->write_to_buffer(m_uniforms_buffer, - tcb::as_bytes(tcb::span(&uniforms, 1))); + m_compute_context->write_object_to_buffer(m_uniforms_buffer, uniforms); } m_compute_context->dispatch_kernel(m_kernel, m_dispatch_grid); From 983f1f155b5c9e0d6a451f09419dfe903170db7e Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Mon, 19 Jan 2026 15:46:53 +0000 Subject: [PATCH 08/12] Add test for write_object_to_buffer --- testing/unit_tests/gputests.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/testing/unit_tests/gputests.cpp b/testing/unit_tests/gputests.cpp index a609a24645..9f153311a6 100644 --- a/testing/unit_tests/gputests.cpp +++ b/testing/unit_tests/gputests.cpp @@ -102,6 +102,28 @@ TEST_F(GPUTest, BufferFromHostMemoryObject) { EXPECT_EQ(downloaded_data.c, host_data.c); } +TEST_F(GPUTest, WriteObjectToBuffer) { + struct Data { + float a; + float b; + float c; + }; + + const Data initial_data{0.0F, 0.0F, 0.0F}; + const Buffer buffer = context.new_buffer_from_host_object(initial_data, BufferType::UniformBuffer); + + const Data host_data{1.25F, -2.5F, 3.75F}; + context.write_object_to_buffer(buffer, host_data); + + Data downloaded_data{}; + auto downloaded_bytes = tcb::as_writable_bytes(tcb::span(&downloaded_data, 1)); + context.download_buffer(buffer, downloaded_bytes); + + EXPECT_EQ(downloaded_data.a, host_data.a); + EXPECT_EQ(downloaded_data.b, host_data.b); + EXPECT_EQ(downloaded_data.c, host_data.c); +} + TEST_F(GPUTest, BufferFromHostMemoryMultipleRegions) { std::vector region1 = {1, 2, 3}; std::vector region2 = {4, 5}; From b412e0300e5defc8c0d5701d54883c0ddff92b84 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Mon, 19 Jan 2026 15:50:19 +0000 Subject: [PATCH 09/12] Enforce 4-byte size for GPU object uploads --- cpp/core/gpu/gpu.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/core/gpu/gpu.h b/cpp/core/gpu/gpu.h index bd5e3f917f..d03c5cf1a3 100644 --- a/cpp/core/gpu/gpu.h +++ b/cpp/core/gpu/gpu.h @@ -222,6 +222,7 @@ struct ComputeContext { new_buffer_from_host_object(const Object &object, BufferType buffer_type = BufferType::StorageBuffer) const { static_assert(std::is_trivially_copyable_v, "Object must be trivially copyable"); static_assert(std::is_standard_layout_v, "Object must be standard layout"); + static_assert(sizeof(Object) % 4 == 0, "Object size must be a multiple of 4 bytes"); return {buffer_type, inner_new_buffer_from_host_memory(&object, sizeof(object), buffer_type)}; } @@ -246,6 +247,7 @@ struct ComputeContext { void write_object_to_buffer(const Buffer &buffer, const Object &object, uint64_t offset_bytes = 0) const { static_assert(std::is_trivially_copyable_v, "Object must be trivially copyable"); static_assert(std::is_standard_layout_v, "Object must be standard layout"); + static_assert(sizeof(Object) % 4 == 0, "Object size must be a multiple of 4 bytes"); const auto bytes = tcb::as_bytes(tcb::span(&object, 1)); write_to_buffer(buffer, bytes, offset_bytes); } From 3bb753372b837c7a27f0fab31fb70963b4d0459f Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Thu, 29 Jan 2026 14:20:32 +0000 Subject: [PATCH 10/12] Rename Metric class -> GlobalMetric --- cpp/cmd/mrreggpu.cpp | 4 ++-- cpp/core/gpu/registration/calculatorinterface.h | 2 +- cpp/core/gpu/registration/registrationtypes.h | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/cmd/mrreggpu.cpp b/cpp/cmd/mrreggpu.cpp index 3b2fbbdcaa..9e731248b7 100644 --- a/cpp/cmd/mrreggpu.cpp +++ b/cpp/cmd/mrreggpu.cpp @@ -286,7 +286,7 @@ const InitRotationChoice init_rotation = from_name(get_option_value("init_rotation", "none")); const float init_rotation_max_angle = get_option_value("init_rotation_max_angle", default_max_search_angle); - Metric init_metric; + GlobalMetric init_metric; switch (metric_type) { case MetricType::NMI: init_metric = NMIMetric{}; @@ -350,7 +350,7 @@ ++index; } - Metric metric; + GlobalMetric metric; switch (metric_type) { case MetricType::NMI: metric = NMIMetric{}; diff --git a/cpp/core/gpu/registration/calculatorinterface.h b/cpp/core/gpu/registration/calculatorinterface.h index 028f336617..db2c5d9344 100644 --- a/cpp/core/gpu/registration/calculatorinterface.h +++ b/cpp/core/gpu/registration/calculatorinterface.h @@ -38,7 +38,7 @@ class Calculator final : public std::variant; +using GlobalMetric = std::variant; enum class MetricType : uint8_t { NMI, SSD, NCC }; enum class InitTranslationChoice : uint8_t { None, Mass, Geometric }; @@ -124,7 +124,7 @@ enum class InitRotationChoice : uint8_t { None, Search, Moments }; struct InitialisationOptions { InitTranslationChoice translation_choice = InitTranslationChoice::Mass; InitRotationChoice rotation_choice = InitRotationChoice::None; - Metric cost_metric = NMIMetric{}; + GlobalMetric cost_metric = NMIMetric{}; // Limits the maximum sampled rotation angle (degrees) for search-based initialisation. float max_search_angle_degrees = 90.0F; }; @@ -143,7 +143,7 @@ struct RegistrationConfig { std::vector channels; TransformationType transformation_type; InitialGuess initial_guess; - Metric metric; + GlobalMetric metric; uint32_t max_iterations = 500; }; From 0a6b4e2034573428ae08084ebf6e8751b97f4b14 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Thu, 29 Jan 2026 14:28:34 +0000 Subject: [PATCH 11/12] Rename RegistrationConfig -> GlobalRegistrationConfig --- cpp/cmd/mrreggpu.cpp | 2 +- cpp/core/gpu/registration/globalregistration.cpp | 2 +- cpp/core/gpu/registration/globalregistration.h | 2 +- cpp/core/gpu/registration/registrationtypes.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/cmd/mrreggpu.cpp b/cpp/cmd/mrreggpu.cpp index 9e731248b7..f6b832f316 100644 --- a/cpp/cmd/mrreggpu.cpp +++ b/cpp/cmd/mrreggpu.cpp @@ -365,7 +365,7 @@ throw Exception("Unsupported metric type"); } - const RegistrationConfig registration_config{ + const GlobalRegistrationConfig registration_config{ .channels = channels, .transformation_type = transform_type, .initial_guess = initial_guess, diff --git a/cpp/core/gpu/registration/globalregistration.cpp b/cpp/core/gpu/registration/globalregistration.cpp index b7fce61c83..1d373f74d8 100644 --- a/cpp/core/gpu/registration/globalregistration.cpp +++ b/cpp/core/gpu/registration/globalregistration.cpp @@ -127,7 +127,7 @@ struct ChannelData { float weight = 1.0F; }; -RegistrationResult run_registration(const RegistrationConfig &config, const GPU::ComputeContext &context) { +RegistrationResult run_registration(const GlobalRegistrationConfig &config, const GPU::ComputeContext &context) { constexpr uint32_t num_levels = 3U; const bool is_affine = config.transformation_type == TransformationType::Affine; const uint32_t degrees_of_freedom = is_affine ? 12U : 6U; diff --git a/cpp/core/gpu/registration/globalregistration.h b/cpp/core/gpu/registration/globalregistration.h index 8c11df651f..3ae1909f92 100644 --- a/cpp/core/gpu/registration/globalregistration.h +++ b/cpp/core/gpu/registration/globalregistration.h @@ -20,5 +20,5 @@ #include "gpu/registration/registrationtypes.h" namespace MR::GPU { -RegistrationResult run_registration(const RegistrationConfig &config, const GPU::ComputeContext &context); +RegistrationResult run_registration(const GlobalRegistrationConfig &config, const GPU::ComputeContext &context); } diff --git a/cpp/core/gpu/registration/registrationtypes.h b/cpp/core/gpu/registration/registrationtypes.h index 88ed589136..b6ae83364e 100644 --- a/cpp/core/gpu/registration/registrationtypes.h +++ b/cpp/core/gpu/registration/registrationtypes.h @@ -139,7 +139,7 @@ struct ChannelConfig { float weight = 1.0F; }; -struct RegistrationConfig { +struct GlobalRegistrationConfig { std::vector channels; TransformationType transformation_type; InitialGuess initial_guess; From cfb2dca1a91e5599c2ccadac32276171e0b78996 Mon Sep 17 00:00:00 2001 From: Daljit Singh Date: Thu, 29 Jan 2026 14:30:10 +0000 Subject: [PATCH 12/12] Add missing include header in registrationtypes.h --- cpp/core/gpu/registration/registrationtypes.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/core/gpu/registration/registrationtypes.h b/cpp/core/gpu/registration/registrationtypes.h index b6ae83364e..4f4361f5e2 100644 --- a/cpp/core/gpu/registration/registrationtypes.h +++ b/cpp/core/gpu/registration/registrationtypes.h @@ -17,6 +17,7 @@ #pragma once #include "image.h" +#include #include #include "types.h"