diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 6fca981c191..1bb287cc1e0 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -158,9 +158,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { public: static std::string generateKernelDefinition( const kir::Kernel* kernel, - const std::string& kernel_name) { + const std::string& kernel_name, + std::optional num_threads_per_cta) { CudaKernelGenerator codegen(kernel); - codegen.genDeclaration(kernel_name); + codegen.genDeclaration(kernel_name, num_threads_per_cta); codegen.startBlock(); codegen.genPrologue(); codegen.genBody(); @@ -272,10 +273,18 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } // Generates the kernel function declaration - void genDeclaration(const std::string& kernel_name) { + void genDeclaration( + const std::string& kernel_name, + std::optional num_threads_per_cta) { code_ << "__global__ void "; - // TODO Fix hardcoded values - code_ << "__launch_bounds__(384, 1) "; + if (kernel_->hasManaged("enable_register_sharing") && + kernel_->getManaged("enable_register_sharing")) { + NVF_ERROR( + num_threads_per_cta.has_value(), + "__launch_bounds__ must be set for register sharing warp specialization"); + code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/" + << num_threads_per_cta.value() << ") "; + } if (kernel_->hasManaged("cluster_dims")) { auto cluster_dims = kernel_->getManaged>( @@ -3550,9 +3559,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { std::string generateCudaKernel( const kir::Kernel* kernel, - const std::string& kernel_name) { + const std::string& kernel_name, + std::optional num_threads_per_cta) { FUSER_PERF_SCOPE("generateCudaKernel"); - return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name); + return CudaKernelGenerator::generateKernelDefinition( + kernel, kernel_name, num_threads_per_cta); } } // namespace codegen diff --git a/csrc/codegen.h b/csrc/codegen.h index 8c0e89663d1..e2f1382c8d2 100644 --- a/csrc/codegen.h +++ b/csrc/codegen.h @@ -19,7 +19,8 @@ namespace codegen { //! Generates a CUDA kernel definition for the given kernel NVF_API std::string generateCudaKernel( const kir::Kernel* kernel, - const std::string& kernel_name = "CUDAGeneratedKernel"); + const std::string& kernel_name = "CUDAGeneratedKernel", + std::optional num_threads_per_cta = std::nullopt); } // namespace codegen } // namespace nvfuser diff --git a/csrc/device_lower/analysis/circular_buffer.cpp b/csrc/device_lower/analysis/circular_buffer.cpp index 58f35a1f8f0..21f58c8d43d 100644 --- a/csrc/device_lower/analysis/circular_buffer.cpp +++ b/csrc/device_lower/analysis/circular_buffer.cpp @@ -143,6 +143,29 @@ void validateCircularBufferedTensor(const TensorView* tv) { ". Consumer memory type: ", c_mem_type); + // Ensure that the warp-specialized circular buffer loop is the outer-most + // for-loop if register sharing is enabled. + if (std::holds_alternative( + tv->circularBufferOptions().type) && + std::get(tv->circularBufferOptions().type) + .num_registers.has_value()) { + for (int64_t axis : c10::irange((int64_t)tv->getLoopDomain().size())) { + // short-circuit: only check IterDomains to the left of the circular + // buffer position + if (axis >= circular_buffer_pos) { + break; + } + NVF_ERROR( + tv->getLoopDomain().at(axis)->isThread() || + tv->getLoopDomain().at(axis)->isDeviceDim() || + tv->getLoopDomain().at(axis)->isBroadcast() || + tv->getLoopDomain().at(axis)->isOneInt(), + "When using register sharing with warp-specialized circular " + "buffering, the circular buffer loop must be the outer-most " + "for-loop."); + } + } + return; } diff --git a/csrc/device_lower/analysis/device_version.cpp b/csrc/device_lower/analysis/device_version.cpp index 4682adfaf75..faf16e91ad6 100644 --- a/csrc/device_lower/analysis/device_version.cpp +++ b/csrc/device_lower/analysis/device_version.cpp @@ -69,6 +69,15 @@ void MinimumDeviceVersion::handle(LoadStoreOp* ls_op) { } } +void MinimumDeviceVersion::handle(TensorView* tv) { + if (std::holds_alternative( + tv->circularBufferOptions().type)) { + ensureVersion( + {9, 0}, + "Warp Specialized Circular Buffering uses the setmaxnreg ptx instruction, which requires Hopper (9.0) or newer"); + } +} + void MinimumDeviceVersion::ensureVersion( std::pair version, std::string reason) { diff --git a/csrc/device_lower/analysis/device_version.h b/csrc/device_lower/analysis/device_version.h index 5fef6b36333..3ebe3a4fa34 100644 --- a/csrc/device_lower/analysis/device_version.h +++ b/csrc/device_lower/analysis/device_version.h @@ -48,6 +48,11 @@ class MinimumDeviceVersion : private IterVisitor { //! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async void handle(LoadStoreOp* ls_op) final; + //! If TensorView has warp specialized circular buffering, it will use the + //! setmaxnreg ptx instruction that requires Hopper (9.0+). + //! https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg + void handle(TensorView* tv) final; + //! bump min_version_ to at least this value void ensureVersion(std::pair version, std::string reason); diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 3eff299d462..84de3c62bf5 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1394,23 +1394,47 @@ class CircularBufferInserter : private kir::ExprMutator { warp_specialize_on), circular_buffer_loop->fusion()->oneVal())))); - kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create( - IrBuilder::create(24, DataType::Index), - /*increase_registers=*/false); - warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp); + // Set default value + auto& circular_buffer_options = + GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( + circular_buffer_loop->iter_domain()); + bool enable_register_sharing = + std::holds_alternative(circular_buffer_options.type) && + std::get(circular_buffer_options.type) + .num_registers.has_value(); + + GpuLower::current()->kernel()->manage( + "enable_register_sharing", enable_register_sharing); + + if (enable_register_sharing) { + auto&& [decrease_num_registers, increase_num_registers] = + std::get(circular_buffer_options.type) + .num_registers.value(); + + // Decrease registers in load warp group + kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create( + IrBuilder::create(decrease_num_registers, DataType::Index), + /*increase_registers=*/false); + warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp); + + // Increase registers in compute warp group + kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create( + IrBuilder::create(increase_num_registers, DataType::Index), + /*increase_registers*/ true); + warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp); + } // Load loop: ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp); warp_dispatch_ite->thenBody().push_back(load_loop); - kir::Return* ret = IrBuilder::create(); - warp_dispatch_ite->thenBody().push_back(ret); - - kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create( - IrBuilder::create(240, DataType::Index), - /*increase_registers*/ true); - warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp); + if (enable_register_sharing) { + // Terminate the warp group handling Load loop immediately after + // finishing its work. + kir::Return* ret = IrBuilder::create(); + warp_dispatch_ite->thenBody().push_back(ret); + } // Prefetch: auto prefetch_loop = createArrivesForWar(circular_buffer_loop); diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 4e53a6207cc..98236fd3c5f 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -227,11 +227,41 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) { } struct WarpSpecialized { - ParallelType on; - explicit WarpSpecialized(ParallelType on) : on(on) {} + ParallelType on = ParallelType::Serial; + // The number of registers for load and compute warps respectively. + std::optional> num_registers = std::nullopt; + + explicit WarpSpecialized( + ParallelType on, + std::pair num_registers) + : on(on), num_registers(num_registers) { + validateRegisterSharing(); + } + explicit WarpSpecialized(ParallelType on) + : on(on), num_registers(std::nullopt) {} WarpSpecialized() = default; + + void validateRegisterSharing() { + // short-circuit: register sharing is not used. + if (!num_registers.has_value()) { + return; + } + auto validate_num_registers = [](int64_t a) { + NVF_ERROR( + a >= 24 && a <= 256 && a % 8 == 0, + "The number of registers for setmaxnreg must be between 24 and", + " 256 (inclusive) and be a multiple of 8."); + }; + validate_num_registers(num_registers.value().first); + validate_num_registers(num_registers.value().second); + NVF_ERROR( + num_registers.value().first <= num_registers.value().second, + "The number of registers for load warp group must be <= to the number", + " of registers for the compute warp groups."); + } + bool operator==(const WarpSpecialized& other) const { - return on == other.on; + return on == other.on && num_registers == other.num_registers; } }; @@ -252,7 +282,15 @@ inline std::ostream& operator<<( default: NVF_THROW("Invalid parallel type"); } - return os << "WarpSpecializedOn" << parallel_type_str; + std::string num_registers = "RegisterSharing_None"; + if (warp_specialized.num_registers.has_value()) { + auto&& [decrease_num_reg, increase_num_reg] = + warp_specialized.num_registers.value(); + std::stringstream s; + s << "RegisterSharing_" << decrease_num_reg << "_" << increase_num_reg; + num_registers = s.str(); + } + return os << "WarpSpecializedOn" << parallel_type_str << num_registers; } using CircularBufferType = std::variant; diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index d2f0d9277d2..ed8986ff817 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -80,9 +80,9 @@ TensorView* triu(TensorView* tv, Val* offset) { NVF_CHECK( dims >= 2, - "triu is only supported for 2+D tensors, but got ", + "input tensor for triu must have 2 or more dims, but got ", dims, - "D tensor"); + " dims"); auto fusion = tv->fusion(); diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index ea061b094f1..31996b58e15 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -1693,6 +1693,52 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag) #undef NVFUSER_PYTHON_BINDING_UNARY_OP + nvf_ops.def( + "triu", + [](FusionDefinition::Operators& self, + Tensor input, + int64_t diagonal) -> Tensor { + FUSER_PERF_SCOPE("Operators.triu"); + NVF_CHECK( + self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + Tensor output = fd->defineTensor(input.dims); + + auto diagonal_ = fd->defineScalar(); + fd->defineRecord(new ScalarRecord( + {fd->recordingState(diagonal_())}, diagonal, DataType::Int, true)); + + fd->defineRecord(new OpRecord( + {fd->recordingState(input()), fd->recordingState(diagonal_())}, + {fd->recordingState(output())}, + ("ops.triu"), + serde::RecordType::Binary_TV_VAL, + static_cast(triu))); + + return output; + }, + py::arg("input"), + py::arg("diagonal") = 0, + py::return_value_policy::reference, + R"doc( + Returns the upper triangular part of a 2+D tensor. + + Parameters + ---------- + input : Tensor + The input tensor. + diagonal : int, optional + The diagonal to consider. Default is 0. + + Returns + ------- + Tensor + The upper triangular part of the input tensor. + + >>> a = torch.randn(3, 3) + >>> fd.ops.triu(a) + )doc"); + // overload to nvf_ops.def( "stride_order", diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index 04f86b1edd0..36a0f9fc966 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -461,7 +461,19 @@ void KernelExecutor::compile( } } - kernel_code_ = codegen::generateCudaKernel(kernel, kernelName()); + // TODO: pass block_size here; + std::optional dynamic_smem = std::nullopt; + std::optional block_size = std::nullopt; + if (!args.empty()) { + auto expr_eval = executor_utils::bindInputs(args, kernel); + auto launch_params = computeLaunchParams( + launch_constraints, expr_eval, warp_size_, kernel->indexType()); + block_size = launch_params.nThreads(); + dynamic_smem = launch_params.smem(); + NVF_ERROR(block_size > 0, "launch param inferred block size < 0"); + } + + kernel_code_ = codegen::generateCudaKernel(kernel, kernelName(), block_size); // If NVFUSER_EXTERNAL_SRC is set, utilize the external source code. // If the loaded external source code is empty, revert to the default codegen. @@ -525,18 +537,6 @@ void KernelExecutor::compile( NVF_THROW(ss.str()); } - // TODO: pass block_size here; - std::optional dynamic_smem = std::nullopt; - std::optional block_size = std::nullopt; - if (!args.empty()) { - auto expr_eval = executor_utils::bindInputs(args, kernel); - auto launch_params = computeLaunchParams( - launch_constraints, expr_eval, warp_size_, kernel->indexType()); - block_size = launch_params.nThreads(); - dynamic_smem = launch_params.smem(); - NVF_ERROR(block_size > 0, "launch param inferred block size < 0"); - } - // TODO: high water mark should be computed via occupancy API after // compilation. diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index 598482b76c9..e051607257a 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -489,7 +489,8 @@ void AmpereMultipleMatmulScheduler::cacheInputsAndOutputs() { scheduler_utils::clearMemorySpace(fusion_); // Cache inputs - scheduler_utils::cacheInputs(fusion_, /*unroll=*/true); + scheduler_utils::cacheInputs( + fusion_, /*unroll=*/true, /*propagate_allocation=*/true); // Cache and fork outputs cached_outputs_ = diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index a68b6714980..93e9321bcfa 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -95,7 +95,8 @@ void HopperMultipleMatmulScheduler::cacheInputsAndOutputs() { scheduler_utils::clearMemorySpace(fusion_); // Cache inputs - scheduler_utils::cacheInputs(fusion_, /*unroll=*/true); + scheduler_utils::cacheInputs( + fusion_, /*unroll=*/true, /*propagate_allocation=*/true); // Cache and fork outputs scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index cd22d935a52..6b840c96b8c 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1187,7 +1187,10 @@ void clearMemorySpace(Fusion* fusion) { // Returns cached after tensors of the fusion inputs if unrolled. Otherwise // return empty vector. -std::vector cacheInputs(Fusion* fusion, bool unroll) { +std::vector cacheInputs( + Fusion* fusion, + bool unroll, + bool propagate_allocation) { if (!unroll) { return {}; } @@ -1224,10 +1227,10 @@ std::vector cacheInputs(Fusion* fusion, bool unroll) { } auto cached_tv = tv->cacheAfter( - /*op_type=*/LoadStoreOpType::Set, - /*cache_op=*/CacheOp::Unspecified, - /*propagate_allocation_domain=*/true, - /*cached_uses=*/cached_uses); + LoadStoreOpType::Set, + CacheOp::Unspecified, + propagate_allocation, + cached_uses); cached_inputs.emplace_back(cached_tv); } return cached_inputs; diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 62a359816d2..dbad708f003 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -334,7 +334,10 @@ void clearMemorySpace(Fusion* fusion); // Returns cached after tensors of the fusion inputs if unrolled. Otherwise // return empty vector. -std::vector cacheInputs(Fusion* fusion, bool unroll); +std::vector cacheInputs( + Fusion* fusion, + bool unroll, + bool propagate_allocation = false); // Returns the pairs of for // all outputs. diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index 7e23adf2b69..253270f0b9d 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -653,6 +653,11 @@ void RecordFunctorFactory::setupFunctionMaps() { ("ops." op_str), static_cast(op_name)); \ unary_val.emplace(("ops." op_str), static_cast(op_name)); +#define NVFUSER_UNARY_TV_ALPHA_OP(op_str, op_name) \ + binary_tv_val.emplace( \ + ("ops." op_str), \ + static_cast(op_name)); + #define NVFUSER_BINARY_TV_ONLY_OP(op_str, op_name) \ binary_tv.emplace( \ ("ops." op_str), \ @@ -808,6 +813,8 @@ void RecordFunctorFactory::setupFunctionMaps() { NVFUSER_UNARY_TV_OP("real", real) NVFUSER_UNARY_TV_OP("imag", imag) + NVFUSER_UNARY_TV_ALPHA_OP("triu", triu) + NVFUSER_BINARY_TV_ONLY_OP("matmul", matmul) NVFUSER_BINARY_TV_ONLY_OP("linear", linear) NVFUSER_TERNARY_TV_ONLY_OP("linear", linear) diff --git a/tests/cpp/test_allocation_domain.cpp b/tests/cpp/test_allocation_domain.cpp index bff62bb98e1..a726dd6b262 100644 --- a/tests/cpp/test_allocation_domain.cpp +++ b/tests/cpp/test_allocation_domain.cpp @@ -1426,11 +1426,9 @@ TEST_F(AllocationDomainTest, InputAllocationIsSplit_Concrete) { fusion->addInput(in); fusion->addOutput(out); - // Ideally, loop should stay the same as logical because a fusion input comes - // from outside and isn't generated by a loop in the containing kernel (cf. - // #3479). in->split(0, 2); in->setAllocationDomain(in->getLoopDomain(), true); + in->setLoopDomain(in->getLogicalDomain()); FusionExecutorCache executor_cache(std::move(fusion)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index a5eb595af9c..193f0382727 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -12,9 +12,154 @@ #include #include #include +#include namespace nvfuser { +TEST_F(NVFuserTest, RegisterSharingCircularBufferingPointwiseCustom) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t number_of_stages = 4; + int64_t prefetch_distance = 1; + int64_t tensor_outer_dim = 128; + int64_t tensor_inner_dim = 128; + CircularBufferType circular_buffer_type = + WarpSpecialized(ParallelType::TIDy, std::make_pair(160L, 160L)); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + TensorView* tv4 = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + + // Circular Buffer with TMA loads + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Load TV1 into shared memory + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(2)->parallelize(ParallelType::Bulk); + tv4->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + KernelExecutor ke; + ke.compile(fusion.get(), {t0, t1}); + + std::vector cg_outputs = ke.run({t0, t1}); + testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, RegisterSharingCircularBufferingPointwiseNested) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int64_t number_of_stages = 4; + int64_t prefetch_distance = 1; + int64_t tensor_outer_dim = 128; + int64_t tensor_inner_dim = 128; + CircularBufferType circular_buffer_type = + WarpSpecialized(ParallelType::TIDy, std::make_pair(160L, 160L)); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + TensorView* tv4 = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + + // Circular Buffer with TMA loads + // tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Load TV1 into shared memory + // tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(2)->parallelize(ParallelType::Bulk); + tv4->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + // reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + KernelExecutor ke; + try { + ke.compile(fusion.get(), {t0, t1}); + } catch (const std::exception& e) { + const char* reference = + R"(When using register sharing with warp-specialized circular buffering, the circular buffer loop must be the outer-most for-loop.)"; + const char* str_match_pointer = strstr(e.what(), reference); + ASSERT_TRUE(str_match_pointer != nullptr); + } +} + using StageAndPrefetch = std::pair; class CircularBufferingTest : public NVFuserFixtureParamTest { @@ -1855,7 +2000,9 @@ auto tmaCircularBufferingParams() { Pipelined(false), Pipelined(true), WarpSpecialized(ParallelType::TIDx), - WarpSpecialized(ParallelType::TIDy)}; + WarpSpecialized(ParallelType::TIDy), + WarpSpecialized(ParallelType::TIDx, std::make_pair(40, 240)), + WarpSpecialized(ParallelType::TIDy, std::make_pair(40, 240))}; std::vector values; for (int64_t i : {2, 4}) { for (int64_t j : c10::irange(-i, i)) { diff --git a/tests/python/opinfo_fusion_definitions.py b/tests/python/opinfo_fusion_definitions.py index 95abad9b7f4..768aa3a2953 100644 --- a/tests/python/opinfo_fusion_definitions.py +++ b/tests/python/opinfo_fusion_definitions.py @@ -28,7 +28,7 @@ def parse_inputs_fusion_definition(fd: FusionDefinition, opinfo: OpInfo, *args): ) num_symbolic_parameters = len(symbolic_parameter_list) - assert num_symbolic_parameters == len( + assert num_symbolic_parameters >= len( args ), f"{num_symbolic_parameters} vs {len(args)}" diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index d3222aea4b4..472d5109059 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -1591,3 +1591,39 @@ def div_input_generator( denom = torch.where(denom_is_small, denom_scaled_to_minabs, denom).detach() denom.requires_grad_(requires_grad) yield SampleInput(numer, denom) + + +def triu_input_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False): + offsets = (0, 1, -1, 2, 3, -3, 1024, -1024) + + for element in elementwise_unary_generator( + op, + dtype, + requires_grad, + enable_extremal_value_testing=False, + enable_large_value_testing=False, + enable_small_value_testing=False, + ): + if element.args[0].ndim < 2: + continue + # to test cases where offset is not passed as an argument + yield element + # to test cases where offset is passed as an argument + for offset in offsets: + yield SampleInput(*element.args, offset) + + +def triu_error_generator(op: OpInfo, dtype: torch.dtype, requires_grad: bool = False): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + + invalid_shapes = ( + (), + (4,), + ) + + for shape in invalid_shapes: + yield SampleInput( + make_arg(shape), + ), RuntimeError, f"input tensor for triu must have 2 or more dims, but got {len(shape)} dims" diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index 9031a9bd091..f0bbd649b87 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -50,6 +50,8 @@ matmul_input_generator, linear_input_generator, linear_error_generator, + triu_input_generator, + triu_error_generator, ) from utils import ( bool_int_dtypes, @@ -1218,6 +1220,19 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): ) linear_ops.append(linear_opinfo) +tv_val_ops = [] + +triu_opinfo = OpInfo( + lambda fd: fd.ops.triu, + "triu", + sample_input_generator=triu_input_generator, + error_input_generator=triu_error_generator, + reference=torch.triu, + symbolic_parameter_list=[ArgumentType.Symbolic, ArgumentType.Constant], +) + +tv_val_ops.append(triu_opinfo) + """ End Tensor Creation """ # Puts all opinfos into the "opinfos" list @@ -1231,3 +1246,4 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): opinfos.extend(tensor_creation_ops) opinfos.extend(matmul_ops) opinfos.extend(linear_ops) +opinfos.extend(tv_val_ops) diff --git a/tests/python/test_ops.py b/tests/python/test_ops.py index d653e005736..bc842ea29dc 100644 --- a/tests/python/test_ops.py +++ b/tests/python/test_ops.py @@ -63,7 +63,7 @@ def parse_args_fusion_execution(opinfo: OpInfo, *args): else [ArgumentType.Symbolic] * len(args) ) - assert len(symbolic_parameter_list) == len(args) + assert len(symbolic_parameter_list) >= len(args) result = [] for arg_type, a in zip(symbolic_parameter_list, args): diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 0b74fddeae6..7b7e1a40218 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -1204,6 +1204,20 @@ def fusion_func(fd: FusionDefinition): self.assertEqual(eager_out2, nvf_out[1]) # self.assertEqual(eager_out3, nvf_out[2]) + def test_triu(self): + inputs = [ + torch.randn(4, 16, device="cuda", dtype=torch.float16), + ] + + def fusion_func(fd: FusionDefinition) -> None: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.ops.triu(t0, -1) + fd.add_output(t1) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out0 = torch.triu(inputs[0], -1) + self.assertEqual(eager_out0, nvf_out[0]) + def test_complex_rsqrt(self): inputs = [ torch.randn(4, device="cuda", dtype=torch.complex64),