-
Notifications
You must be signed in to change notification settings - Fork 168
[1/N] ModelOPT PEFT mode support for the megatron-lm #342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Warning Rate limit exceeded@jingyu-ml has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 2 minutes and 58 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (3)
WalkthroughAdds a PEFT/LoRA subsystem: package initializers, configuration types with validation, LoRA layer and registry, conversion and adapter-management utilities, plugin registry with Megatron integration and guarded plugin imports, and GPU tests for Megatron LoRA behavior. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant API as update_model()
participant Conv as convert_to_peft_model()
participant Plugins as register_custom_model_plugins_on_the_fly()
participant Replace as replace_lora_module()
participant Registry as LoRAModuleRegistry
participant Add as add_adapter()
participant State as update_peft_metadata()
User->>API: call update_model(model, config)
alt model not yet PEFT
API->>Conv: convert_to_peft_model(model, config)
Conv->>Plugins: run pre-replace plugins(model)
Plugins-->>Conv: plugins complete
Conv->>Replace: replace_lora_module(model, config, registry)
Replace->>Registry: register/wrap eligible modules
Replace-->>Conv: model mutated with LoRA modules
Conv->>Add: add_adapter(model, config)
Add->>State: update_peft_metadata(model, config)
else model already PEFT
API->>Add: add_adapter(model, config)
end
API-->>User: return updated model
sequenceDiagram
autonumber
actor Loader
participant Restore as restore_peft_model()
participant Registry as LoRAModuleRegistry
participant Model as model
participant Apply as apply_peft_state()
Loader->>Restore: restore_peft_model(model, config, metadata)
Restore->>Registry: ensure LoRA modules registered
Restore->>Model: reconstruct/replace modules if needed
Restore->>Apply: apply per-module peft_state / extra_state
alt peft_state present
Apply-->>Restore: peft_state applied
else
Apply-->>Restore: best-effort extra_state applied
end
Restore-->>Loader: restored model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #342 +/- ##
==========================================
- Coverage 73.79% 73.70% -0.09%
==========================================
Files 171 171
Lines 17591 17615 +24
==========================================
+ Hits 12981 12984 +3
- Misses 4610 4631 +21 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
♻️ Duplicate comments (1)
modelopt/torch/peft/convert.py (1)
205-219
: Another instance of private attribute accessSimilar to the previous comment, this function also directly accesses
_lora_adapters
.
🧹 Nitpick comments (24)
modelopt/torch/peft/lora/__init__.py (1)
1-3
: Prefer lazy submodule imports to reduce import-time cost and avoid cycles.Importing
layer
andtp_layer
at package import can be heavy and risks circulars. Expose them lazily via__getattr__
.-"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning.""" - -from . import layer, tp_layer +"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning.""" + +from importlib import import_module as _import_module + +__all__ = ["layer", "tp_layer"] + +def __getattr__(name): + if name in __all__: + mod = _import_module(f".{name}", __name__) + globals()[name] = mod + return mod + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")modelopt/torch/peft/__init__.py (1)
16-16
: Docstring mismatch with package purpose.This is PEFT/LoRA, not distillation. Update the docstring.
-"""Distillation API subpackage for torch.""" +"""PEFT/LoRA API for torch."""modelopt/torch/peft/custom.py (2)
22-29
: Run plugins deterministically and guard against concurrent mutation.Take a snapshot, sort by module/qualname, then invoke.
def register_custom_model_plugins_on_the_fly(model): """Registers custom PEFT/LoRA plugins on the fly. This is called before LoRAModule replacement to allow plugins to configure the model (e.g., for distributed checkpointing). """ - for callback in CUSTOM_MODEL_PLUGINS: - callback(model) + # Snapshot to avoid RuntimeError if mutated during iteration + callbacks = sorted( + tuple(CUSTOM_MODEL_PLUGINS), + key=lambda f: (getattr(f, "__module__", ""), getattr(f, "__qualname__", getattr(f, "__name__", ""))), + ) + for callback in callbacks: + callback(model)
16-23
: Add light typing for clarity.Optional: add minimal typing to document the callback contract.
-"""Custom PEFT/LoRA plugins registry.""" +"""Custom PEFT/LoRA plugins registry.""" +from typing import Callable, Iterable # lightweight, no runtime depsmodelopt/torch/peft/plugins/megatron.py (1)
31-34
: Optional: export public symbols for introspection.Expose
MEGATRON_AVAILABLE
and the hook via__all__
for discoverability.-__all__ = [] +__all__ = ["MEGATRON_AVAILABLE", "megatron_replace_lora_module_hook"]modelopt/torch/peft/mode.py (2)
1-13
: Missing module docstringAdd a module-level docstring to document the purpose and functionality of this mode registry module.
+"""PEFT mode definitions and registry for parameter-efficient fine-tuning.""" + from modelopt.torch.opt.config import ModeloptBaseConfig
17-46
: Missing class docstringsBoth mode descriptor classes lack docstrings explaining their purpose and usage.
@PEFTModeRegistry.register_mode class PEFTModeDescriptor(ModeDescriptor): + """Mode descriptor for PEFT/LoRA model conversion.""" + @property def name(self) -> str: + """Return the mode identifier string.""" return "peft"@PEFTModeRegistry.register_mode class ExportPEFTModeDescriptor(ModeDescriptor): + """Mode descriptor for exporting PEFT/LoRA models.""" @property def name(self) -> str: - """Returns the value (str representation) of the mode.""" + """Return the mode identifier string.""" return "export_peft"modelopt/torch/peft/config.py (3)
88-90
: Incorrect error message for scale validationThe error message is missing "a" before "positive number".
if v <= 0: - raise ValueError("scale must be positive number") + raise ValueError("scale must be a positive number") return v
99-117
: Pickling validation may be too restrictiveThe pickling requirement for initialization functions might be overly restrictive and prevent legitimate use cases like closures or partial functions. Consider documenting this requirement more prominently or providing alternative initialization strategies.
Consider adding a class method to create common initialization patterns that are guaranteed to be pickleable:
@classmethod def create_normal_init(cls, mean=0.0, std=0.02): """Create a pickleable normal initialization function.""" def normal_init(weight): return init.normal_(weight, mean=mean, std=std) return normal_init
162-170
: Broad exception handling masks specific validation errorsCatching all exceptions and re-raising with a generic message loses valuable debugging information.
try: validated_cfg[key] = PEFTAttributeConfig(**value) - except Exception as e: - raise ValueError(f"Invalid adapter configuration for '{key}': {e}") + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid adapter configuration for '{key}': {e}") from emodelopt/torch/peft/lora/layer.py (2)
62-62
: Typo in error messageThe error message has a double period.
- raise ValueError(f"adapter_name: {adapter_name} is already exist..") + raise ValueError(f"adapter_name: {adapter_name} already exists.")
193-230
: Forward method has performance implicationsThe forward method iterates through all adapters on every forward pass, which could impact performance with many adapters. Consider caching active adapters.
Consider maintaining a list of active adapters to avoid checking the enable flag on every forward pass:
def _update_active_adapters(self): """Cache list of active adapters for efficient forward pass.""" self._active_adapters = [ (adapter["lora_a"], adapter["lora_b"], adapter["scale"]) for adapter in self._lora_adapters.values() if adapter["enable"] ] def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: output = super().forward(x, *args, **kwargs) if isinstance(output, tuple): result = output[0] other_outputs = output[1:] else: result = output other_outputs = () # Use cached active adapters for lora_a, lora_b, scale in getattr(self, '_active_adapters', []): lora_a_output = lora_a(x) if isinstance(lora_a_output, tuple): lora_a_output = lora_a_output[0] lora_b_output = lora_b(lora_a_output) if isinstance(lora_b_output, tuple): lora_b_output = lora_b_output[0] result = result + scale * lora_b_output return (result, *other_outputs) if other_outputs else resultmodelopt/torch/peft/convert.py (3)
90-90
: Unclear assertion messageThe assertion message uses a non-standard abbreviation "MO-PEFT" without explanation.
- assert is_peft_model(model), "It's not a MO-PEFT model" + assert is_peft_model(model), "Model has not been converted to PEFT/LoRA format"
101-102
: Inconsistent error message formattingThe error message uses different formats for "pattern" vs "adapter pattern".
- pattern_type = "pattern" if allow_callable else "adapter pattern" - raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}") + pattern_type = "layer pattern" if allow_callable else "adapter pattern" + raise TypeError(f"Unsupported {pattern_type} type: {type(pattern).__name__}")
111-119
: Direct access to private attribute _lora_adaptersThe function directly accesses the private
_lora_adapters
attribute, violating encapsulation. Consider adding a public method or property for adapter access.Add a public interface in LoRAModule:
# In LoRAModule class def get_adapter(self, adapter_name: str) -> dict[str, Any] | None: """Get adapter configuration by name.""" return self._lora_adapters.get(adapter_name) def set_adapter_state(self, adapter_name: str, enable: bool) -> None: """Set the enable state of an adapter.""" if adapter_name in self._lora_adapters: self._lora_adapters[adapter_name]["enable"] = enableThen update this function:
- for adapter_name, adapter_dict in module._lora_adapters.items(): + for adapter_name in module.adapter_names: if adapter_patterns is not None: if not matches_any_pattern( adapter_name, adapter_patterns, allow_callable=False ): continue - adapter_dict["enable"] = enable_state + module.set_adapter_state(adapter_name, enable_state)tests/gpu/torch/peft/test_megatron_peft.py (3)
93-121
: Test helper lacks error handlingThe model provider function doesn't handle potential errors during model creation, which could make test failures harder to debug.
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" - + try: if meta_device: with torch.device("meta"): gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=4, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, ) else: gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=4, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, ).cuda() return gpt_model.eval() + except Exception as e: + pytest.fail(f"Failed to create GPT model: {e}")
134-139
: Conditional assertion logic could be clearerThe conditional logic for checking output equality based on config type is not immediately clear. Consider adding a comment explaining why DEFAULT_LORA_CFG_TEST should produce identical outputs.
assert lora_output.shape == original_output.shape + # DEFAULT_LORA_CFG_TEST uses zero initialization for LoRA B, so initial output should match original if lora_config == DEFAULT_LORA_CFG_TEST: assert torch.allclose(lora_output, original_output, rtol=1e-5), ( f"{lora_output}, {original_output}" ) else: assert not torch.allclose(lora_output, original_output, rtol=1e-5)
161-174
: Large number of commented test casesMost test parameterizations are commented out. This suggests either incomplete implementation or test instability.
If these tests are not ready, consider:
- Removing them entirely and tracking in an issue
- Using
pytest.skip
with a reason- Adding a TODO comment explaining why they're disabled
@pytest.mark.parametrize( "lora_config", [ DEFAULT_LORA_CFG_TEST, - # DEFAULT_LORA_CFG_RANDOM_INIT_TEST, - # SMALL_RANK_LORA_CFG, - # LARGE_SCALE_LORA_CFG, - # SELECTIVE_LAYER_LORA_CFG, + pytest.param(DEFAULT_LORA_CFG_RANDOM_INIT_TEST, marks=pytest.mark.skip(reason="Not yet stable")), + pytest.param(SMALL_RANK_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")), + pytest.param(LARGE_SCALE_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")), + pytest.param(SELECTIVE_LAYER_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")), ], )modelopt/torch/peft/conversion.py (3)
104-116
: Return the possibly converted root module from replace_lora_module (and use it)Safer if a root replacement is ever registered; also removes confusion around local reassignment.
def replace_lora_module( model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry ): """Recursively replace the module with LoRA module.""" @@ - if type(model) in registry: - model = registry.convert(model) - _replace_lora_module(model, version=version, registry=registry) + if type(model) in registry: + model = registry.convert(model) + _replace_lora_module(model, version=version, registry=registry) + return modelAnd in convert_to_peft_model:
- replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config) + model = replace_lora_module( + model, version=ModeloptStateManager(model).state_version, config=config + )
41-42
: Remove stale TODOThe replacement is already performed.
- # TODO: Replace to LoRA module
141-149
: Docstring example uses an invalid signatureupdate_layer_lora takes a config object; adjust the example to avoid confusion.
- ... module.update_layer_lora("custom_adapter", rank=32) + ... from modelopt.torch.peft.config import PEFTAttributeConfig + ... module.update_layer_lora("custom_adapter", PEFTAttributeConfig(rank=32))modelopt/torch/peft/lora/tp_layer.py (3)
37-48
: _get_init_methods is unused; either use or removePrefer using it to ensure defaults when attr_config initializers are None (e.g., after metadata restore).
Example use (apply similarly in both update_layer_lora methods):
- lora_a = ColumnParallelLinear( + lora_a_init, lora_b_init = self._get_init_methods(attr_config.lora_a_init, attr_config.lora_b_init) + lora_a = ColumnParallelLinear( self.input_size, attr_config.rank, config=self.config, bias=False, gather_output=True, - init_method=attr_config.lora_a_init, + init_method=lora_a_init, disable_grad_reduce=getattr(self.config, "sequence_parallel", False), ) @@ - lora_b = ColumnParallelLinear( + lora_b = ColumnParallelLinear( attr_config.rank, self.output_size, config=self.config, bias=False, gather_output=False, # Keep output distributed like base layer - init_method=attr_config.lora_a_init, + init_method=lora_b_init, )
80-87
: Micro: combine device/dtype moves into a single .to callSlight cleanup; avoids two passes.
- if device is not None: - lora_a = lora_a.to(device) - lora_b = lora_b.to(device) - if dtype is not None: - lora_a = lora_a.to(dtype) - lora_b = lora_b.to(dtype) + if device is not None or dtype is not None: + lora_a = lora_a.to(device=device, dtype=dtype) + lora_b = lora_b.to(device=device, dtype=dtype)
26-28
: Remove unused defaultsDEFAULT_LORA_RANK and DEFAULT_SCALE are unused.
-DEFAULT_LORA_RANK = 64 -DEFAULT_SCALE = 1.0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
modelopt/torch/peft/__init__.py
(1 hunks)modelopt/torch/peft/config.py
(1 hunks)modelopt/torch/peft/conversion.py
(1 hunks)modelopt/torch/peft/convert.py
(1 hunks)modelopt/torch/peft/custom.py
(1 hunks)modelopt/torch/peft/lora/__init__.py
(1 hunks)modelopt/torch/peft/lora/layer.py
(1 hunks)modelopt/torch/peft/lora/tp_layer.py
(1 hunks)modelopt/torch/peft/mode.py
(1 hunks)modelopt/torch/peft/plugins/__init__.py
(1 hunks)modelopt/torch/peft/plugins/megatron.py
(1 hunks)tests/gpu/torch/peft/test_megatron_peft.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
modelopt/torch/peft/plugins/megatron.py (1)
modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)
modelopt/torch/peft/lora/layer.py (3)
modelopt/torch/opt/dynamic.py (3)
DynamicModule
(338-914)_DMRegistryCls
(917-1124)config
(1265-1278)modelopt/torch/peft/config.py (1)
PEFTAttributeConfig
(40-117)modelopt/torch/peft/conversion.py (1)
peft_state
(96-101)
modelopt/torch/peft/__init__.py (1)
modelopt/torch/peft/mode.py (2)
convert
(31-32)convert
(66-68)
tests/gpu/torch/peft/test_megatron_peft.py (6)
tests/_test_utils/import_helper.py (1)
skip_if_no_megatron
(46-77)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
get_mcore_gpt_model
(133-208)initialize_for_megatron
(385-393)modelopt/torch/peft/config.py (2)
kaiming_init
(30-32)zero_init
(35-37)modelopt/torch/peft/lora/layer.py (1)
LoRAModule
(20-230)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill
(41-130)modelopt/torch/peft/convert.py (3)
update_model
(39-63)disable_adapters
(121-149)enable_adapters
(152-180)
modelopt/torch/peft/convert.py (4)
modelopt/torch/opt/conversion.py (1)
apply_mode
(342-429)modelopt/torch/peft/config.py (1)
PEFTConfig
(124-170)modelopt/torch/peft/conversion.py (1)
add_adapter
(163-192)modelopt/torch/peft/lora/layer.py (1)
LoRAModule
(20-230)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
ModeloptBaseConfig
(59-147)ModeloptField
(50-53)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (7)
ApplyModeError
(314-315)ModelLikeModule
(318-330)ModeloptStateManager
(63-311)init_modellike
(326-330)state_version
(135-137)is_converted
(102-127)_last_metadata
(220-222)modelopt/torch/utils/network.py (1)
get_unwrapped_name
(599-612)modelopt/torch/peft/config.py (1)
PEFTConfig
(124-170)modelopt/torch/peft/lora/layer.py (4)
LoRAModule
(20-230)set_from_peft_state
(152-165)get_peft_state
(91-131)update_layer_lora
(72-89)modelopt/torch/peft/custom.py (1)
register_custom_model_plugins_on_the_fly
(22-29)
modelopt/torch/peft/mode.py (4)
modelopt/torch/opt/config.py (1)
ModeloptBaseConfig
(59-147)modelopt/torch/opt/mode.py (2)
ModeDescriptor
(56-259)_ModeRegistryCls
(267-344)modelopt/torch/peft/config.py (2)
PEFTConfig
(124-170)ExportPEFTConfig
(173-174)modelopt/torch/peft/conversion.py (5)
convert_to_peft_model
(36-48)restore_peft_model
(51-55)update_peft_metadata
(91-93)export_peft_model
(118-119)restore_export_peft_model
(122-123)
modelopt/torch/peft/lora/tp_layer.py (4)
modelopt/torch/peft/config.py (1)
PEFTAttributeConfig
(40-117)modelopt/torch/peft/lora/layer.py (4)
LoRAModule
(20-230)_register_adapter
(39-69)update_layer_lora
(72-89)_setup
(30-32)modelopt/torch/quantization/plugins/megatron.py (2)
_MegatronColumnParallelLinear
(296-318)_MegatronRowParallelLinear
(322-354)modelopt/torch/quantization/conversion.py (1)
register
(325-366)
🪛 GitHub Actions: Code Quality
modelopt/torch/peft/mode.py
[error] 1-1: D100 Missing docstring in public module
[error] 22-22: D101 Missing docstring in public class
[error] 24-24: D102 Missing docstring in public method
[error] 28-28: D102 Missing docstring in public method
[error] 32-32: D102 Missing docstring in public method
[error] 36-36: D102 Missing docstring in public method
[error] 40-40: D102 Missing docstring in public method
[error] 44-44: D102 Missing docstring in public method
[error] 53-53: D101 Missing docstring in public class
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
🔇 Additional comments (2)
modelopt/torch/peft/plugins/__init__.py (1)
20-21
: LGTM: guarded optional import.Import guarding with
suppress(ImportError)
is appropriate; it triggers registration only when Megatron is present.modelopt/torch/peft/__init__.py (1)
18-25
: Import-order OK — no action required. mode.py doesfrom .conversion import convert_to_peft_model, ..., export_peft_model
and its properties return those functions (modelopt/torch/peft/mode.py, ~lines 11–13, 31–33, 66–69).
Transformers |
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve to unblock. Great work @jingyu-ml - please make sure the sample API usages are aligned with the design doc we have.
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (1)
tests/gpu/torch/peft/test_megatron_peft.py (1)
18-18
: Critical: Addapex_or_te_required=True
to skip decorator.Line 18's
skip_if_no_megatron()
call is missingapex_or_te_required=True
. Since_gpt_model_provider
usestransformer_impl="local"
(lines 164, 176), which requires Apex, environments lacking Apex will hit an assertion instead of gracefully skipping these GPU tests.Apply this diff:
-skip_if_no_megatron() +skip_if_no_megatron(apex_or_te_required=True)
🧹 Nitpick comments (1)
tests/gpu/torch/peft/test_megatron_peft.py (1)
153-180
: Consider reducing duplication in_gpt_model_provider
.The meta_device and non-meta_device branches duplicate most of the
get_mcore_gpt_model
arguments. Consider extracting common kwargs to reduce duplication.Example refactor:
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" - + + kwargs = { + "tensor_model_parallel_size": tp_size, + "num_layers": 4, + "ffn_hidden_size": None, + "num_attention_heads": 4, + "activation_func": "squared_relu", + "transformer_impl": "local", + "hidden_size": hidden_size, + "vocab_size": vocab_size, + "use_cpu_initialization": meta_device, + } + if meta_device: with torch.device("meta"): - gpt_model = get_mcore_gpt_model( - tensor_model_parallel_size=tp_size, - num_layers=4, - ffn_hidden_size=None, - num_attention_heads=4, - activation_func="squared_relu", - transformer_impl="local", - hidden_size=hidden_size, - vocab_size=vocab_size, - use_cpu_initialization=meta_device, - ) + gpt_model = get_mcore_gpt_model(**kwargs) else: - gpt_model = get_mcore_gpt_model( - tensor_model_parallel_size=tp_size, - num_layers=4, - ffn_hidden_size=None, - num_attention_heads=4, - activation_func="squared_relu", - transformer_impl="local", - hidden_size=hidden_size, - vocab_size=vocab_size, - ).cuda() + gpt_model = get_mcore_gpt_model(**kwargs).cuda() return gpt_model.eval()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/gpu/torch/peft/test_megatron_peft.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/peft/test_megatron_peft.py (8)
tests/_test_utils/import_helper.py (1)
skip_if_no_megatron
(46-77)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
get_mcore_gpt_model
(133-208)initialize_for_megatron
(385-393)modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (2)
restore_sharded_modelopt_state
(207-250)save_sharded_modelopt_state
(127-173)modelopt/torch/peft/lora/layer.py (1)
LoRAModule
(37-142)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill
(41-130)modelopt/torch/peft/convert.py (3)
update_model
(47-74)disable_adapters
(133-161)enable_adapters
(164-192)modelopt/torch/trace/symbols.py (1)
named_modules
(444-447)modelopt/torch/quantization/model_quant.py (1)
disable_quantizer
(453-455)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
def restore_peft_model( | ||
model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict | ||
) -> nn.Module: | ||
model, _ = convert_to_peft_model(model, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore is calling convert_to_peft_model
which will freeze base model while resuming modelopt_state and before we create optimizer.
Signed-off-by: Jingyu Xin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tests/gpu/torch/peft/test_megatron_peft.py (1)
18-18
: Make Megatron tests skip when Apex/TE is missing
get_mcore_gpt_model(... transformer_impl="local")
asserts Apex is present. Withoutapex_or_te_required=True
, environments lacking Apex/TE crash instead of skipping the suite. Restore the guard so these GPU tests self-skip gracefully.-skip_if_no_megatron() +skip_if_no_megatron(apex_or_te_required=True)
🧹 Nitpick comments (3)
tests/gpu/torch/peft/test_megatron_peft.py (1)
572-779
: Polish the quantizer expectation commentsThe repeated comment “teh quantizer” both misspells “the” and in some cases states the wrong expectation (e.g., LoRA-after-quantize should not inherit quantizers, LoRA-before-quantize should). Please fix the spelling and align each comment with what the assertions actually check—for example:
- # Check if the lora have teh quantizer, they should not have them. + # Check that the LoRA adapters do not carry quantizers.and
- # Check if the lora have teh quantizer, they should not have them. + # Check that the LoRA adapters keep their quantizers after quantization.Update the other duplicated comments (lines 572, 624, 704, 769) accordingly to avoid confusion.
modelopt/torch/peft/conversion.py (2)
58-62
: Metadata parameter is unused.The
metadata
parameter is accepted but not used in the restore logic. Based on past review discussions, this appears intentional (per-module state was removed in favor of model-level restore viaconvert_to_peft_model
). However, if the metadata will never be used, consider either:
- Removing the parameter from the signature (if backward compatibility is not a concern)
- Adding a docstring note explaining why metadata is currently unused but kept for future extensibility
87-93
: Remove unusedversion
parameter.The
version
parameter is passed through the recursive call but never used in the function body. Remove it to simplify the signature, or add a comment explaining its planned future use.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/peft/conversion.py
(1 hunks)tests/gpu/torch/peft/test_megatron_peft.py
(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/peft/conversion.py
🧬 Code graph analysis (2)
tests/gpu/torch/peft/test_megatron_peft.py (7)
tests/_test_utils/import_helper.py (1)
skip_if_no_megatron
(46-77)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
get_mcore_gpt_model
(133-208)initialize_for_megatron
(385-393)modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (2)
restore_sharded_modelopt_state
(207-250)save_sharded_modelopt_state
(127-173)modelopt/torch/peft/lora/layer.py (1)
LoRAModule
(37-142)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill
(41-130)modelopt/torch/peft/convert.py (3)
update_model
(47-74)disable_adapters
(133-161)enable_adapters
(164-192)modelopt/torch/quantization/model_quant.py (1)
disable_quantizer
(453-455)
modelopt/torch/peft/conversion.py (4)
modelopt/torch/opt/conversion.py (4)
ModelLikeModule
(318-330)ModeloptStateManager
(63-311)init_modellike
(326-330)state_version
(135-137)modelopt/torch/peft/config.py (1)
PEFTConfig
(95-155)modelopt/torch/peft/lora/layer.py (2)
LoRAModule
(37-142)update_layer_lora
(89-103)modelopt/torch/peft/custom.py (1)
register_custom_model_plugins_on_the_fly
(22-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/peft/conversion.py (4)
38-55
: LGTM! Base model freezing correctly placed.The function correctly freezes all base model parameters (line 44-46) before replacing modules with LoRA modules, then adds adapters and configures LoRA-specific gradients. This sequence ensures that all non-LoRA parameters (including embeddings, layer norms, etc.) are frozen when
freeze_base_model=True
.
177-196
: LGTM! Base parameter filtering correctly excludes all LoRA parameters.The function correctly collects LoRA parameter IDs from the entire model (line 180,
layer_patterns=None
) to ensure no LoRA parameters are affected when setting baserequires_grad
. Then it selectively applies the setting to base parameters based onlayer_patterns
(lines 192-196). This two-stage approach is correct.
297-313
: LGTM! Clear separation of base and LoRA parameter gradient control.The docstring clearly explains that this function only affects LoRA parameters, addressing past review concerns. The base model parameter gradients are correctly handled separately in
convert_to_peft_model
(lines 44-46) before LoRA module replacement.
65-77
: Add docstring & confirm top-level conversion
- Add a docstring to
replace_lora_module
explaining each parameter (model
,version
,config
,registry
) and that it mutates the model in-place.- Verify whether
registry.convert(model)
returns a newnn.Module
; if it does, havereplace_lora_module
return that converted module (or update callers to use the returned value).
) | ||
|
||
|
||
class _QuantLoRAMegatronColumnParallelLinear( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantization part seems be better in the torch/quantization/plugins
, since the HF peft quant is there.
if not is_megatron_core_model(model): | ||
raise ValueError("PEFT mode currently supports Megatron-Core models only.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why dont we create a LoRAModuleRegistry
just like QuantModuleRegistry
and register the supported modules to it? This way we can avoid hardcoding like this.
@QuantModuleRegistry.register( |
if type(model) in registry: | ||
model = registry.convert(model) | ||
_replace_lora_module(model, version=version, registry=registry) | ||
|
||
|
||
def export_peft_model(model: nn.Module, config): | ||
raise NotImplementedError("Exporting a peft model is not supported yet.") | ||
|
||
|
||
def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict): | ||
raise NotImplementedError("Restoring a peft & exported model is not supported yet.") | ||
|
||
|
||
def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry): | ||
for name, child in model.named_children(): | ||
if type(child) in registry: | ||
lora_module = registry.convert(child) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At line 88, we can use model.named_modules()
instead of model.named_children()
, so lines 74 and 75 can be removed.
for name, module in model.named_modules(): | ||
if isinstance(module, LoRAModule): | ||
# Collect all matching adapter settings and merge them | ||
# Later patterns override earlier ones |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add this note to the config description as well.
(lora_a is a regular nn.Linear and is not sharded) | ||
- lora_b weight: sharded at dim 0 | ||
""" | ||
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do for each lora_b:
sharded_state_dict.update(lora_b.sharded_stated_dict(..., ..))
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
What does this PR do?
Type of change: new feature
Overview:
The PEFT module in TensorRT Model Optimizer provides an implementation of LoRA for parameter-efficient fine-tuning of large models. This module allows you to add trainable low-rank decomposition matrices to existing model layers, significantly reducing the number of trainable parameters while maintaining model performance. It's particularly optimized for Megatron-based models and supports multiple adapters that can be dynamically enabled, disabled, or switched during inference. And with MO, you will be able to do the QLoRA training easily in the future.
Usage
Advanced Usage - Quantization
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Configuration & Runtime
Tests
Chores