Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions mlx_lm/models/gated_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import mlx.core as mx
import mlx.nn as nn

from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled


@partial(mx.compile, shapeless=True)
def compute_g(A_log, a, dt_bias):
Expand Down Expand Up @@ -278,6 +280,27 @@ def gated_delta_update(
Hv, Dv = v.shape[-2:]
state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available():
use_metal = use_kernel and mx.default_device() == mx.gpu and mx.metal.is_available()
path = "metal" if use_metal else "ops"
if not recurrent_profile_enabled():
if use_metal:
return gated_delta_kernel(q, k, v, g, beta, state, mask)
return gated_delta_ops(q, k, v, g, beta, state, mask)
return gated_delta_kernel(q, k, v, g, beta, state, mask)
metadata = {
"B": q.shape[0],
"T": q.shape[1],
"Hk": q.shape[2],
"Dk": q.shape[3],
"Hv": v.shape[2],
"Dv": v.shape[3],
"vectorized_gating": g.ndim == 4,
"has_mask": mask is not None,
}
return profile_recurrent_call(
op="gated_delta",
path=path,
metadata=metadata,
fn=lambda: gated_delta_kernel(q, k, v, g, beta, state, mask)
if use_metal
else gated_delta_ops(q, k, v, g, beta, state, mask),
)
56 changes: 37 additions & 19 deletions mlx_lm/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .activations import swiglu
from .base import BaseModelArgs
from .cache import ArraysCache
from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled


@dataclass
Expand Down Expand Up @@ -122,26 +123,43 @@ def ssm_step(self, x, A, state=None):
return y, new_state

def _process_sequence(self, x, conv_cache, state_cache):
def run():
B, T, D = x.shape
xz = self.in_proj(x)
x_inner, z = xz.split(indices_or_sections=2, axis=-1)
K = self.conv_kernel_size
if conv_cache is not None:
x_full = mx.concatenate([conv_cache, x_inner], axis=1)
else:
x_full = mx.pad(x_inner, [(0, 0), (K - 1, 0), (0, 0)])
conv_out = self.conv1d(x_full)
new_conv_cache = x_full[:, -(K - 1) :, :]
x_inner = nn.silu(conv_out)
A = -mx.exp(self.A_log)
current_state = state_cache
y = []
for t in range(T):
y_t, current_state = self.ssm_step(x_inner[:, t], A, current_state)
y.append(y_t)
y = mx.stack(y, axis=1)
z = self.out_proj(swiglu(z, y))
return z, (new_conv_cache, current_state)

if not recurrent_profile_enabled():
return run()
B, T, D = x.shape
xz = self.in_proj(x)
x, z = xz.split(indices_or_sections=2, axis=-1)
K = self.conv_kernel_size
if conv_cache is not None:
x_full = mx.concatenate([conv_cache, x], axis=1)
else:
x_full = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
conv_out = self.conv1d(x_full)
new_conv_cache = x_full[:, -(K - 1) :, :]
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
current_state = state_cache
y = []
for t in range(T):
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
y.append(y_t)
y = mx.stack(y, axis=1)
z = self.out_proj(swiglu(z, y))
return z, (new_conv_cache, current_state)
return profile_recurrent_call(
op="mamba",
path="python_loop",
metadata={
"B": B,
"T": T,
"D": D,
"has_conv_cache": conv_cache is not None,
"has_state_cache": state_cache is not None,
},
fn=run,
)

def __call__(self, x, cache):
if cache is None:
Expand Down
56 changes: 36 additions & 20 deletions mlx_lm/models/recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import ArraysCache, RotatingKVCache
from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled


@dataclass
Expand Down Expand Up @@ -51,29 +52,44 @@ def rnn_scan(x, a, h0):
assert a.shape == x.shape[-a.ndim :]
assert a.dtype == x.dtype

if x.shape[1] == 1:
# Using scan in sampling mode.
if h0 is None:
return x, x[:, 0]
def run():
if x.shape[1] == 1:
# Using scan in sampling mode.
if h0 is None:
return x, x[:, 0]

else:
y = a * h0[:, None] + x
return y, y[:, -1]
else:
y = a * h0[:, None] + x
return y, y[:, -1]

else:
# Using scan in linear mode.
if h0 is not None:
h_t = h0
else:
B, _, D = x.shape
h_t = mx.zeros((B, D), dtype=x.dtype)

y = mx.zeros_like(x)
for t in range(x.shape[1]):
h_t = a[:, t] * h_t + x[:, t]
y[:, t] = h_t

return y, h_t
# Using scan in linear mode.
if h0 is not None:
h_t = h0
else:
B, _, D = x.shape
h_t = mx.zeros((B, D), dtype=x.dtype)

y = mx.zeros_like(x)
for t in range(x.shape[1]):
h_t = a[:, t] * h_t + x[:, t]
y[:, t] = h_t

return y, h_t

if not recurrent_profile_enabled():
return run()
return profile_recurrent_call(
op="recurrent_gemma_rnn",
path="step" if x.shape[1] == 1 else "python_loop",
metadata={
"B": x.shape[0],
"T": x.shape[1],
"D": x.shape[2],
"has_state": h0 is not None,
},
fn=run,
)


class Conv1d(nn.Module):
Expand Down
56 changes: 56 additions & 0 deletions mlx_lm/models/recurrent_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright © 2024 Apple Inc.

import json
import os
import sys
import time
from typing import Any, Callable

import mlx.core as mx


def recurrent_profile_enabled():
value = os.environ.get("MLX_LM_PROFILE_RECURRENT")
return value is not None and value.lower() not in {"", "0", "false", "no"}


def _arrays(value):
if isinstance(value, mx.array):
yield value
elif isinstance(value, (list, tuple)):
for item in value:
yield from _arrays(item)
elif isinstance(value, dict):
for item in value.values():
yield from _arrays(item)


def profile_recurrent_call(
*,
op: str,
path: str,
metadata: dict[str, Any],
fn: Callable[[], Any],
):
if not recurrent_profile_enabled():
return fn()
start = time.perf_counter()
result = fn()
arrays = list(_arrays(result))
if arrays:
mx.eval(*arrays)
elapsed_ms = (time.perf_counter() - start) * 1000
print(
json.dumps(
{
"event": "recurrent_profile",
"op": op,
"path": path,
"elapsed_ms": elapsed_ms,
**metadata,
}
),
file=sys.stderr,
flush=True,
)
return result
52 changes: 48 additions & 4 deletions mlx_lm/models/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import mlx.core as mx
import mlx.nn as nn

from .recurrent_profile import profile_recurrent_call, recurrent_profile_enabled


@mx.compile
def compute_dt(dt, dt_bias, time_step_limit):
Expand Down Expand Up @@ -228,12 +230,26 @@ def ssm_update(
lengths: Optional[mx.array] = None,
):
seq_len = hidden_states.shape[1]
if (
use_metal_step = not (
seq_len > 1
or state is None
or mx.default_device() != mx.gpu
or not mx.metal.is_available()
):
)
path = "metal_step" if use_metal_step else "ssm_attn"
if not recurrent_profile_enabled():
if use_metal_step:
return ssm_update_kernel(
hidden_states,
A_log,
B,
C,
D,
dt,
dt_bias,
state,
time_step_limit,
)
return ssm_attn(
hidden_states,
A_log,
Expand All @@ -247,8 +263,21 @@ def ssm_update(
mask=mask,
lengths=lengths,
)
else:
return ssm_update_kernel(
metadata = {
"B": hidden_states.shape[0],
"T": seq_len,
"H": hidden_states.shape[2],
"D": hidden_states.shape[3],
"state_dim": A_log.shape[-1],
"has_state": state is not None,
"has_mask": mask is not None,
"has_lengths": lengths is not None,
}
return profile_recurrent_call(
op="ssm",
path=path,
metadata=metadata,
fn=lambda: ssm_update_kernel(
hidden_states,
A_log,
B,
Expand All @@ -259,3 +288,18 @@ def ssm_update(
state,
time_step_limit,
)
if use_metal_step
else ssm_attn(
hidden_states,
A_log,
B,
C,
D,
dt,
dt_bias,
state,
time_step_limit,
mask=mask,
lengths=lengths,
),
)
46 changes: 46 additions & 0 deletions tests/test_recurrent_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright © 2024 Apple Inc.

import json
import unittest
from unittest.mock import patch

import mlx.core as mx

from mlx_lm.models.recurrent_profile import profile_recurrent_call


class TestRecurrentProfile(unittest.TestCase):
def test_disabled_profile_returns_result_without_output(self):
with patch.dict("os.environ", {}, clear=True), patch("sys.stderr") as stderr:
result = profile_recurrent_call(
op="toy",
path="path",
metadata={"T": 1},
fn=lambda: mx.array([1]),
)
mx.eval(result)

self.assertTrue(bool(mx.array_equal(result, mx.array([1]))))
stderr.write.assert_not_called()

def test_enabled_profile_emits_json(self):
with patch.dict("os.environ", {"MLX_LM_PROFILE_RECURRENT": "1"}):
with patch("sys.stderr") as stderr:
result = profile_recurrent_call(
op="toy",
path="path",
metadata={"T": 2},
fn=lambda: mx.array([2]),
)
mx.eval(result)

payload = json.loads(stderr.write.call_args_list[0].args[0])
self.assertEqual(payload["event"], "recurrent_profile")
self.assertEqual(payload["op"], "toy")
self.assertEqual(payload["path"], "path")
self.assertEqual(payload["T"], 2)
self.assertIn("elapsed_ms", payload)


if __name__ == "__main__":
unittest.main()