@@ -312,6 +312,19 @@ class DoConcurrentConversion
312
312
bool isComposite) const {
313
313
mlir::omp::WsloopOperands wsloopClauseOps;
314
314
315
+ auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
316
+ mlir::Region &ompRegion) {
317
+ if (!firRegion.empty ()) {
318
+ rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
319
+ auto firYield =
320
+ mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
321
+ rewriter.setInsertionPoint (firYield);
322
+ rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
323
+ firYield.getOperands ());
324
+ rewriter.eraseOp (firYield);
325
+ }
326
+ };
327
+
315
328
// For `local` (and `local_init`) opernads, emit corresponding `private`
316
329
// clauses and attach these clauses to the workshare loop.
317
330
if (!loop.getLocalVars ().empty ())
@@ -326,50 +339,65 @@ class DoConcurrentConversion
326
339
TODO (localizer.getLoc (),
327
340
" local_init conversion is not supported yet" );
328
341
329
- auto oldIP = rewriter. saveInsertionPoint ( );
342
+ mlir::OpBuilder::InsertionGuard guard (rewriter );
330
343
rewriter.setInsertionPointAfter (localizer);
344
+
331
345
auto privatizer = rewriter.create <mlir::omp::PrivateClauseOp>(
332
346
localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
333
347
localizer.getTypeAttr ().getValue (),
334
348
mlir::omp::DataSharingClauseType::Private);
335
349
336
- if (!localizer.getInitRegion ().empty ()) {
337
- rewriter.cloneRegionBefore (localizer.getInitRegion (),
338
- privatizer.getInitRegion (),
339
- privatizer.getInitRegion ().begin ());
340
- auto firYield = mlir::cast<fir::YieldOp>(
341
- privatizer.getInitRegion ().back ().getTerminator ());
342
- rewriter.setInsertionPoint (firYield);
343
- rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
344
- firYield.getOperands ());
345
- rewriter.eraseOp (firYield);
346
- }
347
-
348
- if (!localizer.getDeallocRegion ().empty ()) {
349
- rewriter.cloneRegionBefore (localizer.getDeallocRegion (),
350
- privatizer.getDeallocRegion (),
351
- privatizer.getDeallocRegion ().begin ());
352
- auto firYield = mlir::cast<fir::YieldOp>(
353
- privatizer.getDeallocRegion ().back ().getTerminator ());
354
- rewriter.setInsertionPoint (firYield);
355
- rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
356
- firYield.getOperands ());
357
- rewriter.eraseOp (firYield);
358
- }
359
-
360
- rewriter.restoreInsertionPoint (oldIP);
350
+ cloneFIRRegionToOMP (localizer.getInitRegion (),
351
+ privatizer.getInitRegion ());
352
+ cloneFIRRegionToOMP (localizer.getDeallocRegion (),
353
+ privatizer.getDeallocRegion ());
361
354
362
355
wsloopClauseOps.privateVars .push_back (op);
363
356
wsloopClauseOps.privateSyms .push_back (
364
357
mlir::SymbolRefAttr::get (privatizer));
365
358
}
366
359
360
+ if (!loop.getReduceVars ().empty ()) {
361
+ for (auto [op, byRef, sym, arg] : llvm::zip_equal (
362
+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
363
+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
364
+ loop.getRegionReduceArgs ())) {
365
+ auto firReducer =
366
+ mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
367
+ loop, sym);
368
+
369
+ mlir::OpBuilder::InsertionGuard guard (rewriter);
370
+ rewriter.setInsertionPointAfter (firReducer);
371
+
372
+ auto ompReducer = rewriter.create <mlir::omp::DeclareReductionOp>(
373
+ firReducer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
374
+ firReducer.getTypeAttr ().getValue ());
375
+
376
+ cloneFIRRegionToOMP (firReducer.getAllocRegion (),
377
+ ompReducer.getAllocRegion ());
378
+ cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
379
+ ompReducer.getInitializerRegion ());
380
+ cloneFIRRegionToOMP (firReducer.getReductionRegion (),
381
+ ompReducer.getReductionRegion ());
382
+ cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
383
+ ompReducer.getAtomicReductionRegion ());
384
+ cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
385
+ ompReducer.getCleanupRegion ());
386
+
387
+ wsloopClauseOps.reductionVars .push_back (op);
388
+ wsloopClauseOps.reductionByref .push_back (byRef);
389
+ wsloopClauseOps.reductionSyms .push_back (
390
+ mlir::SymbolRefAttr::get (ompReducer));
391
+ }
392
+ }
393
+
367
394
auto wsloopOp =
368
395
rewriter.create <mlir::omp::WsloopOp>(loop.getLoc (), wsloopClauseOps);
369
396
wsloopOp.setComposite (isComposite);
370
397
371
398
Fortran::common::openmp::EntryBlockArgs wsloopArgs;
372
399
wsloopArgs.priv .vars = wsloopClauseOps.privateVars ;
400
+ wsloopArgs.reduction .vars = wsloopClauseOps.reductionVars ;
373
401
Fortran::common::openmp::genEntryBlock (rewriter, wsloopArgs,
374
402
wsloopOp.getRegion ());
375
403
@@ -393,7 +421,8 @@ class DoConcurrentConversion
393
421
clauseOps.loopLowerBounds .size ())))
394
422
rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
395
423
396
- for (unsigned i = 0 ; i < loop.getLocalVars ().size (); ++i)
424
+ for (unsigned i = 0 ;
425
+ i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
397
426
loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
398
427
399
428
return loopNestOp;
0 commit comments