Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/peft/tuners/oft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

import packaging.version

from peft.config import PeftConfig
from peft.utils import PeftType

Expand Down Expand Up @@ -193,4 +196,18 @@ def check_kwargs(cls, **kwargs):
"with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. "
"Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights."
)
if kwargs.get("use_cayley_neumann", False):
peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version
# remove commit hash, if present
peft_version = peft_version.partition("@")[0]
parsed_version = packaging.version.Version(peft_version)
min_version = packaging.version.Version("0.18.0")
# note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version
if parsed_version < min_version:
msg = (
"The cayley-neumann parameterization has been slightly changed to be more numerically stable in "
"PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, "
"downgrade PEFT to version 0.17.0 to use the old parameterization."
)
warnings.warn(msg)
return super().check_kwargs(**kwargs)
20 changes: 16 additions & 4 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __init__(
self.use_cayley_neumann = use_cayley_neumann
self.num_cayley_neumann_terms = num_cayley_neumann_terms
# Create indices for upper triangle (excluding diagonal)
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1)
rows, cols = torch.triu_indices(block_size, block_size, 1)
self.register_buffer("rows", rows, persistent=False)
self.register_buffer("cols", cols, persistent=False)

def _pytorch_skew_symmetric(self, vec, block_size):
batch_size = vec.shape[0]
Expand Down Expand Up @@ -139,9 +141,11 @@ def _cayley_batch(
R.add_(Q_squared, alpha=2.0)

Q_power = Q_squared
for i in range(3, num_neumann_terms):
for _ in range(3, num_neumann_terms - 1):
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power, alpha=2.0)
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power)
else:
id_mat = (
torch.eye(Q_skew.shape[-1], device=Q_skew.device)
Expand Down Expand Up @@ -621,9 +625,13 @@ def unmerge(self) -> None:
if active_adapter in self.oft_R.keys():
oft_mat = self.get_delta_weight(active_adapter)

previous_dtype = oft_mat.dtype
if previous_dtype != torch.float32:
oft_mat = oft_mat.to(torch.float32)

orig_weights = self.get_base_layer().weight.data
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
orig_weights = torch.mm(torch.linalg.inv(oft_mat).to(previous_dtype), orig_weights.to(previous_dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)

base_layer.weight.data = orig_weights.to(orig_dtype)
Expand Down Expand Up @@ -855,13 +863,17 @@ def unmerge(self) -> None:
if active_adapter in self.oft_R.keys():
oft_mat = self.get_delta_weight(active_adapter)

previous_dtype = oft_mat.dtype
if previous_dtype != torch.float32:
oft_mat = oft_mat.to(torch.float32)

orig_weights = self.get_base_layer().weight.data.clone()
orig_weights = orig_weights.view(
self.out_features,
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
)
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
orig_weights = torch.mm(torch.linalg.inv(oft_mat).to(previous_dtype), orig_weights.to(previous_dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights.view(
self.out_features,
Expand Down
Loading