@@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1177
1177
if (flatExprs[numDims + numSymbols + it.index ()] == 0 )
1178
1178
continue ;
1179
1179
AffineExpr expr = it.value ();
1180
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
1181
- if (!binaryExpr)
1182
- continue ;
1183
-
1180
+ // Local expression cannot be a dimension, symbol or a constant -- it
1181
+ // should be a binary op expression.
1182
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
1184
1183
AffineExpr lhs = binaryExpr.getLHS ();
1185
1184
AffineExpr rhs = binaryExpr.getRHS ();
1186
1185
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1295,7 +1294,23 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1295
1294
localExprs, context);
1296
1295
AffineExpr b = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1297
1296
localExprs, context);
1298
- return addLocalVariableSemiAffine (mulLhs, rhs, a * b, lhs, lhs.size ());
1297
+ AffineExpr mulExpr = a * b;
1298
+ if (auto constMulExpr = dyn_cast<AffineConstantExpr>(mulExpr)) {
1299
+ std::fill (lhs.begin (), lhs.end (), 0 );
1300
+ lhs[getConstantIndex ()] = constMulExpr.getValue ();
1301
+ return success ();
1302
+ }
1303
+ if (auto dimMulExpr = dyn_cast<AffineDimExpr>(mulExpr)) {
1304
+ std::fill (lhs.begin (), lhs.end (), 0 );
1305
+ lhs[getDimStartIndex () + dimMulExpr.getPosition ()] = 1 ;
1306
+ return success ();
1307
+ }
1308
+ if (auto symbolMulExpr = dyn_cast<AffineSymbolExpr>(mulExpr)) {
1309
+ std::fill (lhs.begin (), lhs.end (), 0 );
1310
+ lhs[getSymbolStartIndex () + symbolMulExpr.getPosition ()] = 1 ;
1311
+ return success ();
1312
+ }
1313
+ return addLocalVariableSemiAffine (mulLhs, rhs, mulExpr, lhs, lhs.size ());
1299
1314
}
1300
1315
1301
1316
// Get the RHS constant.
@@ -1348,6 +1363,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1348
1363
AffineExpr divisorExpr = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1349
1364
localExprs, context);
1350
1365
AffineExpr modExpr = dividendExpr % divisorExpr;
1366
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1367
+ std::fill (lhs.begin (), lhs.end (), 0 );
1368
+ lhs[getConstantIndex ()] = constModExpr.getValue ();
1369
+ return success ();
1370
+ }
1371
+ if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
1372
+ std::fill (lhs.begin (), lhs.end (), 0 );
1373
+ lhs[getDimStartIndex () + dimModExpr.getPosition ()] = 1 ;
1374
+ return success ();
1375
+ }
1376
+ if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
1377
+ std::fill (lhs.begin (), lhs.end (), 0 );
1378
+ lhs[getSymbolStartIndex () + symbolModExpr.getPosition ()] = 1 ;
1379
+ return success ();
1380
+ }
1351
1381
return addLocalVariableSemiAffine (modLhs, rhs, modExpr, lhs, lhs.size ());
1352
1382
}
1353
1383
@@ -1482,6 +1512,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1482
1512
AffineExpr b = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1483
1513
localExprs, context);
1484
1514
AffineExpr divExpr = isCeil ? a.ceilDiv (b) : a.floorDiv (b);
1515
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1516
+ std::fill (lhs.begin (), lhs.end (), 0 );
1517
+ lhs[getConstantIndex ()] = constDivExpr.getValue ();
1518
+ return success ();
1519
+ }
1520
+ if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
1521
+ std::fill (lhs.begin (), lhs.end (), 0 );
1522
+ lhs[getDimStartIndex () + dimDivExpr.getPosition ()] = 1 ;
1523
+ return success ();
1524
+ }
1525
+ if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
1526
+ std::fill (lhs.begin (), lhs.end (), 0 );
1527
+ lhs[getSymbolStartIndex () + symbolDivExpr.getPosition ()] = 1 ;
1528
+ return success ();
1529
+ }
1485
1530
return addLocalVariableSemiAffine (divLhs, rhs, divExpr, lhs, lhs.size ());
1486
1531
}
1487
1532
0 commit comments