Skip to content

Commit db4de47

Browse files
committed
Lower affine modulo by powers of two using bitwise AND
1 parent bb982e7 commit db4de47

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,27 @@ class AffineApplyExpander
8080
/// let remainder = srem a, b;
8181
/// negative = a < 0 in
8282
/// select negative, remainder + b, remainder.
83+
///
84+
/// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative
85+
/// x.
8386
Value visitModExpr(AffineBinaryOpExpr expr) {
8487
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
8588
if (rhsConst.getValue() <= 0) {
8689
emitError(loc, "modulo by non-positive value is not supported");
8790
return nullptr;
8891
}
92+
93+
// Special case: x mod n where n is a power of 2 can be optimized to x &
94+
// (n-1)
95+
int64_t rhsValue = rhsConst.getValue();
96+
if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
97+
auto lhs = visit(expr.getLHS());
98+
assert(lhs && "unexpected affine expr lowering failure");
99+
100+
Value maskCst =
101+
builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
102+
return builder.create<arith::AndIOp>(loc, lhs, maskCst);
103+
}
89104
}
90105

91106
auto lhs = visit(expr.getLHS());

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,12 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
927927
// CHECK: scf.reduce.return %[[RES]] : i64
928928
// CHECK: }
929929
// CHECK: }
930+
931+
#map_mod_8 = affine_map<(i) -> (i mod 8)>
932+
// CHECK-LABEL: func @affine_apply_mod_8
933+
func.func @affine_apply_mod_8(%arg0 : index) -> (index) {
934+
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
935+
// CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index
936+
%0 = affine.apply #map_mod_8 (%arg0)
937+
return %0 : index
938+
}

0 commit comments

Comments
 (0)