Skip to content

Commit ebe48a8

Browse files
[XLA] Convert simple conditionals in to array select instructions to allow for
fusion and avoid copies in buffer assignment. PiperOrigin-RevId: 251500037
1 parent f589c25 commit ebe48a8

File tree

3 files changed

+111
-19
lines changed

3 files changed

+111
-19
lines changed

tensorflow/compiler/xla/service/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -1841,6 +1841,7 @@ cc_library(
18411841
"//tensorflow/compiler/xla:types",
18421842
"//tensorflow/compiler/xla:util",
18431843
"//tensorflow/core:lib",
1844+
"@com_google_absl//absl/algorithm:container",
18441845
"@com_google_absl//absl/strings",
18451846
],
18461847
)

tensorflow/compiler/xla/service/conditional_simplifier.cc

+93-16
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <utility>
2020
#include <vector>
2121

22+
#include "absl/algorithm/container.h"
2223
#include "absl/strings/str_cat.h"
2324
#include "tensorflow/compiler/xla/literal.h"
2425
#include "tensorflow/compiler/xla/service/call_graph.h"
@@ -55,15 +56,24 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
5556
}
5657

5758
// We can always inline a 1-branch conditional due to default branch fallback.
58-
int branch_index = 0;
59-
if (conditional->branch_count() > 1) {
60-
if (conditional->operand(0)->opcode() != HloOpcode::kConstant) {
61-
VLOG(2) << "Not attempting to remove conditional as its branch_index is "
62-
"not a compile-time constant: "
63-
<< conditional->ToShortString();
64-
return false;
65-
}
59+
auto computation = conditional->parent();
60+
auto create_call = [&](int64 branch) {
61+
auto call = computation->AddInstruction(HloInstruction::CreateCall(
62+
conditional->shape(), {conditional->mutable_operand(1 + branch)},
63+
conditional->branch_computation(branch)));
64+
conditional->SetupDerivedInstruction(call);
65+
return call;
66+
};
67+
68+
if (conditional->branch_count() == 1) {
69+
HloInstruction* call_op = create_call(0);
70+
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
71+
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
72+
return true;
73+
}
6674

75+
if (conditional->operand(0)->opcode() == HloOpcode::kConstant) {
76+
int branch_index = 0;
6777
if (conditional->operand(0)->shape().element_type() == PRED) {
6878
branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1;
6979
} else {
@@ -72,16 +82,83 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
7282
branch_index = conditional->branch_count() - 1;
7383
}
7484
}
85+
HloInstruction* call_op = create_call(branch_index);
86+
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
87+
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
88+
89+
return true;
7590
}
76-
auto computation = conditional->parent();
77-
HloInstruction* call_op;
78-
call_op = computation->AddInstruction(HloInstruction::CreateCall(
79-
conditional->shape(), {conditional->mutable_operand(branch_index + 1)},
80-
conditional->branch_computation(branch_index)));
81-
conditional->SetupDerivedInstruction(call_op);
82-
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op));
83-
TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status());
8491

92+
auto instruction_is_expensive = [](const HloInstruction* hlo) {
93+
switch (hlo->opcode()) {
94+
case HloOpcode::kBroadcast:
95+
case HloOpcode::kConcatenate:
96+
case HloOpcode::kDynamicSlice:
97+
case HloOpcode::kDynamicUpdateSlice:
98+
case HloOpcode::kGetTupleElement:
99+
case HloOpcode::kReduce:
100+
case HloOpcode::kReshape:
101+
case HloOpcode::kPad:
102+
case HloOpcode::kParameter:
103+
case HloOpcode::kSlice:
104+
case HloOpcode::kTuple:
105+
return false;
106+
default:
107+
return !hlo->IsElementwise();
108+
}
109+
};
110+
111+
if (conditional->branch_count() != 2 ||
112+
conditional->operand(0)->shape().element_type() != PRED ||
113+
absl::c_any_of(conditional->branch_computation(0)->instructions(),
114+
instruction_is_expensive) ||
115+
absl::c_any_of(conditional->branch_computation(1)->instructions(),
116+
instruction_is_expensive)) {
117+
VLOG(2)
118+
<< "Not attempting to remove conditional as its branch_index is not a "
119+
"compile-time constant or contains expensive instructions: "
120+
<< conditional->ToShortString();
121+
return false;
122+
}
123+
124+
HloInstruction* true_call_op = create_call(0);
125+
HloInstruction* false_call_op = create_call(1);
126+
auto condition_broadcast = [&](const Shape& shape) {
127+
if (ShapeUtil::IsScalar(shape)) {
128+
return conditional->mutable_operand(0);
129+
}
130+
return computation->AddInstruction(HloInstruction::CreateBroadcast(
131+
ShapeUtil::ChangeElementType(shape, PRED),
132+
conditional->mutable_operand(0), {}));
133+
};
134+
135+
auto gte = [&](HloInstruction* hlo, int64 i) {
136+
return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
137+
hlo->shape().tuple_shapes(i), hlo, i));
138+
};
139+
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
140+
[&](HloInstruction* t, HloInstruction* f) {
141+
if (f->shape().IsArray()) {
142+
return computation->AddInstruction(HloInstruction::CreateTernary(
143+
f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()),
144+
t, f));
145+
}
146+
std::vector<HloInstruction*> selects;
147+
const int64 tuple_element_count =
148+
ShapeUtil::TupleElementCount(f->shape());
149+
selects.reserve(tuple_element_count);
150+
for (int64 i = 0; i < tuple_element_count; ++i) {
151+
selects.push_back(select(gte(t, i), gte(f, i)));
152+
}
153+
return computation->AddInstruction(
154+
HloInstruction::CreateTuple(selects));
155+
};
156+
157+
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
158+
conditional, select(true_call_op, false_call_op)));
159+
160+
TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status());
161+
TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status());
85162
return true;
86163
}
87164
StatusOr<bool> TryRemoveUnusedConditionalOperands(

tensorflow/compiler/xla/service/conditional_simplifier_test.cc

+17-3
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ namespace op = xla::testing::opcode_matchers;
4141
class ConditionalSimplifierTest : public HloTestBase {
4242
public:
4343
// Makes a computation that contains a conditional with constant predicate.
44-
HloComputation* MakeConditional(HloModule* module);
44+
HloComputation* MakeConditional(HloModule* module, bool is_constant = true);
4545
};
4646

47-
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
47+
HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module,
48+
bool is_constant) {
4849
HloComputation::Builder builder(TestName());
4950

5051
// true_computation returns param+1.
@@ -83,7 +84,10 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
8384
}
8485

8586
auto false_instrn = builder.AddInstruction(
86-
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
87+
is_constant
88+
? HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))
89+
: HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}),
90+
"cond"));
8791
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
8892
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
8993
auto one = builder.AddInstruction(
@@ -104,6 +108,16 @@ TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) {
104108
op::Add(op::Parameter(), op::Constant()));
105109
}
106110

111+
TEST_F(ConditionalSimplifierTest, BranchGetsInlined) {
112+
auto m = CreateNewVerifiedModule();
113+
HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false);
114+
ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie());
115+
EXPECT_THAT(
116+
computation->root_instruction(),
117+
op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()),
118+
op::Add(op::Parameter(0), op::Constant())));
119+
}
120+
107121
TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
108122
auto m = CreateNewVerifiedModule();
109123
HloComputation* computation = MakeConditional(m.get());

0 commit comments

Comments
 (0)