Skip to content

Commit

Permalink
Add register sharing to warp-specialized circular buffering (#3669)
Browse files Browse the repository at this point in the history
This PR implements register sharing for warp-specialized circular
buffering. Registers in the load warp group are moved to the compute
warp group using the `setmaxnreg` ptx instruction. It is an optimization
for matmul kernels.

## Changes
1. Add `__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/)` to cuda kernel
declaration.
2. Add `kir::SetMaxNReg` and `kir::Return` nodes to warp-specialized
circular buffering.
3. `TensorView::circularBuffer` allows setting the number of registers for load and compute
warp groups through `struct WarpSpecialized`
4. Require Hopper architecture for TensorViews using warp-specialized
circular buffering.

## Why `__launch_bounds__` is necessary?

> The setmaxnreg instruction requires that the kernel has been launched
with a valid value of maximum number of per-thread registers specified
via the appropriate compilation via the appropriate compile-time option
or the appropriate performance tuning directive. Otherwise, the
setmaxnreg instruction may have no effect.

From
https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg

## Generated Code
```cuda
__global__ void __launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/64) 
  nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0,
                                  const __grid_constant__ TensorMap var0,
                                  Tensor<float, 1, 1> T1) {
  // do something
  if ((((nvfuser_index_t)threadIdx.y) == 1)) {
    asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n"::"n"(24));
    // load something
    return;
  } else {
    asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n"::"n"(240));
    // compute something
  }
  // do something
}
```
  • Loading branch information
rdspring1 authored Jan 9, 2025
1 parent 23ab1ad commit 72402bf
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 24 deletions.
23 changes: 18 additions & 5 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
public:
static std::string generateKernelDefinition(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
CudaKernelGenerator codegen(kernel);
codegen.genDeclaration(kernel_name);
codegen.genDeclaration(kernel_name, num_threads_per_cta);
codegen.startBlock();
codegen.genPrologue();
codegen.genBody();
Expand Down Expand Up @@ -272,8 +273,18 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
void genDeclaration(
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
code_ << "__global__ void ";
if (kernel_->hasManaged("enable_register_sharing") &&
kernel_->getManaged<bool>("enable_register_sharing")) {
NVF_ERROR(
num_threads_per_cta.has_value(),
"__launch_bounds__ must be set for register sharing warp specialization");
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
Expand Down Expand Up @@ -3542,9 +3553,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
FUSER_PERF_SCOPE("generateCudaKernel");
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
return CudaKernelGenerator::generateKernelDefinition(
kernel, kernel_name, num_threads_per_cta);
}

} // namespace codegen
Expand Down
3 changes: 2 additions & 1 deletion csrc/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace codegen {
//! Generates a CUDA kernel definition for the given kernel
NVF_API std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name = "CUDAGeneratedKernel");
const std::string& kernel_name = "CUDAGeneratedKernel",
std::optional<int64_t> num_threads_per_cta = std::nullopt);

} // namespace codegen
} // namespace nvfuser
23 changes: 23 additions & 0 deletions csrc/device_lower/analysis/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ void validateCircularBufferedTensor(const TensorView* tv) {
". Consumer memory type: ",
c_mem_type);

// Ensure that the warp-specialized circular buffer loop is the outer-most
// for-loop if register sharing is enabled.
if (std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type) &&
std::get<WarpSpecialized>(tv->circularBufferOptions().type)
.num_registers.has_value()) {
for (int64_t axis : c10::irange((int64_t)tv->getLoopDomain().size())) {
// short-circuit: only check IterDomains to the left of the circular
// buffer position
if (axis >= circular_buffer_pos) {
break;
}
NVF_ERROR(
tv->getLoopDomain().at(axis)->isThread() ||
tv->getLoopDomain().at(axis)->isDeviceDim() ||
tv->getLoopDomain().at(axis)->isBroadcast() ||
tv->getLoopDomain().at(axis)->isOneInt(),
"When using register sharing with warp-specialized circular "
"buffering, the circular buffer loop must be the outer-most "
"for-loop.");
}
}

return;
}

Expand Down
9 changes: 9 additions & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ void MinimumDeviceVersion::handle(LoadStoreOp* ls_op) {
}
}

void MinimumDeviceVersion::handle(TensorView* tv) {
if (std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type)) {
ensureVersion(
{9, 0},
"Warp Specialized Circular Buffering uses the setmaxnreg ptx instruction, which requires Hopper (9.0) or newer");
}
}

void MinimumDeviceVersion::ensureVersion(
std::pair<int, int> version,
std::string reason) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/device_lower/analysis/device_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class MinimumDeviceVersion : private IterVisitor {
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
void handle(LoadStoreOp* ls_op) final;

//! If TensorView has warp specialized circular buffering, it will use the
//! setmaxnreg ptx instruction that requires Hopper (9.0+).
//! https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
void handle(TensorView* tv) final;

//! bump min_version_ to at least this value
void ensureVersion(std::pair<int, int> version, std::string reason);

Expand Down
37 changes: 37 additions & 0 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,11 +1394,48 @@ class CircularBufferInserter : private kir::ExprMutator {
warp_specialize_on),
circular_buffer_loop->fusion()->oneVal()))));

