@@ -281,11 +281,6 @@ def clickpath_setup(is_dir: bool) -> click.Path:
281
281
config_sections = ADDITIONAL_ARGUMENTS ,
282
282
required = True , # default from config
283
283
)
284
- @click .option (
285
- "--legacy" ,
286
- is_flag = True ,
287
- help = "if true, enables the legacy linux training code path from release 0.17.0 and prior." ,
288
- )
289
284
@click .option (
290
285
"--strategy" ,
291
286
type = click .Choice (
@@ -361,6 +356,14 @@ def clickpath_setup(is_dir: bool) -> click.Path:
361
356
is_flag = True ,
362
357
help = "By default, checkpoints are saved at the end of each training epoch. This option disables this behavior." ,
363
358
)
359
+ @click .option (
360
+ "--pipeline" ,
361
+ type = click .Choice (["simple" , "full" , "accelerated" ]),
362
+ default = "accelerated" ,
363
+ help = "Model fidelity pipeline for training: 'simple' uses SFTTrainer on Linux or MLX on MacOS, producing low-fidelity models quickly for rapid prototyping."
364
+ "'full' employs CPU and MPS optimized InstructLab fine-tuning, generating medium-fidelity models over a longer period."
365
+ "'accelerated' utilizes GPU acceleration and distributed training, yielding high-fidelity models but requiring more time. Choose based on your hardware, time constraints, and desired model quality" ,
366
+ )
364
367
@click .pass_context
365
368
@clickext .display_params
366
369
def train (
@@ -377,7 +380,6 @@ def train(
377
380
num_epochs ,
378
381
device : str ,
379
382
four_bit_quant : bool ,
380
- legacy ,
381
383
strategy : str | None ,
382
384
phased_base_dir : pathlib .Path ,
383
385
phased_phase1_data : pathlib .Path | None ,
@@ -391,13 +393,21 @@ def train(
391
393
phased_mt_bench_judge : pathlib .Path | None ,
392
394
skip_user_confirm : bool ,
393
395
enable_serving_output : bool ,
396
+ pipeline : str ,
394
397
** kwargs ,
395
398
):
396
399
"""
397
400
Takes synthetic data generated locally with `ilab data generate` and the previous model and learns a new model using the MLX API.
398
401
On success, writes newly learned model to {model_dir}/mlx_model, which is where `chatmlx` will look for a model.
399
402
"""
400
403
torch .set_autocast_enabled (False )
404
+
405
+ if (
406
+ pipeline in ("full" , "simple" )
407
+ and strategy == SupportedTrainingStrategies .LAB_MULTIPHASE .value
408
+ ):
409
+ ctx .fail ("Multi Phase training is only supported with `--pipeline accelerated`" )
410
+
401
411
if not input_dir :
402
412
# By default, generate output-dir is used as train input-dir
403
413
input_dir = ctx .obj .config .generate .output_dir
@@ -537,12 +547,153 @@ def get_files(directory: str, pattern: str) -> list[str]:
537
547
)
538
548
539
549
# we can use train_args locally to run lower fidelity training
540
- if is_high_fidelity (device ):
550
+ if is_high_fidelity (device ) or pipeline == "accelerated" :
541
551
run_training (train_args = train_args , torch_args = torch_args , device = device )
542
- else :
552
+ elif not is_high_fidelity (device ) or pipeline == "full" :
553
+ # if on CPU or MPS, execute full train, which is based
554
+ # off of the structure of the training repo, just with different optimizers, model sizes, and special data gradient accumulation to get it
555
+ # to fit on most consumer laptops
543
556
full_train .train (train_args , device )
544
-
545
-
557
+ elif pipeline == "simple" :
558
+ if utils .is_macos_with_m_chip () and not strategy :
559
+ # Local
560
+ from ..mlx_explore .gguf_convert_to_mlx import load
561
+ from ..mlx_explore .utils import fetch_tokenizer_from_hub
562
+ from ..train .lora_mlx .convert import convert_between_mlx_and_pytorch
563
+ from ..train .lora_mlx .lora import load_and_train
564
+ from ..train .lora_mlx .make_data import make_data
565
+
566
+ if not skip_preprocessing :
567
+ try :
568
+ make_data (data_dir = data_path )
569
+ except FileNotFoundError as exc :
570
+ click .secho (
571
+ f"Could not read from data directory: { exc } " ,
572
+ fg = "red" ,
573
+ )
574
+ raise click .exceptions .Exit (1 )
575
+
576
+ # NOTE we can skip this if we have a way to ship MLX
577
+ # PyTorch safetensors to MLX safetensors
578
+ model_dir_local = model_path .replace ("/" , "-" )
579
+ model_dir_local = f"{ ckpt_output_dir } /{ model_dir_local } "
580
+ model_dir_mlx = f"{ model_dir_local } -mlx"
581
+ model_dir_mlx_quantized = f"{ model_dir_local } -mlx-q"
582
+
583
+ if skip_quantize :
584
+ dest_model_dir = model_dir_mlx
585
+ quantize_arg = False
586
+ else :
587
+ dest_model_dir = model_dir_mlx_quantized
588
+ quantize_arg = True
589
+
590
+ if tokenizer_dir is not None and gguf_model_path is not None :
591
+ if not local :
592
+ tokenizer_dir_local = tokenizer_dir .replace ("/" , "-" )
593
+ fetch_tokenizer_from_hub (tokenizer_dir , tokenizer_dir_local )
594
+
595
+ # no need to pass quantize_arg for now, script automatically detects if quantization is necessary based on whether gguf model is quantized or not
596
+ load (
597
+ gguf = gguf_model_path ,
598
+ repo = tokenizer_dir ,
599
+ mlx_path = dest_model_dir ,
600
+ )
601
+
602
+ for filename in os .listdir (model_dir_local ):
603
+ shutil .copy (
604
+ os .path .join (model_dir_local , filename ),
605
+ os .path .join (dest_model_dir , filename ),
606
+ )
607
+ shutil .rmtree (model_dir_local , ignore_errors = True )
608
+
609
+ else :
610
+ # Downloading PyTorch SafeTensor and Converting to MLX SafeTensor
611
+ convert_between_mlx_and_pytorch (
612
+ hf_path = model_path ,
613
+ mlx_path = dest_model_dir ,
614
+ quantize = quantize_arg ,
615
+ local = local ,
616
+ )
617
+
618
+ adapter_file_path = f"{ dest_model_dir } /adapters.npz"
619
+
620
+ # train the model with LoRA
621
+ load_and_train (
622
+ model = dest_model_dir ,
623
+ train = True ,
624
+ data = data_path ,
625
+ adapter_file = adapter_file_path ,
626
+ iters = iters ,
627
+ save_every = 10 ,
628
+ steps_per_eval = 10 ,
629
+ )
630
+ else :
631
+ # Local
632
+ from ..llamacpp .llamacpp_convert_to_gguf import convert_llama_to_gguf
633
+ from ..train .linux_train import linux_train
634
+
635
+ training_results_dir = linux_train (
636
+ ctx = ctx ,
637
+ train_file = train_file ,
638
+ test_file = test_file ,
639
+ model_name = model_path ,
640
+ num_epochs = num_epochs ,
641
+ train_device = device ,
642
+ four_bit_quant = four_bit_quant ,
643
+ )
644
+
645
+ final_results_dir = training_results_dir / "final"
646
+ if final_results_dir .exists ():
647
+ shutil .rmtree (final_results_dir )
648
+ final_results_dir .mkdir ()
649
+
650
+ gguf_models_dir = Path (DEFAULTS .CHECKPOINTS_DIR )
651
+ gguf_models_dir .mkdir (exist_ok = True )
652
+ gguf_models_file = gguf_models_dir / "ggml-model-f16.gguf"
653
+
654
+ # Remove previously trained model, its taking up space we may need in the next step
655
+ gguf_models_file .unlink (missing_ok = True )
656
+
657
+ # TODO: Figure out what to do when there are multiple checkpoint dirs.
658
+ # Right now it's just copying files from the first one numerically not necessarily the best one
659
+ for fpath in (
660
+ "checkpoint-*/added_tokens.json" ,
661
+ "checkpoint-*/special_tokens_map.json" ,
662
+ "checkpoint-*/tokenizer.json" ,
663
+ "checkpoint-*/tokenizer.model" ,
664
+ "checkpoint-*/tokenizer_config.json" ,
665
+ "merged_model/config.json" ,
666
+ "merged_model/generation_config.json" ,
667
+ ):
668
+ file_ = next (training_results_dir .glob (fpath ))
669
+ shutil .copy (file_ , final_results_dir )
670
+ logger .info (f"Copied { file_ } to { final_results_dir } " )
671
+
672
+ for file in training_results_dir .glob ("merged_model/*.safetensors" ):
673
+ shutil .move (file , final_results_dir )
674
+ logger .info (f"Moved { file } to { final_results_dir } " )
675
+
676
+ if four_bit_quant :
677
+ logger .info (
678
+ "SKIPPING CONVERSION to gguf. This is unsupported with --4-bit-quant. "
679
+ + "See https://github.com/instructlab/instructlab/issues/579."
680
+ )
681
+ return
682
+
683
+ gguf_file_path = convert_llama_to_gguf (
684
+ model = final_results_dir , pad_vocab = True
685
+ )
686
+
687
+ # Remove safetensors files to save space, were done with them here
688
+ # and the huggingface lib has them cached
689
+ for file in final_results_dir .glob ("*.safetensors" ):
690
+ file .unlink ()
691
+
692
+ shutil .move (gguf_file_path , gguf_models_file )
693
+ logger .info (f"Save trained model to { gguf_models_file } " )
694
+
695
+
696
+ # chooses which type of training to run depending on the device provided
546
697
def is_high_fidelity (device ):
547
698
return device == "cuda" or device == "hpu"
548
699
0 commit comments