From 04f59f1b5b285342e536596053f6f70cdc43b646 Mon Sep 17 00:00:00 2001 From: Peter Klausler Date: Tue, 17 Jun 2025 11:51:02 -0700 Subject: [PATCH] [flang] Add general symbol dependence collection utility Replace HarvestSymbolsNeededFromOtherModules() in mod-file.cpp with a general utility function in Semantics. This new code will find other uses in further rework of hermetic module file generation as the means by which the necessary subsets of symbols in dependency modules are collected. --- .../flang/Semantics/symbol-set-closure.h | 34 ++++ flang/lib/Semantics/CMakeLists.txt | 1 + flang/lib/Semantics/mod-file.cpp | 68 +------ flang/lib/Semantics/resolve-names.cpp | 10 +- flang/lib/Semantics/symbol-set-closure.cpp | 187 ++++++++++++++++++ flang/lib/Semantics/tools.cpp | 6 +- flang/lib/Semantics/type.cpp | 12 +- 7 files changed, 246 insertions(+), 72 deletions(-) create mode 100644 flang/include/flang/Semantics/symbol-set-closure.h create mode 100644 flang/lib/Semantics/symbol-set-closure.cpp diff --git a/flang/include/flang/Semantics/symbol-set-closure.h b/flang/include/flang/Semantics/symbol-set-closure.h new file mode 100644 index 0000000000000..d7f2f74c47e9a --- /dev/null +++ b/flang/include/flang/Semantics/symbol-set-closure.h @@ -0,0 +1,34 @@ +//===-- include/flang/Semantics/symbol-set-closure.h ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_ +#define FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_ + +#include "flang/Semantics/symbol.h" + +namespace Fortran::semantics { + +// For a set or scope of symbols, computes the transitive closure of their +// dependences due to their types, bounds, specific procedures, interfaces, +// initialization, storage association, &c. Includes the original symbol +// or members of the original set. Does not include dependences from +// subprogram definitions, only their interfaces. +enum DependenceCollectionFlags { + NoDependenceCollectionFlags = 0, + IncludeOriginalSymbols = 1 << 0, + FollowUseAssociations = 1 << 1, + IncludeSpecificsOfGenerics = 1 << 2, + IncludeComponentsInExprs = 1 << 3, +}; +UnorderedSymbolSet CollectAllDependences( + const UnorderedSymbolSet &, int = NoDependenceCollectionFlags); +UnorderedSymbolSet CollectAllDependences( + const Scope &, int = NoDependenceCollectionFlags); + +} // namespace Fortran::semantics +#endif // FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_ diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt index 18c89587843a9..c1be83b0c744c 100644 --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -46,6 +46,7 @@ add_flang_library(FortranSemantics scope.cpp semantics.cpp symbol.cpp + symbol-set-closure.cpp tools.cpp type.cpp unparse-with-symbols.cpp diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 82c8536902eb2..05a4ee2ea21e3 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -15,6 +15,7 @@ #include "flang/Parser/unparse.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" +#include "flang/Semantics/symbol-set-closure.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "llvm/Support/FileSystem.h" @@ -223,71 +224,10 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) { // Collect symbols from constant and specification expressions that are being // referenced directly from other modules; they may require new USE // associations. -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &, const Scope &); -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) { - auto HarvestBound{[&](const Bound &bound) { - if (const auto &expr{bound.GetExplicit()}) { - for (SymbolRef ref : evaluate::CollectSymbols(*expr)) { - set.emplace(*ref); - } - } - }}; - auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) { - HarvestBound(shapeSpec.lbound()); - HarvestBound(shapeSpec.ubound()); - }}; - auto HarvestArraySpec{[&](const ArraySpec &arraySpec) { - for (const auto &shapeSpec : arraySpec) { - HarvestShapeSpec(shapeSpec); - } - }}; - - if (symbol.has()) { - if (symbol.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *symbol.scope()); - } - } else if (const auto &generic{symbol.detailsIf()}; - generic && generic->derivedType()) { - const Symbol &dtSym{*generic->derivedType()}; - if (dtSym.has()) { - if (dtSym.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope()); - } - } else { - CHECK(dtSym.has() || dtSym.has()); - } - } else if (const auto *object{symbol.detailsIf()}) { - HarvestArraySpec(object->shape()); - HarvestArraySpec(object->coshape()); - if (IsNamedConstant(symbol) || scope.IsDerivedType()) { - if (object->init()) { - for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) { - set.emplace(*ref); - } - } - } - } else if (const auto *proc{symbol.detailsIf()}) { - if (proc->init() && *proc->init() && scope.IsDerivedType()) { - set.emplace(**proc->init()); - } - } else if (const auto *subp{symbol.detailsIf()}) { - for (const Symbol *dummy : subp->dummyArgs()) { - if (dummy) { - HarvestSymbolsNeededFromOtherModules(set, *dummy, scope); - } - } - if (subp->isFunction()) { - HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope); - } - } -} - static void HarvestSymbolsNeededFromOtherModules( SourceOrderedSymbolSet &set, const Scope &scope) { - for (const auto &[_, symbol] : scope) { - HarvestSymbolsNeededFromOtherModules(set, *symbol, scope); + for (const Symbol &symbol : CollectAllDependences(scope)) { + set.insert(symbol); } } @@ -369,7 +309,7 @@ void ModFileWriter::PutSymbols( PrepareRenamings(scope); SourceOrderedSymbolSet modules; CollectSymbols(scope, sorted, uses, modules); - // Write module files for dependencies first so that their + // Write module files for dependences first so that their // hashes are known. for (const Symbol &mod : modules) { if (hermeticModules) { diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index f66918e5c140e..f6cbe49f56543 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -7416,7 +7416,8 @@ void DeclarationVisitor::SetType( std::optional DeclarationVisitor::ResolveDerivedType( const parser::Name &name) { Scope &outer{NonDerivedTypeScope()}; - Symbol *symbol{FindSymbol(outer, name)}; + Symbol *original{FindSymbol(outer, name)}; + Symbol *symbol{original}; Symbol *ultimate{symbol ? &symbol->GetUltimate() : nullptr}; auto *generic{ultimate ? ultimate->detailsIf() : nullptr}; if (generic) { @@ -7429,11 +7430,12 @@ std::optional DeclarationVisitor::ResolveDerivedType( (generic && &ultimate->owner() == &outer)) { if (allowForwardReferenceToDerivedType()) { if (!symbol) { - symbol = &MakeSymbol(outer, name.source, Attrs{}); + symbol = original = &MakeSymbol(outer, name.source, Attrs{}); Resolve(name, *symbol); } else if (generic) { // forward ref to type with later homonymous generic - symbol = &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{}); + symbol = original = + &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{}); generic->set_derivedType(*symbol); name.symbol = symbol; } @@ -7453,7 +7455,7 @@ std::optional DeclarationVisitor::ResolveDerivedType( if (CheckUseError(name)) { return std::nullopt; } else if (symbol->GetUltimate().has()) { - return DerivedTypeSpec{name.source, *symbol}; + return DerivedTypeSpec{name.source, *original}; } else { Say(name, "'%s' is not a derived type"_err_en_US); return std::nullopt; diff --git a/flang/lib/Semantics/symbol-set-closure.cpp b/flang/lib/Semantics/symbol-set-closure.cpp new file mode 100644 index 0000000000000..fb928460b86b5 --- /dev/null +++ b/flang/lib/Semantics/symbol-set-closure.cpp @@ -0,0 +1,187 @@ +//===-- lib/Semantics/symbol-set-closure.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Semantics/symbol-set-closure.h" +#include "flang/Common/idioms.h" +#include "flang/Common/visit.h" + +namespace Fortran::semantics { + +class Collector { +public: + explicit Collector(int flags) : flags_{flags} {} + + UnorderedSymbolSet Collected() { return std::move(set_); } + + void operator()(const Symbol &x) { set_.insert(x); } + void operator()(SymbolRef x) { (*this)(*x); } + template void operator()(const std::optional &x) { + if (x) { + (*this)(*x); + } + } + template void operator()(const A *x) { + if (x) { + (*this)(*x); + } + } + void operator()(const UnorderedSymbolSet &x) { + for (const Symbol &symbol : x) { + (*this)(symbol); + } + } + void operator()(const SourceOrderedSymbolSet &x) { + for (const Symbol &symbol : x) { + (*this)(symbol); + } + } + void operator()(const Scope &x) { + for (const auto &[_, ref] : x) { + (*this)(*ref); + } + } + template void operator()(const evaluate::Expr &x) { + UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)}; + for (const Symbol &sym : exprSyms) { + if (!sym.owner().IsDerivedType() || sym.has() || + (flags_ & IncludeComponentsInExprs)) { + (*this)(sym); + } + } + } + void operator()(const DeclTypeSpec &type) { + if (type.category() == DeclTypeSpec::Category::Character) { + (*this)(type.characterTypeSpec().length()); + } else { + (*this)(type.AsDerived()); + } + } + void operator()(const DerivedTypeSpec &type) { + (*this)(type.originalTypeSymbol()); + for (const auto &[_, value] : type.parameters()) { + (*this)(value); + } + } + void operator()(const ParamValue &x) { (*this)(x.GetExplicit()); } + void operator()(const Bound &x) { (*this)(x.GetExplicit()); } + void operator()(const ShapeSpec &x) { + (*this)(x.lbound()); + (*this)(x.ubound()); + } + void operator()(const ArraySpec &x) { + for (const ShapeSpec &shapeSpec : x) { + (*this)(shapeSpec); + } + } + +private: + UnorderedSymbolSet set_; + int flags_{NoDependenceCollectionFlags}; +}; + +UnorderedSymbolSet CollectAllDependences(const Scope &scope, int flags) { + UnorderedSymbolSet basis; + for (const auto &[_, symbol] : scope) { + basis.insert(*symbol); + } + return CollectAllDependences(basis, flags); +} + +UnorderedSymbolSet CollectAllDependences( + const UnorderedSymbolSet &original, int flags) { + UnorderedSymbolSet result; + if (flags & IncludeOriginalSymbols) { + result = original; + } + UnorderedSymbolSet work{original}; + while (!work.empty()) { + Collector collect{flags}; + for (const Symbol &symbol : work) { + if (symbol.test(Symbol::Flag::CompilerCreated)) { + continue; + } + collect(symbol.GetType()); + common::visit( + common::visitors{ + [&collect, &symbol](const ObjectEntityDetails &x) { + collect(x.shape()); + collect(x.coshape()); + if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) { + collect(x.init()); + } + collect(x.commonBlock()); + if (const auto *set{FindEquivalenceSet(symbol)}) { + for (const EquivalenceObject &equivObject : *set) { + collect(equivObject.symbol); + } + } + }, + [&collect, &symbol](const ProcEntityDetails &x) { + collect(x.rawProcInterface()); + if (symbol.owner().IsDerivedType()) { + collect(x.init()); + } + // TODO: worry about procedure pointers in common blocks? + }, + [&collect](const ProcBindingDetails &x) { collect(x.symbol()); }, + [&collect](const SubprogramDetails &x) { + for (const Symbol *dummy : x.dummyArgs()) { + collect(dummy); + } + if (x.isFunction()) { + collect(x.result()); + } + }, + [&collect, &symbol]( + const DerivedTypeDetails &) { collect(symbol.scope()); }, + [&collect, flags](const GenericDetails &x) { + collect(x.derivedType()); + collect(x.specific()); + for (const Symbol &use : x.uses()) { + collect(use); + } + if (flags & IncludeSpecificsOfGenerics) { + for (const Symbol &specific : x.specificProcs()) { + collect(specific); + } + } + }, + [&collect](const NamelistDetails &x) { + for (const Symbol &symbol : x.objects()) { + collect(symbol); + } + }, + [&collect](const CommonBlockDetails &x) { + for (auto ref : x.objects()) { + collect(*ref); + } + }, + [&collect, &symbol, flags](const UseDetails &x) { + if (flags & FollowUseAssociations) { + collect(x.symbol()); + } + }, + [&collect](const HostAssocDetails &x) { collect(x.symbol()); }, + [](const auto &) {}, + }, + symbol.details()); + } + work.clear(); + for (const Symbol &symbol : collect.Collected()) { + if (result.find(symbol) == result.end() && + ((flags & IncludeOriginalSymbols) || + original.find(symbol) == original.end())) { + result.insert(symbol); + work.insert(symbol); + } + } + } + return result; +} + +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index bf520d04a50cc..adaad89b0bcfd 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "flang/Parser/tools.h" +#include "flang/Semantics/tools.h" #include "flang/Common/indirection.h" #include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/message.h" #include "flang/Parser/parse-tree.h" +#include "flang/Parser/tools.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" -#include "flang/Semantics/tools.h" #include "flang/Semantics/type.h" #include "flang/Support/Fortran.h" #include "llvm/ADT/StringSwitch.h" @@ -2117,4 +2117,4 @@ bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) { return false; } } -} // namespace Fortran::semantics \ No newline at end of file +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/type.cpp b/flang/lib/Semantics/type.cpp index 964a37e1c822b..4a56902524417 100644 --- a/flang/lib/Semantics/type.cpp +++ b/flang/lib/Semantics/type.cpp @@ -22,9 +22,19 @@ namespace Fortran::semantics { +static const Symbol &ResolveOriginalTypeSymbol(const Symbol *symbol) { + symbol = &symbol->GetUltimate(); + if (const auto *generic{symbol->detailsIf()}) { + CHECK(generic->derivedType() != nullptr); + return generic->derivedType()->GetUltimate(); + } else { + return *symbol; + } +} + DerivedTypeSpec::DerivedTypeSpec(SourceName name, const Symbol &typeSymbol) : name_{name}, originalTypeSymbol_{typeSymbol}, - typeSymbol_{typeSymbol.GetUltimate()} { + typeSymbol_{ResolveOriginalTypeSymbol(&typeSymbol)} { CHECK(typeSymbol_.has()); } DerivedTypeSpec::DerivedTypeSpec(const DerivedTypeSpec &that) = default;