Skip to content

CUDA Packet Scatter Reduce of f16x2 #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

DoeringChristian
Copy link
Contributor

This PR enables us to render a specialized version of the PacketScatter op in CUDA when a reduction operation is specified, with f16 types. It uses the red.global.add.noftz.f16x2 instruction.

@wjakob
Copy link
Member

wjakob commented Jun 11, 2025

Potentially relevant: sm_90+ supports red.global.add.v2.f16x2 and red.global.add.v4.f16x2 (so 4 or 8 FP16 accumulations in one instruction). This should be available on RTX5090 and similar Blackwell GPUs.

fmt(" @$v ", mask);
else
put(" ");
fmt("red.global.$s.noftz.f16x2 [%rd3+$u], %tmp;\n", op,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this works when op != ".add"?
Looking carefully at the PTX docs for red, it seems that non-vector variants of red only support add for f16 and f16x2. Do you agree or am I misreading it?

If there are only 1 or 2 values left to reduce (e.g. packets of size 2 , 6, 10, etc), I don't think we could use the vector variants?

void jitc_cuda_render_scatter_packet(const Variable *v, const Variable *ptr,
const Variable *index, const Variable *mask) {
bool is_masked = !mask->is_literal() || mask->literal != 1;
PacketScatterData *psd = (PacketScatterData *) v->data;
const std::vector<uint32_t> &values = psd->values;
const Variable *v0 = jitc_var(values[0]);

// Handle non-Identitiy reduction case
if (psd->op != ReduceOp::Identity){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting

if (v0->type != (uint32_t)VarType::Float16)
jitc_fail("Packeted scatter reductions are only supported with f16 "
"variables.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this check earlier + include the name of the type you actually got in the error message (type_name[v0->type] I think)

@@ -15,6 +15,10 @@
#include "op.h"
#include "log.h"

static const char *reduce_op_name[(int) ReduceOp::Count] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to not re-define this? (vs cuda_scatter.cpp).
Having it duplicated in separate locations makes it error-prone when the enum changes.
Although I see it's already defined another time in llvm_scatter.cpp, so maybe it's not easy to have it only once.

const Variable *mask) {
bool is_masked = !mask->is_literal() || mask->literal != 1;
PacketScatterData *psd = (PacketScatterData *) v->data;
const std::vector<uint32_t> &values = psd->values;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could use dr::vector?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to use std::vector. I just prefer dr::vector in header files, so that we don' t have to pull in STL code everywhere.

@DoeringChristian DoeringChristian force-pushed the scatter-reduce-f16x2 branch 5 times, most recently from 2e4993e to 2620cc7 Compare June 18, 2025 14:17
@DoeringChristian DoeringChristian force-pushed the scatter-reduce-f16x2 branch 5 times, most recently from bd67bc4 to 27702ea Compare July 9, 2025 15:18
…packed functions

Specialization for vector scatter reductions on sm_90

Improved failure messages

Removed include of llvm in cuda

Cleanup packet scatter reduce

Improved packet scatter reduce gating
@DoeringChristian DoeringChristian marked this pull request as ready for review July 18, 2025 07:09
Copy link
Member

@wjakob wjakob left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool, some feedback from me. The first two are from an earlier partial review and may no longer apply. (It says the files are "Outdated").

byte_offset);
}
} else {
jitc_fail("jitc_cuda_render_scatter_reduce_packet(): Number of elements not supported for reduction.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this cause some existing reductions to fail, or does the logic in drjit-core preclude this case from being reachable?

src/op.cpp Outdated
@@ -2639,6 +2639,11 @@ uint32_t jitc_var_scatter_packet(size_t n, uint32_t target_,
mode == ReduceMode::NoConflicts ||
(mode == ReduceMode::Auto &&
target_info.size <= llvm_expand_threshold));
} else if (op == ReduceOp::Add && backend == JitBackend::CUDA) {
use_packet_op = (mode == ReduceMode::Expand ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have ReduceMode::Expand in CUDA.

const Variable *mask) {
bool is_masked = !mask->is_literal() || mask->literal != 1;
PacketScatterData *psd = (PacketScatterData *) v->data;
const std::vector<uint32_t> &values = psd->values;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to use std::vector. I just prefer dr::vector in header files, so that we don' t have to pull in STL code everywhere.

"elements not supported for reduction.");

if (ts->compute_capability >= 90) {
// Use the new `red.global.v2` instructions. This enables both min & max
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment up-to-date? .v2 is only two elements, and we're doing wider ones AFAIK.

@@ -1905,11 +1905,11 @@ void jitc_var_gather_packet(size_t n, uint32_t src_, uint32_t index, uint32_t ma
auto [var_info, index_v, mask_v] =
jitc_var_check("jit_var_gather_packet", index, mask);

if ((n & (n-1)) || n == 1)
if (n == 1)
jitc_raise("jitc_var_gather_packet(): vector size must be a power of two "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message seems out of date now.

jitc_raise("jitc_var_gather_packet(): vector size must be a power of two "
"and >= 1 (got %zu)!", n);

if ((src_info.size & (n-1)) != 0 && src_info.size != 1)
if (src_info.size % 2 != 0 && src_info.size != 1)
jitc_raise("jitc_var_gather_packet(): source r%u has size %u, which is not "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message seems out of date now.

src/op.cpp Outdated
}
}

// If the packet size is not divisible by two we cannot use packet ops.
use_packet_op = use_packet_op && n > 1 && n % 2 == 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move the logic to code generation, perhaps it's easier to not even have special handling for the n=1 case here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants