Skip to content

Commit

Permalink
Fused attention for single query (#1497)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Oct 18, 2024
1 parent 9dd72cd commit 50d8bed
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 747 deletions.
49 changes: 49 additions & 0 deletions benchmarks/python/sdpa_vector_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
import math

import mlx.core as mx
from time_utils import time_fn

L = 1024
H = 32
H_k = 32 // 4
D = 128


def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)


def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)


def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)


def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)


if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
5 changes: 3 additions & 2 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
steel/defines.h steel/gemm/transforms.h steel/utils.h)
build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)

set(STEEL_HEADERS
steel/defines.h
Expand Down
Loading

0 comments on commit 50d8bed

Please sign in to comment.