Skip to content

Commit 3897ff2

Browse files
committed
introduce --pipeline
--pipeline has 3 options: simple, full, and accelerated. simple will run either MLX train or Linux_train. Full will run the CPU/MPS optimized version of full fine tuning. Accelerated will shell out to the library code for larger GPU support this conforms well with SDG --pipeline, it also allows us to maintaim the SFTTrainer and MLX while also supporting our own training loop Signed-off-by: Charlie Doern <[email protected]>
1 parent b3165dc commit 3897ff2

File tree

1 file changed

+161
-10
lines changed

1 file changed

+161
-10
lines changed

src/instructlab/model/train.py

+161-10
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,6 @@ def clickpath_setup(is_dir: bool) -> click.Path:
281281
config_sections=ADDITIONAL_ARGUMENTS,
282282
required=True, # default from config
283283
)
284-
@click.option(
285-
"--legacy",
286-
is_flag=True,
287-
help="if true, enables the legacy linux training code path from release 0.17.0 and prior.",
288-
)
289284
@click.option(
290285
"--strategy",
291286
type=click.Choice(
@@ -361,6 +356,14 @@ def clickpath_setup(is_dir: bool) -> click.Path:
361356
is_flag=True,
362357
help="By default, checkpoints are saved at the end of each training epoch. This option disables this behavior.",
363358
)
359+
@click.option(
360+
"--pipeline",
361+
type=click.Choice(["simple", "full", "accelerated"]),
362+
default="accelerated",
363+
help="Model fidelity pipeline for training: 'simple' uses SFTTrainer on Linux or MLX on MacOS, producing low-fidelity models quickly for rapid prototyping."
364+
"'full' employs CPU and MPS optimized InstructLab fine-tuning, generating medium-fidelity models over a longer period."
365+
"'accelerated' utilizes GPU acceleration and distributed training, yielding high-fidelity models but requiring more time. Choose based on your hardware, time constraints, and desired model quality",
366+
)
364367
@click.pass_context
365368
@clickext.display_params
366369
def train(
@@ -377,7 +380,6 @@ def train(
377380
num_epochs,
378381
device: str,
379382
four_bit_quant: bool,
380-
legacy,
381383
strategy: str | None,
382384
phased_base_dir: pathlib.Path,
383385
phased_phase1_data: pathlib.Path | None,
@@ -391,13 +393,21 @@ def train(
391393
phased_mt_bench_judge: pathlib.Path | None,
392394
skip_user_confirm: bool,
393395
enable_serving_output: bool,
396+
pipeline: str,
394397
**kwargs,
395398
):
396399
"""
397400
Takes synthetic data generated locally with `ilab data generate` and the previous model and learns a new model using the MLX API.
398401
On success, writes newly learned model to {model_dir}/mlx_model, which is where `chatmlx` will look for a model.
399402
"""
400403
torch.set_autocast_enabled(False)
404+
405+
if (
406+
pipeline in ("full", "simple")
407+
and strategy == SupportedTrainingStrategies.LAB_MULTIPHASE.value
408+
):
409+
ctx.fail("Multi Phase training is only supported with `--pipeline accelerated`")
410+
401411
if not input_dir:
402412
# By default, generate output-dir is used as train input-dir
403413
input_dir = ctx.obj.config.generate.output_dir
@@ -537,12 +547,153 @@ def get_files(directory: str, pattern: str) -> list[str]:
537547
)
538548

539549
# we can use train_args locally to run lower fidelity training
540-
if is_high_fidelity(device):
550+
if is_high_fidelity(device) or pipeline == "accelerated":
541551
run_training(train_args=train_args, torch_args=torch_args, device=device)
542-
else:
552+
elif not is_high_fidelity(device) or pipeline == "full":
553+
# if on CPU or MPS, execute full train, which is based
554+
# off of the structure of the training repo, just with different optimizers, model sizes, and special data gradient accumulation to get it
555+
# to fit on most consumer laptops
543556
full_train.train(train_args, device)
544-
545-
557+
elif pipeline == "simple":
558+
if utils.is_macos_with_m_chip() and not strategy:
559+
# Local
560+
from ..mlx_explore.gguf_convert_to_mlx import load
561+
from ..mlx_explore.utils import fetch_tokenizer_from_hub
562+
from ..train.lora_mlx.convert import convert_between_mlx_and_pytorch
563+
from ..train.lora_mlx.lora import load_and_train
564+
from ..train.lora_mlx.make_data import make_data
565+
566+
if not skip_preprocessing:
567+
try:
568+
make_data(data_dir=data_path)
569+
except FileNotFoundError as exc:
570+
click.secho(
571+
f"Could not read from data directory: {exc}",
572+
fg="red",
573+
)
574+
raise click.exceptions.Exit(1)
575+
576+
# NOTE we can skip this if we have a way to ship MLX
577+
# PyTorch safetensors to MLX safetensors
578+
model_dir_local = model_path.replace("/", "-")
579+
model_dir_local = f"{ckpt_output_dir}/{model_dir_local}"
580+
model_dir_mlx = f"{model_dir_local}-mlx"
581+
model_dir_mlx_quantized = f"{model_dir_local}-mlx-q"
582+
583+
if skip_quantize:
584+
dest_model_dir = model_dir_mlx
585+
quantize_arg = False
586+
else:
587+
dest_model_dir = model_dir_mlx_quantized
588+
quantize_arg = True
589+
590+
if tokenizer_dir is not None and gguf_model_path is not None:
591+
if not local:
592+
tokenizer_dir_local = tokenizer_dir.replace("/", "-")
593+
fetch_tokenizer_from_hub(tokenizer_dir, tokenizer_dir_local)
594+
595+
# no need to pass quantize_arg for now, script automatically detects if quantization is necessary based on whether gguf model is quantized or not
596+
load(
597+
gguf=gguf_model_path,
598+
repo=tokenizer_dir,
599+
mlx_path=dest_model_dir,
600+
)
601+
602+
for filename in os.listdir(model_dir_local):
603+
shutil.copy(
604+
os.path.join(model_dir_local, filename),
605+
os.path.join(dest_model_dir, filename),
606+
)
607+
shutil.rmtree(model_dir_local, ignore_errors=True)
608+
609+
else:
610+
# Downloading PyTorch SafeTensor and Converting to MLX SafeTensor
611+
convert_between_mlx_and_pytorch(
612+
hf_path=model_path,
613+
mlx_path=dest_model_dir,
614+
quantize=quantize_arg,
615+
local=local,
616+
)
617+
618+
adapter_file_path = f"{dest_model_dir}/adapters.npz"
619+
620+
# train the model with LoRA
621+
load_and_train(
622+
model=dest_model_dir,
623+
train=True,
624+
data=data_path,
625+
adapter_file=adapter_file_path,
626+
iters=iters,
627+
save_every=10,
628+
steps_per_eval=10,
629+
)
630+
else:
631+
# Local
632+
from ..llamacpp.llamacpp_convert_to_gguf import convert_llama_to_gguf
633+
from ..train.linux_train import linux_train
634+
635+
training_results_dir = linux_train(
636+
ctx=ctx,
637+
train_file=train_file,
638+
test_file=test_file,
639+
model_name=model_path,
640+
num_epochs=num_epochs,
641+
train_device=device,
642+
four_bit_quant=four_bit_quant,
643+
)
644+
645+
final_results_dir = training_results_dir / "final"
646+
if final_results_dir.exists():
647+
shutil.rmtree(final_results_dir)
648+
final_results_dir.mkdir()
649+
650+
gguf_models_dir = Path(DEFAULTS.CHECKPOINTS_DIR)
651+
gguf_models_dir.mkdir(exist_ok=True)
652+
gguf_models_file = gguf_models_dir / "ggml-model-f16.gguf"
653+
654+
# Remove previously trained model, its taking up space we may need in the next step
655+
gguf_models_file.unlink(missing_ok=True)
656+
657+
# TODO: Figure out what to do when there are multiple checkpoint dirs.
658+
# Right now it's just copying files from the first one numerically not necessarily the best one
659+
for fpath in (
660+
"checkpoint-*/added_tokens.json",
661+
"checkpoint-*/special_tokens_map.json",
662+
"checkpoint-*/tokenizer.json",
663+
"checkpoint-*/tokenizer.model",
664+
"checkpoint-*/tokenizer_config.json",
665+
"merged_model/config.json",
666+
"merged_model/generation_config.json",
667+
):
668+
file_ = next(training_results_dir.glob(fpath))
669+
shutil.copy(file_, final_results_dir)
670+
logger.info(f"Copied {file_} to {final_results_dir}")
671+
672+
for file in training_results_dir.glob("merged_model/*.safetensors"):
673+
shutil.move(file, final_results_dir)
674+
logger.info(f"Moved {file} to {final_results_dir}")
675+
676+
if four_bit_quant:
677+
logger.info(
678+
"SKIPPING CONVERSION to gguf. This is unsupported with --4-bit-quant. "
679+
+ "See https://github.com/instructlab/instructlab/issues/579."
680+
)
681+
return
682+
683+
gguf_file_path = convert_llama_to_gguf(
684+
model=final_results_dir, pad_vocab=True
685+
)
686+
687+
# Remove safetensors files to save space, were done with them here
688+
# and the huggingface lib has them cached
689+
for file in final_results_dir.glob("*.safetensors"):
690+
file.unlink()
691+
692+
shutil.move(gguf_file_path, gguf_models_file)
693+
logger.info(f"Save trained model to {gguf_models_file}")
694+
695+
696+
# chooses which type of training to run depending on the device provided
546697
def is_high_fidelity(device):
547698
return device == "cuda" or device == "hpu"
548699

0 commit comments

Comments
 (0)