1
1
import operator
2
- from typing import Any , Dict , List , Union , cast
2
+ from typing import Any , Dict , List , cast
3
3
4
4
import torch
5
- import torch ._dynamo as dynamo
6
5
import torch .fx as fx
7
6
import torch .nn as nn
8
7
import torch .utils ._pytree as pytree
@@ -22,27 +21,26 @@ def run_node(self, n: fx.node.Node) -> Any:
22
21
23
22
24
23
def parse (model : nn .Module , * example_args : List [torch .Tensor ]) -> GraphIR :
25
- gm : fx .GraphModule = dynamo . export (model , * example_args )[ 0 ]
24
+ gm : fx .GraphModule = fx . symbolic_trace (model )
26
25
# store shape info on nodes
27
26
sp = PropInterpreter (gm )
28
27
sp .run (* example_args )
29
28
30
- def load_args (args : Union [List , Dict [str , Any ]]) -> Union [List , Dict [str , Any ]]:
31
- """
32
- Map nodes to their outputs while keeping structures and other values the same.
33
- """
34
- return torch .fx .graph .map_arg (args , lambda n : n .meta ["res" ])
35
-
36
29
named_modules = dict (gm .named_modules ())
37
30
ir = GraphIR ()
38
31
name2retvals : Dict [str , List [str ]] = {}
39
- for i_node , node in enumerate ( gm .graph .nodes ) :
32
+ for node in gm .graph .nodes :
40
33
node = cast (fx .node .Node , node )
41
34
if node .op == "placeholder" :
42
- iexpr = InstExpr (Input (dim = len (node .meta ["res" ].shape )), [])
35
+ input_node = Input (dim = len (node .meta ["res" ].shape ))
36
+ input_node .abs_tensor = AbsTensor (
37
+ shape = list (node .meta ["res" ].shape ),
38
+ dtype = DType .from_torch (node .meta ["res" ].dtype ),
39
+ )
40
+ iexpr = InstExpr (input_node , [])
43
41
else :
44
- args_flatten , args_treespec = pytree .tree_flatten (node .args )
45
- kwargs_flatten , kwargs_treespec = pytree .tree_flatten (node .kwargs )
42
+ args_flatten , _ = pytree .tree_flatten (node .args )
43
+ kwargs_flatten , _ = pytree .tree_flatten (node .kwargs )
46
44
input_nodes = [
47
45
a
48
46
for a in (args_flatten + kwargs_flatten )
@@ -67,8 +65,8 @@ def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
67
65
pytree .tree_flatten (node .meta ["res" ])[0 ],
68
66
)
69
67
)
70
- nodes2empty = (
71
- lambda n : ConcreteOp .empty if isinstance (n , fx .node .Node ) else n
68
+ nodes2empty = lambda n : (
69
+ ConcreteOp .empty if isinstance (n , fx .node .Node ) else n
72
70
)
73
71
args_wo_nodes = pytree .tree_map (nodes2empty , node .args )
74
72
kwargs_wo_nodes = pytree .tree_map (nodes2empty , node .kwargs )
0 commit comments