|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 | //
|
9 |
| -// This file implements the lowering of omp.workdistribute. |
| 9 | +// This file implements the lowering and optimisations of omp.workdistribute. |
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include "flang/Optimizer/Builder/FIRBuilder.h" |
13 | 14 | #include "flang/Optimizer/Dialect/FIRDialect.h"
|
14 | 15 | #include "flang/Optimizer/Dialect/FIROps.h"
|
15 | 16 | #include "flang/Optimizer/Dialect/FIRType.h"
|
16 | 17 | #include "flang/Optimizer/Transforms/Passes.h"
|
17 | 18 | #include "flang/Optimizer/HLFIR/Passes.h"
|
| 19 | +#include "flang/Optimizer/OpenMP/Utils.h" |
| 20 | +#include "mlir/Analysis/SliceAnalysis.h" |
18 | 21 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
19 | 22 | #include "mlir/IR/Builders.h"
|
20 | 23 | #include "mlir/IR/Value.h"
|
| 24 | +#include "mlir/Transforms/DialectConversion.h" |
21 | 25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
22 | 26 | #include <mlir/Dialect/Arith/IR/Arith.h>
|
23 | 27 | #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
|
29 | 33 | #include <mlir/IR/PatternMatch.h>
|
30 | 34 | #include <mlir/Interfaces/SideEffectInterfaces.h>
|
31 | 35 | #include <mlir/Support/LLVM.h>
|
| 36 | +#include "mlir/Transforms/RegionUtils.h" |
32 | 37 | #include <optional>
|
33 | 38 | #include <variant>
|
34 | 39 |
|
@@ -87,25 +92,6 @@ static bool shouldParallelize(Operation *op) {
|
87 | 92 | return false;
|
88 | 93 | }
|
89 | 94 |
|
90 |
| -struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> { |
91 |
| - using OpRewritePattern::OpRewritePattern; |
92 |
| - LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, |
93 |
| - PatternRewriter &rewriter) const override { |
94 |
| - auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp); |
95 |
| - if (!workdistributeOp) { |
96 |
| - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); |
97 |
| - return failure(); |
98 |
| - } |
99 |
| - |
100 |
| - Block *workdistributeBlock = &workdistributeOp.getRegion().front(); |
101 |
| - rewriter.eraseOp(workdistributeBlock->getTerminator()); |
102 |
| - rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); |
103 |
| - rewriter.eraseOp(teamsOp); |
104 |
| - workdistributeOp.emitWarning("unable to parallelize coexecute"); |
105 |
| - return success(); |
106 |
| - } |
107 |
| -}; |
108 |
| - |
109 | 95 | /// If B() and D() are parallelizable,
|
110 | 96 | ///
|
111 | 97 | /// omp.teams {
|
@@ -210,22 +196,161 @@ struct FissionWorkdistribute
|
210 | 196 | }
|
211 | 197 | };
|
212 | 198 |
|
| 199 | +static void |
| 200 | +genLoopNestClauseOps(mlir::Location loc, |
| 201 | + mlir::PatternRewriter &rewriter, |
| 202 | + fir::DoLoopOp loop, |
| 203 | + mlir::omp::LoopNestOperands &loopNestClauseOps) { |
| 204 | + assert(loopNestClauseOps.loopLowerBounds.empty() && |
| 205 | + "Loop nest bounds were already emitted!"); |
| 206 | + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); |
| 207 | + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); |
| 208 | + loopNestClauseOps.loopSteps.push_back(loop.getStep()); |
| 209 | + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); |
| 210 | +} |
| 211 | + |
| 212 | +static void |
| 213 | +genWsLoopOp(mlir::PatternRewriter &rewriter, |
| 214 | + fir::DoLoopOp doLoop, |
| 215 | + const mlir::omp::LoopNestOperands &clauseOps) { |
| 216 | + |
| 217 | + auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc()); |
| 218 | + rewriter.createBlock(&wsloopOp.getRegion()); |
| 219 | + |
| 220 | + auto loopNestOp = |
| 221 | + rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps); |
| 222 | + |
| 223 | + // Clone the loop's body inside the loop nest construct using the |
| 224 | + // mapped values. |
| 225 | + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), |
| 226 | + loopNestOp.getRegion().begin()); |
| 227 | + Block *clonedBlock = &loopNestOp.getRegion().back(); |
| 228 | + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); |
| 229 | + |
| 230 | + // Erase fir.result op of do loop and create yield op. |
| 231 | + if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) { |
| 232 | + rewriter.setInsertionPoint(terminatorOp); |
| 233 | + rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc()); |
| 234 | + rewriter.eraseOp(terminatorOp); |
| 235 | + } |
| 236 | + return; |
| 237 | +} |
| 238 | + |
| 239 | +/// If fir.do_loop id present inside teams workdistribute |
| 240 | +/// |
| 241 | +/// omp.teams { |
| 242 | +/// omp.workdistribute { |
| 243 | +/// fir.do_loop unoredered { |
| 244 | +/// ... |
| 245 | +/// } |
| 246 | +/// } |
| 247 | +/// } |
| 248 | +/// |
| 249 | +/// Then, its lowered to |
| 250 | +/// |
| 251 | +/// omp.teams { |
| 252 | +/// omp.workdistribute { |
| 253 | +/// omp.parallel { |
| 254 | +/// omp.wsloop { |
| 255 | +/// omp.loop_nest |
| 256 | +/// ... |
| 257 | +/// } |
| 258 | +/// } |
| 259 | +/// } |
| 260 | +/// } |
| 261 | +/// } |
| 262 | + |
| 263 | +struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> { |
| 264 | + using OpRewritePattern::OpRewritePattern; |
| 265 | + LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, |
| 266 | + PatternRewriter &rewriter) const override { |
| 267 | + auto teamsLoc = teamsOp->getLoc(); |
| 268 | + auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp); |
| 269 | + if (!workdistributeOp) { |
| 270 | + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); |
| 271 | + return failure(); |
| 272 | + } |
| 273 | + assert(teamsOp.getReductionVars().empty()); |
| 274 | + |
| 275 | + auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp); |
| 276 | + if (doLoop && shouldParallelize(doLoop)) { |
| 277 | + |
| 278 | + auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc); |
| 279 | + rewriter.createBlock(¶llelOp.getRegion()); |
| 280 | + rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc())); |
| 281 | + |
| 282 | + mlir::omp::LoopNestOperands loopNestClauseOps; |
| 283 | + genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop, |
| 284 | + loopNestClauseOps); |
| 285 | + |
| 286 | + genWsLoopOp(rewriter, doLoop, loopNestClauseOps); |
| 287 | + rewriter.setInsertionPoint(doLoop); |
| 288 | + rewriter.eraseOp(doLoop); |
| 289 | + return success(); |
| 290 | + } |
| 291 | + return failure(); |
| 292 | + } |
| 293 | +}; |
| 294 | + |
| 295 | + |
| 296 | +/// If A() and B () are present inside teams workdistribute |
| 297 | +/// |
| 298 | +/// omp.teams { |
| 299 | +/// omp.workdistribute { |
| 300 | +/// A() |
| 301 | +/// B() |
| 302 | +/// } |
| 303 | +/// } |
| 304 | +/// |
| 305 | +/// Then, its lowered to |
| 306 | +/// |
| 307 | +/// A() |
| 308 | +/// B() |
| 309 | +/// |
| 310 | + |
| 311 | +struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> { |
| 312 | + using OpRewritePattern::OpRewritePattern; |
| 313 | + LogicalResult matchAndRewrite(omp::TeamsOp teamsOp, |
| 314 | + PatternRewriter &rewriter) const override { |
| 315 | + auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp); |
| 316 | + if (!workdistributeOp) { |
| 317 | + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n"); |
| 318 | + return failure(); |
| 319 | + } |
| 320 | + Block *workdistributeBlock = &workdistributeOp.getRegion().front(); |
| 321 | + rewriter.eraseOp(workdistributeBlock->getTerminator()); |
| 322 | + rewriter.inlineBlockBefore(workdistributeBlock, teamsOp); |
| 323 | + rewriter.eraseOp(teamsOp); |
| 324 | + return success(); |
| 325 | + } |
| 326 | +}; |
| 327 | + |
213 | 328 | class LowerWorkdistributePass
|
214 | 329 | : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
|
215 | 330 | public:
|
216 | 331 | void runOnOperation() override {
|
217 | 332 | MLIRContext &context = getContext();
|
218 |
| - RewritePatternSet patterns(&context); |
219 | 333 | GreedyRewriteConfig config;
|
220 | 334 | // prevent the pattern driver form merging blocks
|
221 | 335 | config.setRegionSimplificationLevel(
|
222 | 336 | GreedySimplifyRegionLevel::Disabled);
|
223 |
| - |
224 |
| - patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context); |
| 337 | + |
225 | 338 | Operation *op = getOperation();
|
226 |
| - if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { |
227 |
| - emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); |
228 |
| - signalPassFailure(); |
| 339 | + { |
| 340 | + RewritePatternSet patterns(&context); |
| 341 | + patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(&context); |
| 342 | + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { |
| 343 | + emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); |
| 344 | + signalPassFailure(); |
| 345 | + } |
| 346 | + } |
| 347 | + { |
| 348 | + RewritePatternSet patterns(&context); |
| 349 | + patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context); |
| 350 | + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { |
| 351 | + emitError(op->getLoc(), DEBUG_TYPE " pass failed\n"); |
| 352 | + signalPassFailure(); |
| 353 | + } |
229 | 354 | }
|
230 | 355 | }
|
231 | 356 | };
|
|
0 commit comments