Skip to content

Commit 4fb96ab

Browse files
chore: ruff fix
1 parent a88a388 commit 4fb96ab

File tree

5 files changed

+1075
-1059
lines changed

5 files changed

+1075
-1059
lines changed

invokeai/backend/patches/layers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseL
3535
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
3636

3737

38-
3938
def swap_shift_scale_for_linear_weight(weight: torch.Tensor) -> torch.Tensor:
4039
"""Swap shift/scale for given linear layer back and forth"""
4140
# In SD3 and Flux implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
4241
# while in diffusers it split into scale, shift. This will flip them around
43-
chunk1, chunk2 = weight.chunk(2, dim=0)
42+
chunk1, chunk2 = weight.chunk(2, dim=0)
4443
return torch.cat([chunk2, chunk1], dim=0)
4544

45+
4646
def decomposite_weight_matric_with_rank(
4747
delta: torch.Tensor,
4848
rank: int,
@@ -56,7 +56,7 @@ def decomposite_weight_matric_with_rank(
5656
S_r = S[:rank]
5757
V_r = V[:, :rank]
5858

59-
S_sqrt = torch.sqrt(S_r + epsilon) # regularization
59+
S_sqrt = torch.sqrt(S_r + epsilon) # regularization
6060

6161
up = torch.matmul(U_r, torch.diag(S_sqrt))
6262
down = torch.matmul(torch.diag(S_sqrt), V_r.T)

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import torch
44

5-
from invokeai.backend.patches.layers.lora_layer import LoRALayer
65
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
6+
from invokeai.backend.patches.layers.lora_layer import LoRALayer
77
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
8-
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, swap_shift_scale_for_linear_weight, decomposite_weight_matric_with_rank
8+
from invokeai.backend.patches.layers.utils import (
9+
any_lora_layer_from_state_dict,
10+
decomposite_weight_matric_with_rank,
11+
swap_shift_scale_for_linear_weight,
12+
)
913
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
1014
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1115

@@ -39,46 +43,47 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
3943

4044
return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
4145

46+
4247
def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
43-
'''Approximate given diffusers AdaLN loRA layer in our Flux model'''
48+
"""Approximate given diffusers AdaLN loRA layer in our Flux model"""
4449

45-
if not "lora_up.weight" in state_dict:
50+
if "lora_up.weight" not in state_dict:
4651
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
47-
48-
if not "lora_down.weight" in state_dict:
52+
53+
if "lora_down.weight" not in state_dict:
4954
raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
50-
51-
up = state_dict.pop('lora_up.weight')
52-
down = state_dict.pop('lora_down.weight')
5355

54-
# layer-patcher upcast things to f32,
56+
up = state_dict.pop("lora_up.weight")
57+
down = state_dict.pop("lora_down.weight")
58+
59+
# layer-patcher upcast things to f32,
5560
# we want to maintain a better precison for this one
5661
dtype = torch.float32
5762

5863
device = up.device
5964
up_shape = up.shape
6065
down_shape = down.shape
61-
66+
6267
# desired low rank
6368
rank = up_shape[1]
6469

6570
# up scaling for more precise
6671
up = up.to(torch.float32)
6772
down = down.to(torch.float32)
6873

69-
weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)
74+
weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1)
7075

7176
# swap to our linear format
7277
swapped = swap_shift_scale_for_linear_weight(weight)
7378

7479
_up, _down = decomposite_weight_matric_with_rank(swapped, rank)
7580

76-
assert(_up.shape == up_shape)
77-
assert(_down.shape == down_shape)
81+
assert _up.shape == up_shape
82+
assert _down.shape == down_shape
7883

7984
# down scaling to original dtype, device
80-
state_dict['lora_up.weight'] = _up.to(dtype).to(device=device)
81-
state_dict['lora_down.weight'] = _down.to(dtype).to(device=device)
85+
state_dict["lora_up.weight"] = _up.to(dtype).to(device=device)
86+
state_dict["lora_down.weight"] = _down.to(dtype).to(device=device)
8287

8388
return LoRALayer.from_state_dict_values(state_dict)
8489

@@ -148,7 +153,7 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
148153
src_layer_dict = grouped_state_dict.pop(src_key)
149154
values = get_lora_layer_values(src_layer_dict)
150155
layers[dst_key] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(values)
151-
156+
152157
def add_qkv_lora_layer_if_present(
153158
src_keys: list[str],
154159
src_weight_shapes: list[tuple[int, int]],
@@ -291,8 +296,8 @@ def add_qkv_lora_layer_if_present(
291296
# Final layer.
292297
add_lora_layer_if_present("proj_out", "final_layer.linear")
293298
add_adaLN_lora_layer_if_present(
294-
'norm_out.linear',
295-
'final_layer.adaLN_modulation.1',
299+
"norm_out.linear",
300+
"final_layer.adaLN_modulation.1",
296301
)
297302

298303
# Assert that all keys were processed.

tests/backend/patches/layers/test_layer_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22

3-
from invokeai.backend.patches.layers.utils import decomposite_weight_matric_with_rank, swap_shift_scale_for_linear_weight
3+
from invokeai.backend.patches.layers.utils import (
4+
decomposite_weight_matric_with_rank,
5+
swap_shift_scale_for_linear_weight,
6+
)
47

58

69
def test_swap_shift_scale_for_linear_weight():
@@ -9,38 +12,37 @@ def test_swap_shift_scale_for_linear_weight():
912
expected = torch.Tensor([2, 1])
1013

1114
swapped = swap_shift_scale_for_linear_weight(original)
12-
assert(torch.allclose(expected, swapped))
15+
assert torch.allclose(expected, swapped)
1316

14-
size= (3, 4)
17+
size = (3, 4)
1518
first = torch.randn(size)
1619
second = torch.randn(size)
1720

1821
original = torch.concat([first, second])
1922
expected = torch.concat([second, first])
2023

2124
swapped = swap_shift_scale_for_linear_weight(original)
22-
assert(torch.allclose(expected, swapped))
25+
assert torch.allclose(expected, swapped)
2326

2427
# call this twice will reconstruct the original
2528
reconstructed = swap_shift_scale_for_linear_weight(swapped)
26-
assert(torch.allclose(reconstructed, original))
29+
assert torch.allclose(reconstructed, original)
30+
2731

2832
def test_decomposite_weight_matric_with_rank():
2933
"""Test that decompsition of given matrix into 2 low rank matrices work"""
3034
input_dim = 1024
3135
output_dim = 1024
3236
rank = 8 # Low rank
3337

34-
3538
A = torch.randn(input_dim, rank).double()
3639
B = torch.randn(rank, output_dim).double()
3740
W0 = A @ B
3841

3942
C, D = decomposite_weight_matric_with_rank(W0, rank)
4043
R = C @ D
4144

42-
assert(C.shape == A.shape)
43-
assert(D.shape == B.shape)
45+
assert C.shape == A.shape
46+
assert D.shape == B.shape
4447

4548
assert torch.allclose(W0, R)
46-

0 commit comments

Comments
 (0)