19
19
#include " mlir/Support/TypeID.h"
20
20
#include " llvm/ADT/STLExtras.h"
21
21
#include " llvm/Support/MathExtras.h"
22
+ #include " llvm/Support/raw_ostream.h"
22
23
#include < numeric>
23
24
#include < optional>
24
25
@@ -1177,10 +1178,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1177
1178
if (flatExprs[numDims + numSymbols + it.index ()] == 0 )
1178
1179
continue ;
1179
1180
AffineExpr expr = it.value ();
1180
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181
- if (!binaryExpr)
1182
- continue ;
1183
-
1181
+ // A Local expression cannot be a dimension, symbol or a constant -- it
1182
+ // should be a binary op expression.
1183
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
1184
1184
AffineExpr lhs = binaryExpr.getLHS ();
1185
1185
AffineExpr rhs = binaryExpr.getRHS ();
1186
1186
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1274,6 +1274,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1274
1274
operandExprStack.reserve (8 );
1275
1275
}
1276
1276
1277
+ LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList (
1278
+ ArrayRef<int64_t > lhs, ArrayRef<int64_t > rhs,
1279
+ SmallVectorImpl<int64_t > &result, AffineExpr expr) {
1280
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
1281
+ std::fill (result.begin (), result.end (), 0 );
1282
+ result[getConstantIndex ()] = constExpr.getValue ();
1283
+ return success ();
1284
+ }
1285
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1286
+ std::fill (result.begin (), result.end (), 0 );
1287
+ result[getDimStartIndex () + dimExpr.getPosition ()] = 1 ;
1288
+ return success ();
1289
+ }
1290
+ if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
1291
+ std::fill (result.begin (), result.end (), 0 );
1292
+ result[getSymbolStartIndex () + symExpr.getPosition ()] = 1 ;
1293
+ return success ();
1294
+ }
1295
+ return addLocalVariableSemiAffine (lhs, rhs, expr, result, result.size ());
1296
+ }
1297
+
1277
1298
// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1278
1299
//
1279
1300
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
@@ -1295,7 +1316,8 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1295
1316
localExprs, context);
1296
1317
AffineExpr b = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1297
1318
localExprs, context);
1298
- return addLocalVariableSemiAffine (mulLhs, rhs, a * b, lhs, lhs.size ());
1319
+ AffineExpr mulExpr = a * b;
1320
+ return addExprToFlattenedList (mulLhs, rhs, lhs, mulExpr);
1299
1321
}
1300
1322
1301
1323
// Get the RHS constant.
@@ -1348,7 +1370,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1348
1370
AffineExpr divisorExpr = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1349
1371
localExprs, context);
1350
1372
AffineExpr modExpr = dividendExpr % divisorExpr;
1351
- return addLocalVariableSemiAffine (modLhs, rhs, modExpr, lhs, lhs. size () );
1373
+ return addExprToFlattenedList (modLhs, rhs, lhs, modExpr );
1352
1374
}
1353
1375
1354
1376
int64_t rhsConst = rhs[getConstantIndex ()];
@@ -1482,7 +1504,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1482
1504
AffineExpr b = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1483
1505
localExprs, context);
1484
1506
AffineExpr divExpr = isCeil ? a.ceilDiv (b) : a.floorDiv (b);
1485
- return addLocalVariableSemiAffine (divLhs, rhs, divExpr, lhs, lhs. size () );
1507
+ return addExprToFlattenedList (divLhs, rhs, lhs, divExpr );
1486
1508
}
1487
1509
1488
1510
// This is a pure affine expr; the RHS is a positive constant.
0 commit comments