Skip to content

Conversation

jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Sep 18, 2025

What does this PR do?

Type of change: new feature

Overview:

The PEFT module in TensorRT Model Optimizer provides an implementation of LoRA for parameter-efficient fine-tuning of large models. This module allows you to add trainable low-rank decomposition matrices to existing model layers, significantly reducing the number of trainable parameters while maintaining model performance. It's particularly optimized for Megatron-based models and supports multiple adapters that can be dynamically enabled, disabled, or switched during inference. And with MO, you will be able to do the QLoRA training easily in the future.

Usage

import modelopt.torch.peft as mtpf
import modelopt.torch.quantization as mtq
from modelopt.torch.peft.config import kaiming_init, zero_init

# Define LoRA configuration
lora_config = {
    "adapter_type": "lora",
    "adapter_name": "my_adapter",
    "adapter_cfg": {
        "*": {  # Apply to all layers
            "rank": 32,  # LoRA rank
            "scale": 1.0,  # Scaling factor
            "lora_a_init": kaiming_init,  # A matrix initialization
            "lora_b_init": zero_init,  # B matrix initialization
            "enable": True
        }
    }
}

# Apply LoRA to your model
mtpf.update_model(model, lora_config)

# Use the model with LoRA adapter
output = model(input_data)

# Disable the adapter (use original model)
mtpf.disable_adapters(model)
output_original = model(input_data)

# Re-enable the adapter
mtpf.enable_adapters(model)
output_lora = model(input_data)

Advanced Usage - Quantization

# Add first adapter for task A
task_a_config = {
    "adapter_type": "lora",
    "adapter_name": "task_a",
    "adapter_cfg": {
        "*": {"rank": 16, "scale": 1.0, "enable": True}
    }
}
mtpf.update_model(model, task_a_config)
mtq.quantize(model, FP8_CFG, forward_call)

# Switch between adapters
mtpf.disable_adapters(model, adapters_to_disable=["task_a"])
mtpf.enable_adapters(model, adapters_to_enable=["task_b"])
output_task_b = model(input_data)

mtpf.disable_adapters(model, adapters_to_disable=["task_b"])
mtpf.enable_adapters(model, adapters_to_enable=["task_a"])
output_task_a = model(input_data)

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?:No
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?:No

Additional Information

Summary by CodeRabbit

  • New Features

    • PEFT/LoRA support: convert/restore flows, add/manage adapters, enable/disable/query adapters, gradient controls, with Megatron and quantized integrations.
  • Configuration & Runtime

    • Validated PEFT configuration schema with sensible defaults, per-adapter settings, and mode descriptors for workflows.
  • Tests

    • GPU Megatron PEFT tests covering multi-adapter scenarios, adapter toggling, quantization interactions, and distributed checkpointing.
  • Chores

    • Added CODEOWNERS, editor settings, and changelog entry for LoRA mode support.

@jingyu-ml jingyu-ml self-assigned this Sep 18, 2025
@jingyu-ml jingyu-ml requested review from a team as code owners September 18, 2025 21:29
@jingyu-ml jingyu-ml marked this pull request as draft September 18, 2025 21:29
Copy link

coderabbitai bot commented Sep 18, 2025

Warning

Rate limit exceeded

@jingyu-ml has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 2 minutes and 58 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 82dc269 and fad4982.

📒 Files selected for processing (3)
  • modelopt/torch/peft/conversion.py (1 hunks)
  • modelopt/torch/peft/convert.py (1 hunks)
  • modelopt/torch/utils/network.py (3 hunks)

Walkthrough

Adds a PEFT/LoRA subsystem: package initializers, configuration types with validation, LoRA layer and registry, conversion and adapter-management utilities, plugin registry with Megatron integration and guarded plugin imports, and GPU tests for Megatron LoRA behavior.

Changes

Cohort / File(s) Summary of Changes
Package initializers
modelopt/torch/peft/__init__.py, modelopt/torch/peft/lora/__init__.py, modelopt/torch/peft/lora/plugins/__init__.py
New package __init__ files (license/docstrings); re-export submodules and conditionally import/re-export optional plugin symbols.
Config
modelopt/torch/peft/config.py
New PEFT config types and validators: PEFTAttributeConfig, PEFTConfig, ExportPEFTConfig; coercion of adapter entries, defaults, and __all__ exports.
Conversion & state
modelopt/torch/peft/conversion.py
New conversion and adapter-management utilities: convert_to_peft_model, restore_peft_model, replace_lora_module/replace, add_adapter, pattern matching and traversal helpers, gradient/weight-freeze helpers, update_peft_metadata; export stubs for export-related funcs.
User adapter API
modelopt/torch/peft/convert.py
New user-facing API: update_model, is_peft_model, enable_adapters, disable_adapters, pattern-based adapter toggles, Megatron-core compatibility checks; exported via __all__.
LoRA core layer
modelopt/torch/peft/lora/layer.py
Adds LoRAModule base class and LoRAModuleRegistry: adapter registration/storage, abstract update_layer_lora, forward composition applying enabled adapters.
Plugin registry
modelopt/torch/peft/custom.py
Adds CUSTOM_MODEL_PLUGINS set and register_custom_model_plugins_on_the_fly(model) to run pre-replacement callbacks.
Megatron plugins
modelopt/torch/peft/lora/plugins/megatron.py
Megatron-Core plugin with pre-replacement hook registration, Megatron-parallel LoRA adapter classes (column/row), optional quantized variants, device/dtype alignment, sharded-state handling, and guarded imports.
Mode integration
modelopt/torch/peft/mode.py
Adds PEFTModeDescriptor, ExportPEFTModeDescriptor, and PEFTModeRegistry, wiring convert/restore/update entrypoints into the mode registry.
Tests (GPU Megatron)
tests/gpu/torch/peft/test_megatron_peft.py
New GPU tests and helpers: LoRA configs, GPT model provider, single/dual-adapter flows, enable/disable assertions, adapter/module checks, multiprocessing harness and distributed checkpoint helpers.
Metadata / owners
.github/CODEOWNERS
Adds CODEOWNERS entry mapping modelopt/torch/peft to @NVIDIA/modelopt-torch-peft-codeowners.
Editor settings
.vscode/settings.json
Adds workspace Python env settings (editor-only).
Changelog
CHANGELOG.rst
New entry noting LoRA mode support for MCore via update_model.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant API as update_model()
  participant Conv as convert_to_peft_model()
  participant Plugins as register_custom_model_plugins_on_the_fly()
  participant Replace as replace_lora_module()
  participant Registry as LoRAModuleRegistry
  participant Add as add_adapter()
  participant State as update_peft_metadata()

  User->>API: call update_model(model, config)
  alt model not yet PEFT
    API->>Conv: convert_to_peft_model(model, config)
    Conv->>Plugins: run pre-replace plugins(model)
    Plugins-->>Conv: plugins complete
    Conv->>Replace: replace_lora_module(model, config, registry)
    Replace->>Registry: register/wrap eligible modules
    Replace-->>Conv: model mutated with LoRA modules
    Conv->>Add: add_adapter(model, config)
    Add->>State: update_peft_metadata(model, config)
  else model already PEFT
    API->>Add: add_adapter(model, config)
  end
  API-->>User: return updated model
Loading
sequenceDiagram
  autonumber
  actor Loader
  participant Restore as restore_peft_model()
  participant Registry as LoRAModuleRegistry
  participant Model as model
  participant Apply as apply_peft_state()

  Loader->>Restore: restore_peft_model(model, config, metadata)
  Restore->>Registry: ensure LoRA modules registered
  Restore->>Model: reconstruct/replace modules if needed
  Restore->>Apply: apply per-module peft_state / extra_state
  alt peft_state present
    Apply-->>Restore: peft_state applied
  else
    Apply-->>Restore: best-effort extra_state applied
  end
  Restore-->>Loader: restored model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Poem

I nibble at tensors under the LED,
I stitch small LoRAs, quiet as a thread.
Plugins hop in, align each tiny part,
Megatron hums while adapters start.
Metadata snug — a rabbit's happy art. 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.79% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title accurately summarizes the primary addition of PEFT mode support for Megatron-LM within ModelOPT and directly reflects the core change introduced by the diff. It is concise, specific, and relevant to the modifications made across the new PEFT and Megatron integration modules.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

copy-pr-bot bot commented Sep 18, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

codecov bot commented Sep 18, 2025

Codecov Report

❌ Patch coverage is 8.33333% with 22 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.70%. Comparing base (adcb1a1) to head (fad4982).

Files with missing lines Patch % Lines
modelopt/torch/utils/network.py 8.33% 22 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #342      +/-   ##
==========================================
- Coverage   73.79%   73.70%   -0.09%     
==========================================
  Files         171      171              
  Lines       17591    17615      +24     
==========================================
+ Hits        12981    12984       +3     
- Misses       4610     4631      +21     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jingyu-ml jingyu-ml changed the title Jingyux/megatron lora [1/N] ModelOPT PEFT mode support for the megatron-lm Sep 20, 2025
@jingyu-ml jingyu-ml marked this pull request as ready for review September 20, 2025 07:39
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

♻️ Duplicate comments (1)
modelopt/torch/peft/convert.py (1)

205-219: Another instance of private attribute access

Similar to the previous comment, this function also directly accesses _lora_adapters.

🧹 Nitpick comments (24)
modelopt/torch/peft/lora/__init__.py (1)

1-3: Prefer lazy submodule imports to reduce import-time cost and avoid cycles.

Importing layer and tp_layer at package import can be heavy and risks circulars. Expose them lazily via __getattr__.

-"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning."""
-
-from . import layer, tp_layer
+"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning."""
+
+from importlib import import_module as _import_module
+
+__all__ = ["layer", "tp_layer"]
+
+def __getattr__(name):
+    if name in __all__:
+        mod = _import_module(f".{name}", __name__)
+        globals()[name] = mod
+        return mod
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
modelopt/torch/peft/__init__.py (1)

16-16: Docstring mismatch with package purpose.

This is PEFT/LoRA, not distillation. Update the docstring.

-"""Distillation API subpackage for torch."""
+"""PEFT/LoRA API for torch."""
modelopt/torch/peft/custom.py (2)

22-29: Run plugins deterministically and guard against concurrent mutation.

Take a snapshot, sort by module/qualname, then invoke.

 def register_custom_model_plugins_on_the_fly(model):
     """Registers custom PEFT/LoRA plugins on the fly.
 
     This is called before LoRAModule replacement to allow plugins
     to configure the model (e.g., for distributed checkpointing).
     """
-    for callback in CUSTOM_MODEL_PLUGINS:
-        callback(model)
+    # Snapshot to avoid RuntimeError if mutated during iteration
+    callbacks = sorted(
+        tuple(CUSTOM_MODEL_PLUGINS),
+        key=lambda f: (getattr(f, "__module__", ""), getattr(f, "__qualname__", getattr(f, "__name__", ""))),
+    )
+    for callback in callbacks:
+        callback(model)

16-23: Add light typing for clarity.

Optional: add minimal typing to document the callback contract.

-"""Custom PEFT/LoRA plugins registry."""
+"""Custom PEFT/LoRA plugins registry."""
+from typing import Callable, Iterable  # lightweight, no runtime deps
modelopt/torch/peft/plugins/megatron.py (1)

31-34: Optional: export public symbols for introspection.

Expose MEGATRON_AVAILABLE and the hook via __all__ for discoverability.

-__all__ = []
+__all__ = ["MEGATRON_AVAILABLE", "megatron_replace_lora_module_hook"]
modelopt/torch/peft/mode.py (2)

1-13: Missing module docstring

Add a module-level docstring to document the purpose and functionality of this mode registry module.

+"""PEFT mode definitions and registry for parameter-efficient fine-tuning."""
+
 from modelopt.torch.opt.config import ModeloptBaseConfig

17-46: Missing class docstrings

Both mode descriptor classes lack docstrings explaining their purpose and usage.

 @PEFTModeRegistry.register_mode
 class PEFTModeDescriptor(ModeDescriptor):
+    """Mode descriptor for PEFT/LoRA model conversion."""
+
     @property
     def name(self) -> str:
+        """Return the mode identifier string."""
         return "peft"
 @PEFTModeRegistry.register_mode
 class ExportPEFTModeDescriptor(ModeDescriptor):
+    """Mode descriptor for exporting PEFT/LoRA models."""
 
     @property
     def name(self) -> str:
-        """Returns the value (str representation) of the mode."""
+        """Return the mode identifier string."""
         return "export_peft"
modelopt/torch/peft/config.py (3)

88-90: Incorrect error message for scale validation

The error message is missing "a" before "positive number".

         if v <= 0:
-            raise ValueError("scale must be positive number")
+            raise ValueError("scale must be a positive number")
         return v

99-117: Pickling validation may be too restrictive

The pickling requirement for initialization functions might be overly restrictive and prevent legitimate use cases like closures or partial functions. Consider documenting this requirement more prominently or providing alternative initialization strategies.

Consider adding a class method to create common initialization patterns that are guaranteed to be pickleable:

@classmethod
def create_normal_init(cls, mean=0.0, std=0.02):
    """Create a pickleable normal initialization function."""
    def normal_init(weight):
        return init.normal_(weight, mean=mean, std=std)
    return normal_init

162-170: Broad exception handling masks specific validation errors

Catching all exceptions and re-raising with a generic message loses valuable debugging information.

                 try:
                     validated_cfg[key] = PEFTAttributeConfig(**value)
-                except Exception as e:
-                    raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
+                except (TypeError, ValueError) as e:
+                    raise ValueError(f"Invalid adapter configuration for '{key}': {e}") from e
modelopt/torch/peft/lora/layer.py (2)

62-62: Typo in error message

The error message has a double period.

-            raise ValueError(f"adapter_name: {adapter_name} is already exist..")
+            raise ValueError(f"adapter_name: {adapter_name} already exists.")

193-230: Forward method has performance implications

The forward method iterates through all adapters on every forward pass, which could impact performance with many adapters. Consider caching active adapters.

Consider maintaining a list of active adapters to avoid checking the enable flag on every forward pass:

def _update_active_adapters(self):
    """Cache list of active adapters for efficient forward pass."""
    self._active_adapters = [
        (adapter["lora_a"], adapter["lora_b"], adapter["scale"])
        for adapter in self._lora_adapters.values()
        if adapter["enable"]
    ]

def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
    output = super().forward(x, *args, **kwargs)
    
    if isinstance(output, tuple):
        result = output[0]
        other_outputs = output[1:]
    else:
        result = output
        other_outputs = ()
    
    # Use cached active adapters
    for lora_a, lora_b, scale in getattr(self, '_active_adapters', []):
        lora_a_output = lora_a(x)
        if isinstance(lora_a_output, tuple):
            lora_a_output = lora_a_output[0]
        lora_b_output = lora_b(lora_a_output)
        if isinstance(lora_b_output, tuple):
            lora_b_output = lora_b_output[0]
        result = result + scale * lora_b_output
    
    return (result, *other_outputs) if other_outputs else result
modelopt/torch/peft/convert.py (3)

90-90: Unclear assertion message

The assertion message uses a non-standard abbreviation "MO-PEFT" without explanation.

-    assert is_peft_model(model), "It's not a MO-PEFT model"
+    assert is_peft_model(model), "Model has not been converted to PEFT/LoRA format"

101-102: Inconsistent error message formatting

The error message uses different formats for "pattern" vs "adapter pattern".

-                pattern_type = "pattern" if allow_callable else "adapter pattern"
-                raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}")
+                pattern_type = "layer pattern" if allow_callable else "adapter pattern"
+                raise TypeError(f"Unsupported {pattern_type} type: {type(pattern).__name__}")

111-119: Direct access to private attribute _lora_adapters

The function directly accesses the private _lora_adapters attribute, violating encapsulation. Consider adding a public method or property for adapter access.

Add a public interface in LoRAModule:

# In LoRAModule class
def get_adapter(self, adapter_name: str) -> dict[str, Any] | None:
    """Get adapter configuration by name."""
    return self._lora_adapters.get(adapter_name)

def set_adapter_state(self, adapter_name: str, enable: bool) -> None:
    """Set the enable state of an adapter."""
    if adapter_name in self._lora_adapters:
        self._lora_adapters[adapter_name]["enable"] = enable

Then update this function:

-            for adapter_name, adapter_dict in module._lora_adapters.items():
+            for adapter_name in module.adapter_names:
                 if adapter_patterns is not None:
                     if not matches_any_pattern(
                         adapter_name, adapter_patterns, allow_callable=False
                     ):
                         continue
 
-                adapter_dict["enable"] = enable_state
+                module.set_adapter_state(adapter_name, enable_state)
tests/gpu/torch/peft/test_megatron_peft.py (3)

93-121: Test helper lacks error handling

The model provider function doesn't handle potential errors during model creation, which could make test failures harder to debug.

 def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
     """Build the model."""
-
+    try:
         if meta_device:
             with torch.device("meta"):
                 gpt_model = get_mcore_gpt_model(
                     tensor_model_parallel_size=tp_size,
                     num_layers=4,
                     ffn_hidden_size=None,
                     num_attention_heads=4,
                     activation_func="squared_relu",
                     transformer_impl="local",
                     hidden_size=hidden_size,
                     vocab_size=vocab_size,
                     use_cpu_initialization=meta_device,
                 )
         else:
             gpt_model = get_mcore_gpt_model(
                 tensor_model_parallel_size=tp_size,
                 num_layers=4,
                 ffn_hidden_size=None,
                 num_attention_heads=4,
                 activation_func="squared_relu",
                 transformer_impl="local",
                 hidden_size=hidden_size,
                 vocab_size=vocab_size,
             ).cuda()
         return gpt_model.eval()
+    except Exception as e:
+        pytest.fail(f"Failed to create GPT model: {e}")

134-139: Conditional assertion logic could be clearer

The conditional logic for checking output equality based on config type is not immediately clear. Consider adding a comment explaining why DEFAULT_LORA_CFG_TEST should produce identical outputs.

     assert lora_output.shape == original_output.shape
+    # DEFAULT_LORA_CFG_TEST uses zero initialization for LoRA B, so initial output should match original
     if lora_config == DEFAULT_LORA_CFG_TEST:
         assert torch.allclose(lora_output, original_output, rtol=1e-5), (
             f"{lora_output}, {original_output}"
         )
     else:
         assert not torch.allclose(lora_output, original_output, rtol=1e-5)

