@@ -1501,15 +1501,35 @@ def build_model_from_args(args: argparse.Namespace):
15011501
15021502 transform_seq = []
15031503
1504+ if args .lora is not None :
1505+ prefill_name = "prefill" if "prefill" in mod else "prefill_with_embed"
1506+ transform_seq .extend (
1507+ [
1508+ remove_decode_func ,
1509+ relax .transform .DeadCodeElimination (),
1510+ # TODO: Shouldn't assume that we know which parameters
1511+ # should be lora-tuned at this stage. Maybe start with
1512+ # everything being lora-ized, then make specialized
1513+ # versions with BindParams for common cases?
1514+ lora_optimization_pipeline (mod [prefill_name ].params , args .lora ),
1515+ ]
1516+ )
1517+
15041518 transform_seq .append (optimize_mod_pipeline (args , model_config ))
15051519
1520+ if args .lora is not None :
1521+ transform_seq .append (bundle_lora_params )
1522+
15061523 transform_seq .append (
15071524 tvm .ir .transform .ApplyPassToFunction (
15081525 relax .transform .BundleModelParams ("base_params" ),
15091526 "(?!transform_params).*" ,
15101527 )
15111528 )
15121529
1530+ if args .lora is not None :
1531+ transform_seq .append (reorder_lora_params_after_base_model_params )
1532+
15131533 transform_seq .append (
15141534 tvm .ir .transform .ApplyPassToFunction (
15151535 tvm .ir .transform .Sequential (
@@ -1523,6 +1543,13 @@ def build_model_from_args(args: argparse.Namespace):
15231543 )
15241544 )
15251545
1546+ if args .lora is not None :
1547+ # TODO(Lunderberg): Replace this with
1548+ # transform.CheckForSpecialCase once
1549+ # https://github.com/apache/tvm/pull/16457 is fully implemented
1550+ # and landed.
1551+ transform_seq .append (auto_generate_decode_func )
1552+
15261553 mod = tvm .ir .transform .Sequential (transform_seq , name = "OptimizeMLCModel" )(mod )
15271554
15281555 mod .show (
0 commit comments