diff --git a/xla/service/call_inliner.cc b/xla/service/call_inliner.cc index 579e944270269..bf6777637b803 100644 --- a/xla/service/call_inliner.cc +++ b/xla/service/call_inliner.cc @@ -185,19 +185,6 @@ bool InlineInstruction(HloInstruction* instruction) { return true; } -bool InlineStreamAnnotation(HloInstruction* instruction) { - if (instruction->GetModule() - ->config() - .debug_options() - .xla_gpu_experimental_stream_annotation()) { - if (instruction->frontend_attributes().map().contains( - kXlaStreamAnnotationAttr)) { - return false; - } - } - return true; -} - } // namespace /* static */ absl::StatusOr @@ -247,12 +234,19 @@ CallInliner::Inline(HloInstruction* call) { } bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kCall && - !instruction->has_backend_config() && - !instruction->parent()->IsAsyncComputation() && - InlineInstruction(instruction) && InlineUnderShardy(instruction) && - InlineComposites(instruction, composites_to_preserve_) && - InlineStreamAnnotation(instruction); + bool prerequisite = instruction->opcode() == HloOpcode::kCall && + !instruction->has_backend_config() && + !instruction->parent()->IsAsyncComputation(); + if (!prerequisite) { + return false; + } + if (!InlineInstruction(instruction)) { + // Always prioritize user's explicit requests after fulfilling the + // prerequisites. + return false; + } + return InlineUnderShardy(instruction) && + InlineComposites(instruction, composites_to_preserve_); } absl::StatusOr CallInliner::Run( diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index 3c3b99b78279a..6995c7a971dfe 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -522,59 +522,6 @@ TEST_F(CallInlinerTest, UseShardManualComputationBodySurroundedNotInlined) { "my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234"); } -TEST_F(CallInlinerTest, DontInlineStreamAnnotationCall) { - const absl::string_view hlo_string = R"( - HloModule composite - - %add (lhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] constant(2) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) - } - - %sub (lhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] constant(1) - ROOT %sub = f32[] subtract(f32[] %lhs, f32[] %rhs) - } - - ENTRY %main () -> f32[] { - %lhs = f32[] constant(42) - %call1 = f32[] call(f32[] %lhs), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} - ROOT %call2 = f32[] call(f32[] %call1), to_apply=%add - })"; - - auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_experimental_stream_annotation(true); - auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - module->mutable_config().set_debug_options(debug_options); - CallInliner call_inliner(/*single_call_site=*/true); - - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); - absl::StatusOr filecheck_result = RunFileCheck(module->ToString({}), R"( - //CHECK: %lhs.2 = f32[] constant(42) - //CHECK: %call1 = f32[] call(%lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} - //CHECK: %rhs.2 = f32[] constant(2) - //CHECK: ROOT %add.1 = f32[] add(%call1, %rhs.2) - )"); - TF_ASSERT_OK(filecheck_result.status()); - EXPECT_TRUE(*filecheck_result); - - ASSERT_TRUE(mutated); - ASSERT_EQ(module->entry_computation()->instruction_count(), 4); - auto inst = module->entry_computation()->instructions().begin(); - EXPECT_THAT(*inst, op::Constant()); - // Check that the annotated call isn't inlined - ++inst; - EXPECT_THAT(*inst, op::Call()); - - // Check that the non-annotated call is still inlined - ++inst; - EXPECT_THAT(*inst, op::Constant()); - ++inst; - EXPECT_THAT(*inst, op::Add()); -} - TEST_F(CallInlinerTest, ControlDepsPropagateToRootOfInlinedInstructions) { const char* hlo = R"( HloModule test