diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index ecee4aedf1..3ecc3e7e9b 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -116,6 +116,8 @@ title: VeRA - local: package_reference/fourierft title: FourierFT + - local: package_reference/gralora + title: GraLoRA - local: package_reference/vblora title: VB-LoRA - local: package_reference/hra diff --git a/docs/source/package_reference/gralora.md b/docs/source/package_reference/gralora.md new file mode 100644 index 0000000000..3d499756c1 --- /dev/null +++ b/docs/source/package_reference/gralora.md @@ -0,0 +1,32 @@ +# GraLoRA + +[**Granular Low-Rank Adaptation (GraLoRA)**](https://huggingface.co/papers/2505.20355) is a PEFT method designed to enhance the **expressivity** of low-rank adaptation while improving **robustness to outlier** activations, based on insights from well-known issues in quantization. + +![GraLoRA Overview](https://github.com/SqueezeBits/GraLoRA/raw/main/figure/gralora_overview.png) + +Unlike standard LoRA, which applies a single low-rank adapter across the entire feature space, GraLoRA introduces a structured and fine-grained adaptation scheme. It divides the adaptation space into a grid of $𝑘^2$ smaller, independent adapter pairs, each responsible for a localized subset of the input and output dimensions. As a result, each adapter operates on a subspace that is $k$ times smaller in both dimensions than the original LoRA adapter. + +This granular decomposition enables spatially localized and context-aware updates, effectively increasing representational capacity without additional parameters or computational cost. By isolating the influence of extreme activations within smaller subspaces, GraLoRA mitigates gradient distortion and preserves inter-channel balance during adaptation. + +--- + +The abstract from the paper is: + +*Low-Rank Adaptation (LoRA) is a popular method for parameter-efficient fine- +tuning (PEFT) of generative models, valued for its simplicity and effectiveness. +Despite recent enhancements, LoRA still suffers from a fundamental limitation: +overfitting when the bottleneck is widened. It performs best at ranks 32–64, yet its +accuracy stagnates or declines at higher ranks, still falling short of full fine-tuning +(FFT) performance. We identify the root cause as LoRA’s structural bottleneck, +which introduces gradient entanglement to the unrelated input channels and distorts +gradient propagation. To address this, we introduce a novel structure, Granular +Low-Rank Adaptation (GraLoRA) that partitions weight matrices into sub-blocks, +each with its own low-rank adapter. With negligible computational or storage cost, +GraLoRA overcomes LoRA’s limitations, effectively increases the representational +capacity, and more closely approximates FFT behavior. Experiments on code +generation, commonsense reasoning, mathematical reasoning, general language +understanding, and image generation benchmarks show that GraLoRA consistently +outperforms LoRA and other baselines, achieving up to +8.5% absolute gain in +Pass@1 on HumanEval+. These improvements hold across model sizes and rank +settings, making GraLoRA a scalable and robust solution for PEFT.* + diff --git a/examples/gralora_finetuning/README.md b/examples/gralora_finetuning/README.md new file mode 100644 index 0000000000..616141e1c8 --- /dev/null +++ b/examples/gralora_finetuning/README.md @@ -0,0 +1,73 @@ +# GraLoRA: Granular Low-Rank Adaptation + +![GraLoRA Overview](https://github.com/SqueezeBits/GraLoRA/raw/main/figure/gralora_overview.png) + +## Introduction +[**Granular Low-Rank Adaptation (GraLoRA)**](https://huggingface.co/papers/2505.20355) is a PEFT method designed to enhance the **expressivity** of low-rank adaptation while improving **robustness to outlier** activations, based on insights from well-known issues in quantization. + +GraLoRA introduces a structured and fine-grained adaptation scheme. It divides the adaptation space into a grid of $𝑘^2$ smaller, independent adapter pairs, each responsible for a localized subset of the input and output dimensions. + +## Quick start + +With respect to your standard PEFT training procedure with LoRA, simply swap your `LoraConfig` for a `GraloraConfig`. + +```python +import torch +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import SFTTrainer, SFTConfig +from peft import GraloraConfig + +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B") +dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") +gralora_config = GraloraConfig() + +trainer = SFTTrainer( + model=model, + train_dataset=dataset, + processing_class=tokenizer, + peft_config=gralora_config, + args=SFTConfig( + max_length=2048, + dataset_text_field="text", + per_device_train_batch_size=2, + ), +) +trainer.train() +trainer.model.save_pretrained("gralora-llama-3.2-3b") +``` + +Run the finetuning script simply by running: +```sh +python examples/gralora_finetuning/gralora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco +``` + +## Use the model on 🤗 +You can load and use the model as any other 🤗 models. +```python +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto" +) +peft_model = PeftModel.from_pretrained(model, "gralora-llama-3-8b") +``` + +## Additional Notes +While `gralora_k` is set to 2 for default, you can increase this value to create more fine-grained adapters. `gralora_k` of 4 is recommended when the total rank (`r + hybrid_r`) is 64 or higher. + +## Citation +``` +@misc{jung2025graloragranularlowrankadaptation, + title={GraLoRA: Granular Low-Rank Adaptation for Parameter-Efficient Fine-Tuning}, + author={Yeonjoon Jung and Daehyun Ahn and Hyungjun Kim and Taesu Kim and Eunhyeok Park}, + year={2025}, + eprint={2505.20355}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2505.20355}, +} +``` diff --git a/examples/gralora_finetuning/gralora_finetuning.py b/examples/gralora_finetuning/gralora_finetuning.py new file mode 100644 index 0000000000..e02ab5705b --- /dev/null +++ b/examples/gralora_finetuning/gralora_finetuning.py @@ -0,0 +1,190 @@ +# This script is based on examples/dora_finetuning/dora_finetuning.py +import os + +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) + +from peft import GraloraConfig, get_peft_model, prepare_model_for_kbit_training + + +def train_model( + base_model: str, + data_path: str, + output_dir: str, + batch_size: int, + num_epochs: int, + learning_rate: float, + cutoff_len: int, + val_set_size: int, + eval_step: int, + save_step: int, + device: str, + gralora_r: int, + gralora_alpha: int, + gralora_dropout: float, + gralora_target_modules: str, + gralora_k: int, + hybrid_r: int, + hub_model_id: str, + push_to_hub: bool, +): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + hf_token = os.getenv("HF_TOKEN") + + # Setup device + if device == "auto": + device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" + else: + device = torch.device(device) + print(f"Using device: {device}") + + # load tokenizer + tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) + + model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token) + # GraLoRA config for the PEFT model + gralora_config = GraloraConfig( + r=gralora_r, # Rank of matrix + gralora_alpha=gralora_alpha, + target_modules=( + gralora_target_modules.split(",") + if gralora_target_modules + else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + ), + gralora_dropout=gralora_dropout, + gralora_k=gralora_k, + hybrid_r=hybrid_r, + bias="none", + ) + + # get the peft model with GraLoRA config + model = get_peft_model(model, gralora_config) + + model.to(device) # MODEL TO GPU/CUDA + tokenizer.pad_token = tokenizer.eos_token + + # Load the dataset + dataset = load_dataset(data_path) + + def tokenize_function(examples): + inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len) + inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task + return inputs + + # Tokenize the dataset and prepare for training + tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) + + # Data collator to dynamically pad the batched examples + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + + # Define training arguments + training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_steps=100, + weight_decay=0.01, + logging_steps=eval_step, + save_steps=save_step, + save_total_limit=2, + push_to_hub=push_to_hub, + hub_model_id=hub_model_id, + gradient_accumulation_steps=16, + fp16=True, + learning_rate=learning_rate, + hub_token=hf_token, + ) + + # Clear device cache to free memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() + + # Initialize the Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["test"], + data_collator=data_collator, + ) + + # Start model training + trainer.train() + + # Save and push the trained model and tokenizer + if push_to_hub: + # Push the main model to the hub + trainer.push_to_hub(commit_message="Fine-tuned model") + + # Save the model and tokenizer locally + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Fine-tune LLaMA with GraLoRA and PEFT") + parser.add_argument("--base_model", type=str, default="meta-llama/Llama-3.2-3B", help="Base model path or name") + parser.add_argument( + "--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name" + ) + parser.add_argument( + "--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" + ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization") + parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size") + parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") + parser.add_argument("--save_step", type=int, default=100, help="Save step interval") + parser.add_argument("--device", type=str, default="auto", help="Device to use for training") + parser.add_argument("--gralora_r", type=int, default=8, help="LoRA rank") + parser.add_argument("--gralora_alpha", type=int, default=16, help="LoRA alpha") + parser.add_argument("--gralora_dropout", type=float, default=0.05, help="LoRA dropout rate") + parser.add_argument( + "--gralora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA" + ) + parser.add_argument("--gralora_k", type=int, default=2, help="GraLoRA k") + parser.add_argument("--hybrid_r", type=int, default=0, help="Hybrid rank") + parser.add_argument( + "--hub_model_id", + type=str, + default="path/to/repo", + help="Repository name to push the model on the Hugging Face Hub", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub") + args = parser.parse_args() + train_model( + base_model=args.base_model, + data_path=args.data_path, + output_dir=args.output_dir, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + cutoff_len=args.cutoff_len, + val_set_size=args.val_set_size, + eval_step=args.eval_step, + save_step=args.save_step, + device=args.device, + gralora_r=args.gralora_r, + gralora_alpha=args.gralora_alpha, + gralora_dropout=args.gralora_dropout, + gralora_target_modules=args.gralora_target_modules, + gralora_k=args.gralora_k, + hybrid_r=args.hybrid_r, + hub_model_id=args.hub_model_id, + push_to_hub=args.push_to_hub, + ) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index f8fdd48ff0..9a89b19554 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -64,6 +64,8 @@ EvaConfig, FourierFTConfig, FourierFTModel, + GraloraConfig, + GraloraModel, HRAConfig, HRAModel, IA3Config, @@ -163,6 +165,8 @@ "EvaConfig", "FourierFTConfig", "FourierFTModel", + "GraloraConfig", + "GraloraModel", "HRAConfig", "HRAModel", "IA3Config", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 3bf53d7da9..364bbb8fb2 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -20,6 +20,7 @@ from .cpt import CPTConfig, CPTEmbedding from .delora import DeloraConfig, DeloraModel from .fourierft import FourierFTConfig, FourierFTModel +from .gralora import GraloraConfig, GraloraModel from .hra import HRAConfig, HRAModel from .ia3 import IA3Config, IA3Model from .ln_tuning import LNTuningConfig, LNTuningModel @@ -74,6 +75,8 @@ "EvaConfig", "FourierFTConfig", "FourierFTModel", + "GraloraConfig", + "GraloraModel", "HRAConfig", "HRAModel", "IA3Config", diff --git a/src/peft/tuners/gralora/__init__.py b/src/peft/tuners/gralora/__init__.py new file mode 100644 index 0000000000..830e0a477c --- /dev/null +++ b/src/peft/tuners/gralora/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.utils import register_peft_method + +from .config import GraloraConfig +from .layer import GraloraLayer +from .model import GraloraModel + + +__all__ = ["GraloraConfig", "GraloraLayer", "GraloraModel"] + +register_peft_method(name="gralora", config_cls=GraloraConfig, model_cls=GraloraModel) diff --git a/src/peft/tuners/gralora/config.py b/src/peft/tuners/gralora/config.py new file mode 100644 index 0000000000..1458bca3e2 --- /dev/null +++ b/src/peft/tuners/gralora/config.py @@ -0,0 +1,183 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class GraloraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`GraloraModel`]. + + Args: + r (`int`): + GraLoRA attention dimension determines the rank of the GraLoRA adapter. The total parameter count of the + GraLoRA adapter is same as LoRA with same rank r, while the expressivitiy is multiplied by gralora_k. + hybrid_r (`int`): + Hybrid GraLoRA rank determines the rank allocated to vanilla LoRA method when using Hybrid GraLoRA method. + Hybrid GraLoRA, a combination of GraLoRA and vanilla LoRA, becomes available when hybrid_r > 0. The + parameter count of the GraLoRA adapter is r + hybrid_r. + target_modules (`Union[List[str], str]`): + List of module names or regex expression of the module names to replace with GraLoRA. " For example, ['q', + 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " This can also be a wildcard 'all-linear' + which matches all linear/Conv1D " "(if the model is a PreTrainedModel, the output layer excluded). " If not + specified, modules will be chosen according to the model architecture, If the architecture is " not known, + an error will be raised -- in this case, you should specify the target modules manually. " To avoid + targeting any modules (because you want to apply `target_parameters`), set " `target_modules=[]`. + gralora_alpha (`int`): GraLoRA alpha. + GraLoRA alpha is the scaling factor for the GraLoRA adapter. Scale becomes gralora_alpha / (r + hybrid_r). + gralora_dropout (`float`): + GraLoRA dropout is the dropout probability for the GraLoRA adapter. It is used to prevent overfitting and + improve the generalization of the GraLoRA adapter. + gralora_k (`int`): + GraLoRA k determines the number of subblocks in the GraLoRA adapter. The rank r must be divisible by + gralora_k for the GraLoRA adapter to be valid. The total parameter count is preserved regardles of + gralora_k. The entire rank of the GraLoRA adapter is increased by gralora_k, while the rank of each + subblock is reduced by gralora_k. gralora_k=2 is recommended for rank 32 or lower, and gralora_k=4 is + recommended for rank 64 or higher. + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type for gralora. Can be 'none', 'all' or 'gralora_only'. If 'all' or 'gralora_only', the + corresponding biases will be updated during training. Be aware that this means that, even when disabling + the adapters, the model will not produce the same output as the base model would have without adaptation. + init_weights (`bool`): + Whether to initialize the weights of the GraLoRA layers with their default initialization. Don't change + this setting, except if you know exactly what you're doing. + layers_to_transform (`Union[List[int], int]`): + The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes + that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at + this index. This only works when target_modules is a list of str. + layers_pattern (`Optional[Union[List[str], str]]`): + The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is + not in the common layers pattern. This only works when target_modules is a list of str. This should target + the `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`. + """ + + r: int = field( + default=32, + metadata={ + "help": ( + "GraLoRA attention dimension determines the rank of the GraLoRA adapter. " + "The total parameter count of the GraLoRA adapter is same as LoRA with same rank r, while the expressivitiy is multiplied by gralora_k." + ) + }, + ) + hybrid_r: int = field( + default=0, + metadata={ + "help": ( + "hybrid_r is the rank allocated to vanilla LoRA method when using Hybrid GraLoRA method. " + "Hybrid GraLoRA, a combination of GraLoRA and vanilla LoRA, becomes available when hybrid_r > 0. " + "r + hybrid_r determines the parameter count of the GraLoRA adapter." + ) + }, + ) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with LoRA. " + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D " + "(if the model is a PreTrainedModel, the output layer excluded). " + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you should specify the target modules manually. " + "To avoid targeting any modules (because you want to apply `target_parameters`), set " + "`target_modules=[]`." + ) + }, + ) + gralora_alpha: int = field( + default=64, + metadata={ + "help": ( + "gralora alpha is the scaling factor for the GraLoRA adapter. " + "Scale becomes gralora_alpha / (r + hybrid_r). " + ) + }, + ) + gralora_dropout: float = field(default=0.0, metadata={"help": "gralora dropout"}) + gralora_k: int = field( + default=2, + metadata={ + "help": ( + "gralora_k determines the number of subblocks in the GraLoRA adapter. " + "The rank r must be divisible by gralora_k for the GraLoRA adapter to be valid. " + "The total parameter count is preserved regardles of gralora_k. " + "The entire rank of the GraLoRA adapter is increased by gralora_k, while the rank of each subblock is reduced by gralora_k. " + "gralora_k=2 is recommended for rank 32 or lower, and gralora_k=4 is recommended for rank 64 or higher. " + ) + }, + ) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: str = field( + default="none", metadata={"help": "Bias type for gralora. Can be 'none', 'all' or 'gralora_only'"} + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": ( + "List of modules apart from gralora layers to be set as trainable and saved in the final checkpoint. For" + " example, in Sequence Classification or Token Classification tasks, the final layer" + " `classifier/score` are randomly initialized and as such need to be trainable and saved." + ) + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the GraLoRA layers with their default initialization. " + "Don't change this setting, except if you know exactly what you're doing." + ) + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": ( + "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. " + "If a single integer is passed, PEFT will transform only the layer at this index. " + "This only works when target_modules is a list of str." + ) + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern. " + "This only works when target_modules is a list of str. This should target the `nn.ModuleList` of the " + "model, which is often called `'layers'` or `'h'`." + ) + }, + ) + + def __post_init__(self): + super().__post_init__() + self.peft_type = PeftType.GRALORA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + if self.r % self.gralora_k != 0: + raise ValueError(f"r should be divisible by gralora_k, but got {self.r} and {self.gralora_k}") diff --git a/src/peft/tuners/gralora/layer.py b/src/peft/tuners/gralora/layer.py new file mode 100644 index 0000000000..d6f78665f0 --- /dev/null +++ b/src/peft/tuners/gralora/layer.py @@ -0,0 +1,394 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils.other import transpose + + +class GraloraLayer(BaseTunerLayer): + # List all names of layers that may contain adapter weight + adapter_layer_names = ("gralora_A", "gralora_B", "gralora_A_general", "gralora_B_general") + other_param_names = ("r", "hybrid_r", "gralora_alpha", "scaling", "gralora_dropout") + + def __init__(self, base_layer: nn.Module, **kwargs): + self.base_layer = base_layer + self.r = {} + self.gralora_alpha = {} + self.gralora_k = {} + self.hybrid_r = {} + self.scaling = {} + self.gralora_dropout = nn.ModuleDict({}) + + self.gralora_A = nn.ParameterDict({}) + self.gralora_B = nn.ParameterDict({}) + self.gralora_A_general = nn.ModuleDict({}) + self.gralora_B_general = nn.ModuleDict({}) + + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + else: + raise NotImplementedError(f"Unsupported layer type {type(base_layer)}") + + self.in_features = in_features + self.out_features = out_features + self.kwargs = kwargs + + def update_layer( + self, + adapter_name, + module_name, + r, + gralora_alpha, + gralora_dropout, + gralora_k: int = 2, + hybrid_r: int = 0, + init_weights: bool = True, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + elif hybrid_r < 0: + raise ValueError(f"`hybrid_r` should be a non-negative integer value but the value passed is {hybrid_r}") + + self.r[adapter_name] = r + self.gralora_alpha[adapter_name] = gralora_alpha + self.gralora_k[adapter_name] = gralora_k + self.hybrid_r[adapter_name] = hybrid_r + + if gralora_dropout > 0.0: + gralora_dropout_layer = nn.Dropout(p=gralora_dropout) + else: + gralora_dropout_layer = nn.Identity() + + self.gralora_dropout.update(nn.ModuleDict({adapter_name: gralora_dropout_layer})) + + # Actual trainable parameters + if self.in_features % gralora_k != 0: + raise ValueError( + f"in_features should be divisible by gralora_k, but got {self.in_features} and {gralora_k}" + ) + if self.out_features % gralora_k != 0: + raise ValueError( + f"out_features should be divisible by gralora_k, but got {self.out_features} and {gralora_k}" + ) + subblock_in_features = self.in_features // gralora_k + subblock_out_features = self.out_features // gralora_k + + # gralora_r is the rank allocated to GraLoRA method; hybrid_r is the rank allocated to vanilla LoRA + gralora_r = r + + gralora_A = [] + gralora_B = [] + for _ in range(gralora_k): + new_A = nn.Parameter(torch.empty(gralora_r, subblock_in_features)) + new_B = nn.Parameter(torch.empty(subblock_out_features, gralora_r)) + if init_weights: + # Initialize to identity: A is random, B is zero + nn.init.kaiming_uniform_(new_A, a=math.sqrt(5)) + nn.init.zeros_(new_B) + else: + # Initialize to random: both A and B are random (for testing) + nn.init.kaiming_uniform_(new_A, a=math.sqrt(5)) + nn.init.kaiming_uniform_(new_B, a=math.sqrt(5)) + gralora_A.append(new_A) + gralora_B.append(new_B) + # stack A and B and transpose to get the final shape + gralora_A = torch.stack(tuple(gralora_A), dim=0) # [N, gralora_r, in_features//N] + gralora_A = gralora_A.transpose(1, 2).contiguous() # [N, in_features//N, gralora_r] + + gralora_B = torch.stack(tuple(gralora_B), dim=0) # [N, out_features//N, gralora_r] + gralora_B = gralora_B.transpose(1, 2).contiguous() # [N, gralora_r, out_features//N] + + if hybrid_r > 0: + general_gralora_A = nn.Linear(self.in_features, hybrid_r, bias=False) + general_gralora_B = nn.Linear(hybrid_r, self.out_features, bias=False) + if init_weights: + # Initialize to identity: A is random, B is zero + nn.init.kaiming_uniform_(general_gralora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(general_gralora_B.weight) + else: + # Initialize to random: both A and B are random (for testing) + nn.init.kaiming_uniform_(general_gralora_A.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(general_gralora_B.weight, a=math.sqrt(5)) + else: + general_gralora_A = nn.Identity() + general_gralora_B = nn.Identity() + + self.gralora_A[adapter_name] = gralora_A + self.gralora_B[adapter_name] = gralora_B + self.gralora_A_general[adapter_name] = general_gralora_A + self.gralora_B_general[adapter_name] = general_gralora_B + + self.module_name = module_name + + self.scaling[adapter_name] = gralora_alpha / (gralora_r + hybrid_r) + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + +class Linear(nn.Linear, GraloraLayer): + # Gralora implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + module_name, + r: int = 0, + gralora_alpha: int = 1, + gralora_dropout: float = 0.0, + gralora_k: int = 2, + hybrid_r: int = 0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: bool = True, + **kwargs, + ) -> None: + # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called + super(nn.Linear, self).__init__() + GraloraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, module_name, r, gralora_alpha, gralora_dropout, gralora_k, hybrid_r, init_weights + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + from peft.tuners.tuners_utils import check_adapters_to_merge + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.gralora_A.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + orig_weights += delta_weight + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + delta_weight = self.get_delta_weight(active_adapter) + base_layer.weight.data += delta_weight + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.gralora_A.keys(): + delta_weight = self.get_delta_weight(active_adapter) + self.get_base_layer().weight.data -= delta_weight + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for GraLoRA adapter. + + GraLoRA applies block-wise low-rank adaptation with information exchange. This method computes the equivalent + weight matrix that would be added to the base weight during merge. + + Args: + adapter (str): The name of the adapter + + Returns: + torch.Tensor: The delta weight matrix with shape [out_features, in_features] + """ + gralora_A = self.gralora_A[adapter] # [N, in_features//N, rank] + gralora_B = self.gralora_B[adapter] # [N, rank, out_features//N] + gralora_A_general = self.gralora_A_general[adapter] + gralora_B_general = self.gralora_B_general[adapter] + + device = gralora_A.device + dtype = gralora_A.dtype + + gralora_k = self.gralora_k[adapter] + hybrid_r = self.hybrid_r[adapter] + r = self.r[adapter] + + # Handle CPU fp16/bf16 casting + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + if cast_to_fp32: + gralora_A = gralora_A.float() + gralora_B = gralora_B.float() + + # Get dimensions + in_features = self.in_features + out_features = self.out_features + gralora_rank = r + subblock_gralora_rank = gralora_rank // gralora_k + + # scatter gralora_A to get the scattered weight matrix + l_indices = torch.arange(in_features, device=device) + n_indices = l_indices // (in_features // gralora_k) + i_indices = l_indices % (in_features // gralora_k) + gralora_A_scattered = torch.zeros( + in_features, gralora_k, gralora_rank, device=device, dtype=torch.float32 if cast_to_fp32 else dtype + ) + gralora_A_scattered.scatter_( + 1, + n_indices.unsqueeze(1).unsqueeze(2).expand(-1, 1, gralora_rank), + gralora_A[n_indices, i_indices, :].unsqueeze(1), + ) + + # compute the delta weight + delta_weight = ( + torch.einsum( + "ikr, kro -> iko", + gralora_A_scattered.view(in_features, gralora_k, gralora_k, subblock_gralora_rank) + .permute(0, 2, 1, 3) + .reshape(in_features, gralora_k, gralora_rank), + gralora_B, + ) + .reshape(in_features, out_features) + .T + ) + + # Add hybrid LoRA component if present + if hybrid_r > 0: + weight_A_general = gralora_A_general.weight # [hybrid_r, in_features] + weight_B_general = gralora_B_general.weight # [out_features, hybrid_r] + + if cast_to_fp32: + weight_A_general = weight_A_general.float() + weight_B_general = weight_B_general.float() + + # Compute delta for hybrid part: [out_features, hybrid_r] @ [hybrid_r, in_features] + delta_weight += weight_B_general @ weight_A_general + + # Apply scaling and transpose if needed + delta_weight = transpose(delta_weight, self.fan_in_fan_out) * self.scaling[adapter] + + # Cast back if needed + if cast_to_fp32: + delta_weight = delta_weight.to(dtype=dtype) + + return delta_weight + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + # Handle 2D input: [batch, features] -> [batch, 1, features] + # This is common for MLPs and other non-sequence models + x_is_2d = x.ndim == 2 + if x_is_2d: + x = x.unsqueeze(1) # [B, F] -> [B, 1, F] + + for active_adapter in self.active_adapters: + if active_adapter not in self.gralora_A.keys(): + continue + gralora_A = self.gralora_A[active_adapter] + gralora_B = self.gralora_B[active_adapter] + + gralora_A_general = self.gralora_A_general[active_adapter] + gralora_B_general = self.gralora_B_general[active_adapter] + + r = self.r[active_adapter] + gralora_rank = r + gralora_k = self.gralora_k[active_adapter] + hybrid_r = self.hybrid_r[active_adapter] + + dropout = self.gralora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + gralora_dtype = gralora_A.dtype + + B, L, in_features = x.shape + N = gralora_k + subblock_gralora_rank = gralora_rank // N + + output = torch.einsum( + "bljr, jro -> bljo", + torch.einsum( + "blni, nir -> blnr", + dropout(x.to(gralora_dtype)).view(B, L, N, in_features // N), + gralora_A, + ) + .view(B, L, N, N, subblock_gralora_rank) + .permute(0, 1, 3, 2, 4) + .reshape(B, L, N, N * subblock_gralora_rank), + gralora_B, + ).reshape(B, L, -1) + + # Squeeze back to 2D if input was 2D + if x_is_2d: + output = output.squeeze(1) # [B, 1, F] -> [B, F] + + result += scaling * output.to(torch_result_dtype) + if hybrid_r > 0: + hybrid_output = gralora_B_general(gralora_A_general(dropout(x.to(gralora_dtype)))) + if x_is_2d: + hybrid_output = hybrid_output.squeeze(1) + result += scaling * hybrid_output.to(torch_result_dtype) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "gralora." + rep diff --git a/src/peft/tuners/gralora/model.py b/src/peft/tuners/gralora/model.py new file mode 100644 index 0000000000..23a25d4c9c --- /dev/null +++ b/src/peft/tuners/gralora/model.py @@ -0,0 +1,142 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, +) + +from .layer import GraloraLayer, Linear + + +class GraloraModel(BaseTuner): + """ + Creates Vector-based Random Matrix Adaptation (Gralora) model from a pretrained transformers model. + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be adapted. + config ([`GraloraConfig`]): The configuration of the Gralora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The Gralora model. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import GraloraConfig, get_peft_model + + >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> config = GraloraConfig(r=128) + >>> model = get_peft_model(base_model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`GraloraConfig`]): The configuration of the Gralora model. + """ + + # The unique prefix for GraLoRA method + prefix: str = "gralora_" + # The class of tuner layer for GraLoRA method + tuner_layer_cls = GraloraLayer + + target_module_mapping = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING + + def _create_and_replace( + self, + gralora_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + r = gralora_config.r + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": r, + "gralora_alpha": gralora_config.gralora_alpha, + "gralora_dropout": gralora_config.gralora_dropout, + "gralora_k": gralora_config.gralora_k, + "fan_in_fan_out": gralora_config.fan_in_fan_out, + "hybrid_r": gralora_config.hybrid_r, + "init_weights": gralora_config.init_weights, + } + kwargs["bias"] = bias + + if isinstance(target, Linear): + target.update_layer( + adapter_name, + current_key, + r, + gralora_config.gralora_alpha, + gralora_config.gralora_dropout, + gralora_config.gralora_k, + gralora_config.hybrid_r, + gralora_config.init_weights, + ) + else: + new_module = self._create_new_module(gralora_config, adapter_name, target, current_key, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _create_new_module(gralora_config, adapter_name, target, module_name, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = gralora_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + kwargs["is_target_conv_1d_layer"] = True + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = gralora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`." + ) + new_module = Linear( + target, + adapter_name, + module_name, + **kwargs, + ) + + return new_module diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 8f55a8f2b8..ddac0c8c70 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -48,6 +48,7 @@ class PeftType(str, enum.Enum): - WAVEFT - OSF - DELORA + - GRALORA """ PROMPT_TUNING = "PROMPT_TUNING" @@ -80,6 +81,7 @@ class PeftType(str, enum.Enum): WAVEFT = "WAVEFT" OSF = "OSF" DELORA = "DELORA" + GRALORA = "GRALORA" class TaskType(str, enum.Enum): diff --git a/tests/test_config.py b/tests/test_config.py index 9277d3bb68..5cb7523d84 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -27,6 +27,7 @@ BoneConfig, C3AConfig, FourierFTConfig, + GraloraConfig, HRAConfig, IA3Config, LNTuningConfig, @@ -64,6 +65,7 @@ (BoneConfig, {}), (C3AConfig, {}), (FourierFTConfig, {}), + (GraloraConfig, {}), (HRAConfig, {}), (IA3Config, {}), (LNTuningConfig, {}), diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index ed83db98cb..8d9820a195 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -38,6 +38,7 @@ C3AConfig, DeloraConfig, FourierFTConfig, + GraloraConfig, HRAConfig, IA3Config, LNTuningConfig, @@ -666,6 +667,37 @@ "init_weights": True, }, ), + ########### + # GraLoRA # + ########### + ("Vanilla MLP 1 GraLoRA", "MLP", GraloraConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 3 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin1"]}), + ("Vanilla MLP 4 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0", "lin1"]}), + ( + "Vanilla MLP 5 GraLoRA", + "MLP", + GraloraConfig, + {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, + ), + ( + "Vanilla MLP 6 GraLoRA", + "MLP", + GraloraConfig, + {"target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"]}, + ), + ( + "Vanilla MLP 7 Hybrid GraLoRA", + "MLP", + GraloraConfig, + {"target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"], "hybrid_r": 4}, + ), + ( + "Embedding + transformers Conv1D 1 GraLoRA", + "EmbConv1D", + GraloraConfig, + {"target_modules": ["conv1d"], "gralora_k": 1}, + ), ########## # VBLoRA # ########## @@ -979,6 +1011,20 @@ {"n_frequency": 10, "target_modules": ["lin0"]}, {"n_frequency": 10, "target_modules": ["lin1"]}, ), + ( + "GraLoRA Same", + "gralora", + GraloraConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin0"], "init_weights": False}, + ), + ( + "GraLoRA Different", + "gralora", + GraloraConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin1"], "init_weights": False}, + ), ( "SHiRA Same", "shira", @@ -1165,6 +1211,7 @@ VeraConfig: "vera_lambda_", RandLoraConfig: "randlora_", FourierFTConfig: "fourierft_", + GraloraConfig: "gralora_", C3AConfig: "c3a_", HRAConfig: "hra_", ShiraConfig: "shira_", @@ -3405,6 +3452,24 @@ def test_dora_save_and_load_remapping(self): for k in state_dict: assert torch.allclose(state_dict[k], state_dict_loaded[k]) + def test_gralora_and_hybrid_gralora_parameter_count(self): + # Here we test the parameter count of GraLoRA is preserved + # when rank r + hybrid_r is the same regardless of the value of gralora_k. + model1 = MLP() + config1 = GraloraConfig(target_modules=["lin0"], r=12, gralora_k=2, hybrid_r=0) + model1 = get_peft_model(model1, config1) + model2 = MLP() + config2 = GraloraConfig(target_modules=["lin0"], r=10, gralora_k=2, hybrid_r=2) + model2 = get_peft_model(model2, config2) + model3 = MLP() + config3 = GraloraConfig(target_modules=["lin0"], r=10, gralora_k=5, hybrid_r=2) + model3 = get_peft_model(model3, config3) + trainable_params1, all_params1 = model1.get_nb_trainable_parameters() + trainable_params2, all_params2 = model2.get_nb_trainable_parameters() + trainable_params3, all_params3 = model3.get_nb_trainable_parameters() + assert trainable_params1 == trainable_params2 == trainable_params3 + assert all_params1 == all_params2 == all_params3 + @pytest.mark.parametrize("with_forward_call", [False, True]) def test_mha_gradients_set_correctly(self, with_forward_call): # check for this bug: https://github.com/huggingface/peft/issues/761#issuecomment-1893804738 diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 5b23fa74e2..acb0d9c7d2 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -34,6 +34,7 @@ CPTConfig, DeloraConfig, FourierFTConfig, + GraloraConfig, HRAConfig, IA3Config, LoraConfig, @@ -137,6 +138,30 @@ "target_modules": None, }, ), + ( + GraloraConfig, + { + "task_type": "CAUSAL_LM", + "r": 8, + "gralora_alpha": 16, + "target_modules": None, + "gralora_dropout": 0.05, + "gralora_k": 2, + "hybrid_r": 0, + }, + ), + ( + GraloraConfig, + { + "task_type": "CAUSAL_LM", + "r": 16, + "gralora_alpha": 32, + "target_modules": None, + "gralora_dropout": 0.05, + "gralora_k": 4, + "hybrid_r": 4, + }, + ), ( HRAConfig, { diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index c4e38f934b..42b12e66e0 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -24,6 +24,7 @@ C3AConfig, DeloraConfig, FourierFTConfig, + GraloraConfig, HRAConfig, IA3Config, LoraConfig, @@ -100,6 +101,13 @@ "task_type": "SEQ_2_SEQ_LM", }, ), + ( + GraloraConfig, + { + "target_modules": None, + "task_type": "SEQ_2_SEQ_LM", + }, + ), ( HRAConfig, { diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index a5377827f4..6bfd254ec4 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -22,6 +22,7 @@ C3AConfig, DeloraConfig, FourierFTConfig, + GraloraConfig, HRAConfig, IA3Config, LoraConfig, @@ -98,6 +99,13 @@ "target_modules": None, }, ), + ( + GraloraConfig, + { + "task_type": "FEATURE_EXTRACTION", + "target_modules": None, + }, + ), ( HRAConfig, { diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 3475247cd8..cbc41e5671 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -38,6 +38,7 @@ C3AConfig, DeloraConfig, EvaConfig, + GraloraConfig, IA3Config, LoftQConfig, LoKrConfig, @@ -2157,6 +2158,56 @@ def test_init_weights_false_shifts_output(self, data): assert not torch.allclose(y_base, y_peft, atol=1e-6, rtol=1e-6) +class TestGraLoRAInitialization: + """Basic sanity tests for the GraLoRA tuner.""" + + torch_device = infer_device() + + def get_model(self, bias=True): + class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 30, bias=bias) + self.lin1 = nn.Linear(30, 2, bias=bias) + + def forward(self, X): + X = self.lin0(X) + X = self.lin1(X) + return X + + return MLP(bias=bias).to(self.torch_device).eval() + + @pytest.fixture + def data(self): + torch.manual_seed(0) + return torch.randn(4, 10, device=self.torch_device) + + def test_gralora_with_incompatible_gralora_k_and_r_raises(self): + model = self.get_model() + r = 6 + gralora_k = 4 + # msg = f"r should be divisible by gralora_k, but got {config.r} and {config.gralora_k}" + msg = f"r should be divisible by gralora_k, but got {r} and {gralora_k}" + with pytest.raises(ValueError, match=re.escape(msg)): + GraloraConfig(target_modules=["lin0"], r=r, gralora_k=gralora_k) + + def test_gralora_with_incompatible_gralora_k_and_in_features_raises(self): + model = self.get_model() + config = GraloraConfig(target_modules=["lin0"], r=6, gralora_k=3) + msg = f"in_features should be divisible by gralora_k, but got {model.lin0.in_features} and {config.gralora_k}" + with pytest.raises(ValueError, match=re.escape(msg)): + get_peft_model(model, config) + + def test_gralora_with_incompatible_gralora_k_and_out_features_raises(self): + model = self.get_model() + config = GraloraConfig(target_modules=["lin1"], r=6, gralora_k=3) + msg = ( + f"out_features should be divisible by gralora_k, but got {model.lin1.out_features} and {config.gralora_k}" + ) + with pytest.raises(ValueError, match=re.escape(msg)): + get_peft_model(model, config) + + class TestNoInfiniteRecursionDeepspeed: # see #1892 for details classes = [ diff --git a/tests/test_seq_classifier.py b/tests/test_seq_classifier.py index 03869c3a7a..bee83a879a 100644 --- a/tests/test_seq_classifier.py +++ b/tests/test_seq_classifier.py @@ -22,6 +22,7 @@ C3AConfig, DeloraConfig, FourierFTConfig, + GraloraConfig, HRAConfig, IA3Config, LoraConfig, @@ -99,6 +100,13 @@ "target_modules": None, }, ), + ( + GraloraConfig, + { + "task_type": "SEQ_CLS", + "target_modules": None, + }, + ), ( HRAConfig, {