Skip to content

Affine registration on the GPU#3258

Draft
daljit46 wants to merge 12 commits intodevfrom
affine_reg_gpu
Draft

Affine registration on the GPU#3258
daljit46 wants to merge 12 commits intodevfrom
affine_reg_gpu

Conversation

@daljit46
Copy link
Member

@daljit46 daljit46 commented Jan 14, 2026

This work builds on top of #3238. It introduces a new C++ command called mrreggpu (chosen randomly and without much thought) that performs affine registration of 3D images on the GPU using WebGPU compute shaders. The code is completely independent of mrregister.
It's not ready to be merged and not ready for review yet. It needs much refinement, but I'm posting this PR to gather early feedback.
The utility of this command is rather limited since it only performs affine registration on scalar images, with some other limitations. The primary aim, however, is to provide a first real-world example of the GPU compute API introduced in #3238 so that we have a reference on how to use the GPU in the codebase.
This is also intended as a stepping stone towards non-linear registration on the GPU. I spent some time experimenting with SVF/SyN-style deformation using the current GPU API and I think the approach is feasible, but I'd appreciate guidance on which direction would be most useful.

mrreggpu currently supports:

  • 3D affine image registration on scalar images (4D images are not supported).
  • Three metrics: NMI, SSD and NCC (global and local, using a sliding window).
  • Multi-contrast registration.
  • An interface mostly similar to mrregister, with minor differences (not yet sure these are justified).
  • A symmetric registration strategy. Unlike mrregister, it doesn't register both images into an average space. Instead it registers in both directions and then uses Lie algebra averaging to compute the transform for the next step (see here.
  • An optimiser based on a slightly enhanced version of Adam. I also experimented with other optimisers (e.g. Levenberg–Marquardt), but stuck with Adam for simplicity (affine registration on the GPU is fast enough). The optimiser runs on the CPU while the update step is on the GPU. I also experimented with running everything on the GPU, but the added complexity didn't feel worth it.

Notes on the current state

  • Add option for canonical direct I/O layout #3108 needs a resolution before this can be merged.
  • New code lives in cpp/core/gpu/registration and cpp/core/gpu/shaders.
  • The code is still rough around the edges and needs refinement, but I hope it's understandable enough to give a general idea of how things work. The registration logic may have holes.
  • There are some hacks / temporary solutions to get things working, which I hope to clean up over time.
  • The (L)NCC logic seems buggy (not sure why yet) and probably needs fixing.
  • Like in GPU compute API abstraction on top of WebGPU #3238 I've made extensive use of designated initialisers (a C++20 feature, though supported by Clang and GCC).
  • This PR also includes support for magic_enum (for enum <-> string conversion), but that should likely be split into a separate PR.
  • Some code is not necessarily specific to registration per say (e.g. logic for performing reduction operations in workgroups or performing some operations like computing CoM or downsampling), but was necessary for building the commands. Perhaps, I should separate that logic into a separate PR?
  • I've tested the code manually with the help of a rudimentary Python script, but I'm hoping to add a comprehensive enough set of unit tests to test out the core logic. On this note, we probably need to identify a suitable set of testing data to use for the NMI and NCC metrics.

I'm aware that this is a rather large PR, but any feedback would be welcome.

@daljit46 daljit46 self-assigned this Jan 14, 2026
@daljit46 daljit46 marked this pull request as draft January 14, 2026 08:59
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 25 out of 57. Check the log or trigger a new build to see more.

#include "gpu/registration/globalregistration.h"
#include "gpu/registration/registrationtypes.h"
#include "gpu/registration/imageoperations.h"
#include "gpu/registration/imageoperations.h"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: duplicate include [readability-duplicate-include]

cpp/cmd/mrreggpu.cpp:28:

-  #include "gpu/registration/imageoperations.h"
-  #include "gpu/registration/imageoperations.h"
+  #include "gpu/registration/imageoperations.h"

