From 12e078fafe9c324fc69411828e30e1552c6de9d1 Mon Sep 17 00:00:00 2001
From: Vladimir Cherepanov <vcherepanov@nvidia.com>
Date: Wed, 17 Nov 2021 22:43:53 -0800
Subject: [PATCH] [WIP]Debug print graph

---
 src/common/exec_utils.cc   | 20 +++++++++++++++++
 src/common/exec_utils.h    |  9 ++++++++
 src/imperative/cached_op.h | 24 +++++++++++++++++++++
 tools/print_graph.py       | 44 ++++++++++++++++++++++++++++++++++++++
 4 files changed, 97 insertions(+)
 create mode 100755 tools/print_graph.py

diff --git a/src/common/exec_utils.cc b/src/common/exec_utils.cc
index bbc11e12a708..20245aebe04c 100644
--- a/src/common/exec_utils.cc
+++ b/src/common/exec_utils.cc
@@ -75,5 +75,25 @@ bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx) {
   return true;
 }
 
+void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os) {
+  auto node_str = [&idx](uint32_t nid) {
+    return std::to_string(nid) + " " + idx[nid].source->attrs.name;
+  };
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    const auto& attrs = idx[i].source->attrs;
+    os << "node " << node_str(i) << " " << (attrs.op ? attrs.op->name : "(var)") << "\n";
+    for (auto [k, v] : attrs.dict)
+      os << "attr " << k << " " << v << "\n";
+    for (const auto& inp : idx[i].inputs)
+      os << "inp " << node_str(inp.node_id) << " " << inp.index << " " << inp.version << "\n";
+    for (auto dep : idx[i].control_deps)
+      os << "dep " << node_str(dep) << "\n";
+    for (const auto& sub : attrs.subgraphs) {
+      std::string name;
+      os << "sub " << (sub->GetAttr("name", &name) ? name : "(noname)") << "\n";
+    }
+  }
+}
+
 }  // namespace common
 }  // namespace mxnet
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index 21a97130c183..4052977ed869 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -27,6 +27,7 @@
 #include <nnvm/graph.h>
 #include <nnvm/pass_functions.h>
 #include <map>
+#include <ostream>
 #include <vector>
 #include <string>
 #include <utility>
@@ -570,6 +571,14 @@ void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables);
  */
 bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx);
 
+/*!
+ * \brief Prints graph to the specified stream.
+ *
+ * \param idx Indexed graph to print
+ * \param os Output stream
+ */
+void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os);
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_EXEC_UTILS_H_
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 079a56e20a12..9fc8ae653c30 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -21,6 +21,7 @@
 #define MXNET_IMPERATIVE_CACHED_OP_H_
 
 #include <mxnet/imperative.h>
+#include <fstream>
 #include <vector>
 #include <numeric>
 #include <atomic>
@@ -29,6 +30,7 @@
 #include <unordered_map>
 #include <map>
 #include "../common/alm.h"
+#include "../common/exec_utils.h"
 #include "../operator/operator_common.h"
 #include "../operator/subgraph/common.h"
 #include "./imperative_utils.h"
@@ -330,6 +332,26 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) {
       std::make_shared<dmlc::any>(std::move(full_ref_count));
 }
 
+void MaybePrintGraph(const nnvm::IndexedGraph& idx, const std::string& msg) {
+  if (!dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH", false))
+    return;
+
+  std::ofstream f;
+  std::ostream* dest    = &std::cout;
+  std::string dest_name = dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH_PATH", std::string("stdout"));
+  if (dest_name == "stderr") {
+    dest = &std::cerr;
+  } else if (dest_name != "stdout") {
+    f.open(dest_name.c_str(), std::ios::app);
+    CHECK(f.good());
+    dest = &f;
+  }
+
+  *dest << "[[[ " << msg << "\n";
+  common::PrintGraph(idx, *dest);
+  *dest << "]]] " << msg << "\n";
+}
+
 void OptimizeGraph(nnvm::Graph* full_graph,
                    nnvm::Graph* fwd_graph,
                    nnvm::Graph* grad_graph,
@@ -337,6 +359,7 @@ void OptimizeGraph(nnvm::Graph* full_graph,
                    const Context& context,
                    size_t num_forward_outputs,
                    const bool inlining) {
+  MaybePrintGraph(full_graph->indexed_graph(), "graph before optimization");
   input_map->resize(full_graph->indexed_graph().input_nodes().size());
   std::iota(input_map->begin(), input_map->end(), 0);
 #if MXNET_USE_CUDA && !defined(_WIN32)
@@ -386,6 +409,7 @@ void OptimizeGraph(nnvm::Graph* full_graph,
   grad_graph->outputs = std::vector<nnvm::NodeEntry>(
       full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end());
   SetRefCounts(fwd_graph, *full_graph);
+  MaybePrintGraph(full_graph->indexed_graph(), "graph after optimization");
 }
 
 /* \brief Check if param indices and data indices are set, if not then set data indices */
diff --git a/tools/print_graph.py b/tools/print_graph.py
new file mode 100755
index 000000000000..f7a109460625
--- /dev/null
+++ b/tools/print_graph.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python3
+
+import re
+import sys
+
+
+RE_NODE = re.compile(r'node\s(.+)\n')
+RE_ATTR = re.compile(r'attr\s(.+)\n')
+RE_INP = re.compile(r'inp\s(.+)\n')
+RE_DEP = re.compile(r'dep\s(.+)\n')
+RE_SUB = re.compile(r'node\s(.+)\n')
+
+
+def to_dot(f):
+    print('digraph Net {')
+    for line in f:
+        m = RE_NODE.fullmatch(line)
+        if m:
+            nid, name, op = m.group(1).split()
+            shape = 'ellipse' if op == '(var)' else 'rectangle'
+            print(f'  node_{nid} [shape={shape}, label={name}]')
+            continue
+        m = RE_ATTR.fullmatch(line)
+        if m:
+            continue
+        m = RE_INP.fullmatch(line)
+        if m:
+            njd, _name, index, _version = m.group(1).split()
+            print(f'  node_{njd} -> node_{nid} [label={index}, style=solid]')
+            continue
+        m = RE_DEP.fullmatch(line)
+        if m:
+            njd, _name = m.group(1).split()
+            print(f'  node_{njd} -> node_{nid} [style=dashed]')
+            continue
+        m = RE_SUB.fullmatch(line)
+        if m:
+            continue
+        break
+    print('}')
+
+
+if __name__ == '__main__':
+    to_dot(sys.stdin)