Skip to content

[flang][acc] Ensure fir.class is handled in type categorization #146174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

razvanlupusoru
Copy link
Contributor

fir.class is treated similarly as fir.box - but it has one key distinction which is that it doesn't hold an element type. Thus the categorization logic was mishandling this case for this reason (and also the fact that it assumed that a base object is always a fir.ref).

This PR improves this handling and adds appropriate test exercising both a class and a class field to ensure categorization works.

fir.class is treated similarly as fir.box - but it has one key
distinction which is that it doesn't hold an element type. Thus
the categorization logic was mishandling this case for this reason
(and also the fact that it assumed that a base object is always a
fir.ref).

This PR improves this handling and adds appropriate test exercising
both a class and a class field to ensure categorization works.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir openacc labels Jun 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 27, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Razvan Lupusoru (razvanlupusoru)

Changes

fir.class is treated similarly as fir.box - but it has one key distinction which is that it doesn't hold an element type. Thus the categorization logic was mishandling this case for this reason (and also the fact that it assumed that a base object is always a fir.ref).

This PR improves this handling and adds appropriate test exercising both a class and a class field to ensure categorization works.


Full diff: https://github.com/llvm/llvm-project/pull/146174.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp (+17-5)
  • (added) flang/test/Fir/OpenACC/openacc-type-categories-class.f90 (+18)
  • (modified) flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp (+9)
diff --git a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
index 673d7e86c7ba0..2702f7e8c185e 100644
--- a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
@@ -320,8 +320,13 @@ template <>
 mlir::acc::VariableTypeCategory
 OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
                                                         mlir::Value var) const {
+  // Class-type does not behave like a normal box because it does not hold an
+  // element type. Thus special handle it here.
+  if (mlir::isa<fir::ClassType>(type))
+    return mlir::acc::VariableTypeCategory::composite;
 
   mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
+  assert(eleTy && "expect to be able to unwrap the element type");
 
   // If the type enclosed by the box is a mappable type, then have it
   // provide the type category.
@@ -346,7 +351,7 @@ OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
   return mlir::acc::VariableTypeCategory::nonscalar;
 }
 
-static mlir::TypedValue<mlir::acc::PointerLikeType>
+static mlir::Value
 getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
   // If there is no defining op - the unwrapped reference is the base one.
   mlir::Operation *op = varPtr.getDefiningOp();
@@ -372,7 +377,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
           })
           .Default([&](mlir::Operation *) { return varPtr; });
 
-  return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
+  return baseRef;
 }
 
 static mlir::acc::VariableTypeCategory
@@ -384,10 +389,17 @@ categorizePointee(mlir::Type pointer,
   // value would both be represented as !fir.ref<f32>. We do not want to treat
   // such a reference as a scalar. Thus unwrap interior pointer calculations.
   auto baseRef = getBaseRef(varPtr);
-  mlir::Type eleTy = baseRef.getType().getElementType();
 
-  if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
-    return mappableTy.getTypeCategory(varPtr);
+  if (auto mappableTy =
+          mlir::dyn_cast<mlir::acc::MappableType>(baseRef.getType()))
+    return mappableTy.getTypeCategory(baseRef);
+
+  // It must be a pointer-like type since it is not a MappableType.
+  auto ptrLikeTy = mlir::cast<mlir::acc::PointerLikeType>(baseRef.getType());
+  mlir::Type eleTy = ptrLikeTy.getElementType();
+
+  if (auto mappableEleTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
+    return mappableEleTy.getTypeCategory(varPtr);
 
   if (isScalarLike(eleTy))
     return mlir::acc::VariableTypeCategory::scalar;
diff --git a/flang/test/Fir/OpenACC/openacc-type-categories-class.f90 b/flang/test/Fir/OpenACC/openacc-type-categories-class.f90
new file mode 100644
index 0000000000000..0a38ab96a0315
--- /dev/null
+++ b/flang/test/Fir/OpenACC/openacc-type-categories-class.f90
@@ -0,0 +1,18 @@
+module mm
+  type, public :: polyty
+    real :: field
+  end type
+contains
+  subroutine init(this)
+    class(polyty), intent(inout) :: this
+    !$acc enter data copyin(this, this%field)
+  end subroutine
+end module
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this", structured = false}
+! CHECK: Mappable: !fir.class<!fir.type<_QMmmTpolyty{field:f32}>>
+! CHECK: Type category: composite
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this%field", structured = false}
+! CHECK: Pointer-like: !fir.ref<f32>
+! CHECK: Type category: composite
diff --git a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
index 90aabd7d40d44..11567d1c0c6a3 100644
--- a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
+++ b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
@@ -6,12 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/Support/DataLayout.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 
 using namespace mlir;
 
@@ -25,6 +29,11 @@ struct TestFIROpenACCInterfaces
   StringRef getDescription() const final {
     return "Test FIR implementation of the OpenACC interfaces.";
   }
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    registry.insert<fir::FIROpsDialect, hlfir::hlfirDialect,
+        mlir::arith::ArithDialect, mlir::acc::OpenACCDialect,
+        mlir::DLTIDialect>();
+  }
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
     auto datalayout =

@llvmbot
Copy link
Member

llvmbot commented Jun 27, 2025

@llvm/pr-subscribers-openacc

Author: Razvan Lupusoru (razvanlupusoru)

Changes

fir.class is treated similarly as fir.box - but it has one key distinction which is that it doesn't hold an element type. Thus the categorization logic was mishandling this case for this reason (and also the fact that it assumed that a base object is always a fir.ref).

This PR improves this handling and adds appropriate test exercising both a class and a class field to ensure categorization works.


Full diff: https://github.com/llvm/llvm-project/pull/146174.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp (+17-5)
  • (added) flang/test/Fir/OpenACC/openacc-type-categories-class.f90 (+18)
  • (modified) flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp (+9)
diff --git a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
index 673d7e86c7ba0..2702f7e8c185e 100644
--- a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
@@ -320,8 +320,13 @@ template <>
 mlir::acc::VariableTypeCategory
 OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
                                                         mlir::Value var) const {
+  // Class-type does not behave like a normal box because it does not hold an
+  // element type. Thus special handle it here.
+  if (mlir::isa<fir::ClassType>(type))
+    return mlir::acc::VariableTypeCategory::composite;
 
   mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
+  assert(eleTy && "expect to be able to unwrap the element type");
 
   // If the type enclosed by the box is a mappable type, then have it
   // provide the type category.
@@ -346,7 +351,7 @@ OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
   return mlir::acc::VariableTypeCategory::nonscalar;
 }
 
