forked from torch/nngraph
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraphinspecting.lua
More file actions
119 lines (105 loc) · 3.57 KB
/
graphinspecting.lua
File metadata and controls
119 lines (105 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
-- The findCurrentNode() depends on the names of the
-- local variables in the nngraph.gModule source code.
local function findCurrentNode()
for level = 2, math.huge do
local info = debug.getinfo(level, "n")
if info == nil then
return nil
end
local funcName = info.name
if funcName == "neteval" then
local varName, node = debug.getlocal(level, 1)
if varName == "node" then
return node
end
end
end
end
-- Runs the func and calls onError(failedNode, ...) on an error.
-- The stack trace is inspected to find the failedNode.
local function runChecked(func, onError, ...)
-- The current node needs to be searched-for, before unrolling the stack.
local failedNode
local function errorHandler(message)
-- The stack traceback is added only if not already present.
if not string.find(message, 'stack traceback:\n', 1, true) then
message = debug.traceback(message, 2)
end
failedNode = findCurrentNode()
return message
end
local ok, result = xpcall(func, errorHandler)
if ok then
return result
end
onError(failedNode, ...)
-- Passing the level 0 avoids adding an additional error position info
-- to the message.
error(result, 0)
end
local function customToDot(graph, title, failedNode)
local str = graph:todot(title)
if not failedNode then
return str
end
local failedNodeId = nil
for i, node in ipairs(graph.nodes) do
if node.data == failedNode.data then
failedNodeId = node.id
break
end
end
if failedNodeId ~= nil then
-- The closing '}' is removed.
-- And red fillcolor is specified for the failedNode.
str = string.gsub(str, '}%s*$', '')
str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
failedNodeId)
end
return str
end
local function saveSvg(svgPathPrefix, dotStr)
io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
local dotPath = svgPathPrefix .. '.dot'
local dotFile = io.open(dotPath, 'w')
dotFile:write(dotStr)
dotFile:close()
local svgPath = svgPathPrefix .. '.svg'
local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
os.execute(cmd)
end
local function onError(failedNode, gmodule)
local nInputs = gmodule.nInputs or #gmodule.innode.children
local svgPathPrefix = gmodule.name or string.format(
'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
saveSvg(svgPathPrefix, dotStr)
end
local origFuncs = {
runForwardFunction = nn.gModule.runForwardFunction,
updateGradInput = nn.gModule.updateGradInput,
accGradParameters = nn.gModule.accGradParameters,
}
-- When debug is enabled,
-- a gmodule.name .. '.svg' will be saved
-- if an exception occurs in a graph execution.
-- The problematic node will be marked by red color.
function nngraph.setDebug(enable)
if not enable then
-- When debug is disabled,
-- the origFuncs are restored on nn.gModule.
for funcName, origFunc in pairs(origFuncs) do
nn.gModule[funcName] = origFunc
end
return
end
for funcName, origFunc in pairs(origFuncs) do
nn.gModule[funcName] = function(...)
local args = {...}
local gmodule = args[1]
return runChecked(function()
return origFunc(unpack(args))
end, onError, gmodule)
end
end
end