-
Notifications
You must be signed in to change notification settings - Fork 174
Fix: supporting gpt-oss HF eagle #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughAdds an eagle-only speculative-decoding export path gated by Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Export as export_hf_checkpoint
participant Plugin as hf_spec_export
participant Saver as safetensors.save_file
participant Default as _export_hf_checkpoint
Caller->>Export: export_hf_checkpoint(model, out_dir)
Export->>Plugin: spec_opt_only(model)
alt eagle-only speculative
Export->>Plugin: export_spec_ckpt_state_dict(model)
Plugin-->>Export: draft_state_dict
Export->>Saver: save_file(draft_state_dict, "out_dir/model.safetensors")
Export->>Plugin: export_spec_ckpt_config(model)
Plugin-->>Export: config_json
Export->>Export: write "out_dir/config.json"
Export-->>Caller: return (early exit)
else non-speculative or mixed
Export->>Default: _export_hf_checkpoint(model, out_dir)
Default-->>Export: standard artifacts
Export-->>Caller: return
end
note over Export,Plugin: Removed rename/prune and config-adjust hooks from normal path
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (7)
🚧 Files skipped from review as they are similar to previous changes (4)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
🔇 Additional comments (6)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #398 +/- ##
=======================================
Coverage 73.79% 73.79%
=======================================
Files 171 171
Lines 17591 17591
=======================================
+ Hits 12981 12982 +1
+ Misses 4610 4609 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
modelopt/torch/export/unified_export_hf.py (1)
512-518
: Consider honoring save_modelopt_state parameter and adding error handling.The early-exit path correctly prevents errors with offline training checkpoints, but consider these improvements:
- The
save_modelopt_state
parameter (line 499) is unused in this path. If users request modelopt state preservation, should it be saved separately?- File write operations lack error handling, unlike the try-except block in the standard export path (lines 520-550).
- Consider using
Path
operations for consistency:export_dir / "model.safetensors"
instead of f-strings.Optional refactor to use Path operations:
if spec_opt_only(model): - save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") - with open(f"{export_dir}/config.json", "w") as file: + save_file(export_spec_ckpt_state_dict(model), export_dir / "model.safetensors") + with open(export_dir / "config.json", "w") as file: json.dump(export_spec_ckpt_config(model), file, indent=4) returnmodelopt/torch/export/plugins/hf_spec_export.py (3)
77-79
: Track the TODO for cleaner lm_head.weight handling.The temporary fix for handling missing
eagle_lm_head.weight
works but should be addressed. The fallback tomodel.state_dict()["lm_head.weight"]
could fail if the key doesn't exist in the base model either.Do you want me to open a new issue to track this technical debt?
141-141
: Fix typo in comment.Minor typo: "load fron eagle config" should be "load from eagle config".
- # First, we try to load fron eagle config. + # First, we try to load from eagle config.
62-62
: Consider more descriptive assertion message.The assertion message "Not purely eagle model." could be more helpful for debugging. Consider providing information about what optimization modes were found.
- assert spec_opt_only(model), "Not purely eagle model." + opt_modes = getattr(model, "_modelopt_state", None) + assert spec_opt_only(model), ( + f"Expected purely eagle model but found optimization modes: {opt_modes}. " + "This export path only supports models with a single 'eagle' optimization." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/speculative_decoding/README.md
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(2 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/speculative/eagle/default_config.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (3)
export_spec_ckpt_config
(84-148)export_spec_ckpt_state_dict
(59-81)spec_opt_only
(51-56)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
🔇 Additional comments (1)
modelopt/torch/speculative/eagle/default_config.py (1)
50-50
: Verifyhead_dim
in eagle default_config
In modelopt/torch/speculative/eagle/default_config.py (line 50),head_dim
is set to 64. Confirm that this matcheshidden_size / num_attention_heads
in the same file (or document why it intentionally differs) to avoid silent mis-inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/export/unified_export_hf.py (2)
515-517
: Prefer Path operations over f-string concatenation.Since
export_dir
is already aPath
object (line 509), using Path operations would be more idiomatic and portable.Apply this diff:
- save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") - with open(f"{export_dir}/config.json", "w") as file: + save_file(export_spec_ckpt_state_dict(model), str(export_dir / "model.safetensors")) + with open(export_dir / "config.json", "w") as file:
512-513
: Consider replacing the temporary workaround with a permanent solution.The comment indicates this is a "temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint". Consider investigating the root cause and implementing a more robust solution that properly handles speculative decoding models within the standard export flow.
Do you want me to help identify the specific errors that occur when passing speculative decoding models to
_export_hf_checkpoint
, or open an issue to track the permanent fix?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/export/unified_export_hf.py
(3 hunks)
🔇 Additional comments (3)
modelopt/torch/export/unified_export_hf.py (3)
29-29
: LGTM!The
save_file
import fromsafetensors.torch
is correctly added to support the new speculative-only export path.
57-57
: Verify the new import functions are properly implemented.Ensure that
export_spec_ckpt_config
,export_spec_ckpt_state_dict
, andspec_opt_only
from.plugins
are correctly implemented, handle edge cases, and have appropriate test coverage.Run the following script to verify the functions exist and review their implementation:
512-518
: Verify ifhf_quant_config.json
should be saved for backward compatibility.The main export path saves
hf_quant_config.json
for backward compatibility (line 524), but the early exit path for speculative-only models does not. Verify whether speculative-only models need this file or if downstream tools will handle its absence correctly.Run the following script to check if any downstream code expects this file:
# NOTE: (hg) Early exit for speculative decoding models | ||
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint | ||
if spec_opt_only(model): | ||
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") | ||
with open(f"{export_dir}/config.json", "w") as file: | ||
json.dump(export_spec_ckpt_config(model), file, indent=4) | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling to the early exit path.
The early exit path lacks error handling that exists in the main export flow (lines 520-550). If save_file
or the file write operations fail, the error will propagate without proper context or cleanup.
Apply this diff to add error handling:
export_dir.mkdir(parents=True, exist_ok=True)
# NOTE: (hg) Early exit for speculative decoding models
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint
if spec_opt_only(model):
- save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
- with open(f"{export_dir}/config.json", "w") as file:
- json.dump(export_spec_ckpt_config(model), file, indent=4)
- return
+ try:
+ save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors")
+ with open(f"{export_dir}/config.json", "w") as file:
+ json.dump(export_spec_ckpt_config(model), file, indent=4)
+ return
+ except Exception as e:
+ warnings.warn(
+ "Cannot export speculative-only model to the model_config. The modelopt-optimized model state_dict"
+ " can be saved with torch.save for further inspection."
+ )
+ raise e
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# NOTE: (hg) Early exit for speculative decoding models | |
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint | |
if spec_opt_only(model): | |
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") | |
with open(f"{export_dir}/config.json", "w") as file: | |
json.dump(export_spec_ckpt_config(model), file, indent=4) | |
return | |
export_dir.mkdir(parents=True, exist_ok=True) | |
# NOTE: (hg) Early exit for speculative decoding models | |
# This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint | |
if spec_opt_only(model): | |
try: | |
save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") | |
with open(f"{export_dir}/config.json", "w") as file: | |
json.dump(export_spec_ckpt_config(model), file, indent=4) | |
return | |
except Exception as e: | |
warnings.warn( | |
"Cannot export speculative-only model to the model_config. The modelopt-optimized model state_dict" | |
" can be saved with torch.save for further inspection." | |
) | |
raise e |
🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 512 to 518, the early
exit branch that writes the speculative checkpoint lacks the error handling
present in the main export path; wrap the calls to save_file(...) and writing
config.json in a try/except, and on exception log a clear error message
including the exception details, remove any partially written files
(model.safetensors and config.json) and the export_dir if appropriate, then
re-raise the exception so callers can handle it (mirror the main export flow's
cleanup and logging behavior).
e386c69
to
08bf1d3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_hf.py (1)
516-522
: Add error handling to the early exit path.The early exit path lacks error handling that exists in the main export flow (lines 524-554). If
save_file
or file write operations fail, errors will propagate without proper context or cleanup.Apply this diff to add error handling:
export_dir.mkdir(parents=True, exist_ok=True) # NOTE: (hg) Early exit for speculative decoding models # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint if spec_opt_only(model): - save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") - with open(f"{export_dir}/config.json", "w") as file: - json.dump(export_spec_ckpt_config(model), file, indent=4) - return + try: + save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") + with open(f"{export_dir}/config.json", "w") as file: + json.dump(export_spec_ckpt_config(model), file, indent=4) + return + except Exception as e: + warnings.warn( + "Cannot export speculative-only model to the model_config. The modelopt-optimized model state_dict" + " can be saved with torch.save for further inspection." + ) + raise e
🧹 Nitpick comments (1)
examples/speculative_decoding/README.md (1)
118-118
: Clarify "in-framework evaluation" terminology.The note states that "in-framework evaluation is supported only for online training," but the term "in-framework" may not be immediately clear to users. Consider briefly explaining what this means (e.g., evaluation using the training framework vs. serving frameworks) for better clarity.
Consider this alternative wording:
-**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks. +**Note**: Direct evaluation using `ar_validate.py` is supported only for online training checkpoints (which include both the base and draft models). For offline training checkpoints (draft model only), please export the model first and evaluate it using serving frameworks like TRT-LLM or SGLang.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/speculative_decoding/README.md
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(2 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/speculative/eagle/default_config.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/speculative/eagle/default_config.py
- modelopt/torch/export/plugins/hf_spec_export.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (3)
export_spec_ckpt_config
(84-148)export_spec_ckpt_state_dict
(59-81)spec_opt_only
(51-56)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/export/unified_export_hf.py (3)
29-29
: LGTM!The import of
save_file
fromsafetensors.torch
is appropriate for the new early exit path that writesmodel.safetensors
.
57-57
: LGTM!The updated imports align with the new speculative-only export architecture. The functions
spec_opt_only
,export_spec_ckpt_state_dict
, andexport_spec_ckpt_config
properly handle the draft-only checkpoint export flow.
516-522
: Confirm head_dim addition in default eagle config The entry"head_dim": 128
exists in modelopt/torch/speculative/eagle/default_config.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/examples/speculative_decoding/test_eagle.py (1)
30-30
: LGTM! Consider verifying consistency with the base model.The addition of
head_dim: 64
to the EAGLE config aligns with the PR objectives to support models wherehidden_size != num_heads * head_dim
. The value appears reasonable for this ultra-tiny test configuration.For improved clarity and maintainability, consider:
- Adding a brief comment explaining the choice of
head_dim: 64
and how it relates to the base model's dimensions- Verifying that this value is consistent with the
tiny_llama_path
model'shidden_size
Example:
"num_attention_heads": 2, "num_key_value_heads": 2, - "head_dim": 64, + "head_dim": 64, # Explicit head_dim to avoid incorrect inference from hidden_size
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/examples/speculative_decoding/test_eagle.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for the export code.
``` | ||
|
||
Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. | ||
**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq, what's the difference between online training and offline training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In online training we load both the teacher model and student model into GPU DRAM and perform teacher forward + student forward/backward in a training step.
In offline training, we inference the teacher model first (with HF or TRTLLM), dump hidden states to disks, then train the draft model only with dumped distillation signals. Then the checkpoint does not contains teacher modules. : )
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
c4d51eb
to
80c6cad
Compare
What does this PR do?
Type of change: Bug fix
Overview:
This PR contains two minor fix to support gpt-oss eagle training:
head_dim
in default eagle config to prevent Llama inferring the head_dim byhidden_size/num_heads
. This leads to wrong head dim for models like GPT-oss, wherehidden_size != num_heads * head_dim
._export_hf_checkpoint
, which triggers error for offline training checkpoints.Other changes:
wandb.init()
in eagle_utils.py;Usage
Not changed.
Testing
Tested with gpt-oss-120b with offline training, export, and tested checkpoint on spec-bench.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Refactor
Config
Tests
Chores