From 422420fedd9c333940afa6767195b96c87cb3360 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 04:59:07 +0000 Subject: [PATCH] Add challenge 95: Decode-Phase Attention (Medium) Single-token-query attention over a full KV cache, the dominant kernel in autoregressive LLM decode steps. Supports Grouped Query Attention (GQA) where multiple query heads share one KV head. Teaches the memory-bandwidth- bound nature of decode-phase workloads, distinct from compute-bound training attention. Co-Authored-By: Claude Sonnet 4.6 --- .../95_decode_phase_attention/challenge.html | 133 +++++++++++++ .../95_decode_phase_attention/challenge.py | 183 ++++++++++++++++++ .../starter/starter.cu | 5 + .../starter/starter.cute.py | 18 ++ .../starter/starter.jax.py | 18 ++ .../starter/starter.mojo | 18 ++ .../starter/starter.pytorch.py | 16 ++ .../starter/starter.triton.py | 18 ++ 8 files changed, 409 insertions(+) create mode 100644 challenges/medium/95_decode_phase_attention/challenge.html create mode 100644 challenges/medium/95_decode_phase_attention/challenge.py create mode 100644 challenges/medium/95_decode_phase_attention/starter/starter.cu create mode 100644 challenges/medium/95_decode_phase_attention/starter/starter.cute.py create mode 100644 challenges/medium/95_decode_phase_attention/starter/starter.jax.py create mode 100644 challenges/medium/95_decode_phase_attention/starter/starter.mojo create mode 100644 challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py create mode 100644 challenges/medium/95_decode_phase_attention/starter/starter.triton.py diff --git a/challenges/medium/95_decode_phase_attention/challenge.html b/challenges/medium/95_decode_phase_attention/challenge.html new file mode 100644 index 00000000..3888786d --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/challenge.html @@ -0,0 +1,133 @@ +

+Implement the attention operation used during the decode phase of autoregressive +language model inference. At each decode step a single new token's query vectors attend over all +key-value pairs previously stored in the KV cache. Given query tensor Q of shape +(batch_size, num_q_heads, head_dim) — one query vector per head with no sequence +dimension — and cached key/value tensors K, V each of shape +(batch_size, num_kv_heads, cache_len, head_dim), compute the scaled dot-product +attention output. Grouped Query Attention (GQA) is supported: every group of +num_q_heads / num_kv_heads consecutive query heads shares the same key and value head. +All tensors use float32. +

+ + + + + + Decode-Phase Attention (batch_size=1, num_q_heads=4, num_kv_heads=2) + + + New token (1 query per head) + + Q[0] + + Q[1] + + Q[2] + + Q[3] + group 0 + group 1 + + + KV cache (cache_len positions) + + K,V[0] + + K,V[1] + + K,V[2] + + + K,V[T-1] + KV head 0 (shared by Q[0], Q[1]) + + + K,V[0] + + K,V[1] + + K,V[2] + + + K,V[T-1] + KV head 1 (shared by Q[2], Q[3]) + + + scale = 1 / sqrt(head_dim) + scores[t] = Q · K_cache[t] × scale + weights = softmax(scores) [over all t] + output = Σ weights[t] × V_cache[t] + + + + + + + + + + +

Implementation Requirements

+ + +

Example

+

+ With batch_size = 1, num_q_heads = 2, num_kv_heads = 1, + cache_len = 3, head_dim = 4: +

+

+ Input:
+ \(Q\) (2×4, one row per query head): + \[ + \begin{bmatrix} + 1 & 0 & 0 & 1 \\ + 0 & 1 & 0 & 1 + \end{bmatrix} + \] + \(K\) (3×4, one row per cache position): + \[ + \begin{bmatrix} + 1 & 0 & 1 & 0 \\ + 0 & 1 & 0 & 1 \\ + 1 & 1 & 0 & 0 + \end{bmatrix} + \] + \(V\) (3×4, one row per cache position): + \[ + \begin{bmatrix} + 1 & 2 & 3 & 4 \\ + 5 & 6 & 7 & 8 \\ + 9 & 10 & 11 & 12 + \end{bmatrix} + \] + Both query heads attend to the single KV head (GQA groups = 2). +

