Skip to content

Commit f51da54

Browse files
Fix bug in visitDivExpr and visitModExpr
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
1 parent 529662a commit f51da54

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
11771177
if (flatExprs[numDims + numSymbols + it.index()] == 0)
11781178
continue;
11791179
AffineExpr expr = it.value();
1180-
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181-
if (!binaryExpr)
1182-
continue;
1183-
1180+
assert(isa<AffineBinaryOpExpr>(expr) &&
1181+
"local expression cannot be a dimension, symbol or a constant -- it "
1182+
"should be a binary op expression");
1183+
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
11841184
AffineExpr lhs = binaryExpr.getLHS();
11851185
AffineExpr rhs = binaryExpr.getRHS();
11861186
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13481348
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
13491349
localExprs, context);
13501350
AffineExpr modExpr = dividendExpr % divisorExpr;
1351+
if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1352+
std::fill(lhs.begin(), lhs.end(), 0);
1353+
lhs[getConstantIndex()] = constModExpr.getValue();
1354+
return success();
1355+
}
13511356
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
13521357
}
13531358

@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
14821487
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
14831488
localExprs, context);
14841489
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1490+
if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1491+
std::fill(lhs.begin(), lhs.end(), 0);
1492+
lhs[getConstantIndex()] = constDivExpr.getValue();
1493+
return success();
1494+
}
14851495
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
14861496
}
14871497

0 commit comments

Comments
 (0)