Skip to content

[flang] Add general symbol dependence collection utility #144618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions flang/include/flang/Semantics/symbol-set-closure.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===-- include/flang/Semantics/symbol-set-closure.h ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_
#define FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_

#include "flang/Semantics/symbol.h"

namespace Fortran::semantics {

// For a set or scope of symbols, computes the transitive closure of their
// dependences due to their types, bounds, specific procedures, interfaces,
// initialization, storage association, &c. Includes the original symbol
// or members of the original set. Does not include dependences from
// subprogram definitions, only their interfaces.
enum DependenceCollectionFlags {
NoDependenceCollectionFlags = 0,
IncludeOriginalSymbols = 1 << 0,
FollowUseAssociations = 1 << 1,
IncludeSpecificsOfGenerics = 1 << 2,
IncludeComponentsInExprs = 1 << 3,
};
UnorderedSymbolSet CollectAllDependences(
const UnorderedSymbolSet &, int = NoDependenceCollectionFlags);
UnorderedSymbolSet CollectAllDependences(
const Scope &, int = NoDependenceCollectionFlags);

} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_SYMBOLS_SET_CLOSURE_H_
1 change: 1 addition & 0 deletions flang/lib/Semantics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_flang_library(FortranSemantics
scope.cpp
semantics.cpp
symbol.cpp
symbol-set-closure.cpp
tools.cpp
type.cpp
unparse-with-symbols.cpp
Expand Down
68 changes: 4 additions & 64 deletions flang/lib/Semantics/mod-file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "flang/Parser/unparse.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol-set-closure.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "llvm/Support/FileSystem.h"
Expand Down Expand Up @@ -223,71 +224,10 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
// Collect symbols from constant and specification expressions that are being
// referenced directly from other modules; they may require new USE
// associations.
static void HarvestSymbolsNeededFromOtherModules(
SourceOrderedSymbolSet &, const Scope &);
static void HarvestSymbolsNeededFromOtherModules(
SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) {
auto HarvestBound{[&](const Bound &bound) {
if (const auto &expr{bound.GetExplicit()}) {
for (SymbolRef ref : evaluate::CollectSymbols(*expr)) {
set.emplace(*ref);
}
}
}};
auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) {
HarvestBound(shapeSpec.lbound());
HarvestBound(shapeSpec.ubound());
}};
auto HarvestArraySpec{[&](const ArraySpec &arraySpec) {
for (const auto &shapeSpec : arraySpec) {
HarvestShapeSpec(shapeSpec);
}
}};

if (symbol.has<DerivedTypeDetails>()) {
if (symbol.scope()) {
HarvestSymbolsNeededFromOtherModules(set, *symbol.scope());
}
} else if (const auto &generic{symbol.detailsIf<GenericDetails>()};
generic && generic->derivedType()) {
const Symbol &dtSym{*generic->derivedType()};
if (dtSym.has<DerivedTypeDetails>()) {
if (dtSym.scope()) {
HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope());
}
} else {
CHECK(dtSym.has<UseDetails>() || dtSym.has<UseErrorDetails>());
}
} else if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()}) {
HarvestArraySpec(object->shape());
HarvestArraySpec(object->coshape());
if (IsNamedConstant(symbol) || scope.IsDerivedType()) {
if (object->init()) {
for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) {
set.emplace(*ref);
}
}
}
} else if (const auto *proc{symbol.detailsIf<ProcEntityDetails>()}) {
if (proc->init() && *proc->init() && scope.IsDerivedType()) {
set.emplace(**proc->init());
}
} else if (const auto *subp{symbol.detailsIf<SubprogramDetails>()}) {
for (const Symbol *dummy : subp->dummyArgs()) {
if (dummy) {
HarvestSymbolsNeededFromOtherModules(set, *dummy, scope);
}
}
if (subp->isFunction()) {
HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope);
}
}
}