// Set default value
auto& circular_buffer_options =
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
circular_buffer_loop->iter_domain());
bool enable_register_sharing =
std::holds_alternative<WarpSpecialized>(circular_buffer_options.type) &&
std::get<WarpSpecialized>(circular_buffer_options.type)
.num_registers.has_value();

GpuLower::current()->kernel()->manage(
"enable_register_sharing", enable_register_sharing);

if (enable_register_sharing) {
auto&& [decrease_num_registers, increase_num_registers] =
std::get<WarpSpecialized>(circular_buffer_options.type)
.num_registers.value();

// Decrease registers in load warp group
kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(decrease_num_registers, DataType::Index),
/*increase_registers=*/false);
warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp);

// Increase registers in compute warp group
kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(increase_num_registers, DataType::Index),
/*increase_registers*/ true);
warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp);
}

// Load loop:
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
warp_dispatch_ite->thenBody().push_back(load_loop);

if (enable_register_sharing) {
// Terminate the warp group handling Load loop immediately after
// finishing its work.
kir::Return* ret = IrBuilder::create<kir::Return>();
warp_dispatch_ite->thenBody().push_back(ret);
}

// Prefetch:
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
warp_dispatch_ite->elseBody().push_back(prefetch_loop);
Expand Down
46 changes: 42 additions & 4 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,41 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) {
}

struct WarpSpecialized {
ParallelType on;
explicit WarpSpecialized(ParallelType on) : on(on) {}
ParallelType on = ParallelType::Serial;
// The number of registers for load and compute warps respectively.
std::optional<std::pair<int64_t, int64_t>> num_registers = std::nullopt;

explicit WarpSpecialized(
ParallelType on,
std::pair<int64_t, int64_t> num_registers)
: on(on), num_registers(num_registers) {
validateRegisterSharing();
}
explicit WarpSpecialized(ParallelType on)
: on(on), num_registers(std::nullopt) {}
WarpSpecialized() = default;

void validateRegisterSharing() {
// short-circuit: register sharing is not used.
if (!num_registers.has_value()) {
return;
}
auto validate_num_registers = [](int64_t a) {
NVF_ERROR(
a >= 24 && a <= 256 && a % 8 == 0,
"The number of registers for setmaxnreg must be between 24 and",
" 256 (inclusive) and be a multiple of 8.");
};
validate_num_registers(num_registers.value().first);
validate_num_registers(num_registers.value().second);
NVF_ERROR(
num_registers.value().first <= num_registers.value().second,
"The number of registers for load warp group must be <= to the number",
" of registers for the compute warp groups.");
}

bool operator==(const WarpSpecialized& other) const {
return on == other.on;
return on == other.on && num_registers == other.num_registers;
}
};

Expand All @@ -252,7 +282,15 @@ inline std::ostream& operator<<(
default:
NVF_THROW("Invalid parallel type");
}
return os << "WarpSpecializedOn" << parallel_type_str;
std::string num_registers = "RegisterSharing_None";
if (warp_specialized.num_registers.has_value()) {
auto&& [decrease_num_reg, increase_num_reg] =
warp_specialized.num_registers.value();
std::stringstream s;
s << "RegisterSharing_" << decrease_num_reg << "_" << increase_num_reg;
num_registers = s.str();
}
return os << "WarpSpecializedOn" << parallel_type_str << num_registers;
}

using CircularBufferType = std::variant<Pipelined, WarpSpecialized>;
Expand Down
26 changes: 13 additions & 13 deletions csrc/runtime/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,19 @@ void KernelExecutor::compile(
}
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
if (!args.empty()) {
auto expr_eval = executor_utils::bindInputs(args, kernel);
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, kernel->indexType());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
}

kernel_code_ = codegen::generateCudaKernel(kernel, kernelName(), block_size);

// If NVFUSER_EXTERNAL_SRC is set, utilize the external source code.
// If the loaded external source code is empty, revert to the default codegen.
Expand Down Expand Up @@ -525,18 +537,6 @@ void KernelExecutor::compile(
NVF_THROW(ss.str());
}

// TODO: pass block_size here;
std::optional<int64_t> dynamic_smem = std::nullopt;
std::optional<int64_t> block_size = std::nullopt;
if (!args.empty()) {
auto expr_eval = executor_utils::bindInputs(args, kernel);
auto launch_params = computeLaunchParams(
launch_constraints, expr_eval, warp_size_, kernel->indexType());
block_size = launch_params.nThreads();
dynamic_smem = launch_params.smem();
NVF_ERROR(block_size > 0, "launch param inferred block size < 0");
}

// TODO: high water mark should be computed via occupancy API after
// compilation.

Expand Down
Loading

0 comments on commit 72402bf

Please sign in to comment.