File::Matrix::save_transform(registration_result.transformation, centre, matrix_filename);
}
if (!matrix_1tomid_filename.empty()) {
File::Matrix::save_transform(halfway_transforms->half, centre, matrix_1tomid_filename);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: unchecked access to optional value [bugprone-unchecked-optional-access]

       File::Matrix::save_transform(halfway_transforms->half, centre, matrix_1tomid_filename);
                                    ^

File::Matrix::save_transform(halfway_transforms->half, centre, matrix_1tomid_filename);
}
if (!matrix_2tomid_filename.empty()) {
File::Matrix::save_transform(halfway_transforms->half_inverse, centre, matrix_2tomid_filename);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: unchecked access to optional value [bugprone-unchecked-optional-access]

       File::Matrix::save_transform(halfway_transforms->half_inverse, centre, matrix_2tomid_filename);
                                    ^

if (!transformed_midway1_filenames.empty()) {
// Compute midpioint transforms in scanner space and then build a midway output header that can hold both images
using ProjectiveTransform = Eigen::Transform<default_type, 3, Eigen::Projective>;
const ProjectiveTransform half_projective(halfway_transforms->half_matrix);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: unchecked access to optional value [bugprone-unchecked-optional-access]

     const ProjectiveTransform half_projective(halfway_transforms->half_matrix);
                                               ^

// Compute midpioint transforms in scanner space and then build a midway output header that can hold both images
using ProjectiveTransform = Eigen::Transform<default_type, 3, Eigen::Projective>;
const ProjectiveTransform half_projective(halfway_transforms->half_matrix);
const ProjectiveTransform half_inverse_projective(halfway_transforms->half_inverse_matrix);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: unchecked access to optional value [bugprone-unchecked-optional-access]

     const ProjectiveTransform half_inverse_projective(halfway_transforms->half_inverse_matrix);
                                                       ^

const Buffer<float> matrixBuffer = context.new_buffer_from_host_memory<float>(matrixData);

const MomentUniforms uniforms{
.centre = {centreScanner.x(), centreScanner.y(), centreScanner.z(), 0.0f},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: floating point literal has suffix 'f', which is not uppercase [readability-uppercase-literal-suffix]

Suggested change
.centre = {centreScanner.x(), centreScanner.y(), centreScanner.z(), 0.0f},
.centre = {centreScanner.x(), centreScanner.y(), centreScanner.z(), 0.0F},


std::array<float, kMomentCount> momentValues{};
for (size_t i = 0; i < kMomentCount; ++i) {
std::memcpy(&momentValues[i], &momentBits[i], sizeof(float));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]

    std::memcpy(&momentValues[i], &momentBits[i], sizeof(float));
                                   ^


std::array<float, kMomentCount> momentValues{};
for (size_t i = 0; i < kMomentCount; ++i) {
std::memcpy(&momentValues[i], &momentBits[i], sizeof(float));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]

    std::memcpy(&momentValues[i], &momentBits[i], sizeof(float));
                 ^


context.dispatch_kernel(transformKernel, dispatch_grid);

return outputTexture;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: constness of 'outputTexture' prevents automatic move [performance-no-automatic-move]

  return outputTexture;
         ^


context.dispatch_kernel(transformKernel, dispatch_grid);

return outputTexture;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: constness of 'outputTexture' prevents automatic move [performance-no-automatic-move]

  return outputTexture;
         ^

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 25 out of 34. Check the log or trigger a new build to see more.

#include <string>
#include <vector>
#include <utility>
#include <vector>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: duplicate include [readability-duplicate-include]

cpp/core/gpu/registration/initialisation.cpp:46:

- #include <utility>
- #include <vector>
+ #include <utility>

return false;
}

const Eigen::Vector3f values = solver.eigenvalues();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: the const qualified variable 'values' is copy-constructed from a const reference; consider making it a const reference [performance-unnecessary-copy-initialization]

Suggested change
const Eigen::Vector3f values = solver.eigenvalues();
const Eigen::Vector3f& values = solver.eigenvalues();

}

const Eigen::Vector3f values = solver.eigenvalues();
const Eigen::Matrix3f vectors = solver.eigenvectors();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: the const qualified variable 'vectors' is copy-constructed from a const reference; consider making it a const reference [performance-unnecessary-copy-initialization]

Suggested change
const Eigen::Matrix3f vectors = solver.eigenvectors();
const Eigen::Matrix3f& vectors = solver.eigenvectors();

std::sort(indices.begin(), indices.end(), [&](int a, int b) { return values[a] > values[b]; });

for (size_t i = 0; i < indices.size(); ++i) {
eigenvalues[static_cast<int>(i)] = values[indices[i]];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]

    eigenvalues[static_cast<int>(i)] = values[indices[i]];
                                              ^


for (size_t i = 0; i < indices.size(); ++i) {
eigenvalues[static_cast<int>(i)] = values[indices[i]];
eigenvectors.col(static_cast<int>(i)) = vectors.col(indices[i]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use array subscript when the index is not an integer constant expression [cppcoreguidelines-pro-bounds-constant-array-index]

    eigenvectors.col(static_cast<int>(i)) = vectors.col(indices[i]);
                                                        ^

.transformation_matrix = {},
};
m_compute_context->write_object_to_buffer(m_joint_histogram_uniforms_buffer, joint_histogram_uniforms);
const uint32_t jointHistogramPartialsSize = (m_num_bins * m_num_bins) * m_joint_histogram_dispatch_grid.workgroup_count();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: Value stored to 'jointHistogramPartialsSize' during its initialization is never read [clang-analyzer-deadcode.DeadStores]

  const uint32_t jointHistogramPartialsSize = (m_num_bins * m_num_bins) * m_joint_histogram_dispatch_grid.workgroup_count();
                 ^
Additional context

cpp/core/gpu/registration/nmicalculator.cpp:198: Value stored to 'jointHistogramPartialsSize' during its initialization is never read

  const uint32_t jointHistogramPartialsSize = (m_num_bins * m_num_bins) * m_joint_histogram_dispatch_grid.workgroup_count();
                 ^

m_compute_context->clear_buffer(m_joint_histogram_mass_buffer);

assert(transformation.param_count() == m_degrees_of_freedom);
const auto moving_dispatch_grid = DispatchGrid::element_wise_texture(m_moving, m_min_max_moving_kernel.workgroup_size);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: Value stored to 'moving_dispatch_grid' during its initialization is never read [clang-analyzer-deadcode.DeadStores]

  const auto moving_dispatch_grid = DispatchGrid::element_wise_texture(m_moving, m_min_max_moving_kernel.workgroup_size);
             ^
Additional context

cpp/core/gpu/registration/nmicalculator.cpp:287: Value stored to 'moving_dispatch_grid' during its initialization is never read

  const auto moving_dispatch_grid = DispatchGrid::element_wise_texture(m_moving, m_min_max_moving_kernel.workgroup_size);
             ^

m_compute_context->dispatch_kernel(m_joint_histogram_smooth_kernel, smooth_grid);

const uint32_t histogramSize = m_num_bins * m_num_bins;
const DispatchGrid merge_grid{.x = histogramSize};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: Value stored to 'merge_grid' during its initialization is never read [clang-analyzer-deadcode.DeadStores]

  const DispatchGrid merge_grid{.x = histogramSize};
                     ^
Additional context

cpp/core/gpu/registration/nmicalculator.cpp:311: Value stored to 'merge_grid' during its initialization is never read

  const DispatchGrid merge_grid{.x = histogramSize};
                     ^


if (m_output == CalculatorOutput::CostAndGradients) {
if (transformation.is_affine()) {
std::array<float, 12> params;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: uninitialized record type: 'params' [cppcoreguidelines-pro-type-member-init]

Suggested change
std::array<float, 12> params;
std::array<float, 12> params{};


m_compute_context->write_object_to_buffer(m_gradients_uniforms_buffer, gradients_uniforms);
} else {
std::array<float, 6> params;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: uninitialized record type: 'params' [cppcoreguidelines-pro-type-member-init]

Suggested change
std::array<float, 6> params;
std::array<float, 6> params{};

Base automatically changed from webgpu to dev February 5, 2026 07:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant