diff --git a/mlx_lm/models/gated_delta.py b/mlx_lm/models/gated_delta.py index fa6a2ed3f..c5aaa0916 100644 --- a/mlx_lm/models/gated_delta.py +++ b/mlx_lm/models/gated_delta.py @@ -4,6 +4,8 @@ import mlx.core as mx import mlx.nn as nn +from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled + @partial(mx.compile, shapeless=True) def compute_g(A_log, a, dt_bias): @@ -278,6 +280,27 @@ def gated_delta_update( Hv, Dv = v.shape[-2:] state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32) - if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available(): + use_metal = use_kernel and mx.default_device() == mx.gpu and mx.metal.is_available() + path = "metal" if use_metal else "ops" + if not recurrent_profile_enabled(): + if use_metal: + return gated_delta_kernel(q, k, v, g, beta, state, mask) return gated_delta_ops(q, k, v, g, beta, state, mask) - return gated_delta_kernel(q, k, v, g, beta, state, mask) + metadata = { + "B": q.shape[0], + "T": q.shape[1], + "Hk": q.shape[2], + "Dk": q.shape[3], + "Hv": v.shape[2], + "Dv": v.shape[3], + "vectorized_gating": g.ndim == 4, + "has_mask": mask is not None, + } + return profile_recurrent_call( + op="gated_delta", + path=path, + metadata=metadata, + fn=lambda: gated_delta_kernel(q, k, v, g, beta, state, mask) + if use_metal + else gated_delta_ops(q, k, v, g, beta, state, mask), + ) diff --git a/mlx_lm/models/mamba.py b/mlx_lm/models/mamba.py index 0eff678fb..d8899eacd 100644 --- a/mlx_lm/models/mamba.py +++ b/mlx_lm/models/mamba.py @@ -9,6 +9,7 @@ from .activations import swiglu from .base import BaseModelArgs from .cache import ArraysCache +from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled @dataclass @@ -122,26 +123,43 @@ def ssm_step(self, x, A, state=None): return y, new_state def _process_sequence(self, x, conv_cache, state_cache): + def run(): + B, T, D = x.shape + xz = self.in_proj(x) + x_inner, z = xz.split(indices_or_sections=2, axis=-1) + K = self.conv_kernel_size + if conv_cache is not None: + x_full = mx.concatenate([conv_cache, x_inner], axis=1) + else: + x_full = mx.pad(x_inner, [(0, 0), (K - 1, 0), (0, 0)]) + conv_out = self.conv1d(x_full) + new_conv_cache = x_full[:, -(K - 1) :, :] + x_inner = nn.silu(conv_out) + A = -mx.exp(self.A_log) + current_state = state_cache + y = [] + for t in range(T): + y_t, current_state = self.ssm_step(x_inner[:, t], A, current_state) + y.append(y_t) + y = mx.stack(y, axis=1) + z = self.out_proj(swiglu(z, y)) + return z, (new_conv_cache, current_state) + + if not recurrent_profile_enabled(): + return run() B, T, D = x.shape - xz = self.in_proj(x) - x, z = xz.split(indices_or_sections=2, axis=-1) - K = self.conv_kernel_size - if conv_cache is not None: - x_full = mx.concatenate([conv_cache, x], axis=1) - else: - x_full = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - conv_out = self.conv1d(x_full) - new_conv_cache = x_full[:, -(K - 1) :, :] - x = nn.silu(conv_out) - A = -mx.exp(self.A_log) - current_state = state_cache - y = [] - for t in range(T): - y_t, current_state = self.ssm_step(x[:, t], A, current_state) - y.append(y_t) - y = mx.stack(y, axis=1) - z = self.out_proj(swiglu(z, y)) - return z, (new_conv_cache, current_state) + return profile_recurrent_call( + op="mamba", + path="python_loop", + metadata={ + "B": B, + "T": T, + "D": D, + "has_conv_cache": conv_cache is not None, + "has_state_cache": state_cache is not None, + }, + fn=run, + ) def __call__(self, x, cache): if cache is None: diff --git a/mlx_lm/models/recurrent_gemma.py b/mlx_lm/models/recurrent_gemma.py index 4659d6c3d..81a37c7aa 100644 --- a/mlx_lm/models/recurrent_gemma.py +++ b/mlx_lm/models/recurrent_gemma.py @@ -9,6 +9,7 @@ from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import ArraysCache, RotatingKVCache +from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled @dataclass @@ -51,29 +52,44 @@ def rnn_scan(x, a, h0): assert a.shape == x.shape[-a.ndim :] assert a.dtype == x.dtype - if x.shape[1] == 1: - # Using scan in sampling mode. - if h0 is None: - return x, x[:, 0] + def run(): + if x.shape[1] == 1: + # Using scan in sampling mode. + if h0 is None: + return x, x[:, 0] - else: - y = a * h0[:, None] + x - return y, y[:, -1] + else: + y = a * h0[:, None] + x + return y, y[:, -1] - else: - # Using scan in linear mode. - if h0 is not None: - h_t = h0 else: - B, _, D = x.shape - h_t = mx.zeros((B, D), dtype=x.dtype) - - y = mx.zeros_like(x) - for t in range(x.shape[1]): - h_t = a[:, t] * h_t + x[:, t] - y[:, t] = h_t - - return y, h_t + # Using scan in linear mode. + if h0 is not None: + h_t = h0 + else: + B, _, D = x.shape + h_t = mx.zeros((B, D), dtype=x.dtype) + + y = mx.zeros_like(x) + for t in range(x.shape[1]): + h_t = a[:, t] * h_t + x[:, t] + y[:, t] = h_t + + return y, h_t + + if not recurrent_profile_enabled(): + return run() + return profile_recurrent_call( + op="recurrent_gemma_rnn", + path="step" if x.shape[1] == 1 else "python_loop", + metadata={ + "B": x.shape[0], + "T": x.shape[1], + "D": x.shape[2], + "has_state": h0 is not None, + }, + fn=run, + ) class Conv1d(nn.Module): diff --git a/mlx_lm/models/recurrent_profile.py b/mlx_lm/models/recurrent_profile.py new file mode 100644 index 000000000..59cab7614 --- /dev/null +++ b/mlx_lm/models/recurrent_profile.py @@ -0,0 +1,56 @@ +# Copyright © 2024 Apple Inc. + +import json +import os +import sys +import time +from typing import Any, Callable + +import mlx.core as mx + + +def recurrent_profile_enabled(): + value = os.environ.get("MLX_LM_PROFILE_RECURRENT") + return value is not None and value.lower() not in {"", "0", "false", "no"} + + +def _arrays(value): + if isinstance(value, mx.array): + yield value + elif isinstance(value, (list, tuple)): + for item in value: + yield from _arrays(item) + elif isinstance(value, dict): + for item in value.values(): + yield from _arrays(item) + + +def profile_recurrent_call( + *, + op: str, + path: str, + metadata: dict[str, Any], + fn: Callable[[], Any], +): + if not recurrent_profile_enabled(): + return fn() + start = time.perf_counter() + result = fn() + arrays = list(_arrays(result)) + if arrays: + mx.eval(*arrays) + elapsed_ms = (time.perf_counter() - start) * 1000 + print( + json.dumps( + { + "event": "recurrent_profile", + "op": op, + "path": path, + "elapsed_ms": elapsed_ms, + **metadata, + } + ), + file=sys.stderr, + flush=True, + ) + return result diff --git a/mlx_lm/models/ssm.py b/mlx_lm/models/ssm.py index eb7199c96..57c5de866 100644 --- a/mlx_lm/models/ssm.py +++ b/mlx_lm/models/ssm.py @@ -3,6 +3,8 @@ import mlx.core as mx import mlx.nn as nn +from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled + @mx.compile def compute_dt(dt, dt_bias, time_step_limit): @@ -228,12 +230,26 @@ def ssm_update( lengths: Optional[mx.array] = None, ): seq_len = hidden_states.shape[1] - if ( + use_metal_step = not ( seq_len > 1 or state is None or mx.default_device() != mx.gpu or not mx.metal.is_available() - ): + ) + path = "metal_step" if use_metal_step else "ssm_attn" + if not recurrent_profile_enabled(): + if use_metal_step: + return ssm_update_kernel( + hidden_states, + A_log, + B, + C, + D, + dt, + dt_bias, + state, + time_step_limit, + ) return ssm_attn( hidden_states, A_log, @@ -247,8 +263,21 @@ def ssm_update( mask=mask, lengths=lengths, ) - else: - return ssm_update_kernel( + metadata = { + "B": hidden_states.shape[0], + "T": seq_len, + "H": hidden_states.shape[2], + "D": hidden_states.shape[3], + "state_dim": A_log.shape[-1], + "has_state": state is not None, + "has_mask": mask is not None, + "has_lengths": lengths is not None, + } + return profile_recurrent_call( + op="ssm", + path=path, + metadata=metadata, + fn=lambda: ssm_update_kernel( hidden_states, A_log, B, @@ -259,3 +288,18 @@ def ssm_update( state, time_step_limit, ) + if use_metal_step + else ssm_attn( + hidden_states, + A_log, + B, + C, + D, + dt, + dt_bias, + state, + time_step_limit, + mask=mask, + lengths=lengths, + ), + ) diff --git a/tests/test_recurrent_profile.py b/tests/test_recurrent_profile.py new file mode 100644 index 000000000..9846de74f --- /dev/null +++ b/tests/test_recurrent_profile.py @@ -0,0 +1,46 @@ +# Copyright © 2024 Apple Inc. + +import json +import unittest +from unittest.mock import patch + +import mlx.core as mx + +from mlx_lm.models.recurrent_profile import profile_recurrent_call + + +class TestRecurrentProfile(unittest.TestCase): + def test_disabled_profile_returns_result_without_output(self): + with patch.dict("os.environ", {}, clear=True), patch("sys.stderr") as stderr: + result = profile_recurrent_call( + op="toy", + path="path", + metadata={"T": 1}, + fn=lambda: mx.array([1]), + ) + mx.eval(result) + + self.assertTrue(bool(mx.array_equal(result, mx.array([1])))) + stderr.write.assert_not_called() + + def test_enabled_profile_emits_json(self): + with patch.dict("os.environ", {"MLX_LM_PROFILE_RECURRENT": "1"}): + with patch("sys.stderr") as stderr: + result = profile_recurrent_call( + op="toy", + path="path", + metadata={"T": 2}, + fn=lambda: mx.array([2]), + ) + mx.eval(result) + + payload = json.loads(stderr.write.call_args_list[0].args[0]) + self.assertEqual(payload["event"], "recurrent_profile") + self.assertEqual(payload["op"], "toy") + self.assertEqual(payload["path"], "path") + self.assertEqual(payload["T"], 2) + self.assertIn("elapsed_ms", payload) + + +if __name__ == "__main__": + unittest.main()