14
14
#include " flang/Optimizer/Dialect/FIRDialect.h"
15
15
#include " flang/Optimizer/Dialect/FIROps.h"
16
16
#include " flang/Optimizer/Dialect/FIRType.h"
17
- #include " flang/Optimizer/Transforms/Passes.h"
18
17
#include " flang/Optimizer/HLFIR/Passes.h"
19
18
#include " flang/Optimizer/OpenMP/Utils.h"
19
+ #include " flang/Optimizer/Transforms/Passes.h"
20
20
#include " mlir/Analysis/SliceAnalysis.h"
21
21
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
22
22
#include " mlir/IR/Builders.h"
23
23
#include " mlir/IR/Value.h"
24
24
#include " mlir/Transforms/DialectConversion.h"
25
25
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
26
+ #include " mlir/Transforms/RegionUtils.h"
26
27
#include < mlir/Dialect/Arith/IR/Arith.h>
27
28
#include < mlir/Dialect/LLVMIR/LLVMTypes.h>
28
29
#include < mlir/Dialect/Utils/IndexingUtils.h>
33
34
#include < mlir/IR/PatternMatch.h>
34
35
#include < mlir/Interfaces/SideEffectInterfaces.h>
35
36
#include < mlir/Support/LLVM.h>
36
- #include " mlir/Transforms/RegionUtils.h"
37
37
#include < optional>
38
38
#include < variant>
39
39
@@ -66,30 +66,30 @@ static T getPerfectlyNested(Operation *op) {
66
66
// / This is the single source of truth about whether we should parallelize an
67
67
// / operation nested in an omp.workdistribute region.
68
68
static bool shouldParallelize (Operation *op) {
69
- // Currently we cannot parallelize operations with results that have uses
70
- if (llvm::any_of (op->getResults (),
71
- [](OpResult v) -> bool { return !v.use_empty (); }))
69
+ // Currently we cannot parallelize operations with results that have uses
70
+ if (llvm::any_of (op->getResults (),
71
+ [](OpResult v) -> bool { return !v.use_empty (); }))
72
+ return false ;
73
+ // We will parallelize unordered loops - these come from array syntax
74
+ if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
75
+ auto unordered = loop.getUnordered ();
76
+ if (!unordered)
72
77
return false ;
73
- // We will parallelize unordered loops - these come from array syntax
74
- if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
75
- auto unordered = loop.getUnordered ();
76
- if (!unordered)
77
- return false ;
78
- return *unordered;
79
- }
80
- if (auto callOp = dyn_cast<fir::CallOp>(op)) {
81
- auto callee = callOp.getCallee ();
82
- if (!callee)
83
- return false ;
84
- auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
85
- // TODO need to insert a check here whether it is a call we can actually
86
- // parallelize currently
87
- if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
88
- return true ;
78
+ return *unordered;
79
+ }
80
+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
81
+ auto callee = callOp.getCallee ();
82
+ if (!callee)
89
83
return false ;
90
- }
91
- // We cannot parallise anything else
84
+ auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
85
+ // TODO need to insert a check here whether it is a call we can actually
86
+ // parallelize currently
87
+ if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
88
+ return true ;
92
89
return false ;
90
+ }
91
+ // We cannot parallise anything else
92
+ return false ;
93
93
}
94
94
95
95
// / If B() and D() are parallelizable,
@@ -120,12 +120,10 @@ static bool shouldParallelize(Operation *op) {
120
120
// / }
121
121
// / E()
122
122
123
- struct FissionWorkdistribute
124
- : public OpRewritePattern<omp::WorkdistributeOp> {
123
+ struct FissionWorkdistribute : public OpRewritePattern <omp::WorkdistributeOp> {
125
124
using OpRewritePattern::OpRewritePattern;
126
- LogicalResult
127
- matchAndRewrite (omp::WorkdistributeOp workdistribute,
128
- PatternRewriter &rewriter) const override {
125
+ LogicalResult matchAndRewrite (omp::WorkdistributeOp workdistribute,
126
+ PatternRewriter &rewriter) const override {
129
127
auto loc = workdistribute->getLoc ();
130
128
auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ());
131
129
if (!teams) {
@@ -185,7 +183,7 @@ struct FissionWorkdistribute
185
183
auto newWorkdistribute = rewriter.create <omp::WorkdistributeOp>(loc);
186
184
rewriter.create <omp::TerminatorOp>(loc);
187
185
rewriter.createBlock (&newWorkdistribute.getRegion (),
188
- newWorkdistribute.getRegion ().begin (), {}, {});
186
+ newWorkdistribute.getRegion ().begin (), {}, {});
189
187
auto *cloned = rewriter.clone (*parallelize);
190
188
rewriter.replaceOp (parallelize, cloned);
191
189
rewriter.create <omp::TerminatorOp>(loc);
@@ -197,8 +195,7 @@ struct FissionWorkdistribute
197
195
};
198
196
199
197
static void
200
- genLoopNestClauseOps (mlir::Location loc,
201
- mlir::PatternRewriter &rewriter,
198
+ genLoopNestClauseOps (mlir::Location loc, mlir::PatternRewriter &rewriter,
202
199
fir::DoLoopOp loop,
203
200
mlir::omp::LoopNestOperands &loopNestClauseOps) {
204
201
assert (loopNestClauseOps.loopLowerBounds .empty () &&
@@ -209,10 +206,8 @@ genLoopNestClauseOps(mlir::Location loc,
209
206
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
210
207
}
211
208
212
- static void
213
- genWsLoopOp (mlir::PatternRewriter &rewriter,
214
- fir::DoLoopOp doLoop,
215
- const mlir::omp::LoopNestOperands &clauseOps) {
209
+ static void genWsLoopOp (mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
210
+ const mlir::omp::LoopNestOperands &clauseOps) {
216
211
217
212
auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
218
213
rewriter.createBlock (&wsloopOp.getRegion ());
@@ -236,7 +231,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
236
231
return ;
237
232
}
238
233
239
- // / If fir.do_loop id present inside teams workdistribute
234
+ // / If fir.do_loop is present inside teams workdistribute
240
235
// /
241
236
// / omp.teams {
242
237
// / omp.workdistribute {
@@ -246,7 +241,7 @@ genWsLoopOp(mlir::PatternRewriter &rewriter,
246
241
// / }
247
242
// / }
248
243
// /
249
- // / Then, its lowered to
244
+ // / Then, its lowered to
250
245
// /
251
246
// / omp.teams {
252
247
// / omp.workdistribute {
@@ -277,7 +272,8 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
277
272
278
273
auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(teamsLoc);
279
274
rewriter.createBlock (¶llelOp.getRegion ());
280
- rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(doLoop.getLoc ()));
275
+ rewriter.setInsertionPoint (
276
+ rewriter.create <mlir::omp::TerminatorOp>(doLoop.getLoc ()));
281
277
282
278
mlir::omp::LoopNestOperands loopNestClauseOps;
283
279
genLoopNestClauseOps (doLoop.getLoc (), rewriter, doLoop,
@@ -292,7 +288,6 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
292
288
}
293
289
};
294
290
295
-
296
291
// / If A() and B () are present inside teams workdistribute
297
292
// /
298
293
// / omp.teams {
@@ -311,17 +306,17 @@ struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
311
306
struct TeamsWorkdistributeToSingle : public OpRewritePattern <omp::TeamsOp> {
312
307
using OpRewritePattern::OpRewritePattern;
313
308
LogicalResult matchAndRewrite (omp::TeamsOp teamsOp,
314
- PatternRewriter &rewriter) const override {
315
- auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
316
- if (!workdistributeOp) {
317
- LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " No workdistribute nested\n " );
318
- return failure ();
319
- }
320
- Block *workdistributeBlock = &workdistributeOp.getRegion ().front ();
321
- rewriter.eraseOp (workdistributeBlock->getTerminator ());
322
- rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
323
- rewriter.eraseOp (teamsOp);
324
- return success ();
309
+ PatternRewriter &rewriter) const override {
310
+ auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
311
+ if (!workdistributeOp) {
312
+ LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " No workdistribute nested\n " );
313
+ return failure ();
314
+ }
315
+ Block *workdistributeBlock = &workdistributeOp.getRegion ().front ();
316
+ rewriter.eraseOp (workdistributeBlock->getTerminator ());
317
+ rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
318
+ rewriter.eraseOp (teamsOp);
319
+ return success ();
325
320
}
326
321
};
327
322
@@ -332,26 +327,27 @@ class LowerWorkdistributePass
332
327
MLIRContext &context = getContext ();
333
328
GreedyRewriteConfig config;
334
329
// prevent the pattern driver form merging blocks
335
- config.setRegionSimplificationLevel (
336
- GreedySimplifyRegionLevel::Disabled);
337
-
330
+ config.setRegionSimplificationLevel (GreedySimplifyRegionLevel::Disabled);
331
+
338
332
Operation *op = getOperation ();
339
333
{
340
334
RewritePatternSet patterns (&context);
341
- patterns.insert <FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
335
+ patterns.insert <FissionWorkdistribute, TeamsWorkdistributeLowering>(
336
+ &context);
342
337
if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
343
338
emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
344
339
signalPassFailure ();
345
340
}
346
341
}
347
342
{
348
343
RewritePatternSet patterns (&context);
349
- patterns.insert <TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
344
+ patterns.insert <TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
345
+ &context);
350
346
if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
351
347
emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
352
348
signalPassFailure ();
353
349
}
354
350
}
355
351
}
356
352
};
357
- }
353
+ } // namespace
0 commit comments