Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[WIP]Debug print graph #20743

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/common/exec_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,25 @@ bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx) {
return true;
}

void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os) {
auto node_str = [&idx](uint32_t nid) {
return std::to_string(nid) + " " + idx[nid].source->attrs.name;
};
for (size_t i = 0; i < idx.num_nodes(); ++i) {
const auto& attrs = idx[i].source->attrs;
os << "node " << node_str(i) << " " << (attrs.op ? attrs.op->name : "(var)") << "\n";
for (auto [k, v] : attrs.dict)
os << "attr " << k << " " << v << "\n";
for (const auto& inp : idx[i].inputs)
os << "inp " << node_str(inp.node_id) << " " << inp.index << " " << inp.version << "\n";
for (auto dep : idx[i].control_deps)
os << "dep " << node_str(dep) << "\n";
for (const auto& sub : attrs.subgraphs) {
std::string name;
os << "sub " << (sub->GetAttr("name", &name) ? name : "(noname)") << "\n";
}
}
}

} // namespace common
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/common/exec_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <nnvm/graph.h>
#include <nnvm/pass_functions.h>
#include <map>
#include <ostream>
#include <vector>
#include <string>
#include <utility>
Expand Down Expand Up @@ -570,6 +571,14 @@ void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables);
*/
bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx);

/*!
* \brief Prints graph to the specified stream.
*
* \param idx Indexed graph to print
* \param os Output stream
*/
void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os);

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_EXEC_UTILS_H_
24 changes: 24 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define MXNET_IMPERATIVE_CACHED_OP_H_

#include <mxnet/imperative.h>
#include <fstream>
#include <vector>
#include <numeric>
#include <atomic>
Expand All @@ -29,6 +30,7 @@
#include <unordered_map>
#include <map>
#include "../common/alm.h"
#include "../common/exec_utils.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"
#include "./imperative_utils.h"
Expand Down Expand Up @@ -330,13 +332,34 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) {
std::make_shared<dmlc::any>(std::move(full_ref_count));
}

void MaybePrintGraph(const nnvm::IndexedGraph& idx, const std::string& msg) {
if (!dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH", false))
return;

std::ofstream f;
std::ostream* dest = &std::cout;
std::string dest_name = dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH_PATH", std::string("stdout"));
if (dest_name == "stderr") {
dest = &std::cerr;
} else if (dest_name != "stdout") {
f.open(dest_name.c_str(), std::ios::app);
CHECK(f.good());
dest = &f;
}

*dest << "[[[ " << msg << "\n";
common::PrintGraph(idx, *dest);
*dest << "]]] " << msg << "\n";
}

void OptimizeGraph(nnvm::Graph* full_graph,
nnvm::Graph* fwd_graph,
nnvm::Graph* grad_graph,
std::vector<size_t>* input_map,
const Context& context,
size_t num_forward_outputs,
const bool inlining) {
MaybePrintGraph(full_graph->indexed_graph(), "graph before optimization");
input_map->resize(full_graph->indexed_graph().input_nodes().size());
std::iota(input_map->begin(), input_map->end(), 0);
#if MXNET_USE_CUDA && !defined(_WIN32)
Expand Down Expand Up @@ -386,6 +409,7 @@ void OptimizeGraph(nnvm::Graph* full_graph,
grad_graph->outputs = std::vector<nnvm::NodeEntry>(
full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end());
SetRefCounts(fwd_graph, *full_graph);
MaybePrintGraph(full_graph->indexed_graph(), "graph after optimization");
}

/* \brief Check if param indices and data indices are set, if not then set data indices */
Expand Down
44 changes: 44 additions & 0 deletions tools/print_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3

import re
import sys


RE_NODE = re.compile(r'node\s(.+)\n')
RE_ATTR = re.compile(r'attr\s(.+)\n')
RE_INP = re.compile(r'inp\s(.+)\n')
RE_DEP = re.compile(r'dep\s(.+)\n')
RE_SUB = re.compile(r'node\s(.+)\n')


def to_dot(f):
print('digraph Net {')
for line in f:
m = RE_NODE.fullmatch(line)
if m:
nid, name, op = m.group(1).split()
shape = 'ellipse' if op == '(var)' else 'rectangle'
print(f' node_{nid} [shape={shape}, label={name}]')
continue
m = RE_ATTR.fullmatch(line)
if m:
continue
m = RE_INP.fullmatch(line)
if m:
njd, _name, index, _version = m.group(1).split()
print(f' node_{njd} -> node_{nid} [label={index}, style=solid]')
continue
m = RE_DEP.fullmatch(line)
if m:
njd, _name = m.group(1).split()
print(f' node_{njd} -> node_{nid} [style=dashed]')
continue
m = RE_SUB.fullmatch(line)
if m:
continue
break
print('}')


if __name__ == '__main__':
to_dot(sys.stdin)