Skip to content

Commit 17ce1f0

Browse files
Fix bug in visitDivExpr, visitMulExpr and visitModExpr
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
1 parent 529662a commit 17ce1f0

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
11771177
if (flatExprs[numDims + numSymbols + it.index()] == 0)
11781178
continue;
11791179
AffineExpr expr = it.value();
1180-
auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181-
if (!binaryExpr)
1182-
continue;
1183-
1180+
// Local expression cannot be a dimension, symbol or a constant -- it
1181+
// should be a binary op expression.
1182+
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
11841183
AffineExpr lhs = binaryExpr.getLHS();
11851184
AffineExpr rhs = binaryExpr.getRHS();
11861185
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1295,7 +1294,23 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
12951294
localExprs, context);
12961295
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
12971296
localExprs, context);
1298-
return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
1297+
AffineExpr mulExpr = a * b;
1298+
if (auto constMulExpr = dyn_cast<AffineConstantExpr>(mulExpr)) {
1299+
std::fill(lhs.begin(), lhs.end(), 0);
1300+
lhs[getConstantIndex()] = constMulExpr.getValue();
1301+
return success();
1302+
}
1303+
if (auto dimMulExpr = dyn_cast<AffineDimExpr>(mulExpr)) {
1304+
std::fill(lhs.begin(), lhs.end(), 0);
1305+
lhs[getDimStartIndex() + dimMulExpr.getPosition()] = 1;
1306+
return success();
1307+
}
1308+
if (auto symbolMulExpr = dyn_cast<AffineSymbolExpr>(mulExpr)) {
1309+
std::fill(lhs.begin(), lhs.end(), 0);
1310+
lhs[getSymbolStartIndex() + symbolMulExpr.getPosition()] = 1;
1311+
return success();
1312+
}
1313+
return addLocalVariableSemiAffine(mulLhs, rhs, mulExpr, lhs, lhs.size());
12991314
}
13001315

13011316
// Get the RHS constant.
@@ -1348,6 +1363,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
13481363
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
13491364
localExprs, context);
13501365
AffineExpr modExpr = dividendExpr % divisorExpr;
1366+
if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1367+
std::fill(lhs.begin(), lhs.end(), 0);
1368+
lhs[getConstantIndex()] = constModExpr.getValue();
1369+
return success();
1370+
}
1371+
if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
1372+
std::fill(lhs.begin(), lhs.end(), 0);
1373+
lhs[getDimStartIndex() + dimModExpr.getPosition()] = 1;
1374+
return success();
1375+
}
1376+
if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
1377+
std::fill(lhs.begin(), lhs.end(), 0);
1378+
lhs[getSymbolStartIndex() + symbolModExpr.getPosition()] = 1;
1379+
return success();
1380+
}
13511381
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
13521382
}
13531383

@@ -1482,6 +1512,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
14821512
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
14831513
localExprs, context);
14841514
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1515+
if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1516+
std::fill(lhs.begin(), lhs.end(), 0);
1517+
lhs[getConstantIndex()] = constDivExpr.getValue();
1518+
return success();
1519+
}
1520+
if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
1521+
std::fill(lhs.begin(), lhs.end(), 0);
1522+
lhs[getDimStartIndex() + dimDivExpr.getPosition()] = 1;
1523+
return success();
1524+
}
1525+
if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
1526+
std::fill(lhs.begin(), lhs.end(), 0);
1527+
lhs[getSymbolStartIndex() + symbolDivExpr.getPosition()] = 1;
1528+
return success();
1529+
}
14851530
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
14861531
}
14871532

0 commit comments

Comments
 (0)