Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] new fusion kind for nested dot fusions #23937

Merged
merged 1 commit into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions xla/backends/gpu/codegen/triton/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ TritonFusion::GenerateTritonKernelAndWrapper(
absl::string_view fusion_kind = backend_config.kind();
TritonWrapperResult triton_wrapper_result;

if (fusion_kind == kTritonFusionKind) {
if (fusion_kind == kTritonFusionKind ||
fusion_kind == kTritonNestedGemmFusionKind) {
std::optional<LaunchConfig> launch_config = this->launch_config();
if (!launch_config.has_value()) {
return absl::InvalidArgumentError(absl::StrCat(
Expand Down Expand Up @@ -145,7 +146,8 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
absl::string_view fusion_kind = backend_config.kind();

LaunchDimensions launch_dimensions;
if (fusion_kind == kTritonFusionKind) {
if (fusion_kind == kTritonFusionKind ||
fusion_kind == kTritonNestedGemmFusionKind) {
std::optional<LaunchConfig> launch_config = this->launch_config();
// This check should be enforced by `GenerateTritonKernelWrapper`.
CHECK(launch_config.has_value());
Expand Down
3 changes: 2 additions & 1 deletion xla/backends/gpu/codegen/triton/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,8 @@ absl::StatusOr<TritonModule> CreateTritonModule(
TF_ASSIGN_OR_RETURN(tma_metadata,
EmitMatMul(b, libdevice_path, device_info, fusion, fn,
block_level_parameters));
} else if (fusion_kind == kTritonFusionKind) {
} else if (fusion_kind == kTritonFusionKind ||
fusion_kind == kTritonNestedGemmFusionKind) {
TF_RETURN_IF_ERROR(EmitGeneric(b, libdevice_path, device_info, fusion, fn,
block_level_parameters));
} else {
Expand Down
18 changes: 9 additions & 9 deletions xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1784,14 +1784,14 @@ fdot {
fdot.p1 = f32[16,16] parameter(1)
fdot.lhs = f32[16,16] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "16"]}]
}
}
}
fdot.rhs = f32[16,16]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "16"]}]
}
}
Expand All @@ -1806,7 +1806,7 @@ ENTRY entry {
ROOT fusion = f32[16,16] fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "16"]}], "num_warps":"1"
}
}
Expand Down Expand Up @@ -1837,14 +1837,14 @@ fdot {
fdot.p1 = f32[256,512] parameter(1)
fdot.lhs = f32[32,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "32"]}]
}
}
}
fdot.rhs = f32[256,512]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["32", "64"]}]
}
}
Expand All @@ -1859,7 +1859,7 @@ ENTRY entry {
ROOT fusion = f32[32,512] fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "64"]}], "num_warps":"1"
}
}
Expand Down Expand Up @@ -1903,14 +1903,14 @@ fdot {
fdot.p1 = f32[299,512] parameter(1)
fdot.lhs = f32[32,299] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "32"]}]
}
}
}
fdot.rhs = f32[299,512]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["32", "64"]}]
}
}
Expand All @@ -1925,7 +1925,7 @@ ENTRY entry {
ROOT fusion = f32[32,512] fusion(entry.p0, entry.p1),
kind=kCustom, calls=fdot, backend_config={
"fusion_backend_config":{
"kind":"__triton", "block_level_fusion_config":{
"kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{
"output_tiles":[{"sizes":["16", "64"]}], "num_warps":"1"
}
}
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/backend_configs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ message FusionBackendConfig {
// present, we use the default Triton config.
AutotuneResult.TritonGemmKey triton_gemm_config = 2;

// Only valid when kind == "__triton" for now. Code generation of such
// fusions will fail if this field is not set.
// Only valid when kind is "__triton" or "__triton_nested_fusion_gemm". Code
// generation of such fusions will fail if this field is not set.
BlockLevelFusionConfig block_level_fusion_config = 6;

// Only valid when kind == "__custom_fusion".
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/gpu_fusible.cc
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,8 @@ std::vector<const HloInstruction*> GetFusionRoots(
}

bool IsGenericTritonFusion(const HloInstruction& instr) {
// Note that we don't accept kTritonNestedGemmFusionKind here as they should
// not be fused with anything else.
return instr.opcode() == HloOpcode::kFusion &&
instr.fusion_kind() == HloInstruction::FusionKind::kCustom &&
instr.backend_config<GpuBackendConfig>().ok() &&
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind()
}

if (fusion_backend_config_.kind() == kTritonFusionKind ||
fusion_backend_config_.kind() == kTritonGemmFusionKind) {
fusion_backend_config_.kind() == kTritonGemmFusionKind ||
fusion_backend_config_.kind() == kTritonNestedGemmFusionKind) {
return EmitterFusionKind::kTriton;
}

Expand Down
7 changes: 6 additions & 1 deletion xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,17 @@ inline constexpr absl::string_view kCustomFusionKind = "__custom_fusion";

// Generic fusions that use Triton have FusionBackendConfig.kind equal to this
// string. This fusion kind will eventually subsume all usages of
// kTritonGemmFusionKind and kTritonSoftmaxFusionKind.
// kTritonGemmFusionKind.
inline constexpr absl::string_view kTritonFusionKind = "__triton";

// Fusions that use Triton have FusionBackendConfig.kind equal to this string.
inline constexpr absl::string_view kTritonGemmFusionKind = "__triton_gemm";

// Generic fusions that use Triton have FusionBackendConfig.kind equal to this
// string. Used for fusions that implement a dot expressed as nested fusions.
inline constexpr absl::string_view kTritonNestedGemmFusionKind =
"__triton_nested_gemm_fusion";

inline constexpr absl::string_view kCuDnnFusionKind = "__cudnn$fusion";

// Fusions that can be emitted using a dynamic memcpy. A dynamic memcpy depends
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/transforms/nest_gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ absl::Status FuseInstructionsForConsumer(
fusion->backend_config<GpuBackendConfig>());
FusionBackendConfig& backend_config =
*gpu_config.mutable_fusion_backend_config();
backend_config.set_kind(std::string(kTritonFusionKind));
backend_config.set_kind(std::string(kTritonNestedGemmFusionKind));
TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config));

for (int64_t operand_index : consumer.OperandIndices(old_root)) {
Expand Down Expand Up @@ -317,7 +317,7 @@ absl::Status MakeNestedFusionFromGemmFusion(HloFusionInstruction* fusion,
FusionBackendConfig& backend_config =
*gpu_config.mutable_fusion_backend_config();
backend_config.clear_triton_gemm_config();
backend_config.set_kind(std::string(kTritonFusionKind));
backend_config.set_kind(std::string(kTritonNestedGemmFusionKind));

BlockLevelParameters block_level_parameters;
block_level_parameters.output_tile_sizes = {
Expand Down
8 changes: 8 additions & 0 deletions xla/service/gpu/transforms/nest_gemm_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ namespace xla::gpu {
// Rewrites Triton GEMM fusions to generic Triton fusions. Any other fusions are
// left unchanged.
//
// Fusions with kind kCustom and fusion_backend_config.kind "__triton_gemm" are
// rewritten to fusion_backend_config.kind
// "__triton_nested_fusion_gemm".
//
// While this new fusion kind is supported by generic triton emitter we want
// to distinguish it from "__triton" as we don't want other passes to modify the
// resulting fusions.
//
// The fusion's backend config is set to a BlockLevelFusionConfig, derived from
// a previously set TritonGemmConfig.
//
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/transforms/nest_gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ MATCHER_P(OutputTileSizesIs, matcher, "") {
*result_listener << "has no block level fusion config";
return false;
}
if (fusion_backend_config.kind() != "__triton_nested_gemm_fusion") {
*result_listener << "fusion kind is not __triton_nested_gemm_fusion";
return false;
}
auto output_tile_sizes =
fusion_backend_config.block_level_fusion_config().output_tiles(0).sizes();
return ExplainMatchResult(matcher, output_tile_sizes, result_listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ absl::StatusOr<const HloFusionInstruction*> AsTritonFusion(
fusion->backend_config<GpuBackendConfig>());
const FusionBackendConfig& backend_config =
gpu_config.fusion_backend_config();
if (backend_config.kind() == kTritonFusionKind) {
if (backend_config.kind() == kTritonFusionKind ||
backend_config.kind() == kTritonNestedGemmFusionKind) {
return fusion;
}
return nullptr;
Expand Down
Loading