-
Notifications
You must be signed in to change notification settings - Fork 175
Fix: supporting gpt-oss HF eagle #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7ca3dbf
85017cb
1d1972f
fb0f020
a9882fe
53b9466
e66985c
80c6cad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -26,6 +26,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||
import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||
from safetensors.torch import save_file | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.quantization import set_quantizer_by_cfg_context | ||||||||||||||||||||||||||||||||||||||||||||||||
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer | ||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -53,7 +54,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
QUANTIZATION_W4A8_AWQ, | ||||||||||||||||||||||||||||||||||||||||||||||||
QUANTIZATION_W4A8_NVFP4_FP8, | ||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||
from .plugins import rename_and_prune_if_spec_decoding, set_config_if_spec_decoding | ||||||||||||||||||||||||||||||||||||||||||||||||
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only | ||||||||||||||||||||||||||||||||||||||||||||||||
from .quant_utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||
fuse_prequant_layernorm, | ||||||||||||||||||||||||||||||||||||||||||||||||
get_activation_scaling_factor, | ||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -511,18 +512,24 @@ def export_hf_checkpoint( | |||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||
export_dir = Path(export_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||
export_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
# NOTE: (hg) Early exit for speculative decoding models | ||||||||||||||||||||||||||||||||||||||||||||||||
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint | ||||||||||||||||||||||||||||||||||||||||||||||||
if spec_opt_only(model): | ||||||||||||||||||||||||||||||||||||||||||||||||
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") | ||||||||||||||||||||||||||||||||||||||||||||||||
with open(f"{export_dir}/config.json", "w") as file: | ||||||||||||||||||||||||||||||||||||||||||||||||
json.dump(export_spec_ckpt_config(model), file, indent=4) | ||||||||||||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+516
to
+522
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling to the early exit path. The early exit path lacks error handling that exists in the main export flow (lines 520-550). If Apply this diff to add error handling: export_dir.mkdir(parents=True, exist_ok=True)
# NOTE: (hg) Early exit for speculative decoding models
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
if spec_opt_only(model):
- save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
- with open(f"{export_dir}/config.json", "w") as file:
- json.dump(export_spec_ckpt_config(model), file, indent=4)
- return
+ try:
+ save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
+ with open(f"{export_dir}/config.json", "w") as file:
+ json.dump(export_spec_ckpt_config(model), file, indent=4)
+ return
+ except Exception as e:
+ warnings.warn(
+ "Cannot export speculative-only model to the model_config. The modelopt-optimized model state_dict"
+ " can be saved with torch.save for further inspection."
+ )
+ raise e 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
# NOTE: (hg) Should we save hf_quant_config when there's no quantization applied? | ||||||||||||||||||||||||||||||||||||||||||||||||
# Save hf_quant_config.json for backward compatibility | ||||||||||||||||||||||||||||||||||||||||||||||||
with open(f"{export_dir}/hf_quant_config.json", "w") as file: | ||||||||||||||||||||||||||||||||||||||||||||||||
json.dump(hf_quant_config, file, indent=4) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
hf_quant_config = convert_hf_quant_config_format(hf_quant_config) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
# Save model | ||||||||||||||||||||||||||||||||||||||||||||||||
model.save_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||||||
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state | ||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -536,8 +543,6 @@ def export_hf_checkpoint( | |||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
config_data["quantization_config"] = hf_quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
config_data = set_config_if_spec_decoding(model, config_data) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
with open(original_config, "w") as file: | ||||||||||||||||||||||||||||||||||||||||||||||||
json.dump(config_data, file, indent=4) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,4 +47,5 @@ | |
"use_mtp_layernorm": False, | ||
"parallel_draft_step": 1, | ||
"has_lm_head": False, | ||
"head_dim": 128, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq, what's the difference between online training and offline training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In online training we load both the teacher model and student model into GPU DRAM and perform teacher forward + student forward/backward in a training step.
In offline training, we inference the teacher model first (with HF or TRTLLM), dump hidden states to disks, then train the draft model only with dumped distillation signals. Then the checkpoint does not contains teacher modules. : )