Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/nvidia/fuser into hopper_ma…
Browse files Browse the repository at this point in the history
…tmul_cta_k_fix
  • Loading branch information
rdspring1 committed Jan 9, 2025
2 parents 7966599 + 72402bf commit e3826e2
Show file tree
Hide file tree
Showing 22 changed files with 435 additions and 52 deletions.
25 changes: 18 additions & 7 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
public:
static std::string generateKernelDefinition(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
CudaKernelGenerator codegen(kernel);
codegen.genDeclaration(kernel_name);
codegen.genDeclaration(kernel_name, num_threads_per_cta);
codegen.startBlock();
codegen.genPrologue();
codegen.genBody();
Expand Down Expand Up @@ -272,10 +273,18 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
void genDeclaration(
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
code_ << "__global__ void ";
// TODO Fix hardcoded values
code_ << "__launch_bounds__(384, 1) ";
if (kernel_->hasManaged("enable_register_sharing") &&
kernel_->getManaged<bool>("enable_register_sharing")) {
NVF_ERROR(
num_threads_per_cta.has_value(),
"__launch_bounds__ must be set for register sharing warp specialization");
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
Expand Down Expand Up @@ -3550,9 +3559,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
FUSER_PERF_SCOPE("generateCudaKernel");
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
return CudaKernelGenerator::generateKernelDefinition(
kernel, kernel_name, num_threads_per_cta);
}

} // namespace codegen
Expand Down
3 changes: 2 additions & 1 deletion csrc/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace codegen {
//! Generates a CUDA kernel definition for the given kernel
NVF_API std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name = "CUDAGeneratedKernel");
const std::string& kernel_name = "CUDAGeneratedKernel",
std::optional<int64_t> num_threads_per_cta = std::nullopt);

} // namespace codegen
} // namespace nvfuser
23 changes: 23 additions & 0 deletions csrc/device_lower/analysis/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ void validateCircularBufferedTensor(const TensorView* tv) {
". Consumer memory type: ",
c_mem_type);

// Ensure that the warp-specialized circular buffer loop is the outer-most
// for-loop if register sharing is enabled.
if (std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type) &&
std::get<WarpSpecialized>(tv->circularBufferOptions().type)
.num_registers.has_value()) {
for (int64_t axis : c10::irange((int64_t)tv->getLoopDomain().size())) {
// short-circuit: only check IterDomains to the left of the circular
// buffer position
if (axis >= circular_buffer_pos) {
break;
}
NVF_ERROR(
tv->getLoopDomain().at(axis)->isThread() ||
tv->getLoopDomain().at(axis)->isDeviceDim() ||
tv->getLoopDomain().at(axis)->isBroadcast() ||
tv->getLoopDomain().at(axis)->isOneInt(),
"When using register sharing with warp-specialized circular "
"buffering, the circular buffer loop must be the outer-most "
"for-loop.");
}
}

return;
}

Expand Down
9 changes: 9 additions & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ void MinimumDeviceVersion::handle(LoadStoreOp* ls_op) {
}
}

void MinimumDeviceVersion::handle(TensorView* tv) {
if (std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type)) {
ensureVersion(
{9, 0},
"Warp Specialized Circular Buffering uses the setmaxnreg ptx instruction, which requires Hopper (9.0) or newer");
}
}

void MinimumDeviceVersion::ensureVersion(
std::pair<int, int> version,
std::string reason) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/device_lower/analysis/device_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class MinimumDeviceVersion : private IterVisitor {
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
void handle(LoadStoreOp* ls_op) final;

//! If TensorView has warp specialized circular buffering, it will use the
//! setmaxnreg ptx instruction that requires Hopper (9.0+).
//! https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
void handle(TensorView* tv) final;

//! bump min_version_ to at least this value
void ensureVersion(std::pair<int, int> version, std::string reason);

Expand Down
46 changes: 35 additions & 11 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,23 +1394,47 @@ class CircularBufferInserter : private kir::ExprMutator {
warp_specialize_on),
circular_buffer_loop->fusion()->oneVal()))));

kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(24, DataType::Index),
/*increase_registers=*/false);
warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp);
// Set default value
auto& circular_buffer_options =
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
circular_buffer_loop->iter_domain());
bool enable_register_sharing =
std::holds_alternative<WarpSpecialized>(circular_buffer_options.type) &&
std::get<WarpSpecialized>(circular_buffer_options.type)
.num_registers.has_value();

GpuLower::current()->kernel()->manage(
"enable_register_sharing", enable_register_sharing);

if (enable_register_sharing) {
auto&& [decrease_num_registers, increase_num_registers] =
std::get<WarpSpecialized>(circular_buffer_options.type)
.num_registers.value();

// Decrease registers in load warp group
kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(decrease_num_registers, DataType::Index),
/*increase_registers=*/false);
warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp);

// Increase registers in compute warp group
kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(increase_num_registers, DataType::Index),
/*increase_registers*/ true);
warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp);
}

// Load loop:
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
warp_dispatch_ite->thenBody().push_back(load_loop);

kir::Return* ret = IrBuilder::create<kir::Return>();
warp_dispatch_ite->thenBody().push_back(ret);

kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(240, DataType::Index),
/*increase_registers*/ true);
warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp);
if (enable_register_sharing) {
// Terminate the warp group handling Load loop immediately after
// finishing its work.
kir::Return* ret = IrBuilder::create<kir::Return>();
warp_dispatch_ite->thenBody().push_back(ret);
}

// Prefetch:
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
Expand Down
46 changes: 42 additions & 4 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,41 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) {
}

struct WarpSpecialized {
ParallelType on;
explicit WarpSpecialized(ParallelType on) : on(on) {}
ParallelType on = ParallelType::Serial;
// The number of registers for load and compute warps respectively.
std::optional<std::pair<int64_t, int64_t>> num_registers = std::nullopt;

explicit WarpSpecialized(
ParallelType on,
std::pair<int64_t, int64_t> num_registers)
: on(on), num_registers(num_registers) {
validateRegisterSharing();
}
explicit WarpSpecialized(ParallelType on)
: on(on), num_registers(std::nullopt) {}
WarpSpecialized() = default;

void validateRegisterSharing() {
// short-circuit: register sharing is not used.
if (!num_registers.has_value()) {
return;
}
auto validate_num_registers = [](int64_t a) {
NVF_ERROR(
a >= 24 && a <= 256 && a % 8 == 0,
"The number of registers for setmaxnreg must be between 24 and",
" 256 (inclusive) and be a multiple of 8.");
};
validate_num_registers(num_registers.value().first);
validate_num_registers(num_registers.value().second);
NVF_ERROR(
num_registers.value().first <= num_registers.value().second,
"The number of registers for load warp group must be <= to the number",
" of registers for the compute warp groups.");
}

bool operator==(const WarpSpecialized& other) const {
return on == other.on;
return on == other.on && num_registers == other.num_registers;
}
};

Expand All @@ -252,7 +282,15 @@ inline std::ostream& operator<<(
default:
NVF_THROW("Invalid parallel type");
}
return os << "WarpSpecializedOn" << parallel_type_str;
std::string num_registers = "RegisterSharing_None";
if (warp_specialized.num_registers.has_value()) {
auto&& [decrease_num_reg, increase_num_reg] =
warp_specialized.num_registers.value();
std::stringstream s;
s << "RegisterSharing_" << decrease_num_reg << "_" << increase_num_reg;
num_registers = s.str();
}
return os << "WarpSpecializedOn" << parallel_type_str << num_registers;
}

using CircularBufferType = std::variant<Pipelined, WarpSpecialized>;
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ TensorView* triu(TensorView* tv, Val* offset) {

NVF_CHECK(
dims >= 2,
"triu is only supported for 2+D tensors, but got ",
"input tensor for triu must have 2 or more dims, but got ",
dims,
"D tensor");
" dims");

auto fusion = tv->fusion();

Expand Down
46 changes: 46 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,52 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
#undef NVFUSER_PYTHON_BINDING_UNARY_OP

nvf_ops.def(
"triu",
[](FusionDefinition::Operators& self,
Tensor input,
int64_t diagonal) -> Tensor {
FUSER_PERF_SCOPE("Operators.triu");
NVF_CHECK(
self.validUse(), "Attempting to add to a completed definition!");
FusionDefinition* fd = self.fusion_definition;
Tensor output = fd->defineTensor(input.dims);

auto diagonal_ = fd->defineScalar();
fd->defineRecord(new ScalarRecord(
{fd->recordingState(diagonal_())}, diagonal, DataType::Int, true));

fd->defineRecord(new OpRecord<TensorView*, TensorView*, Val*>(
{fd->recordingState(input()), fd->recordingState(diagonal_())},
{fd->recordingState(output())},
("ops.triu"),
serde::RecordType::Binary_TV_VAL,
static_cast<TensorView* (*)(TensorView*, Val*)>(triu)));

return output;
},
py::arg("input"),
py::arg("diagonal") = 0,
py::return_value_policy::reference,
R"doc(
Returns the upper triangular part of a 2+D tensor.

Parameters
----------
input : Tensor
The input tensor.
diagonal : int, optional
The diagonal to consider. Default is 0.

Returns
-------
Tensor
The upper triangular part of the input tensor.

>>> a = torch.randn(3, 3)
>>> fd.ops.triu(a)
)doc");

// overload to
nvf_ops.def(
"stride_order",
Expand Down
26 changes: 13 additions & 13 deletions csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,19 @@ void KernelExecutor::compile(
}
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
if (!args.empty()) {
auto expr_eval = executor_utils::bindInputs(args, kernel);
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, kernel->indexType());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName(), block_size);

// If NVFUSER_EXTERNAL_SRC is set, utilize the external source code.
// If the loaded external source code is empty, revert to the default codegen.
Expand Down Expand Up @@ -525,18 +537,6 @@ void KernelExecutor::compile(
NVF_THROW(ss.str());
}

// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
if (!args.empty()) {
auto expr_eval = executor_utils::bindInputs(args, kernel);
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, kernel->indexType());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
}

// TODO: high water mark should be computed via occupancy API after
// compilation.

Expand Down
3 changes: 2 additions & 1 deletion csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ void AmpereMultipleMatmulScheduler::cacheInputsAndOutputs() {
scheduler_utils::clearMemorySpace(fusion_);

// Cache inputs
scheduler_utils::cacheInputs(fusion_, /*unroll=*/true);
scheduler_utils::cacheInputs(
fusion_, /*unroll=*/true, /*propagate_allocation=*/true);

// Cache and fork outputs
cached_outputs_ =
Expand Down
3 changes: 2 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ void HopperMultipleMatmulScheduler::cacheInputsAndOutputs() {
scheduler_utils::clearMemorySpace(fusion_);

// Cache inputs
scheduler_utils::cacheInputs(fusion_, /*unroll=*/true);
scheduler_utils::cacheInputs(
fusion_, /*unroll=*/true, /*propagate_allocation=*/true);

// Cache and fork outputs
scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true);
Expand Down
13 changes: 8 additions & 5 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,10 @@ void clearMemorySpace(Fusion* fusion) {

// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
std::vector<TensorView*> cacheInputs(
Fusion* fusion,
bool unroll,
bool propagate_allocation) {
if (!unroll) {
return {};
}
Expand Down Expand Up @@ -1224,10 +1227,10 @@ std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
}

auto cached_tv = tv->cacheAfter(
/*op_type=*/LoadStoreOpType::Set,
/*cache_op=*/CacheOp::Unspecified,
/*propagate_allocation_domain=*/true,
/*cached_uses=*/cached_uses);
LoadStoreOpType::Set,
CacheOp::Unspecified,
propagate_allocation,
cached_uses);
cached_inputs.emplace_back(cached_tv);
}
return cached_inputs;
Expand Down
Loading

0 comments on commit e3826e2

Please sign in to comment.