Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions challenges/medium/95_decode_phase_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
<p>
Implement the attention operation used during the <strong>decode phase</strong> of autoregressive
language model inference. At each decode step a single new token's query vectors attend over all
key-value pairs previously stored in the KV cache. Given query tensor <code>Q</code> of shape
<code>(batch_size, num_q_heads, head_dim)</code> — one query vector per head with no sequence
dimension — and cached key/value tensors <code>K</code>, <code>V</code> each of shape
<code>(batch_size, num_kv_heads, cache_len, head_dim)</code>, compute the scaled dot-product
attention output. Grouped Query Attention (GQA) is supported: every group of
<code>num_q_heads / num_kv_heads</code> consecutive query heads shares the same key and value head.
All tensors use <code>float32</code>.
</p>

<svg width="700" height="290" viewBox="0 0 700 290" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
<rect width="700" height="290" fill="#222" rx="10"/>

<!-- Title -->
<text x="350" y="26" fill="#ccc" font-family="monospace" font-size="13" text-anchor="middle">Decode-Phase Attention (batch_size=1, num_q_heads=4, num_kv_heads=2)</text>

<!-- Left: query heads -->
<text x="75" y="56" fill="#aaa" font-family="monospace" font-size="11" text-anchor="middle">New token (1 query per head)</text>
<rect x="15" y="65" width="55" height="30" fill="#2563eb" rx="4"/>
<text x="42" y="85" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">Q[0]</text>
<rect x="80" y="65" width="55" height="30" fill="#2563eb" rx="4"/>
<text x="107" y="85" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">Q[1]</text>
<rect x="15" y="105" width="55" height="30" fill="#7c3aed" rx="4"/>
<text x="42" y="125" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">Q[2]</text>
<rect x="80" y="105" width="55" height="30" fill="#7c3aed" rx="4"/>
<text x="107" y="125" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">Q[3]</text>
<text x="42" y="150" fill="#60a5fa" font-family="monospace" font-size="10" text-anchor="middle">group 0</text>
<text x="107" y="150" fill="#c4b5fd" font-family="monospace" font-size="10" text-anchor="middle">group 1</text>

<!-- Right: KV cache -->
<text x="450" y="56" fill="#aaa" font-family="monospace" font-size="11" text-anchor="middle">KV cache (cache_len positions)</text>
<rect x="255" y="65" width="52" height="30" fill="#1d4ed8" rx="4"/>
<text x="281" y="85" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[0]</text>
<rect x="317" y="65" width="52" height="30" fill="#1d4ed8" rx="4"/>
<text x="343" y="85" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[1]</text>
<rect x="379" y="65" width="52" height="30" fill="#1d4ed8" rx="4"/>
<text x="405" y="85" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[2]</text>
<text x="453" y="85" fill="#888" font-family="monospace" font-size="16" text-anchor="middle">&#8230;</text>
<rect x="470" y="65" width="66" height="30" fill="#1d4ed8" rx="4"/>
<text x="503" y="85" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[T-1]</text>
<text x="390" y="112" fill="#60a5fa" font-family="monospace" font-size="10" text-anchor="middle">KV head 0 (shared by Q[0], Q[1])</text>

<rect x="255" y="120" width="52" height="30" fill="#5b21b6" rx="4"/>
<text x="281" y="140" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[0]</text>
<rect x="317" y="120" width="52" height="30" fill="#5b21b6" rx="4"/>
<text x="343" y="140" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[1]</text>
<rect x="379" y="120" width="52" height="30" fill="#5b21b6" rx="4"/>
<text x="405" y="140" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[2]</text>
<text x="453" y="140" fill="#888" font-family="monospace" font-size="16" text-anchor="middle">&#8230;</text>
<rect x="470" y="120" width="66" height="30" fill="#5b21b6" rx="4"/>
<text x="503" y="140" fill="#fff" font-family="monospace" font-size="10" text-anchor="middle">K,V[T-1]</text>
<text x="390" y="165" fill="#c4b5fd" font-family="monospace" font-size="10" text-anchor="middle">KV head 1 (shared by Q[2], Q[3])</text>

<!-- Attention formula -->
<text x="350" y="198" fill="#4ade80" font-family="monospace" font-size="12" text-anchor="middle">scale = 1 / sqrt(head_dim)</text>
<text x="350" y="218" fill="#4ade80" font-family="monospace" font-size="12" text-anchor="middle">scores[t] = Q &#183; K_cache[t] &#215; scale</text>
<text x="350" y="238" fill="#4ade80" font-family="monospace" font-size="12" text-anchor="middle">weights = softmax(scores) [over all t]</text>
<text x="350" y="258" fill="#4ade80" font-family="monospace" font-size="12" text-anchor="middle">output = &#931; weights[t] &#215; V_cache[t]</text>

