|
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 |
|
5 |
| -from invokeai.backend.patches.layers.lora_layer import LoRALayer |
6 | 5 | from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
| 6 | +from invokeai.backend.patches.layers.lora_layer import LoRALayer |
7 | 7 | from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
|
8 |
| -from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict, swap_shift_scale_for_linear_weight, decomposite_weight_matric_with_rank |
| 8 | +from invokeai.backend.patches.layers.utils import ( |
| 9 | + any_lora_layer_from_state_dict, |
| 10 | + decomposite_weight_matric_with_rank, |
| 11 | + swap_shift_scale_for_linear_weight, |
| 12 | +) |
9 | 13 | from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
10 | 14 | from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
11 | 15 |
|
@@ -39,46 +43,47 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
39 | 43 |
|
40 | 44 | return all_keys_in_peft_format and (transformer_keys_present or base_model_keys_present)
|
41 | 45 |
|
| 46 | + |
42 | 47 | def approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRALayer:
|
43 |
| - '''Approximate given diffusers AdaLN loRA layer in our Flux model''' |
| 48 | + """Approximate given diffusers AdaLN loRA layer in our Flux model""" |
44 | 49 |
|
45 |
| - if not "lora_up.weight" in state_dict: |
| 50 | + if "lora_up.weight" not in state_dict: |
46 | 51 | raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_up")
|
47 |
| - |
48 |
| - if not "lora_down.weight" in state_dict: |
| 52 | + |
| 53 | + if "lora_down.weight" not in state_dict: |
49 | 54 | raise ValueError(f"Unsupported lora format: {state_dict.keys()}, missing lora_down")
|
50 |
| - |
51 |
| - up = state_dict.pop('lora_up.weight') |
52 |
| - down = state_dict.pop('lora_down.weight') |
53 | 55 |
|
54 |
| - # layer-patcher upcast things to f32, |
| 56 | + up = state_dict.pop("lora_up.weight") |
| 57 | + down = state_dict.pop("lora_down.weight") |
| 58 | + |
| 59 | + # layer-patcher upcast things to f32, |
55 | 60 | # we want to maintain a better precison for this one
|
56 | 61 | dtype = torch.float32
|
57 | 62 |
|
58 | 63 | device = up.device
|
59 | 64 | up_shape = up.shape
|
60 | 65 | down_shape = down.shape
|
61 |
| - |
| 66 | + |
62 | 67 | # desired low rank
|
63 | 68 | rank = up_shape[1]
|
64 | 69 |
|
65 | 70 | # up scaling for more precise
|
66 | 71 | up = up.to(torch.float32)
|
67 | 72 | down = down.to(torch.float32)
|
68 | 73 |
|
69 |
| - weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1) |
| 74 | + weight = up.reshape(up_shape[0], -1) @ down.reshape(down_shape[0], -1) |
70 | 75 |
|
71 | 76 | # swap to our linear format
|
72 | 77 | swapped = swap_shift_scale_for_linear_weight(weight)
|
73 | 78 |
|
74 | 79 | _up, _down = decomposite_weight_matric_with_rank(swapped, rank)
|
75 | 80 |
|
76 |
| - assert(_up.shape == up_shape) |
77 |
| - assert(_down.shape == down_shape) |
| 81 | + assert _up.shape == up_shape |
| 82 | + assert _down.shape == down_shape |
78 | 83 |
|
79 | 84 | # down scaling to original dtype, device
|
80 |
| - state_dict['lora_up.weight'] = _up.to(dtype).to(device=device) |
81 |
| - state_dict['lora_down.weight'] = _down.to(dtype).to(device=device) |
| 85 | + state_dict["lora_up.weight"] = _up.to(dtype).to(device=device) |
| 86 | + state_dict["lora_down.weight"] = _down.to(dtype).to(device=device) |
82 | 87 |
|
83 | 88 | return LoRALayer.from_state_dict_values(state_dict)
|
84 | 89 |
|
@@ -148,7 +153,7 @@ def add_adaLN_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
148 | 153 | src_layer_dict = grouped_state_dict.pop(src_key)
|
149 | 154 | values = get_lora_layer_values(src_layer_dict)
|
150 | 155 | layers[dst_key] = approximate_flux_adaLN_lora_layer_from_diffusers_state_dict(values)
|
151 |
| - |
| 156 | + |
152 | 157 | def add_qkv_lora_layer_if_present(
|
153 | 158 | src_keys: list[str],
|
154 | 159 | src_weight_shapes: list[tuple[int, int]],
|
@@ -291,8 +296,8 @@ def add_qkv_lora_layer_if_present(
|
291 | 296 | # Final layer.
|
292 | 297 | add_lora_layer_if_present("proj_out", "final_layer.linear")
|
293 | 298 | add_adaLN_lora_layer_if_present(
|
294 |
| - 'norm_out.linear', |
295 |
| - 'final_layer.adaLN_modulation.1', |
| 299 | + "norm_out.linear", |
| 300 | + "final_layer.adaLN_modulation.1", |
296 | 301 | )
|
297 | 302 |
|
298 | 303 | # Assert that all keys were processed.
|
|
0 commit comments