Skip to content

Commit 5b30d3d

Browse files
committed
[OpenMP][Flang] Lower teams workdistribute do_loop to wsloop.
Logic inspired from ivanradanov commit 5682e9e
1 parent 048c3f2 commit 5b30d3d

File tree

3 files changed

+193
-34
lines changed

3 files changed

+193
-34
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 151 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements the lowering of omp.workdistribute.
9+
// This file implements the lowering and optimisations of omp.workdistribute.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "flang/Optimizer/Builder/FIRBuilder.h"
1314
#include "flang/Optimizer/Dialect/FIRDialect.h"
1415
#include "flang/Optimizer/Dialect/FIROps.h"
1516
#include "flang/Optimizer/Dialect/FIRType.h"
1617
#include "flang/Optimizer/Transforms/Passes.h"
1718
#include "flang/Optimizer/HLFIR/Passes.h"
19+
#include "flang/Optimizer/OpenMP/Utils.h"
20+
#include "mlir/Analysis/SliceAnalysis.h"
1821
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1922
#include "mlir/IR/Builders.h"
2023
#include "mlir/IR/Value.h"
24+
#include "mlir/Transforms/DialectConversion.h"
2125
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2226
#include <mlir/Dialect/Arith/IR/Arith.h>
2327
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
@@ -29,6 +33,7 @@
2933
#include <mlir/IR/PatternMatch.h>
3034
#include <mlir/Interfaces/SideEffectInterfaces.h>
3135
#include <mlir/Support/LLVM.h>
36+
#include "mlir/Transforms/RegionUtils.h"
3237
#include <optional>
3338
#include <variant>
3439

@@ -87,25 +92,6 @@ static bool shouldParallelize(Operation *op) {
8792
return false;
8893
}
8994

90-
struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
91-
using OpRewritePattern::OpRewritePattern;
92-
LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
93-
PatternRewriter &rewriter) const override {
94-
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
95-
if (!workdistributeOp) {
96-
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
97-
return failure();
98-
}
99-
100-
Block *workdistributeBlock = &workdistributeOp.getRegion().front();
101-
rewriter.eraseOp(workdistributeBlock->getTerminator());
102-
rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
103-
rewriter.eraseOp(teamsOp);
104-
workdistributeOp.emitWarning("unable to parallelize coexecute");
105-
return success();
106-
}
107-
};
108-
10995
/// If B() and D() are parallelizable,
11096
///
11197
/// omp.teams {
@@ -210,22 +196,161 @@ struct FissionWorkdistribute
210196
}
211197
};
212198