-static mlir::TypedValue<mlir::acc::PointerLikeType>
+static mlir::Value
 getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
   // If there is no defining op - the unwrapped reference is the base one.
   mlir::Operation *op = varPtr.getDefiningOp();
@@ -372,7 +377,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
           })
           .Default([&](mlir::Operation *) { return varPtr; });
 
-  return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
+  return baseRef;
 }
 
 static mlir::acc::VariableTypeCategory
@@ -384,10 +389,17 @@ categorizePointee(mlir::Type pointer,
   // value would both be represented as !fir.ref<f32>. We do not want to treat
   // such a reference as a scalar. Thus unwrap interior pointer calculations.
   auto baseRef = getBaseRef(varPtr);
-  mlir::Type eleTy = baseRef.getType().getElementType();
 
-  if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
-    return mappableTy.getTypeCategory(varPtr);
+  if (auto mappableTy =
+          mlir::dyn_cast<mlir::acc::MappableType>(baseRef.getType()))
+    return mappableTy.getTypeCategory(baseRef);
+
+  // It must be a pointer-like type since it is not a MappableType.
+  auto ptrLikeTy = mlir::cast<mlir::acc::PointerLikeType>(baseRef.getType());
+  mlir::Type eleTy = ptrLikeTy.getElementType();
+
+  if (auto mappableEleTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
+    return mappableEleTy.getTypeCategory(varPtr);
 
   if (isScalarLike(eleTy))
     return mlir::acc::VariableTypeCategory::scalar;
diff --git a/flang/test/Fir/OpenACC/openacc-type-categories-class.f90 b/flang/test/Fir/OpenACC/openacc-type-categories-class.f90
new file mode 100644
index 0000000000000..0a38ab96a0315
--- /dev/null
+++ b/flang/test/Fir/OpenACC/openacc-type-categories-class.f90
@@ -0,0 +1,18 @@
+module mm
+  type, public :: polyty
+    real :: field
+  end type
+contains
+  subroutine init(this)
+    class(polyty), intent(inout) :: this
+    !$acc enter data copyin(this, this%field)
+  end subroutine
+end module
+
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this", structured = false}
+! CHECK: Mappable: !fir.class<!fir.type<_QMmmTpolyty{field:f32}>>
+! CHECK: Type category: composite
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "this%field", structured = false}
+! CHECK: Pointer-like: !fir.ref<f32>
+! CHECK: Type category: composite
diff --git a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
index 90aabd7d40d44..11567d1c0c6a3 100644
--- a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
+++ b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
@@ -6,12 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/Support/DataLayout.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 
 using namespace mlir;
 
@@ -25,6 +29,11 @@ struct TestFIROpenACCInterfaces
   StringRef getDescription() const final {
     return "Test FIR implementation of the OpenACC interfaces.";
   }
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    registry.insert<fir::FIROpsDialect, hlfir::hlfirDialect,
+        mlir::arith::ArithDialect, mlir::acc::OpenACCDialect,
+        mlir::DLTIDialect>();
+  }
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
     auto datalayout =

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
View the diff from clang-format here.
diff --git a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
index 11567d1c0..e72b96fe7 100644
--- a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
+++ b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -15,7 +16,6 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/Support/DataLayout.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
 
 using namespace mlir;
 

Comment on lines +7 to +8
class(polyty), intent(inout) :: this
!$acc enter data copyin(this, this%field)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you check if it works as well with an unlimited polymorphic dummy?

class(*), intent(inout) :: this
!$acc enter data copyin(this)

@@ -0,0 +1,18 @@
module mm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should probably go in flang/test/Lower/OpenACC or starts from mlir directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I just saw that you added another test like that in this folder as well. No strong opinion on this.

end subroutine
end module

! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to not add the run line at the top?

Comment on lines +325 to +326
if (mlir::isa<fir::ClassType>(type))
return mlir::acc::VariableTypeCategory::composite;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make a distinction for unlimited polymorphic that can be smth else than a composite?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants