Skip to content

Commit cb4125e

Browse files
committed
[Prototype][Flang][OpenMP] Swap to attach semantics for descriptor mapping
1 parent 1ab2948 commit cb4125e

21 files changed

+615
-515
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,15 @@ class MapInfoFinalizationPass
8484
/// | |
8585
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
8686

87-
// List of deferrable descriptors to process at the end of
88-
// the pass.
87+
/// List of deferrable descriptors to process at the end of
88+
/// the pass.
8989
llvm::SmallVector<mlir::Operation *> deferrableDesc;
9090

91+
/// List of base addresses already expanded from their
92+
/// descriptors within a parent, currently used to
93+
/// prevent incorrect member index generation.
94+
std::map<mlir::Operation *, llvm::SmallVector<uint64_t>> expandedBaseAddr;
95+
9196
/// Return true if the given path exists in a list of paths.
9297
static bool
9398
containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
@@ -403,26 +408,38 @@ class MapInfoFinalizationPass
403408
/// of the base address index.
404409
void adjustMemberIndices(
405410
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &memberIndices,
406-
size_t memberIndex) {
407-
llvm::SmallVector<int64_t> baseAddrIndex = memberIndices[memberIndex];
411+
ParentAndPlacement parentAndPlacement) {
412+
llvm::SmallVector<int64_t> baseAddrIndex =
413+
memberIndices[parentAndPlacement.index];
414+
auto &expansionIndexes = expandedBaseAddr[parentAndPlacement.parent];
408415

409416
// If we find another member that is "derived/a member of" the descriptor
410417
// that is not the descriptor itself, we must insert a 0 for the new base
411418
// address we have just added for the descriptor into the list at the
412419
// appropriate position to maintain correctness of the positional/index data
413420
// for that member.
414-
for (llvm::SmallVector<int64_t> &member : memberIndices)
421+
for (auto [i, member] : llvm::enumerate(memberIndices)) {
422+
if (std::find(expansionIndexes.begin(), expansionIndexes.end(), i) !=
423+
expansionIndexes.end())
424+
if (member.size() == baseAddrIndex.size() + 1 &&
425+
member[baseAddrIndex.size()] == 0)
426+
continue;
427+
415428
if (member.size() > baseAddrIndex.size() &&
416429
std::equal(baseAddrIndex.begin(), baseAddrIndex.end(),
417430
member.begin()))
418431
member.insert(std::next(member.begin(), baseAddrIndex.size()), 0);
432+
}
419433

420434
// Add the base address index to the main base address member data
421435
baseAddrIndex.push_back(0);
422436

423-
// Insert our newly created baseAddrIndex into the larger list of indices at
424-
// the correct location.
425-
memberIndices.insert(std::next(memberIndices.begin(), memberIndex + 1),
437+
uint64_t newIdxInsert = parentAndPlacement.index + 1;
438+
expansionIndexes.push_back(newIdxInsert);
439+
440+
// Insert our newly created baseAddrIndex into the larger list of
441+
// indices at the correct location.
442+
memberIndices.insert(std::next(memberIndices.begin(), newIdxInsert),
426443
baseAddrIndex);
427444
}
428445

@@ -449,30 +466,23 @@ class MapInfoFinalizationPass
449466
/// descriptor tag to it as it's used differently to a regular mapping
450467
/// and some of the runtime descriptor behaviour at the moment can cause
451468
/// issues.
452-
mlir::omp::ClauseMapFlags getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
453-
mlir::Operation *target) {
469+
mlir::omp::ClauseMapFlags
470+
getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
471+
mlir::Operation *target, bool isHasDeviceAddr) {
454472
using mapFlags = mlir::omp::ClauseMapFlags;
473+
mapFlags flags = mapFlags::none;
474+
if (!isHasDeviceAddr)
475+
flags |= mapFlags::attach;
476+
455477
if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp,
456-
mlir::omp::TargetUpdateOp>(target))
457-
return mapTypeFlag;
458-
459-
mapFlags flags = mapFlags::to | mapFlags::descriptor |
460-
(mapTypeFlag & mapFlags::implicit);
461-
// Descriptors for objects will always be copied. This is because the
462-
// descriptor can be rematerialized by the compiler, and so the addres
463-
// of the descriptor for a given object at one place in the code may
464-
// differ from that address in another place. The contents of the
465-
// descriptor (the base address in particular) will remain unchanged
466-
// though.
467-
// TODO/FIXME: We currently cannot have MAP_CLOSE and MAP_ALWAYS on
468-
// the descriptor at once, these are mutually exclusive and when
469-
// both are applied the runtime will fail to map.
470-
flags |= ((mapTypeFlag & mapFlags::close) == mapFlags::close)
471-
? mapFlags::close
472-
: mapFlags::always;
473-
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
474-
// to ensure device-local placement where required by tests relying on USM +
475-
// close semantics.
478+
mlir::omp::TargetUpdateOp>(target)) {
479+
flags |= mapTypeFlag | mapFlags::descriptor;
480+
return flags;
481+
}
482+
483+
flags |= mapFlags::to | mapFlags::descriptor | mapFlags::always |
484+
(mapTypeFlag & mapFlags::implicit);
485+
476486
if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
477487
flags |= mapFlags::close;
478488
return flags;
@@ -532,10 +542,7 @@ class MapInfoFinalizationPass
532542
mlir::Value varPtr = op.getVarPtr();
533543
mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create(
534544
builder, op.getLoc(), varPtr.getType(), varPtr,
535-
mlir::TypeAttr::get(boxCharType.getEleTy()),
536-
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
537-
mlir::omp::ClauseMapFlags::to |
538-
mlir::omp::ClauseMapFlags::implicit),
545+
mlir::TypeAttr::get(boxCharType.getEleTy()), op.getMapTypeAttr(),
539546
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
540547
mlir::omp::VariableCaptureKind::ByRef),
541548
/*varPtrPtr=*/boxAddr,
@@ -550,7 +557,11 @@ class MapInfoFinalizationPass
550557
mlir::TypeAttr::get(
551558
llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType())
552559
.getElementType()),
553-
op.getMapTypeAttr(), op.getMapCaptureTypeAttr(),
560+
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
561+
mlir::omp::ClauseMapFlags::attach | mlir::omp::ClauseMapFlags::to |
562+
mlir::omp::ClauseMapFlags::always |
563+
mlir::omp::ClauseMapFlags::descriptor),
564+
op.getMapCaptureTypeAttr(),
554565
/*varPtrPtr=*/mlir::Value{},
555566
/*members=*/llvm::SmallVector<mlir::Value>{memberMapInfoOp},
556567
/*member_index=*/newMembersAttr,
@@ -676,7 +687,7 @@ class MapInfoFinalizationPass
676687
auto baseAddr =
677688
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
678689
ParentAndPlacement mapUser = mapMemberUsers[0];
679-
adjustMemberIndices(memberIndices, mapUser.index);
690+
adjustMemberIndices(memberIndices, mapUser);
680691
llvm::SmallVector<mlir::Value> newMemberOps;
681692
for (auto v : mapUser.parent.getMembers()) {
682693
newMemberOps.push_back(v);
@@ -706,7 +717,7 @@ class MapInfoFinalizationPass
706717
builder, op->getLoc(), op.getResult().getType(), descriptor,
707718
mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
708719
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
709-
getDescriptorMapType(op.getMapType(), target)),
720+
getDescriptorMapType(op.getMapType(), target, isHasDeviceAddrFlag)),
710721
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
711722
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
712723
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
@@ -1005,6 +1016,7 @@ class MapInfoFinalizationPass
10051016
// iterations from previous function scopes.
10061017
localBoxAllocas.clear();
10071018
deferrableDesc.clear();
1019+
expandedBaseAddr.clear();
10081020

10091021
// First, walk `omp.map.info` ops to see if any of them have varPtrs
10101022
// with an underlying type of fir.char<k, ?>, i.e a character

0 commit comments

Comments
 (0)