@@ -36,7 +36,9 @@ limitations under the License.
36
36
#include " absl/strings/str_join.h"
37
37
#include " absl/strings/string_view.h"
38
38
#include " absl/types/span.h"
39
+ #include " xla/hlo/ir/hlo_casting_utils.h"
39
40
#include " xla/hlo/ir/hlo_instruction.h"
41
+ #include " xla/hlo/ir/hlo_instructions.h"
40
42
#include " xla/hlo/ir/hlo_opcode.h"
41
43
#include " xla/layout.h"
42
44
#include " xla/layout_util.h"
@@ -215,6 +217,13 @@ absl::Status Allocation::UpdateUses(HloComputation* computation,
215
217
}
216
218
TF_RETURN_IF_ERROR (use.instruction ->ReplaceOperandWith (
217
219
use.operand_number , replacement_instruction));
220
+ if (use.instruction ->opcode () == HloOpcode::kFusion &&
221
+ replacement_instruction->shape ().has_layout () &&
222
+ replacement_instruction->shape ().layout ().split_configs_size () > 0 ) {
223
+ HloInstruction* fusion = Cast<HloFusionInstruction>(use.instruction );
224
+ HloInstruction* param = fusion->fused_parameter (use.operand_number );
225
+ *param->mutable_shape () = replacement_instruction->shape ();
226
+ }
218
227
}
219
228
return absl::OkStatus ();
220
229
}
@@ -398,6 +407,10 @@ absl::Status CopyAllocation::Process(const BitcastSplitFn& bitcast_split_fn) {
398
407
if (memory_space () == MemorySpace::kAlternate &&
399
408
mutable_split_shape ().has_value ()) {
400
409
dest_shape = mutable_split_shape ().value ();
410
+ } else if (memory_space () == MemorySpace::kDefault && shape.has_layout () &&
411
+ shape.layout ().split_configs_size () > 0 ) {
412
+ dest_shape = shape;
413
+ dest_shape.mutable_layout ()->clear_split_configs ();
401
414
} else {
402
415
dest_shape = shape;
403
416
}
0 commit comments