diff --git a/xla/backends/gpu/codegen/triton/fusion.cc b/xla/backends/gpu/codegen/triton/fusion.cc index d54187c55fe26..643f103ced798 100644 --- a/xla/backends/gpu/codegen/triton/fusion.cc +++ b/xla/backends/gpu/codegen/triton/fusion.cc @@ -75,7 +75,8 @@ TritonFusion::GenerateTritonKernelAndWrapper( absl::string_view fusion_kind = backend_config.kind(); TritonWrapperResult triton_wrapper_result; - if (fusion_kind == kTritonFusionKind) { + if (fusion_kind == kTritonFusionKind || + fusion_kind == kTritonNestedGemmFusionKind) { std::optional launch_config = this->launch_config(); if (!launch_config.has_value()) { return absl::InvalidArgumentError(absl::StrCat( @@ -145,7 +146,8 @@ absl::StatusOr TritonFusion::Emit( absl::string_view fusion_kind = backend_config.kind(); LaunchDimensions launch_dimensions; - if (fusion_kind == kTritonFusionKind) { + if (fusion_kind == kTritonFusionKind || + fusion_kind == kTritonNestedGemmFusionKind) { std::optional launch_config = this->launch_config(); // This check should be enforced by `GenerateTritonKernelWrapper`. CHECK(launch_config.has_value()); diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/xla/backends/gpu/codegen/triton/fusion_emitter.cc index 02644b9dc4733..410916d329f9b 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1444,7 +1444,8 @@ absl::StatusOr CreateTritonModule( TF_ASSIGN_OR_RETURN(tma_metadata, EmitMatMul(b, libdevice_path, device_info, fusion, fn, block_level_parameters)); - } else if (fusion_kind == kTritonFusionKind) { + } else if (fusion_kind == kTritonFusionKind || + fusion_kind == kTritonNestedGemmFusionKind) { TF_RETURN_IF_ERROR(EmitGeneric(b, libdevice_path, device_info, fusion, fn, block_level_parameters)); } else { diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index b5ca33fa4f48a..4313c9d9d6fe0 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -1784,14 +1784,14 @@ fdot { fdot.p1 = f32[16,16] parameter(1) fdot.lhs = f32[16,16] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "16"]}] } } } fdot.rhs = f32[16,16]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "16"]}] } } @@ -1806,7 +1806,7 @@ ENTRY entry { ROOT fusion = f32[16,16] fusion(entry.p0, entry.p1), kind=kCustom, calls=fdot, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "16"]}], "num_warps":"1" } } @@ -1837,14 +1837,14 @@ fdot { fdot.p1 = f32[256,512] parameter(1) fdot.lhs = f32[32,256] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "32"]}] } } } fdot.rhs = f32[256,512]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["32", "64"]}] } } @@ -1859,7 +1859,7 @@ ENTRY entry { ROOT fusion = f32[32,512] fusion(entry.p0, entry.p1), kind=kCustom, calls=fdot, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "64"]}], "num_warps":"1" } } @@ -1903,14 +1903,14 @@ fdot { fdot.p1 = f32[299,512] parameter(1) fdot.lhs = f32[32,299] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "32"]}] } } } fdot.rhs = f32[299,512]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["32", "64"]}] } } @@ -1925,7 +1925,7 @@ ENTRY entry { ROOT fusion = f32[32,512] fusion(entry.p0, entry.p1), kind=kCustom, calls=fdot, backend_config={ "fusion_backend_config":{ - "kind":"__triton", "block_level_fusion_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ "output_tiles":[{"sizes":["16", "64"]}], "num_warps":"1" } } diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index b3e1205134ebd..62c666fc94682 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -201,8 +201,8 @@ message FusionBackendConfig { // present, we use the default Triton config. AutotuneResult.TritonGemmKey triton_gemm_config = 2; - // Only valid when kind == "__triton" for now. Code generation of such - // fusions will fail if this field is not set. + // Only valid when kind is "__triton" or "__triton_nested_fusion_gemm". Code + // generation of such fusions will fail if this field is not set. BlockLevelFusionConfig block_level_fusion_config = 6; // Only valid when kind == "__custom_fusion". diff --git a/xla/service/gpu/gpu_fusible.cc b/xla/service/gpu/gpu_fusible.cc index 8140177db4887..3644f02667612 100644 --- a/xla/service/gpu/gpu_fusible.cc +++ b/xla/service/gpu/gpu_fusible.cc @@ -819,6 +819,8 @@ std::vector GetFusionRoots( } bool IsGenericTritonFusion(const HloInstruction& instr) { + // Note that we don't accept kTritonNestedGemmFusionKind here as they should + // not be fused with anything else. return instr.opcode() == HloOpcode::kFusion && instr.fusion_kind() == HloInstruction::FusionKind::kCustom && instr.backend_config().ok() && diff --git a/xla/service/gpu/hlo_fusion_analysis.cc b/xla/service/gpu/hlo_fusion_analysis.cc index 40a0d1117ca3f..2aa6996a3694b 100644 --- a/xla/service/gpu/hlo_fusion_analysis.cc +++ b/xla/service/gpu/hlo_fusion_analysis.cc @@ -233,7 +233,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() } if (fusion_backend_config_.kind() == kTritonFusionKind || - fusion_backend_config_.kind() == kTritonGemmFusionKind) { + fusion_backend_config_.kind() == kTritonGemmFusionKind || + fusion_backend_config_.kind() == kTritonNestedGemmFusionKind) { return EmitterFusionKind::kTriton; } diff --git a/xla/service/gpu/ir_emission_utils.h b/xla/service/gpu/ir_emission_utils.h index 0677ee6bdac01..40d5024adfe99 100644 --- a/xla/service/gpu/ir_emission_utils.h +++ b/xla/service/gpu/ir_emission_utils.h @@ -75,12 +75,17 @@ inline constexpr absl::string_view kCustomFusionKind = "__custom_fusion"; // Generic fusions that use Triton have FusionBackendConfig.kind equal to this // string. This fusion kind will eventually subsume all usages of -// kTritonGemmFusionKind and kTritonSoftmaxFusionKind. +// kTritonGemmFusionKind. inline constexpr absl::string_view kTritonFusionKind = "__triton"; // Fusions that use Triton have FusionBackendConfig.kind equal to this string. inline constexpr absl::string_view kTritonGemmFusionKind = "__triton_gemm"; +// Generic fusions that use Triton have FusionBackendConfig.kind equal to this +// string. Used for fusions that implement a dot expressed as nested fusions. +inline constexpr absl::string_view kTritonNestedGemmFusionKind = + "__triton_nested_gemm_fusion"; + inline constexpr absl::string_view kCuDnnFusionKind = "__cudnn$fusion"; // Fusions that can be emitted using a dynamic memcpy. A dynamic memcpy depends diff --git a/xla/service/gpu/transforms/nest_gemm_fusion.cc b/xla/service/gpu/transforms/nest_gemm_fusion.cc index e217e4ab494bc..9a6ccc584db85 100644 --- a/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -118,7 +118,7 @@ absl::Status FuseInstructionsForConsumer( fusion->backend_config()); FusionBackendConfig& backend_config = *gpu_config.mutable_fusion_backend_config(); - backend_config.set_kind(std::string(kTritonFusionKind)); + backend_config.set_kind(std::string(kTritonNestedGemmFusionKind)); TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); for (int64_t operand_index : consumer.OperandIndices(old_root)) { @@ -317,7 +317,7 @@ absl::Status MakeNestedFusionFromGemmFusion(HloFusionInstruction* fusion, FusionBackendConfig& backend_config = *gpu_config.mutable_fusion_backend_config(); backend_config.clear_triton_gemm_config(); - backend_config.set_kind(std::string(kTritonFusionKind)); + backend_config.set_kind(std::string(kTritonNestedGemmFusionKind)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = { diff --git a/xla/service/gpu/transforms/nest_gemm_fusion.h b/xla/service/gpu/transforms/nest_gemm_fusion.h index aee2ece23afd3..134810c1600b6 100644 --- a/xla/service/gpu/transforms/nest_gemm_fusion.h +++ b/xla/service/gpu/transforms/nest_gemm_fusion.h @@ -27,6 +27,14 @@ namespace xla::gpu { // Rewrites Triton GEMM fusions to generic Triton fusions. Any other fusions are // left unchanged. // +// Fusions with kind kCustom and fusion_backend_config.kind "__triton_gemm" are +// rewritten to fusion_backend_config.kind +// "__triton_nested_fusion_gemm". +// +// While this new fusion kind is supported by generic triton emitter we want +// to distinguish it from "__triton" as we don't want other passes to modify the +// resulting fusions. +// // The fusion's backend config is set to a BlockLevelFusionConfig, derived from // a previously set TritonGemmConfig. // diff --git a/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/xla/service/gpu/transforms/nest_gemm_fusion_test.cc index 1898713fb60ab..4fbd8ab09bebe 100644 --- a/xla/service/gpu/transforms/nest_gemm_fusion_test.cc +++ b/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -56,6 +56,10 @@ MATCHER_P(OutputTileSizesIs, matcher, "") { *result_listener << "has no block level fusion config"; return false; } + if (fusion_backend_config.kind() != "__triton_nested_gemm_fusion") { + *result_listener << "fusion kind is not __triton_nested_gemm_fusion"; + return false; + } auto output_tile_sizes = fusion_backend_config.block_level_fusion_config().output_tiles(0).sizes(); return ExplainMatchResult(matcher, output_tile_sizes, result_listener); diff --git a/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index 01f6c891a48cc..04e4da53f0fc0 100644 --- a/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -72,7 +72,8 @@ absl::StatusOr AsTritonFusion( fusion->backend_config()); const FusionBackendConfig& backend_config = gpu_config.fusion_backend_config(); - if (backend_config.kind() == kTritonFusionKind) { + if (backend_config.kind() == kTritonFusionKind || + backend_config.kind() == kTritonNestedGemmFusionKind) { return fusion; } return nullptr;