Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use source tensor sizes after the sharding created so far when tie-breaking between candidates during explicit reshards. #429

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,32 @@ class FactorAxesCandidateBag {
bestCandidate = std::max(bestCandidate, candidate);
}

// Updates the source tensor sizes of all candidates.
// TODO(enver): Optimize updating source tensor sizes.
void updateSourceTensorSizes(const ShardingProjection& projection,
OpShardingRuleAttr shardingRule,
const SmallVector<AxisListRef>& factorAxisRefs) {
for (const auto& [tensorIndex, tensorFactorSharding] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
int64_t localTensorSize = shardingRule.getTensorSizes()[tensorIndex];
for (const auto& [factorIndex, _] :
tensorFactorSharding.factorIndexToSharding) {
// TODO(enver): Consider cases tensor size may not be divisable.
localTensorSize /= factorAxisRefs[factorIndex].getShardingSize(mesh);
}
for (const auto& [factorIndex, _] :
tensorFactorSharding.factorIndexToSharding) {
int64_t candidateIndex = 0;
while (candidateIndex < size()) {
updateSourceTensorSizeAt(factorIndex, candidateIndex,
localTensorSize);
candidateIndex++;
}
}
}
}

// Resets best. Performs in constant-time.
void resetBest() { bestCandidate = FactorAxesCandidate(); }

Expand All @@ -394,6 +420,16 @@ class FactorAxesCandidateBag {
int64_t size() const { return candidates.size(); }

private:
void updateSourceTensorSizeAt(const int64_t factorIndex, const int64_t index,
const int64_t sourceTensorSize) {
FactorAxesCandidate& candidate = candidates[index];
if (candidate.factorAxes.factorIndex == factorIndex) {
candidate.sourceTensorSize =
std::max(candidate.sourceTensorSize, sourceTensorSize);
bestCandidate = std::max(bestCandidate, candidate);
}
}

SmallVector<FactorAxesCandidate> candidates;
FactorAxesCandidate bestCandidate;
// Used for recalculating sharding size of a candidate.
Expand Down Expand Up @@ -543,6 +579,11 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic(
}
factorAxesCandidates.updateShardingSizeAt(candidateIndex++);
}

// TODO(enver): Optimize updating source tensor sizes.
factorAxesCandidates.resetBest();
factorAxesCandidates.updateSourceTensorSizes(projection, shardingRule,
factorAxisRefs);
}

// TODO(enver): Consider to keep factorAxisRefs for longer until acutall
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,28 @@ func.func @dot_incompatible_i_mismatch(%arg0: tensor<8x32xf32> {sdy.sharding = #

// CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded
func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{"x"}, {}]> : tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD]] : tensor<8x16xf32>
// TODO(b/404475296): The cost of a2a is smaller than all-gather, hence it could reshard the result instead.
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32>
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}


// CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded_smaller_local_contracting_dim
func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded_smaller_local_contracting_dim(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<16x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32>
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}

// CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_rhs_non_contracting_dim_is_sharded
func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_rhs_non_contracting_dim_is_sharded(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) {
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
Expand Down Expand Up @@ -371,11 +385,11 @@ func.func @dot_incompatible_in_out_mismatch_i_j_swapped_large_k(%arg0: tensor<8x

// CHECK-LABEL: func @dot_incompatible_sub_axis_overlaps
func.func @dot_incompatible_sub_axis_overlaps(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(2)2}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) {
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x":(1)2}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x":(2)2}, {"x":(1)2}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{"x":(2)2}, {"x":(1)2}]> : tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[RESHARD2]] : tensor<8x16xf32>
// CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32>
// CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32>
// CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32>
%0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
Expand Down