Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions mlx_lm/models/switch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,104 @@ def __init__(
num_experts: int,
activation=SwiGLU(),
bias: bool = False,
fuse_gate_up: bool = False,
):
super().__init__()

self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
self.fuse_gate_up = fuse_gate_up
self._fused_gate_up_cache = None

def _can_fuse_gate_up(self):
if not self.fuse_gate_up or self.training:
return False
if type(self.up_proj) is not type(self.gate_proj):
return False
if not isinstance(self.up_proj, SwitchLinear):
return False
if self.up_proj.input_dims != self.gate_proj.input_dims:
return False
if self.up_proj.output_dims != self.gate_proj.output_dims:
return False
if self.up_proj.num_experts != self.gate_proj.num_experts:
return False
if ("bias" in self.up_proj) != ("bias" in self.gate_proj):
return False
if isinstance(self.up_proj, QuantizedSwitchLinear):
if self.up_proj.group_size != self.gate_proj.group_size:
return False
if self.up_proj.bits != self.gate_proj.bits:
return False
if self.up_proj.mode != self.gate_proj.mode:
return False
if (self.up_proj.get("biases") is None) != (
self.gate_proj.get("biases") is None
):
return False
return True

def _fused_gate_up_params(self):
up = self.up_proj
gate = self.gate_proj
key = (
type(up),
id(up["weight"]),
id(gate["weight"]),
up["weight"].shape,
gate["weight"].shape,
)
if self._fused_gate_up_cache is not None:
cached_key, params = self._fused_gate_up_cache
if cached_key == key:
return params

weight = mx.concatenate([up["weight"], gate["weight"]], axis=1)
bias = None
if "bias" in up:
bias = mx.concatenate([up["bias"], gate["bias"]], axis=1)
if isinstance(up, QuantizedSwitchLinear):
scales = mx.concatenate([up["scales"], gate["scales"]], axis=1)
up_biases = up.get("biases")
gate_biases = gate.get("biases")
biases = None
if up_biases is not None:
biases = mx.concatenate([up_biases, gate_biases], axis=1)
params = (weight, scales, biases, bias)
else:
params = (weight, bias)
self._fused_gate_up_cache = (key, params)
return params

def _fused_gate_up(self, x, indices, sorted_indices=False):
hidden_dims = self.up_proj.output_dims
if isinstance(self.up_proj, QuantizedSwitchLinear):
weight, scales, biases, bias = self._fused_gate_up_params()
x = mx.gather_qmm(
x,
weight,
scales,
biases,
rhs_indices=indices,
transpose=True,
group_size=self.up_proj.group_size,
bits=self.up_proj.bits,
mode=self.up_proj.mode,
sorted_indices=sorted_indices,
)
else:
weight, bias = self._fused_gate_up_params()
x = mx.gather_mm(
x,
weight.swapaxes(-1, -2),
rhs_indices=indices,
sorted_indices=sorted_indices,
)
if bias is not None:
x = x + mx.expand_dims(bias[indices], -2)
return x[..., :hidden_dims], x[..., hidden_dims:]

def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
Expand All @@ -185,8 +276,11 @@ def __call__(self, x, indices) -> mx.array:
x, idx, inv_order = _gather_sort(x, indices)
if self.training:
idx = mx.stop_gradient(idx)
x_up = self.up_proj(x, idx, sorted_indices=do_sort)
x_gate = self.gate_proj(x, idx, sorted_indices=do_sort)
if self._can_fuse_gate_up():
x_up, x_gate = self._fused_gate_up(x, idx, sorted_indices=do_sort)
else:
x_up = self.up_proj(x, idx, sorted_indices=do_sort)
x_gate = self.gate_proj(x, idx, sorted_indices=do_sort)
x = self.down_proj(
self.activation(x_up, x_gate),
idx,
Expand Down
89 changes: 89 additions & 0 deletions tests/test_switch_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright © 2024 Apple Inc.

import unittest

import mlx.core as mx
from mlx.utils import tree_flatten

from mlx_lm.models.switch_layers import SwitchGLU


def assert_allclose(testcase, left, right, rtol=1e-5, atol=1e-5):
testcase.assertTrue(bool(mx.allclose(left, right, rtol=rtol, atol=atol)))


class TestSwitchGLUFusion(unittest.TestCase):
def test_fused_gate_up_matches_unfused(self):
mx.random.seed(0)
layer = SwitchGLU(16, 32, 4, bias=True)
layer.eval()
x = mx.random.normal((8, 16))
indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3], [2, 0], [3, 1]])

layer.fuse_gate_up = False
expected = layer(x, indices)
layer.fuse_gate_up = True
actual = layer(x, indices)
mx.eval(expected, actual)

assert_allclose(self, actual, expected)

def test_fused_gate_up_matches_unfused_sorted_path(self):
mx.random.seed(1)
layer = SwitchGLU(16, 32, 4, bias=True)
layer.eval()
x = mx.random.normal((64, 16))
indices = mx.array([[i % 4, (i + 1) % 4] for i in range(64)])

layer.fuse_gate_up = False
expected = layer(x, indices)
layer.fuse_gate_up = True
actual = layer(x, indices)
mx.eval(expected, actual)

assert_allclose(self, actual, expected)

def test_quantized_gate_up_fusion_falls_back_without_building_cache(self):
mx.random.seed(2)
layer = SwitchGLU(64, 64, 4, bias=True)
layer.gate_proj = layer.gate_proj.to_quantized(group_size=32, bits=4)
layer.up_proj = layer.up_proj.to_quantized(group_size=32, bits=4)
layer.eval()
x = mx.random.normal((8, 64))
indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3], [2, 0], [3, 1]])

layer.fuse_gate_up = False
expected = layer(x, indices)
layer.fuse_gate_up = True
actual = layer(x, indices)
mx.eval(expected, actual)

assert_allclose(self, actual, expected, rtol=1e-4, atol=1e-4)
self.assertIsNone(layer._fused_gate_up_cache)

def test_training_mode_falls_back_without_building_fused_cache(self):
layer = SwitchGLU(16, 32, 4, fuse_gate_up=True)
layer.train()
x = mx.random.normal((4, 16))
indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0]])

layer(x, indices)

self.assertIsNone(layer._fused_gate_up_cache)

def test_fused_cache_does_not_add_parameters(self):
mx.random.seed(3)
layer = SwitchGLU(16, 32, 4, bias=True, fuse_gate_up=True)
layer.eval()
before = [key for key, _ in tree_flatten(layer.parameters())]
x = mx.random.normal((4, 16))
indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0]])

layer(x, indices)
after = [key for key, _ in tree_flatten(layer.parameters())]

self.assertEqual(after, before)


if __name__ == "__main__":
unittest.main()