|
19 | 19 | #include "mlir/Support/TypeID.h"
|
20 | 20 | #include "llvm/ADT/STLExtras.h"
|
21 | 21 | #include "llvm/Support/MathExtras.h"
|
| 22 | +#include "llvm/Support/raw_ostream.h" |
22 | 23 | #include <numeric>
|
23 | 24 | #include <optional>
|
24 | 25 |
|
@@ -1177,10 +1178,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
|
1177 | 1178 | if (flatExprs[numDims + numSymbols + it.index()] == 0)
|
1178 | 1179 | continue;
|
1179 | 1180 | AffineExpr expr = it.value();
|
1180 |
| - auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr); |
1181 |
| - if (!binaryExpr) |
1182 |
| - continue; |
1183 |
| - |
| 1181 | + // A Local expression cannot be a dimension, symbol or a constant -- it |
| 1182 | + // should be a binary op expression. |
| 1183 | + auto binaryExpr = cast<AffineBinaryOpExpr>(expr); |
1184 | 1184 | AffineExpr lhs = binaryExpr.getLHS();
|
1185 | 1185 | AffineExpr rhs = binaryExpr.getRHS();
|
1186 | 1186 | if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
|
@@ -1295,7 +1295,8 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
|
1295 | 1295 | localExprs, context);
|
1296 | 1296 | AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
1297 | 1297 | localExprs, context);
|
1298 |
| - return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size()); |
| 1298 | + AffineExpr mulExpr = a * b; |
| 1299 | + return addExprToFlattenedList(mulLhs, rhs, lhs, mulExpr); |
1299 | 1300 | }
|
1300 | 1301 |
|
1301 | 1302 | // Get the RHS constant.
|
@@ -1348,7 +1349,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
|
1348 | 1349 | AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
1349 | 1350 | localExprs, context);
|
1350 | 1351 | AffineExpr modExpr = dividendExpr % divisorExpr;
|
1351 |
| - return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size()); |
| 1352 | + return addExprToFlattenedList(modLhs, rhs, lhs, modExpr); |
1352 | 1353 | }
|
1353 | 1354 |
|
1354 | 1355 | int64_t rhsConst = rhs[getConstantIndex()];
|
@@ -1450,6 +1451,27 @@ LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
|
1450 | 1451 | return success();
|
1451 | 1452 | }
|
1452 | 1453 |
|
| 1454 | +LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList( |
| 1455 | + ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, |
| 1456 | + SmallVectorImpl<int64_t> &result, AffineExpr expr) { |
| 1457 | + if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { |
| 1458 | + std::fill(result.begin(), result.end(), 0); |
| 1459 | + result[getConstantIndex()] = constExpr.getValue(); |
| 1460 | + return success(); |
| 1461 | + } |
| 1462 | + if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { |
| 1463 | + std::fill(result.begin(), result.end(), 0); |
| 1464 | + result[getDimStartIndex() + dimExpr.getPosition()] = 1; |
| 1465 | + return success(); |
| 1466 | + } |
| 1467 | + if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) { |
| 1468 | + std::fill(result.begin(), result.end(), 0); |
| 1469 | + result[getSymbolStartIndex() + symExpr.getPosition()] = 1; |
| 1470 | + return success(); |
| 1471 | + } |
| 1472 | + return addLocalVariableSemiAffine(lhs, rhs, expr, result, result.size()); |
| 1473 | +} |
| 1474 | + |
1453 | 1475 | // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
|
1454 | 1476 | // A floordiv is thus flattened by introducing a new local variable q, and
|
1455 | 1477 | // replacing that expression with 'q' while adding the constraints
|
@@ -1482,7 +1504,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
|
1482 | 1504 | AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
|
1483 | 1505 | localExprs, context);
|
1484 | 1506 | AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
|
1485 |
| - return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size()); |
| 1507 | + return addExprToFlattenedList(divLhs, rhs, lhs, divExpr); |
1486 | 1508 | }
|
1487 | 1509 |
|
1488 | 1510 | // This is a pure affine expr; the RHS is a positive constant.
|
|
0 commit comments