Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ pip install -U nvidia-modelopt[hf]
pip install -r requirements.txt
```

We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Download by:
### Data Preparation

We use [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset in this example. Prepare data by:

```bash
apt-get update && apt-get install -y git-lfs
git lfs install --system
git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
python prepare_input_conversations/add_daring_anteater.py
```

See [other-datasets](#other-datasets) section for other dataset options and instruction for user-provided data.

## Getting Started: Simplified Workflow

```bash
Expand All @@ -71,7 +73,7 @@ For small base models that fit in GPU memory, we can collocate them with draft m
```bash
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data Daring-Anteater/train.jsonl \
--data input_conversations/daring-anteater.jsonl \
--num_gpu $NUM_GPU \
--num_epochs $NUM_EPOCH \
--eagle_config eagle_config.json
Expand All @@ -91,7 +93,7 @@ We support two backends for generating base model hidden states. For better effc
```bash
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--input-file input_conversations/daring-anteater.jsonl \
--output-dir $HIDDEN_STATES_DIR
```

Expand All @@ -102,7 +104,7 @@ Alternatively, you can generate the same hidden states with HF:
```bash
python collect_hidden_states/compute_hidden_states_hf.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--input-file input_conversations/daring-anteater.jsonl \
--output-dir $HIDDEN_STATES_DIR
```

Expand Down Expand Up @@ -130,7 +132,7 @@ For online training checkpoints, we can run in-framework evaluation on MT-bench:
python ar_validate.py --model_path $ONLINE_CKPT
```

Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, what's the difference between online training and offline training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In online training we load both the teacher model and student model into GPU DRAM and perform teacher forward + student forward/backward in a training step.

In offline training, we inference the teacher model first (with HF or TRTLLM), dump hidden states to disks, then train the draft model only with dumped distillation signals. Then the checkpoint does not contains teacher modules. : )


## Export

Expand Down Expand Up @@ -183,6 +185,28 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE

## Advanced Usage

### Other Datasets

In addition to `daring-anteater`, we provide scripts for adding several other commonly used datasets in `prepare_input_conversations`:

```text
prepare_input_conversations/
├── add_daring_anteater.py
├── add_mtbench.py
├── add_sharegpt.py
├── add_ultrachat.py
└── example_make_prompt_dataset.sh
```

To use your own datasets, please preprocess your data into a `.jsonl` file with each line in the format:

```json
{
"conversation_id": <unique id>,
"conversations": [{"role":<user or assistant>, "content":<content>}]
}
```

### Data Synthesis

To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model.
Expand All @@ -199,7 +223,7 @@ Note: Add `--quantization=modelopt` flag for quantized models.
Then, we generate conversations with the base model using prompts from Daring-Anteater:

```bash
python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl
python server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl
```

To add a system prompt, use the `--system_prompt <system_prompt_text>` argument.
Expand All @@ -211,7 +235,7 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:

```bash
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data Daring-Anteater/train.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
```

This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
Expand Down
4 changes: 2 additions & 2 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@

try:
import wandb

wandb.init()
except ImportError:
wandb = None

Expand Down Expand Up @@ -397,6 +395,8 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
class ARValidationCallback(TrainerCallback):
def __init__(self, ar_validate_steps: int = 1000):
self.ar_validate_steps = ar_validate_steps
if wandb:
wandb.init()

def on_step_end(self, args, state, control, **kwargs):
if self.ar_validate_steps <= 0:
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative_decoding/train_eagle3_and_export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set -eo pipefail
# Set default values for BASE_MODEL, NUM_GPU, and DATA
BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
NUM_GPU=1
DATA=Daring-Anteater/train.jsonl
DATA=input_conversations/daring-anteater.jsonl

# Parse input arguments --base_model, --num_gpu, and --data
while [[ $# -gt 0 ]]; do
Expand Down
30 changes: 12 additions & 18 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,18 @@ def _check_state_dict_keys_match(draft_model: nn.Module, required_items: dict):
raise ValueError(f"State dict keys mismatch!\nMissing in draft model: {required_key}")


def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
def spec_opt_only(model: nn.Module):
"""Check if the model have only speculative decoding optimization."""
opt_modes = getattr(model, "_modelopt_state", None)
return (
isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle"
)


def export_spec_ckpt_state_dict(model: nn.Module):
"""Only return the state dict of the draft model in official format and ignore the base model."""
# check the model has only speculative decoding
opt_modes = getattr(model, "_modelopt_state", None)
if (
not isinstance(opt_modes, (list, tuple))
or len(opt_modes) != 1
or opt_modes[0][0] != "eagle"
):
# if there's other opts, return as is
return post_state_dict
assert spec_opt_only(model), "Not purely eagle model."

# Check if the state dict keys match
_check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])
Expand All @@ -80,16 +81,9 @@ def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict):
return export_state_dict


