Skip to content

Commit cff46a5

Browse files
authored
fix(materialize): onnx loading with torch model available (#134)
1 parent 4061eb7 commit cff46a5

File tree

6 files changed

+24
-26
lines changed

6 files changed

+24
-26
lines changed

.github/workflows/ci.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ jobs:
2525
cd doc && python bug_summary.py
2626
- name: Test core
2727
run: |
28-
pytest tests/core
28+
pytest -x tests/core
2929
- name: Test PyTorch
3030
run: |
3131
pip install -r requirements/sys/torch.txt --pre --upgrade
3232
pip install -r requirements/sys/onnx.txt --pre --upgrade
3333
pip install -r requirements/sys/tvm.txt --pre --upgrade
3434
pip install -r requirements/sys/onnxruntime.txt --pre --upgrade
35-
pytest tests/torch
35+
pytest -x tests/torch
3636
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch mgen.method=symbolic
3737
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch mgen.method=symbolic-cinit
3838
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch backend.type="pt2 backend@inductor" mgen.method=concolic
@@ -42,20 +42,20 @@ jobs:
4242
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py backend.type=pt2 mgen.grad_check=true
4343
- name: Test ONNX + ONNXRuntime
4444
run: |
45-
pytest tests/onnxruntime
45+
pytest -x tests/onnxruntime
4646
yes | python nnsmith/cli/model_gen.py model.type=onnx mgen.method=symbolic
4747
yes | python nnsmith/cli/model_gen.py model.type=onnx backend.type=onnxruntime mgen.method=concolic
4848
python nnsmith/cli/model_exec.py model.type=onnx backend.type=onnxruntime model.path=nnsmith_output/model.onnx
4949
- name: Test ONNX + TVM
5050
run: |
51-
pytest tests/tvm
51+
pytest -x tests/tvm
5252
- name: Test ONNX + TRT
5353
run: |
54-
pytest tests/tensorrt
54+
pytest -x tests/tensorrt
5555
- name: Test TensorFlow
5656
run: |
5757
pip install -r requirements/sys/tensorflow.txt --pre --upgrade
58-
pytest tests/tensorflow --log-cli-level=DEBUG
58+
pytest -x tests/tensorflow --log-cli-level=DEBUG
5959
yes | python nnsmith/cli/model_gen.py model.type=tensorflow mgen.method=symbolic
6060
python nnsmith/cli/model_exec.py model.type=tensorflow backend.type=xla model.path=nnsmith_output/model/
6161
yes | python nnsmith/cli/model_gen.py model.type=tensorflow mgen.method=concolic

experiments/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Next, we will run NNSmith to dump a bunch of random test-cases for an SUT (say P
107107
>
108108
> ```shell
109109
> # PyTorch
110-
> pip install --extra-index-url https://download.pytorch.org/whl/nightly/cpu --pre torch
110+
> pip install --index-url https://download.pytorch.org/whl/nightly/cpu --pre torch
111111
> # TensorFlow
112112
> pip install tf-nightly
113113
> ```

nnsmith/materialize/onnx/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def load(cls, path: PathLike) -> "ONNXModel":
260260
# FIXME: missing key(s) in state_dict: "mlist.0.data", "mlist.1.data".
261261
if os.path.exists(torch_path):
262262
ret.with_torch = True
263-
ret.torch_model = cls.PTType.load(torch_path)
263+
ret.torch_model = cls.PTType.load(torch_path).torch_model
264264
ret.full_input_like = ret.torch_model.input_like
265265
ret.full_output_like = ret.torch_model.output_like
266266

nnsmith/materialize/torch/parse.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import operator
2-
from typing import Any, Dict, List, Union, cast
2+
from typing import Any, Dict, List, cast
33

44
import torch
5-
import torch._dynamo as dynamo
65
import torch.fx as fx
76
import torch.nn as nn
87
import torch.utils._pytree as pytree
@@ -22,27 +21,26 @@ def run_node(self, n: fx.node.Node) -> Any:
2221

2322

2423
def parse(model: nn.Module, *example_args: List[torch.Tensor]) -> GraphIR:
25-
gm: fx.GraphModule = dynamo.export(model, *example_args)[0]
24+
gm: fx.GraphModule = fx.symbolic_trace(model)
2625
# store shape info on nodes
2726
sp = PropInterpreter(gm)
2827
sp.run(*example_args)
2928

30-
def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
31-
"""
32-
Map nodes to their outputs while keeping structures and other values the same.
33-
"""
34-
return torch.fx.graph.map_arg(args, lambda n: n.meta["res"])
35-
3629
named_modules = dict(gm.named_modules())
3730
ir = GraphIR()
3831
name2retvals: Dict[str, List[str]] = {}
39-
for i_node, node in enumerate(gm.graph.nodes):
32+
for node in gm.graph.nodes:
4033
node = cast(fx.node.Node, node)
4134
if node.op == "placeholder":
42-
iexpr = InstExpr(Input(dim=len(node.meta["res"].shape)), [])
35+
input_node = Input(dim=len(node.meta["res"].shape))
36+
input_node.abs_tensor = AbsTensor(
37+
shape=list(node.meta["res"].shape),
38+
dtype=DType.from_torch(node.meta["res"].dtype),
39+
)
40+
iexpr = InstExpr(input_node, [])
4341
else:
44-
args_flatten, args_treespec = pytree.tree_flatten(node.args)
45-
kwargs_flatten, kwargs_treespec = pytree.tree_flatten(node.kwargs)
42+
args_flatten, _ = pytree.tree_flatten(node.args)
43+
kwargs_flatten, _ = pytree.tree_flatten(node.kwargs)
4644
input_nodes = [
4745
a
4846
for a in (args_flatten + kwargs_flatten)
@@ -67,8 +65,8 @@ def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
6765
pytree.tree_flatten(node.meta["res"])[0],
6866
)
6967
)
70-
nodes2empty = (
71-
lambda n: ConcreteOp.empty if isinstance(n, fx.node.Node) else n
68+
nodes2empty = lambda n: (
69+
ConcreteOp.empty if isinstance(n, fx.node.Node) else n
7270
)
7371
args_wo_nodes = pytree.tree_map(nodes2empty, node.args)
7472
kwargs_wo_nodes = pytree.tree_map(nodes2empty, node.kwargs)

requirements/sys/torch.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# TODO(@ganler): make other platform/device distribution also work.
2-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
2+
--index-url https://download.pytorch.org/whl/nightly/cpu
33
--pre
44
torch

tests/torch/test_dump_load.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_onnx_load_dump(tmp_path):
4545
# check oracle
4646
compare_two_oracle(oracle, loaded_testcase.oracle)
4747

48-
loaded_model = loaded_testcase.model.torch_model
48+
loaded_model = loaded_testcase.model
4949
loaded_model.sat_inputs = {k: torch.from_numpy(v) for k, v in oracle.input.items()}
5050
rerun_oracle = loaded_model.make_oracle()
5151
compare_two_oracle(oracle, rerun_oracle)
@@ -77,7 +77,7 @@ def test_bug_report_load_dump(tmp_path):
7777
# check oracle
7878
compare_two_oracle(oracle, loaded_testcase.oracle)
7979

80-
loaded_model = loaded_testcase.model.torch_model
80+
loaded_model = loaded_testcase.model
8181
loaded_model.sat_inputs = {k: torch.from_numpy(v) for k, v in oracle.input.items()}
8282
rerun_oracle = loaded_model.make_oracle()
8383
compare_two_oracle(oracle, rerun_oracle)

0 commit comments

Comments
 (0)