From c6e0978faffb751ce098eb05cb56d9c068c8569a Mon Sep 17 00:00:00 2001 From: shardy authors Date: Thu, 27 Mar 2025 09:33:34 -0700 Subject: [PATCH] Use source tensor sizes after the sharding created so far when tie-breaking between candidates during explicit reshards. Instead of source tensor sizes from unsharded tensors. It for example prefers: reshard lhs: {"x":(2)2}, {"y"} -> {}, {"y"} reshard rhs: {"y"}, {} -> {"y"}, {"x"} dot to obtain the result in sharding {}, {"x"} all-reduce along "y" return all-reduce instead of: reshard rhs: {"y"}, {} -> {"y"}, {"x":(1)2} dot to obtain the result in sharding {"x":(2)2}, {"x":(1)2} all-reduce along "y" reshard to {}, {"x"} return reshard for the following example: func.func @main( %arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(2)2}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) { %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> } PiperOrigin-RevId: 741182516 --- .../export/insert_explicit_reshards.cc | 41 +++++++++++++++++++ .../export/test/insert_explicit_reshards.mlir | 32 +++++++++++---- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index e1c46c8c..5b302eeb 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -370,6 +370,32 @@ class FactorAxesCandidateBag { bestCandidate = std::max(bestCandidate, candidate); } + // Updates the source tensor sizes of all candidates. + // TODO(enver): Optimize updating source tensor sizes. + void updateSourceTensorSizes(const ShardingProjection& projection, + OpShardingRuleAttr shardingRule, + const SmallVector& factorAxisRefs) { + for (const auto& [tensorIndex, tensorFactorSharding] : + llvm::enumerate(llvm::concat( + projection.getOperands(), projection.getResults()))) { + int64_t localTensorSize = shardingRule.getTensorSizes()[tensorIndex]; + for (const auto& [factorIndex, _] : + tensorFactorSharding.factorIndexToSharding) { + // TODO(enver): Consider cases tensor size may not be divisable. + localTensorSize /= factorAxisRefs[factorIndex].getShardingSize(mesh); + } + for (const auto& [factorIndex, _] : + tensorFactorSharding.factorIndexToSharding) { + int64_t candidateIndex = 0; + while (candidateIndex < size()) { + updateSourceTensorSizeAt(factorIndex, candidateIndex, + localTensorSize); + candidateIndex++; + } + } + } + } + // Resets best. Performs in constant-time. void resetBest() { bestCandidate = FactorAxesCandidate(); } @@ -394,6 +420,16 @@ class FactorAxesCandidateBag { int64_t size() const { return candidates.size(); } private: + void updateSourceTensorSizeAt(const int64_t factorIndex, const int64_t index, + const int64_t sourceTensorSize) { + FactorAxesCandidate& candidate = candidates[index]; + if (candidate.factorAxes.factorIndex == factorIndex) { + candidate.sourceTensorSize = + std::max(candidate.sourceTensorSize, sourceTensorSize); + bestCandidate = std::max(bestCandidate, candidate); + } + } + SmallVector candidates; FactorAxesCandidate bestCandidate; // Used for recalculating sharding size of a candidate. @@ -543,6 +579,11 @@ AxesPerFactor findCommonAxesUsingMajorityVoteHeuristic( } factorAxesCandidates.updateShardingSizeAt(candidateIndex++); } + + // TODO(enver): Optimize updating source tensor sizes. + factorAxesCandidates.resetBest(); + factorAxesCandidates.updateSourceTensorSizes(projection, shardingRule, + factorAxisRefs); } // TODO(enver): Consider to keep factorAxisRefs for longer until acutall diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir index 9f17cb1f..87b68f23 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir @@ -313,14 +313,28 @@ func.func @dot_incompatible_i_mismatch(%arg0: tensor<8x32xf32> {sdy.sharding = # // CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) { - // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{"x"}, {}]> : tensor<8x16xf32> - // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh, [{}, {"x"}]> : tensor<8x16xf32> - // CHECK-NEXT: return %[[RESHARD]] : tensor<8x16xf32> + // TODO(b/404475296): The cost of a2a is smaller than all-gather, hence it could reshard the result instead. + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<32x16xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32> + // CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32> %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> } + +// CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded_smaller_local_contracting_dim +func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_lhs_non_contracting_dim_is_sharded_smaller_local_contracting_dim(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) { + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x16xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<16x16xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32> + // CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32> + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x16xf32>, tensor<16x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + // CHECK-LABEL: func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_rhs_non_contracting_dim_is_sharded func.func @dot_incompatible_in_out_mismatch_same_axis_on_different_factors_rhs_non_contracting_dim_is_sharded(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> @@ -371,11 +385,11 @@ func.func @dot_incompatible_in_out_mismatch_i_j_swapped_large_k(%arg0: tensor<8x // CHECK-LABEL: func @dot_incompatible_sub_axis_overlaps func.func @dot_incompatible_sub_axis_overlaps(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(2)2}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}]>}) { - // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x":(1)2}]> : tensor<32x16xf32> - // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %[[RESHARD1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x":(2)2}, {"x":(1)2}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{"x":(2)2}, {"x":(1)2}]> : tensor<8x16xf32> - // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[ALL_REDUCE]] <@mesh, [{}, {"x"}]> : tensor<8x16xf32> - // CHECK-NEXT: return %[[RESHARD2]] : tensor<8x16xf32> + // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}]> : tensor<8x32xf32> + // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %arg1 <@mesh, [{"y"}, {"x"}]> : tensor<32x16xf32> + // CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %[[RESHARD1]], %[[RESHARD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ALL_REDUCE:.*]] = sdy.all_reduce {"y"} %[[DOT]] out_sharding=<@mesh, [{}, {"x"}]> : tensor<8x16xf32> + // CHECK-NEXT: return %[[ALL_REDUCE]] : tensor<8x16xf32> %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}]>]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> }