diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h index 6e150ef4e8e82..4d77fff912f4f 100644 --- a/flang/include/flang/Lower/OpenMP.h +++ b/flang/include/flang/Lower/OpenMP.h @@ -79,7 +79,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &, void genOpenMPSymbolProperties(AbstractConverter &converter, const pft::Variable &var); -int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); void genThreadprivateOp(AbstractConverter &, const pft::Variable &); void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &); bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index 61f97b855b0e5..f0b0b2cef1b10 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -5024,8 +5024,10 @@ struct OpenMPBlockConstruct { struct OpenMPLoopConstruct { TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct); OpenMPLoopConstruct(OmpBeginLoopDirective &&a) - : t({std::move(a), std::nullopt, std::nullopt}) {} + : t({std::move(a), std::nullopt, std::nullopt, std::nullopt}) {} std::tuple, + // Inner loop construct used to handle tiling for now. + std::optional>, std::optional> t; }; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 2595a08f626e8..9cd95497fc384 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -45,6 +45,7 @@ using namespace Fortran::lower::omp; using namespace Fortran::common::openmp; +using namespace Fortran::semantics; static llvm::cl::opt DumpAtomicAnalysis("fdebug-dump-atomic-analysis"); @@ -456,6 +457,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, return; const parser::OmpClauseList *beginClauseList = nullptr; + const parser::OmpClauseList *middleClauseList = nullptr; const parser::OmpClauseList *endClauseList = nullptr; common::visit( common::visitors{ @@ -473,6 +475,23 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, beginClauseList = &std::get(beginDirective.t); + // FIXME(JAN): For now we check if there is an inner + // OpenMPLoopConstruct, and extract the size clause from there + const auto &innerOptional = std::get>>( + ompConstruct.t); + if (innerOptional.has_value()) { + const auto &innerLoopDirective = innerOptional.value().value(); + const auto &innerBegin = + std::get( + innerLoopDirective.t); + const auto &innerDirective = + std::get(innerBegin.t); + if (innerDirective.v == llvm::omp::Directive::OMPD_tile) { + middleClauseList = + &std::get(innerBegin.t); + } + } if (auto &endDirective = std::get>( ompConstruct.t)) @@ -485,6 +504,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, assert(beginClauseList && "expected begin directive"); clauses.append(makeClauses(*beginClauseList, semaCtx)); + if (middleClauseList) + clauses.append(makeClauses(*middleClauseList, semaCtx)); + if (endClauseList) clauses.append(makeClauses(*endClauseList, semaCtx)); }; @@ -960,6 +982,7 @@ static void genLoopVars( storeOp = createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); } + firOpBuilder.setInsertionPointAfter(storeOp); } @@ -1712,6 +1735,30 @@ genLoopNestClauses(lower::AbstractConverter &converter, cp.processCollapse(loc, eval, clauseOps, iv); clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + for (auto &clause : clauses) { + if (clause.id == llvm::omp::Clause::OMPC_collapse) { + const auto &collapse = std::get(clause.u); + int64_t collapseValue = evaluate::ToInt64(collapse.v).value(); + clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue); + } else if (clause.id == llvm::omp::Clause::OMPC_sizes) { + // This case handles the stand-alone tiling construct + const auto &sizes = std::get(clause.u); + llvm::SmallVector sizeValues; + for (auto &size : sizes.v) { + int64_t sizeValue = evaluate::ToInt64(size).value(); + sizeValues.push_back(sizeValue); + } + clauseOps.tileSizes = sizeValues; + } + } + + llvm::SmallVector sizeValues; + auto *ompCons{eval.getIf()}; + collectTileSizesFromOpenMPConstruct(ompCons, sizeValues, semaCtx); + if (sizeValues.size() > 0) + clauseOps.tileSizes = sizeValues; } static void genLoopClauses( @@ -2085,9 +2132,9 @@ static mlir::omp::LoopNestOp genLoopNestOp( return llvm::SmallVector(iv); }; - auto *nestedEval = - getCollapsedLoopEval(eval, getCollapseValue(item->clauses)); - + uint64_t nestValue = getCollapseValue(item->clauses); + nestValue = nestValue < iv.size() ? iv.size() : nestValue; + auto *nestedEval = getCollapsedLoopEval(eval, nestValue); return genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval, directive) @@ -3610,6 +3657,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter, item); break; case llvm::omp::Directive::OMPD_tile: + newOp = genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item); + break; case llvm::omp::Directive::OMPD_unroll: { unsigned version = semaCtx.langOptions().OpenMPVersion; TODO(loc, "Unhandled loop directive (" + @@ -4186,6 +4235,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, std::get(loopConstruct.t); List clauses = makeClauses( std::get(beginLoopDirective.t), semaCtx); + if (auto &endLoopDirective = std::get>( loopConstruct.t)) { @@ -4292,18 +4342,6 @@ void Fortran::lower::genOpenMPSymbolProperties( lower::genDeclareTargetIntGlobal(converter, var); } -int64_t -Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) { - for (const parser::OmpClause &clause : clauseList.v) { - if (const auto &collapseClause = - std::get_if(&clause.u)) { - const auto *expr = semantics::GetExpr(collapseClause->v); - return evaluate::ToInt64(*expr).value(); - } - } - return 1; -} - void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter, const lower::pft::Variable &var) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index c226c2558e7aa..64dbdc3075656 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -15,6 +15,7 @@ #include "Clauses.h" #include "ClauseFinder.h" +#include "flang/Evaluate/fold.h" #include #include #include @@ -24,10 +25,30 @@ #include #include #include +#include #include #include +using namespace Fortran::semantics; + +template +MaybeIntExpr EvaluateIntExpr(SemanticsContext &context, const T &expr) { + if (MaybeExpr maybeExpr{ + Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) { + if (auto *intExpr{Fortran::evaluate::UnwrapExpr(*maybeExpr)}) { + return std::move(*intExpr); + } + } + return std::nullopt; +} + +template +std::optional EvaluateInt64(SemanticsContext &context, + const T &expr) { + return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr)); +} + llvm::cl::opt treatIndexAsSection( "openmp-treat-index-as-section", llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."), @@ -38,14 +59,21 @@ namespace lower { namespace omp { int64_t getCollapseValue(const List &clauses) { - auto iter = llvm::find_if(clauses, [](const Clause &clause) { - return clause.id == llvm::omp::Clause::OMPC_collapse; - }); - if (iter != clauses.end()) { - const auto &collapse = std::get(iter->u); - return evaluate::ToInt64(collapse.v).value(); + int64_t collapseValue = 1; + int64_t numTileSizes = 0; + for (auto &clause : clauses) { + if (clause.id == llvm::omp::Clause::OMPC_collapse) { + const auto &collapse = std::get(clause.u); + collapseValue = evaluate::ToInt64(collapse.v).value(); + } else if (clause.id == llvm::omp::Clause::OMPC_sizes) { + const auto &sizes = std::get(clause.u); + numTileSizes = sizes.v.size(); + } } - return 1; + + collapseValue = collapseValue - numTileSizes; + int64_t result = collapseValue > numTileSizes ? collapseValue : numTileSizes; + return result; } void genObjectList(const ObjectList &objects, @@ -606,11 +634,48 @@ static void convertLoopBounds(lower::AbstractConverter &converter, } } +// Populates the sizes vector with values if the given OpenMPConstruct +// Contains a loop construct with an inner tiling construct. +void collectTileSizesFromOpenMPConstruct( + const parser::OpenMPConstruct *ompCons, + llvm::SmallVectorImpl &tileSizes, SemanticsContext &semaCtx) { + if (!ompCons) + return; + + if (auto *ompLoop{std::get_if(&ompCons->u)}) { + const auto &innerOptional = std::get< + std::optional>>( + ompLoop->t); + if (innerOptional.has_value()) { + const auto &innerLoopDirective = innerOptional.value().value(); + const auto &innerBegin = + std::get(innerLoopDirective.t); + const auto &innerDirective = + std::get(innerBegin.t).v; + + if (innerDirective == llvm::omp::Directive::OMPD_tile) { + // Get the size values from parse tree and convert to a vector + const auto &innerClauseList{ + std::get(innerBegin.t)}; + for (const auto &clause : innerClauseList.v) + if (const auto tclause{ + std::get_if(&clause.u)}) { + for (auto &tval : tclause->v) { + if (const auto v{EvaluateInt64(semaCtx, tval)}) + tileSizes.push_back(*v); + } + } + } + } + } +} + bool collectLoopRelatedInfo( lower::AbstractConverter &converter, mlir::Location currentLocation, lower::pft::Evaluation &eval, const omp::List &clauses, mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl &iv) { + bool found = false; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -627,6 +692,38 @@ bool collectLoopRelatedInfo( found = true; } + // Collect sizes from tile directive if present + std::int64_t sizesLengthValue = 0l; + if (auto *ompCons{eval.getIf()}) { + if (auto *ompLoop{std::get_if(&ompCons->u)}) { + const auto &innerOptional = std::get< + std::optional>>( + ompLoop->t); + if (innerOptional.has_value()) { + const auto &innerLoopDirective = innerOptional.value().value(); + const auto &innerBegin = + std::get(innerLoopDirective.t); + const auto &innerDirective = + std::get(innerBegin.t).v; + + if (innerDirective == llvm::omp::Directive::OMPD_tile) { + // Get the size values from parse tree and convert to a vector + const auto &innerClauseList{ + std::get(innerBegin.t)}; + for (const auto &clause : innerClauseList.v) + if (const auto tclause{ + std::get_if(&clause.u)}) { + sizesLengthValue = tclause->v.size(); + found = true; + } + } + } + } + } + + collapseValue = collapseValue - sizesLengthValue; + collapseValue = + collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue; std::size_t loopVarTypeSize = 0; do { lower::pft::Evaluation *doLoop = @@ -659,7 +756,6 @@ bool collectLoopRelatedInfo( } while (collapseValue > 0); convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); - return found; } } // namespace omp diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index a7eb2dc5ee664..d54c9eff653cc 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -166,6 +166,12 @@ bool collectLoopRelatedInfo( lower::pft::Evaluation &eval, const omp::List &clauses, mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl &iv); + +void collectTileSizesFromOpenMPConstruct( + const parser::OpenMPConstruct *ompCons, + llvm::SmallVectorImpl &tileSizes, + Fortran::semantics::SemanticsContext &semaCtx); + } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index ed0f227fd5b98..75ac84fc23300 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2926,6 +2926,8 @@ class UnparseVisitor { Walk(std::get(x.t)); Put("\n"); EndOpenMP(); + Walk( + std::get>>(x.t)); Walk(std::get>(x.t)); Walk(std::get>(x.t)); } diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp index 5164f1dc6faab..08c791f78412d 100644 --- a/flang/lib/Semantics/canonicalize-omp.cpp +++ b/flang/lib/Semantics/canonicalize-omp.cpp @@ -8,7 +8,6 @@ #include "canonicalize-omp.h" #include "flang/Parser/parse-tree-visitor.h" - // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP // Constructs more structured which provide explicit scopes for later // structural checks and semantic analysis. @@ -112,15 +111,19 @@ class CanonicalizationOfOmp { // in the same iteration // // Original: - // ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct - // OmpBeginLoopDirective + // ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t-> + // OmpBeginLoopDirective t-> OmpLoopDirective + // [ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct u-> + // OmpBeginLoopDirective t-> OmpLoopDirective t-> Tile v-> OMP_tile] // ExecutableConstruct -> DoConstruct + // [ExecutableConstruct -> OmpEndLoopDirective] (note: tile) // ExecutableConstruct -> OmpEndLoopDirective (if available) // // After rewriting: - // ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct - // OmpBeginLoopDirective - // DoConstruct + // ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct t-> + // [OpenMPLoopConstruct t -> OmpBeginLoopDirective -> OmpLoopDirective + // OmpEndLoopDirective] (note: tile) + // OmpBeginLoopDirective t -> OmpLoopDirective -> DoConstruct // OmpEndLoopDirective (if available) parser::Block::iterator nextIt; auto &beginDir{std::get(x.t)}; @@ -131,20 +134,55 @@ class CanonicalizationOfOmp { // Ignore compiler directives. if (GetConstructIf(*nextIt)) continue; + // Keep track of the loops to handle the end loop directives + llvm::SmallVector loops; + loops.push_back(&x); + if (auto *innerOmpLoop{GetOmpIf(*nextIt)}) { + auto &innerBeginDir{ + std::get(innerOmpLoop->t)}; + auto &innerDir{std::get(innerBeginDir.t)}; + if (innerDir.v == llvm::omp::Directive::OMPD_tile) { + auto &innerLoop = std::get< + std::optional>>( + loops.back()->t); + innerLoop = std::move(*innerOmpLoop); + // Retrieveing the address so that DoConstruct or inner loop can be + // set later. + loops.push_back(&(innerLoop.value().value())); + nextIt = block.erase(nextIt); + } + } if (auto *doCons{GetConstructIf(*nextIt)}) { if (doCons->GetLoopControl()) { - // move DoConstruct - std::get>(x.t) = + std::get>(loops.back()->t) = std::move(*doCons); nextIt = block.erase(nextIt); // try to match OmpEndLoopDirective - if (nextIt != block.end()) { + while (nextIt != block.end() && !loops.empty()) { if (auto *endDir{ GetConstructIf(*nextIt)}) { - std::get>(x.t) = - std::move(*endDir); - block.erase(nextIt); + auto &endOmpDirective{ + std::get(endDir->t)}; + auto &loopBegin{ + std::get(loops.back()->t)}; + auto &loopDir{std::get(loopBegin.t)}; + + // If the directive is a tile we try to match the corresponding + // end tile if it exsists. If it is not a tile directive we + // always assign the end loop directive and fall back on the + // existing directive structure checks. + if (loopDir.v != llvm::omp::Directive::OMPD_tile || + loopDir.v == endOmpDirective.v) { + std::get>( + loops.back()->t) = std::move(*endDir); + nextIt = block.erase(nextIt); + } + + loops.pop_back(); + } else { + // If there is a mismatch bail out. + break; } } } else { diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 885c02e6ec74b..6c4c360aba493 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -745,7 +745,20 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor { const parser::OmpClause *GetAssociatedClause() { return associatedClause; } private: + std::int64_t SetAssociatedMaxClause(llvm::SmallVector &, + llvm::SmallVector &); + std::int64_t GetAssociatedLoopLevelFromLoopConstruct( + const parser::OpenMPLoopConstruct &); std::int64_t GetAssociatedLoopLevelFromClauses(const parser::OmpClauseList &); + void CollectAssociatedLoopLevelsFromLoopConstruct( + const parser::OpenMPLoopConstruct &, llvm::SmallVector &, + llvm::SmallVector &); + void CollectAssociatedLoopLevelsFromInnerLoopContruct( + const parser::OpenMPLoopConstruct &, llvm::SmallVector &, + llvm::SmallVector &); + void CollectAssociatedLoopLevelsFromClauses(const parser::OmpClauseList &, + llvm::SmallVector &, + llvm::SmallVector &); Symbol::Flags dataSharingAttributeFlags{Symbol::Flag::OmpShared, Symbol::Flag::OmpPrivate, Symbol::Flag::OmpFirstPrivate, @@ -1742,7 +1755,6 @@ bool OmpAttributeVisitor::Pre( bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) { const auto &beginLoopDir{std::get(x.t)}; const auto &beginDir{std::get(beginLoopDir.t)}; - const auto &clauseList{std::get(beginLoopDir.t)}; switch (beginDir.v) { case llvm::omp::Directive::OMPD_distribute: case llvm::omp::Directive::OMPD_distribute_parallel_do: @@ -1793,7 +1805,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) { beginDir.v == llvm::omp::Directive::OMPD_target_loop) IssueNonConformanceWarning(beginDir.v, beginDir.source, 52); ClearDataSharingAttributeObjects(); - SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList)); + SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromLoopConstruct(x)); if (beginDir.v == llvm::omp::Directive::OMPD_do) { if (const auto &doConstruct{ @@ -1804,7 +1816,7 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) { } } PrivatizeAssociatedLoopIndexAndCheckLoopLevel(x); - ordCollapseLevel = GetAssociatedLoopLevelFromClauses(clauseList) + 1; + ordCollapseLevel = GetAssociatedLoopLevelFromLoopConstruct(x) + 1; return true; } @@ -1892,44 +1904,118 @@ bool OmpAttributeVisitor::Pre(const parser::DoConstruct &x) { return true; } +static bool isSizesClause(const parser::OmpClause *clause) { + return std::holds_alternative(clause->u); +} + +std::int64_t OmpAttributeVisitor::SetAssociatedMaxClause( + llvm::SmallVector &levels, + llvm::SmallVector &clauses) { + + // Find the tile level to know how much to reduce the level for collapse + std::int64_t tileLevel = 0; + for (auto [level, clause] : llvm::zip_equal(levels, clauses)) { + if (isSizesClause(clause)) { + tileLevel = level; + } + } + + std::int64_t maxLevel = 1; + const parser::OmpClause *maxClause = nullptr; + for (auto [level, clause] : llvm::zip_equal(levels, clauses)) { + if (tileLevel > 0 && tileLevel < level) { + context_.Say(clause->source, + "The value of the parameter in the COLLAPSE clause must" + " not be larger than the number of the number of tiled loops" + " because collapse relies on independent loop iterations."_err_en_US); + return 1; + } + + if (!isSizesClause(clause)) { + level = level - tileLevel; + } + + if (level > maxLevel) { + maxLevel = level; + maxClause = clause; + } + } + if (maxClause) + SetAssociatedClause(maxClause); + return maxLevel; +} + +std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromLoopConstruct( + const parser::OpenMPLoopConstruct &x) { + llvm::SmallVector levels; + llvm::SmallVector clauses; + + CollectAssociatedLoopLevelsFromLoopConstruct(x, levels, clauses); + return SetAssociatedMaxClause(levels, clauses); +} + std::int64_t OmpAttributeVisitor::GetAssociatedLoopLevelFromClauses( const parser::OmpClauseList &x) { - std::int64_t orderedLevel{0}; - std::int64_t collapseLevel{0}; + llvm::SmallVector levels; + llvm::SmallVector clauses; - const parser::OmpClause *ordClause{nullptr}; - const parser::OmpClause *collClause{nullptr}; + CollectAssociatedLoopLevelsFromClauses(x, levels, clauses); + return SetAssociatedMaxClause(levels, clauses); +} + +void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromLoopConstruct( + const parser::OpenMPLoopConstruct &x, + llvm::SmallVector &levels, + llvm::SmallVector &clauses) { + const auto &beginLoopDir{std::get(x.t)}; + const auto &clauseList{std::get(beginLoopDir.t)}; + + CollectAssociatedLoopLevelsFromClauses(clauseList, levels, clauses); + CollectAssociatedLoopLevelsFromInnerLoopContruct(x, levels, clauses); +} +void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromInnerLoopContruct( + const parser::OpenMPLoopConstruct &x, + llvm::SmallVector &levels, + llvm::SmallVector &clauses) { + const auto &innerOptional = + std::get>>( + x.t); + if (innerOptional.has_value()) { + CollectAssociatedLoopLevelsFromLoopConstruct( + innerOptional.value().value(), levels, clauses); + } +} + +void OmpAttributeVisitor::CollectAssociatedLoopLevelsFromClauses( + const parser::OmpClauseList &x, llvm::SmallVector &levels, + llvm::SmallVector &clauses) { for (const auto &clause : x.v) { - if (const auto *orderedClause{ + if (const auto oclause{ std::get_if(&clause.u)}) { - if (const auto v{EvaluateInt64(context_, orderedClause->v)}) { - orderedLevel = *v; + std::int64_t level = 0; + if (const auto v{EvaluateInt64(context_, oclause->v)}) { + level = *v; } - ordClause = &clause; + levels.push_back(level); + clauses.push_back(&clause); } - if (const auto *collapseClause{ + + if (const auto cclause{ std::get_if(&clause.u)}) { - if (const auto v{EvaluateInt64(context_, collapseClause->v)}) { - collapseLevel = *v; + std::int64_t level = 0; + if (const auto v{EvaluateInt64(context_, cclause->v)}) { + level = *v; } - collClause = &clause; + levels.push_back(level); + clauses.push_back(&clause); } - } - if (orderedLevel && (!collapseLevel || orderedLevel >= collapseLevel)) { - SetAssociatedClause(ordClause); - return orderedLevel; - } else if (!orderedLevel && collapseLevel) { - SetAssociatedClause(collClause); - return collapseLevel; - } else { - SetAssociatedClause(nullptr); + if (const auto tclause{std::get_if(&clause.u)}) { + levels.push_back(tclause->v.size()); + clauses.push_back(&clause); + } } - // orderedLevel < collapseLevel is an error handled in structural - // checks - - return 1; // default is outermost loop } // 2.15.1.1 Data-sharing Attribute Rules - Predetermined @@ -1961,8 +2047,16 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel( const parser::OmpClause *clause{GetAssociatedClause()}; bool hasCollapseClause{ clause ? (clause->Id() == llvm::omp::OMPC_collapse) : false}; + const parser::OpenMPLoopConstruct *innerMostLoop = &x; - const auto &outer{std::get>(x.t)}; + while (auto &innerLoop{ + std::get>>( + innerMostLoop->t)}) { + innerMostLoop = &innerLoop.value().value(); + } + + const auto &outer{ + std::get>(innerMostLoop->t)}; if (outer.has_value()) { for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) { if (loop->IsDoConcurrent()) { diff --git a/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 index 2890e78e9d17f..faf8f717f6308 100644 --- a/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 +++ b/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 @@ -108,7 +108,7 @@ subroutine omp_do_lastprivate_collapse2(a) ! CHECK-NEXT: %[[UB2:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref ! CHECK-NEXT: %[[STEP2:.*]] = arith.constant 1 : i32 ! CHECK-NEXT: omp.wsloop private(@{{.*}} %{{.*}}#0 -> %[[A_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[I_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[J_PVT_REF:.*]] : !fir.ref, !fir.ref, !fir.ref) { - ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]]) : i32 = (%[[LB1]], %[[LB2]]) to (%[[UB1]], %[[UB2]]) inclusive step (%[[STEP1]], %[[STEP2]]) { + ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]]) : i32 = (%[[LB1]], %[[LB2]]) to (%[[UB1]], %[[UB2]]) inclusive step (%[[STEP1]], %[[STEP2]]) collapse(2) { ! CHECK: %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ea"} : (!fir.ref) -> (!fir.ref, !fir.ref) ! CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) ! CHECK: %[[J_PVT_DECL:.*]]:2 = hlfir.declare %[[J_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ej"} : (!fir.ref) -> (!fir.ref, !fir.ref) @@ -174,7 +174,7 @@ subroutine omp_do_lastprivate_collapse3(a) ! CHECK-NEXT: %[[UB3:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref ! CHECK-NEXT: %[[STEP3:.*]] = arith.constant 1 : i32 ! CHECK-NEXT: omp.wsloop private(@{{.*}} %{{.*}}#0 -> %[[A_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[I_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[J_PVT_REF:.*]], @{{.*}} %{{.*}}#0 -> %[[K_PVT_REF:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { - ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) : i32 = (%[[LB1]], %[[LB2]], %[[LB3]]) to (%[[UB1]], %[[UB2]], %[[UB3]]) inclusive step (%[[STEP1]], %[[STEP2]], %[[STEP3]]) { + ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) : i32 = (%[[LB1]], %[[LB2]], %[[LB3]]) to (%[[UB1]], %[[UB2]], %[[UB3]]) inclusive step (%[[STEP1]], %[[STEP2]], %[[STEP3]]) collapse(3) { ! CHECK: %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ea"} : (!fir.ref) -> (!fir.ref, !fir.ref) ! CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) ! CHECK: %[[J_PVT_DECL:.*]]:2 = hlfir.declare %[[J_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ej"} : (!fir.ref) -> (!fir.ref, !fir.ref) diff --git a/flang/test/Lower/OpenMP/simd.f90 b/flang/test/Lower/OpenMP/simd.f90 index d815474b84b31..3572b9baff00b 100644 --- a/flang/test/Lower/OpenMP/simd.f90 +++ b/flang/test/Lower/OpenMP/simd.f90 @@ -175,7 +175,7 @@ subroutine simd_with_collapse_clause(n) ! CHECK-NEXT: omp.loop_nest (%[[ARG_0:.*]], %[[ARG_1:.*]]) : i32 = ( ! CHECK-SAME: %[[LOWER_I]], %[[LOWER_J]]) to ( ! CHECK-SAME: %[[UPPER_I]], %[[UPPER_J]]) inclusive step ( - ! CHECK-SAME: %[[STEP_I]], %[[STEP_J]]) { + ! CHECK-SAME: %[[STEP_I]], %[[STEP_J]]) collapse(2) { !$OMP SIMD COLLAPSE(2) do i = 1, n do j = 1, n diff --git a/flang/test/Lower/OpenMP/wsloop-collapse.f90 b/flang/test/Lower/OpenMP/wsloop-collapse.f90 index a4d5cbdc03d3e..0f27cd6f6cd7c 100644 --- a/flang/test/Lower/OpenMP/wsloop-collapse.f90 +++ b/flang/test/Lower/OpenMP/wsloop-collapse.f90 @@ -57,7 +57,7 @@ program wsloop_collapse !CHECK: %[[VAL_31:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref !CHECK: %[[VAL_32:.*]] = arith.constant 1 : i32 !CHECK: omp.wsloop private(@{{.*}} %{{.*}}#0 -> %[[VAL_4:.*]], @{{.*}} %{{.*}}#0 -> %[[VAL_2:.*]], @{{.*}} %{{.*}}#0 -> %[[VAL_0:.*]] : !fir.ref, !fir.ref, !fir.ref) { -!CHECK-NEXT: omp.loop_nest (%[[VAL_33:.*]], %[[VAL_34:.*]], %[[VAL_35:.*]]) : i32 = (%[[VAL_24]], %[[VAL_27]], %[[VAL_30]]) to (%[[VAL_25]], %[[VAL_28]], %[[VAL_31]]) inclusive step (%[[VAL_26]], %[[VAL_29]], %[[VAL_32]]) { +!CHECK-NEXT: omp.loop_nest (%[[VAL_33:.*]], %[[VAL_34:.*]], %[[VAL_35:.*]]) : i32 = (%[[VAL_24]], %[[VAL_27]], %[[VAL_30]]) to (%[[VAL_25]], %[[VAL_28]], %[[VAL_31]]) inclusive step (%[[VAL_26]], %[[VAL_29]], %[[VAL_32]]) collapse(3) { !$omp do collapse(3) do i = 1, a do j= 1, b diff --git a/flang/test/Lower/OpenMP/wsloop-tile.f90 b/flang/test/Lower/OpenMP/wsloop-tile.f90 new file mode 100644 index 0000000000000..c9bf18e3b278d --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-tile.f90 @@ -0,0 +1,39 @@ +! This test checks lowering of OpenMP DO Directive(Worksharing) with collapse. + +! RUN: bbc -fopenmp -fopenmp-version=51 -emit-hlfir %s -o - | FileCheck %s + +!CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "wsloop_tile"} { +program wsloop_tile + integer :: i, j, k + integer :: a, b, c + integer :: x + + a=30 + b=20 + c=50 + x=0 + + !CHECK: omp.loop_nest (%[[IV_0:.*]], %[[IV_1:.*]], %[[IV_2:.*]]) : i32 + !CHECK-SAME: tiles(2, 5, 10) + + !$omp do + !$omp tile sizes(2,5,10) + do i = 1, a + do j= 1, b + do k = 1, c + !CHECK: hlfir.assign %[[IV_0]] to %[[IV_0A:.*]] : i32 + !CHECK: hlfir.assign %[[IV_1]] to %[[IV_1A:.*]] : i32 + !CHECK: hlfir.assign %[[IV_2]] to %[[IV_2A:.*]] : i32 + !CHECK: %[[IVV_0:.*]] = fir.load %[[IV_0A]] + !CHECK: %[[SUM0:.*]] = arith.addi %{{.*}}, %[[IVV_0]] : i32 + !CHECK: %[[IVV_1:.*]] = fir.load %[[IV_1A]] + !CHECK: %[[SUM1:.*]] = arith.addi %[[SUM0]], %[[IVV_1]] : i32 + !CHECK: %[[IVV_2:.*]] = fir.load %[[IV_2A]] + !CHECK: %[[SUM2:.*]] = arith.addi %[[SUM1]], %[[IVV_2]] : i32 + x = x + i + j + k + end do + end do + end do + !$omp end tile + !$omp end do +end program wsloop_tile diff --git a/flang/test/Lower/OpenMP/wsloop-variable.f90 b/flang/test/Lower/OpenMP/wsloop-variable.f90 index a7fb5fb8936e7..cceb77b974fee 100644 --- a/flang/test/Lower/OpenMP/wsloop-variable.f90 +++ b/flang/test/Lower/OpenMP/wsloop-variable.f90 @@ -23,7 +23,7 @@ program wsloop_variable !CHECK: %[[TMP6:.*]] = fir.convert %[[TMP1]] : (i32) -> i64 !CHECK: %[[TMP7:.*]] = fir.convert %{{.*}} : (i32) -> i64 !CHECK: omp.wsloop private({{.*}}) { -!CHECK-NEXT: omp.loop_nest (%[[ARG0:.*]], %[[ARG1:.*]]) : i64 = (%[[TMP2]], %[[TMP5]]) to (%[[TMP3]], %[[TMP6]]) inclusive step (%[[TMP4]], %[[TMP7]]) { +!CHECK-NEXT: omp.loop_nest (%[[ARG0:.*]], %[[ARG1:.*]]) : i64 = (%[[TMP2]], %[[TMP5]]) to (%[[TMP3]], %[[TMP6]]) inclusive step (%[[TMP4]], %[[TMP7]]) collapse(2) { !CHECK: %[[ARG0_I16:.*]] = fir.convert %[[ARG0]] : (i64) -> i16 !CHECK: hlfir.assign %[[ARG0_I16]] to %[[STORE_IV0:.*]]#0 : i16, !fir.ref !CHECK: hlfir.assign %[[ARG1]] to %[[STORE_IV1:.*]]#0 : i64, !fir.ref diff --git a/flang/test/Parser/OpenMP/do-tile-size.f90 b/flang/test/Parser/OpenMP/do-tile-size.f90 new file mode 100644 index 0000000000000..886ee4a2a680c --- /dev/null +++ b/flang/test/Parser/OpenMP/do-tile-size.f90 @@ -0,0 +1,29 @@ +! RUN: %flang_fc1 -fdebug-unparse -fopenmp -fopenmp-version=51 %s | FileCheck --ignore-case %s +! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp -fopenmp-version=51 %s | FileCheck --check-prefix="PARSE-TREE" %s + +subroutine openmp_do_tiles(x) + + integer, intent(inout)::x + + +!CHECK: !$omp do +!CHECK: !$omp tile sizes +!$omp do +!$omp tile sizes(2) +!CHECK: do + do x = 1, 100 + call F1() +!CHECK: end do + end do +!CHECK: !$omp end tile +!$omp end tile +!$omp end do + +!PARSE-TREE:| | ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct +!PARSE-TREE:| | | OmpBeginLoopDirective +!PARSE-TREE:| | | OpenMPLoopConstruct +!PARSE-TREE:| | | | OmpBeginLoopDirective +!PARSE-TREE:| | | | | OmpLoopDirective -> llvm::omp::Directive = tile +!PARSE-TREE:| | | | | OmpClauseList -> OmpClause -> Sizes -> Scalar -> Integer -> Expr = '2_4' +!PARSE-TREE: | | | | DoConstruct +END subroutine openmp_do_tiles diff --git a/flang/test/Semantics/OpenMP/do-collapse.f90 b/flang/test/Semantics/OpenMP/do-collapse.f90 index 480bd45b79b83..ec6a3bdad3686 100644 --- a/flang/test/Semantics/OpenMP/do-collapse.f90 +++ b/flang/test/Semantics/OpenMP/do-collapse.f90 @@ -31,6 +31,7 @@ program omp_doCollapse end do end do + !ERROR: The value of the parameter in the COLLAPSE or ORDERED clause must not be larger than the number of nested loops following the construct. !ERROR: At most one COLLAPSE clause can appear on the SIMD directive !$omp simd collapse(2) collapse(1) do i = 1, 4 diff --git a/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90 b/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90 index bb1929249183b..355626f6e73b9 100644 --- a/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90 +++ b/flang/test/Semantics/OpenMP/do-concurrent-collapse.f90 @@ -1,6 +1,7 @@ !RUN: %python %S/../test_errors.py %s %flang -fopenmp integer :: i, j +! ERROR: DO CONCURRENT loops cannot be used with the COLLAPSE clause. !$omp parallel do collapse(2) do i = 1, 1 ! ERROR: DO CONCURRENT loops cannot form part of a loop nest. diff --git a/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h b/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h index cdc80c88b7425..91012ba6868d1 100644 --- a/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h +++ b/llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h @@ -209,6 +209,8 @@ struct ConstructDecompositionT { bool applyClause(const tomp::clause::CollapseT &clause, const ClauseTy *); + bool applyClause(const tomp::clause::SizesT &clause, + const ClauseTy *); bool applyClause(const tomp::clause::PrivateT &clause, const ClauseTy *); bool @@ -482,6 +484,24 @@ bool ConstructDecompositionT::applyClause( return false; } +// FIXME(JAN): Do the correct thing, but for now we'll do the same as collapse +template +bool ConstructDecompositionT::applyClause( + const tomp::clause::SizesT &clause, + const ClauseTy *node) { + // Apply "sizes" to the innermost directive. If it's not one that + // allows it flag an error. + if (!leafs.empty()) { + auto &last = leafs.back(); + + if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) { + last.clauses.push_back(node); + return true; + } + } + + return false; +} // PRIVATE // [5.2:111:5-7] diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 93fb0d8e8d078..cdc78a4d104c9 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2162,6 +2162,9 @@ class OpenMPIRBuilder { /// Return the function that contains the region to be outlined. Function *getFunction() const { return EntryBB->getParent(); } + + /// Dump the info in a somewhat readable way + void dump(); }; /// Collection of regions that need to be outlined during finalization. @@ -2179,6 +2182,9 @@ class OpenMPIRBuilder { /// Add a new region that will be outlined later. void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); } + /// Dump outline infos + void dumpOutlineInfos(); + /// An ordered map of auto-generated variables to their unique names. /// It stores variables with the following names: 1) ".gomp_critical_user_" + /// + ".var" for "omp critical" directives; 2) @@ -3762,6 +3768,9 @@ class CanonicalLoopInfo { /// Invalidate this loop. That is, the underlying IR does not fulfill the /// requirements of an OpenMP canonical loop anymore. LLVM_ABI void invalidate(); + + /// Dump the info in a somewhat readable way + void dump(); }; } // end namespace llvm diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index ddc9c5392f922..4f04af50b3a45 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -8725,6 +8725,15 @@ Error OpenMPIRBuilder::emitOffloadingArrays( return Error::success(); } +void OpenMPIRBuilder::dumpOutlineInfos() { + errs() << "=== Outline Infos Begin ===\n"; + for (auto En : enumerate(OutlineInfos)) { + errs() << "[" << En.index() << "]: "; + En.value().dump(); + } + errs() << "=== Outline Infos End ===\n"; +} + void OpenMPIRBuilder::emitBranch(BasicBlock *Target) { BasicBlock *CurBB = Builder.GetInsertBlock(); @@ -9633,6 +9642,14 @@ void OpenMPIRBuilder::OutlineInfo::collectBlocks( } } +void OpenMPIRBuilder::OutlineInfo::dump() { + errs() << "=== OutilneInfo == " + << " EntryBB: " << (EntryBB ? EntryBB->getName() : "n\a") + << " ExitBB: " << (ExitBB ? ExitBB->getName() : "n\a") + << " OuterAllocaBB: " + << (OuterAllocaBB ? OuterAllocaBB->getName() : "n/a") << "\n"; +} + void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr, uint64_t Size, int32_t Flags, GlobalValue::LinkageTypes, @@ -10410,3 +10427,10 @@ void CanonicalLoopInfo::invalidate() { Latch = nullptr; Exit = nullptr; } + +void CanonicalLoopInfo::dump() { + errs() << "CanonicaLoop == Header: " << (Header ? Header->getName() : "n/a") + << " Cond: " << (Cond ? Cond->getName() : "n/a") + << " Latch: " << (Latch ? Latch->getName() : "n/a") + << " Exit: " << (Exit ? Exit->getName() : "n/a") << "\n"; +} diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 16c14ef085d6d..09c77beb857d4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -266,6 +266,38 @@ class OpenMP_DeviceClauseSkip< def OpenMP_DeviceClause : OpenMP_DeviceClauseSkip<>; +//===----------------------------------------------------------------------===// +// V5.2: [XX.X] `collapse` clause +//===----------------------------------------------------------------------===// + +class OpenMP_CollapseClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + DefaultValuedOptionalAttr:$num_collapse + ); +} + +def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>; + +//===----------------------------------------------------------------------===// +// V5.2: [xx.x] `sizes` clause +//===----------------------------------------------------------------------===// + +class OpenMP_TileSizesClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let arguments = (ins + OptionalAttr:$tile_sizes + ); +} + +def OpenMP_TileSizesClause : OpenMP_TileSizesClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [11.6.1] `dist_schedule` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ac80926053a2d..53e9a29123bd1 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -407,7 +407,9 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [ RecursiveMemoryEffects, SameVariadicOperandSize ], clauses = [ - OpenMP_LoopRelatedClause + OpenMP_LoopRelatedClause, + OpenMP_CollapseClause, + OpenMP_TileSizesClause ], singleRegion = true> { let summary = "rectangular loop nest"; let description = [{ diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 71786e856c6db..cc955e55ee88b 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -493,7 +493,8 @@ struct ParallelOpLowering : public OpRewritePattern { // Create loop nest and populate region with contents of scf.parallel. auto loopOp = rewriter.create( parallelOp.getLoc(), parallelOp.getLowerBound(), - parallelOp.getUpperBound(), parallelOp.getStep()); + parallelOp.getUpperBound(), parallelOp.getStep(), false, + parallelOp.getLowerBound().size(), nullptr); rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), loopOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index e94d570b57122..a5c9a05d53c22 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -55,6 +55,11 @@ makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef boolArray) { return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray); } +static DenseI64ArrayAttr +makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef intArray) { + return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray); +} + namespace { struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel steps; @@ -2950,6 +2955,35 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) return failure(); + // Parse collapse + int64_t value = 0; + if (!parser.parseOptionalKeyword("collapse") && + (parser.parseLParen() || parser.parseInteger(value) || + parser.parseRParen())) + return failure(); + if (value > 1) + result.addAttribute( + "num_collapse", + IntegerAttr::get(parser.getBuilder().getI64Type(), value)); + + // Parse tiles + SmallVector tiles; + auto parseTiles = [&]() -> ParseResult { + int64_t tile; + if (parser.parseInteger(tile)) + return failure(); + tiles.push_back(tile); + return success(); + }; + + if (!parser.parseOptionalKeyword("tiles") && + (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) || + parser.parseRParen())) + return failure(); + + if (tiles.size() > 0) + result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles)); + // Parse the body. Region *region = result.addRegion(); if (parser.parseRegion(*region, ivs)) @@ -2973,14 +3007,23 @@ void LoopNestOp::print(OpAsmPrinter &p) { if (getLoopInclusive()) p << "inclusive "; p << "step (" << getLoopSteps() << ") "; + if (int64_t numCollapse = getNumCollapse()) + if (numCollapse > 1) + p << "collapse(" << numCollapse << ") "; + + if (const auto tiles = getTileSizes()) + p << "tiles(" << tiles.value() << ") "; + p.printRegion(region, /*printEntryBlockArgs=*/false); } void LoopNestOp::build(OpBuilder &builder, OperationState &state, const LoopNestOperands &clauses) { + MLIRContext *ctx = builder.getContext(); LoopNestOp::build(builder, state, clauses.loopLowerBounds, clauses.loopUpperBounds, clauses.loopSteps, - clauses.loopInclusive); + clauses.loopInclusive, clauses.numCollapse, + makeDenseI64ArrayAttr(ctx, clauses.tileSizes)); } LogicalResult LoopNestOp::verify() { @@ -2996,6 +3039,18 @@ LogicalResult LoopNestOp::verify() { << "range argument type does not match corresponding IV type"; } + uint64_t numIVs = getIVs().size(); + + if (const auto &numCollapse = getNumCollapse()) + if (numCollapse > numIVs) + return emitOpError() + << "collapse value is larger than the number of loops"; + + if (const auto &tiles = getTileSizes()) + if (tiles.value().size() > numIVs) + return emitOpError() + << "number of tilings is larger than the number of loops"; + if (!llvm::dyn_cast_if_present((*this)->getParentOp())) return emitOpError() << "expects parent op to be a loop wrapper"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 90ce06a0345c0..064925fa2b2a2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2926,7 +2926,6 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); auto loopOp = cast(opInst); - // Set up the source location value for OpenMP runtime. llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); @@ -2992,18 +2991,59 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, loopInfos.push_back(*loopResult); } - // Collapse loops. Store the insertion point because LoopInfos may get - // invalidated. llvm::OpenMPIRBuilder::InsertPointTy afterIP = loopInfos.front()->getAfterIP(); - // Update the stack frame created for this loop to point to the resulting loop - // after applying transformations. - moduleTranslation.stackWalk( - [&](OpenMPLoopInfoStackFrame &frame) { - frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); - return WalkResult::interrupt(); - }); + // Initialize the new loop info to the current one, in case there + // are no loop transformations done. + llvm::CanonicalLoopInfo *NewTopLoopInfo = nullptr; + + // Do tiling + if (const auto &tiles = loopOp.getTileSizes()) { + llvm::Type *IVType = loopInfos.front()->getIndVarType(); + SmallVector TileSizes; + + for (auto tile : tiles.value()) { + llvm::Value *TileVal = llvm::ConstantInt::get(IVType, tile); + TileSizes.push_back(TileVal); + } + + std::vector NewLoops = + ompBuilder->tileLoops(ompLoc.DL, loopInfos, TileSizes); + + // Update afterIP to get the correct insertion point after + // tiling. + llvm::BasicBlock *AfterBB = NewLoops.front()->getAfter(); + llvm::BasicBlock *AfterAfterBB = AfterBB->getSingleSuccessor(); + afterIP = {AfterAfterBB, AfterAfterBB->begin()}; + NewTopLoopInfo = NewLoops[0]; + + // Update the loop infos + loopInfos.clear(); + for (const auto &newLoop : NewLoops) { + loopInfos.push_back(newLoop); + } + } // Tiling done + + // Do collapse + if (const auto &numCollapse = loopOp.getNumCollapse()) { + SmallVector collapseLoopInfos( + loopInfos.begin(), loopInfos.begin() + (numCollapse)); + + auto newLoopInfo = + ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {}); + NewTopLoopInfo = newLoopInfo; + } // Collapse done + + // Update the stack frame created for this loop to point to the resulting + // loop after applying transformations. + if (NewTopLoopInfo) { + moduleTranslation.stackWalk( + [&](OpenMPLoopInfoStackFrame &frame) { + frame.loopInfo = NewTopLoopInfo; + return WalkResult::interrupt(); + }); + } // Continue building IR after the loop. Note that the LoopInfo returned by // `collapseLoops` points inside the outermost loop and is intended for diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir index a722acbf2c347..d362bb6092419 100644 --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -6,7 +6,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index, // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) { // CHECK: omp.wsloop { - // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) collapse(2) { // CHECK: memref.alloca_scope scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 060b3cd2455a0..2fb540e542ab0 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -157,6 +157,29 @@ func.func @no_loops(%lb : index, %ub : index, %step : index) { } } +// ----- + +func.func @collapse_size(%lb : index, %ub : index, %step : index) { + omp.wsloop { + // expected-error@+1 {{collapse value is larger than the number of loops}} + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) collapse(4) { + omp.yield + } + } +} + +// ----- + +func.func @tiles_length(%lb : index, %ub : index, %step : index) { + omp.wsloop { + // expected-error@+1 {{number of tilings is larger than the number of loops}} + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) tiles(2, 4) { + omp.yield + } + } +} + + // ----- func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 47cfc5278a5d0..951040af3d422 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -376,6 +376,60 @@ func.func @omp_loop_nest_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32, return } +// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse +func.func @omp_loop_nest_pretty_multiple_collapse(%lb1 : i32, %ub1 : i32, %step1 : i32, + %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref) -> () { + + omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2) + omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { + %1 = "test.payload"(%iv1) : (i32) -> (index) + %2 = "test.payload"(%iv2) : (i32) -> (index) + memref.store %iv1, %data1[%1] : memref + memref.store %iv2, %data1[%2] : memref + omp.yield + } + } + + return +} + +// CHECK-LABEL: omp_loop_nest_pretty_multiple_tiles +func.func @omp_loop_nest_pretty_multiple_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32, + %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref) -> () { + + omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) tiles(5, 10) + omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) tiles(5, 10) { + %1 = "test.payload"(%iv1) : (i32) -> (index) + %2 = "test.payload"(%iv2) : (i32) -> (index) + memref.store %iv1, %data1[%1] : memref + memref.store %iv2, %data1[%2] : memref + omp.yield + } + } + + return +} + +// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse_tiles +func.func @omp_loop_nest_pretty_multiple_collapse_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32, + %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref) -> () { + + omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2) tiles(5, 10) + omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) tiles(5, 10) { + %1 = "test.payload"(%iv1) : (i32) -> (index) + %2 = "test.payload"(%iv2) : (i32) -> (index) + memref.store %iv1, %data1[%1] : memref + memref.store %iv2, %data1[%2] : memref + omp.yield + } + } + + return +} + // CHECK-LABEL: omp_wsloop func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref, %linear_var : i32, %chunk_var : i32) -> () { diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir index 0ebcec0e0ec31..189e3bc57db86 100644 --- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir @@ -9,7 +9,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo %loop_lb = llvm.mlir.constant(0 : i32) : i32 %loop_step = llvm.mlir.constant(1 : index) : i32 omp.wsloop { - omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) { + omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) collapse(2) { %1 = llvm.add %arg1, %arg2 : i32 %2 = llvm.mul %arg2, %loop_ub overflow : i32 %3 = llvm.add %arg1, %2 :i32 diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 32f0ba5b105ff..439e315a23ce1 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -698,7 +698,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) { // CHECK-LABEL: @simd_simple_multiple llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 // The form of the emitted IR is controlled by OpenMPIRBuilder and // tested there. Just check that the right metadata is added and collapsed @@ -736,7 +736,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64 // CHECK-LABEL: @simd_simple_multiple_simdlen llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd simdlen(2) { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 // The form of the emitted IR is controlled by OpenMPIRBuilder and // tested there. Just check that the right metadata is added. @@ -760,7 +760,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l // CHECK-LABEL: @simd_simple_multiple_safelen llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd safelen(2) { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -779,7 +779,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l // CHECK-LABEL: @simd_simple_multiple_simdlen_safelen llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd simdlen(1) safelen(2) { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -1179,7 +1179,7 @@ llvm.func @collapse_wsloop( // CHECK: store i32 %[[TOTAL_SUB_1]], ptr // CHECK: call void @__kmpc_for_static_init_4u omp.wsloop { - omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) { + omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) { %31 = llvm.load %20 : !llvm.ptr -> i32 %32 = llvm.add %31, %arg0 : i32 %33 = llvm.add %32, %arg1 : i32 @@ -1241,7 +1241,7 @@ llvm.func @collapse_wsloop_dynamic( // CHECK: store i32 %[[TOTAL]], ptr // CHECK: call void @__kmpc_dispatch_init_4u omp.wsloop schedule(dynamic) { - omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) { + omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) { %31 = llvm.load %20 : !llvm.ptr -> i32 %32 = llvm.add %31, %arg0 : i32 %33 = llvm.add %32, %arg1 : i32