Skip to content

[flang] Emit fir.global in the global address space #146653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,10 @@ uint64_t getAllocaAddressSpace(const mlir::DataLayout *dataLayout);
llvm::SmallVector<mlir::Value> deduceOptimalExtents(mlir::ValueRange extents1,
mlir::ValueRange extents2);

uint64_t getGlobalAddressSpace(mlir::DataLayout *dataLayout);

uint64_t getProgramAddressSpace(mlir::DataLayout *dataLayout);

/// Given array extents generate code that sets them all to zeroes,
/// if the array is empty, e.g.:
/// %false = arith.constant false
Expand Down
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
unsigned
getProgramAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

unsigned
getGlobalAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

const fir::FIRToLLVMPassOptions &options;

using ConvertToLLVMPattern::matchAndRewrite;
Expand Down
14 changes: 14 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,20 @@ fir::factory::deduceOptimalExtents(mlir::ValueRange extents1,
return extents;
}

uint64_t fir::factory::getGlobalAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getGlobalMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
Comment on lines +1871 to +1874
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there so much machinery? This should be direct read of constant

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follow the same pattern as getAllocaAddressSpace which I think was made this way to allow for testing modules that do not have data layout attached.

Added @tblah to verify if my understanding is correct or not (since it seems he wrote getAllocaAddressSpace).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes lots of flang lit tests do not have data layout strings. It would be inconvenient to add them.

}

uint64_t fir::factory::getProgramAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getProgramMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

llvm::SmallVector<mlir::Value> fir::factory::updateRuntimeExtentsForEmptyArrays(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents) {
if (extents.size() <= 1)
Expand Down
61 changes: 52 additions & 9 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,54 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
}

namespace {

mlir::Value replaceWithAddrOfOrASCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc,
std::uint64_t globalAS,
std::uint64_t programAS,
llvm::StringRef symName, mlir::Type type,
mlir::Operation *replaceOp = nullptr) {
if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
if (globalAS != programAS) {
auto llvmAddrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
loc, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
replaceOp, ::getLlvmPtrType(rewriter.getContext(), programAS),
llvmAddrOp);
return rewriter.create<mlir::LLVM::AddrSpaceCastOp>(
loc, getLlvmPtrType(rewriter.getContext(), programAS), llvmAddrOp);
}

if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
replaceOp, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
return rewriter.create<mlir::LLVM::AddressOfOp>(
loc, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
}

if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(replaceOp, type,
symName);
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, type, symName);
}

/// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
using FIROpConversion::FIROpConversion;

llvm::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto ty = convertType(addr.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
addr, ty, addr.getSymbol().getRootReference().getValue());
auto global = addr->getParentOfType<mlir::ModuleOp>()
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
replaceWithAddrOfOrASCast(
rewriter, addr->getLoc(),
global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
getProgramAddressSpace(rewriter),
global ? global.getSymName()
: addr.getSymbol().getRootReference().getValue(),
convertType(addr.getType()), addr);
return mlir::success();
}
};
Expand Down Expand Up @@ -1306,13 +1344,18 @@ getTypeDescriptor(ModOpTy mod, mlir::ConversionPatternRewriter &rewriter,
? fir::NameUniquer::getTypeDescriptorAssemblyName(recType.getName())
: fir::NameUniquer::getTypeDescriptorName(recType.getName());
mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext());
mlir::DataLayout dataLayout(mod);
if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name))
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
return replaceWithAddrOfOrASCast(
rewriter, loc, fir::factory::getGlobalAddressSpace(&dataLayout),
fir::factory::getProgramAddressSpace(&dataLayout), global.getSymName(),
llvmPtrTy);
// The global may have already been translated to LLVM.
if (auto global = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(name))
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
return replaceWithAddrOfOrASCast(
rewriter, loc, global.getAddrSpace(),
fir::factory::getProgramAddressSpace(&dataLayout), global.getSymName(),
llvmPtrTy);
// Type info derived types do not have type descriptors since they are the
// types defining type descriptors.
if (options.ignoreMissingTypeDescriptors ||
Expand Down Expand Up @@ -3130,8 +3173,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
mlir::SymbolRefAttr comdat;
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto g = rewriter.create<mlir::LLVM::GlobalOp>(
loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0, 0,
false, false, comdat, attrs, dbgExprs);
loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0,
getGlobalAddressSpace(rewriter), false, false, comdat, attrs, dbgExprs);

if (global.getAlignment() && *global.getAlignment() > 0)
g.setAlignment(*global.getAlignment());
Expand Down
10 changes: 10 additions & 0 deletions flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/CodeGen/FIROpPatterns.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -365,4 +366,13 @@ unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace(
return defaultAddressSpace;
}

unsigned ConvertFIRToLLVMPattern::getGlobalAddressSpace(
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
auto dataLayout = mlir::DataLayout::closest(parentOp);
return fir::factory::getGlobalAddressSpace(&dataLayout);
}

} // namespace fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
!REQUIRES: amdgpu-registered-target

!RUN: %flang_fc1 -emit-llvm -triple amdgcn-amd-amdhsa -target-cpu gfx908 %s -o - | FileCheck %s

subroutine maintest
implicit none

type r1_t
end type r1_t

type(r1_t), pointer :: A
end subroutine

! CHECK: @[[TYPE_DESC:.*XdtXr1_t]] = linkonce_odr addrspace(1) constant %_QM__fortran_type_infoTderivedtype

! CHECK: define void @maintest_() {{.*}} {
! CHECK: store { {{.*}} } { {{.*}}, ptr addrspacecast (ptr addrspace(1) @[[TYPE_DESC]] to ptr), {{.*}} }, {{.*}}
! CHECK: }
Loading