@@ -19,6 +19,7 @@ limitations under the License.
1919#include < utility>
2020#include < vector>
2121
22+ #include " absl/algorithm/container.h"
2223#include " absl/strings/str_cat.h"
2324#include " tensorflow/compiler/xla/literal.h"
2425#include " tensorflow/compiler/xla/service/call_graph.h"
@@ -55,15 +56,24 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
5556 }
5657
5758 // We can always inline a 1-branch conditional due to default branch fallback.
58- int branch_index = 0 ;
59- if (conditional->branch_count () > 1 ) {
60- if (conditional->operand (0 )->opcode () != HloOpcode::kConstant ) {
61- VLOG (2 ) << " Not attempting to remove conditional as its branch_index is "
62- " not a compile-time constant: "
63- << conditional->ToShortString ();
64- return false ;
65- }
59+ auto computation = conditional->parent ();
60+ auto create_call = [&](int64 branch) {
61+ auto call = computation->AddInstruction (HloInstruction::CreateCall (
62+ conditional->shape (), {conditional->mutable_operand (1 + branch)},
63+ conditional->branch_computation (branch)));
64+ conditional->SetupDerivedInstruction (call);
65+ return call;
66+ };
67+
68+ if (conditional->branch_count () == 1 ) {
69+ HloInstruction* call_op = create_call (0 );
70+ TF_RETURN_IF_ERROR (computation->ReplaceInstruction (conditional, call_op));
71+ TF_RETURN_IF_ERROR (CallInliner::Inline (call_op).status ());
72+ return true ;
73+ }
6674
75+ if (conditional->operand (0 )->opcode () == HloOpcode::kConstant ) {
76+ int branch_index = 0 ;
6777 if (conditional->operand (0 )->shape ().element_type () == PRED) {
6878 branch_index = conditional->operand (0 )->literal ().Get <bool >({}) ? 0 : 1 ;
6979 } else {
@@ -72,16 +82,83 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
7282 branch_index = conditional->branch_count () - 1 ;
7383 }
7484 }
85+ HloInstruction* call_op = create_call (branch_index);
86+ TF_RETURN_IF_ERROR (computation->ReplaceInstruction (conditional, call_op));
87+ TF_RETURN_IF_ERROR (CallInliner::Inline (call_op).status ());
88+
89+ return true ;
7590 }
76- auto computation = conditional->parent ();
77- HloInstruction* call_op;
78- call_op = computation->AddInstruction (HloInstruction::CreateCall (
79- conditional->shape (), {conditional->mutable_operand (branch_index + 1 )},
80- conditional->branch_computation (branch_index)));
81- conditional->SetupDerivedInstruction (call_op);
82- TF_RETURN_IF_ERROR (computation->ReplaceInstruction (conditional, call_op));
83- TF_RETURN_IF_ERROR (CallInliner::Inline (call_op).status ());
8491
92+ auto instruction_is_expensive = [](const HloInstruction* hlo) {
93+ switch (hlo->opcode ()) {
94+ case HloOpcode::kBroadcast :
95+ case HloOpcode::kConcatenate :
96+ case HloOpcode::kDynamicSlice :
97+ case HloOpcode::kDynamicUpdateSlice :
98+ case HloOpcode::kGetTupleElement :
99+ case HloOpcode::kReduce :
100+ case HloOpcode::kReshape :
101+ case HloOpcode::kPad :
102+ case HloOpcode::kParameter :
103+ case HloOpcode::kSlice :
104+ case HloOpcode::kTuple :
105+ return false ;
106+ default :
107+ return !hlo->IsElementwise ();
108+ }
109+ };
110+
111+ if (conditional->branch_count () != 2 ||
112+ conditional->operand (0 )->shape ().element_type () != PRED ||
113+ absl::c_any_of (conditional->branch_computation (0 )->instructions (),
114+ instruction_is_expensive) ||
115+ absl::c_any_of (conditional->branch_computation (1 )->instructions (),
116+ instruction_is_expensive)) {
117+ VLOG (2 )
118+ << " Not attempting to remove conditional as its branch_index is not a "
119+ " compile-time constant or contains expensive instructions: "
120+ << conditional->ToShortString ();
121+ return false ;
122+ }
123+
124+ HloInstruction* true_call_op = create_call (0 );
125+ HloInstruction* false_call_op = create_call (1 );
126+ auto condition_broadcast = [&](const Shape& shape) {
127+ if (ShapeUtil::IsScalar (shape)) {
128+ return conditional->mutable_operand (0 );
129+ }
130+ return computation->AddInstruction (HloInstruction::CreateBroadcast (
131+ ShapeUtil::ChangeElementType (shape, PRED),
132+ conditional->mutable_operand (0 ), {}));
133+ };
134+
135+ auto gte = [&](HloInstruction* hlo, int64 i) {
136+ return computation->AddInstruction (HloInstruction::CreateGetTupleElement (
137+ hlo->shape ().tuple_shapes (i), hlo, i));
138+ };
139+ std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
140+ [&](HloInstruction* t, HloInstruction* f) {
141+ if (f->shape ().IsArray ()) {
142+ return computation->AddInstruction (HloInstruction::CreateTernary (
143+ f->shape (), HloOpcode::kSelect , condition_broadcast (f->shape ()),
144+ t, f));
145+ }
146+ std::vector<HloInstruction*> selects;
147+ const int64 tuple_element_count =
148+ ShapeUtil::TupleElementCount (f->shape ());
149+ selects.reserve (tuple_element_count);
150+ for (int64 i = 0 ; i < tuple_element_count; ++i) {
151+ selects.push_back (select (gte (t, i), gte (f, i)));
152+ }
153+ return computation->AddInstruction (
154+ HloInstruction::CreateTuple (selects));
155+ };
156+
157+ TF_RETURN_IF_ERROR (computation->ReplaceInstruction (
158+ conditional, select (true_call_op, false_call_op)));
159+
160+ TF_RETURN_IF_ERROR (CallInliner::Inline (false_call_op).status ());
161+ TF_RETURN_IF_ERROR (CallInliner::Inline (true_call_op).status ());
85162 return true ;
86163}
87164StatusOr<bool > TryRemoveUnusedConditionalOperands (
0 commit comments