Skip to content

Commit cfee537

Browse files
authored
Export a lora model (#12875)
^ Add lora linear definition. Pull out linears from attention, and allow custom linear (eg. lora linear) to be passed in. If none, construct linear (current behaviour). ghstack-source-id: 298411928 @exported-using-ghexport Differential Revision: [D73953776](https://our.internmc.facebook.com/intern/diff/D73953776/)
1 parent 9483c72 commit cfee537

File tree

5 files changed

+48
-5
lines changed

5 files changed

+48
-5
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,10 @@ def get_serialized_buffer_index(
622622
)
623623

624624
external_tag = tensor.meta.get("delegate_constant_tag", None)
625-
logging.info(
626-
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
627-
)
625+
if external_tag is not None:
626+
logging.info(
627+
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
628+
)
628629
self._named_data_store.add_named_data(
629630
named_key,
630631
bytes(array),

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,18 @@ def build_args_parser() -> argparse.ArgumentParser:
239239
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
240240
)
241241

242+
parser.add_argument(
243+
"--adapter_checkpoint",
244+
required=False,
245+
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
246+
)
247+
248+
parser.add_argument(
249+
"--adapter_config",
250+
required=False,
251+
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
252+
)
253+
242254
parser.add_argument(
243255
"--use_qnn_sha",
244256
action="store_true",

examples/models/llama/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
4646
checkpoint_dir = self.llm_config.base.checkpoint_dir
4747
params_path = self.llm_config.base.params
4848

49+
# Adapter checkpoint and config.
50+
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
51+
adapter_config_path = self.llm_config.base.adapter_config
52+
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
53+
adapter_checkpoint_path is not None and adapter_config_path is not None
54+
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."
55+
4956
self.use_kv_cache = self.llm_config.model.use_kv_cache
5057
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
5158
self.generate_full_logits = self.llm_config.debug.generate_full_logits
@@ -129,6 +136,20 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
129136
with open(params_path, "r") as f:
130137
params = json.loads(f.read())
131138

139+
# Get adapter checkpoint and config.
140+
adapter_checkpoint = {}
141+
adapter_config = {}
142+
if adapter_checkpoint_path:
143+
adapter_checkpoint = torch.load(
144+
adapter_checkpoint_path, map_location=device, mmap=True
145+
)
146+
from torchtune.models import convert_weights
147+
148+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
149+
with open(adapter_config_path, "r") as f:
150+
adapter_config = json.loads(f.read())
151+
checkpoint.update(adapter_checkpoint)
152+
132153
output_prune_map = None
133154
if self.output_prune_map_path is not None:
134155
with open(self.output_prune_map_path, "r") as f:
@@ -153,6 +174,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
153174
output_prune_map=output_prune_map,
154175
enable_dynamic_shape=self.enable_dynamic_shape,
155176
**params,
177+
**adapter_config,
156178
)
157179

158180
if model_args.use_scaled_rope:

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ModelArgs:
5959
lora_args: Optional[dict] = None
6060

6161
# LoRA arguments to set up a LoRA inference model.
62-
# These arguments come directly from a torchtune LoRA config.
62+
# These arguments come directly from a torchtune adapter_config.json file.
6363
r: Optional[int] = None # Rank.
6464
lora_alpha: Optional[int] = None # Alpha.
6565
# Eg. q_proj, k_proj, v_proj, output_proj

extension/llm/export/config/llm_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,16 @@ class BaseConfig:
7373
if it is a Llama model or the weights will be downloaded from HuggingFace
7474
if it is a non-Llama model.
7575
checkpoint_dir: Path to directory containing sharded checkpoint files.
76+
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
77+
the model has trained LoRA adapters. Must provide
78+
adapter_config.json.
79+
adapter_config: Path to the adapter_config.json file from torchtune.
80+
Used if the model has trained LoRA adapters. Must provide adapter.pt.
7681
tokenizer_path: Path to the tokenizer file.
7782
metadata: Json string containing metadata information.
7883
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
79-
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
84+
use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled
85+
if set to 0.
8086
fairseq2: For legacy internal use cases, this is safe to ignore.
8187
preq_mode: Legacy option to specify how prequantized weights are loaded.
8288
Going forward, ExecuTorch supports loading weights prequantized through
@@ -90,6 +96,8 @@ class BaseConfig:
9096
params: Optional[str] = None
9197
checkpoint: Optional[str] = None
9298
checkpoint_dir: Optional[str] = None
99+
adapter_checkpoint: Optional[str] = None
100+
adapter_config: Optional[str] = None
93101
tokenizer_path: Optional[str] = None
94102
metadata: Optional[str] = None
95103
use_lora: int = 0

0 commit comments

Comments
 (0)