Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy support ShapedTensor in Shardy as long as it has a rank and static shape. #45

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
@@ -430,7 +431,8 @@ TensorShardingPerValueAttr TensorShardingPerValueAttr::getFullyOpen(
for (Type type : types) {
int64_t rank = 0;
// TODO(tomnatan): remove mlir:: once Attribute::dyn_cast is removed.
if (auto tensorType = mlir::dyn_cast<RankedTensorType>(type)) {
if (auto tensorType = mlir::dyn_cast<ShapedType>(type)) {
assert(tensorType.hasStaticShape());
rank = tensorType.getRank();
}
shardingPerResult.push_back(
12 changes: 6 additions & 6 deletions shardy/dialect/sdy/ir/ops.td
Original file line number Diff line number Diff line change
@@ -135,12 +135,12 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation",
}];

let arguments = (ins
Variadic<AnyTensor>:$tensors,
Variadic<AnyRankedTensor>:$tensors,
Sdy_TensorShardingPerValue:$in_shardings,
Sdy_TensorShardingPerValue:$out_shardings,
Sdy_ManualAxes:$manual_axes
);
let results = (outs Variadic<AnyTensor>:$results);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat = [{
@@ -337,10 +337,10 @@ def DataFlowEdgeOp : Sdy_Op<"data_flow_edge",
}];

let arguments = (ins
AnyRankedTensor:$input,
AnyShaped:$input,
OptionalAttr<Sdy_TensorSharding>:$sharding);

let results = (outs AnyRankedTensor:$result);
let results = (outs AnyShaped:$result);

let assemblyFormat = "$input (`sharding````=``` $sharding^)? attr-dict `:` type($result)";

@@ -381,10 +381,10 @@ def PropagationBarrierOp : Sdy_Op<"propagation_barrier",
}];

let arguments = (ins
AnyTensor:$input,
AnyRankedTensor:$input,
Sdy_PropagationDirection:$allowed_direction
);
let results = (outs AnyTensor:$result);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = "$input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)";
let hasVerifier = 1;
}
9 changes: 9 additions & 0 deletions shardy/dialect/sdy/ir/test/data_flow_edge_verification.mlir
Original file line number Diff line number Diff line change
@@ -12,6 +12,15 @@ func.func @invalid_sharding(%arg0 : tensor<8xf32>) -> tensor<8xf32> {

// -----

func.func @dynamic_shaped_type(%arg0: tensor<?x?xf32>)
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
// expected-error @+1 {{expected sdy.data_flow_edge to have a static-shaped result}}
%0 = sdy.data_flow_edge %arg0 : tensor<?x?xf32>
return %arg0, %0 : tensor<?x?xf32>, tensor<?x?xf32>
}

// -----

func.func @input_has_multiple_users(%arg0: tensor<32x96xf32>)
-> (tensor<32x96xf32>, tensor<32x96xf32>) {
// expected-error @+1 {{expected input of sdy.data_flow_edge to have a single user}}
16 changes: 16 additions & 0 deletions shardy/dialect/sdy/ir/test/sharding_rule_verification.mlir
Original file line number Diff line number Diff line change
@@ -16,6 +16,22 @@ func.func @sharding_rule_wrong_attr_type(%arg0: tensor<8xf32>) -> tensor<8xf32>

// -----

func.func @unranked_tensor_type(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// expected-error@+1 {{operand 0 - expected a ranked tensor with a static shape}}
%0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=2, j=4}>} : tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

func.func @dynamic_shaped_tensor_type(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{operand 0 - expected a ranked tensor with a static shape}}
%0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=2, j=4}>} : tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----

func.func @operand_mappings_wrong_rank(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// expected-error@+1 {{operand 1 - mapping rank must match: 1 != 2}}
%0 = stablehlo.add %arg0, %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i])->([i, j]) {i=2, j=4}>} : tensor<2x4xf32>
22 changes: 20 additions & 2 deletions shardy/dialect/sdy/ir/test/tensor_sharding_verification.mlir
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

sdy.mesh @mesh = <"a"=2>

