diff --git a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp index 673d7e86c7ba0..317a41a2129c3 100644 --- a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp @@ -306,6 +306,10 @@ static bool isArrayLike(mlir::Type type) { } static bool isCompositeLike(mlir::Type type) { + // class(*) is not a composite type since it does not have a determined type. + if (fir::isUnlimitedPolymorphicType(type)) + return false; + return mlir::isa(type); } @@ -320,8 +324,18 @@ template <> mlir::acc::VariableTypeCategory OpenACCMappableModel::getTypeCategory(mlir::Type type, mlir::Value var) const { + // Class-type does not behave like a normal box because it does not hold an + // element type. Thus special handle it here. + if (mlir::isa(type)) { + // class(*) is not a composite type since it does not have a determined + // type. + if (fir::isUnlimitedPolymorphicType(type)) + return mlir::acc::VariableTypeCategory::uncategorized; + return mlir::acc::VariableTypeCategory::composite; + } mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type); + assert(eleTy && "expect to be able to unwrap the element type"); // If the type enclosed by the box is a mappable type, then have it // provide the type category. @@ -346,7 +360,7 @@ OpenACCMappableModel::getTypeCategory(mlir::Type type, return mlir::acc::VariableTypeCategory::nonscalar; } -static mlir::TypedValue +static mlir::Value getBaseRef(mlir::TypedValue varPtr) { // If there is no defining op - the unwrapped reference is the base one. mlir::Operation *op = varPtr.getDefiningOp(); @@ -372,7 +386,7 @@ getBaseRef(mlir::TypedValue varPtr) { }) .Default([&](mlir::Operation *) { return varPtr; }); - return mlir::cast>(baseRef); + return baseRef; } static mlir::acc::VariableTypeCategory @@ -384,10 +398,17 @@ categorizePointee(mlir::Type pointer, // value would both be represented as !fir.ref. We do not want to treat // such a reference as a scalar. Thus unwrap interior pointer calculations. auto baseRef = getBaseRef(varPtr); - mlir::Type eleTy = baseRef.getType().getElementType(); - if (auto mappableTy = mlir::dyn_cast(eleTy)) - return mappableTy.getTypeCategory(varPtr); + if (auto mappableTy = + mlir::dyn_cast(baseRef.getType())) + return mappableTy.getTypeCategory(baseRef); + + // It must be a pointer-like type since it is not a MappableType. + auto ptrLikeTy = mlir::cast(baseRef.getType()); + mlir::Type eleTy = ptrLikeTy.getElementType(); + + if (auto mappableEleTy = mlir::dyn_cast(eleTy)) + return mappableEleTy.getTypeCategory(varPtr); if (isScalarLike(eleTy)) return mlir::acc::VariableTypeCategory::scalar; @@ -397,8 +418,12 @@ categorizePointee(mlir::Type pointer, return mlir::acc::VariableTypeCategory::composite; if (mlir::isa(eleTy)) return mlir::acc::VariableTypeCategory::nonscalar; + // Assumed-type (type(*))does not have a determined type that can be + // categorized. + if (mlir::isa(eleTy)) + return mlir::acc::VariableTypeCategory::uncategorized; // "pointers" - in the sense of raw address point-of-view, are considered - // scalars. However + // scalars. if (mlir::isa(eleTy)) return mlir::acc::VariableTypeCategory::scalar; diff --git a/flang/test/Fir/OpenACC/openacc-type-categories-class.f90 b/flang/test/Fir/OpenACC/openacc-type-categories-class.f90 new file mode 100644 index 0000000000000..58025bfa556a5 --- /dev/null +++ b/flang/test/Fir/OpenACC/openacc-type-categories-class.f90 @@ -0,0 +1,46 @@ +! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s + +module mm + type, public :: polyty + real :: field + end type +contains + subroutine init(this) + class(polyty), intent(inout) :: this + !$acc enter data copyin(this, this%field) + end subroutine + subroutine init_assumed_type(var) + type(*), intent(inout) :: var + !$acc enter data copyin(var) + end subroutine + subroutine init_unlimited(this) + class(*), intent(inout) :: this + !$acc enter data copyin(this) + select type(this) + type is(real) + !$acc enter data copyin(this) + class is(polyty) + !$acc enter data copyin(this, this%field) + end select + end subroutine +end module + +! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this", structured = false} +! CHECK: Mappable: !fir.class> +! CHECK: Type category: composite +! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this%field", structured = false} +! CHECK: Pointer-like: !fir.ref +! CHECK: Type category: composite + +! For unlimited polymorphic entities and assumed types - they effectively have +! no declared type. Thus the type categorizer cannot categorize it. +! CHECK: Visiting: {{.*}} = acc.copyin {{.*}} {name = "var", structured = false} +! CHECK: Pointer-like: !fir.ref +! CHECK: Type category: uncategorized +! CHECK: Visiting: {{.*}} = acc.copyin {{.*}} {name = "this", structured = false} +! CHECK: Mappable: !fir.class +! CHECK: Type category: uncategorized + +! TODO: After using select type - the appropriate type category should be +! possible. Add the rest of the test once OpenACC lowering correctly handles +! unlimited polymorhic. diff --git a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp index 90aabd7d40d44..e72b96fe7cd10 100644 --- a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp +++ b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp @@ -6,11 +6,15 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/Support/DataLayout.h" using namespace mlir; @@ -25,6 +29,11 @@ struct TestFIROpenACCInterfaces StringRef getDescription() const final { return "Test FIR implementation of the OpenACC interfaces."; } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } void runOnOperation() override { mlir::ModuleOp mod = getOperation(); auto datalayout =