diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index c8d9761511bec..cc81f9d19aca7 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef flatExprs, // the indices in `coefficients` map, and affine expression corresponding to // in indices in `indexToExprMap` map. for (const auto &it : llvm::enumerate(localExprs)) { - AffineExpr expr = it.value(); if (flatExprs[numDims + numSymbols + it.index()] == 0) continue; - AffineExpr lhs = cast(expr).getLHS(); - AffineExpr rhs = cast(expr).getRHS(); + AffineExpr expr = it.value(); + auto binaryExpr = dyn_cast(expr); + if (!binaryExpr) + continue; + + AffineExpr lhs = binaryExpr.getLHS(); + AffineExpr rhs = binaryExpr.getRHS(); if (!((isa(lhs) || isa(lhs)) && (isa(rhs) || isa(rhs) || isa(rhs)))) { diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir index e4a8512b002ee..6f2737a982752 100644 --- a/mlir/test/Dialect/Affine/simplify-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-structures.mlir @@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in //CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}] return %a : index } + +// ----- + +// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const +func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor) -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13 = arith.constant 13 : index + %dim = tensor.dim %arg0, %c0 : tensor + %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1] + %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1] + // CHECK: %[[C6:.*]] = arith.constant 6 : index + // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index + // CHECK-NEXT: return %[[C6]], %[[C7]] + return %a, %b : index, index +}