Skip to content

Add logging of the layers that were intercepted by LoRA (incl. names + paths). #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion gemma/gm/nn/_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dataclasses
import functools
from typing import Any

from absl import logging
from flax import linen as nn
from gemma import peft
from gemma.gm.nn import _layers
Expand All @@ -27,6 +27,9 @@
import numpy as np


_SUPPORTED_MODULES = (nn.Dense, nn.Einsum, nn.DenseGeneral, _layers.Einsum)


class LoRA(nn.Module):
"""Wrapper around a Gemma model to enable LoRA.

Expand All @@ -38,13 +41,15 @@ class LoRA(nn.Module):
rank: The rank of the LoRA decomposition.
model: The model to wrap.
dtype: The dtype to use for the LoRA weights.
verbose: If `True`, logs diagnostic strings for the LoRA layers.
"""

_: dataclasses.KW_ONLY

rank: int
model: nn.Module
dtype: jnp.dtype = jnp.bfloat16
verbose: bool = False

def __post_init__(self):
super().__post_init__()
Expand All @@ -61,6 +66,7 @@ def __call__(self, *args, **kwargs):
_replace_by_lora,
rank=self.rank,
dtype=self.dtype,
verbose=self.verbose,
)
with peft.ModuleInterceptor(replace_module_fn):
return self.model(*args, **kwargs)
Expand All @@ -81,19 +87,34 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.model, name)


def _lora_debug_string(module: nn.Module) -> str | None:
if isinstance(module, _SUPPORTED_MODULES):
return f'[LoRA] {type(module).__name__} ({module.name}) <- {module.path}'
else:
return None


def _replace_by_lora(
module: nn.Module,
*,
rank: int,
dtype: np.dtype,
verbose: bool,
) -> nn.Module:
"""Replaces compatible modules by their LoRA version."""
if verbose:
debug_str = _lora_debug_string(module)
if debug_str:
logging.info(debug_str)

# TODO(epot): Replace by generic LoRA wrapper ?
match module:
case nn.Dense():
return peft.LoRADense(rank=rank, dtype=dtype, wrapped=module)
case nn.Einsum():
return peft.LoRAEinsum(rank=rank, dtype=dtype, wrapped=module)
case nn.DenseGeneral():
return peft.LoRADenseGeneral(rank=rank, dtype=dtype, wrapped=module)
case _layers.Einsum():
# This hack is required because the FeedForward layer call two different
# Einsum with using `nn.share_scope`, so the two wrappers need a different
Expand Down