diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d2fd2ebb8..e7b9c1dee1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ option(MRTRIX_USE_PCH "Use precompiled headers" ON) option(MRTRIX_PYTHON_SOFTLINK "Build directory softlink to Python source code rather than copying" ON) option(MRTRIX_BUILD_STATIC "Build MRtrix's library statically" OFF) option(MRTRIX_USE_LLD "Use lld as the linker" OFF) +option(MRTRIX_BUILD_GPU "Build GPU support" ON) set(MRTRIX_DEPENDENCIES_DIR "" CACHE PATH "An optional local directory containing all thirdparty dependencies:\n \ diff --git a/cmake/FetchDawn.cmake b/cmake/FetchDawn.cmake new file mode 100644 index 0000000000..e6c88fdb94 --- /dev/null +++ b/cmake/FetchDawn.cmake @@ -0,0 +1,45 @@ +# Adapted from https://github.com/eliemichel/WebGPU-distribution/blob/dawn/cmake/FetchDawn.cmake + +include(FetchContent) +set(FETCHCONTENT_QUIET OFF) +FetchContent_Declare( + dawn + DOWNLOAD_COMMAND + cd ${FETCHCONTENT_BASE_DIR}/dawn-src && + git init && + git fetch --depth=1 https://dawn.googlesource.com/dawn chromium/7060 && + git reset --hard FETCH_HEAD +) + +FetchContent_GetProperties(dawn) +if (NOT dawn_POPULATED) + FetchContent_Populate(dawn) + if (APPLE) + set(USE_VULKAN OFF) + set(USE_METAL ON) + else() + set(USE_VULKAN ON) + set(USE_METAL OFF) + endif() + + set(DAWN_FETCH_DEPENDENCIES ON) + set(DAWN_USE_GLFW OFF) + set(DAWN_USE_X11 OFF) + set(DAWN_USE_WAYLAND OFF) + set(DAWN_ENABLE_D3D11 OFF) + set(DAWN_ENABLE_D3D12 OFF) + set(DAWN_ENABLE_METAL ${USE_METAL}) + set(DAWN_ENABLE_NULL OFF) + set(DAWN_ENABLE_DESKTOP_GL OFF) + set(DAWN_ENABLE_OPENGLES OFF) + set(DAWN_ENABLE_VULKAN ${USE_VULKAN}) + set(TINT_BUILD_SPV_READER OFF) + set(DAWN_BUILD_SAMPLES OFF) + set(TINT_BUILD_CMD_TOOLS OFF) + set(TINT_BUILD_TESTS OFF) + set(TINT_BUILD_IR_BINARY OFF) + + add_subdirectory(${dawn_SOURCE_DIR} ${dawn_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() + +set(FETCHCONTENT_QUIET ON) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 680509d645..a7e37212a9 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -35,3 +35,7 @@ add_subdirectory(core) if(MRTRIX_BUILD_GUI) add_subdirectory(gui) endif() + +if(MRTRIX_BUILD_GPU) + add_subdirectory(gpu) +endif() diff --git a/cpp/cmd/CMakeLists.txt b/cpp/cmd/CMakeLists.txt index 5ef1eb6794..8e1f9ca953 100644 --- a/cpp/cmd/CMakeLists.txt +++ b/cpp/cmd/CMakeLists.txt @@ -11,6 +11,10 @@ endforeach(CMD) add_custom_target(MRtrixCppCommands) +if(NOT MRTRIX_BUILD_GPU) + list(REMOVE_ITEM HEADLESS_CMD_SRCS "test_gpu.cpp") +endif() + if(MRTRIX_USE_PCH) file(GENERATE OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/pch_cmd.cpp CONTENT "int main(){}") add_executable(pch_cmd ${CMAKE_CURRENT_BINARY_DIR}/pch_cmd.cpp) @@ -60,6 +64,17 @@ foreach(CMD ${HEADLESS_CMD_SRCS}) add_cmd(${CMD} FALSE) endforeach(CMD) +if(MRTRIX_BUILD_GPU) + add_custom_command( + TARGET test_gpu POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink + ${PROJECT_SOURCE_DIR}/cpp/gpu/shaders + $/shaders + COMMENT "Symlinking shaders directory" + ) + target_link_libraries(test_gpu PRIVATE mrtrix-gpu-lib) +endif() + # For the set of commands that take the longest to compile, # we try to start their compilation as soon as possible by adding them as dependencies # of a generated dummy file (which Ninja prioritises during compilation). @@ -73,3 +88,5 @@ if(MRTRIX_BUILD_GUI) add_cmd(${CMD} TRUE) endforeach(CMD) endif() + + diff --git a/cpp/cmd/test_gpu.cpp b/cpp/cmd/test_gpu.cpp new file mode 100644 index 0000000000..5b661db089 --- /dev/null +++ b/cpp/cmd/test_gpu.cpp @@ -0,0 +1,146 @@ +#include "app.h" +#include "cmdline_option.h" +#include "command.h" +#include "file/matrix.h" +#include "image.h" +#include "image_helpers.h" +#include "gpu.h" +#include "transform.h" +#include "types.h" +#include "utils.h" + +#include +#include +#include +#include +#include + +using namespace MR; +using namespace App; + +namespace { + +#if defined(_WIN32) +#include +#elif defined(__APPLE__) +#include +#elif defined(__linux__) +#include +#endif + +// A cross-platform function to get the path of the executable +std::filesystem::path getExecutablePath() { +#if defined(_WIN32) + wchar_t buffer[MAX_PATH]; + DWORD len = GetModuleFileNameW(NULL, buffer, MAX_PATH); + if (len == 0 || len == MAX_PATH) + throw std::runtime_error("GetModuleFileNameW failed"); + return std::filesystem::path(buffer); +#elif defined(__APPLE__) + uint32_t size = 0; + if (_NSGetExecutablePath(nullptr, &size) != -1) + throw std::runtime_error("Unexpected success getting buffer size"); + std::string buffer; + buffer.resize(size); + if (_NSGetExecutablePath(&buffer[0], &size) != 0) + throw std::runtime_error("Buffer still too small"); + return std::filesystem::canonical(std::filesystem::path(buffer)); +#elif defined(__linux__) + std::string link = "/proc/self/exe"; + std::error_code ec; + auto p = std::filesystem::read_symlink(link, ec); + if (ec) + throw std::runtime_error("read_symlink failed: " + ec.message()); + return std::filesystem::canonical(p); +#else +#error Unsupported platform +#endif +} +} + +void usage() { + AUTHOR = "Daljit Singh", + SYNOPSIS = "Transforms an image given the input image and the backward transformation matrix"; + + ARGUMENTS + + Argument ("input", "input image").type_image_in() + + Argument ("output", "the output image.").type_image_out() + + Argument ("transform").type_file_in(); + +} + + +void run() +{ + const std::filesystem::path inputPath { argument.at(0) }; + const std::filesystem::path outputPath { argument.at(1) }; + const std::filesystem::path transformPath { argument.at(2) }; + + const transform_type transform = File::Matrix::load_transform(transformPath); + const GPU::ComputeContext context; + + const auto inputImage = Image::open(inputPath).with_direct_io(); + const GPU::Texture inputTexture = context.newTextureFromHostImage(inputImage); + + const transform_type transformationWorldCoords = Transform(inputImage).scanner2voxel * + transform * + Transform(inputImage).voxel2scanner; + + auto to4x4Matrix = [](const transform_type& t) { + const auto mat3x4 = t.matrix().cast(); + Eigen::Matrix4f matrix; + matrix.block<3, 4>(0, 0) = mat3x4; + matrix.block<1, 4>(3, 0) = Eigen::RowVector4f(0, 0, 0, 1); + return matrix; + }; + + const Eigen::Matrix4f transformationMat = to4x4Matrix(transformationWorldCoords); + constexpr MR::GPU::WorkgroupSize workgroupSize = {8, 8, 4}; + + const MR::GPU::Buffer transformBuffer = context.newBufferFromHostMemory(transformationMat); + + const MR::GPU::TextureSpec outputTextureSpec { + .width = inputTexture.spec.width, + .height = inputTexture.spec.height, + .depth = inputTexture.spec.depth, + .format = GPU::TextureFormat::R32Float, + .usage = { .storageBinding = true } + }; + const GPU::Texture outputTexture = context.newEmptyTexture(outputTextureSpec); + + const auto currentPath = getExecutablePath().parent_path(); + const GPU::KernelSpec transformKernelSpec { + .computeShader = { + .shaderSource = GPU::ShaderFile { currentPath / "shaders/transform_image.wgsl" }, + .workgroupSize = workgroupSize, + }, + .readOnlyBuffers = { transformBuffer}, + .readOnlyTextures = { inputTexture }, + .writeOnlyTextures = { outputTexture }, + .samplers = { context.newLinearSampler() } + }; + + const GPU::Kernel transformKernel = context.newKernel(transformKernelSpec); + + const auto width = inputTexture.spec.width; + const auto height = inputTexture.spec.height; + const auto depth = inputTexture.spec.depth; + + const GPU::DispatchGrid dispatchGrid { + .wgCountX = Utils::nextMultipleOf(width / workgroupSize.x, workgroupSize.x), + .wgCountY = Utils::nextMultipleOf(height / workgroupSize.y, workgroupSize.y), + .wgCountZ = Utils::nextMultipleOf(depth / workgroupSize.z, workgroupSize.z) + }; + + context.dispatchKernel(transformKernel, dispatchGrid); + + std::vector gpuData(MR::voxel_count(inputImage), 0.0F); + context.downloadTexture(outputTexture, gpuData); + + Image outputImage = Image::scratch(inputImage); + float *data = static_cast(outputImage.address()); + + std::copy_n(gpuData.data(), gpuData.size(), data); + + MR::save(outputImage, outputPath); +} diff --git a/cpp/gpu/CMakeLists.txt b/cpp/gpu/CMakeLists.txt new file mode 100644 index 0000000000..33f6081394 --- /dev/null +++ b/cpp/gpu/CMakeLists.txt @@ -0,0 +1,25 @@ +include(FetchDawn) + +add_library(mrtrix-gpu-lib STATIC + gpu.h gpu.cpp + utils.h utils.cpp + wgslprocessing.h wgslprocessing.cpp + match.h + span.h +) + +target_link_libraries(mrtrix-gpu-lib PUBLIC + dawncpp_headers + webgpu_dawn + mrtrix::core +) + +target_include_directories(mrtrix-gpu-lib PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(MRTRIX_BUILD_TESTS) + enable_testing() + add_subdirectory(testing) +endif() + diff --git a/cpp/gpu/gpu.cpp b/cpp/gpu/gpu.cpp new file mode 100644 index 0000000000..6a82b13b25 --- /dev/null +++ b/cpp/gpu/gpu.cpp @@ -0,0 +1,551 @@ +#include "gpu.h" +#include "exception.h" +#include "image_helpers.h" +#include "match.h" +#include "span.h" +#include "utils.h" +#include "wgslprocessing.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { + +constexpr auto GPUBackendType +#ifdef __APPLE__ + = wgpu::BackendType::Metal; +#else + = wgpu::BackendType::Vulkan; +#endif + +namespace { +uint32_t pixelSizeInBytes(const wgpu::TextureFormat format) { + switch(format) { + case wgpu::TextureFormat::R8Unorm: + return 1; + case wgpu::TextureFormat::R16Float: + return 2; + case wgpu::TextureFormat::R32Float: + return 4; + default: + throw MR::Exception("Only R8Unorm, R16Float and R32Float textures are supported!"); + } +} + + +wgpu::ShaderModule makeShaderModule(const std::string &name, const std::string &code, const wgpu::Device &device) +{ + wgpu::ShaderModuleWGSLDescriptor wgslDescriptor {}; + wgslDescriptor.code = code.c_str(); + wgpu::ShaderModuleDescriptor descriptor {}; + descriptor.nextInChain = &wgslDescriptor; + descriptor.label = name.c_str(); + + return device.CreateShaderModule(&descriptor); +} + +wgpu::TextureFormat toWGPUFormat(const MR::GPU::TextureFormat& format) { + switch(format) { + case MR::GPU::TextureFormat::R32Float: return wgpu::TextureFormat::R32Float; + default: wgpu::TextureFormat::Undefined; + }; +} + +wgpu::TextureUsage toWGPUUsage(const MR::GPU::TextureUsage& usage) { + wgpu::TextureUsage textureUsage = wgpu::TextureUsage::CopySrc | wgpu::TextureUsage::CopyDst | wgpu::TextureUsage::TextureBinding; + + if(usage.storageBinding) { + textureUsage |= wgpu::TextureUsage::StorageBinding; + } + if(usage.renderTarget) { + textureUsage |= wgpu::TextureUsage::RenderAttachment; + } + return textureUsage; +} + +} + +ComputeContext::ComputeContext() +{ + constexpr std::array dawnToggles { + "allow_unsafe_apis", + "enable_immediate_error_handling" + }; + + wgpu::DawnTogglesDescriptor dawnTogglesDesc; + dawnTogglesDesc.enabledToggles = dawnToggles.data(); + dawnTogglesDesc.enabledToggleCount = dawnToggles.size(); + + constexpr wgpu::InstanceDescriptor instanceDescriptor { + .nextInChain = nullptr, + .capabilities = {.timedWaitAnyEnable = true} + }; + const wgpu::Instance instance = wgpu::CreateInstance(&instanceDescriptor); + wgpu::Adapter adapter; + + const wgpu::RequestAdapterOptions adapterOptions { + .powerPreference = wgpu::PowerPreference::HighPerformance, + .backendType = GPUBackendType + }; + + + struct RequestAdapterResult { + wgpu::RequestAdapterStatus status = wgpu::RequestAdapterStatus::Error; + wgpu::Adapter adapter = nullptr; + std::string message; + } requestAdapterResult; + + const auto adapterCallback = [&requestAdapterResult](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message) { + requestAdapterResult = { status, std::move(adapter), std::string(message) }; + }; + + const wgpu::Future adapterRequest = instance.RequestAdapter(&adapterOptions, + wgpu::CallbackMode::WaitAnyOnly, + adapterCallback); + const wgpu::WaitStatus waitStatus = instance.WaitAny(adapterRequest, -1); + + if(waitStatus == wgpu::WaitStatus::Success) { + if(requestAdapterResult.status != wgpu::RequestAdapterStatus::Success) { + throw MR::Exception("Failed to get adapter: " + requestAdapterResult.message); + } + } else { + throw MR::Exception("Failed to get adapter: wgpu::Instance::WaitAny failed"); + } + + adapter = requestAdapterResult.adapter; + + const std::vector requiredDeviceFeatures = { + wgpu::FeatureName::R8UnormStorage, + wgpu::FeatureName::Float32Filterable + }; + + const wgpu::Limits requiredDeviceLimits { + .maxComputeWorkgroupStorageSize = 32768, + .maxComputeInvocationsPerWorkgroup = 1024, + }; + + wgpu::DeviceDescriptor deviceDescriptor {}; + deviceDescriptor.nextInChain = &dawnTogglesDesc; + deviceDescriptor.requiredFeatures = requiredDeviceFeatures.data(); + deviceDescriptor.requiredFeatureCount = requiredDeviceFeatures.size(); + deviceDescriptor.requiredLimits = &requiredDeviceLimits; + + deviceDescriptor.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device&, wgpu::DeviceLostReason reason, wgpu::StringView message) { + const char* reasonName = ""; + if(reason != wgpu::DeviceLostReason::Destroyed && + reason != wgpu::DeviceLostReason::InstanceDropped) { + throw MR::Exception("GPU device lost: " + std::string(reasonName) + " : " + message.data); + } + }); + deviceDescriptor.SetUncapturedErrorCallback( + [](const wgpu::Device&, wgpu::ErrorType type, wgpu::StringView message) { + (void)type; + FAIL("Uncaptured gpu error: " + std::string(message)); + throw MR::Exception("Uncaptured gpu error: " + std::string(message)); + }); + + this->instance = instance; + this->adapter = adapter; + this->device = adapter.CreateDevice(&deviceDescriptor); +} + +wgpu::Buffer ComputeContext::innerNewEmptyBuffer(size_t byteSize) const +{ + const wgpu::BufferDescriptor bufferDescriptor { + .usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + .size = byteSize, + }; + + return device.CreateBuffer(&bufferDescriptor); +} + +wgpu::Buffer ComputeContext::innerNewBufferFromHostMemory(const void *srcMemory, size_t srcByteSize) const +{ + const auto buffer = innerNewEmptyBuffer(srcByteSize); + innerWriteToBuffer(buffer, srcMemory, srcByteSize, 0); + return buffer; +} + +void ComputeContext::innerDownloadBuffer(const wgpu::Buffer &buffer, void *dstMemory, size_t dstByteSize) const +{ + assert(buffer.GetSize() == dstByteSize); + + const wgpu::BufferDescriptor stagingBufferDescriptor { + .usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + .size = dstByteSize, + }; + + const wgpu::Buffer stagingBuffer = device.CreateBuffer(&stagingBufferDescriptor); + const wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(buffer, 0, stagingBuffer, 0, dstByteSize); + const wgpu::CommandBuffer commands = encoder.Finish(); + device.GetQueue().Submit(1, &commands); + + auto mappingCallback = [](wgpu::MapAsyncStatus status, const char* message) { + if(status != wgpu::MapAsyncStatus::Success) { + throw MR::Exception("Failed to map buffer: " + std::string(message)); + } + }; + const wgpu::Future mappingFuture = stagingBuffer.MapAsync(wgpu::MapMode::Read, + 0, + stagingBuffer.GetSize(), + wgpu::CallbackMode::WaitAnyOnly, + mappingCallback); + const wgpu::WaitStatus waitStatus = instance.WaitAny(mappingFuture, std::numeric_limits::max()); + if(waitStatus != wgpu::WaitStatus::Success) { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Instance::WaitAny failed"); + } + + const void* data = stagingBuffer.GetConstMappedRange(); + if(dstMemory != nullptr) { + std::memcpy(dstMemory, data, dstByteSize); + stagingBuffer.Unmap(); + } else { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Buffer::GetMappedRange returned nullptr"); + } +} + +void ComputeContext::innerWriteToBuffer(const wgpu::Buffer &buffer, const void *data, size_t srcByteSize, uint64_t offset) const +{ + device.GetQueue().WriteBuffer(buffer, offset, data, srcByteSize); +} + + +Texture ComputeContext::newEmptyTexture(const TextureSpec &textureSpec) const +{ + const wgpu::TextureDescriptor wgpuTextureDesc { + .usage = toWGPUUsage(textureSpec.usage), + .dimension = textureSpec.depth > 1 ? wgpu::TextureDimension::e3D : wgpu::TextureDimension::e2D, + .size = { textureSpec.width, textureSpec.height, textureSpec.depth }, + .format = toWGPUFormat(textureSpec.format) + }; + return { device.CreateTexture(&wgpuTextureDesc), textureSpec }; +} + +Texture ComputeContext::newTextureFromHostMemory(const TextureSpec &textureSpec, tcb::span srcMemoryRegion) const +{ + const Texture texture = newEmptyTexture(textureSpec); + const wgpu::TexelCopyTextureInfo imageCopyTexture { texture.wgpuHandle }; + const wgpu::TexelCopyBufferLayout textureDataLayout { + .bytesPerRow = textureSpec.width * pixelSizeInBytes(texture.wgpuHandle.GetFormat()), + .rowsPerImage = textureSpec.height, + }; + + const wgpu::Extent3D textureSize { textureSpec.width, textureSpec.height, textureSpec.depth }; + device.GetQueue().WriteTexture(&imageCopyTexture, + srcMemoryRegion.data(), + srcMemoryRegion.size_bytes(), + &textureDataLayout, + &textureSize); + return texture; +} + +Texture ComputeContext::newTextureFromHostImage(const MR::Image &image, const TextureUsage &usage) const { + const TextureSpec textureSpec = { + .width = static_cast(image.size(0)), + .height = static_cast(image.size(1)), + .depth = static_cast(image.size(2)), + .usage = usage, + }; + const auto imageSize = MR::voxel_count(image); + return newTextureFromHostMemory(textureSpec, tcb::span(image.address(), imageSize)); +} + +void ComputeContext::downloadTexture(const Texture &texture, tcb::span dstMemoryRegion) const +{ + assert(dstMemoryRegion.size_bytes() >= static_cast(texture.wgpuHandle.GetWidth()) * + texture.wgpuHandle.GetHeight() * + texture.wgpuHandle.GetDepthOrArrayLayers() + && "Memory region size is too small for the texture"); + + const uint32_t bytesPerRow = Utils::nextMultipleOf(texture.wgpuHandle.GetWidth() * + pixelSizeInBytes(texture.wgpuHandle.GetFormat()), 256); + const size_t paddedDataSize = static_cast(bytesPerRow) * texture.wgpuHandle.GetHeight() * + texture.wgpuHandle.GetDepthOrArrayLayers(); + const wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + const wgpu::BufferDescriptor stagingBufferDesc { + .usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + .size = paddedDataSize, + }; + const wgpu::Buffer stagingBuffer = device.CreateBuffer(&stagingBufferDesc); + + const wgpu::TexelCopyTextureInfo imageCopyTexture { texture.wgpuHandle }; + const wgpu::TexelCopyBufferInfo imageCopyBuffer { + .layout = wgpu::TexelCopyBufferLayout { + .bytesPerRow = bytesPerRow, + .rowsPerImage = texture.wgpuHandle.GetHeight(), + }, + .buffer = stagingBuffer + }; + + const wgpu::Extent3D imageCopySize { + .width = texture.wgpuHandle.GetWidth(), + .height = texture.wgpuHandle.GetHeight(), + .depthOrArrayLayers = texture.wgpuHandle.GetDepthOrArrayLayers(), + }; + encoder.CopyTextureToBuffer(&imageCopyTexture, &imageCopyBuffer, &imageCopySize); + const wgpu::CommandBuffer commands = encoder.Finish(); + device.GetQueue().Submit(1, &commands); + + auto mappingCallback = [](wgpu::MapAsyncStatus status, const char* message) { + if(status != wgpu::MapAsyncStatus::Success) { + throw MR::Exception("Failed to map buffer: " + std::string(message)); + } + }; + + const wgpu::Future mappingFuture = stagingBuffer.MapAsync(wgpu::MapMode::Read, + 0, + stagingBuffer.GetSize(), + wgpu::CallbackMode::WaitAnyOnly, + mappingCallback); + + const wgpu::WaitStatus waitStatus = instance.WaitAny(mappingFuture, -1); + if(waitStatus != wgpu::WaitStatus::Success) { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Instance::WaitAny failed"); + } + + const void* data = stagingBuffer.GetConstMappedRange(); + + // Copy the unpadded data + if(data != nullptr) { + const size_t paddedRowWidth = bytesPerRow / sizeof(float); + const size_t numRows = static_cast(texture.wgpuHandle.GetDepthOrArrayLayers()) * texture.wgpuHandle.GetHeight(); + const tcb::span srcSpan(static_cast(data), paddedRowWidth * numRows); + const size_t width = texture.wgpuHandle.GetWidth(); + const tcb::span dstSpan(dstMemoryRegion.data(), width * numRows); + + for (size_t row = 0; row < numRows; ++row) { + const auto rowSrc = srcSpan.subspan(row * paddedRowWidth, width); + auto rowDst = dstSpan.subspan(row * width, width); + // copy exactly 'width' pixels + std::copy_n(rowSrc.begin(), width, rowDst.begin()); + } + } else { + throw MR::Exception("Failed to map buffer to host memory: wgpu::Buffer::GetMappedRange returned nullptr"); + } +} + +Kernel ComputeContext::newKernel(const KernelSpec &kernelSpec) const +{ + struct BindingEntries { + std::vector bindGroupEntries; + std::vector bindGroupLayoutEntries; + + void add(const wgpu::BindGroupEntry& bindGroupEntry, + const wgpu::BindGroupLayoutEntry& bindGroupLayoutEntry) { + bindGroupEntries.push_back(bindGroupEntry); + bindGroupLayoutEntries.push_back(bindGroupLayoutEntry); + } + }; + + BindingEntries bindingEntries; + + + uint32_t bindingIndex = 0; + + auto getBufferWGPUHandle = [](const BufferVariant& buffer) { + return MR::match(buffer, [](auto &&arg) { return arg.wgpuHandle; }); + }; + // Uniform buffers + for(const auto& buffer : kernelSpec.uniformBuffers) { + const wgpu::BindGroupLayoutEntry layoutEntry { + .binding = bindingIndex++, + .visibility = wgpu::ShaderStage::Compute, + .buffer = { .type = wgpu::BufferBindingType::Uniform } + }; + + const wgpu::BindGroupEntry bindGroupEntry { + .binding = layoutEntry.binding, + .buffer = getBufferWGPUHandle(buffer), + }; + bindingEntries.add(bindGroupEntry, layoutEntry); + } + + // Read-only buffers + for(const auto& buffer : kernelSpec.readOnlyBuffers) { + const wgpu::BindGroupLayoutEntry layoutEntry { + .binding = bindingIndex++, + .visibility = wgpu::ShaderStage::Compute, + .buffer = { .type = wgpu::BufferBindingType::ReadOnlyStorage } + }; + + const wgpu::BindGroupEntry bindGroupEntry { + .binding = layoutEntry.binding, + .buffer = getBufferWGPUHandle(buffer), + }; + bindingEntries.add(bindGroupEntry, layoutEntry); + } + + // Read-write (storage) buffers + for(const auto& buffer : kernelSpec.readWriteBuffers) { + const wgpu::BindGroupLayoutEntry layoutEntry { + .binding = bindingIndex++, + .visibility = wgpu::ShaderStage::Compute, + .buffer = { .type = wgpu::BufferBindingType::Storage } + }; + + const wgpu::BindGroupEntry bindGroupEntry { + .binding = layoutEntry.binding, + .buffer = getBufferWGPUHandle(buffer), + }; + bindingEntries.add(bindGroupEntry, layoutEntry); + } + + // Read-only textures + for(const auto& texture : kernelSpec.readOnlyTextures) { + const wgpu::BindGroupLayoutEntry layoutEntry { + .binding = bindingIndex++, + .visibility = wgpu::ShaderStage::Compute, + .texture = { + .sampleType = wgpu::TextureSampleType::Float, + .viewDimension = texture.wgpuHandle.GetDepthOrArrayLayers() > 1 + ? wgpu::TextureViewDimension::e3D + : wgpu::TextureViewDimension::e2D, + } + }; + + const wgpu::BindGroupEntry bindGroupEntry { + .binding = layoutEntry.binding, + .textureView = texture.wgpuHandle.CreateView(), + }; + bindingEntries.add(bindGroupEntry, layoutEntry); + } + + // Write-only textures + for(const auto& texture : kernelSpec.writeOnlyTextures) { + const wgpu::BindGroupLayoutEntry layoutEntry { + .binding = bindingIndex++, + .visibility = wgpu::ShaderStage::Compute, + .storageTexture = { + .access = wgpu::StorageTextureAccess::WriteOnly, + .format = texture.wgpuHandle.GetFormat(), + .viewDimension = texture.wgpuHandle.GetDepthOrArrayLayers() > 1 + ? wgpu::TextureViewDimension::e3D + : wgpu::TextureViewDimension::e2D + } + }; + + const wgpu::BindGroupEntry bindGroupEntry { + .binding = layoutEntry.binding, + .textureView = texture.wgpuHandle.CreateView(), + }; + bindingEntries.add(bindGroupEntry, layoutEntry); + } + + // Samplers + for(const auto& sampler : kernelSpec.samplers) { + const wgpu::BindGroupLayoutEntry layoutEntry { + .binding = bindingIndex++, + .visibility = wgpu::ShaderStage::Compute, + .sampler = { .type = wgpu::SamplerBindingType::Filtering } + }; + + const wgpu::BindGroupEntry bindGroupEntry { + .binding = layoutEntry.binding, + .sampler = sampler, + }; + bindingEntries.add(bindGroupEntry, layoutEntry); + } + + const auto layoutDescLabel = kernelSpec.computeShader.name + " layout descriptor"; + + const wgpu::BindGroupLayoutDescriptor bindGroupLayoutDesc { + .label = layoutDescLabel.c_str(), + .entryCount = bindingEntries.bindGroupLayoutEntries.size(), + .entries = bindingEntries.bindGroupLayoutEntries.data(), + }; + + const wgpu::BindGroupLayout bindGroupLayout = device.CreateBindGroupLayout(&bindGroupLayoutDesc); + + const wgpu::PipelineLayoutDescriptor pipelineLayoutDesc { + .bindGroupLayoutCount = 1, + .bindGroupLayouts = &bindGroupLayout, + }; + const wgpu::PipelineLayout pipelineLayout = device.CreatePipelineLayout(&pipelineLayoutDesc); + + auto shaderPlaceHolders = kernelSpec.computeShader.placeholders; + if(kernelSpec.computeShader.workgroupSize.has_value()) { + shaderPlaceHolders["workgroup_size"] = std::to_string(kernelSpec.computeShader.workgroupSize->x) + ", " + + std::to_string(kernelSpec.computeShader.workgroupSize->y) + ", " + + std::to_string(kernelSpec.computeShader.workgroupSize->z); + } + const std::string shaderSource = [&](){ + return MR::match(kernelSpec.computeShader.shaderSource, + [&](const ShaderFile &shaderFile) { + return preprocessWGSLFile(shaderFile.filePath, + shaderPlaceHolders, + kernelSpec.computeShader.macros); + }, + [&](const InlineShaderText &inlineString) { + return preprocessWGSLString(inlineString.text, + shaderPlaceHolders, + kernelSpec.computeShader.macros); + }); + }(); + + const std::string computePipelineLabel = kernelSpec.computeShader.name + " compute pipeline"; + const wgpu::ComputePipelineDescriptor computePipelineDesc { + .label = computePipelineLabel.c_str(), + .layout = pipelineLayout, + .compute = { + .module = makeShaderModule(kernelSpec.computeShader.name, shaderSource, device), + .entryPoint = kernelSpec.computeShader.entryPoint.c_str() + } + }; + + const wgpu::BindGroupDescriptor bindGroupDesc { + .layout = bindGroupLayout, + .entryCount = bindingEntries.bindGroupEntries.size(), + .entries = bindingEntries.bindGroupEntries.data(), + }; + + return Kernel { + .name = kernelSpec.computeShader.name, + .pipeline = device.CreateComputePipeline(&computePipelineDesc), + .bindGroup = device.CreateBindGroup(&bindGroupDesc), + .shaderSource = shaderSource + }; +} + +void ComputeContext::dispatchKernel(const Kernel &kernel, const DispatchGrid &dispatchGrid) const +{ + const wgpu::ComputePassDescriptor passDesc { + .label = kernel.name.c_str(), + }; + const wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + const wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&passDesc); + pass.SetPipeline(kernel.pipeline); + pass.SetBindGroup(0, kernel.bindGroup); + pass.DispatchWorkgroups(dispatchGrid.wgCountX, dispatchGrid.wgCountY, dispatchGrid.wgCountZ); + pass.End(); + + const wgpu::CommandBuffer commands = encoder.Finish(); + device.GetQueue().Submit(1, &commands); +} + +wgpu::Sampler ComputeContext::newLinearSampler() const +{ + const wgpu::SamplerDescriptor samplerDesc { + .magFilter = wgpu::FilterMode::Linear, + .minFilter = wgpu::FilterMode::Linear, + .mipmapFilter = wgpu::MipmapFilterMode::Linear, + .maxAnisotropy = 1 + }; + + return device.CreateSampler(&samplerDesc); +} + + +} diff --git a/cpp/gpu/gpu.h b/cpp/gpu/gpu.h new file mode 100644 index 0000000000..f4a33c572d --- /dev/null +++ b/cpp/gpu/gpu.h @@ -0,0 +1,218 @@ +#pragma once + +#include "image.h" +#include "match.h" +#include "span.h" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace MR::GPU { +// A workgroup is a collection of threads that execute the same kernel +// function in parallel. Each thread within a workgroup can cooperate with others +// through shared memory. +struct WorkgroupSize { + uint32_t x = 1; + uint32_t y = 1; + uint32_t z = 1; + + // As a rule of thumb, for optimal performance across different hardware, the + // total number of threads in a workgroup should be a multiple of 64. + uint32_t threadCount() const { return x * y * z; } +}; + +// The dispatch grid defines the number of workgroups to be dispatched for a +// kernel. The total number of threads dispatched is the product of the number of +// workgroups in each dimension and the number of threads per workgroup. +struct DispatchGrid { + // Number of workgroups for each dimension. + uint32_t wgCountX = 1; + uint32_t wgCountY = 1; + uint32_t wgCountZ = 1; + + uint32_t workgroupCount() const { return wgCountX * wgCountY * wgCountZ; } +}; + +// Absolute/relative (to working dir) path of a WGSL file. +struct ShaderFile { std::filesystem::path filePath; }; + +struct InlineShaderText { std::string text; }; + +using ShaderSource = std::variant; + +struct ShaderEntry { + ShaderSource shaderSource; + + std::string entryPoint = "main"; + + std::string name = MR::match(shaderSource, + [](const ShaderFile& file) { return file.filePath.stem().string(); }, + [](const InlineShaderText&){ return std::string("inline_shader"); }); + + // Convenience property to set the {{workgroup_size}} placeholder. + // Only relevant for compute shaders. + std::optional workgroupSize; + + // Map of placeholders to their values. The values will be replaced in the shader source code. + // Placeholders must be in the format {{placeholder_name}}. + std::unordered_map placeholders; + + // Set of macro definitions to be defined in the shader. + std::unordered_set macros; +}; + + +template +struct Buffer { + wgpu::Buffer wgpuHandle; + + static_assert( + std::is_same_v || std::is_same_v || + std::is_same_v, + "GPU::Buffer only supports float, int32_t or uint32_t" + ); +}; + +using BufferVariant = std::variant, + Buffer, + Buffer>; + +struct TextureUsage { + bool storageBinding = false; + bool renderTarget = false; +}; + +enum class TextureFormat : uint8_t { + R32Float, +}; + +struct TextureSpec { + uint32_t width = 0; + uint32_t height = 0; + uint32_t depth = 1; + TextureFormat format = TextureFormat::R32Float; + TextureUsage usage; +}; + +struct Texture { + wgpu::Texture wgpuHandle; + TextureSpec spec; +}; + +struct KernelSpec { + // NOTE: The order in the shader must match the lists below: + // 1. Uniform buffers + // 2. Read-only buffers + // 3. Read-write buffers + // 4. Read-only textures + // 5. Write-only textures + // 6. Samplers + // List order must also match the shader's binding points. + ShaderEntry computeShader; + std::vector uniformBuffers; + std::vector readOnlyBuffers; + std::vector readWriteBuffers; + std::vector readOnlyTextures; + std::vector writeOnlyTextures; + std::vector samplers; +}; + +struct Kernel { + std::string name; + wgpu::ComputePipeline pipeline; + wgpu::BindGroup bindGroup; + // For debugging purposes, the shader source code is stored here. + std::string shaderSource; +}; + +struct ComputeContext { + explicit ComputeContext(); + + template + Buffer newEmptyBuffer(size_t size) const { + return { innerNewEmptyBuffer(size * sizeof(T)) }; + } + + template + Buffer newBufferFromHostMemory(tcb::span srcMemory) const { + return { innerNewBufferFromHostMemory(srcMemory.data(), srcMemory.size_bytes()) }; + } + + template + Buffer newBufferFromHostMemory(const void* srcMemory, size_t byteSize) const { + return { innerNewBufferFromHostMemory(srcMemory, byteSize) }; + } + + template + Buffer newBufferFromHostMemory(const std::vector> &srcMemoryRegions) const { + size_t totalBytes = 0; + for (const auto& region : srcMemoryRegions) totalBytes += region.size_bytes(); + + auto buffer = innerNewEmptyBuffer(totalBytes); + uint64_t offset = 0; + for (const auto& region : srcMemoryRegions) { + innerWriteToBuffer(buffer, region.data(), region.size_bytes(), offset); + offset += region.size_bytes(); + } + return Buffer{ std::move(buffer) }; + } + + // This function blocks until the download is complete. + template + void downloadBuffer(const Buffer& buffer, tcb::span dstMemoryRegion) const { + return downloadBuffer(buffer, dstMemoryRegion.data(), dstMemoryRegion.size_bytes()); + } + + // This function blocks until the download is complete. + template + void downloadBuffer(const Buffer& buffer, void* data, size_t dstByteSize) const { + return innerDownloadBuffer(buffer.wgpuHandle, data, dstByteSize); + } + + template + void writeToBuffer(const Buffer& buffer, tcb::span srcMemoryRegion, uint64_t offset = 0) const { + return writeToBuffer(buffer, srcMemoryRegion.data(), srcMemoryRegion.size_bytes(), offset * sizeof(T)); + } + + template + void writeToBuffer(const Buffer& buffer, const void* data, size_t size, uint64_t bytesOffset = 0) const { + return innerWriteToBuffer(buffer.wgpuHandle, data, size, bytesOffset); + } + + Texture newEmptyTexture(const TextureSpec& textureSpec) const; + + Texture newTextureFromHostMemory(const TextureSpec& textureDesc, + tcb::span srcMemoryRegion) const; + + Texture newTextureFromHostImage(const MR::Image& image, const TextureUsage& usage = {}) const; + + // This function blocks until the download is complete. + void downloadTexture(const Texture& texture, tcb::span dstMemoryRegion) const; + + Kernel newKernel(const KernelSpec& kernelSpec) const; + + void dispatchKernel(const Kernel& kernel, const DispatchGrid& dispatchGrid) const; + + wgpu::Sampler newLinearSampler() const; + +private: + wgpu::Buffer innerNewEmptyBuffer(size_t byteSize) const; + wgpu::Buffer innerNewBufferFromHostMemory(const void* srcMemory, size_t srcByteSize) const; + void innerDownloadBuffer(const wgpu::Buffer& buffer, void* dstMemory, size_t dstByteSize) const; + void innerWriteToBuffer(const wgpu::Buffer& buffer, const void* data, size_t srcByteSize, uint64_t offset) const; + + wgpu::Instance instance; + wgpu::Adapter adapter; + wgpu::Device device; +}; +} diff --git a/cpp/gpu/match.h b/cpp/gpu/match.h new file mode 100644 index 0000000000..7863820aa1 --- /dev/null +++ b/cpp/gpu/match.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace MR { +template struct overload : Ts... { using Ts::operator()...; }; +template overload(Ts...) -> overload; +template auto match(var_t & variant, Func &&... funcs) +{ + return std::visit(overload{ std::forward(funcs)... }, variant); +} +} diff --git a/cpp/gpu/shaders/transform_image.wgsl b/cpp/gpu/shaders/transform_image.wgsl new file mode 100644 index 0000000000..27cffc48b9 --- /dev/null +++ b/cpp/gpu/shaders/transform_image.wgsl @@ -0,0 +1,33 @@ +const workgroupSize = vec3({{workgroup_size}}); +const workgroupInvocations = workgroupSize.x * workgroupSize.y * workgroupSize.z; + +// Linear transformation matrix from output to input +@group(0) @binding(0) var transformationMat: mat4x4; +@group(0) @binding(1) var inputImage: texture_3d; +@group(0) @binding(2) var outputImage: texture_storage_3d; +@group(0) @binding(3) var linearSampler: sampler; + +@compute @workgroup_size(workgroupSize.x, workgroupSize.y, workgroupSize.z) +fn main(@builtin(global_invocation_id) globalId: vec3) +{ + let inputDims = textureDimensions(inputImage); + let outputDims = textureDimensions(outputImage); + + if (globalId.x >= outputDims.x || globalId.y >= outputDims.y || globalId.z >= outputDims.z) { + return; + } + + let dstVoxel = vec3(globalId); + let transformedVoxel4 = transformationMat * vec4(dstVoxel, 1.0); + let transformedVoxel = transformedVoxel4.xyz / transformedVoxel4.w; + + let inside = all(transformedVoxel >= vec3(0.0)) && + all(transformedVoxel < vec3(inputDims)); + + var outputValue = vec4(); + if(inside) { + let sampleCoord = (transformedVoxel + 0.5)/ vec3(inputDims); + outputValue = textureSampleLevel(inputImage, linearSampler, sampleCoord, 0.0); + } + textureStore(outputImage, globalId, outputValue); +} diff --git a/cpp/gpu/span.h b/cpp/gpu/span.h new file mode 100644 index 0000000000..a7b67c2837 --- /dev/null +++ b/cpp/gpu/span.h @@ -0,0 +1,618 @@ + +/* +This is an implementation of C++20's std::span +http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/n4820.pdf +*/ + +// Copyright Tristan Brindle 2018. +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file ../../LICENSE_1_0.txt or copy at +// https://www.boost.org/LICENSE_1_0.txt) + +#ifndef TCB_SPAN_HPP_INCLUDED +#define TCB_SPAN_HPP_INCLUDED + +#include +#include +#include +#include + +#ifndef TCB_SPAN_NO_EXCEPTIONS +// Attempt to discover whether we're being compiled with exception support +#if !(defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) +#define TCB_SPAN_NO_EXCEPTIONS +#endif +#endif + +#ifndef TCB_SPAN_NO_EXCEPTIONS +#include +#include +#endif + +// Various feature test macros + +#ifndef TCB_SPAN_NAMESPACE_NAME +#define TCB_SPAN_NAMESPACE_NAME tcb +#endif + +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define TCB_SPAN_HAVE_CPP17 +#endif + +#if __cplusplus >= 201402L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +#define TCB_SPAN_HAVE_CPP14 +#endif + +namespace TCB_SPAN_NAMESPACE_NAME { + +// Establish default contract checking behavior +#if !defined(TCB_SPAN_THROW_ON_CONTRACT_VIOLATION) && \ +!defined(TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION) && \ + !defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#if defined(NDEBUG) || !defined(TCB_SPAN_HAVE_CPP14) +#define TCB_SPAN_NO_CONTRACT_CHECKING +#else +#define TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION +#endif +#endif + +#if defined(TCB_SPAN_THROW_ON_CONTRACT_VIOLATION) + struct contract_violation_error : std::logic_error { + explicit contract_violation_error(const char* msg) : std::logic_error(msg) + {} +}; + +inline void contract_violation(const char* msg) +{ + throw contract_violation_error(msg); +} + +#elif defined(TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION) +[[noreturn]] inline void contract_violation(const char* /*unused*/) +{ + std::terminate(); +} +#endif + +#if !defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#define TCB_SPAN_STRINGIFY(cond) #cond +#define TCB_SPAN_EXPECT(cond) \ +cond ? (void) 0 : contract_violation("Expected " TCB_SPAN_STRINGIFY(cond)) +#else +#define TCB_SPAN_EXPECT(cond) +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_inline_variables) +#define TCB_SPAN_INLINE_VAR inline +#else +#define TCB_SPAN_INLINE_VAR +#endif + +#if defined(TCB_SPAN_HAVE_CPP14) || \ + (defined(__cpp_constexpr) && __cpp_constexpr >= 201304) +#define TCB_SPAN_HAVE_CPP14_CONSTEXPR +#endif + +#if defined(TCB_SPAN_HAVE_CPP14_CONSTEXPR) +#define TCB_SPAN_CONSTEXPR14 constexpr +#else +#define TCB_SPAN_CONSTEXPR14 +#endif + +#if defined(TCB_SPAN_HAVE_CPP14_CONSTEXPR) && \ + (!defined(_MSC_VER) || _MSC_VER > 1900) +#define TCB_SPAN_CONSTEXPR_ASSIGN constexpr +#else +#define TCB_SPAN_CONSTEXPR_ASSIGN +#endif + +#if defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#define TCB_SPAN_CONSTEXPR11 constexpr +#else +#define TCB_SPAN_CONSTEXPR11 TCB_SPAN_CONSTEXPR14 +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_deduction_guides) +#define TCB_SPAN_HAVE_DEDUCTION_GUIDES +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_byte) +#define TCB_SPAN_HAVE_STD_BYTE +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_array_constexpr) +#define TCB_SPAN_HAVE_CONSTEXPR_STD_ARRAY_ETC +#endif + +#if defined(TCB_SPAN_HAVE_CONSTEXPR_STD_ARRAY_ETC) +#define TCB_SPAN_ARRAY_CONSTEXPR constexpr +#else +#define TCB_SPAN_ARRAY_CONSTEXPR +#endif + +#ifdef TCB_SPAN_HAVE_STD_BYTE + using byte = std::byte; +#else +using byte = unsigned char; +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) +#define TCB_SPAN_NODISCARD [[nodiscard]] +#else +#define TCB_SPAN_NODISCARD +#endif + +TCB_SPAN_INLINE_VAR constexpr std::size_t dynamic_extent = SIZE_MAX; + +template +class span; + +namespace detail { + +template +struct span_storage { + constexpr span_storage() noexcept = default; + + constexpr span_storage(E* p_ptr, std::size_t /*unused*/) noexcept + : ptr(p_ptr) + {} + + E* ptr = nullptr; + static constexpr std::size_t size = S; +}; + +template +struct span_storage { + constexpr span_storage() noexcept = default; + + constexpr span_storage(E* p_ptr, std::size_t p_size) noexcept + : ptr(p_ptr), size(p_size) + {} + + E* ptr = nullptr; + std::size_t size = 0; +}; + +// Reimplementation of C++17 std::size() and std::data() +#if defined(TCB_SPAN_HAVE_CPP17) || \ +defined(__cpp_lib_nonmember_container_access) + using std::data; +using std::size; +#else +template +constexpr auto size(const C& c) -> decltype(c.size()) +{ + return c.size(); +} + +template +constexpr std::size_t size(const T (&)[N]) noexcept +{ + return N; +} + +template +constexpr auto data(C& c) -> decltype(c.data()) +{ + return c.data(); +} + +template +constexpr auto data(const C& c) -> decltype(c.data()) +{ + return c.data(); +} + +template +constexpr T* data(T (&array)[N]) noexcept +{ + return array; +} + +template +constexpr const E* data(std::initializer_list il) noexcept +{ + return il.begin(); +} +#endif // TCB_SPAN_HAVE_CPP17 + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_void_t) +using std::void_t; +#else +template +using void_t = void; +#endif + +template +using uncvref_t = + typename std::remove_cv::type>::type; + +template +struct is_span : std::false_type {}; + +template +struct is_span> : std::true_type {}; + +template +struct is_std_array : std::false_type {}; + +template +struct is_std_array> : std::true_type {}; + +template +struct has_size_and_data : std::false_type {}; + +template +struct has_size_and_data())), + decltype(detail::data(std::declval()))>> + : std::true_type {}; + +template > +struct is_container { + static constexpr bool value = + !is_span::value && !is_std_array::value && + !std::is_array::value && has_size_and_data::value; +}; + +template +using remove_pointer_t = typename std::remove_pointer::type; + +template +struct is_container_element_type_compatible : std::false_type {}; + +template +struct is_container_element_type_compatible< + T, E, + typename std::enable_if< + !std::is_same< + typename std::remove_cv()))>::type, + void>::value && + std::is_convertible< + remove_pointer_t()))> (*)[], + E (*)[]>::value + >::type> + : std::true_type {}; + +template +struct is_complete : std::false_type {}; + +template +struct is_complete : std::true_type {}; + +} // namespace detail + +template +class span { + static_assert(std::is_object::value, + "A span's ElementType must be an object type (not a " + "reference type or void)"); + static_assert(detail::is_complete::value, + "A span's ElementType must be a complete type (not a forward " + "declaration)"); + static_assert(!std::is_abstract::value, + "A span's ElementType cannot be an abstract class type"); + + using storage_type = detail::span_storage; + +public: + // constants and types + using element_type = ElementType; + using value_type = typename std::remove_cv::type; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using reverse_iterator = std::reverse_iterator; + + static constexpr size_type extent = Extent; + + // [span.cons], span constructors, copy, assignment, and destructor + template < + std::size_t E = Extent, + typename std::enable_if<(E == dynamic_extent || E <= 0), int>::type = 0> + constexpr span() noexcept + {} + + TCB_SPAN_CONSTEXPR11 span(pointer ptr, size_type count) + : storage_(ptr, count) + { + TCB_SPAN_EXPECT(extent == dynamic_extent || count == extent); + } + + TCB_SPAN_CONSTEXPR11 span(pointer first_elem, pointer last_elem) + : storage_(first_elem, last_elem - first_elem) + { + TCB_SPAN_EXPECT(extent == dynamic_extent || + last_elem - first_elem == + static_cast(extent)); + } + + template ::value, + int>::type = 0> + constexpr span(element_type (&arr)[N]) noexcept : storage_(arr, N) + {} + + template &, ElementType>::value, + int>::type = 0> + TCB_SPAN_ARRAY_CONSTEXPR span(std::array& arr) noexcept + : storage_(arr.data(), N) + {} + + template &, ElementType>::value, + int>::type = 0> + TCB_SPAN_ARRAY_CONSTEXPR span(const std::array& arr) noexcept + : storage_(arr.data(), N) + {} + + template < + typename Container, std::size_t E = Extent, + typename std::enable_if< + E == dynamic_extent && detail::is_container::value && + detail::is_container_element_type_compatible< + Container&, ElementType>::value, + int>::type = 0> + constexpr span(Container& cont) + : storage_(detail::data(cont), detail::size(cont)) + {} + + template < + typename Container, std::size_t E = Extent, + typename std::enable_if< + E == dynamic_extent && detail::is_container::value && + detail::is_container_element_type_compatible< + const Container&, ElementType>::value, + int>::type = 0> + constexpr span(const Container& cont) + : storage_(detail::data(cont), detail::size(cont)) + {} + + constexpr span(const span& other) noexcept = default; + + template ::value, + int>::type = 0> + constexpr span(const span& other) noexcept + : storage_(other.data(), other.size()) + {} + + ~span() noexcept = default; + + TCB_SPAN_CONSTEXPR_ASSIGN span& + operator=(const span& other) noexcept = default; + + // [span.sub], span subviews + template + TCB_SPAN_CONSTEXPR11 span first() const + { + TCB_SPAN_EXPECT(Count <= size()); + return {data(), Count}; + } + + template + TCB_SPAN_CONSTEXPR11 span last() const + { + TCB_SPAN_EXPECT(Count <= size()); + return {data() + (size() - Count), Count}; + } + + template + using subspan_return_t = + span; + + template + TCB_SPAN_CONSTEXPR11 subspan_return_t subspan() const + { + TCB_SPAN_EXPECT(Offset <= size() && + (Count == dynamic_extent || Offset + Count <= size())); + return {data() + Offset, + Count != dynamic_extent ? Count : size() - Offset}; + } + + TCB_SPAN_CONSTEXPR11 span + first(size_type count) const + { + TCB_SPAN_EXPECT(count <= size()); + return {data(), count}; + } + + TCB_SPAN_CONSTEXPR11 span + last(size_type count) const + { + TCB_SPAN_EXPECT(count <= size()); + return {data() + (size() - count), count}; + } + + TCB_SPAN_CONSTEXPR11 span + subspan(size_type offset, size_type count = dynamic_extent) const + { + TCB_SPAN_EXPECT(offset <= size() && + (count == dynamic_extent || offset + count <= size())); + return {data() + offset, + count == dynamic_extent ? size() - offset : count}; + } + + // [span.obs], span observers + constexpr size_type size() const noexcept { return storage_.size; } + + constexpr size_type size_bytes() const noexcept + { + return size() * sizeof(element_type); + } + + TCB_SPAN_NODISCARD constexpr bool empty() const noexcept + { + return size() == 0; + } + + // [span.elem], span element access + TCB_SPAN_CONSTEXPR11 reference operator[](size_type idx) const + { + TCB_SPAN_EXPECT(idx < size()); + return *(data() + idx); + } + + TCB_SPAN_CONSTEXPR11 reference front() const + { + TCB_SPAN_EXPECT(!empty()); + return *data(); + } + + TCB_SPAN_CONSTEXPR11 reference back() const + { + TCB_SPAN_EXPECT(!empty()); + return *(data() + (size() - 1)); + } + + constexpr pointer data() const noexcept { return storage_.ptr; } + + // [span.iterators], span iterator support + constexpr iterator begin() const noexcept { return data(); } + + constexpr iterator end() const noexcept { return data() + size(); } + + TCB_SPAN_ARRAY_CONSTEXPR reverse_iterator rbegin() const noexcept + { + return reverse_iterator(end()); + } + + TCB_SPAN_ARRAY_CONSTEXPR reverse_iterator rend() const noexcept + { + return reverse_iterator(begin()); + } + +private: + storage_type storage_{}; +}; + +#ifdef TCB_SPAN_HAVE_DEDUCTION_GUIDES + +/* Deduction Guides */ +template +span(T (&)[N])->span; + +template +span(std::array&)->span; + +template +span(const std::array&)->span; + +template +span(Container&)->span()))>::type>; + +template +span(const Container&)->span; + +#endif // TCB_HAVE_DEDUCTION_GUIDES + +template +constexpr span +make_span(span s) noexcept +{ + return s; +} + +template +constexpr span make_span(T (&arr)[N]) noexcept +{ + return {arr}; +} + +template +TCB_SPAN_ARRAY_CONSTEXPR span make_span(std::array& arr) noexcept +{ + return {arr}; +} + +template +TCB_SPAN_ARRAY_CONSTEXPR span +make_span(const std::array& arr) noexcept +{ + return {arr}; +} + +template +constexpr span()))>::type> +make_span(Container& cont) +{ + return {cont}; +} + +template +constexpr span +make_span(const Container& cont) +{ + return {cont}; +} + +template +span +as_bytes(span s) noexcept +{ + return {reinterpret_cast(s.data()), s.size_bytes()}; +} + +template < + class ElementType, size_t Extent, + typename std::enable_if::value, int>::type = 0> +span +as_writable_bytes(span s) noexcept +{ + return {reinterpret_cast(s.data()), s.size_bytes()}; +} + +template +constexpr auto get(span s) -> decltype(s[N]) +{ + return s[N]; +} + +} // namespace TCB_SPAN_NAMESPACE_NAME + +namespace std { + +template +class tuple_size> + : public integral_constant {}; + +template +class tuple_size>; // not defined + +template +class tuple_element> { +public: + static_assert(Extent != TCB_SPAN_NAMESPACE_NAME::dynamic_extent && + I < Extent, + ""); + using type = ElementType; +}; + +} // end namespace std + +#endif // TCB_SPAN_HPP_INCLUDED diff --git a/cpp/gpu/testing/CMakeLists.txt b/cpp/gpu/testing/CMakeLists.txt new file mode 100644 index 0000000000..5254398c39 --- /dev/null +++ b/cpp/gpu/testing/CMakeLists.txt @@ -0,0 +1,28 @@ +include(FetchContent) +set(GTEST_VERSION 1.16.0) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/releases/download/v${GTEST_VERSION}/googletest-${GTEST_VERSION}.tar.gz +) +FetchContent_MakeAvailable(googletest) + +add_executable(mrtrix-gpu-tests + gputests_common.h + gputests.cpp +) + +target_compile_definitions(mrtrix-gpu-tests PRIVATE + # To avoid conflicts with MRtrix's FAIL macro in exception.h + GTEST_DONT_DEFINE_FAIL +) +target_link_libraries(mrtrix-gpu-tests + GTest::gtest + mrtrix-gpu-lib +) + +set(TEST_SHADER_SOURCE "test_shader.wgsl") +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/${TEST_SHADER_SOURCE} ${CMAKE_CURRENT_BINARY_DIR}/${TEST_SHADER_SOURCE} COPYONLY) + + +include(GoogleTest) +gtest_discover_tests(mrtrix-gpu-tests) diff --git a/cpp/gpu/testing/gputests.cpp b/cpp/gpu/testing/gputests.cpp new file mode 100644 index 0000000000..3209251690 --- /dev/null +++ b/cpp/gpu/testing/gputests.cpp @@ -0,0 +1,275 @@ +#include + +#include "exception.h" +#include "gpu.h" +#include "gputests_common.h" +#include "span.h" + +#include +#include +#include +#include +#include +#include +#include + + +using namespace MR; +using namespace MR::GPU; + + +TEST_F(GPUTest, MakeEmptyBuffer) { + const size_t bufferElementCount = 1024; + const Buffer buffer = context.newEmptyBuffer(bufferElementCount); + + std::vector downloadedData(bufferElementCount, 1); // Initialize with non-zero + + context.downloadBuffer(buffer, downloadedData); + + for (auto val : downloadedData) { + EXPECT_EQ(val, 0); + } +} + +TEST_F(GPUTest, BufferFromHostMemory) { + std::vector hostData = {1, 2, 3, 4, 5}; + + const Buffer buffer = context.newBufferFromHostMemory(tcb::span(hostData)); + + std::vector downloadedData(hostData.size(), 0); + context.downloadBuffer(buffer, downloadedData); + + EXPECT_EQ(downloadedData, hostData); +} + +TEST_F(GPUTest, BufferFromHostMemoryVoidPtr) { + std::vector hostData = {1.0F, 2.5F, -3.0F}; + const Buffer buffer = context.newBufferFromHostMemory(hostData.data(), hostData.size() * sizeof(float)); + + std::vector downloadedData(hostData.size()); + context.downloadBuffer(buffer, downloadedData); + EXPECT_EQ(downloadedData, hostData); +} + + +TEST_F(GPUTest, BufferFromHostMemoryMultipleRegions) { + std::vector region1 = {1, 2, 3}; + std::vector region2 = {4, 5}; + std::vector region3 = {6, 7, 8, 9}; + + const std::vector> regions = {region1, region2, region3}; + const Buffer buffer = context.newBufferFromHostMemory(regions); + + const std::vector expectedData = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector downloadedData(expectedData.size()); + context.downloadBuffer(buffer, downloadedData); + + EXPECT_EQ(downloadedData, expectedData); +} + + +TEST_F(GPUTest, WriteToBuffer) { + std::vector newData = {0.1F, 0.2F, 0.3F, 0.4F}; + + const Buffer buffer = context.newEmptyBuffer(newData.size()); + std::vector downloadedData(newData.size(), 0.0F); + + context.writeToBuffer(buffer, newData); + context.downloadBuffer(buffer, downloadedData); + + for (size_t i = 0; i < newData.size(); i++) { + EXPECT_FLOAT_EQ(downloadedData[i], newData[i]); + } +} + +TEST_F(GPUTest, WriteToBufferWithOffset) { + const size_t bufferSize = 10; + std::vector initialData(bufferSize); + std::iota(initialData.begin(), initialData.end(), 0.0F); // 0, 1, ..., 9 + + const Buffer buffer = context.newBufferFromHostMemory(initialData); + + std::vector newData = {100.0F, 101.0F, 102.0F}; + const uint32_t offsetElements = 3; + const uint32_t offsetBytes = offsetElements * sizeof(float); + + context.writeToBuffer(buffer, newData, offsetBytes); + + std::vector downloadedData(bufferSize); + context.downloadBuffer(buffer, downloadedData); + + std::vector expectedData = {0.0F, 1.0F, 2.0F, 100.0F, 101.0F, 102.0F, 6.0F, 7.0F, 8.0F, 9.0F}; + for (size_t i = 0; i < bufferSize; ++i) { + EXPECT_FLOAT_EQ(downloadedData[i], expectedData[i]); + } +} + + +TEST_F(GPUTest, EmptyTexture) { + const MR::GPU::TextureSpec textureSpec = { + .width = 4, .height = 4, .depth = 1, .format = TextureFormat::R32Float, + }; + + const auto texture = context.newEmptyTexture(textureSpec); + + const uint32_t bytesPerPixel = 4; // R32Float + const size_t downloadedSizeBytes = textureSpec.width * textureSpec.height * textureSpec.depth * bytesPerPixel; + std::vector downloadedData(downloadedSizeBytes / sizeof(float), 1.0f); // Init with non-zero + + context.downloadTexture(texture, downloadedData); + + for (uint32_t z = 0; z < textureSpec.depth; ++z) { + for (uint32_t y = 0; y < textureSpec.height; ++y) { + for (uint32_t x = 0; x < textureSpec.width; ++x) { + const size_t idx = (z * textureSpec.height + y) * textureSpec.width + x; + EXPECT_FLOAT_EQ(downloadedData[idx], 0.0F); + } + } + } +} + + +TEST_F(GPUTest, KernelWithInlineShader) { + const std::string shaderCode = R"wgsl( + @group(0) @binding(0) var data: array; + + @compute @workgroup_size(64) + fn main(@builtin(global_invocation_id) id: vec3) { + let idx = id.x; + if (idx < arrayLength(&data)) { + data[idx] = data[idx] * 3.0; + } + } + )wgsl"; + + const std::vector hostData = {1.0F, 2.0F, 3.0F, 4.0F}; + const std::vector expectedData = {3.0F, 6.0F, 9.0F, 12.0F}; + Buffer buffer = context.newBufferFromHostMemory(hostData); + + const KernelSpec kernelSpec { + .computeShader = { + .shaderSource = InlineShaderText{ shaderCode }, + }, + .readWriteBuffers = { buffer } + }; + + const Kernel kernel = context.newKernel(kernelSpec); + const DispatchGrid dispatchGrid = { static_cast((hostData.size() + 63)), 1, 1 }; + context.dispatchKernel(kernel, dispatchGrid); + + std::vector resultData(hostData.size()); + context.downloadBuffer(buffer, resultData); + EXPECT_EQ(resultData, expectedData); +} + + +TEST_F(GPUTest, KernelWithPlaceholders) { + const std::string shaderCode = R"wgsl( + @group(0) @binding(0) var data: array; + + @compute @workgroup_size(64) + fn main(@builtin(global_invocation_id) id: vec3) { + let idx = id.x; + if (idx < arrayLength(&data)) { + data[idx] = data[idx] + {{value_to_add}}; + } + } + )wgsl"; + + const std::vector hostData = {10.0F, 20.0F}; + const float valueToAdd = 5.5F; + const std::vector expectedData = {15.5F, 25.5F}; + const Buffer buffer = context.newBufferFromHostMemory(hostData); + + const KernelSpec kernelSpec { + .computeShader = { + .shaderSource = InlineShaderText{ shaderCode }, + .placeholders = {{"value_to_add", std::to_string(valueToAdd)}} + }, + .readWriteBuffers = { buffer } + }; + + const Kernel kernel = context.newKernel(kernelSpec); + constexpr DispatchGrid dispatchGrid = { 1, 1, 1 }; + context.dispatchKernel(kernel, dispatchGrid); + + std::vector resultData(hostData.size()); + context.downloadBuffer(buffer, resultData); + EXPECT_EQ(resultData, expectedData); +} + +TEST_F(GPUTest, KernelWithMacros) { + const std::string shaderCode = R"wgsl( + @group(0) @binding(0) var data: array; + + @compute @workgroup_size(64) + fn main_macro(@builtin(global_invocation_id) id: vec3) { + let idx = id.x; + if (idx < arrayLength(&data)) { + #ifdef MULTIPLY_MODE + data[idx] = data[idx] * 2.0; + #else + data[idx] = data[idx] + 1.0; + #endif + } + } + )wgsl"; + + std::vector hostData = {5.0F, 10.0F}; + Buffer buffer = context.newBufferFromHostMemory(hostData); + + // Test with MULTIPLY_MODE defined + const KernelSpec specMultiply { + .computeShader = { + .shaderSource = InlineShaderText{ shaderCode }, + .entryPoint = "main_macro", + .macros = {"MULTIPLY_MODE"} + }, + .readWriteBuffers = { buffer } + }; + const Kernel kernelMultiply = context.newKernel(specMultiply); + const DispatchGrid dispatchGrid = { 1, 1, 1 }; + context.dispatchKernel(kernelMultiply, dispatchGrid); + + std::vector resultDataMultiply(hostData.size()); + context.downloadBuffer(buffer, resultDataMultiply); + const std::vector expectedDataMultiply = {10.0F, 20.0F}; + EXPECT_EQ(resultDataMultiply, expectedDataMultiply); + + // Test without MULTIPLY_MODE (ADD_MODE) + context.writeToBuffer(buffer, hostData); // Reset buffer + const KernelSpec specAdd { + .computeShader = { + .shaderSource = InlineShaderText{ shaderCode }, + .entryPoint = "main_macro" + // .macros is empty + }, + .readWriteBuffers = { buffer } + }; + const Kernel kernelAdd = context.newKernel(specAdd); + context.dispatchKernel(kernelAdd, dispatchGrid); + + std::vector resultDataAdd(hostData.size()); + context.downloadBuffer(buffer, resultDataAdd); + const std::vector expectedDataAdd = {6.0F, 11.0F}; + EXPECT_EQ(resultDataAdd, expectedDataAdd); +} + +int main(int argc, char **argv) { + try { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); + } + catch (const MR::Exception &e) { + e.display(); + return 1; + } + catch (const std::exception &e) { + std::cerr << "Uncaught exception: " << e.what() << "\n"; + return 1; + } + catch (...) { + std::cerr << "Uncaught exception of unknown type" << "\n"; + return 1; + } +} diff --git a/cpp/gpu/testing/gputests_common.h b/cpp/gpu/testing/gputests_common.h new file mode 100644 index 0000000000..f89f5af9f4 --- /dev/null +++ b/cpp/gpu/testing/gputests_common.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "gpu.h" + +class GPUTest : public ::testing::Test { +protected: + MR::GPU::ComputeContext context; + + void SetUp() { + ASSERT_NO_THROW( + context = MR::GPU::ComputeContext(); + ); + } +}; diff --git a/cpp/gpu/testing/test_shader.wgsl b/cpp/gpu/testing/test_shader.wgsl new file mode 100644 index 0000000000..326bd830da --- /dev/null +++ b/cpp/gpu/testing/test_shader.wgsl @@ -0,0 +1,9 @@ +@group(0) @binding(0) var input : array; +@group(0) @binding(1) var output : array; + +@compute @workgroup_size(256, 1, 1) +fn main(@builtin(global_invocation_id) globalId: vec3) +{ + let value = input[globalId.x]; + output[globalId.x] = sqrt(value); +} diff --git a/cpp/gpu/utils.cpp b/cpp/gpu/utils.cpp new file mode 100644 index 0000000000..179f7d800b --- /dev/null +++ b/cpp/gpu/utils.cpp @@ -0,0 +1,35 @@ +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include + +using namespace std::string_literals; + +uint32_t Utils::nextMultipleOf(const uint32_t value, const uint32_t multiple) +{ + if (value > std::numeric_limits::max() - multiple) { + return std::numeric_limits::max(); + } + return (value + multiple - 1) / multiple * multiple; +} + +std::string Utils::readFile(const std::filesystem::path &filePath, ReadFileMode mode) +{ + if(!std::filesystem::exists(filePath)) { + throw std::runtime_error("File not found: "s + filePath.string()); + } + + const auto openMode = (mode == ReadFileMode::Binary) ? std::ios::in | std::ios::binary : std::ios::in; + std::ifstream f(filePath, std::ios::in | openMode); + const auto fileSize = std::filesystem::file_size(filePath); + std::string result(fileSize, '\0'); + f.read(result.data(), fileSize); + + return result; +} + diff --git a/cpp/gpu/utils.h b/cpp/gpu/utils.h new file mode 100644 index 0000000000..7af47948be --- /dev/null +++ b/cpp/gpu/utils.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace Utils { + +// Divides the input vector into equal-sized rows (each row having "chunkSize" elements) +// and then performs a column-wise accumulation using the provided binary operator. +// e.g. { 1, 2, 3, 4, 5, 6 } with chunkSize = 2, we form the "matrix" +// [1, 3, 5] +// [2, 4, 6] +// and then perform the operation on each column. +template +std::vector chunkReduce(const std::vector& data, size_t chunkSize, BinaryOp op) { + if (chunkSize == 0) { + throw std::invalid_argument("chunkSize cannot be zero."); + } + if (data.size() % chunkSize != 0) { + throw std::invalid_argument("vector size must be a multiple of chunkSize."); + } + + const size_t numRows = data.size() / chunkSize; + std::vector result(chunkSize, T{}); + + for (size_t row = 0; row < numRows; ++row) { + for (size_t col = 0; col < chunkSize; ++col) { + result[col] = op(result[col], data[row * chunkSize + col]); + } + } + return result; +} + +// Returns the smallest multiple of `multiple` that is greater or equal to `value`. +uint32_t nextMultipleOf(const uint32_t value, const uint32_t multiple); + +enum ReadFileMode { + Text, + Binary +}; +std::string readFile(const std::filesystem::path &filePath, ReadFileMode mode = ReadFileMode::Text); +} diff --git a/cpp/gpu/wgslprocessing.cpp b/cpp/gpu/wgslprocessing.cpp new file mode 100644 index 0000000000..b2808ad711 --- /dev/null +++ b/cpp/gpu/wgslprocessing.cpp @@ -0,0 +1,191 @@ +#include "wgslprocessing.h" +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std::string_literals; + + +namespace MR::GPU { + +namespace { +std::string replacePlaceholders(const std::string& line, + const std::unordered_map& substitutions) { + const std::regex placeholder_regex(R"(\{\{([^{}]+)\}\})"); + std::string result; + size_t last_pos = 0; + + std::sregex_iterator it(line.begin(), line.end(), placeholder_regex); + const std::sregex_iterator end; + + for (; it != end; ++it) { + const std::smatch& match = *it; + const size_t start = match.position(); + const size_t length = match.length(); + const std::string key = match[1].str(); + + result += line.substr(last_pos, start - last_pos); + + auto sub_it = substitutions.find(key); + if (sub_it != substitutions.end()) { + result += sub_it->second; + } else { + result += match.str(); // Leave unknown placeholders intact + } + + last_pos = start + length; + } + + result += line.substr(last_pos); + + return result; +} + +std::string_view trimLeadingWhitespace(std::string_view s) { + size_t const start = s.find_first_not_of(" \t"); + return (start == std::string::npos) ? "" : s.substr(start); +} + +std::string preprocessRecursive( + const std::filesystem::path& currentPathContext, + std::unordered_set& visitedFilePaths, + const MacroDefinitions& definedMacros, + const std::string* initialCode = nullptr +) +{ + const std::filesystem::path normalizedPathKey = currentPathContext.lexically_normal(); + const std::string pathKeyStr = normalizedPathKey.string(); + + // Detect cycles + if (visitedFilePaths.count(pathKeyStr) > 0) { + throw std::runtime_error("Detected recursive include of " + pathKeyStr); + } + visitedFilePaths.insert(pathKeyStr); + + std::string code; + if (initialCode != nullptr) { + code = *initialCode; + } + else { + if (!std::filesystem::exists(normalizedPathKey)) { + throw std::runtime_error("File not found: " + pathKeyStr); + } + code = Utils::readFile(normalizedPathKey); + } + + std::stringstream inputStream(code); + std::stringstream outputStream; + // Tracks nesting of #ifdef/#else blocks + std::vector conditionStack; + std::string line; + + while (std::getline(inputStream, line)) { + const std::string_view trimmedLine = trimLeadingWhitespace(line); + + if (trimmedLine.rfind("#ifdef", 0) == 0) { + std::istringstream iss((std::string(trimmedLine))); + std::string directive; + std::string macro; + iss >> directive >> macro; + const bool parentActive = std::all_of(conditionStack.begin(), conditionStack.end(), [](bool b){ return b; }); + const bool isDefined = parentActive && ((definedMacros.count(macro)) != 0U); + conditionStack.push_back(isDefined); + continue; + } + else if (trimmedLine.rfind("#else", 0) == 0) { + if (conditionStack.empty()) { + throw std::runtime_error("Unmatched #else directive in " + pathKeyStr); + } + const bool previousCondition = conditionStack.back(); + conditionStack.pop_back(); + const bool parentActive = std::all_of(conditionStack.begin(), conditionStack.end(), [](bool b){ return b; }); + const bool newCondition = parentActive ? !previousCondition : false; + conditionStack.push_back(newCondition); + continue; + } + else if (trimmedLine.rfind("#endif", 0) == 0) { + if (conditionStack.empty()) { + throw std::runtime_error("Unmatched #endif directive in " + pathKeyStr); + } + conditionStack.pop_back(); + continue; + } + + const bool currentActive = std::all_of(conditionStack.begin(), conditionStack.end(), [](bool b){ return b; }); + if (currentActive && trimmedLine.rfind("#include", 0) == 0) { + const auto startQuote = trimmedLine.find_first_of("\"<"); + const auto endQuote = trimmedLine.find_last_of("\">"); + if (startQuote != std::string::npos && endQuote != std::string::npos && endQuote > startQuote) { + const std::string_view includePathStrView = trimmedLine.substr(startQuote + 1, endQuote - (startQuote + 1)); + const std::filesystem::path includeDirectivePath(std::string{includePathStrView}); + + std::filesystem::path fullPathToInclude; + const std::filesystem::path baseDir = normalizedPathKey.parent_path(); + + if (includeDirectivePath.is_absolute()) { + fullPathToInclude = includeDirectivePath; + } else { + fullPathToInclude = baseDir / includeDirectivePath; + } + + // Recursively process the included file content. initialCode is nullptr, so it will be read. + const std::string includedCode = preprocessRecursive(fullPathToInclude, visitedFilePaths, definedMacros, nullptr); + outputStream << includedCode << "\n"; + continue; + } else { + throw std::runtime_error("Malformed #include directive in " + pathKeyStr + ": " + std::string(trimmedLine)); + } + } + + if (currentActive) { + outputStream << line << "\n"; + } + } + + if (!conditionStack.empty()) { + throw std::runtime_error("Unterminated conditional block in " + pathKeyStr); + } + // NOTE: pathKeyStr is NOT removed from visitedFilePaths here, as it tracks files processed + // within the entire scope of one top-level preprocessWGSLFile/Inline + return outputStream.str(); +} + +} // namespace + + +std::string preprocessWGSLFile(const std::filesystem::path& filePath, + const PlaceHoldersMap& placeholders, + const MacroDefinitions& macros) +{ + std::unordered_set visitedFilePaths; + // preprocessRecursive will read the file specified by filePath. + const std::string combinedCode = preprocessRecursive(filePath, visitedFilePaths, macros, nullptr); + const std::string finalCode = replacePlaceholders(combinedCode, placeholders); + return finalCode; +} + +std::string preprocessWGSLString(const std::string& shaderText, + const PlaceHoldersMap& placeholders, + const MacroDefinitions& macros) +{ + std::unordered_set visitedFilePaths; + // Use a conceptual path for the inline shader. Relative includes will be resolved + // based on this path's parent. + const std::filesystem::path inlineContextPath = std::filesystem::current_path() / ""; + const std::string combinedCode = preprocessRecursive(inlineContextPath, visitedFilePaths, macros, &shaderText); + const std::string finalCode = replacePlaceholders(combinedCode, placeholders); + return finalCode; +} + +} // namespace MR::GPU diff --git a/cpp/gpu/wgslprocessing.h b/cpp/gpu/wgslprocessing.h new file mode 100644 index 0000000000..07014343f0 --- /dev/null +++ b/cpp/gpu/wgslprocessing.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include +#include + +namespace MR::GPU { +using PlaceHoldersMap = std::unordered_map; +using MacroDefinitions = std::unordered_set; +std::string preprocessWGSLFile(const std::filesystem::path& filePath, + const PlaceHoldersMap& placeholders, + const MacroDefinitions& macros); + +std::string preprocessWGSLString(const std::string& shaderText, + const PlaceHoldersMap& placeholders, + const MacroDefinitions& macros); +}