@@ -87,7 +87,7 @@ class OpShardingRuleBuilder {
87
87
FactorType factorType = FactorType::kPassThrough );
88
88
89
89
// Adds a pointwise factor for all dimensions of all operands/results that
90
- // have rank at least 1. The factor type is determined by `predFactorType `.
90
+ // have rank at least 1. The factor type is determined by `getFactorType `.
91
91
OpShardingRuleBuilder& addPointwise (
92
92
ArrayRef<int64_t > shape,
93
93
std::function<FactorType(int64_t )> getFactorType = [](int64_t ) {
@@ -96,7 +96,7 @@ class OpShardingRuleBuilder {
96
96
97
97
// Adds a pointwise factor for all dimensions that satisfy `pred` of all
98
98
// operands/results that have rank at least 1. The factor type is determined
99
- // by `predFactorType `.
99
+ // by `getFactorType `.
100
100
OpShardingRuleBuilder& addPointwiseIf (
101
101
ArrayRef<int64_t > shape, std::function<bool (int64_t )> pred,
102
102
std::function<FactorType(int64_t )> getFactorType = [](int64_t ) {
@@ -113,7 +113,7 @@ class OpShardingRuleBuilder {
113
113
onMismatchFn = [](int64_t dim, OpShardingRuleBuilder& builder) {});
114
114
115
115
// Adds a pointwise factor for all dimensions of all operands/results that
116
- // have rank at least 1. The factor type is determined by `predFactorType`.
116
+ // have rank at least 1.
117
117
//
118
118
// Each dimension whose size in `inShape` and `outShape` is different, gets a
119
119
// `mismatchFactorType` factor type.
0 commit comments