forked from pytorch/torchchat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport.py
119 lines (99 loc) · 3.8 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import torch
from build.builder import (
_initialize_model,
_initialize_tokenizer,
_set_gguf_kwargs,
_unset_gguf_kwargs,
BuilderArgs,
TokenizerArgs,
)
from build.utils import set_backend, set_precision
from cli import add_arguments_for_verb, arg_init, check_args
from export_util.export_aoti import export_model as export_model_aoti
try:
executorch_export_available = True
from export_util.export_et import export_model as export_model_et
except Exception as e:
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
executorch_export_available = False
default_device = "cpu"
def main(args):
builder_args = BuilderArgs.from_args(args)
quantize = args.quantize
print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
builder_args.dso_path = None
builder_args.pte_path = None
builder_args.setup_caches = True
output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path
if output_pte_path and builder_args.device != "cpu":
print(
f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting."
)
builder_args.device = "cpu"
elif "mps" in builder_args.device:
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
builder_args.device = "cpu"
# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
# tokenizer needed for quantization so get that here,
try:
tokenizer_args = TokenizerArgs.from_args(args)
tokenizer = _initialize_tokenizer(tokenizer_args)
except:
tokenizer = None
model = _initialize_model(
builder_args,
quantize,
tokenizer,
)
model_to_pte = model
model_to_dso = model
else:
if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
model_to_pte = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)
if output_dso_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)
with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
if executorch_export_available:
print(f"Exporting model using ExecuTorch to {output_pte_path}")
export_model_et(
model_to_pte, builder_args.device, args.output_pte_path, args
)
else:
print(
"Export with executorch requested but ExecuTorch could not be loaded"
)
print(executorch_exception)
if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="torchchat export CLI")
add_arguments_for_verb(parser, "export")
args = parser.parse_args()
check_args(args, "export")
args = arg_init(args)
main(args)