-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[flang] Implement workdistribute construct lowering #140523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
6e8010d
to
df65bd5
Compare
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp flang/lib/Optimizer/Passes/Pipelines.cpp mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp View the diff from clang-format here.diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 1fe2592d1..155cf220f 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -24,6 +24,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/Dialect/Utils/IndexingUtils.h>
@@ -34,7 +35,6 @@
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LLVM.h>
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include <optional>
#include <variant>
@@ -373,7 +373,8 @@ std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
RewriterBase &rewriter) {
auto loc = targetOp->getLoc();
if (targetOp.getMapVars().empty()) {
- LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << DEBUG_TYPE << " target region has no data maps\n");
return std::nullopt;
}
@@ -475,9 +476,9 @@ static Type getPtrTypeForOmp(Type ty) {
return fir::LLVMPointerType::get(ty);
}
-static TempOmpVar
-allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) {
- MLIRContext& ctx = *ty.getContext();
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+ RewriterBase &rewriter) {
+ MLIRContext &ctx = *ty.getContext();
Value alloc;
Type allocType;
auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
@@ -487,28 +488,30 @@ allocateTempOmpVar(Location loc, Type ty, RewriterBase &rewriter) {
allocType = llvmPtrTy;
alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
allocType = intTy;
- }
- else {
+ } else {
allocType = ty;
alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
}
auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
return rewriter.create<omp::MapInfoOp>(
- loc, alloc.getType(), alloc,
- TypeAttr::get(allocType),
- rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), mappingFlags),
- rewriter.getAttr<omp::VariableCaptureKindAttr>(
- omp::VariableCaptureKind::ByRef),
- /*varPtrPtr=*/Value{},
- /*members=*/SmallVector<Value>{},
- /*member_index=*/mlir::ArrayAttr{},
- /*bounds=*/ValueRange(),
- /*mapperId=*/mlir::FlatSymbolRefAttr(),
- /*name=*/rewriter.getStringAttr(name),
- rewriter.getBoolAttr(false));
+ loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+ mappingFlags),
+ rewriter.getAttr<omp::VariableCaptureKindAttr>(
+ omp::VariableCaptureKind::ByRef),
+ /*varPtrPtr=*/Value{},
+ /*members=*/SmallVector<Value>{},
+ /*member_index=*/mlir::ArrayAttr{},
+ /*bounds=*/ValueRange(),
+ /*mapperId=*/mlir::FlatSymbolRefAttr(),
+ /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
};
- uint64_t mapFrom = static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
- uint64_t mapTo = static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+ uint64_t mapFrom =
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+ uint64_t mapTo =
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
return TempOmpVar{mapInfoFrom, mapInfoTo};
@@ -550,11 +553,10 @@ struct SplitResult {
omp::TargetOp postTargetOp;
};
-static void collectNonRecomputableDeps(Value& v,
- omp::TargetOp targetOp,
- SetVector<Operation *>& nonRecomputable,
- SetVector<Operation *>& toCache,
- SetVector<Operation *>& toRecompute) {
+static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
+ SetVector<Operation *> &nonRecomputable,
+ SetVector<Operation *> &toCache,
+ SetVector<Operation *> &toRecompute) {
Operation *op = v.getDefiningOp();
if (!op) {
assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
@@ -566,16 +568,16 @@ static void collectNonRecomputableDeps(Value& v,
}
toRecompute.insert(op);
for (auto opr : op->getOperands())
- collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, toRecompute);
+ collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
+ toRecompute);
}
-
static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
- MLIRContext& ctx,
- IRMapping &mapping, Operation *splitBefore,
- Block *targetBlock, Block *newTargetBlock,
- SmallVector<Value>& allocs,
- SetVector<Operation *>& toRecompute) {
+ MLIRContext &ctx, IRMapping &mapping,
+ Operation *splitBefore, Block *targetBlock,
+ Block *newTargetBlock,
+ SmallVector<Value> &allocs,
+ SetVector<Operation *> &toRecompute) {
for (unsigned i = 0; i < targetBlock->getNumArguments(); i++) {
auto originalArg = targetBlock->getArgument(i);
auto newArg = newTargetBlock->addArgument(originalArg.getType(),
@@ -585,15 +587,15 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
for (auto original : allocs) {
Value newArg = newTargetBlock->addArgument(
- getPtrTypeForOmp(original.getType()), original.getLoc());
+ getPtrTypeForOmp(original.getType()), original.getLoc());
Value restored;
if (isPtr(original.getType())) {
restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
if (!isa<LLVM::LLVMPointerType>(original.getType()))
- restored = rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
- }
- else {
- restored = rewriter.create<fir::LoadOp>(loc, newArg);
+ restored =
+ rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
+ } else {
+ restored = rewriter.create<fir::LoadOp>(loc, newArg);
}
mapping.map(original, restored);
}
@@ -604,14 +606,14 @@ static void reloadCacheAndRecompute(Location loc, RewriterBase &rewriter,
}
static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter) {
auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
- MLIRContext& ctx = *targetOp.getContext();
+ MLIRContext &ctx = *targetOp.getContext();
assert(targetOp);
auto loc = targetOp.getLoc();
auto *targetBlock = &targetOp.getRegion().front();
rewriter.setInsertionPoint(targetOp);
-
+
auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
@@ -621,21 +623,24 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
SetVector<Operation *> nonRecomputable;
SmallVector<Value> allocs;
- for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); it++) {
+ for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+ it++) {
for (auto res : it->getResults()) {
if (usedOutsideSplit(res, splitBeforeOp))
requiredVals.push_back(res);
}
if (!isRecomputableAfterFission(&*it, splitBeforeOp))
- nonRecomputable.insert(&*it);
+ nonRecomputable.insert(&*it);
}
for (auto requiredVal : requiredVals)
- collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, toRecompute);
-
+ collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+ toRecompute);
+
for (Operation *op : toCache) {
for (auto res : op->getResults()) {
- auto alloc = allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+ auto alloc =
+ allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
allocs.push_back(res);
preMapOperands.push_back(alloc.from);
postMapOperands.push_back(alloc.to);
@@ -645,16 +650,16 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
rewriter.setInsertionPoint(targetOp);
auto preTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
- targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
- targetOp.getDependVars(), targetOp.getDevice(),
- targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
- targetOp.getIfExpr(), targetOp.getInReductionVars(),
- targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
- targetOp.getIsDevicePtrVars(), preMapOperands,
- targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
- targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ preMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
IRMapping preMapping;
@@ -669,7 +674,6 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
-
for (auto original : allocs) {
Value toStore = preMapping.lookup(original);
auto newArg = preTargetBlock->addArgument(
@@ -687,53 +691,52 @@ static SplitResult isolateOp(Operation *splitBeforeOp, bool splitAfter,
rewriter.setInsertionPoint(targetOp);
auto isolatedTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
- targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
- targetOp.getDependVars(), targetOp.getDevice(),
- targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
- targetOp.getIfExpr(), targetOp.getInReductionVars(),
- targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
- targetOp.getIsDevicePtrVars(), postMapOperands,
- targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
- rewriter.createBlock(&isolatedTargetOp.getRegion(),
- isolatedTargetOp.getRegion().begin(), {}, {});
+ rewriter.createBlock(&isolatedTargetOp.getRegion(),
+ isolatedTargetOp.getRegion().begin(), {}, {});
IRMapping isolatedMapping;
reloadCacheAndRecompute(loc, rewriter, ctx, isolatedMapping, splitBeforeOp,
- targetBlock, isolatedTargetBlock,
- allocs, toRecompute);
+ targetBlock, isolatedTargetBlock, allocs,
+ toRecompute);
rewriter.clone(*splitBeforeOp, isolatedMapping);
rewriter.create<omp::TerminatorOp>(loc);
omp::TargetOp postTargetOp = nullptr;
-
+
if (splitAfter) {
- rewriter.setInsertionPoint(targetOp);
+ rewriter.setInsertionPoint(targetOp);
postTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(),
- targetOp.getBareAttr(), targetOp.getDependKindsAttr(),
- targetOp.getDependVars(), targetOp.getDevice(),
- targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(),
- targetOp.getIfExpr(), targetOp.getInReductionVars(),
- targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
- targetOp.getIsDevicePtrVars(), postMapOperands,
- targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ postMapOperands, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
auto *postTargetBlock = rewriter.createBlock(
- &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+ &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
IRMapping postMapping;
- reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
- targetBlock, postTargetBlock,
- allocs, toRecompute);
+ reloadCacheAndRecompute(loc, rewriter, ctx, postMapping, splitBeforeOp,
+ targetBlock, postTargetBlock, allocs, toRecompute);
assert(splitBeforeOp->getNumResults() == 0 ||
- llvm::all_of(splitBeforeOp->getResults(),
- [](Value result) { return result.use_empty(); }));
+ llvm::all_of(splitBeforeOp->getResults(),
+ [](Value result) { return result.use_empty(); }));
for (auto it = std::next(splitBeforeOp->getIterator());
it != targetBlock->end(); it++)
@@ -912,7 +915,8 @@ public:
IRRewriter rewriter(&context);
for (auto targetOp : targetOps) {
auto res = splitTargetData(targetOp, rewriter);
- if (res) fissionTarget(res->targetOp, rewriter);
+ if (res)
+ fissionTarget(res->targetOp, rewriter);
}
}
}
|
Lowering logic inspired from ivanradanov coexeute lowering f56da1a
Fission logic inspired from ivanradanov implementation : c97eca4
Logic inspired from ivanradanov commit 5682e9e
Logic inspired from ivanradanov llvm branch: flang_workdistribute_iwomp_2024 commit: a774515
This commit is CP from ivanradanov commit be860ac
This commit is C-P from ivaradanov commit be860ac
This commit is c-p from ivanradanov commit d7e4499
C-P from ivanradanov commit 73fd865
bcf51ad
to
8669a40
Compare
I am not sure the using rewrite patterns is suitable here - mainly because we have a list of transformations we know we need to do in what order so there is no need to involve patterns which add additional overhead and complexity (guarding against matching again etc, making sure the order is correct). That's how the LowerWorkshare pass is implemented. |
Pre-requiste PRs:
[flang] Introduce omp_target_allocmem and omp_target_freemem fir ops.
[flang-rt] Add Assign_omp RT call.
[flang] Add support for workdistribute construct in flang frontend
This PR introduces a new pass "lower-workdistribute" which identifies parallel ops inside workdistribute region and moves them to new omp. target region.
This pass implements following patternmatches and optimisations:
FissionWorkdistribute, WorkdistributeDoLower and TeamsWorkdistributeToSingle.
After the pattern match and rewrite, omp.target is moved under omp.target_data region and then only moves the parallelize ops to new omp.target region and moves all other ops to host.
The work in this PR is C-P and updated from @ivanradanov commits from coexecute implementation:
flang_workdistribute_iwomp_2024