Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions invokeai/backend/patches/layers/lora_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,30 @@ def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
the LoRA weights have different ranks. This function handles the fusion of these differently sized
matrices.
"""
up_rows, up_cols = up.shape
down_rows, down_cols = down.shape
fused_lora = torch.zeros((up_rows, down_cols), device=down.device, dtype=down.dtype)
rank_diff = down_rows / up_cols

fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
rank_diff = down.shape[0] / up.shape[1]
# optimization: coalesce operation with torch.stack and .sum to avoid python-level looping

if rank_diff > 1:
rank_diff = down.shape[0] / up.shape[1]
w_down = down.chunk(int(rank_diff), dim=0)
for w_down_chunk in w_down:
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
# split down along the row dimension and compute mm for each chunk
num_chunks = int(rank_diff)
w_down = down.chunk(num_chunks, dim=0)
# if num_chunks == 1, fallback to simpler matmul
if num_chunks == 1:
fused_lora = torch.mm(up, down)
else:
fused_lora = torch.stack([torch.mm(up, w) for w in w_down], dim=0).sum(dim=0)
else:
rank_diff = up.shape[1] / down.shape[0]
w_up = up.chunk(int(rank_diff), dim=0)
for w_up_chunk in w_up:
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
rank_diff = up_cols / down_rows
num_chunks = int(rank_diff)
w_up = up.chunk(num_chunks, dim=0)
if num_chunks == 1:
fused_lora = torch.mm(up, down)
else:
fused_lora = torch.stack([torch.mm(w, down) for w in w_up], dim=0).sum(dim=0)

return fused_lora

Expand Down