diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 79012dbd32f80..7992e561895d8 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3133,6 +3133,16 @@ struct ParallelOpSingleOrZeroIterationDimsFolder newSteps, op.getInitVals(), nullptr); // Erase the empty block that was inserted by the builder. rewriter.eraseBlock(newOp.getBody()); + + // The new ParallelOp needs to keep all attributes from the old one, except + // for "operandSegmentSizes" which will be outdated. + for (const auto &namedAttr : op->getAttrs()) { + if (namedAttr.getName() == ParallelOp::getOperandSegmentSizeAttr()) + continue; + rewriter.modifyOpInPlace(newOp, [&]() { + newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + }); + } // Clone the loop body and remap the block arguments of the collapsed loops // (inlining does not support a cancellable block argument mapping). rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 5e32a3a78c032..41f608f8f0f30 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -94,6 +94,38 @@ func.func @single_iteration_reduce(%A: index, %B: index) -> (index, index) { // ----- +func.func @single_iteration_with_attributes(%A: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c10 = arith.constant 10 : index + scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) { + %c42 = arith.constant 42 : i32 + memref.store %c42, %A[%i0, %i1, %i2] : memref + scf.reduce + } {some_attr} + return +} + +// CHECK-LABEL: func @single_iteration_with_attributes( +// CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK-DAG: [[C42:%.*]] = arith.constant 42 : i32 +// CHECK-DAG: [[C7:%.*]] = arith.constant 7 : index +// CHECK-DAG: [[C6:%.*]] = arith.constant 6 : index +// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK: scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) { +// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref +// CHECK: scf.reduce +// CHECK: } {some_attr} +// CHECK: return + +// ----- + func.func @nested_parallel(%0: memref) -> memref { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index