Skip to content

[flang][acc] Ensure fir.class is handled in type categorization #146174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ static bool isArrayLike(mlir::Type type) {
}

static bool isCompositeLike(mlir::Type type) {
// class(*) is not a composite type since it does not have a determined type.
if (fir::isUnlimitedPolymorphicType(type))
return false;

return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
}

Expand All @@ -320,8 +324,18 @@ template <>
mlir::acc::VariableTypeCategory
OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
mlir::Value var) const {
// Class-type does not behave like a normal box because it does not hold an
// element type. Thus special handle it here.
if (mlir::isa<fir::ClassType>(type)) {
// class(*) is not a composite type since it does not have a determined
// type.
if (fir::isUnlimitedPolymorphicType(type))
return mlir::acc::VariableTypeCategory::uncategorized;
return mlir::acc::VariableTypeCategory::composite;
}

mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
assert(eleTy && "expect to be able to unwrap the element type");

// If the type enclosed by the box is a mappable type, then have it
// provide the type category.
Expand All @@ -346,7 +360,7 @@ OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
return mlir::acc::VariableTypeCategory::nonscalar;
}

static mlir::TypedValue<mlir::acc::PointerLikeType>
static mlir::Value
getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// If there is no defining op - the unwrapped reference is the base one.
mlir::Operation *op = varPtr.getDefiningOp();
Expand All @@ -372,7 +386,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
})
.Default([&](mlir::Operation *) { return varPtr; });

return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
return baseRef;
}

static mlir::acc::VariableTypeCategory
Expand All @@ -384,10 +398,17 @@ categorizePointee(mlir::Type pointer,
// value would both be represented as !fir.ref<f32>. We do not want to treat
// such a reference as a scalar. Thus unwrap interior pointer calculations.
auto baseRef = getBaseRef(varPtr);
mlir::Type eleTy = baseRef.getType().getElementType();

if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableTy.getTypeCategory(varPtr);
if (auto mappableTy =
mlir::dyn_cast<mlir::acc::MappableType>(baseRef.getType()))
return mappableTy.getTypeCategory(baseRef);

// It must be a pointer-like type since it is not a MappableType.
auto ptrLikeTy = mlir::cast<mlir::acc::PointerLikeType>(baseRef.getType());
mlir::Type eleTy = ptrLikeTy.getElementType();

if (auto mappableEleTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableEleTy.getTypeCategory(varPtr);

if (isScalarLike(eleTy))
return mlir::acc::VariableTypeCategory::scalar;
Expand All @@ -397,8 +418,12 @@ categorizePointee(mlir::Type pointer,
return mlir::acc::VariableTypeCategory::composite;
if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy))
return mlir::acc::VariableTypeCategory::nonscalar;
// Assumed-type (type(*))does not have a determined type that can be
// categorized.
if (mlir::isa<mlir::NoneType>(eleTy))
return mlir::acc::VariableTypeCategory::uncategorized;
// "pointers" - in the sense of raw address point-of-view, are considered
// scalars. However
// scalars.
if (mlir::isa<fir::LLVMPointerType>(eleTy))
return mlir::acc::VariableTypeCategory::scalar;

Expand Down
46 changes: 46 additions & 0 deletions flang/test/Fir/OpenACC/openacc-type-categories-class.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s

module mm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should probably go in flang/test/Lower/OpenACC or starts from mlir directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I just saw that you added another test like that in this folder as well. No strong opinion on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the test this way - changing FIR directly, especially for derived types - is not quite the nicest experience due to the verbosity. I hope you can be OK with this!

type, public :: polyty
real :: field
end type
contains
subroutine init(this)
class(polyty), intent(inout) :: this
!$acc enter data copyin(this, this%field)
Comment on lines +9 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you check if it works as well with an unlimited polymorphic dummy?

class(*), intent(inout) :: this
!$acc enter data copyin(this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It did not - and now I fixed it as I described below.

end subroutine
subroutine init_assumed_type(var)
type(*), intent(inout) :: var
!$acc enter data copyin(var)
end subroutine
subroutine init_unlimited(this)
class(*), intent(inout) :: this
!$acc enter data copyin(this)
select type(this)
type is(real)
!$acc enter data copyin(this)
class is(polyty)
!$acc enter data copyin(this, this%field)
end select
end subroutine
end module

! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this", structured = false}
! CHECK: Mappable: !fir.class<!fir.type<_QMmmTpolyty{field:f32}>>
! CHECK: Type category: composite
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this%field", structured = false}
! CHECK: Pointer-like: !fir.ref<f32>
! CHECK: Type category: composite

! For unlimited polymorphic entities and assumed types - they effectively have
! no declared type. Thus the type categorizer cannot categorize it.
! CHECK: Visiting: {{.*}} = acc.copyin {{.*}} {name = "var", structured = false}
! CHECK: Pointer-like: !fir.ref<none>
! CHECK: Type category: uncategorized
! CHECK: Visiting: {{.*}} = acc.copyin {{.*}} {name = "this", structured = false}
! CHECK: Mappable: !fir.class<none>
! CHECK: Type category: uncategorized

! TODO: After using select type - the appropriate type category should be
! possible. Add the rest of the test once OpenACC lowering correctly handles
! unlimited polymorhic.
9 changes: 9 additions & 0 deletions flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/Support/DataLayout.h"

using namespace mlir;
Expand All @@ -25,6 +29,11 @@ struct TestFIROpenACCInterfaces
StringRef getDescription() const final {
return "Test FIR implementation of the OpenACC interfaces.";
}
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
registry.insert<fir::FIROpsDialect, hlfir::hlfirDialect,
mlir::arith::ArithDialect, mlir::acc::OpenACCDialect,
mlir::DLTIDialect>();
}
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
auto datalayout =
Expand Down
Loading