Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(materialize): onnx loading with torch model available #134

Merged
merged 6 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ jobs:
cd doc && python bug_summary.py
- name: Test core
run: |
pytest tests/core
pytest -x tests/core
- name: Test PyTorch
run: |
pip install -r requirements/sys/torch.txt --pre --upgrade
pip install -r requirements/sys/onnx.txt --pre --upgrade
pip install -r requirements/sys/tvm.txt --pre --upgrade
pip install -r requirements/sys/onnxruntime.txt --pre --upgrade
pytest tests/torch
pytest -x tests/torch
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch mgen.method=symbolic
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch mgen.method=symbolic-cinit
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch backend.type="pt2 backend@inductor" mgen.method=concolic
Expand All @@ -42,20 +42,20 @@ jobs:
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py backend.type=pt2 mgen.grad_check=true
- name: Test ONNX + ONNXRuntime
run: |
pytest tests/onnxruntime
pytest -x tests/onnxruntime
yes | python nnsmith/cli/model_gen.py model.type=onnx mgen.method=symbolic
yes | python nnsmith/cli/model_gen.py model.type=onnx backend.type=onnxruntime mgen.method=concolic
python nnsmith/cli/model_exec.py model.type=onnx backend.type=onnxruntime model.path=nnsmith_output/model.onnx
- name: Test ONNX + TVM
run: |
pytest tests/tvm
pytest -x tests/tvm
- name: Test ONNX + TRT
run: |
pytest tests/tensorrt
pytest -x tests/tensorrt
- name: Test TensorFlow
run: |
pip install -r requirements/sys/tensorflow.txt --pre --upgrade
pytest tests/tensorflow --log-cli-level=DEBUG
pytest -x tests/tensorflow --log-cli-level=DEBUG
yes | python nnsmith/cli/model_gen.py model.type=tensorflow mgen.method=symbolic
python nnsmith/cli/model_exec.py model.type=tensorflow backend.type=xla model.path=nnsmith_output/model/
yes | python nnsmith/cli/model_gen.py model.type=tensorflow mgen.method=concolic
Expand Down
2 changes: 1 addition & 1 deletion experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Next, we will run NNSmith to dump a bunch of random test-cases for an SUT (say P
>
> ```shell
> # PyTorch
> pip install --extra-index-url https://download.pytorch.org/whl/nightly/cpu --pre torch
> pip install --index-url https://download.pytorch.org/whl/nightly/cpu --pre torch
> # TensorFlow
> pip install tf-nightly
> ```
Expand Down
2 changes: 1 addition & 1 deletion nnsmith/materialize/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def load(cls, path: PathLike) -> "ONNXModel":
# FIXME: missing key(s) in state_dict: "mlist.0.data", "mlist.1.data".
if os.path.exists(torch_path):
ret.with_torch = True
ret.torch_model = cls.PTType.load(torch_path)
ret.torch_model = cls.PTType.load(torch_path).torch_model
ret.full_input_like = ret.torch_model.input_like
ret.full_output_like = ret.torch_model.output_like

Expand Down
28 changes: 13 additions & 15 deletions nnsmith/materialize/torch/parse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import operator
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, List, cast

import torch
import torch._dynamo as dynamo
import torch.fx as fx
import torch.nn as nn
import torch.utils._pytree as pytree
Expand All @@ -22,27 +21,26 @@ def run_node(self, n: fx.node.Node) -> Any:


def parse(model: nn.Module, *example_args: List[torch.Tensor]) -> GraphIR:
gm: fx.GraphModule = dynamo.export(model, *example_args)[0]
gm: fx.GraphModule = fx.symbolic_trace(model)
# store shape info on nodes
sp = PropInterpreter(gm)
sp.run(*example_args)

def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
"""
Map nodes to their outputs while keeping structures and other values the same.
"""
return torch.fx.graph.map_arg(args, lambda n: n.meta["res"])

named_modules = dict(gm.named_modules())
ir = GraphIR()
name2retvals: Dict[str, List[str]] = {}
for i_node, node in enumerate(gm.graph.nodes):
for node in gm.graph.nodes:
node = cast(fx.node.Node, node)
if node.op == "placeholder":
iexpr = InstExpr(Input(dim=len(node.meta["res"].shape)), [])
input_node = Input(dim=len(node.meta["res"].shape))
input_node.abs_tensor = AbsTensor(
shape=list(node.meta["res"].shape),
dtype=DType.from_torch(node.meta["res"].dtype),
)
iexpr = InstExpr(input_node, [])
else:
args_flatten, args_treespec = pytree.tree_flatten(node.args)
kwargs_flatten, kwargs_treespec = pytree.tree_flatten(node.kwargs)
args_flatten, _ = pytree.tree_flatten(node.args)
kwargs_flatten, _ = pytree.tree_flatten(node.kwargs)
input_nodes = [
a
for a in (args_flatten + kwargs_flatten)
Expand All @@ -67,8 +65,8 @@ def load_args(args: Union[List, Dict[str, Any]]) -> Union[List, Dict[str, Any]]:
pytree.tree_flatten(node.meta["res"])[0],
)
)
nodes2empty = (
lambda n: ConcreteOp.empty if isinstance(n, fx.node.Node) else n
nodes2empty = lambda n: (
ConcreteOp.empty if isinstance(n, fx.node.Node) else n
)
args_wo_nodes = pytree.tree_map(nodes2empty, node.args)
kwargs_wo_nodes = pytree.tree_map(nodes2empty, node.kwargs)
Expand Down
2 changes: 1 addition & 1 deletion requirements/sys/torch.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO(@ganler): make other platform/device distribution also work.
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--index-url https://download.pytorch.org/whl/nightly/cpu
--pre
torch
4 changes: 2 additions & 2 deletions tests/torch/test_dump_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_onnx_load_dump(tmp_path):
# check oracle
compare_two_oracle(oracle, loaded_testcase.oracle)

loaded_model = loaded_testcase.model.torch_model
loaded_model = loaded_testcase.model
loaded_model.sat_inputs = {k: torch.from_numpy(v) for k, v in oracle.input.items()}
rerun_oracle = loaded_model.make_oracle()
compare_two_oracle(oracle, rerun_oracle)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_bug_report_load_dump(tmp_path):
# check oracle
compare_two_oracle(oracle, loaded_testcase.oracle)

loaded_model = loaded_testcase.model.torch_model
loaded_model = loaded_testcase.model
loaded_model.sat_inputs = {k: torch.from_numpy(v) for k, v in oracle.input.items()}
rerun_oracle = loaded_model.make_oracle()
compare_two_oracle(oracle, rerun_oracle)
Loading