From 72402bf4c12eaf548a9feafed0b69e5ad44d0c26 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 9 Jan 2025 08:31:47 -0800 Subject: [PATCH] Add register sharing to warp-specialized circular buffering (#3669) This PR implements register sharing for warp-specialized circular buffering. Registers in the load warp group are moved to the compute warp group using the `setmaxnreg` ptx instruction. It is an optimization for matmul kernels. ## Changes 1. Add `__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/)` to cuda kernel declaration. 2. Add `kir::SetMaxNReg` and `kir::Return` nodes to warp-specialized circular buffering. 3. `TensorView::circularBuffer` allows setting the number of registers for load and compute warp groups through `struct WarpSpecialized` 4. Require Hopper architecture for TensorViews using warp-specialized circular buffering. ## Why `__launch_bounds__` is necessary? > The setmaxnreg instruction requires that the kernel has been launched with a valid value of maximum number of per-thread registers specified via the appropriate compilation via the appropriate compile-time option or the appropriate performance tuning directive. Otherwise, the setmaxnreg instruction may have no effect. From https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg ## Generated Code ```cuda __global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/64) nvfuser_none_f0_c0_r0_g0(Tensor T0, const __grid_constant__ TensorMap var0, Tensor T1) { // do something if ((((nvfuser_index_t)threadIdx.y) == 1)) { asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n"::"n"(24)); // load something return; } else { asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n"::"n"(240)); // compute something } // do something } ``` --- csrc/codegen.cpp | 23 ++- csrc/codegen.h | 3 +- .../device_lower/analysis/circular_buffer.cpp | 23 +++ csrc/device_lower/analysis/device_version.cpp | 9 ++ csrc/device_lower/analysis/device_version.h | 5 + csrc/device_lower/pass/circular_buffer.cpp | 37 +++++ csrc/ir/interface_nodes.h | 46 +++++- csrc/runtime/executor.cpp | 26 +-- tests/cpp/test_circular_buffering.cpp | 149 +++++++++++++++++- 9 files changed, 297 insertions(+), 24 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 3a5f31c74d5..fe00c1f8105 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -158,9 +158,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { public: static std::string generateKernelDefinition( const kir::Kernel* kernel, - const std::string& kernel_name) { + const std::string& kernel_name, + std::optional num_threads_per_cta) { CudaKernelGenerator codegen(kernel); - codegen.genDeclaration(kernel_name); + codegen.genDeclaration(kernel_name, num_threads_per_cta); codegen.startBlock(); codegen.genPrologue(); codegen.genBody(); @@ -272,8 +273,18 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } // Generates the kernel function declaration - void genDeclaration(const std::string& kernel_name) { + void genDeclaration( + const std::string& kernel_name, + std::optional num_threads_per_cta) { code_ << "__global__ void "; + if (kernel_->hasManaged("enable_register_sharing") && + kernel_->getManaged("enable_register_sharing")) { + NVF_ERROR( + num_threads_per_cta.has_value(), + "__launch_bounds__ must be set for register sharing warp specialization"); + code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" + << num_threads_per_cta.value() << ") "; + } if (kernel_->hasManaged("cluster_dims")) { auto cluster_dims = kernel_->getManaged>( @@ -3542,9 +3553,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { std::string generateCudaKernel( const kir::Kernel* kernel, - const std::string& kernel_name) { + const std::string& kernel_name, + std::optional num_threads_per_cta) { FUSER_PERF_SCOPE("generateCudaKernel"); - return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name); + return CudaKernelGenerator::generateKernelDefinition( + kernel, kernel_name, num_threads_per_cta); } } // namespace codegen diff --git a/csrc/codegen.h b/csrc/codegen.h index 8c0e89663d1..e2f1382c8d2 100644 --- a/csrc/codegen.h +++ b/csrc/codegen.h @@ -19,7 +19,8 @@ namespace codegen { //! Generates a CUDA kernel definition for the given kernel NVF_API std::string generateCudaKernel( const kir::Kernel* kernel, - const std::string& kernel_name = "CUDAGeneratedKernel"); + const std::string& kernel_name = "CUDAGeneratedKernel", + std::optional num_threads_per_cta = std::nullopt); } // namespace codegen } // namespace nvfuser diff --git a/csrc/device_lower/analysis/circular_buffer.cpp b/csrc/device_lower/analysis/circular_buffer.cpp index 58f35a1f8f0..21f58c8d43d 100644 --- a/csrc/device_lower/analysis/circular_buffer.cpp +++ b/csrc/device_lower/analysis/circular_buffer.cpp @@ -143,6 +143,29 @@ void validateCircularBufferedTensor(const TensorView* tv) { ". Consumer memory type: ", c_mem_type); + // Ensure that the warp-specialized circular buffer loop is the outer-most + // for-loop if register sharing is enabled. + if (std::holds_alternative( + tv->circularBufferOptions().type) && + std::get(tv->circularBufferOptions().type) + .num_registers.has_value()) { + for (int64_t axis : c10::irange((int64_t)tv->getLoopDomain().size())) { + // short-circuit: only check IterDomains to the left of the circular + // buffer position + if (axis >= circular_buffer_pos) { + break; + } + NVF_ERROR( + tv->getLoopDomain().at(axis)->isThread() || + tv->getLoopDomain().at(axis)->isDeviceDim() || + tv->getLoopDomain().at(axis)->isBroadcast() || + tv->getLoopDomain().at(axis)->isOneInt(), + "When using register sharing with warp-specialized circular " + "buffering, the circular buffer loop must be the outer-most " + "for-loop."); + } + } + return; } diff --git a/csrc/device_lower/analysis/device_version.cpp b/csrc/device_lower/analysis/device_version.cpp index 4682adfaf75..faf16e91ad6 100644 --- a/csrc/device_lower/analysis/device_version.cpp +++ b/csrc/device_lower/analysis/device_version.cpp @@ -69,6 +69,15 @@ void MinimumDeviceVersion::handle(LoadStoreOp* ls_op) { } } +void MinimumDeviceVersion::handle(TensorView* tv) { + if (std::holds_alternative( + tv->circularBufferOptions().type)) { + ensureVersion( + {9, 0}, + "Warp Specialized Circular Buffering uses the setmaxnreg ptx instruction, which requires Hopper (9.0) or newer"); + } +} + void MinimumDeviceVersion::ensureVersion( std::pair version, std::string reason) { diff --git a/csrc/device_lower/analysis/device_version.h b/csrc/device_lower/analysis/device_version.h index 5fef6b36333..3ebe3a4fa34 100644 --- a/csrc/device_lower/analysis/device_version.h +++ b/csrc/device_lower/analysis/device_version.h @@ -48,6 +48,11 @@ class MinimumDeviceVersion : private IterVisitor { //! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async void handle(LoadStoreOp* ls_op) final; + //! If TensorView has warp specialized circular buffering, it will use the + //! setmaxnreg ptx instruction that requires Hopper (9.0+). + //! https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg + void handle(TensorView* tv) final; + //! bump min_version_ to at least this value void ensureVersion(std::pair version, std::string reason); diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 15ed808b936..84de3c62bf5 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1394,11 +1394,48 @@ class CircularBufferInserter : private kir::ExprMutator { warp_specialize_on), circular_buffer_loop->fusion()->oneVal())))); + // Set default value + auto& circular_buffer_options = + GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( + circular_buffer_loop->iter_domain()); + bool enable_register_sharing = + std::holds_alternative(circular_buffer_options.type) && + std::get(circular_buffer_options.type) + .num_registers.has_value(); + + GpuLower::current()->kernel()->manage( + "enable_register_sharing", enable_register_sharing); + + if (enable_register_sharing) { + auto&& [decrease_num_registers, increase_num_registers] = + std::get(circular_buffer_options.type) + .num_registers.value(); + + // Decrease registers in load warp group + kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create( + IrBuilder::create(decrease_num_registers, DataType::Index), + /*increase_registers=*/false); + warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp); + + // Increase registers in compute warp group + kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create( + IrBuilder::create(increase_num_registers, DataType::Index), + /*increase_registers*/ true); + warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp); + } + // Load loop: ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp); warp_dispatch_ite->thenBody().push_back(load_loop); + if (enable_register_sharing) { + // Terminate the warp group handling Load loop immediately after + // finishing its work. + kir::Return* ret = IrBuilder::create(); + warp_dispatch_ite->thenBody().push_back(ret); + } + // Prefetch: auto prefetch_loop = createArrivesForWar(circular_buffer_loop); warp_dispatch_ite->elseBody().push_back(prefetch_loop); diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 4e53a6207cc..98236fd3c5f 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -227,11 +227,41 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) { } struct WarpSpecialized { - ParallelType on; - explicit WarpSpecialized(ParallelType on) : on(on) {} + ParallelType on = ParallelType::Serial; + // The number of registers for load and compute warps respectively. + std::optional> num_registers = std::nullopt; + + explicit WarpSpecialized( + ParallelType on, + std::pair num_registers) + : on(on), num_registers(num_registers) { + validateRegisterSharing(); + } + explicit WarpSpecialized(ParallelType on) + : on(on), num_registers(std::nullopt) {} WarpSpecialized() = default; + + void validateRegisterSharing() { + // short-circuit: register sharing is not used. + if (!num_registers.has_value()) { + return; + } + auto validate_num_registers = [](int64_t a) { + NVF_ERROR( + a >= 24 && a <= 256 && a % 8 == 0, + "The number of registers for setmaxnreg must be between 24 and", + " 256 (inclusive) and be a multiple of 8."); + }; + validate_num_registers(num_registers.value().first); + validate_num_registers(num_registers.value().second); + NVF_ERROR( + num_registers.value().first <= num_registers.value().second, + "The number of registers for load warp group must be <= to the number", + " of registers for the compute warp groups."); + } + bool operator==(const WarpSpecialized& other) const { - return on == other.on; + return on == other.on && num_registers == other.num_registers; } }; @@ -252,7 +282,15 @@ inline std::ostream& operator<<( default: NVF_THROW("Invalid parallel type"); } - return os << "WarpSpecializedOn" << parallel_type_str; + std::string num_registers = "RegisterSharing_None"; + if (warp_specialized.num_registers.has_value()) { + auto&& [decrease_num_reg, increase_num_reg] = + warp_specialized.num_registers.value(); + std::stringstream s; + s << "RegisterSharing_" << decrease_num_reg << "_" << increase_num_reg; + num_registers = s.str(); + } + return os << "WarpSpecializedOn" << parallel_type_str << num_registers; } using CircularBufferType = std::variant; diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index 04f86b1edd0..36a0f9fc966 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -461,7 +461,19 @@ void KernelExecutor::compile( } } - kernel_code_ = codegen::generateCudaKernel(kernel, kernelName()); + // TODO: pass block_size here; + std::optional dynamic_smem = std::nullopt; + std::optional block_size = std::nullopt; + if (!args.empty()) { + auto expr_eval = executor_utils::bindInputs(args, kernel); + auto launch_params = computeLaunchParams( + launch_constraints, expr_eval, warp_size_, kernel->indexType()); + block_size = launch_params.nThreads(); + dynamic_smem = launch_params.smem(); + NVF_ERROR(block_size > 0, "launch param inferred block size < 0"); + } + + kernel_code_ = codegen::generateCudaKernel(kernel, kernelName(), block_size); // If NVFUSER_EXTERNAL_SRC is set, utilize the external source code. // If the loaded external source code is empty, revert to the default codegen. @@ -525,18 +537,6 @@ void KernelExecutor::compile( NVF_THROW(ss.str()); } - // TODO: pass block_size here; - std::optional dynamic_smem = std::nullopt; - std::optional block_size = std::nullopt; - if (!args.empty()) { - auto expr_eval = executor_utils::bindInputs(args, kernel); - auto launch_params = computeLaunchParams( - launch_constraints, expr_eval, warp_size_, kernel->indexType()); - block_size = launch_params.nThreads(); - dynamic_smem = launch_params.smem(); - NVF_ERROR(block_size > 0, "launch param inferred block size < 0"); - } - // TODO: high water mark should be computed via occupancy API after // compilation. diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index a5eb595af9c..193f0382727 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -12,9 +12,154 @@ #include #include #include +#include namespace nvfuser { +TEST_F(NVFuserTest, RegisterSharingCircularBufferingPointwiseCustom) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t number_of_stages = 4; + int64_t prefetch_distance = 1; + int64_t tensor_outer_dim = 128; + int64_t tensor_inner_dim = 128; + CircularBufferType circular_buffer_type = + WarpSpecialized(ParallelType::TIDy, std::make_pair(160L, 160L)); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + TensorView* tv4 = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + + // Circular Buffer with TMA loads + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Load TV1 into shared memory + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(2)->parallelize(ParallelType::Bulk); + tv4->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + KernelExecutor ke; + ke.compile(fusion.get(), {t0, t1}); + + std::vector cg_outputs = ke.run({t0, t1}); + testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, RegisterSharingCircularBufferingPointwiseNested) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t number_of_stages = 4; + int64_t prefetch_distance = 1; + int64_t tensor_outer_dim = 128; + int64_t tensor_inner_dim = 128; + CircularBufferType circular_buffer_type = + WarpSpecialized(ParallelType::TIDy, std::make_pair(160L, 160L)); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + TensorView* tv4 = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + + // Circular Buffer with TMA loads + // tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Load TV1 into shared memory + // tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(2)->parallelize(ParallelType::Bulk); + tv4->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + // reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + KernelExecutor ke; + try { + ke.compile(fusion.get(), {t0, t1}); + } catch (const std::exception& e) { + const char* reference = + R"(When using register sharing with warp-specialized circular buffering, the circular buffer loop must be the outer-most for-loop.)"; + const char* str_match_pointer = strstr(e.what(), reference); + ASSERT_TRUE(str_match_pointer != nullptr); + } +} + using StageAndPrefetch = std::pair; class CircularBufferingTest : public NVFuserFixtureParamTest { @@ -1855,7 +2000,9 @@ auto tmaCircularBufferingParams() { Pipelined(false), Pipelined(true), WarpSpecialized(ParallelType::TIDx), - WarpSpecialized(ParallelType::TIDy)}; + WarpSpecialized(ParallelType::TIDy), + WarpSpecialized(ParallelType::TIDx, std::make_pair(40, 240)), + WarpSpecialized(ParallelType::TIDy, std::make_pair(40, 240))}; std::vector values; for (int64_t i : {2, 4}) { for (int64_t j : c10::irange(-i, i)) {