Skip to content

Commit d87c06c

Browse files
Remove split configs from evicted tensors.
PiperOrigin-RevId: 736670739
1 parent c7461fd commit d87c06c

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

xla/service/memory_space_assignment/allocation.cc

+13
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ limitations under the License.
3636
#include "absl/strings/str_join.h"
3737
#include "absl/strings/string_view.h"
3838
#include "absl/types/span.h"
39+
#include "xla/hlo/ir/hlo_casting_utils.h"
3940
#include "xla/hlo/ir/hlo_instruction.h"
41+
#include "xla/hlo/ir/hlo_instructions.h"
4042
#include "xla/hlo/ir/hlo_opcode.h"
4143
#include "xla/layout.h"
4244
#include "xla/layout_util.h"
@@ -215,6 +217,13 @@ absl::Status Allocation::UpdateUses(HloComputation* computation,
215217
}
216218
TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
217219
use.operand_number, replacement_instruction));
220+
if (use.instruction->opcode() == HloOpcode::kFusion &&
221+
replacement_instruction->shape().has_layout() &&
222+
replacement_instruction->shape().layout().split_configs_size() > 0) {
223+
HloInstruction* fusion = Cast<HloFusionInstruction>(use.instruction);
224+
HloInstruction* param = fusion->fused_parameter(use.operand_number);
225+
*param->mutable_shape() = replacement_instruction->shape();
226+
}
218227
}
219228
return absl::OkStatus();
220229
}
@@ -398,6 +407,10 @@ absl::Status CopyAllocation::Process(const BitcastSplitFn& bitcast_split_fn) {
398407
if (memory_space() == MemorySpace::kAlternate &&
399408
mutable_split_shape().has_value()) {
400409
dest_shape = mutable_split_shape().value();
410+
} else if (memory_space() == MemorySpace::kDefault && shape.has_layout() &&
411+
shape.layout().split_configs_size() > 0) {
412+
dest_shape = shape;
413+
dest_shape.mutable_layout()->clear_split_configs();
401414
} else {
402415
dest_shape = shape;
403416
}

0 commit comments

Comments
 (0)