diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 3dd35ed9ae481..027e8d29a0515 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -369,6 +369,55 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { llvm::report_fatal_error("Could not find symbol"); } +static mlir::Value getBaseAddr(Fortran::semantics::Symbol &symbol, + const fir::factory::AddrAndBoundsInfo &info) { + if (Fortran::semantics::IsAssumedSizeArray(symbol)) { + // Assumed-size arrays in FIR are represented as: + // func.func @func(%arg0: !fir.ref> {fir.bindc_name = "arr"}) { + // %arr:2 = hlfir.declare %arg0(%shape) ... -> (!fir.box>, !fir.ref>) + // The `rawInput` refers to the #1 output of the `hlfir.declare` operation. + // This is preferred since the Fortran variable properties does not contain + // any useful size information. + return info.rawInput; + } + + if (Fortran::semantics::IsOptional(symbol)) { + // When there is an optional argument for which there is a possibility + // to create a descriptor, pick the rawInput instead. This is done to + // avoid materializing the descriptor which leads to following pattern + // generated at the FIR level which adds an extra indirection that makes + // recovering original variable not evident. + // This is the pattern we want to avoid to be generated: + // %1 = fir.declare %arg0 ... {fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub1Eassumedshapeoptarr"} : (!fir.box>, !fir.dscope) -> !fir.box> + // %2 = fir.is_present %1 : (!fir.box>) -> i1 + // %3 = fir.if %2 -> (!fir.box>) { + // %5 = fir.rebox %1 : (!fir.box>) -> !fir.box> + // fir.result %5 : !fir.box> + // } else { + // %5 = fir.absent !fir.box> + // fir.result %5 : !fir.box> + // } + // %4 = acc.copyin var(%3 : !fir.box>) ... + // + // Instead by picking the rawInput we get the following pattern: + // %1 = fir.declare %arg0 ... {fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub1Eassumedshapeoptarr"} : (!fir.box>, !fir.dscope) -> !fir.box> + // %2 = acc.copyin var(%2 : !fir.box>) ... + if (fir::unwrapRefType(info.addr.getType()) != + fir::unwrapRefType(info.rawInput.getType())) { + return info.rawInput; + } + } + + // The `addr` field refers to the address of the Fortran entity, but with the + // ssa value that when lowered to FIR will include the tied Fortran variable + // properties. Additionally, in cases where `unwrapFirBox` is requested, + // it refers to the address of the data (either result of fir.box_addr or + // result of `fir.if` in case of optional). + // Therefore, use the processed address in all cases by default unless it was + // deemed through the earlier checks in this routine that it is not useful. + return info.addr; +} + template static void genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, @@ -399,13 +448,7 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, /*genDefaultBounds=*/generateDefaultBounds); LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); - // If the input value is optional and is not a descriptor, we use the - // rawInput directly. - mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) != - fir::unwrapRefType(info.rawInput.getType())) && - info.isPresent) - ? info.rawInput - : info.addr; + mlir::Value baseAddr = getBaseAddr(symbol, info); Op op = createDataEntryOp( builder, operandLocation, baseAddr, asFortran, bounds, structured, implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes, diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90 index 8fea357f116a2..e333cdff122a2 100644 --- a/flang/test/Lower/OpenACC/acc-bounds.f90 +++ b/flang/test/Lower/OpenACC/acc-bounds.f90 @@ -92,8 +92,7 @@ subroutine acc_undefined_extent(a) ! CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#0, %c0{{.*}} : (!fir.box>, index) -> (index, index, index) ! CHECK: %[[UB:.*]] = arith.subi %[[DIMS0]]#1, %c1{{.*}} : index ! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%[[UB]] : index) extent(%[[DIMS0]]#1 : index) stride(%[[DIMS0]]#2 : index) startIdx(%c1{{.*}} : index) {strideInBytes = true} -! CHECK: %[[ADDR:.*]] = fir.box_addr %[[DECL_ARG0]]#0 : (!fir.box>) -> !fir.ref> -! CHECK: %[[PRESENT:.*]] = acc.present varPtr(%[[ADDR]] : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a"} +! CHECK: %[[PRESENT:.*]] = acc.present varPtr(%[[DECL_ARG0]]#1 : !fir.ref>) bounds(%[[BOUND]]) -> !fir.ref> {name = "a"} ! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref>) subroutine acc_multi_strides(a)