diff --git a/challenges/medium/95_decode_phase_attention/challenge.html b/challenges/medium/95_decode_phase_attention/challenge.html
new file mode 100644
index 00000000..3888786d
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/challenge.html
@@ -0,0 +1,133 @@
+
+Implement the attention operation used during the decode phase of autoregressive
+language model inference. At each decode step a single new token's query vectors attend over all
+key-value pairs previously stored in the KV cache. Given query tensor Q of shape
+(batch_size, num_q_heads, head_dim) — one query vector per head with no sequence
+dimension — and cached key/value tensors K, V each of shape
+(batch_size, num_kv_heads, cache_len, head_dim), compute the scaled dot-product
+attention output. Grouped Query Attention (GQA) is supported: every group of
+num_q_heads / num_kv_heads consecutive query heads shares the same key and value head.
+All tensors use float32.
+
+
+
+
+
+
+ Decode-Phase Attention (batch_size=1, num_q_heads=4, num_kv_heads=2)
+
+
+ New token (1 query per head)
+
+ Q[0]
+
+ Q[1]
+
+ Q[2]
+
+ Q[3]
+ group 0
+ group 1
+
+
+ KV cache (cache_len positions)
+
+ K,V[0]
+
+ K,V[1]
+
+ K,V[2]
+ …
+
+ K,V[T-1]
+ KV head 0 (shared by Q[0], Q[1])
+
+
+ K,V[0]
+
+ K,V[1]
+
+ K,V[2]
+ …
+
+ K,V[T-1]
+ KV head 1 (shared by Q[2], Q[3])
+
+
+ scale = 1 / sqrt(head_dim)
+ scores[t] = Q · K_cache[t] × scale
+ weights = softmax(scores) [over all t]
+ output = Σ weights[t] × V_cache[t]
+
+
+
+
+
+
+
+
+
+
+Implementation Requirements
+
+ Implement the function solve(Q, K, V, output, batch_size, num_q_heads, num_kv_heads, cache_len, head_dim).
+ Do not change the function signature or use external libraries beyond the standard GPU frameworks.
+ Write the result into the provided output buffer.
+ num_q_heads is always divisible by num_kv_heads; every group of num_q_heads / num_kv_heads consecutive query heads shares the same KV head.
+ Use scaled dot-product attention with scale factor 1 / sqrt(head_dim) and softmax over the cache dimension.
+
+
+Example
+
+ With batch_size = 1, num_q_heads = 2, num_kv_heads = 1,
+ cache_len = 3, head_dim = 4:
+
+
+ Input:
+ \(Q\) (2×4, one row per query head):
+ \[
+ \begin{bmatrix}
+ 1 & 0 & 0 & 1 \\
+ 0 & 1 & 0 & 1
+ \end{bmatrix}
+ \]
+ \(K\) (3×4, one row per cache position):
+ \[
+ \begin{bmatrix}
+ 1 & 0 & 1 & 0 \\
+ 0 & 1 & 0 & 1 \\
+ 1 & 1 & 0 & 0
+ \end{bmatrix}
+ \]
+ \(V\) (3×4, one row per cache position):
+ \[
+ \begin{bmatrix}
+ 1 & 2 & 3 & 4 \\
+ 5 & 6 & 7 & 8 \\
+ 9 & 10 & 11 & 12
+ \end{bmatrix}
+ \]
+ Both query heads attend to the single KV head (GQA groups = 2).
+
+
+ Output (2×4, values rounded to 2 decimal places):
+ \[
+ \begin{bmatrix}
+ 5.00 & 6.00 & 7.00 & 8.00 \\
+ 5.48 & 6.48 & 7.48 & 8.48
+ \end{bmatrix}
+ \]
+ Head 0 receives equal scores (0.5, 0.5, 0.5) → uniform weights → mean of value rows.
+ Head 1 receives scores (0.0, 1.0, 0.5) → softmax concentrates weight on position 1.
+
+
+Constraints
+
+ 1 ≤ batch_size ≤ 16
+ 1 ≤ num_kv_heads ≤ num_q_heads ≤ 64
+ num_q_heads is divisible by num_kv_heads
+ 1 ≤ cache_len ≤ 65,536
+ 8 ≤ head_dim ≤ 256; head_dim is a multiple of 8
+ All tensor values are float32
+ Performance is measured with batch_size = 4, num_q_heads = 32, num_kv_heads = 8, cache_len = 16,384, head_dim = 128
+
diff --git a/challenges/medium/95_decode_phase_attention/challenge.py b/challenges/medium/95_decode_phase_attention/challenge.py
new file mode 100644
index 00000000..df7f34a4
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/challenge.py
@@ -0,0 +1,183 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="Decode-Phase Attention",
+ atol=1e-04,
+ rtol=1e-04,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ batch_size: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ cache_len: int,
+ head_dim: int,
+ ):
+ assert Q.shape == (batch_size, num_q_heads, head_dim)
+ assert K.shape == (batch_size, num_kv_heads, cache_len, head_dim)
+ assert V.shape == (batch_size, num_kv_heads, cache_len, head_dim)
+ assert output.shape == (batch_size, num_q_heads, head_dim)
+ assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K.device.type == "cuda"
+ assert V.device.type == "cuda"
+ assert output.device.type == "cuda"
+ assert num_q_heads % num_kv_heads == 0
+
+ scale = 1.0 / math.sqrt(head_dim)
+ num_groups = num_q_heads // num_kv_heads
+
+ # Expand K and V from (B, Hkv, T, D) to (B, Hq, T, D)
+ K_exp = K.repeat_interleave(num_groups, dim=1)
+ V_exp = V.repeat_interleave(num_groups, dim=1)
+
+ # scores: (B, Hq, T) = Q(B, Hq, 1, D) @ K^T(B, Hq, D, T) -> squeeze
+ Q_unsq = Q.unsqueeze(2) # (B, Hq, 1, D)
+ scores = torch.matmul(Q_unsq, K_exp.transpose(2, 3)).squeeze(2) * scale
+
+ # Softmax over the cache dimension
+ weights = torch.softmax(scores, dim=-1) # (B, Hq, T)
+
+ # Weighted sum of values: (B, Hq, T) x (B, Hq, T, D) -> (B, Hq, D)
+ out = torch.matmul(weights.unsqueeze(2), V_exp).squeeze(2)
+ output.copy_(out)
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K": (ctypes.POINTER(ctypes.c_float), "in"),
+ "V": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output": (ctypes.POINTER(ctypes.c_float), "out"),
+ "batch_size": (ctypes.c_int, "in"),
+ "num_q_heads": (ctypes.c_int, "in"),
+ "num_kv_heads": (ctypes.c_int, "in"),
+ "cache_len": (ctypes.c_int, "in"),
+ "head_dim": (ctypes.c_int, "in"),
+ }
+
+ def _make_test_case(
+ self,
+ batch_size,
+ num_q_heads,
+ num_kv_heads,
+ cache_len,
+ head_dim,
+ zero_inputs=False,
+ ):
+ dtype = torch.float32
+ device = "cuda"
+ if zero_inputs:
+ Q = torch.zeros(batch_size, num_q_heads, head_dim, device=device, dtype=dtype)
+ K = torch.zeros(
+ batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
+ )
+ V = torch.zeros(
+ batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
+ )
+ else:
+ Q = torch.randn(batch_size, num_q_heads, head_dim, device=device, dtype=dtype)
+ K = torch.randn(
+ batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
+ )
+ V = torch.randn(
+ batch_size, num_kv_heads, cache_len, head_dim, device=device, dtype=dtype
+ )
+ output = torch.zeros(batch_size, num_q_heads, head_dim, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "batch_size": batch_size,
+ "num_q_heads": num_q_heads,
+ "num_kv_heads": num_kv_heads,
+ "cache_len": cache_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ dtype = torch.float32
+ device = "cuda"
+ Q = torch.tensor(
+ [[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]],
+ device=device,
+ dtype=dtype,
+ )
+ K = torch.tensor(
+ [[[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0]]]],
+ device=device,
+ dtype=dtype,
+ )
+ V = torch.tensor(
+ [[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]]],
+ device=device,
+ dtype=dtype,
+ )
+ output = torch.zeros(1, 2, 4, device=device, dtype=dtype)
+ return {
+ "Q": Q,
+ "K": K,
+ "V": V,
+ "output": output,
+ "batch_size": 1,
+ "num_q_heads": 2,
+ "num_kv_heads": 1,
+ "cache_len": 3,
+ "head_dim": 4,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ torch.manual_seed(42)
+ tests = []
+
+ # Edge case: single head, single cache position
+ tests.append(self._make_test_case(1, 1, 1, 1, 8))
+
+ # Edge case: zero inputs (softmax of uniform zeros → uniform weights)
+ tests.append(self._make_test_case(1, 2, 1, 4, 8, zero_inputs=True))
+
+ # MQA (num_kv_heads=1): all query heads share one KV head
+ tests.append(self._make_test_case(2, 4, 1, 16, 16))
+
+ # GQA with groups=2, short cache
+ tests.append(self._make_test_case(2, 4, 2, 2, 8))
+
+ # MHA equivalent (num_kv_heads == num_q_heads)
+ tests.append(self._make_test_case(1, 4, 4, 16, 32))
+
+ # Power-of-2 cache length
+ tests.append(self._make_test_case(2, 8, 2, 64, 32))
+
+ # Power-of-2 larger cache
+ tests.append(self._make_test_case(2, 8, 2, 256, 64))
+
+ # Non-power-of-2 cache length
+ tests.append(self._make_test_case(2, 4, 2, 30, 32))
+
+ # Non-power-of-2, larger
+ tests.append(self._make_test_case(4, 4, 2, 100, 32))
+
+ # Realistic small inference: LLaMA-3 8B style heads
+ tests.append(self._make_test_case(2, 32, 8, 1024, 128))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ torch.manual_seed(0)
+ # LLaMA-3 8B: 32 Q heads, 8 KV heads, head_dim=128, long context
+ return self._make_test_case(4, 32, 8, 16384, 128)
diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.cu b/challenges/medium/95_decode_phase_attention/starter/starter.cu
new file mode 100644
index 00000000..c0e80d5d
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// Q, K, V, output are device pointers
+extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int batch_size,
+ int num_q_heads, int num_kv_heads, int cache_len, int head_dim) {}
diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.cute.py b/challenges/medium/95_decode_phase_attention/starter/starter.cute.py
new file mode 100644
index 00000000..3cd2cb5c
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/starter/starter.cute.py
@@ -0,0 +1,18 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K, V, output are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K: cute.Tensor,
+ V: cute.Tensor,
+ output: cute.Tensor,
+ batch_size: cute.Int32,
+ num_q_heads: cute.Int32,
+ num_kv_heads: cute.Int32,
+ cache_len: cute.Int32,
+ head_dim: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.jax.py b/challenges/medium/95_decode_phase_attention/starter/starter.jax.py
new file mode 100644
index 00000000..7970bce1
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/starter/starter.jax.py
@@ -0,0 +1,18 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K, V are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array,
+ K: jax.Array,
+ V: jax.Array,
+ batch_size: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ cache_len: int,
+ head_dim: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.mojo b/challenges/medium/95_decode_phase_attention/starter/starter.mojo
new file mode 100644
index 00000000..8a52b3b7
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/starter/starter.mojo
@@ -0,0 +1,18 @@
+from std.gpu.host import DeviceContext
+from std.memory import UnsafePointer
+
+
+# Q, K, V, output are device pointers
+@export
+def solve(
+ Q: UnsafePointer[Float32, MutExternalOrigin],
+ K: UnsafePointer[Float32, MutExternalOrigin],
+ V: UnsafePointer[Float32, MutExternalOrigin],
+ output: UnsafePointer[Float32, MutExternalOrigin],
+ batch_size: Int32,
+ num_q_heads: Int32,
+ num_kv_heads: Int32,
+ cache_len: Int32,
+ head_dim: Int32,
+) raises:
+ pass
diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py b/challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py
new file mode 100644
index 00000000..8d9b5b28
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/starter/starter.pytorch.py
@@ -0,0 +1,16 @@
+import torch
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ batch_size: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ cache_len: int,
+ head_dim: int,
+):
+ pass
diff --git a/challenges/medium/95_decode_phase_attention/starter/starter.triton.py b/challenges/medium/95_decode_phase_attention/starter/starter.triton.py
new file mode 100644
index 00000000..52038ffc
--- /dev/null
+++ b/challenges/medium/95_decode_phase_attention/starter/starter.triton.py
@@ -0,0 +1,18 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K, V, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K: torch.Tensor,
+ V: torch.Tensor,
+ output: torch.Tensor,
+ batch_size: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ cache_len: int,
+ head_dim: int,
+):
+ pass