From 12e078fafe9c324fc69411828e30e1552c6de9d1 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov <vcherepanov@nvidia.com> Date: Wed, 17 Nov 2021 22:43:53 -0800 Subject: [PATCH] [WIP]Debug print graph --- src/common/exec_utils.cc | 20 +++++++++++++++++ src/common/exec_utils.h | 9 ++++++++ src/imperative/cached_op.h | 24 +++++++++++++++++++++ tools/print_graph.py | 44 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+) create mode 100755 tools/print_graph.py diff --git a/src/common/exec_utils.cc b/src/common/exec_utils.cc index bbc11e12a708..20245aebe04c 100644 --- a/src/common/exec_utils.cc +++ b/src/common/exec_utils.cc @@ -75,5 +75,25 @@ bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx) { return true; } +void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os) { + auto node_str = [&idx](uint32_t nid) { + return std::to_string(nid) + " " + idx[nid].source->attrs.name; + }; + for (size_t i = 0; i < idx.num_nodes(); ++i) { + const auto& attrs = idx[i].source->attrs; + os << "node " << node_str(i) << " " << (attrs.op ? attrs.op->name : "(var)") << "\n"; + for (auto [k, v] : attrs.dict) + os << "attr " << k << " " << v << "\n"; + for (const auto& inp : idx[i].inputs) + os << "inp " << node_str(inp.node_id) << " " << inp.index << " " << inp.version << "\n"; + for (auto dep : idx[i].control_deps) + os << "dep " << node_str(dep) << "\n"; + for (const auto& sub : attrs.subgraphs) { + std::string name; + os << "sub " << (sub->GetAttr("name", &name) ? name : "(noname)") << "\n"; + } + } +} + } // namespace common } // namespace mxnet diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 21a97130c183..4052977ed869 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -27,6 +27,7 @@ #include <nnvm/graph.h> #include <nnvm/pass_functions.h> #include <map> +#include <ostream> #include <vector> #include <string> #include <utility> @@ -570,6 +571,14 @@ void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables); */ bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx); +/*! + * \brief Prints graph to the specified stream. + * + * \param idx Indexed graph to print + * \param os Output stream + */ +void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os); + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_EXEC_UTILS_H_ diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 079a56e20a12..9fc8ae653c30 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -21,6 +21,7 @@ #define MXNET_IMPERATIVE_CACHED_OP_H_ #include <mxnet/imperative.h> +#include <fstream> #include <vector> #include <numeric> #include <atomic> @@ -29,6 +30,7 @@ #include <unordered_map> #include <map> #include "../common/alm.h" +#include "../common/exec_utils.h" #include "../operator/operator_common.h" #include "../operator/subgraph/common.h" #include "./imperative_utils.h" @@ -330,6 +332,26 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { std::make_shared<dmlc::any>(std::move(full_ref_count)); } +void MaybePrintGraph(const nnvm::IndexedGraph& idx, const std::string& msg) { + if (!dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH", false)) + return; + + std::ofstream f; + std::ostream* dest = &std::cout; + std::string dest_name = dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH_PATH", std::string("stdout")); + if (dest_name == "stderr") { + dest = &std::cerr; + } else if (dest_name != "stdout") { + f.open(dest_name.c_str(), std::ios::app); + CHECK(f.good()); + dest = &f; + } + + *dest << "[[[ " << msg << "\n"; + common::PrintGraph(idx, *dest); + *dest << "]]] " << msg << "\n"; +} + void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* grad_graph, @@ -337,6 +359,7 @@ void OptimizeGraph(nnvm::Graph* full_graph, const Context& context, size_t num_forward_outputs, const bool inlining) { + MaybePrintGraph(full_graph->indexed_graph(), "graph before optimization"); input_map->resize(full_graph->indexed_graph().input_nodes().size()); std::iota(input_map->begin(), input_map->end(), 0); #if MXNET_USE_CUDA && !defined(_WIN32) @@ -386,6 +409,7 @@ void OptimizeGraph(nnvm::Graph* full_graph, grad_graph->outputs = std::vector<nnvm::NodeEntry>( full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end()); SetRefCounts(fwd_graph, *full_graph); + MaybePrintGraph(full_graph->indexed_graph(), "graph after optimization"); } /* \brief Check if param indices and data indices are set, if not then set data indices */ diff --git a/tools/print_graph.py b/tools/print_graph.py new file mode 100755 index 000000000000..f7a109460625 --- /dev/null +++ b/tools/print_graph.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import re +import sys + + +RE_NODE = re.compile(r'node\s(.+)\n') +RE_ATTR = re.compile(r'attr\s(.+)\n') +RE_INP = re.compile(r'inp\s(.+)\n') +RE_DEP = re.compile(r'dep\s(.+)\n') +RE_SUB = re.compile(r'node\s(.+)\n') + + +def to_dot(f): + print('digraph Net {') + for line in f: + m = RE_NODE.fullmatch(line) + if m: + nid, name, op = m.group(1).split() + shape = 'ellipse' if op == '(var)' else 'rectangle' + print(f' node_{nid} [shape={shape}, label={name}]') + continue + m = RE_ATTR.fullmatch(line) + if m: + continue + m = RE_INP.fullmatch(line) + if m: + njd, _name, index, _version = m.group(1).split() + print(f' node_{njd} -> node_{nid} [label={index}, style=solid]') + continue + m = RE_DEP.fullmatch(line) + if m: + njd, _name = m.group(1).split() + print(f' node_{njd} -> node_{nid} [style=dashed]') + continue + m = RE_SUB.fullmatch(line) + if m: + continue + break + print('}') + + +if __name__ == '__main__': + to_dot(sys.stdin)