Skip to content

Commit 13755c8

Browse files
committed
[LoRA] Add transforms to inject/optimize LoRA
1 parent 10b5259 commit 13755c8

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

mlc_llm/core.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,15 +1501,35 @@ def build_model_from_args(args: argparse.Namespace):
15011501

15021502
transform_seq = []
15031503

1504+
if args.lora is not None:
1505+
prefill_name = "prefill" if "prefill" in mod else "prefill_with_embed"
1506+
transform_seq.extend(
1507+
[
1508+
remove_decode_func,
1509+
relax.transform.DeadCodeElimination(),
1510+
# TODO: Shouldn't assume that we know which parameters
1511+
# should be lora-tuned at this stage. Maybe start with
1512+
# everything being lora-ized, then make specialized
1513+
# versions with BindParams for common cases?
1514+
lora_optimization_pipeline(mod[prefill_name].params, args.lora),
1515+
]
1516+
)
1517+
15041518
transform_seq.append(optimize_mod_pipeline(args, model_config))
15051519

1520+
if args.lora is not None:
1521+
transform_seq.append(bundle_lora_params)
1522+
15061523
transform_seq.append(
15071524
tvm.ir.transform.ApplyPassToFunction(
15081525
relax.transform.BundleModelParams("base_params"),
15091526
"(?!transform_params).*",
15101527
)
15111528
)
15121529

1530+
if args.lora is not None:
1531+
transform_seq.append(reorder_lora_params_after_base_model_params)
1532+
15131533
transform_seq.append(
15141534
tvm.ir.transform.ApplyPassToFunction(
15151535
tvm.ir.transform.Sequential(
@@ -1523,6 +1543,13 @@ def build_model_from_args(args: argparse.Namespace):
15231543
)
15241544
)
15251545

1546+
if args.lora is not None:
1547+
# TODO(Lunderberg): Replace this with
1548+
# transform.CheckForSpecialCase once
1549+
# https://github.com/apache/tvm/pull/16457 is fully implemented
1550+
# and landed.
1551+
transform_seq.append(auto_generate_decode_func)
1552+
15261553
mod = tvm.ir.transform.Sequential(transform_seq, name="OptimizeMLCModel")(mod)
15271554

15281555
mod.show(

0 commit comments

Comments
 (0)