static void HarvestSymbolsNeededFromOtherModules(
SourceOrderedSymbolSet &set, const Scope &scope) {
for (const auto &[_, symbol] : scope) {
HarvestSymbolsNeededFromOtherModules(set, *symbol, scope);
for (const Symbol &symbol : CollectAllDependences(scope)) {
set.insert(symbol);
}
}

Expand Down Expand Up @@ -369,7 +309,7 @@ void ModFileWriter::PutSymbols(
PrepareRenamings(scope);
SourceOrderedSymbolSet modules;
CollectSymbols(scope, sorted, uses, modules);
// Write module files for dependencies first so that their
// Write module files for dependences first so that their
// hashes are known.
for (const Symbol &mod : modules) {
if (hermeticModules) {
Expand Down
10 changes: 6 additions & 4 deletions flang/lib/Semantics/resolve-names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7416,7 +7416,8 @@ void DeclarationVisitor::SetType(
std::optional<DerivedTypeSpec> DeclarationVisitor::ResolveDerivedType(
const parser::Name &name) {
Scope &outer{NonDerivedTypeScope()};
Symbol *symbol{FindSymbol(outer, name)};
Symbol *original{FindSymbol(outer, name)};
Symbol *symbol{original};
Symbol *ultimate{symbol ? &symbol->GetUltimate() : nullptr};
auto *generic{ultimate ? ultimate->detailsIf<GenericDetails>() : nullptr};
if (generic) {
Expand All @@ -7429,11 +7430,12 @@ std::optional<DerivedTypeSpec> DeclarationVisitor::ResolveDerivedType(
(generic && &ultimate->owner() == &outer)) {
if (allowForwardReferenceToDerivedType()) {
if (!symbol) {
symbol = &MakeSymbol(outer, name.source, Attrs{});
symbol = original = &MakeSymbol(outer, name.source, Attrs{});
Resolve(name, *symbol);
} else if (generic) {
// forward ref to type with later homonymous generic
symbol = &outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{});
symbol = original =
&outer.MakeSymbol(name.source, Attrs{}, UnknownDetails{});
generic->set_derivedType(*symbol);
name.symbol = symbol;
}
Expand All @@ -7453,7 +7455,7 @@ std::optional<DerivedTypeSpec> DeclarationVisitor::ResolveDerivedType(
if (CheckUseError(name)) {
return std::nullopt;
} else if (symbol->GetUltimate().has<DerivedTypeDetails>()) {
return DerivedTypeSpec{name.source, *symbol};
return DerivedTypeSpec{name.source, *original};
} else {
Say(name, "'%s' is not a derived type"_err_en_US);
return std::nullopt;
Expand Down
187 changes: 187 additions & 0 deletions flang/lib/Semantics/symbol-set-closure.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//===-- lib/Semantics/symbol-set-closure.cpp ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Semantics/symbol-set-closure.h"
#include "flang/Common/idioms.h"
#include "flang/Common/visit.h"

namespace Fortran::semantics {

class Collector {
public:
explicit Collector(int flags) : flags_{flags} {}

UnorderedSymbolSet Collected() { return std::move(set_); }

void operator()(const Symbol &x) { set_.insert(x); }
void operator()(SymbolRef x) { (*this)(*x); }
template <typename A> void operator()(const std::optional<A> &x) {
if (x) {
(*this)(*x);
}
}
template <typename A> void operator()(const A *x) {
if (x) {
(*this)(*x);
}
}
void operator()(const UnorderedSymbolSet &x) {
for (const Symbol &symbol : x) {
(*this)(symbol);
}
}
void operator()(const SourceOrderedSymbolSet &x) {
for (const Symbol &symbol : x) {
(*this)(symbol);
}
}
void operator()(const Scope &x) {
for (const auto &[_, ref] : x) {
(*this)(*ref);
}
}
template <typename T> void operator()(const evaluate::Expr<T> &x) {
UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)};
for (const Symbol &sym : exprSyms) {
if (!sym.owner().IsDerivedType() || sym.has<DerivedTypeDetails>() ||
(flags_ & IncludeComponentsInExprs)) {
(*this)(sym);
}
}
}
void operator()(const DeclTypeSpec &type) {
if (type.category() == DeclTypeSpec::Category::Character) {
(*this)(type.characterTypeSpec().length());
} else {
(*this)(type.AsDerived());
}
}
void operator()(const DerivedTypeSpec &type) {
(*this)(type.originalTypeSymbol());
for (const auto &[_, value] : type.parameters()) {
(*this)(value);
}
}
void operator()(const ParamValue &x) { (*this)(x.GetExplicit()); }
void operator()(const Bound &x) { (*this)(x.GetExplicit()); }
void operator()(const ShapeSpec &x) {
(*this)(x.lbound());
(*this)(x.ubound());
}
void operator()(const ArraySpec &x) {
for (const ShapeSpec &shapeSpec : x) {
(*this)(shapeSpec);
}
}

private:
UnorderedSymbolSet set_;
int flags_{NoDependenceCollectionFlags};
};

UnorderedSymbolSet CollectAllDependences(const Scope &scope, int flags) {
UnorderedSymbolSet basis;
for (const auto &[_, symbol] : scope) {
basis.insert(*symbol);
}
return CollectAllDependences(basis, flags);
}

UnorderedSymbolSet CollectAllDependences(
const UnorderedSymbolSet &original, int flags) {
UnorderedSymbolSet result;
if (flags & IncludeOriginalSymbols) {
result = original;
}
UnorderedSymbolSet work{original};
while (!work.empty()) {
Collector collect{flags};
for (const Symbol &symbol : work) {
if (symbol.test(Symbol::Flag::CompilerCreated)) {
continue;
}
collect(symbol.GetType());
common::visit(
common::visitors{
[&collect, &symbol](const ObjectEntityDetails &x) {
collect(x.shape());
collect(x.coshape());
if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {
collect(x.init());
}
collect(x.commonBlock());
if (const auto *set{FindEquivalenceSet(symbol)}) {
for (const EquivalenceObject &equivObject : *set) {
collect(equivObject.symbol);
}
}
},
[&collect, &symbol](const ProcEntityDetails &x) {
collect(x.rawProcInterface());
if (symbol.owner().IsDerivedType()) {
collect(x.init());
}
// TODO: worry about procedure pointers in common blocks?
},
[&collect](const ProcBindingDetails &x) { collect(x.symbol()); },
[&collect](const SubprogramDetails &x) {
for (const Symbol *dummy : x.dummyArgs()) {
collect(dummy);
}
if (x.isFunction()) {
collect(x.result());
}
},
[&collect, &symbol](
const DerivedTypeDetails &) { collect(symbol.scope()); },
[&collect, flags](const GenericDetails &x) {
collect(x.derivedType());
collect(x.specific());
for (const Symbol &use : x.uses()) {
collect(use);
}
if (flags & IncludeSpecificsOfGenerics) {
for (const Symbol &specific : x.specificProcs()) {
collect(specific);
}
}
},
[&collect](const NamelistDetails &x) {
for (const Symbol &symbol : x.objects()) {
collect(symbol);
}
},
[&collect](const CommonBlockDetails &x) {
for (auto ref : x.objects()) {
collect(*ref);
}
},
[&collect, &symbol, flags](const UseDetails &x) {
if (flags & FollowUseAssociations) {
collect(x.symbol());
}
},
[&collect](const HostAssocDetails &x) { collect(x.symbol()); },
[](const auto &) {},
},
symbol.details());
}
work.clear();
for (const Symbol &symbol : collect.Collected()) {
if (result.find(symbol) == result.end() &&
((flags & IncludeOriginalSymbols) ||
original.find(symbol) == original.end())) {
result.insert(symbol);
work.insert(symbol);
}
}
}
return result;
}

} // namespace Fortran::semantics
6 changes: 3 additions & 3 deletions flang/lib/Semantics/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//

#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
#include "flang/Common/indirection.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/message.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/StringSwitch.h"
Expand Down Expand Up @@ -2117,4 +2117,4 @@ bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) {
return false;
}
}
} // namespace Fortran::semantics
} // namespace Fortran::semantics
Loading
Loading