+

+ Output (2×4, values rounded to 2 decimal places):
+ \[ + \begin{bmatrix} + 5.00 & 6.00 & 7.00 & 8.00 \\ + 5.48 & 6.48 & 7.48 & 8.48 + \end{bmatrix} + \] + Head 0 receives equal scores (0.5, 0.5, 0.5) → uniform weights → mean of value rows. + Head 1 receives scores (0.0, 1.0, 0.5) → softmax concentrates weight on position 1. +

+ +

Constraints

+ diff --git a/challenges/medium/95_decode_phase_attention/challenge.py b/challenges/medium/95_decode_phase_attention/challenge.py new file mode 100644 index 00000000..df7f34a4 --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/challenge.py @@ -0,0 +1,183 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Decode-Phase Attention", + atol=1e-04, + rtol=1e-04, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_q_heads: int, + num_kv_heads: int, + cache_len: int, + head_dim: int, + ): + assert Q.shape == (batch_size, num_q_heads, head_dim) + assert K.shape == (batch_size, num_kv_heads, cache_len, head_dim) + assert V.shape == (batch_size, num_kv_heads, cache_len, head_dim) + assert output.shape == (batch_size, num_q_heads, head_dim) + assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K.device.type == "cuda" + assert V.device.type == "cuda" + assert output.device.type == "cuda" + assert num_q_heads % num_kv_heads == 0 + + scale = 1.0 / math.sqrt(head_dim) + num_groups = num_q_heads // num_kv_heads + + # Expand K and V from (B, Hkv, T, D) to (B, Hq, T, D) + K_exp = K.repeat_interleave(num_groups, dim=1) + V_exp = V.repeat_interleave(num_groups, dim=1) + + # scores: (B, Hq, T) = Q(B, Hq, 1, D) @ K^T(B, Hq, D, T) -> squeeze + Q_unsq = Q.unsqueeze(2) # (B, Hq, 1, D) + scores = torch.matmul(Q_unsq, K_exp.transpose(2, 3)).squeeze(2) * scale + + # Softmax over the cache dimension + weights = torch.softmax(scores, dim=-1) # (B, Hq, T) + + # Weighted sum of values: (B, Hq, T) x (B, Hq, T, D) -> (B, Hq, D) + out = torch.matmul(weights.unsqueeze(2), V_exp).squeeze(2) + output.copy_(out) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K": (ctypes.POINTER(ctypes.c_float), "in"), + "V": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "batch_size": (ctypes.c_int, "in"), + "num_q_heads": (ctypes.c_int, "in"), + "num_kv_heads": (ctypes.c_int, "in"), + "cache_len": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + } + + def _make_test_case( + self, + batch_size, + num_q_heads, + num_kv_heads, + cache_len, + head_dim, + zero_inputs=False, + ): + dtype = torch.float32 + device = "cuda" + if zero_inputs: + Q = torch.zeros(batch_size, num_q_heads, head_dim, device=device, dtype=dtype) + K = torch.zeros( + batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype + ) + V = torch.zeros( + batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype + ) + else: + Q = torch.randn(batch_size, num_q_heads, head_dim, device=device, dtype=dtype) + K = torch.randn( + batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype + ) + V = torch.randn( + batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype + ) + output = torch.zeros(batch_size, num_q_heads, head_dim, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "batch_size": batch_size, + "num_q_heads": num_q_heads, + "num_kv_heads": num_kv_heads, + "cache_len": cache_len, + "head_dim": head_dim, + } + + def generate_example_test(self) -> Dict[str, Any]: + dtype = torch.float32 + device = "cuda" + Q = torch.tensor( + [[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]], + device=device, + dtype=dtype, + ) + K = torch.tensor( + [[[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]]]], + device=device, + dtype=dtype, + ) + V = torch.tensor( + [[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]], + device=device, + dtype=dtype, + ) + output = torch.zeros(1, 2, 4, device=device, dtype=dtype) + return { + "Q": Q, + "K": K, + "V": V, + "output": output, + "batch_size": 1, + "num_q_heads": 2, + "num_kv_heads": 1, + "cache_len": 3, + "head_dim": 4, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge case: single head, single cache position + tests.append(self._make_test_case(1, 1, 1, 1, 8)) + + # Edge case: zero inputs (softmax of uniform zeros → uniform weights) + tests.append(self._make_test_case(1, 2, 1, 4, 8, zero_inputs=True)) + + # MQA (num_kv_heads=1): all query heads share one KV head + tests.append(self._make_test_case(2, 4, 1, 16, 16)) + + # GQA with groups=2, short cache + tests.append(self._make_test_case(2, 4, 2, 2, 8)) + + # MHA equivalent (num_kv_heads == num_q_heads) + tests.append(self._make_test_case(1, 4, 4, 16, 32)) + + # Power-of-2 cache length + tests.append(self._make_test_case(2, 8, 2, 64, 32)) + + # Power-of-2 larger cache + tests.append(self._make_test_case(2, 8, 2, 256, 64)) + + # Non-power-of-2 cache length + tests.append(self._make_test_case(2, 4, 2, 30, 32)) + + # Non-power-of-2, larger + tests.append(self._make_test_case(4, 4, 2, 100, 32)) + + # Realistic small inference: LLaMA-3 8B style heads + tests.append(self._make_test_case(2, 32, 8, 1024, 128)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # LLaMA-3 8B: 32 Q heads, 8 KV heads, head_dim=128, long context + return self._make_test_case(4, 32, 8, 16384, 128) diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.cu b/challenges/medium/95_decode_phase_attention/starter/starter.cu new file mode 100644 index 00000000..c0e80d5d --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// Q, K, V, output are device pointers +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int batch_size, + int num_q_heads, int num_kv_heads, int cache_len, int head_dim) {} diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.cute.py b/challenges/medium/95_decode_phase_attention/starter/starter.cute.py new file mode 100644 index 00000000..3cd2cb5c --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/starter/starter.cute.py @@ -0,0 +1,18 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K, V, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + output: cute.Tensor, + batch_size: cute.Int32, + num_q_heads: cute.Int32, + num_kv_heads: cute.Int32, + cache_len: cute.Int32, + head_dim: cute.Int32, +): + pass diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.jax.py b/challenges/medium/95_decode_phase_attention/starter/starter.jax.py new file mode 100644 index 00000000..7970bce1 --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/starter/starter.jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp + + +# Q, K, V are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K: jax.Array, + V: jax.Array, + batch_size: int, + num_q_heads: int, + num_kv_heads: int, + cache_len: int, + head_dim: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.mojo b/challenges/medium/95_decode_phase_attention/starter/starter.mojo new file mode 100644 index 00000000..8a52b3b7 --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/starter/starter.mojo @@ -0,0 +1,18 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# Q, K, V, output are device pointers +@export +def solve( + Q: UnsafePointer[Float32, MutExternalOrigin], + K: UnsafePointer[Float32, MutExternalOrigin], + V: UnsafePointer[Float32, MutExternalOrigin], + output: UnsafePointer[Float32, MutExternalOrigin], + batch_size: Int32, + num_q_heads: Int32, + num_kv_heads: Int32, + cache_len: Int32, + head_dim: Int32, +) raises: + pass diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py b/challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..8d9b5b28 --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py @@ -0,0 +1,16 @@ +import torch + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_q_heads: int, + num_kv_heads: int, + cache_len: int, + head_dim: int, +): + pass diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.triton.py b/challenges/medium/95_decode_phase_attention/starter/starter.triton.py new file mode 100644 index 00000000..52038ffc --- /dev/null +++ b/challenges/medium/95_decode_phase_attention/starter/starter.triton.py @@ -0,0 +1,18 @@ +import torch +import triton +import triton.language as tl + + +# Q, K, V, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + output: torch.Tensor, + batch_size: int, + num_q_heads: int, + num_kv_heads: int, + cache_len: int, + head_dim: int, +): + pass