// expected-error @+1 {{'func.func' op arg 0 - non-ranked tensors can only have a sharding with rank 0 and no replicated axes}}
// expected-error @+1 {{'func.func' op arg 0 - non-shaped tensors can only have a sharding with rank 0 and no replicated axes}}
func.func @token_sharding_rank_non_zero(%arg0: !stablehlo.token {sdy.sharding=#sdy.sharding<@mesh, [{}]>}) -> !stablehlo.token {
return %arg0 : !stablehlo.token
}
@@ -11,13 +11,31 @@ func.func @token_sharding_rank_non_zero(%arg0: !stablehlo.token {sdy.sharding=#s

sdy.mesh @mesh = <"a"=2>

// expected-error @+1 {{'func.func' op arg 0 - non-ranked tensors can only have a sharding with rank 0 and no replicated axes}}
// expected-error @+1 {{'func.func' op arg 0 - non-shaped tensors can only have a sharding with rank 0 and no replicated axes}}
func.func @token_sharding_with_replicated_axes(%arg0: !stablehlo.token {sdy.sharding=#sdy.sharding<@mesh, [], replicated={"a"}>}) -> !stablehlo.token {
return %arg0 : !stablehlo.token
}

// -----

sdy.mesh @mesh = <"a"=2>

// expected-error @+1 {{'func.func' op arg 0 - only ranked tensors with a static shape can have a sharding}}
func.func @unranked_tensor_with_sharding(%arg0: tensor<*xf32> {sdy.sharding=#sdy.sharding<@mesh, []>}) -> tensor<*xf32> {
return %arg0 : tensor<*xf32>
}

// -----

sdy.mesh @mesh = <"a"=2>

// expected-error @+1 {{'func.func' op arg 0 - only ranked tensors with a static shape can have a sharding}}
func.func @dynamic_shaped_tensor_with_sharding(%arg0: tensor<*xf32> {sdy.sharding=#sdy.sharding<@mesh, [{}, {}]>}) -> tensor<?x?xf32> {
return %arg0 : tensor<*xf32>
}

// -----

sdy.mesh @mesh = <"a"=2, "b"=2>

func.func @dim_shardings_rank_mismatch(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
22 changes: 17 additions & 5 deletions shardy/dialect/sdy/ir/utils.cc
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@@ -91,26 +92,37 @@ std::string operationToString(Operation* op) {
return mlirToString(op);
}

std::string valueToString(Value value) {
return mlirToString(&value);
std::string valueToString(Value value) { return mlirToString(&value); }

ShapedType dynCastStaticShapedType(Type type) {
if (auto shapedType = dyn_cast<ShapedType>(type);
shapedType && shapedType.hasStaticShape()) {
return shapedType;
}
return nullptr;
}

bool isStaticShapedType(Type type) {
return dynCastStaticShapedType(type) != nullptr;
}

ArrayRef<int64_t> getTensorShape(Value value) {
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
if (auto tensorType = dyn_cast<ShapedType>(value.getType())) {
return tensorType.getShape();
}
return {};
}

int64_t getTensorRank(Value value) {
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
if (auto tensorType = dyn_cast<ShapedType>(value.getType())) {
return tensorType.getRank();
}
return 0;
}

int64_t isScalar(Value value) {
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
if (auto tensorType = dyn_cast<ShapedType>(value.getType());
tensorType && tensorType.hasRank()) {
return tensorType.getRank() == 0;
}
return false;
16 changes: 14 additions & 2 deletions shardy/dialect/sdy/ir/utils.h
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -65,12 +66,23 @@ std::string operationToString(Operation* op);
// Converts `value` to string with location information.
std::string valueToString(Value value);

// Returns the shape of the given `value` if its type is a `RankedTensorType`,
// If the given `type` is a `ShapedType` with a static shape, returns it,
// otherwise returns nullptr.
ShapedType dynCastStaticShapedType(Type type);

// Returns true if the given `type` is a `ShapedType` with a static shape.
bool isStaticShapedType(Type type);

// Returns the shape of the given `value` if its type is a `ShapeTensor`,
// otherwise returns an empty array.
//
// Assumes the `ShapeTensor` has a rank.
ArrayRef<int64_t> getTensorShape(Value value);

// Returns the rank of the given `value` if its type is a `RankedTensorType`,
// Returns the rank of the given `value` if its type is a `ShapeTensor`,
// otherwise returns 0.
//
// Assumes the `ShapeTensor` has a rank.
int64_t getTensorRank(Value value);

// Returns true if the value is a tensor with rank 0.
35 changes: 26 additions & 9 deletions shardy/dialect/sdy/ir/verifiers.cc
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/SymbolTable.h"
@@ -199,11 +200,11 @@ LogicalResult emitBoundAxisInManualComputationError(EmitErrorFn emitError,

// Verifies the following for `shardingAttr`:
//
// If `type` isn't a `RankedTensorType`, the sharding must have rank 0 and no
// If `type` isn't a `ShapedType`, the sharding must have rank 0 and no
// replicated axes.
//
// - The number of dimension shardings is equal to the rank of the tensor
// (specified by `type`, which should be a `RankedTensorType`).
// - The tensor should have a rank and static shape.
// - The number of dimension shardings is equal to the rank of the tensor.
// - Dimensions of size 0 aren't sharded.
// - Replicated axes are ordered w.r.t. `mesh` (see
// AxisRefAttr::getMeshComparator).
@@ -220,17 +221,22 @@ LogicalResult verifyTensorShardingAttr(
TensorShardingAttr shardingAttr, Type type, MeshAttr mesh,
EmitErrorFn emitError,
ManualAxisToOwner alreadyManualAxes = ManualAxisToOwner()) {
auto tensorType = dyn_cast<RankedTensorType>(type);
auto tensorType = dyn_cast<ShapedType>(type);
if (!tensorType) {
if (shardingAttr.getRank() != 0 ||
!shardingAttr.getReplicatedAxes().empty()) {
return emitError(
"non-ranked tensors can only have a sharding with rank 0 ")
"non-shaped tensors can only have a sharding with rank 0 ")
<< "and no replicated axes. type: " << type
<< ", sharding: " << shardingAttr;
}
return success();
}
if (!tensorType.hasStaticShape()) {
return emitError(
"only ranked tensors with a static shape can have a sharding. ")
<< "type: " << type;
}
int64_t rank = tensorType.getRank();
if (shardingAttr.getRank() != rank) {
return emitError("sharding doesn't match tensor rank: ")
@@ -426,14 +432,20 @@ LogicalResult verifyShardingRuleMapping(Operation* op, TypeRange types,
// doesn't reuse the same factor.
BitVector valueSeenFactorIndices(factorSizes.size());
auto [type, mapping] = typeAndMapping;
auto tensorType = cast<RankedTensorType>(type);

EmitErrorFn valueEmitError = getEmitValueInRangeErrorFn(
[op, valueKindStr](StringRef msg) {
return op->emitOpError(valueKindStr) << " " << msg;
},
types.size(), index);

auto tensorType = dynCastStaticShapedType(type);
if (!tensorType) {
return valueEmitError(
"expected a ranked tensor with a static shape. type: ")
<< type;
}

if (mapping.getRank() != tensorType.getRank()) {
return valueEmitError("mapping rank must match: ")
<< mapping.getRank() << " != " << tensorType.getRank();
@@ -559,6 +571,11 @@ LogicalResult ReshardOp::verify() {
}

LogicalResult DataFlowEdgeOp::verify() {
if (!getType().hasStaticShape()) {
return emitOpError(
"expected sdy.data_flow_edge to have a static-shaped result. ")
<< "type: " << getType();
}
if (!getInput().hasOneUse()) {
return emitOpError(
"expected input of sdy.data_flow_edge to have a single user");
@@ -665,8 +682,8 @@ LogicalResult verifyManualComputationValue(
for (auto [valueIndex, valueEntry] : llvm::enumerate(llvm::zip_equal(
globalTypes, localTypes, shardingPerValueAttr.getShardings()))) {
auto [globalType, localType, sharding] = valueEntry;
auto globalRankedType = globalType.template cast<RankedTensorType>();
auto localRankedType = localType.template cast<RankedTensorType>();
auto globalRankedType = cast<RankedTensorType>(globalType);
auto localRankedType = cast<RankedTensorType>(localType);

// 5. Verify the manual axes come before any free axes in each dim sharding.
for (auto [dim, dimSharding] :
@@ -693,7 +710,7 @@ LogicalResult verifyManualComputationValue(
accumulatedManualAxesSize(op, dimSharding.getAxes(),
manualAxes, mesh));
}
RankedTensorType expectedLocalRankedType =
auto expectedLocalRankedType =
RankedTensorType::get(newDimSizes, globalRankedType.getElementType());
if (expectedLocalRankedType != localRankedType) {
return op->emitOpError(valueKindStr)
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
@@ -58,8 +59,7 @@ namespace {
// - [{"y","x"}] : tensor<4xf32> -> [{"y","x":(1)2}] : tensor<4xf32>
// See update_non_divisible_input_output_shardings.mlir for more examples.
TensorShardingAttr getEvenlySharded(TensorShardingAttr sharding,
RankedTensorType type,
func::FuncOp funcOp) {
ShapedType type, func::FuncOp funcOp) {
StringRef meshName = sharding.getMeshName();
MeshAttr mesh = getMeshAttr(funcOp, meshName);
assert(mesh && "unknown mesh");
@@ -130,7 +130,7 @@ void updateValueShardings(
func::FuncOp funcOp) {
for (auto [index, type] : llvm::enumerate(types)) {
TensorShardingAttr sharding = getSharding(index);
if (auto tensorType = dyn_cast<RankedTensorType>(type);
if (auto tensorType = dynCastStaticShapedType(type);
sharding && tensorType) {
setSharding(index, getEvenlySharded(sharding, tensorType, funcOp));
}
4 changes: 2 additions & 2 deletions shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc
Original file line number Diff line number Diff line change
@@ -47,8 +47,8 @@ struct AddDataFlowEdgesPass
ValueRange edgeRoots = getDataFlowEdgeRoots(op);
rewriter.setInsertionPointAfter(op);
for (Value edgeRoot : edgeRoots) {
if (!isa<RankedTensorType>(edgeRoot.getType())) {
// Skip non-tensor values, e.g., tokens.
if (!isStaticShapedType(edgeRoot.getType())) {
// Skip non-static-shaped tensors, e.g., tokens.
continue;
}
TensorShardingAttr sharding = nullptr;
10 changes: 10 additions & 0 deletions shardy/dialect/sdy/transforms/import/test/add_data_flow_edges.mlir
Original file line number Diff line number Diff line change
@@ -66,6 +66,16 @@ func.func @optimization_barrier(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf3
return %0#0, %0#1 : tensor<32x96xf32>, tensor<32x96xf32>
}

// CHECK-LABEL: func @optimization_barrier
func.func @optimization_barrier_dynamic_shaped_tensor_skipped(%arg0: tensor<32x96xf32>, %arg1: tensor<?x?xf32>)
-> (tensor<32x96xf32>, tensor<?x?xf32>) {
// CHECK-NEXT: %[[OPT_BARRIER:.*]]:2 = stablehlo.optimization_barrier %arg0, %arg1
// CHECK: %[[EDGE_1:.*]] = sdy.data_flow_edge %[[OPT_BARRIER]]#0
// CHECK-NEXT: return %[[EDGE_1]], %[[OPT_BARRIER]]#1
%0:2 = stablehlo.optimization_barrier %arg0, %arg1 : tensor<32x96xf32>, tensor<?x?xf32>
return %0#0, %0#1 : tensor<32x96xf32>, tensor<?x?xf32>
}

// CHECK-LABEL: func @while_unused_result
func.func @while_unused_result(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> {
// CHECK: %[[C0:.*]] = stablehlo.constant dense<0>
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
@@ -45,7 +46,6 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/data_flow_utils.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h"
#include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h"
@@ -328,9 +328,9 @@ LogicalResult propagateFuncResults(FuncOp funcOp,
const FactorPropagation& factorPropagation) {
for (OpOperand& returnOperand : getBodyTerminatorOpOperands(funcOp)) {
Value returnValue = returnOperand.get();
auto tensorType = dyn_cast<RankedTensorType>(returnValue.getType());
auto tensorType = dynCastStaticShapedType(returnValue.getType());
if (!tensorType) {
// Skip non-tensor values, e.g., tokens.
// Skip non-static-shaped tensors, e.g., tokens.
continue;
}
int64_t resNum = returnOperand.getOperandNumber();
@@ -436,7 +436,7 @@ class PropagateDataFlowEdgeOp : public OpRewritePattern<DataFlowEdgeOp> {
return propagateTensorShardings(
sources, dataFlowEdgeOp.getResult(),
createIdentityShardingRule(
cast<RankedTensorType>(dataFlowEdgeOp.getType()), sources.size()),
cast<ShapedType>(dataFlowEdgeOp.getType()), sources.size()),
dataFlowEdgeOp, rewriter, factorPropagation);
}

Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ limitations under the License.
#include <optional>

#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@@ -87,12 +88,12 @@ OpShardingRuleBuilder::OpShardingRuleBuilder(
resultMappings.reserve(resultTypes.size());
int64_t maxRank = 0;
for (Type operandType : operandTypes) {
int64_t rank = cast<RankedTensorType>(operandType).getRank();
int64_t rank = cast<ShapedType>(operandType).getRank();
maxRank = std::max(maxRank, rank);
operandMappings.push_back(TensorMapping(rank));
}
for (Type resultType : resultTypes) {
int64_t rank = cast<RankedTensorType>(resultType).getRank();
int64_t rank = cast<ShapedType>(resultType).getRank();
maxRank = std::max(maxRank, rank);
resultMappings.push_back(TensorMapping(rank));
}
@@ -125,7 +126,7 @@ OpShardingRuleAttr OpShardingRuleBuilder::build() {
OpShardingRuleAttr OpShardingRuleBuilder::buildPointwise(Operation* op) {
// All results should have the same shape, so we look at the first.
ArrayRef<int64_t> shape =
cast<RankedTensorType>(op->getResultTypes().front()).getShape();
cast<ShapedType>(op->getResultTypes().front()).getShape();

OpShardingRuleBuilder builder(op);

@@ -200,7 +201,7 @@ OpShardingRuleBuilder& OpShardingRuleBuilder::addPointwiseIfDimSizesMatch(
return *this;
}

OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type,
OpShardingRuleAttr createIdentityShardingRule(ShapedType type,
size_t numOperands,
size_t numResults) {
return OpShardingRuleBuilder(SmallVector<Type>(numOperands, type),
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ limitations under the License.
#include <functional>
#include <optional>

#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@@ -118,7 +119,7 @@ class OpShardingRuleBuilder {
// i.e., all operands/results have the same mapping.
//
// NOTE: an empty rule {([])->([])} will be created for scalar ops.
OpShardingRuleAttr createIdentityShardingRule(RankedTensorType type,
OpShardingRuleAttr createIdentityShardingRule(ShapedType type,
size_t numOperands = 1,
size_t numResults = 1);

Original file line number Diff line number Diff line change
@@ -581,17 +581,28 @@ func.func @func_out_sharding(%arg0: tensor<8x8xf32>, %arg1: tensor<8x16xf32>)
return %0 : tensor<8x16xf32>
}

// CHECK-LABEL: func @token_func_output_token_skipped(
// CHECK-LABEL: func @token_func_output_skipped(
// CHECK-SAME: %arg0: !stablehlo.token,
// CHECK-SAME: %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>})
// CHECK-SAME: -> (!stablehlo.token, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) {
func.func @token_func_output_token_skipped(%arg0: !stablehlo.token, %arg1: tensor<8x16xf32>)
func.func @token_func_output_skipped(%arg0: !stablehlo.token, %arg1: tensor<8x16xf32>)
-> (!stablehlo.token, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) {
// CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>]>}
%0 = stablehlo.add %arg1, %arg1 : tensor<8x16xf32>
return %arg0, %0 : !stablehlo.token, tensor<8x16xf32>
}

// CHECK-LABEL: func @dynamic_shaped_func_output_skipped(
// CHECK-SAME: %arg0: tensor<?x?xf32>,
// CHECK-SAME: %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>})
// CHECK-SAME: -> (tensor<?x?xf32>, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) {
func.func @dynamic_shaped_func_output_skipped(%arg0: tensor<?x?xf32>, %arg1: tensor<8x16xf32>)
-> (tensor<?x?xf32>, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}, {"b"}]>}) {
// CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>]>}
%0 = stablehlo.add %arg1, %arg1 : tensor<8x16xf32>
return %arg0, %0 : tensor<?x?xf32>, tensor<8x16xf32>
}

// CHECK-LABEL: func @func_result_intermediate_op_both_updated(
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>})
// CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a", ?}, {"b", ?}]>}) {