<defs>
<marker id="arr2" markerWidth="5" markerHeight="5" refX="3" refY="2.5" orient="auto">
<path d="M0,0 L0,5 L5,2.5 z" fill="#555"/>
</marker>
</defs>
<line x1="136" y1="80" x2="250" y2="80" stroke="#60a5fa" stroke-width="1.5" stroke-dasharray="4,3" marker-end="url(#arr2)"/>
<line x1="136" y1="120" x2="250" y2="135" stroke="#c4b5fd" stroke-width="1.5" stroke-dasharray="4,3" marker-end="url(#arr2)"/>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the function <code>solve(Q, K, V, output, batch_size, num_q_heads, num_kv_heads, cache_len, head_dim)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write the result into the provided <code>output</code> buffer.</li>
<li><code>num_q_heads</code> is always divisible by <code>num_kv_heads</code>; every group of <code>num_q_heads / num_kv_heads</code> consecutive query heads shares the same KV head.</li>
<li>Use scaled dot-product attention with scale factor <code>1 / sqrt(head_dim)</code> and softmax over the cache dimension.</li>
</ul>

<h2>Example</h2>
<p>
With <code>batch_size</code> = 1, <code>num_q_heads</code> = 2, <code>num_kv_heads</code> = 1,
<code>cache_len</code> = 3, <code>head_dim</code> = 4:
</p>
<p>
<strong>Input:</strong><br>
\(Q\) (2&times;4, one row per query head):
\[
\begin{bmatrix}
1 & 0 & 0 & 1 \\
0 & 1 & 0 & 1
\end{bmatrix}
\]
\(K\) (3&times;4, one row per cache position):
\[
\begin{bmatrix}
1 & 0 & 1 & 0 \\
0 & 1 & 0 & 1 \\
1 & 1 & 0 & 0
\end{bmatrix}
\]
\(V\) (3&times;4, one row per cache position):
\[
\begin{bmatrix}
1 & 2 & 3 & 4 \\
5 & 6 & 7 & 8 \\
9 & 10 & 11 & 12
\end{bmatrix}
\]
Both query heads attend to the single KV head (GQA groups = 2).
</p>
<p>
<strong>Output</strong> (2&times;4, values rounded to 2 decimal places):<br>
\[
\begin{bmatrix}
5.00 & 6.00 & 7.00 & 8.00 \\
5.48 & 6.48 & 7.48 & 8.48
\end{bmatrix}
\]
Head 0 receives equal scores (0.5, 0.5, 0.5) &rarr; uniform weights &rarr; mean of value rows.
Head 1 receives scores (0.0, 1.0, 0.5) &rarr; softmax concentrates weight on position 1.
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>batch_size</code> &le; 16</li>
<li>1 &le; <code>num_kv_heads</code> &le; <code>num_q_heads</code> &le; 64</li>
<li><code>num_q_heads</code> is divisible by <code>num_kv_heads</code></li>
<li>1 &le; <code>cache_len</code> &le; 65,536</li>
<li>8 &le; <code>head_dim</code> &le; 256; <code>head_dim</code> is a multiple of 8</li>
<li>All tensor values are <code>float32</code></li>
<li>Performance is measured with <code>batch_size</code> = 4, <code>num_q_heads</code> = 32, <code>num_kv_heads</code> = 8, <code>cache_len</code> = 16,384, <code>head_dim</code> = 128</li>
</ul>
183 changes: 183 additions & 0 deletions challenges/medium/95_decode_phase_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="Decode-Phase Attention",
atol=1e-04,
rtol=1e-04,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
batch_size: int,
num_q_heads: int,
num_kv_heads: int,
cache_len: int,
head_dim: int,
):
assert Q.shape == (batch_size, num_q_heads, head_dim)
assert K.shape == (batch_size, num_kv_heads, cache_len, head_dim)
assert V.shape == (batch_size, num_kv_heads, cache_len, head_dim)
assert output.shape == (batch_size, num_q_heads, head_dim)
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K.device.type == "cuda"
assert V.device.type == "cuda"
assert output.device.type == "cuda"
assert num_q_heads % num_kv_heads == 0

scale = 1.0 / math.sqrt(head_dim)
num_groups = num_q_heads // num_kv_heads

# Expand K and V from (B, Hkv, T, D) to (B, Hq, T, D)
K_exp = K.repeat_interleave(num_groups, dim=1)
V_exp = V.repeat_interleave(num_groups, dim=1)

# scores: (B, Hq, T) = Q(B, Hq, 1, D) @ K^T(B, Hq, D, T) -> squeeze
Q_unsq = Q.unsqueeze(2) # (B, Hq, 1, D)
scores = torch.matmul(Q_unsq, K_exp.transpose(2, 3)).squeeze(2) * scale

# Softmax over the cache dimension
weights = torch.softmax(scores, dim=-1) # (B, Hq, T)

