diff --git a/mlx_lm/models/switch_layers.py b/mlx_lm/models/switch_layers.py index 1fe5d917e..e18986a94 100644 --- a/mlx_lm/models/switch_layers.py +++ b/mlx_lm/models/switch_layers.py @@ -165,6 +165,7 @@ def __init__( num_experts: int, activation=SwiGLU(), bias: bool = False, + fuse_gate_up: bool = False, ): super().__init__() @@ -172,6 +173,96 @@ def __init__( self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) self.activation = activation + self.fuse_gate_up = fuse_gate_up + self._fused_gate_up_cache = None + + def _can_fuse_gate_up(self): + if not self.fuse_gate_up or self.training: + return False + if type(self.up_proj) is not type(self.gate_proj): + return False + if not isinstance(self.up_proj, SwitchLinear): + return False + if self.up_proj.input_dims != self.gate_proj.input_dims: + return False + if self.up_proj.output_dims != self.gate_proj.output_dims: + return False + if self.up_proj.num_experts != self.gate_proj.num_experts: + return False + if ("bias" in self.up_proj) != ("bias" in self.gate_proj): + return False + if isinstance(self.up_proj, QuantizedSwitchLinear): + if self.up_proj.group_size != self.gate_proj.group_size: + return False + if self.up_proj.bits != self.gate_proj.bits: + return False + if self.up_proj.mode != self.gate_proj.mode: + return False + if (self.up_proj.get("biases") is None) != ( + self.gate_proj.get("biases") is None + ): + return False + return True + + def _fused_gate_up_params(self): + up = self.up_proj + gate = self.gate_proj + key = ( + type(up), + id(up["weight"]), + id(gate["weight"]), + up["weight"].shape, + gate["weight"].shape, + ) + if self._fused_gate_up_cache is not None: + cached_key, params = self._fused_gate_up_cache + if cached_key == key: + return params + + weight = mx.concatenate([up["weight"], gate["weight"]], axis=1) + bias = None + if "bias" in up: + bias = mx.concatenate([up["bias"], gate["bias"]], axis=1) + if isinstance(up, QuantizedSwitchLinear): + scales = mx.concatenate([up["scales"], gate["scales"]], axis=1) + up_biases = up.get("biases") + gate_biases = gate.get("biases") + biases = None + if up_biases is not None: + biases = mx.concatenate([up_biases, gate_biases], axis=1) + params = (weight, scales, biases, bias) + else: + params = (weight, bias) + self._fused_gate_up_cache = (key, params) + return params + + def _fused_gate_up(self, x, indices, sorted_indices=False): + hidden_dims = self.up_proj.output_dims + if isinstance(self.up_proj, QuantizedSwitchLinear): + weight, scales, biases, bias = self._fused_gate_up_params() + x = mx.gather_qmm( + x, + weight, + scales, + biases, + rhs_indices=indices, + transpose=True, + group_size=self.up_proj.group_size, + bits=self.up_proj.bits, + mode=self.up_proj.mode, + sorted_indices=sorted_indices, + ) + else: + weight, bias = self._fused_gate_up_params() + x = mx.gather_mm( + x, + weight.swapaxes(-1, -2), + rhs_indices=indices, + sorted_indices=sorted_indices, + ) + if bias is not None: + x = x + mx.expand_dims(bias[indices], -2) + return x[..., :hidden_dims], x[..., hidden_dims:] def __call__(self, x, indices) -> mx.array: x = mx.expand_dims(x, (-2, -3)) @@ -185,8 +276,11 @@ def __call__(self, x, indices) -> mx.array: x, idx, inv_order = _gather_sort(x, indices) if self.training: idx = mx.stop_gradient(idx) - x_up = self.up_proj(x, idx, sorted_indices=do_sort) - x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) + if self._can_fuse_gate_up(): + x_up, x_gate = self._fused_gate_up(x, idx, sorted_indices=do_sort) + else: + x_up = self.up_proj(x, idx, sorted_indices=do_sort) + x_gate = self.gate_proj(x, idx, sorted_indices=do_sort) x = self.down_proj( self.activation(x_up, x_gate), idx, diff --git a/tests/test_switch_layers.py b/tests/test_switch_layers.py new file mode 100644 index 000000000..53c6db3f7 --- /dev/null +++ b/tests/test_switch_layers.py @@ -0,0 +1,89 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +from mlx.utils import tree_flatten + +from mlx_lm.models.switch_layers import SwitchGLU + + +def assert_allclose(testcase, left, right, rtol=1e-5, atol=1e-5): + testcase.assertTrue(bool(mx.allclose(left, right, rtol=rtol, atol=atol))) + + +class TestSwitchGLUFusion(unittest.TestCase): + def test_fused_gate_up_matches_unfused(self): + mx.random.seed(0) + layer = SwitchGLU(16, 32, 4, bias=True) + layer.eval() + x = mx.random.normal((8, 16)) + indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3], [2, 0], [3, 1]]) + + layer.fuse_gate_up = False + expected = layer(x, indices) + layer.fuse_gate_up = True + actual = layer(x, indices) + mx.eval(expected, actual) + + assert_allclose(self, actual, expected) + + def test_fused_gate_up_matches_unfused_sorted_path(self): + mx.random.seed(1) + layer = SwitchGLU(16, 32, 4, bias=True) + layer.eval() + x = mx.random.normal((64, 16)) + indices = mx.array([[i % 4, (i + 1) % 4] for i in range(64)]) + + layer.fuse_gate_up = False + expected = layer(x, indices) + layer.fuse_gate_up = True + actual = layer(x, indices) + mx.eval(expected, actual) + + assert_allclose(self, actual, expected) + + def test_quantized_gate_up_fusion_falls_back_without_building_cache(self): + mx.random.seed(2) + layer = SwitchGLU(64, 64, 4, bias=True) + layer.gate_proj = layer.gate_proj.to_quantized(group_size=32, bits=4) + layer.up_proj = layer.up_proj.to_quantized(group_size=32, bits=4) + layer.eval() + x = mx.random.normal((8, 64)) + indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3], [2, 0], [3, 1]]) + + layer.fuse_gate_up = False + expected = layer(x, indices) + layer.fuse_gate_up = True + actual = layer(x, indices) + mx.eval(expected, actual) + + assert_allclose(self, actual, expected, rtol=1e-4, atol=1e-4) + self.assertIsNone(layer._fused_gate_up_cache) + + def test_training_mode_falls_back_without_building_fused_cache(self): + layer = SwitchGLU(16, 32, 4, fuse_gate_up=True) + layer.train() + x = mx.random.normal((4, 16)) + indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0]]) + + layer(x, indices) + + self.assertIsNone(layer._fused_gate_up_cache) + + def test_fused_cache_does_not_add_parameters(self): + mx.random.seed(3) + layer = SwitchGLU(16, 32, 4, bias=True, fuse_gate_up=True) + layer.eval() + before = [key for key, _ in tree_flatten(layer.parameters())] + x = mx.random.normal((4, 16)) + indices = mx.array([[0, 1], [1, 2], [2, 3], [3, 0]]) + + layer(x, indices) + after = [key for key, _ in tree_flatten(layer.parameters())] + + self.assertEqual(after, before) + + +if __name__ == "__main__": + unittest.main()