Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ option(MRTRIX_USE_PCH "Use precompiled headers" ON)
option(MRTRIX_PYTHON_SOFTLINK "Build directory softlink to Python source code rather than copying" ON)
option(MRTRIX_BUILD_STATIC "Build MRtrix's library statically" OFF)
option(MRTRIX_USE_LLD "Use lld as the linker" OFF)
option(MRTRIX_BUILD_GPU "Build GPU support" ON)

set(MRTRIX_DEPENDENCIES_DIR "" CACHE PATH
"An optional local directory containing all thirdparty dependencies:\n \
Expand Down
45 changes: 45 additions & 0 deletions cmake/FetchDawn.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Adapted from https://github.com/eliemichel/WebGPU-distribution/blob/dawn/cmake/FetchDawn.cmake

include(FetchContent)
set(FETCHCONTENT_QUIET OFF)
FetchContent_Declare(
dawn
DOWNLOAD_COMMAND
cd ${FETCHCONTENT_BASE_DIR}/dawn-src &&
git init &&
git fetch --depth=1 https://dawn.googlesource.com/dawn chromium/7060 &&
git reset --hard FETCH_HEAD
)

FetchContent_GetProperties(dawn)
if (NOT dawn_POPULATED)
FetchContent_Populate(dawn)
if (APPLE)
set(USE_VULKAN OFF)
set(USE_METAL ON)
else()
set(USE_VULKAN ON)
set(USE_METAL OFF)
endif()

set(DAWN_FETCH_DEPENDENCIES ON)
set(DAWN_USE_GLFW OFF)
set(DAWN_USE_X11 OFF)
set(DAWN_USE_WAYLAND OFF)
set(DAWN_ENABLE_D3D11 OFF)
set(DAWN_ENABLE_D3D12 OFF)
set(DAWN_ENABLE_METAL ${USE_METAL})
set(DAWN_ENABLE_NULL OFF)
set(DAWN_ENABLE_DESKTOP_GL OFF)
set(DAWN_ENABLE_OPENGLES OFF)
set(DAWN_ENABLE_VULKAN ${USE_VULKAN})
set(TINT_BUILD_SPV_READER OFF)
set(DAWN_BUILD_SAMPLES OFF)
set(TINT_BUILD_CMD_TOOLS OFF)
set(TINT_BUILD_TESTS OFF)
set(TINT_BUILD_IR_BINARY OFF)

add_subdirectory(${dawn_SOURCE_DIR} ${dawn_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()

set(FETCHCONTENT_QUIET ON)
4 changes: 4 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ add_subdirectory(core)
if(MRTRIX_BUILD_GUI)
add_subdirectory(gui)
endif()

if(MRTRIX_BUILD_GPU)
add_subdirectory(gpu)
endif()
17 changes: 17 additions & 0 deletions cpp/cmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ endforeach(CMD)

add_custom_target(MRtrixCppCommands)

if(NOT MRTRIX_BUILD_GPU)
list(REMOVE_ITEM HEADLESS_CMD_SRCS "test_gpu.cpp")
endif()

if(MRTRIX_USE_PCH)
file(GENERATE OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/pch_cmd.cpp CONTENT "int main(){}")
add_executable(pch_cmd ${CMAKE_CURRENT_BINARY_DIR}/pch_cmd.cpp)
Expand Down Expand Up @@ -60,6 +64,17 @@ foreach(CMD ${HEADLESS_CMD_SRCS})
add_cmd(${CMD} FALSE)
endforeach(CMD)

if(MRTRIX_BUILD_GPU)
add_custom_command(
TARGET test_gpu POST_BUILD
COMMAND ${CMAKE_COMMAND} -E create_symlink
${PROJECT_SOURCE_DIR}/cpp/gpu/shaders
$<TARGET_FILE_DIR:test_gpu>/shaders
COMMENT "Symlinking shaders directory"
)
target_link_libraries(test_gpu PRIVATE mrtrix-gpu-lib)
endif()

# For the set of commands that take the longest to compile,
# we try to start their compilation as soon as possible by adding them as dependencies
# of a generated dummy file (which Ninja prioritises during compilation).
Expand All @@ -73,3 +88,5 @@ if(MRTRIX_BUILD_GUI)
add_cmd(${CMD} TRUE)
endforeach(CMD)
endif()


146 changes: 146 additions & 0 deletions cpp/cmd/test_gpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include "app.h"
#include "cmdline_option.h"
#include "command.h"
#include "file/matrix.h"
#include "image.h"
#include "image_helpers.h"
#include "gpu.h"
#include "transform.h"
#include "types.h"
#include "utils.h"

#include <Eigen/Core>
#include <algorithm>
#include <cassert>
#include <filesystem>
#include <vector>

using namespace MR;
using namespace App;

namespace {

#if defined(_WIN32)
#include <windows.h>
#elif defined(__APPLE__)
#include <mach-o/dyld.h>
#elif defined(__linux__)
#include <unistd.h>
#endif

// A cross-platform function to get the path of the executable
std::filesystem::path getExecutablePath() {
#if defined(_WIN32)
wchar_t buffer[MAX_PATH];
DWORD len = GetModuleFileNameW(NULL, buffer, MAX_PATH);
if (len == 0 || len == MAX_PATH)
throw std::runtime_error("GetModuleFileNameW failed");
return std::filesystem::path(buffer);
#elif defined(__APPLE__)
uint32_t size = 0;
if (_NSGetExecutablePath(nullptr, &size) != -1)
throw std::runtime_error("Unexpected success getting buffer size");
std::string buffer;
buffer.resize(size);
if (_NSGetExecutablePath(&buffer[0], &size) != 0)
throw std::runtime_error("Buffer still too small");
return std::filesystem::canonical(std::filesystem::path(buffer));
#elif defined(__linux__)
std::string link = "/proc/self/exe";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: variable 'link' of type 'std::string' (aka 'basic_string') can be declared 'const' [misc-const-correctness]

Suggested change
std::string link = "/proc/self/exe";
std::string const link = "/proc/self/exe";

std::error_code ec;
auto p = std::filesystem::read_symlink(link, ec);
if (ec)
throw std::runtime_error("read_symlink failed: " + ec.message());
return std::filesystem::canonical(p);
#else
#error Unsupported platform
#endif
}
}

void usage() {
AUTHOR = "Daljit Singh",
SYNOPSIS = "Transforms an image given the input image and the backward transformation matrix";

ARGUMENTS
+ Argument ("input", "input image").type_image_in()
+ Argument ("output", "the output image.").type_image_out()
+ Argument ("transform").type_file_in();

}


void run()
{
const std::filesystem::path inputPath { argument.at(0) };
const std::filesystem::path outputPath { argument.at(1) };
const std::filesystem::path transformPath { argument.at(2) };

const transform_type transform = File::Matrix::load_transform(transformPath);
const GPU::ComputeContext context;

const auto inputImage = Image<float>::open(inputPath).with_direct_io();
const GPU::Texture inputTexture = context.newTextureFromHostImage(inputImage);

const transform_type transformationWorldCoords = Transform(inputImage).scanner2voxel *
transform *
Transform(inputImage).voxel2scanner;

auto to4x4Matrix = [](const transform_type& t) {
const auto mat3x4 = t.matrix().cast<float>();
Eigen::Matrix4f matrix;
matrix.block<3, 4>(0, 0) = mat3x4;
matrix.block<1, 4>(3, 0) = Eigen::RowVector4f(0, 0, 0, 1);
return matrix;
};

const Eigen::Matrix4f transformationMat = to4x4Matrix(transformationWorldCoords);
constexpr MR::GPU::WorkgroupSize workgroupSize = {8, 8, 4};

const MR::GPU::Buffer<float> transformBuffer = context.newBufferFromHostMemory<float>(transformationMat);

const MR::GPU::TextureSpec outputTextureSpec {
.width = inputTexture.spec.width,
.height = inputTexture.spec.height,
.depth = inputTexture.spec.depth,
.format = GPU::TextureFormat::R32Float,
.usage = { .storageBinding = true }
};
const GPU::Texture outputTexture = context.newEmptyTexture(outputTextureSpec);

const auto currentPath = getExecutablePath().parent_path();
const GPU::KernelSpec transformKernelSpec {
.computeShader = {
.shaderSource = GPU::ShaderFile { currentPath / "shaders/transform_image.wgsl" },
.workgroupSize = workgroupSize,
},
.readOnlyBuffers = { transformBuffer},
.readOnlyTextures = { inputTexture },
.writeOnlyTextures = { outputTexture },
.samplers = { context.newLinearSampler() }
};

const GPU::Kernel transformKernel = context.newKernel(transformKernelSpec);

const auto width = inputTexture.spec.width;
const auto height = inputTexture.spec.height;
const auto depth = inputTexture.spec.depth;

const GPU::DispatchGrid dispatchGrid {
.wgCountX = Utils::nextMultipleOf(width / workgroupSize.x, workgroupSize.x),
.wgCountY = Utils::nextMultipleOf(height / workgroupSize.y, workgroupSize.y),
.wgCountZ = Utils::nextMultipleOf(depth / workgroupSize.z, workgroupSize.z)
};

context.dispatchKernel(transformKernel, dispatchGrid);

std::vector<float> gpuData(MR::voxel_count(inputImage), 0.0F);
context.downloadTexture(outputTexture, gpuData);

Image<float> outputImage = Image<float>::scratch(inputImage);
float *data = static_cast<float*>(outputImage.address());

std::copy_n(gpuData.data(), gpuData.size(), data);

MR::save(outputImage, outputPath);
}
25 changes: 25 additions & 0 deletions cpp/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
include(FetchDawn)

add_library(mrtrix-gpu-lib STATIC
gpu.h gpu.cpp
utils.h utils.cpp
wgslprocessing.h wgslprocessing.cpp
match.h
span.h
)

target_link_libraries(mrtrix-gpu-lib PUBLIC
dawncpp_headers
webgpu_dawn
mrtrix::core
)

target_include_directories(mrtrix-gpu-lib PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
)

if(MRTRIX_BUILD_TESTS)
enable_testing()
add_subdirectory(testing)
endif()

Loading
Loading