diff --git a/flang/include/flang/Semantics/symbol-set-closure.h b/flang/include/flang/Semantics/symbol-set-closure.h new file mode 100644 index 0000000000000..d7f2f74c47e9a --- /dev/null +++ b/flang/include/flang/Semantics/symbol-set-closure.h @@ -0,0 +1,34 @@ +//===-- include/flang/Semantics/symbol-set-closure.h ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_ +#define FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_ + +#include "flang/Semantics/symbol.h" + +namespace Fortran::semantics { + +// For a set or scope of symbols, computes the transitive closure of their +// dependences due to their types, bounds, specific procedures, interfaces, +// initialization, storage association, &c. Includes the original symbol +// or members of the original set. Does not include dependences from +// subprogram definitions, only their interfaces. +enum DependenceCollectionFlags { + NoDependenceCollectionFlags = 0, + IncludeOriginalSymbols = 1 << 0, + FollowUseAssociations = 1 << 1, + IncludeSpecificsOfGenerics = 1 << 2, + IncludeComponentsInExprs = 1 << 3, +}; +UnorderedSymbolSet CollectAllDependences( + const UnorderedSymbolSet &, int = NoDependenceCollectionFlags); +UnorderedSymbolSet CollectAllDependences( + const Scope &, int = NoDependenceCollectionFlags); + +} // namespace Fortran::semantics +#endif // FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_ diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt index 18c89587843a9..c1be83b0c744c 100644 --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -46,6 +46,7 @@ add_flang_library(FortranSemantics scope.cpp semantics.cpp symbol.cpp + symbol-set-closure.cpp tools.cpp type.cpp unparse-with-symbols.cpp diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 82c8536902eb2..05a4ee2ea21e3 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -15,6 +15,7 @@ #include "flang/Parser/unparse.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" +#include "flang/Semantics/symbol-set-closure.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "llvm/Support/FileSystem.h" @@ -223,71 +224,10 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) { // Collect symbols from constant and specification expressions that are being // referenced directly from other modules; they may require new USE // associations. -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &, const Scope &); -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) { - auto HarvestBound{[&](const Bound &bound) { - if (const auto &expr{bound.GetExplicit()}) { - for (SymbolRef ref : evaluate::CollectSymbols(*expr)) { - set.emplace(*ref); - } - } - }}; - auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) { - HarvestBound(shapeSpec.lbound()); - HarvestBound(shapeSpec.ubound()); - }}; - auto HarvestArraySpec{[&](const ArraySpec &arraySpec) { - for (const auto &shapeSpec : arraySpec) { - HarvestShapeSpec(shapeSpec); - } - }}; - - if (symbol.has()) { - if (symbol.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *symbol.scope()); - } - } else if (const auto &generic{symbol.detailsIf()}; - generic && generic->derivedType()) { - const Symbol &dtSym{*generic->derivedType()}; - if (dtSym.has()) { - if (dtSym.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope()); - } - } else { - CHECK(dtSym.has() || dtSym.has()); - } - } else if (const auto *object{symbol.detailsIf()}) { - HarvestArraySpec(object->shape()); - HarvestArraySpec(object->coshape()); - if (IsNamedConstant(symbol) || scope.IsDerivedType()) { - if (object->init()) { - for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) { - set.emplace(*ref); - } - } - } - } else if (const auto *proc{symbol.detailsIf()}) { - if (proc->init() && *proc->init() && scope.IsDerivedType()) { - set.emplace(**proc->init()); - } - } else if (const auto *subp{symbol.detailsIf()}) { - for (const Symbol *dummy : subp->dummyArgs()) { - if (dummy) { - HarvestSymbolsNeededFromOtherModules(set, *dummy, scope); - } - } - if (subp->isFunction()) { - HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope); - } - } -} - static void HarvestSymbolsNeededFromOtherModules( SourceOrderedSymbolSet &set, const Scope &scope) { - for (const auto &[_, symbol] : scope) { - HarvestSymbolsNeededFromOtherModules(set, *symbol, scope); + for (const Symbol &symbol : CollectAllDependences(scope)) { + set.insert(symbol); } } @@ -369,7 +309,7 @@ void ModFileWriter::PutSymbols( PrepareRenamings(scope); SourceOrderedSymbolSet modules; CollectSymbols(scope, sorted, uses, modules); - // Write module files for dependencies first so that their + // Write module files for dependences first so that their // hashes are known. for (const Symbol &mod : modules) { if (hermeticModules) { diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index f66918e5c140e..f6cbe49f56543 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -7416,7 +7416,8 @@ void DeclarationVisitor::SetType( std::optional DeclarationVisitor::ResolveDerivedType( const parser::Name &name) { Scope &outer{NonDerivedTypeScope()}; - Symbol *symbol{FindSymbol(outer, name)}; + Symbol *original{FindSymbol(outer, name)}; + Symbol *symbol{original}; Symbol *ultimate{symbol ? &symbol->GetUltimate() : nullptr}; auto *generic{ultimate ? ultimate->detailsIf() : nullptr}; if (generic) { @@ -7429,11 +7430,12 @@ std::optional DeclarationVisitor::ResolveDerivedType( (generic && &ultimate->owner() == &outer)) { if (allowForwardReferenceToDerivedType()) { if (!symbol) { - symbol = &MakeSymbol(outer, name.source, Attrs{}); + symbol = original = &MakeSymbol(outer, name.source, Attrs{}); Resolve(name, *symbol); } else if (generic) { // forward ref to type with later homonymous generic - symbol = &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{}); + symbol = original = + &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{}); generic->set_derivedType(*symbol); name.symbol = symbol; } @@ -7453,7 +7455,7 @@ std::optional DeclarationVisitor::ResolveDerivedType( if (CheckUseError(name)) { return std::nullopt; } else if (symbol->GetUltimate().has()) { - return DerivedTypeSpec{name.source, *symbol}; + return DerivedTypeSpec{name.source, *original}; } else { Say(name, "'%s' is not a derived type"_err_en_US); return std::nullopt; diff --git a/flang/lib/Semantics/symbol-set-closure.cpp b/flang/lib/Semantics/symbol-set-closure.cpp new file mode 100644 index 0000000000000..fb928460b86b5 --- /dev/null +++ b/flang/lib/Semantics/symbol-set-closure.cpp @@ -0,0 +1,187 @@ +//===-- lib/Semantics/symbol-set-closure.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Semantics/symbol-set-closure.h" +#include "flang/Common/idioms.h" +#include "flang/Common/visit.h" + +namespace Fortran::semantics { + +class Collector { +public: + explicit Collector(int flags) : flags_{flags} {} + + UnorderedSymbolSet Collected() { return std::move(set_); } + + void operator()(const Symbol &x) { set_.insert(x); } + void operator()(SymbolRef x) { (*this)(*x); } + template void operator()(const std::optional &x) { + if (x) { + (*this)(*x); + } + } + template void operator()(const A *x) { + if (x) { + (*this)(*x); + } + } + void operator()(const UnorderedSymbolSet &x) { + for (const Symbol &symbol : x) { + (*this)(symbol); + } + } + void operator()(const SourceOrderedSymbolSet &x) { + for (const Symbol &symbol : x) { + (*this)(symbol); + } + } + void operator()(const Scope &x) { + for (const auto &[_, ref] : x) { + (*this)(*ref); + } + } + template void operator()(const evaluate::Expr &x) { + UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)}; + for (const Symbol &sym : exprSyms) { + if (!sym.owner().IsDerivedType() || sym.has() || + (flags_ & IncludeComponentsInExprs)) { + (*this)(sym); + } + } + } + void operator()(const DeclTypeSpec &type) { + if (type.category() == DeclTypeSpec::Category::Character) { + (*this)(type.characterTypeSpec().length()); + } else { + (*this)(type.AsDerived()); + } + } + void operator()(const DerivedTypeSpec &type) { + (*this)(type.originalTypeSymbol()); + for (const auto &[_, value] : type.parameters()) { + (*this)(value); + } + } + void operator()(const ParamValue &x) { (*this)(x.GetExplicit()); } + void operator()(const Bound &x) { (*this)(x.GetExplicit()); } + void operator()(const ShapeSpec &x) { + (*this)(x.lbound()); + (*this)(x.ubound()); + } + void operator()(const ArraySpec &x) { + for (const ShapeSpec &shapeSpec : x) { + (*this)(shapeSpec); + } + } + +private: + UnorderedSymbolSet set_; + int flags_{NoDependenceCollectionFlags}; +}; + +UnorderedSymbolSet CollectAllDependences(const Scope &scope, int flags) { + UnorderedSymbolSet basis; + for (const auto &[_, symbol] : scope) { + basis.insert(*symbol); + } + return CollectAllDependences(basis, flags); +} + +UnorderedSymbolSet CollectAllDependences( + const UnorderedSymbolSet &original, int flags) { + UnorderedSymbolSet result; + if (flags & IncludeOriginalSymbols) { + result = original; + } + UnorderedSymbolSet work{original}; + while (!work.empty()) { + Collector collect{flags}; + for (const Symbol &symbol : work) { + if (symbol.test(Symbol::Flag::CompilerCreated)) { + continue; + } + collect(symbol.GetType()); + common::visit( + common::visitors{ + [&collect, &symbol](const ObjectEntityDetails &x) { + collect(x.shape()); + collect(x.coshape()); + if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) { + collect(x.init()); + } + collect(x.commonBlock()); + if (const auto *set{FindEquivalenceSet(symbol)}) { + for (const EquivalenceObject &equivObject : *set) { + collect(equivObject.symbol); + } + } + }, + [&collect, &symbol](const ProcEntityDetails &x) { + collect(x.rawProcInterface()); + if (symbol.owner().IsDerivedType()) { + collect(x.init()); + } + // TODO: worry about procedure pointers in common blocks? + }, + [&collect](const ProcBindingDetails &x) { collect(x.symbol()); }, + [&collect](const SubprogramDetails &x) { + for (const Symbol *dummy : x.dummyArgs()) { + collect(dummy); + } + if (x.isFunction()) { + collect(x.result()); + } + }, + [&collect, &symbol]( + const DerivedTypeDetails &) { collect(symbol.scope()); }, + [&collect, flags](const GenericDetails &x) { + collect(x.derivedType()); + collect(x.specific()); + for (const Symbol &use : x.uses()) { + collect(use); + } + if (flags & IncludeSpecificsOfGenerics) { + for (const Symbol &specific : x.specificProcs()) { + collect(specific); + } + } + }, + [&collect](const NamelistDetails &x) { + for (const Symbol &symbol : x.objects()) { + collect(symbol); + } + }, + [&collect](const CommonBlockDetails &x) { + for (auto ref : x.objects()) { + collect(*ref); + } + }, + [&collect, &symbol, flags](const UseDetails &x) { + if (flags & FollowUseAssociations) { + collect(x.symbol()); + } + }, + [&collect](const HostAssocDetails &x) { collect(x.symbol()); }, + [](const auto &) {}, + }, + symbol.details()); + } + work.clear(); + for (const Symbol &symbol : collect.Collected()) { + if (result.find(symbol) == result.end() && + ((flags & IncludeOriginalSymbols) || + original.find(symbol) == original.end())) { + result.insert(symbol); + work.insert(symbol); + } + } + } + return result; +} + +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index bf520d04a50cc..adaad89b0bcfd 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "flang/Parser/tools.h" +#include "flang/Semantics/tools.h" #include "flang/Common/indirection.h" #include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/message.h" #include "flang/Parser/parse-tree.h" +#include "flang/Parser/tools.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/symbol.h" -#include "flang/Semantics/tools.h" #include "flang/Semantics/type.h" #include "flang/Support/Fortran.h" #include "llvm/ADT/StringSwitch.h" @@ -2117,4 +2117,4 @@ bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) { return false; } } -} // namespace Fortran::semantics \ No newline at end of file +} // namespace Fortran::semantics diff --git a/flang/lib/Semantics/type.cpp b/flang/lib/Semantics/type.cpp index 964a37e1c822b..4a56902524417 100644 --- a/flang/lib/Semantics/type.cpp +++ b/flang/lib/Semantics/type.cpp @@ -22,9 +22,19 @@ namespace Fortran::semantics { +static const Symbol &ResolveOriginalTypeSymbol(const Symbol *symbol) { + symbol = &symbol->GetUltimate(); + if (const auto *generic{symbol->detailsIf()}) { + CHECK(generic->derivedType() != nullptr); + return generic->derivedType()->GetUltimate(); + } else { + return *symbol; + } +} + DerivedTypeSpec::DerivedTypeSpec(SourceName name, const Symbol &typeSymbol) : name_{name}, originalTypeSymbol_{typeSymbol}, - typeSymbol_{typeSymbol.GetUltimate()} { + typeSymbol_{ResolveOriginalTypeSymbol(&typeSymbol)} { CHECK(typeSymbol_.has()); } DerivedTypeSpec::DerivedTypeSpec(const DerivedTypeSpec &that) = default;