diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h index a6e47473..55acd77d 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_builder.h @@ -87,7 +87,7 @@ class OpShardingRuleBuilder { FactorType factorType = FactorType::kPassThrough); // Adds a pointwise factor for all dimensions of all operands/results that - // have rank at least 1. The factor type is determined by `predFactorType`. + // have rank at least 1. The factor type is determined by `getFactorType`. OpShardingRuleBuilder& addPointwise( ArrayRef shape, std::function getFactorType = [](int64_t) { @@ -96,7 +96,7 @@ class OpShardingRuleBuilder { // Adds a pointwise factor for all dimensions that satisfy `pred` of all // operands/results that have rank at least 1. The factor type is determined - // by `predFactorType`. + // by `getFactorType`. OpShardingRuleBuilder& addPointwiseIf( ArrayRef shape, std::function pred, std::function getFactorType = [](int64_t) { @@ -113,7 +113,7 @@ class OpShardingRuleBuilder { onMismatchFn = [](int64_t dim, OpShardingRuleBuilder& builder) {}); // Adds a pointwise factor for all dimensions of all operands/results that - // have rank at least 1. The factor type is determined by `predFactorType`. + // have rank at least 1. // // Each dimension whose size in `inShape` and `outShape` is different, gets a // `mismatchFactorType` factor type.