# Weighted sum of values: (B, Hq, T) x (B, Hq, T, D) -> (B, Hq, D)
out = torch.matmul(weights.unsqueeze(2), V_exp).squeeze(2)
output.copy_(out)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K": (ctypes.POINTER(ctypes.c_float), "in"),
"V": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"batch_size": (ctypes.c_int, "in"),
"num_q_heads": (ctypes.c_int, "in"),
"num_kv_heads": (ctypes.c_int, "in"),
"cache_len": (ctypes.c_int, "in"),
"head_dim": (ctypes.c_int, "in"),
}

def _make_test_case(
self,
batch_size,
num_q_heads,
num_kv_heads,
cache_len,
head_dim,
zero_inputs=False,
):
dtype = torch.float32
device = "cuda"
if zero_inputs:
Q = torch.zeros(batch_size, num_q_heads, head_dim, device=device, dtype=dtype)
K = torch.zeros(
batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
)
V = torch.zeros(
batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
)
else:
Q = torch.randn(batch_size, num_q_heads, head_dim, device=device, dtype=dtype)
K = torch.randn(
batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
)
V = torch.randn(
batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
)
output = torch.zeros(batch_size, num_q_heads, head_dim, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"batch_size": batch_size,
"num_q_heads": num_q_heads,
"num_kv_heads": num_kv_heads,
"cache_len": cache_len,
"head_dim": head_dim,
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
device = "cuda"
Q = torch.tensor(
[[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]],
device=device,
dtype=dtype,
)
K = torch.tensor(
[[[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]]]],
device=device,
dtype=dtype,
)
V = torch.tensor(
[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]],
device=device,
dtype=dtype,
)
output = torch.zeros(1, 2, 4, device=device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"batch_size": 1,
"num_q_heads": 2,
"num_kv_heads": 1,
"cache_len": 3,
"head_dim": 4,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
torch.manual_seed(42)
tests = []

# Edge case: single head, single cache position
tests.append(self._make_test_case(1, 1, 1, 1, 8))

# Edge case: zero inputs (softmax of uniform zeros → uniform weights)
tests.append(self._make_test_case(1, 2, 1, 4, 8, zero_inputs=True))

# MQA (num_kv_heads=1): all query heads share one KV head
tests.append(self._make_test_case(2, 4, 1, 16, 16))

# GQA with groups=2, short cache
tests.append(self._make_test_case(2, 4, 2, 2, 8))

# MHA equivalent (num_kv_heads == num_q_heads)
tests.append(self._make_test_case(1, 4, 4, 16, 32))

# Power-of-2 cache length
tests.append(self._make_test_case(2, 8, 2, 64, 32))

# Power-of-2 larger cache
tests.append(self._make_test_case(2, 8, 2, 256, 64))

# Non-power-of-2 cache length
tests.append(self._make_test_case(2, 4, 2, 30, 32))

# Non-power-of-2, larger
tests.append(self._make_test_case(4, 4, 2, 100, 32))

# Realistic small inference: LLaMA-3 8B style heads
tests.append(self._make_test_case(2, 32, 8, 1024, 128))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
torch.manual_seed(0)
# LLaMA-3 8B: 32 Q heads, 8 KV heads, head_dim=128, long context
return self._make_test_case(4, 32, 8, 16384, 128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda_runtime.h>

// Q, K, V, output are device pointers
extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int batch_size,
int num_q_heads, int num_kv_heads, int cache_len, int head_dim) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import cutlass
import cutlass.cute as cute


# Q, K, V, output are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K: cute.Tensor,
V: cute.Tensor,
output: cute.Tensor,
batch_size: cute.Int32,
num_q_heads: cute.Int32,
num_kv_heads: cute.Int32,
cache_len: cute.Int32,
head_dim: cute.Int32,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax
import jax.numpy as jnp


# Q, K, V are tensors on GPU
@jax.jit
def solve(
Q: jax.Array,
K: jax.Array,
V: jax.Array,
batch_size: int,
num_q_heads: int,
num_kv_heads: int,
cache_len: int,
head_dim: int,
) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from std.gpu.host import DeviceContext
from std.memory import UnsafePointer


# Q, K, V, output are device pointers
@export
def solve(
Q: UnsafePointer[Float32, MutExternalOrigin],
K: UnsafePointer[Float32, MutExternalOrigin],
V: UnsafePointer[Float32, MutExternalOrigin],
output: UnsafePointer[Float32, MutExternalOrigin],
batch_size: Int32,
num_q_heads: Int32,
num_kv_heads: Int32,
cache_len: Int32,
head_dim: Int32,
) raises:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


# Q, K, V, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
output: torch.Tensor,
batch_size: int,
num_q_heads: int,
num_kv_heads: int,
cache_len: int,
head_dim: int,
):
pass
Loading
Loading