199+
static void
200+
genLoopNestClauseOps(mlir::Location loc,
201+
mlir::PatternRewriter &rewriter,
202+
fir::DoLoopOp loop,
203+
mlir::omp::LoopNestOperands &loopNestClauseOps) {
204+
assert(loopNestClauseOps.loopLowerBounds.empty() &&
205+
"Loop nest bounds were already emitted!");
206+
loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
207+
loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
208+
loopNestClauseOps.loopSteps.push_back(loop.getStep());
209+
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
210+
}
211+
212+
static void
213+
genWsLoopOp(mlir::PatternRewriter &rewriter,
214+
fir::DoLoopOp doLoop,
215+
const mlir::omp::LoopNestOperands &clauseOps) {
216+
217+
auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
218+
rewriter.createBlock(&wsloopOp.getRegion());
219+
220+
auto loopNestOp =
221+
rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
222+
223+
// Clone the loop's body inside the loop nest construct using the
224+
// mapped values.
225+
rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
226+
loopNestOp.getRegion().begin());
227+
Block *clonedBlock = &loopNestOp.getRegion().back();
228+
mlir::Operation *terminatorOp = clonedBlock->getTerminator();
229+
230+
// Erase fir.result op of do loop and create yield op.
231+
if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
232+
rewriter.setInsertionPoint(terminatorOp);
233+
rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
234+
rewriter.eraseOp(terminatorOp);
235+
}
236+
return;
237+
}
238+
239+
/// If fir.do_loop id present inside teams workdistribute
240+
///
241+
/// omp.teams {
242+
/// omp.workdistribute {
243+
/// fir.do_loop unoredered {
244+
/// ...
245+
/// }
246+
/// }
247+
/// }
248+
///
249+
/// Then, its lowered to
250+
///
251+
/// omp.teams {
252+
/// omp.workdistribute {
253+
/// omp.parallel {
254+
/// omp.wsloop {
255+
/// omp.loop_nest
256+
/// ...
257+
/// }
258+
/// }
259+
/// }
260+
/// }
261+
/// }
262+
263+
struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
264+
using OpRewritePattern::OpRewritePattern;
265+
LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
266+
PatternRewriter &rewriter) const override {
267+
auto teamsLoc = teamsOp->getLoc();
268+
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
269+
if (!workdistributeOp) {
270+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
271+
return failure();
272+
}
273+
assert(teamsOp.getReductionVars().empty());
274+
275+
auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp);
276+
if (doLoop && shouldParallelize(doLoop)) {
277+
278+
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
279+
rewriter.createBlock(&parallelOp.getRegion());
280+
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
281+
282+
mlir::omp::LoopNestOperands loopNestClauseOps;
283+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
284+
loopNestClauseOps);
285+
286+
genWsLoopOp(rewriter, doLoop, loopNestClauseOps);
287+
rewriter.setInsertionPoint(doLoop);
288+
rewriter.eraseOp(doLoop);
289+
return success();
290+
}
291+
return failure();
292+
}
293+
};
294+
295+
296+
/// If A() and B () are present inside teams workdistribute
297+
///
298+
/// omp.teams {
299+
/// omp.workdistribute {
300+
/// A()
301+
/// B()
302+
/// }
303+
/// }
304+
///
305+
/// Then, its lowered to
306+
///
307+
/// A()
308+
/// B()
309+
///
310+
311+
struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
312+
using OpRewritePattern::OpRewritePattern;
313+
LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
314+
PatternRewriter &rewriter) const override {
315+
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
316+
if (!workdistributeOp) {
317+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
318+
return failure();
319+
}
320+
Block *workdistributeBlock = &workdistributeOp.getRegion().front();
321+
rewriter.eraseOp(workdistributeBlock->getTerminator());
322+
rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
323+
rewriter.eraseOp(teamsOp);
324+
return success();
325+
}
326+
};
327+
213328
class LowerWorkdistributePass
214329
: public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
215330
public:
216331
void runOnOperation() override {
217332
MLIRContext &context = getContext();
218-
RewritePatternSet patterns(&context);
219333
GreedyRewriteConfig config;
220334
// prevent the pattern driver form merging blocks
221335
config.setRegionSimplificationLevel(
222336
GreedySimplifyRegionLevel::Disabled);
223-
224-
patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context);
337+
225338
Operation *op = getOperation();
226-
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
227-
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
228-
signalPassFailure();
339+
{
340+
RewritePatternSet patterns(&context);
341+
patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(&context);
342+
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
343+
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
344+
signalPassFailure();
345+
}
346+
}
347+
{
348+
RewritePatternSet patterns(&context);
349+
patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(&context);
350+
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
351+
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
352+
signalPassFailure();
353+
}
229354
}
230355
}
231356
};
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @x({{.*}})
4+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
5+
// CHECK: omp.parallel {
6+
// CHECK: omp.wsloop {
7+
// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
8+
// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
9+
// CHECK: omp.yield
10+
// CHECK: }
11+
// CHECK: }
12+
// CHECK: omp.terminator
13+
// CHECK: }
14+
// CHECK: return
15+
// CHECK: }
16+
func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
17+
omp.teams {
18+
omp.workdistribute {
19+
fir.do_loop %iv = %lb to %ub step %step unordered {
20+
%zero = arith.constant 0 : index
21+
fir.store %zero to %addr : !fir.ref<index>
22+
}
23+
omp.terminator
24+
}
25+
omp.terminator
26+
}
27+
return
28+
}

flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,26 @@
66
// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index
77
// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
88
// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
9-
// CHECK: fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered {
10-
// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
11-
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
12-
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
13-
// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
9+
// CHECK: omp.parallel {
10+
// CHECK: omp.wsloop {
11+
// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
12+
// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
13+
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
14+
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
15+
// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
16+
// CHECK: omp.yield
17+
// CHECK: }
18+
// CHECK: }
19+
// CHECK: omp.terminator
1420
// CHECK: }
1521
// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
1622
// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
1723
// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
18-
// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
24+
// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
1925
// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
2026
// CHECK: }
21-
// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
22-
// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
27+
// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref<f32>
28+
// CHECK: fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref<f32>
2329
// CHECK: return
2430
// CHECK: }
2531
module {

0 commit comments

Comments
 (0)