Skip to content

Conversation

@daljit46
Copy link
Member

@daljit46 daljit46 commented May 14, 2025

This PR is a preview of a highly experimental, cross-platform GPU abstraction layer built on top of WebGPU. The aim is to enable high-performance GPU compute in MRtrix3. The intention is not to merge this into dev, but to gather feedback on the API and potentially start a discussion about future integration into the codebase.

WebGPU is a new graphics and compute API developed by the W3C, the World Wide Web Consortium, providing a common interface over Vulkan, Metal, and DirectX. While originally designed for the web, there are native implementations written in C++ (dawn) and Rust (wgpu). This work builds on top of Dawn by wrapping common tasks to avoid boilerplate, such as buffer uploads, shader loading, and dispatch. The goal is to keep the API simple and ergonomic, extending functionality later as needed.

Modern GPU programming requires a significantly different mental model than CPU programming. For the uninitiated (with much oversimplification due to brevity), here are some key concepts:

  • The GPU executes compute shaders (or “kernels” in CUDA parlance) by launching many lightweight “threads” in parallel.

  • Threads are grouped into workgroups. You define a workgroup by its 3D dimensions (x, y, z), so that each workgroup contains x × y × z threads.

  • To actually run your kernel, you issue a launch with a “dispatch grid,” specifying how many workgroups to spawn in each dimension (wgCountX × wgCountY × wgCountZ). The GPU scheduler then maps those workgroups onto its compute units.

  • Each thread in a workgroup has built-in IDs like:

    • global_invocation_id: a unique id for each thread,
    • local_invocation_id: its index within the workgroup,
    • workgroup_invocation_id: the index of the workgroup that the thread is part of.
  • The total number of threads per launch is (wgCountX × wgCountY × wgCountZ) × (x × y × z).

  • Kernels are written from the point of view of a single thread: you author your shader as if it were running on exactly one thread, using those IDs to process data. Within a workgroup, threads can share fast local memory and synchronize with barriers, but there is no cross-workgroup synchronization at the end of each dispatch.

For example, to cover a 2D image of size (W, H) with a 16 × 16 workgroup, you’d dispatch (ceil(W/16), ceil(H/16), 1) workgroups, yielding at least one thread per pixel.

For a complete overview of the API, see the gpu.h header.


Here is a usage example that uploads a buffer to the GPU, performs an element-wise sqrt operation, and writes the results into a second buffer of the same size.

In C++, we create the buffers, schedule the GPU kernel, and then download the output:

using namespace MR::GPU;

// Create a GPU context
ComputeContext context;

// CPU-side input
std::vector<float> inputVector(1'000'000);
std::iota(inputVector.begin(), inputVector.end(), 0.0F);

// Create GPU buffers
Buffer<float> inputBuffer  = context.newBufferFromHostMemory<float>(inputVector);
Buffer<float> outputBuffer = context.newEmptyBuffer<float>(1'000'000);

// Define our kernel
const KernelSpec kernelSpec {
  .computeShader    = { .shaderSource = ShaderFile { "sqrt_kernel.wgsl" } },
  .readOnlyBuffers  = { inputBuffer },
  .readWriteBuffers = { outputBuffer }
};

// Create the kernel and dispatch it
// We assume a 1D workgroup size of (64, 1, 1) = 64 threads per workgroup.
// Since we have 1'000'000 elements, we need 1'000'000/64 = 15'625 workgroups so each thread handles one element.
const Kernel       kernel       = context.newKernel(kernelSpec);
const DispatchGrid dispatchGrid = { 1'000'000/64, 1, 1 };
context.dispatchKernel(kernel, dispatchGrid);

// Download the data and wait for all work to finish
std::vector<float> outputData(inputVector.size());
context.downloadBuffer<float>(outputBuffer, outputData);

The kernel itself is written in WGSL, a new shader language for WebGPU that sort of resembles Rust:

// Define the “slots” of our kernel: two arrays of floats.=
@group(0) @binding(0) var<storage, read>      input:  array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

// Define the size of the workgroup: a 1D workgroup with 64 threads along x
@compute(64, 1, 1)
fn main(@builtin(global_invocation_id) globalId: vec3<u32>) {
  output[globalId.x] = sqrt(input[globalId.x]);
}

I have also created a test command test_gpu that performs the slightly more complicated task of applying an affine transformation to an image on the GPU.

As of now, the API is still in its early stages and has been shaped primarily by my work on performing affine registration using WebGPU, which builds on top of this work. It may change significantly based on feedback and refinements needed for the registration work, but it’s in a usable state, so feel free to experiment with it. You can create a shader in cpp/gpu/shaders (automatically symlinked to bin/shaders) and then inspect how the test_gpu command works.

There are many questions for discussion, notably how to integrate the needed dependencies into our project structure. Currently, everything lives in cpp/gpu without much organization. The unit tests, inappropriately residing in cpp/gpu/testing, use the Google Test framework and I found them to be very helpful both for debugging GPU code (they can also serve as examples of API usage). On a slightly tangential note, I think we should integrate a unit testing framework like Google Test into the codebase independently from this PR.

