⚡️ Speed up method LoRALayer.fuse_weights by 8%
#138
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 8% (0.08x) speedup for
LoRALayer.fuse_weightsininvokeai/backend/patches/layers/lora_layer.py⏱️ Runtime :
2.17 milliseconds→2.00 milliseconds(best of108runs)📝 Explanation and details
The optimized code achieves an 8% speedup by replacing Python-level loops with vectorized PyTorch operations and adding strategic optimizations for common cases.
Key optimizations applied:
Vectorized operations: The original code used Python
forloops with repeated in-place additions (fused_lora = fused_lora + torch.mm(...)). The optimized version replaces this withtorch.stack([torch.mm(...) for ...]).sum(dim=0), which batches the matrix multiplications and performs the summation in PyTorch's optimized backend rather than Python.Single-chunk optimization: Added fast-path checks for
num_chunks == 1cases, directly usingtorch.mm(up, down)instead of unnecessary chunking operations. This optimization shows significant benefits in test cases with equal ranks.Shape caching: Pre-computes tensor shapes (
up_rows, up_cols = up.shape) to avoid repeated attribute lookups during computation.Performance characteristics from tests:
Why this works: PyTorch's
torch.stackand.sum()operations are implemented in optimized C++/CUDA code and can leverage SIMD instructions and better memory access patterns compared to Python loops with repeated tensor additions. The single-chunk optimization eliminates unnecessary overhead for the common case where chunking isn't actually needed.The optimization is most beneficial for LoRA layers with equal or similar ranks, which appear to be common based on the test results showing frequent single-chunk scenarios.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest # used for our unit tests
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
function to test
--- LoRALayerBase ---
class LoRALayerBase:
def init(self, alpha: float | None, bias: torch.Tensor | None):
self._alpha = alpha
self.bias = bias
from invokeai.backend.patches.layers.lora_layer import LoRALayer
unit tests
-------- BASIC TEST CASES --------
def test_fuse_weights_basic_equal_ranks():
# up: [4, 2], down: [2, 3] -> ranks equal (2)
up = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
down = torch.tensor([[9., 10., 11.], [12., 13., 14.]])
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 49.2μs -> 39.1μs (25.9% faster)
# Should be torch.mm(up, down)
expected = torch.mm(up, down)
def test_fuse_weights_basic_rank_diff_gt_1():
# up: [4, 2], down: [4, 3] -> down.shape[0]/up.shape[1]=2>1
up = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
down = torch.tensor([
[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.],
[10., 11., 12.]
])
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 42.7μs -> 66.2μs (35.5% slower)
# down.chunk(2, dim=0) -> two [2,3] chunks
expected = torch.mm(up, down[:2]) + torch.mm(up, down[2:])
def test_fuse_weights_basic_rank_diff_lt_1():
# up: [6, 4], down: [2, 3] -> up.shape[1]/down.shape[0]=2>1, so rank_diff>1 but in else branch
up = torch.arange(24, dtype=torch.float32).reshape(6,4)
down = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output
# up.chunk(2, dim=0) -> two [3,4] chunks
w_up = up.chunk(2, dim=0)
expected = torch.mm(w_up[0], down) + torch.mm(w_up[1], down)
def test_fuse_weights_basic_dtype_and_device():
# Check that dtype and device are preserved
up = torch.ones((2,2), dtype=torch.float64, device="cpu")
down = torch.ones((2,2), dtype=torch.float64, device="cpu")
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 44.6μs -> 37.2μs (20.1% faster)
-------- EDGE TEST CASES --------
def test_fuse_weights_edge_zero_matrix():
# up and down are zero matrices
up = torch.zeros((3, 2))
down = torch.zeros((2, 4))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 35.7μs -> 27.8μs (28.5% faster)
expected = torch.zeros((3,4))
def test_fuse_weights_edge_negative_values():
# up and down have negative values
up = torch.tensor([[-1., -2.], [-3., -4.]])
down = torch.tensor([[-5., -6.], [-7., -8.]])
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 36.1μs -> 27.4μs (31.7% faster)
expected = torch.mm(up, down)
def test_fuse_weights_edge_single_element():
# up: [1,1], down: [1,1]
up = torch.tensor([[2.]])
down = torch.tensor([[3.]])
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 33.2μs -> 24.1μs (37.8% faster)
expected = torch.mm(up, down)
def test_fuse_weights_edge_ones_and_identity():
# up: identity, down: ones
up = torch.eye(2)
down = torch.ones((2,2))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 32.1μs -> 25.2μs (27.1% faster)
expected = torch.mm(up, down)
def test_fuse_weights_edge_rank_diff_not_integer():
# up: [2,3], down: [5,2] -> down.shape[0]/up.shape[1]=5/3=1.666... not integer, chunk will fail
up = torch.randn(2,3)
down = torch.randn(5,2)
lora = LoRALayer(up, None, down, None, None)
with pytest.raises(RuntimeError):
lora.fuse_weights(up, down) # 72.4μs -> 71.5μs (1.34% faster)
def test_fuse_weights_edge_up_chunk_not_integer():
# up: [5,4], down: [2,2] -> up.shape[1]/down.shape[0]=4/2=2, ok; but if not integer, chunk fails
up = torch.randn(5,5)
down = torch.randn(2,2)
lora = LoRALayer(up, None, down, None, None)
with pytest.raises(RuntimeError):
lora.fuse_weights(up, down) # 66.7μs -> 68.4μs (2.50% slower)
def test_fuse_weights_edge_empty_up():
# up: [0,2], down: [2,3]
up = torch.empty((0,2))
down = torch.randn(2,3)
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 31.8μs -> 25.0μs (27.5% faster)
expected = torch.zeros((0,3), dtype=down.dtype)
def test_fuse_weights_edge_empty_down():
# up: [2,0], down: [0,3]
up = torch.empty((2,0))
down = torch.empty((0,3))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output
expected = torch.zeros((2,3), dtype=down.dtype)
def test_fuse_weights_edge_large_rank_diff():
# up: [2,1], down: [8,3] -> rank_diff=8/1=8, chunk(8)
up = torch.ones((2,1))
down = torch.ones((8,3))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 70.7μs -> 79.2μs (10.7% slower)
# Each chunk is [1,3], so 8 chunks, each mm is [[1,1,1],[1,1,1]]
expected = torch.zeros((2,3))
for i in range(8):
expected += torch.mm(up, down[i:i+1])
-------- LARGE SCALE TEST CASES --------
def test_fuse_weights_large_scale_equal_ranks():
# up: [100,50], down: [50,100], ranks equal
up = torch.ones((100,50))
down = torch.ones((50,100))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 48.8μs -> 41.3μs (18.4% faster)
# Each element should be 50 (sum of 50 ones)
expected = torch.full((100,100), 50.0)
def test_fuse_weights_large_scale_rank_diff_gt_1():
# up: [128,32], down: [64,128], rank_diff=2
up = torch.ones((128,32))
down = torch.ones((64,128))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 71.2μs -> 98.2μs (27.5% slower)
# down.chunk(2, dim=0): two [32,128] chunks
expected = torch.mm(up, down[:32]) + torch.mm(up, down[32:])
def test_fuse_weights_large_scale_rank_diff_lt_1():
# up: [128,64], down: [32,128], up.shape[1]/down.shape[0]=2
up = torch.ones((128,64))
down = torch.ones((32,128))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output
# up.chunk(2, dim=0): two [64,64] chunks
w_up = up.chunk(2, dim=0)
expected = torch.mm(w_up[0], down) + torch.mm(w_up[1], down)
def test_fuse_weights_large_scale_random():
# up: [50,25], down: [25,50], random values
torch.manual_seed(42)
up = torch.randn(50,25)
down = torch.randn(25,50)
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 53.2μs -> 43.2μs (23.0% faster)
expected = torch.mm(up, down)
def test_fuse_weights_large_scale_dtype():
# up: [100,20], down: [20,100], float16
up = torch.ones((100,20), dtype=torch.float16)
down = torch.ones((20,100), dtype=torch.float16)
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 96.4μs -> 83.1μs (15.9% faster)
expected = torch.full((100,100), 20.0, dtype=torch.float16)
def test_fuse_weights_large_scale_max_size():
# up: [200,200], down: [200,200] -> 2002002*4 bytes = 320KB, safe
up = torch.ones((200,200))
down = torch.ones((200,200))
lora = LoRALayer(up, None, down, None, None)
codeflash_output = lora.fuse_weights(up, down); result = codeflash_output # 177μs -> 161μs (10.0% faster)
expected = torch.full((200,200), 200.0)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
function to test
---- lora_layer_base.py ----
class BaseLayerPatch:
pass
class LoRALayerBase(BaseLayerPatch):
def init(self, alpha: float | None, bias: torch.Tensor | None):
self._alpha = alpha
self.bias = bias
from invokeai.backend.patches.layers.lora_layer import LoRALayer
unit tests
-------- BASIC TEST CASES --------
def test_basic_equal_ranks():
# up: [4, 2], down: [2, 3], up.shape[1] == down.shape[0]
up = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
down = torch.tensor([[1., 0., 2.], [0., 1., 3.]])
layer = LoRALayer(up, None, down, None, None)
# Since ranks are equal, fuse_weights should just do torch.mm(up, down)
expected = torch.mm(up, down)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 42.9μs -> 31.2μs (37.6% faster)
def test_basic_rank_diff_down_larger():
# up: [2, 2], down: [4, 3], down.shape[0] / up.shape[1] == 2
up = torch.tensor([[1., 2.], [3., 4.]])
down = torch.tensor([[1., 0., 2.], [0., 1., 3.], [2., 1., 0.], [1., 2., 1.]])
layer = LoRALayer(up, None, down, None, None)
# down chunked into 2 chunks along dim=0: each [2,3]
chunk1 = down[:2]
chunk2 = down[2:]
expected = torch.mm(up, chunk1) + torch.mm(up, chunk2)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 28.8μs -> 52.0μs (44.7% slower)
def test_basic_dtype_and_device():
# Test that dtype and device are preserved
up = torch.ones((2, 2), dtype=torch.float64, device='cpu')
down = torch.ones((2, 2), dtype=torch.float64, device='cpu')
layer = LoRALayer(up, None, down, None, None)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 44.7μs -> 37.1μs (20.5% faster)
expected = torch.mm(up, down)
-------- EDGE TEST CASES --------
def test_edge_empty_up():
# up: [0, 2], down: [2, 3] -- empty up
up = torch.empty((0, 2))
down = torch.ones((2, 3))
layer = LoRALayer(up, None, down, None, None)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 27.6μs -> 22.7μs (21.8% faster)
def test_edge_empty_down():
# up: [2, 2], down: [0, 3] -- empty down
up = torch.ones((2, 2))
down = torch.empty((0, 3))
layer = LoRALayer(up, None, down, None, None)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output
def test_edge_one_element():
# up: [1, 1], down: [1, 1]
up = torch.tensor([[2.]])
down = torch.tensor([[3.]])
layer = LoRALayer(up, None, down, None, None)
expected = torch.mm(up, down)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 40.4μs -> 29.9μs (35.2% faster)
def test_edge_non_divisible_chunk():
# up: [3, 2], down: [5, 3] -- down.shape[0] / up.shape[1] == 2.5 (not divisible)
up = torch.ones((3, 2))
down = torch.ones((5, 3))
layer = LoRALayer(up, None, down, None, None)
# Should raise error due to chunking non-divisible
with pytest.raises(RuntimeError):
layer.fuse_weights(up, down) # 75.0μs -> 77.1μs (2.80% slower)
def test_edge_non_divisible_chunk_up():
# up: [4, 5], down: [2, 3] -- up.shape[1] / down.shape[0] == 2.5 (not divisible)
up = torch.ones((4, 5))
down = torch.ones((2, 3))
layer = LoRALayer(up, None, down, None, None)
# Should raise error due to chunking non-divisible
with pytest.raises(RuntimeError):
layer.fuse_weights(up, down) # 67.2μs -> 65.4μs (2.69% faster)
def test_edge_mismatched_shapes():
# up: [2, 3], down: [4, 2] -- up.shape[1] != down.shape[0], not divisible
up = torch.ones((2, 3))
down = torch.ones((4, 2))
layer = LoRALayer(up, None, down, None, None)
# Should raise error due to chunking non-divisible
with pytest.raises(RuntimeError):
layer.fuse_weights(up, down) # 59.4μs -> 56.3μs (5.52% faster)
def test_edge_negative_values():
# up and down contain negative values
up = torch.tensor([[1., -2.], [-3., 4.]])
down = torch.tensor([[-1., 2.], [3., -4.]])
layer = LoRALayer(up, None, down, None, None)
expected = torch.mm(up, down)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 33.8μs -> 23.1μs (46.1% faster)
def test_edge_large_rank_diff():
# up: [2, 1], down: [8, 3], down.shape[0] / up.shape[1] == 8
up = torch.ones((2, 1))
down = torch.ones((8, 3))
layer = LoRALayer(up, None, down, None, None)
# down chunked into 8 chunks along dim=0: each [1, 3]
expected = torch.zeros((2, 3))
for i in range(8):
expected += torch.mm(up, down[i:i+1])
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 43.4μs -> 55.2μs (21.4% slower)
-------- LARGE SCALE TEST CASES --------
def test_large_scale_equal_ranks():
# up: [256, 128], down: [128, 256]
up = torch.ones((256, 128))
down = torch.ones((128, 256))
layer = LoRALayer(up, None, down, None, None)
expected = torch.mm(up, down)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 161μs -> 139μs (15.6% faster)
def test_large_scale_rank_diff_down_larger():
# up: [64, 32], down: [128, 128], down.shape[0] / up.shape[1] == 4
up = torch.ones((64, 32))
down = torch.ones((128, 128))
layer = LoRALayer(up, None, down, None, None)
# down chunked into 4 chunks along dim=0: each [32, 128]
expected = torch.zeros((64, 128))
for i in range(4):
expected += torch.mm(up, down[i*32:(i+1)*32])
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 53.5μs -> 77.8μs (31.2% slower)
def test_large_scale_random_values():
# up: [100, 50], down: [50, 100]
torch.manual_seed(42)
up = torch.randn((100, 50))
down = torch.randn((50, 100))
layer = LoRALayer(up, None, down, None, None)
expected = torch.mm(up, down)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 53.4μs -> 40.2μs (32.7% faster)
def test_large_scale_max_tensor_size():
# up: [500, 100], down: [100, 500] -- total size ~200k floats < 100MB
up = torch.ones((500, 100))
down = torch.ones((100, 500))
layer = LoRALayer(up, None, down, None, None)
expected = torch.mm(up, down)
codeflash_output = layer.fuse_weights(up, down); result = codeflash_output # 473μs -> 377μs (25.3% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-LoRALayer.fuse_weights-mhvk86o9and push.