Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 49 additions & 37 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ class MapInfoFinalizationPass
/// | |
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;

// List of deferrable descriptors to process at the end of
// the pass.
/// List of deferrable descriptors to process at the end of
/// the pass.
llvm::SmallVector<mlir::Operation *> deferrableDesc;

/// List of base addresses already expanded from their
/// descriptors within a parent, currently used to
/// prevent incorrect member index generation.
std::map<mlir::Operation *, llvm::SmallVector<uint64_t>> expandedBaseAddr;

/// Return true if the given path exists in a list of paths.
static bool
containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
Expand Down Expand Up @@ -403,26 +408,38 @@ class MapInfoFinalizationPass
/// of the base address index.
void adjustMemberIndices(
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &memberIndices,
size_t memberIndex) {
llvm::SmallVector<int64_t> baseAddrIndex = memberIndices[memberIndex];
ParentAndPlacement parentAndPlacement) {
llvm::SmallVector<int64_t> baseAddrIndex =
memberIndices[parentAndPlacement.index];
auto &expansionIndexes = expandedBaseAddr[parentAndPlacement.parent];

// If we find another member that is "derived/a member of" the descriptor
// that is not the descriptor itself, we must insert a 0 for the new base
// address we have just added for the descriptor into the list at the
// appropriate position to maintain correctness of the positional/index data
// for that member.
for (llvm::SmallVector<int64_t> &member : memberIndices)
for (auto [i, member] : llvm::enumerate(memberIndices)) {
if (std::find(expansionIndexes.begin(), expansionIndexes.end(), i) !=
expansionIndexes.end())
if (member.size() == baseAddrIndex.size() + 1 &&
member[baseAddrIndex.size()] == 0)
continue;

if (member.size() > baseAddrIndex.size() &&
std::equal(baseAddrIndex.begin(), baseAddrIndex.end(),
member.begin()))
member.insert(std::next(member.begin(), baseAddrIndex.size()), 0);
}

// Add the base address index to the main base address member data
baseAddrIndex.push_back(0);

// Insert our newly created baseAddrIndex into the larger list of indices at
// the correct location.
memberIndices.insert(std::next(memberIndices.begin(), memberIndex + 1),
uint64_t newIdxInsert = parentAndPlacement.index + 1;
expansionIndexes.push_back(newIdxInsert);

// Insert our newly created baseAddrIndex into the larger list of
// indices at the correct location.
memberIndices.insert(std::next(memberIndices.begin(), newIdxInsert),
baseAddrIndex);
}

Expand All @@ -449,30 +466,23 @@ class MapInfoFinalizationPass
/// descriptor tag to it as it's used differently to a regular mapping
/// and some of the runtime descriptor behaviour at the moment can cause
/// issues.
mlir::omp::ClauseMapFlags getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
mlir::Operation *target) {
mlir::omp::ClauseMapFlags
getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
mlir::Operation *target, bool isHasDeviceAddr) {
using mapFlags = mlir::omp::ClauseMapFlags;
mapFlags flags = mapFlags::none;
if (!isHasDeviceAddr)
flags |= mapFlags::attach;

if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp,
mlir::omp::TargetUpdateOp>(target))
return mapTypeFlag;

mapFlags flags = mapFlags::to | mapFlags::descriptor |
(mapTypeFlag & mapFlags::implicit);
// Descriptors for objects will always be copied. This is because the
// descriptor can be rematerialized by the compiler, and so the addres
// of the descriptor for a given object at one place in the code may
// differ from that address in another place. The contents of the
// descriptor (the base address in particular) will remain unchanged
// though.
// TODO/FIXME: We currently cannot have MAP_CLOSE and MAP_ALWAYS on
// the descriptor at once, these are mutually exclusive and when
// both are applied the runtime will fail to map.
flags |= ((mapTypeFlag & mapFlags::close) == mapFlags::close)
? mapFlags::close
: mapFlags::always;
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
// to ensure device-local placement where required by tests relying on USM +
// close semantics.
mlir::omp::TargetUpdateOp>(target)) {
flags |= mapTypeFlag | mapFlags::descriptor;
return flags;
}

flags |= mapFlags::to | mapFlags::descriptor | mapFlags::always |
(mapTypeFlag & mapFlags::implicit);

if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
flags |= mapFlags::close;
return flags;
Expand Down Expand Up @@ -532,10 +542,7 @@ class MapInfoFinalizationPass
mlir::Value varPtr = op.getVarPtr();
mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create(
builder, op.getLoc(), varPtr.getType(), varPtr,
mlir::TypeAttr::get(boxCharType.getEleTy()),
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
mlir::omp::ClauseMapFlags::to |
mlir::omp::ClauseMapFlags::implicit),
mlir::TypeAttr::get(boxCharType.getEleTy()), op.getMapTypeAttr(),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/boxAddr,
Expand All @@ -550,7 +557,11 @@ class MapInfoFinalizationPass
mlir::TypeAttr::get(
llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType())
.getElementType()),
op.getMapTypeAttr(), op.getMapCaptureTypeAttr(),
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
mlir::omp::ClauseMapFlags::attach | mlir::omp::ClauseMapFlags::to |
mlir::omp::ClauseMapFlags::always |
mlir::omp::ClauseMapFlags::descriptor),
op.getMapCaptureTypeAttr(),
/*varPtrPtr=*/mlir::Value{},
/*members=*/llvm::SmallVector<mlir::Value>{memberMapInfoOp},
/*member_index=*/newMembersAttr,
Expand Down Expand Up @@ -676,7 +687,7 @@ class MapInfoFinalizationPass
auto baseAddr =
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
ParentAndPlacement mapUser = mapMemberUsers[0];
adjustMemberIndices(memberIndices, mapUser.index);
adjustMemberIndices(memberIndices, mapUser);
llvm::SmallVector<mlir::Value> newMemberOps;
for (auto v : mapUser.parent.getMembers()) {
newMemberOps.push_back(v);
Expand Down Expand Up @@ -706,7 +717,7 @@ class MapInfoFinalizationPass
builder, op->getLoc(), op.getResult().getType(), descriptor,
mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
getDescriptorMapType(op.getMapType(), target)),
getDescriptorMapType(op.getMapType(), target, isHasDeviceAddrFlag)),
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
Expand Down Expand Up @@ -1005,6 +1016,7 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();
deferrableDesc.clear();
expandedBaseAddr.clear();

// First, walk `omp.map.info` ops to see if any of them have varPtrs
// with an underlying type of fir.char<k, ?>, i.e a character
Expand Down
Loading