Additionally, the C++ library used to interface with the GPU (dawn) is built from source at each build (fetched over the network), which takes considerable time and clutters the CMake's output at configure time. We should consider providing prebuilt binaries to avoid long build times.

@MRtrix3/mrtrix3-devs any feedback is much appreciated!

@daljit46 daljit46 self-assigned this May 14, 2025
@daljit46 daljit46 marked this pull request as draft May 14, 2025 22:18
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 25 out of 78. Check the log or trigger a new build to see more.

throw std::runtime_error("Buffer still too small");
return std::filesystem::canonical(std::filesystem::path(buffer));
#elif defined(__linux__)
std::string link = "/proc/self/exe";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'link' of type 'std::string' (aka 'basic_string') can be declared 'const' [misc-const-correctness]

Suggested change
std::string link = "/proc/self/exe";
std::string const link = "/proc/self/exe";

{
const auto buffer = innerNewEmptyBuffer(srcByteSize);
innerWriteToBuffer(buffer, srcMemory, srcByteSize, 0);
return buffer;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: constness of 'buffer' prevents automatic move [performance-no-automatic-move]

    return buffer;
           ^

srcMemoryRegion.size_bytes(),
&textureDataLayout,
&textureSize);
return texture;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: constness of 'texture' prevents automatic move [performance-no-automatic-move]

    return texture;
           ^

Kernel ComputeContext::newKernel(const KernelSpec &kernelSpec) const
{
struct BindingEntries {
std::vector<wgpu::BindGroupEntry> bindGroupEntries;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'bindGroupEntries' has public visibility [misc-non-private-member-variables-in-classes]

        std::vector<wgpu::BindGroupEntry> bindGroupEntries;
                                          ^

{
struct BindingEntries {
std::vector<wgpu::BindGroupEntry> bindGroupEntries;
std::vector<wgpu::BindGroupLayoutEntry> bindGroupLayoutEntries;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'bindGroupLayoutEntries' has public visibility [misc-non-private-member-variables-in-classes]

        std::vector<wgpu::BindGroupLayoutEntry> bindGroupLayoutEntries;
                                                ^

typename std::enable_if<
(Extent == dynamic_extent || OtherExtent == dynamic_extent ||
Extent == OtherExtent) &&
std::is_convertible<OtherElementType (*)[],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not declare C-style arrays, use std::array<> instead [cppcoreguidelines-avoid-c-arrays]

                     std::is_convertible<OtherElementType (*)[],
                                         ^

(Extent == dynamic_extent || OtherExtent == dynamic_extent ||
Extent == OtherExtent) &&
std::is_convertible<OtherElementType (*)[],
ElementType (*)[]>::value,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not declare C-style arrays, use std::array<> instead [cppcoreguidelines-avoid-c-arrays]

                                         ElementType (*)[]>::value,
                                         ^


/* Deduction Guides */
template <class T, size_t N>
span(T (&)[N])->span<T, N>;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not declare C-style arrays, use std::array<> instead [cppcoreguidelines-avoid-c-arrays]

span(T (&)[N])->span<T, N>;
     ^

}

template <typename T, std::size_t N>
constexpr span<T, N> make_span(T (&arr)[N]) noexcept

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not declare C-style arrays, use std::array<> instead [cppcoreguidelines-avoid-c-arrays]

constexpr span<T, N> make_span(T (&arr)[N]) noexcept
                               ^

: sizeof(ElementType) * Extent)>
as_bytes(span<ElementType, Extent> s) noexcept
{
return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use reinterpret_cast [cppcoreguidelines-pro-type-reinterpret-cast]

    return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};
            ^

@Cewein
Copy link

Cewein commented May 15, 2025

This is pretty nice, have you consider looking into https://shader-slang.org/ instead of WGSL for the shading language ?
It is pretty new but could be use to with other application very simply. and would also help by not rewritting all the previous glsl shaders.

I have been doing vulkan before and I think webGPU is the rigth middle ground between complexity and simplicity opposed to pure metal, vulkan or directX implementation .

@daljit46
Copy link
Member Author

daljit46 commented May 15, 2025

Hi. Yes, I’m aware of Slang. In fact, I discussed with @jdtournier the possibility of using it instead of WGSL a few months ago.

In terms of language capabilities, Slang is clearly a step up from WGSL. Its support for generics and modules would eliminate the need for the custom preprocessor mechanism used in the current implementation and could potentially reduce duplication in some tasks (for example, handling images with different dimensions). Automatic differentiation could also be quite handy.

However, there are some downsides to consider:

  • Adding Slang would introduce an extra dependency to the project. The Slang team does provide prebuilt binaries, so this may not be a serious issue, but it’s nonetheless worth keeping in mind.
  • WGSL acts as the lowest common denominator across all graphics APIs, meaning it supports only the features common to its targets. As a result, if you find WGSL code online, you can copy and paste it without worrying about whether it will be supported on your target platform (although, due to extensions, this isn’t strictly guaranteed). On the other hand, Slang advertises features that may only supported by a subset of its targets.

That said, I haven’t given Slang a thorough trial yet, but I’m not ruling it out. If the benefits outweigh the issues I’ve raised, we should consider adopting it for MRtrix3.

@Cewein
Copy link

Cewein commented May 23, 2025

I looked at gpu.h and gpu.cpp. It is a great but I think if MRtrix move deep into GPU driven state we should make specific file for each part of the gpu pipeline such as texture, kernel, buffer and more. Nevertheless, I understand this is just the start.

I built the branch under ubuntu 24, build just fine. I Also agree that dawn clutter a lot the build time, i think having a fork of dawn that we update only for major wgpu realese could do it. Also since the spec of webgpu are not fully define yet it could be a great way to go around this caviate.

Also are you also thinking to move away from Qt6 and for example using Dear ImGui or something like this ? Qt6 is nice but ImGui provide better and faster UI in my personal opinion since ImGui is designed for GPU driven application. There is binding for WGPU. it allow interactive data interaction for GPU-CPU interaction.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 25 out of 52. Check the log or trigger a new build to see more.

: sizeof(ElementType) * Extent)>
as_writable_bytes(span<ElementType, Extent> s) noexcept
{
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use reinterpret_cast [cppcoreguidelines-pro-type-reinterpret-cast]

    return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
            ^

@@ -0,0 +1,275 @@
#include <gtest/gtest.h>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'gtest/gtest.h' file not found [clang-diagnostic-error]

#include <gtest/gtest.h>
         ^

using namespace MR::GPU;


TEST_F(GPUTest, MakeEmptyBuffer) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: all parameters should be named in a function [readability-named-parameter]

Suggested change
TEST_F(GPUTest, MakeEmptyBuffer) {
TEST_F(GPUTest /*unused*/, MakeEmptyBuffer /*unused*/) {

const size_t bufferElementCount = 1024;
const Buffer<uint32_t> buffer = context.newEmptyBuffer<uint32_t>(bufferElementCount);

std::vector<uint32_t> downloadedData(bufferElementCount, 1); // Initialize with non-zero
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'downloadedData' is not initialized [cppcoreguidelines-init-variables]

Suggested change
std::vector<uint32_t> downloadedData(bufferElementCount, 1); // Initialize with non-zero
std::vector<uint32_t> downloadedData = 0(bufferElementCount, 1); // Initialize with non-zero

}
}

TEST_F(GPUTest, BufferFromHostMemory) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: all parameters should be named in a function [readability-named-parameter]

Suggested change
TEST_F(GPUTest, BufferFromHostMemory) {
TEST_F(GPUTest /*unused*/, BufferFromHostMemory /*unused*/) {

}
}

TEST_F(GPUTest, WriteToBufferWithOffset) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: all parameters should be named in a function [readability-named-parameter]

Suggested change
TEST_F(GPUTest, WriteToBufferWithOffset) {
TEST_F(GPUTest /*unused*/, WriteToBufferWithOffset /*unused*/) {


TEST_F(GPUTest, WriteToBufferWithOffset) {
const size_t bufferSize = 10;
std::vector<float> initialData(bufferSize);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'initialData' is not initialized [cppcoreguidelines-init-variables]

Suggested change
std::vector<float> initialData(bufferSize);
std::vector<float> initialData = 0(bufferSize);


const Buffer<float> buffer = context.newBufferFromHostMemory<float>(initialData);

std::vector<float> newData = {100.0F, 101.0F, 102.0F};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'newData' is not initialized [cppcoreguidelines-init-variables]

Suggested change
std::vector<float> newData = {100.0F, 101.0F, 102.0F};
std::vector<float> newData = 0 = {100.0F, 101.0F, 102.0F};


context.writeToBuffer<float>(buffer, newData, offsetBytes);

std::vector<float> downloadedData(bufferSize);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'downloadedData' is not initialized [cppcoreguidelines-init-variables]

Suggested change
std::vector<float> downloadedData(bufferSize);
std::vector<float> downloadedData = 0(bufferSize);

std::vector<float> downloadedData(bufferSize);
context.downloadBuffer<float>(buffer, downloadedData);

std::vector<float> expectedData = {0.0F, 1.0F, 2.0F, 100.0F, 101.0F, 102.0F, 6.0F, 7.0F, 8.0F, 9.0F};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'expectedData' is not initialized [cppcoreguidelines-init-variables]

Suggested change
std::vector<float> expectedData = {0.0F, 1.0F, 2.0F, 100.0F, 101.0F, 102.0F, 6.0F, 7.0F, 8.0F, 9.0F};
std::vector<float> expectedData = 0 = {0.0F, 1.0F, 2.0F, 100.0F, 101.0F, 102.0F, 6.0F, 7.0F, 8.0F, 9.0F};

@daljit46
Copy link
Member Author

Closing in favour of #3238

@daljit46 daljit46 closed this Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants