Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleans up call inliner in the XLA shared code path #23964

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,6 @@ bool InlineInstruction(HloInstruction* instruction) {
return true;
}

bool InlineStreamAnnotation(HloInstruction* instruction) {
if (instruction->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_stream_annotation()) {
if (instruction->frontend_attributes().map().contains(
kXlaStreamAnnotationAttr)) {
return false;
}
}
return true;
}

} // namespace

/* static */ absl::StatusOr<CallInliner::InlinedInstructionMap>
Expand Down Expand Up @@ -247,12 +234,19 @@ CallInliner::Inline(HloInstruction* call) {
}

bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
return instruction->opcode() == HloOpcode::kCall &&
!instruction->has_backend_config() &&
!instruction->parent()->IsAsyncComputation() &&
InlineInstruction(instruction) && InlineUnderShardy(instruction) &&
InlineComposites(instruction, composites_to_preserve_) &&
InlineStreamAnnotation(instruction);
bool prerequisite = instruction->opcode() == HloOpcode::kCall &&
!instruction->has_backend_config() &&
!instruction->parent()->IsAsyncComputation();
if (!prerequisite) {
return false;
}
if (!InlineInstruction(instruction)) {
// Always prioritize user's explicit requests after fulfilling the
// prerequisites.
return false;
}
return InlineUnderShardy(instruction) &&
InlineComposites(instruction, composites_to_preserve_);
}

absl::StatusOr<bool> CallInliner::Run(
Expand Down
53 changes: 0 additions & 53 deletions xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,59 +522,6 @@ TEST_F(CallInlinerTest, UseShardManualComputationBodySurroundedNotInlined) {
"my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234");
}

TEST_F(CallInlinerTest, DontInlineStreamAnnotationCall) {
const absl::string_view hlo_string = R"(
HloModule composite

%add (lhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] constant(2)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}

%sub (lhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] constant(1)
ROOT %sub = f32[] subtract(f32[] %lhs, f32[] %rhs)
}

ENTRY %main () -> f32[] {
%lhs = f32[] constant(42)
%call1 = f32[] call(f32[] %lhs), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"}
ROOT %call2 = f32[] call(f32[] %call1), to_apply=%add
})";

auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_experimental_stream_annotation(true);
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
module->mutable_config().set_debug_options(debug_options);
CallInliner call_inliner(/*single_call_site=*/true);

TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
absl::StatusOr<bool> filecheck_result = RunFileCheck(module->ToString({}), R"(
//CHECK: %lhs.2 = f32[] constant(42)
//CHECK: %call1 = f32[] call(%lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"}
//CHECK: %rhs.2 = f32[] constant(2)
//CHECK: ROOT %add.1 = f32[] add(%call1, %rhs.2)
)");
TF_ASSERT_OK(filecheck_result.status());
EXPECT_TRUE(*filecheck_result);

ASSERT_TRUE(mutated);
ASSERT_EQ(module->entry_computation()->instruction_count(), 4);
auto inst = module->entry_computation()->instructions().begin();
EXPECT_THAT(*inst, op::Constant());
// Check that the annotated call isn't inlined
++inst;
EXPECT_THAT(*inst, op::Call());

// Check that the non-annotated call is still inlined
++inst;
EXPECT_THAT(*inst, op::Constant());
++inst;
EXPECT_THAT(*inst, op::Add());
}

TEST_F(CallInlinerTest, ControlDepsPropagateToRootOfInlinedInstructions) {
const char* hlo = R"(
HloModule test
Expand Down