@@ -19,6 +19,7 @@ limitations under the License.
19
19
#include < utility>
20
20
#include < vector>
21
21
22
+ #include " absl/algorithm/container.h"
22
23
#include " absl/strings/str_cat.h"
23
24
#include " tensorflow/compiler/xla/literal.h"
24
25
#include " tensorflow/compiler/xla/service/call_graph.h"
@@ -55,15 +56,24 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
55
56
}
56
57
57
58
// We can always inline a 1-branch conditional due to default branch fallback.
58
- int branch_index = 0 ;
59
- if (conditional->branch_count () > 1 ) {
60
- if (conditional->operand (0 )->opcode () != HloOpcode::kConstant ) {
61
- VLOG (2 ) << " Not attempting to remove conditional as its branch_index is "
62
- " not a compile-time constant: "
63
- << conditional->ToShortString ();
64
- return false ;
65
- }
59
+ auto computation = conditional->parent ();
60
+ auto create_call = [&](int64 branch) {
61
+ auto call = computation->AddInstruction (HloInstruction::CreateCall (
62
+ conditional->shape (), {conditional->mutable_operand (1 + branch)},
63
+ conditional->branch_computation (branch)));
64
+ conditional->SetupDerivedInstruction (call);
65
+ return call;
66
+ };
67
+
68
+ if (conditional->branch_count () == 1 ) {
69
+ HloInstruction* call_op = create_call (0 );
70
+ TF_RETURN_IF_ERROR (computation->ReplaceInstruction (conditional, call_op));
71
+ TF_RETURN_IF_ERROR (CallInliner::Inline (call_op).status ());
72
+ return true ;
73
+ }
66
74
75
+ if (conditional->operand (0 )->opcode () == HloOpcode::kConstant ) {
76
+ int branch_index = 0 ;
67
77
if (conditional->operand (0 )->shape ().element_type () == PRED) {
68
78
branch_index = conditional->operand (0 )->literal ().Get <bool >({}) ? 0 : 1 ;
69
79
} else {
@@ -72,16 +82,83 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
72
82
branch_index = conditional->branch_count () - 1 ;
73
83
}
74
84
}
85
+ HloInstruction* call_op = create_call (branch_index);
86
+ TF_RETURN_IF_ERROR (computation->ReplaceInstruction (conditional, call_op));
87
+ TF_RETURN_IF_ERROR (CallInliner::Inline (call_op).status ());
88
+
89
+ return true ;
75
90
}
76
- auto computation = conditional->parent ();
77
- HloInstruction* call_op;
78
- call_op = computation->AddInstruction (HloInstruction::CreateCall (
79
- conditional->shape (), {conditional->mutable_operand (branch_index + 1 )},
80
- conditional->branch_computation (branch_index)));
81
- conditional->SetupDerivedInstruction (call_op);
82
- TF_RETURN_IF_ERROR (computation->ReplaceInstruction (conditional, call_op));
83
- TF_RETURN_IF_ERROR (CallInliner::Inline (call_op).status ());
84
91
92
+ auto instruction_is_expensive = [](const HloInstruction* hlo) {
93
+ switch (hlo->opcode ()) {
94
+ case HloOpcode::kBroadcast :
95
+ case HloOpcode::kConcatenate :
96
+ case HloOpcode::kDynamicSlice :
97
+ case HloOpcode::kDynamicUpdateSlice :
98
+ case HloOpcode::kGetTupleElement :
99
+ case HloOpcode::kReduce :
100
+ case HloOpcode::kReshape :
101
+ case HloOpcode::kPad :
102
+ case HloOpcode::kParameter :
103
+ case HloOpcode::kSlice :
104
+ case HloOpcode::kTuple :
105
+ return false ;
106
+ default :
107
+ return !hlo->IsElementwise ();
108
+ }
109
+ };
110
+
111
+ if (conditional->branch_count () != 2 ||
112
+ conditional->operand (0 )->shape ().element_type () != PRED ||
113
+ absl::c_any_of (conditional->branch_computation (0 )->instructions (),
114
+ instruction_is_expensive) ||
115
+ absl::c_any_of (conditional->branch_computation (1 )->instructions (),
116
+ instruction_is_expensive)) {
117
+ VLOG (2 )
118
+ << " Not attempting to remove conditional as its branch_index is not a "
119
+ " compile-time constant or contains expensive instructions: "
120
+ << conditional->ToShortString ();
121
+ return false ;
122
+ }
123
+
124
+ HloInstruction* true_call_op = create_call (0 );
125
+ HloInstruction* false_call_op = create_call (1 );
126
+ auto condition_broadcast = [&](const Shape& shape) {
127
+ if (ShapeUtil::IsScalar (shape)) {
128
+ return conditional->mutable_operand (0 );
129
+ }
130
+ return computation->AddInstruction (HloInstruction::CreateBroadcast (
131
+ ShapeUtil::ChangeElementType (shape, PRED),
132
+ conditional->mutable_operand (0 ), {}));
133
+ };
134
+
135
+ auto gte = [&](HloInstruction* hlo, int64 i) {
136
+ return computation->AddInstruction (HloInstruction::CreateGetTupleElement (
137
+ hlo->shape ().tuple_shapes (i), hlo, i));
138
+ };
139
+ std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
140
+ [&](HloInstruction* t, HloInstruction* f) {
141
+ if (f->shape ().IsArray ()) {
142
+ return computation->AddInstruction (HloInstruction::CreateTernary (
143
+ f->shape (), HloOpcode::kSelect , condition_broadcast (f->shape ()),
144
+ t, f));
145
+ }
146
+ std::vector<HloInstruction*> selects;
147
+ const int64 tuple_element_count =
148
+ ShapeUtil::TupleElementCount (f->shape ());
149
+ selects.reserve (tuple_element_count);
150
+ for (int64 i = 0 ; i < tuple_element_count; ++i) {
151
+ selects.push_back (select (gte (t, i), gte (f, i)));
152
+ }
153
+ return computation->AddInstruction (
154
+ HloInstruction::CreateTuple (selects));
155
+ };
156
+
157
+ TF_RETURN_IF_ERROR (computation->ReplaceInstruction (
158
+ conditional, select (true_call_op, false_call_op)));
159
+
160
+ TF_RETURN_IF_ERROR (CallInliner::Inline (false_call_op).status ());
161
+ TF_RETURN_IF_ERROR (CallInliner::Inline (true_call_op).status ());
85
162
return true ;
86
163
}
87
164
StatusOr<bool > TryRemoveUnusedConditionalOperands (
0 commit comments