161-174: Large number of commented test cases

Most test parameterizations are commented out. This suggests either incomplete implementation or test instability.

If these tests are not ready, consider:

  1. Removing them entirely and tracking in an issue
  2. Using pytest.skip with a reason
  3. Adding a TODO comment explaining why they're disabled
 @pytest.mark.parametrize(
     "lora_config",
     [
         DEFAULT_LORA_CFG_TEST,
-        # DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
-        # SMALL_RANK_LORA_CFG,
-        # LARGE_SCALE_LORA_CFG,
-        # SELECTIVE_LAYER_LORA_CFG,
+        pytest.param(DEFAULT_LORA_CFG_RANDOM_INIT_TEST, marks=pytest.mark.skip(reason="Not yet stable")),
+        pytest.param(SMALL_RANK_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")),
+        pytest.param(LARGE_SCALE_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")),
+        pytest.param(SELECTIVE_LAYER_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")),
     ],
 )
modelopt/torch/peft/conversion.py (3)

104-116: Return the possibly converted root module from replace_lora_module (and use it)

Safer if a root replacement is ever registered; also removes confusion around local reassignment.

 def replace_lora_module(
     model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry
 ):
     """Recursively replace the module with LoRA module."""
@@
-    if type(model) in registry:
-        model = registry.convert(model)
-    _replace_lora_module(model, version=version, registry=registry)
+    if type(model) in registry:
+        model = registry.convert(model)
+    _replace_lora_module(model, version=version, registry=registry)
+    return model

And in convert_to_peft_model:

-    replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)
+    model = replace_lora_module(
+        model, version=ModeloptStateManager(model).state_version, config=config
+    )

41-42: Remove stale TODO

The replacement is already performed.

-    # TODO: Replace to LoRA module

141-149: Docstring example uses an invalid signature

update_layer_lora takes a config object; adjust the example to avoid confusion.

-        ...         module.update_layer_lora("custom_adapter", rank=32)
+        ...         from modelopt.torch.peft.config import PEFTAttributeConfig
+        ...         module.update_layer_lora("custom_adapter", PEFTAttributeConfig(rank=32))
modelopt/torch/peft/lora/tp_layer.py (3)

37-48: _get_init_methods is unused; either use or remove

Prefer using it to ensure defaults when attr_config initializers are None (e.g., after metadata restore).

Example use (apply similarly in both update_layer_lora methods):

-        lora_a = ColumnParallelLinear(
+        lora_a_init, lora_b_init = self._get_init_methods(attr_config.lora_a_init, attr_config.lora_b_init)
+        lora_a = ColumnParallelLinear(
             self.input_size,
             attr_config.rank,
             config=self.config,
             bias=False,
             gather_output=True,
-            init_method=attr_config.lora_a_init,
+            init_method=lora_a_init,
             disable_grad_reduce=getattr(self.config, "sequence_parallel", False),
         )
@@
-        lora_b = ColumnParallelLinear(
+        lora_b = ColumnParallelLinear(
             attr_config.rank,
             self.output_size,
             config=self.config,
             bias=False,
             gather_output=False,  # Keep output distributed like base layer
-            init_method=attr_config.lora_a_init,
+            init_method=lora_b_init,
         )

80-87: Micro: combine device/dtype moves into a single .to call

Slight cleanup; avoids two passes.

-        if device is not None:
-            lora_a = lora_a.to(device)
-            lora_b = lora_b.to(device)
-        if dtype is not None:
-            lora_a = lora_a.to(dtype)
-            lora_b = lora_b.to(dtype)
+        if device is not None or dtype is not None:
+            lora_a = lora_a.to(device=device, dtype=dtype)
+            lora_b = lora_b.to(device=device, dtype=dtype)

26-28: Remove unused defaults

DEFAULT_LORA_RANK and DEFAULT_SCALE are unused.

-DEFAULT_LORA_RANK = 64
-DEFAULT_SCALE = 1.0
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b895dc5 and d9a79c1.

📒 Files selected for processing (12)
  • modelopt/torch/peft/__init__.py (1 hunks)
  • modelopt/torch/peft/config.py (1 hunks)
  • modelopt/torch/peft/conversion.py (1 hunks)
  • modelopt/torch/peft/convert.py (1 hunks)
  • modelopt/torch/peft/custom.py (1 hunks)
  • modelopt/torch/peft/lora/__init__.py (1 hunks)
  • modelopt/torch/peft/lora/layer.py (1 hunks)
  • modelopt/torch/peft/lora/tp_layer.py (1 hunks)
  • modelopt/torch/peft/mode.py (1 hunks)
  • modelopt/torch/peft/plugins/__init__.py (1 hunks)
  • modelopt/torch/peft/plugins/megatron.py (1 hunks)
  • tests/gpu/torch/peft/test_megatron_peft.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
modelopt/torch/peft/plugins/megatron.py (1)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/peft/lora/layer.py (3)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • _DMRegistryCls (917-1124)
  • config (1265-1278)
modelopt/torch/peft/config.py (1)
  • PEFTAttributeConfig (40-117)
modelopt/torch/peft/conversion.py (1)
  • peft_state (96-101)
modelopt/torch/peft/__init__.py (1)
modelopt/torch/peft/mode.py (2)
  • convert (31-32)
  • convert (66-68)
tests/gpu/torch/peft/test_megatron_peft.py (6)
tests/_test_utils/import_helper.py (1)
  • skip_if_no_megatron (46-77)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • get_mcore_gpt_model (133-208)
  • initialize_for_megatron (385-393)
modelopt/torch/peft/config.py (2)
  • kaiming_init (30-32)
  • zero_init (35-37)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (20-230)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/peft/convert.py (3)
  • update_model (39-63)
  • disable_adapters (121-149)
  • enable_adapters (152-180)
modelopt/torch/peft/convert.py (4)
modelopt/torch/opt/conversion.py (1)
  • apply_mode (342-429)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (124-170)
modelopt/torch/peft/conversion.py (1)
  • add_adapter (163-192)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (20-230)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (7)
  • ApplyModeError (314-315)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
  • is_converted (102-127)
  • _last_metadata (220-222)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (124-170)
modelopt/torch/peft/lora/layer.py (4)
  • LoRAModule (20-230)
  • set_from_peft_state (152-165)
  • get_peft_state (91-131)
  • update_layer_lora (72-89)
modelopt/torch/peft/custom.py (1)
  • register_custom_model_plugins_on_the_fly (22-29)
modelopt/torch/peft/mode.py (4)
modelopt/torch/opt/config.py (1)
  • ModeloptBaseConfig (59-147)
modelopt/torch/opt/mode.py (2)
  • ModeDescriptor (56-259)
  • _ModeRegistryCls (267-344)
modelopt/torch/peft/config.py (2)
  • PEFTConfig (124-170)
  • ExportPEFTConfig (173-174)
modelopt/torch/peft/conversion.py (5)
  • convert_to_peft_model (36-48)
  • restore_peft_model (51-55)
  • update_peft_metadata (91-93)
  • export_peft_model (118-119)
  • restore_export_peft_model (122-123)
modelopt/torch/peft/lora/tp_layer.py (4)
modelopt/torch/peft/config.py (1)
  • PEFTAttributeConfig (40-117)
modelopt/torch/peft/lora/layer.py (4)
  • LoRAModule (20-230)
  • _register_adapter (39-69)
  • update_layer_lora (72-89)
  • _setup (30-32)
modelopt/torch/quantization/plugins/megatron.py (2)
  • _MegatronColumnParallelLinear (296-318)
  • _MegatronRowParallelLinear (322-354)
modelopt/torch/quantization/conversion.py (1)
  • register (325-366)
🪛 GitHub Actions: Code Quality
modelopt/torch/peft/mode.py

[error] 1-1: D100 Missing docstring in public module


[error] 22-22: D101 Missing docstring in public class


[error] 24-24: D102 Missing docstring in public method


[error] 28-28: D102 Missing docstring in public method


[error] 32-32: D102 Missing docstring in public method


[error] 36-36: D102 Missing docstring in public method


[error] 40-40: D102 Missing docstring in public method


[error] 44-44: D102 Missing docstring in public method


[error] 53-53: D101 Missing docstring in public class

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
🔇 Additional comments (2)
modelopt/torch/peft/plugins/__init__.py (1)

20-21: LGTM: guarded optional import.

Import guarding with suppress(ImportError) is appropriate; it triggers registration only when Megatron is present.

modelopt/torch/peft/__init__.py (1)

18-25: Import-order OK — no action required. mode.py does from .conversion import convert_to_peft_model, ..., export_peft_model and its properties return those functions (modelopt/torch/peft/mode.py, ~lines 11–13, 31–33, 66–69).

@realAsma
Copy link
Contributor

@jingyu-ml

Transformers peft already seems to have a backend for Megatron Core - https://huggingface.co/docs/peft/v0.17.0/en/package_reference/lora#peft.LoraConfig.megatron_core. Have we tested peft MCore backend? Does it work? What does this PR add in addition to peft's MCore backend?

@cjluo-nv cjluo-nv requested a review from meenchen September 22, 2025 18:45
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve to unblock. Great work @jingyu-ml - please make sure the sample API usages are aligned with the design doc we have.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
tests/gpu/torch/peft/test_megatron_peft.py (1)

18-18: Critical: Add apex_or_te_required=True to skip decorator.

Line 18's skip_if_no_megatron() call is missing apex_or_te_required=True. Since _gpt_model_provider uses transformer_impl="local" (lines 164, 176), which requires Apex, environments lacking Apex will hit an assertion instead of gracefully skipping these GPU tests.

Apply this diff:

-skip_if_no_megatron()
+skip_if_no_megatron(apex_or_te_required=True)
🧹 Nitpick comments (1)
tests/gpu/torch/peft/test_megatron_peft.py (1)

153-180: Consider reducing duplication in _gpt_model_provider.

The meta_device and non-meta_device branches duplicate most of the get_mcore_gpt_model arguments. Consider extracting common kwargs to reduce duplication.

Example refactor:

 def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
     """Build the model."""
-
+    
+    kwargs = {
+        "tensor_model_parallel_size": tp_size,
+        "num_layers": 4,
+        "ffn_hidden_size": None,
+        "num_attention_heads": 4,
+        "activation_func": "squared_relu",
+        "transformer_impl": "local",
+        "hidden_size": hidden_size,
+        "vocab_size": vocab_size,
+        "use_cpu_initialization": meta_device,
+    }
+    
     if meta_device:
         with torch.device("meta"):
-            gpt_model = get_mcore_gpt_model(
-                tensor_model_parallel_size=tp_size,
-                num_layers=4,
-                ffn_hidden_size=None,
-                num_attention_heads=4,
-                activation_func="squared_relu",
-                transformer_impl="local",
-                hidden_size=hidden_size,
-                vocab_size=vocab_size,
-                use_cpu_initialization=meta_device,
-            )
+            gpt_model = get_mcore_gpt_model(**kwargs)
     else:
-        gpt_model = get_mcore_gpt_model(
-            tensor_model_parallel_size=tp_size,
-            num_layers=4,
-            ffn_hidden_size=None,
-            num_attention_heads=4,
-            activation_func="squared_relu",
-            transformer_impl="local",
-            hidden_size=hidden_size,
-            vocab_size=vocab_size,
-        ).cuda()
+        gpt_model = get_mcore_gpt_model(**kwargs).cuda()
     return gpt_model.eval()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 03807e2 and 9b96fea.

📒 Files selected for processing (1)
  • tests/gpu/torch/peft/test_megatron_peft.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/peft/test_megatron_peft.py (8)
tests/_test_utils/import_helper.py (1)
  • skip_if_no_megatron (46-77)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • get_mcore_gpt_model (133-208)
  • initialize_for_megatron (385-393)
modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (2)
  • restore_sharded_modelopt_state (207-250)
  • save_sharded_modelopt_state (127-173)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (37-142)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/peft/convert.py (3)
  • update_model (47-74)
  • disable_adapters (133-161)
  • enable_adapters (164-192)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/model_quant.py (1)
  • disable_quantizer (453-455)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs

def restore_peft_model(
model: ModelLikeModule, config: PEFTConfig, metadata: MetadataDict
) -> nn.Module:
model, _ = convert_to_peft_model(model, config)
Copy link
Collaborator

@ChenhanYu ChenhanYu Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restore is calling convert_to_peft_model which will freeze base model while resuming modelopt_state and before we create optimizer.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tests/gpu/torch/peft/test_megatron_peft.py (1)

18-18: Make Megatron tests skip when Apex/TE is missing

get_mcore_gpt_model(... transformer_impl="local") asserts Apex is present. Without apex_or_te_required=True, environments lacking Apex/TE crash instead of skipping the suite. Restore the guard so these GPU tests self-skip gracefully.

-skip_if_no_megatron()
+skip_if_no_megatron(apex_or_te_required=True)
🧹 Nitpick comments (3)
tests/gpu/torch/peft/test_megatron_peft.py (1)

572-779: Polish the quantizer expectation comments

The repeated comment “teh quantizer” both misspells “the” and in some cases states the wrong expectation (e.g., LoRA-after-quantize should not inherit quantizers, LoRA-before-quantize should). Please fix the spelling and align each comment with what the assertions actually check—for example:

-            # Check if the lora have teh quantizer, they should not have them.
+            # Check that the LoRA adapters do not carry quantizers.

and

-            # Check if the lora have teh quantizer, they should not have them.
+            # Check that the LoRA adapters keep their quantizers after quantization.

Update the other duplicated comments (lines 572, 624, 704, 769) accordingly to avoid confusion.

modelopt/torch/peft/conversion.py (2)

58-62: Metadata parameter is unused.

The metadata parameter is accepted but not used in the restore logic. Based on past review discussions, this appears intentional (per-module state was removed in favor of model-level restore via convert_to_peft_model). However, if the metadata will never be used, consider either:

  1. Removing the parameter from the signature (if backward compatibility is not a concern)
  2. Adding a docstring note explaining why metadata is currently unused but kept for future extensibility

87-93: Remove unused version parameter.

The version parameter is passed through the recursive call but never used in the function body. Remove it to simplify the signature, or add a comment explaining its planned future use.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9b96fea and 5030b43.

📒 Files selected for processing (2)
  • modelopt/torch/peft/conversion.py (1 hunks)
  • tests/gpu/torch/peft/test_megatron_peft.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/peft/conversion.py
🧬 Code graph analysis (2)
tests/gpu/torch/peft/test_megatron_peft.py (7)
tests/_test_utils/import_helper.py (1)
  • skip_if_no_megatron (46-77)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • get_mcore_gpt_model (133-208)
  • initialize_for_megatron (385-393)
modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (2)
  • restore_sharded_modelopt_state (207-250)
  • save_sharded_modelopt_state (127-173)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (37-142)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/peft/convert.py (3)
  • update_model (47-74)
  • disable_adapters (133-161)
  • enable_adapters (164-192)
modelopt/torch/quantization/model_quant.py (1)
  • disable_quantizer (453-455)
modelopt/torch/peft/conversion.py (4)
modelopt/torch/opt/conversion.py (4)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (95-155)
modelopt/torch/peft/lora/layer.py (2)
  • LoRAModule (37-142)
  • update_layer_lora (89-103)
modelopt/torch/peft/custom.py (1)
  • register_custom_model_plugins_on_the_fly (22-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/peft/conversion.py (4)

38-55: LGTM! Base model freezing correctly placed.

The function correctly freezes all base model parameters (line 44-46) before replacing modules with LoRA modules, then adds adapters and configures LoRA-specific gradients. This sequence ensures that all non-LoRA parameters (including embeddings, layer norms, etc.) are frozen when freeze_base_model=True.


177-196: LGTM! Base parameter filtering correctly excludes all LoRA parameters.

The function correctly collects LoRA parameter IDs from the entire model (line 180, layer_patterns=None) to ensure no LoRA parameters are affected when setting base requires_grad. Then it selectively applies the setting to base parameters based on layer_patterns (lines 192-196). This two-stage approach is correct.


297-313: LGTM! Clear separation of base and LoRA parameter gradient control.

The docstring clearly explains that this function only affects LoRA parameters, addressing past review concerns. The base model parameter gradients are correctly handled separately in convert_to_peft_model (lines 44-46) before LoRA module replacement.


65-77: Add docstring & confirm top-level conversion

  • Add a docstring to replace_lora_module explaining each parameter (model, version, config, registry) and that it mutates the model in-place.
  • Verify whether registry.convert(model) returns a new nn.Module; if it does, have replace_lora_module return that converted module (or update callers to use the returned value).

)


class _QuantLoRAMegatronColumnParallelLinear(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quantization part seems be better in the torch/quantization/plugins, since the HF peft quant is there.

Comment on lines +64 to +65
if not is_megatron_core_model(model):
raise ValueError("PEFT mode currently supports Megatron-Core models only.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why dont we create a LoRAModuleRegistry just like QuantModuleRegistry and register the supported modules to it? This way we can avoid hardcoding like this.

Comment on lines +74 to +90
if type(model) in registry:
model = registry.convert(model)
_replace_lora_module(model, version=version, registry=registry)


def export_peft_model(model: nn.Module, config):
raise NotImplementedError("Exporting a peft model is not supported yet.")


def restore_export_peft_model(model: nn.Module, config, metadata: MetadataDict):
raise NotImplementedError("Restoring a peft & exported model is not supported yet.")


def _replace_lora_module(model: nn.Module, version=None, registry=LoRAModuleRegistry):
for name, child in model.named_children():
if type(child) in registry:
lora_module = registry.convert(child)
Copy link

@meenchen meenchen Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At line 88, we can use model.named_modules() instead of model.named_children(), so lines 74 and 75 can be removed.

for name, module in model.named_modules():
if isinstance(module, LoRAModule):
# Collect all matching adapter settings and merge them
# Later patterns override earlier ones
Copy link

@meenchen meenchen Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add this note to the config description as well.

(lora_a is a regular nn.Linear and is not sharded)
- lora_b weight: sharded at dim 0
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do for each lora_b:
sharded_state_dict.update(lora_b.sharded_stated_dict(..., ..))

@realAsma realAsma self-requested a review October 2, 2025 22:48
@jingyu-ml jingyu-ml requested a review from Fridah-nv October 4, 2025 00:06
@jingyu-ml jingyu-ml requested a review from a team as a code owner October 4, 2025 00:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants