@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1177
1177
if (flatExprs[numDims + numSymbols + it.index ()] == 0 )
1178
1178
continue ;
1179
1179
AffineExpr expr = it.value ();
1180
- auto binaryExpr = dyn_cast <AffineBinaryOpExpr>(expr);
1181
- if (!binaryExpr)
1182
- continue ;
1183
-
1180
+ assert (isa <AffineBinaryOpExpr>(expr) &&
1181
+ " local expression cannot be a dimension, symbol or a constant -- it "
1182
+ " should be a binary op expression " ) ;
1183
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
1184
1184
AffineExpr lhs = binaryExpr.getLHS ();
1185
1185
AffineExpr rhs = binaryExpr.getRHS ();
1186
1186
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1348
1348
AffineExpr divisorExpr = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1349
1349
localExprs, context);
1350
1350
AffineExpr modExpr = dividendExpr % divisorExpr;
1351
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
1352
+ std::fill (lhs.begin (), lhs.end (), 0 );
1353
+ lhs[getConstantIndex ()] = constModExpr.getValue ();
1354
+ return success ();
1355
+ }
1351
1356
return addLocalVariableSemiAffine (modLhs, rhs, modExpr, lhs, lhs.size ());
1352
1357
}
1353
1358
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1482
1487
AffineExpr b = getAffineExprFromFlatForm (rhs, numDims, numSymbols,
1483
1488
localExprs, context);
1484
1489
AffineExpr divExpr = isCeil ? a.ceilDiv (b) : a.floorDiv (b);
1490
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
1491
+ std::fill (lhs.begin (), lhs.end (), 0 );
1492
+ lhs[getConstantIndex ()] = constDivExpr.getValue ();
1493
+ return success ();
1494
+ }
1485
1495
return addLocalVariableSemiAffine (divLhs, rhs, divExpr, lhs, lhs.size ());
1486
1496
}
1487
1497
0 commit comments