def set_config_if_spec_decoding(model: nn.Module, config_data: dict):
def export_spec_ckpt_config(model: nn.Module):
"""Return the config of draft model in official format."""
opt_modes = getattr(model, "_modelopt_state", None)
if (
not isinstance(opt_modes, (list, tuple))
or len(opt_modes) != 1
or opt_modes[0][0] != "eagle"
):
# return as is
return config_data
assert spec_opt_only(model), "Not purely eagle model."

# This is the config keys in official checkpoint.
template_config = {
Expand Down
17 changes: 11 additions & 6 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
import torch.nn as nn
from safetensors.torch import save_file

from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
Expand Down Expand Up @@ -53,7 +54,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
get_activation_scaling_factor,
Expand Down Expand Up @@ -511,18 +512,24 @@ def export_hf_checkpoint(
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)

# NOTE: (hg) Early exit for speculative decoding models
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
if spec_opt_only(model):
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
with open(f"{export_dir}/config.json", "w") as file:
json.dump(export_spec_ckpt_config(model), file, indent=4)
return
Comment on lines +516 to +522
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling to the early exit path.

The early exit path lacks error handling that exists in the main export flow (lines 520-550). If save_file or the file write operations fail, the error will propagate without proper context or cleanup.

Apply this diff to add error handling:

     export_dir.mkdir(parents=True, exist_ok=True)
 
     # NOTE: (hg) Early exit for speculative decoding models
     # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
     if spec_opt_only(model):
-        save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
-        with open(f"{export_dir}/config.json", "w") as file:
-            json.dump(export_spec_ckpt_config(model), file, indent=4)
-        return
+        try:
+            save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
+            with open(f"{export_dir}/config.json", "w") as file:
+                json.dump(export_spec_ckpt_config(model), file, indent=4)
+            return
+        except Exception as e:
+            warnings.warn(
+                "Cannot export speculative-only model to the model_config. The modelopt-optimized model state_dict"
+                " can be saved with torch.save for further inspection."
+            )
+            raise e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# NOTE: (hg) Early exit for speculative decoding models
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
if spec_opt_only(model):
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
with open(f"{export_dir}/config.json", "w") as file:
json.dump(export_spec_ckpt_config(model), file, indent=4)
return
export_dir.mkdir(parents=True, exist_ok=True)
# NOTE: (hg) Early exit for speculative decoding models
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
if spec_opt_only(model):
try:
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
with open(f"{export_dir}/config.json", "w") as file:
json.dump(export_spec_ckpt_config(model), file, indent=4)
return
except Exception as e:
warnings.warn(
"Cannot export speculative-only model to the model_config. The modelopt-optimized model state_dict"
" can be saved with torch.save for further inspection."
)
raise e
🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 512 to 518, the early
exit branch that writes the speculative checkpoint lacks the error handling
present in the main export path; wrap the calls to save_file(...) and writing
config.json in a try/except, and on exception log a clear error message
including the exception details, remove any partially written files
(model.safetensors and config.json) and the export_dir if appropriate, then
re-raise the exception so callers can handle it (mirror the main export flow's
cleanup and logging behavior).


try:
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)

# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied?
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
json.dump(hf_quant_config, file, indent=4)

hf_quant_config = convert_hf_quant_config_format(hf_quant_config)

post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)

# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
Expand All @@ -536,8 +543,6 @@ def export_hf_checkpoint(

config_data["quantization_config"] = hf_quant_config

config_data = set_config_if_spec_decoding(model, config_data)

with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)

Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"has_lm_head": False,
"head_dim": 128,
}
1 change: 1 addition & 0 deletions tests/examples/speculative_decoding/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_
"intermediate_size": 64,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"head_dim": 64,
}

# Write the tiny config to a temporary file
Expand Down