From 7ca3dbfea46a669fbed88f29f89a3b61380d3ceb Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:33:28 +0000 Subject: [PATCH 1/8] fix for gptoss Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../torch/export/plugins/hf_spec_export.py | 30 ++++++++----------- modelopt/torch/export/unified_export_hf.py | 17 +++++++---- .../torch/speculative/eagle/default_config.py | 1 + 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index fe044828a..9f89cb269 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict): raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}") -def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict): +def spec_opt_only(model: nn.Module): + """Check if the model have only speculative decoding optimization.""" + opt_modes = getattr(model, "_modelopt_state", None) + return ( + isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle" + ) + + +def export_spec_ckpt_state_dict(model: nn.Module): """Only return the state dict of the draft model in official format and ignore the base model.""" # check the model has only speculative decoding - opt_modes = getattr(model, "_modelopt_state", None) - if ( - not isinstance(opt_modes, (list, tuple)) - or len(opt_modes) != 1 - or opt_modes[0][0] != "eagle" - ): - # if there's other opts, return as is - return post_state_dict + assert spec_opt_only(model), "Not purely eagle model." # Check if the state dict keys match _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) @@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict): return export_state_dict -def set_config_if_spec_decoding(model: nn.Module, config_data: dict): +def export_spec_ckpt_config(model: nn.Module): """Return the config of draft model in official format.""" - opt_modes = getattr(model, "_modelopt_state", None) - if ( - not isinstance(opt_modes, (list, tuple)) - or len(opt_modes) != 1 - or opt_modes[0][0] != "eagle" - ): - # return as is - return config_data + assert spec_opt_only(model), "Not purely eagle model." # This is the config keys in official checkpoint. template_config = { diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 2a69831e9..e3cb9dd8e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn +from safetensors.torch import save_file from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer @@ -53,7 +54,7 @@ QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ) -from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding +from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, get_activation_scaling_factor, @@ -511,18 +512,24 @@ def export_hf_checkpoint( """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) + + # Early exit for speculative decoding models + # We do this since some spec models get error in convert_hf_quant_config_format + if spec_opt_only(model): + save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") + with open(f"{export_dir}/config.json", "w") as file: + json.dump(export_spec_ckpt_config(model), file, indent=4) + return + try: post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) - # NOTE: (hg) Should we save hf_quant_config when there's no quantization applied? # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) hf_quant_config = convert_hf_quant_config_format(hf_quant_config) - post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) - # Save model model.save_pretrained( export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state @@ -536,8 +543,6 @@ def export_hf_checkpoint( config_data["quantization_config"] = hf_quant_config - config_data = set_config_if_spec_decoding(model, config_data) - with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index f8c69b2ff..edbc1e360 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -47,4 +47,5 @@ "use_mtp_layernorm": False, "parallel_draft_step": 1, "has_lm_head": False, + "head_dim": 64, } From 85017cb63ab9dfa153ab13f89c532c5e5d8cdd02 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:44:09 +0000 Subject: [PATCH 2/8] minor: update readme Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 31133e5df..f110af8ce 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -130,7 +130,7 @@ For online training checkpoints, we can run in-framework evaluation on MT-bench: python ar_validate.py --model_path $ONLINE_CKPT ``` -Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks. +**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks. ## Export From 1d1972fb24551df206db121b1b2644395d309631 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 3 Oct 2025 20:42:36 +0000 Subject: [PATCH 3/8] update default headdim to 128 Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/eagle/default_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index edbc1e360..1a7f1fddc 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -47,5 +47,5 @@ "use_mtp_layernorm": False, "parallel_draft_step": 1, "has_lm_head": False, - "head_dim": 64, + "head_dim": 128, } From fb0f020ca286ff3dd9bd134f6942c18ba5eb10dc Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 7 Oct 2025 22:56:08 +0000 Subject: [PATCH 4/8] update comments Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/export/unified_export_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e3cb9dd8e..f966ffac6 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -513,8 +513,8 @@ def export_hf_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - # Early exit for speculative decoding models - # We do this since some spec models get error in convert_hf_quant_config_format + # NOTE: (hg) Early exit for speculative decoding models + # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint if spec_opt_only(model): save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") with open(f"{export_dir}/config.json", "w") as file: From a9882fea4160b7f1086a2ae51d9ccbe8a6380bca Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 7 Oct 2025 23:15:58 +0000 Subject: [PATCH 5/8] udpate test_eagle.py to avoid shared mem oom Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- tests/examples/speculative_decoding/test_eagle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index c81bc9363..2049f77a7 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -27,6 +27,7 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ "intermediate_size": 64, "num_attention_heads": 2, "num_key_value_heads": 2, + "head_dim": 64, } # Write the tiny config to a temporary file From 53b9466f6fdb2f70f7f427ba88cc5dfe9ca83f3d Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 23:26:33 +0000 Subject: [PATCH 6/8] update data instructions; remove redundant wandb init Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 42 +++++++++++++++---- examples/speculative_decoding/eagle_utils.py | 2 - .../train_eagle3_and_export.sh | 2 +- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index f110af8ce..bd7ca696f 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -43,14 +43,16 @@ pip install -U nvidia-modelopt[hf] pip install -r requirements.txt ``` -We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Download by: +### Data Preparation + +We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Prepare data by: ```bash -apt-get update && apt-get install -y git-lfs -git lfs install --system -git clone https://huggingface.co/datasets/nvidia/Daring-Anteater +python prepare_input_conversations/add_daring_anteater.py ``` +See [custom dataset](#custom-datasets) section for other dataset options and instruction for user-provided data. + ## Getting Started: Simplified Workflow ```bash @@ -71,7 +73,7 @@ For small base models that fit in GPU memory, we can collocate them with draft m ```bash ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ - --data Daring-Anteater/train.jsonl \ + --data input_conversations/daring-anteater.jsonl \ --num_gpu $NUM_GPU \ --num_epochs $NUM_EPOCH \ --eagle_config eagle_config.json @@ -91,7 +93,7 @@ We support two backends for generating base model hidden states. For better effc ```bash python collect_hidden_states/compute_hidden_states_trtllm.py \ --model $BASE_MODEL \ - --input-file Daring-Anteater/train.jsonl \ + --input-file input_conversations/daring-anteater.jsonl \ --output-dir $HIDDEN_STATES_DIR ``` @@ -102,7 +104,7 @@ Alternatively, you can generate the same hidden states with HF: ```bash python collect_hidden_states/compute_hidden_states_hf.py \ --model $BASE_MODEL \ - --input-file Daring-Anteater/train.jsonl \ + --input-file input_conversations/daring-anteater.jsonl \ --output-dir $HIDDEN_STATES_DIR ``` @@ -183,6 +185,28 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE ## Advanced Usage +### Custom Datasets + +In addition to `daring-anteater`, we provide scripts for adding several other commonly used datasets in `prepare_input_conversations`: + +```text +prepare_input_conversations/ + ├── add_daring_anteater.py + ├── add_mtbench.py + ├── add_sharegpt.py + ├── add_ultrachat.py + └── example_make_prompt_dataset.sh +``` + +To use custom datasets, please preprocess your data into a `.jsonl` file with each line in the format: + +```json +{ + "conversation_id": , + "conversations": [{"role":, "content":}] +} +``` + ### Data Synthesis To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model. @@ -199,7 +223,7 @@ Note: Add `--quantization=modelopt` flag for quantized models. Then, we generate conversations with the base model using prompts from Daring-Anteater: ```bash -python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl +python server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl ``` To add a system prompt, use the `--system_prompt ` argument. @@ -211,7 +235,7 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: ```bash -python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache +python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache ``` This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 452bb4ded..c3573fb77 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -29,8 +29,6 @@ try: import wandb - - wandb.init() except ImportError: wandb = None diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 6af635be0..76e56da38 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -20,7 +20,7 @@ set -eo pipefail # Set default values for BASE_MODEL, NUM_GPU, and DATA BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct NUM_GPU=1 -DATA=Daring-Anteater/train.jsonl +DATA=input_conversations/daring-anteater.jsonl # Parse input arguments --base_model, --num_gpu, and --data while [[ $# -gt 0 ]]; do From e66985c8259fb95d14a45714a02d9cb21fcec259 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 23:30:59 +0000 Subject: [PATCH 7/8] minor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index bd7ca696f..2936b1d5a 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -51,7 +51,7 @@ We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) python prepare_input_conversations/add_daring_anteater.py ``` -See [custom dataset](#custom-datasets) section for other dataset options and instruction for user-provided data. +See [other-datasets](#other-datasets) section for other dataset options and instruction for user-provided data. ## Getting Started: Simplified Workflow @@ -185,7 +185,7 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE ## Advanced Usage -### Custom Datasets +### Other Datasets In addition to `daring-anteater`, we provide scripts for adding several other commonly used datasets in `prepare_input_conversations`: @@ -198,7 +198,7 @@ prepare_input_conversations/ └── example_make_prompt_dataset.sh ``` -To use custom datasets, please preprocess your data into a `.jsonl` file with each line in the format: +To use your own datasets, please preprocess your data into a `.jsonl` file with each line in the format: ```json { From 80c6cadf9f13fc3ca8af5c8950276f4a5584f547 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 23:36:17 +0000 Subject: [PATCH 8/8] minor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index c3573fb77..576179dd1 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -395,6 +395,8 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: class ARValidationCallback(TrainerCallback): def __init__(self, ar_validate_steps: int = 1000): self.ar_validate_steps = ar_validate_steps + if wandb: + wandb.init() def on_step_end(self, args, state, control, **kwargs): if self.ar_validate_steps <= 0: