diff --git a/challenges/medium/96_int8_kv_cache_attention/challenge.html b/challenges/medium/96_int8_kv_cache_attention/challenge.html
new file mode 100644
index 0000000..a7a9566
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/challenge.html
@@ -0,0 +1,131 @@
+
+Implement decode-phase multi-head attention where the key and value caches are stored as
+int8 with per-token scale factors. This memory layout halves KV-cache bandwidth
+versus float32 and is used in production LLM serving systems such as TensorRT-LLM
+and vLLM. Given a query tensor Q for a single new token, int8 key cache
+K_int8, int8 value cache V_int8, and per-token scales
+k_scale and v_scale, dequantize the caches and compute scaled
+dot-product attention to produce output. All non-integer tensors use
+float32.
+
+
+
+
+ INT8 KV-Cache Attention — single token decode
+
+
+
+ Q (fp32)
+
+
+
+ K_int8 (int8)
+
+ k_scale (fp32)
+
+
+
+ V_int8 (int8)
+
+ v_scale (fp32)
+
+
+
+
+ ×
+ ×
+
+
+
+ K_float
+
+ V_float
+
+
+
+
+
+
+ scores
+
+
+ softmax
+
+
+
+ output
+
+
+ K[h,s,:] = K_int8[h,s,:] × k_scale[h,s]
+ V[h,s,:] = V_int8[h,s,:] × v_scale[h,s]
+ scores[h,s] = Q[h,:]·K[h,s,:] / √head_dim
+ w[h,:] = softmax(scores[h,:])
+ out[h,:] = Σ_s w[h,s] · V[h,s,:]
+
+
+
+
+
+
+
+
+Implementation Requirements
+
+ Implement the function solve(Q, K_int8, V_int8, k_scale, v_scale, output, num_heads, seq_len, head_dim).
+ Do not change the function signature or use external libraries beyond the standard GPU frameworks.
+ Write the result into the provided output buffer.
+ Dequantize using per-token scales: K_float[h, s, d] = K_int8[h, s, d] × k_scale[h, s] (and analogously for V).
+ Use scaled dot-product attention with scale factor 1 / sqrt(head_dim) and a softmax over the sequence dimension.
+
+
+Example
+
+ With num_heads = 1, seq_len = 3, head_dim = 4:
+
+
+ Input:
+ \(Q\) (1×4):
+ \[
+ \begin{bmatrix} 1 & 1 & 1 & 1 \end{bmatrix}
+ \]
+ \(K\_int8\) (1×3×4):
+ \[
+ \begin{bmatrix} 10 & 0 & 0 & 0 \\ 0 & 10 & 0 & 0 \\ 0 & 0 & 10 & 0 \end{bmatrix}
+ \]
+ \(k\_scale\) (1×3): \(\begin{bmatrix} 0.1 & 0.1 & 0.1 \end{bmatrix}\)
+ ⇒
+ \(K\_float\) (1×3×4):
+ \[
+ \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix}
+ \]
+ \(V\_int8\) (1×3×4):
+ \[
+ \begin{bmatrix} 10 & 20 & 30 & 40 \\ 50 & 60 & 70 & 80 \\ 90 & 100 & 110 & 120 \end{bmatrix}
+ \]
+ \(v\_scale\) (1×3): \(\begin{bmatrix} 0.1 & 0.1 & 0.1 \end{bmatrix}\)
+ ⇒
+ \(V\_float\) (1×3×4):
+ \[
+ \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix}
+ \]
+
+
+ Scores = \(Q \cdot K\_float^T / \sqrt{4}\) = \(\begin{bmatrix} 0.5 & 0.5 & 0.5 \end{bmatrix}\),
+ so softmax weights = \(\begin{bmatrix} 1/3 & 1/3 & 1/3 \end{bmatrix}\).
+
+
+ Output (1×4):
+ \[
+ \begin{bmatrix} 5.00 & 6.00 & 7.00 & 8.00 \end{bmatrix}
+ \]
+
+
+Constraints
+
+ 1 ≤ num_heads ≤ 64
+ 1 ≤ seq_len ≤ 32,768
+ 8 ≤ head_dim ≤ 256; head_dim is a multiple of 8
+ K_int8 and V_int8 values are in \([-128, 127]\)
+ All scale values are positive float32
+ Performance is measured with num_heads = 32, seq_len = 8,192, head_dim = 128
+
diff --git a/challenges/medium/96_int8_kv_cache_attention/challenge.py b/challenges/medium/96_int8_kv_cache_attention/challenge.py
new file mode 100644
index 0000000..f2b0a27
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/challenge.py
@@ -0,0 +1,155 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="INT8 KV-Cache Attention",
+ atol=1e-03,
+ rtol=1e-03,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K_int8: torch.Tensor,
+ V_int8: torch.Tensor,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+ ):
+ assert Q.shape == (num_heads, head_dim)
+ assert K_int8.shape == (num_heads, seq_len, head_dim)
+ assert V_int8.shape == (num_heads, seq_len, head_dim)
+ assert k_scale.shape == (num_heads, seq_len)
+ assert v_scale.shape == (num_heads, seq_len)
+ assert output.shape == (num_heads, head_dim)
+ assert Q.dtype == torch.float32
+ assert K_int8.dtype == torch.int8
+ assert V_int8.dtype == torch.int8
+ assert k_scale.dtype == torch.float32
+ assert v_scale.dtype == torch.float32
+ assert output.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K_int8.device.type == "cuda"
+ assert V_int8.device.type == "cuda"
+ assert k_scale.device.type == "cuda"
+ assert v_scale.device.type == "cuda"
+ assert output.device.type == "cuda"
+
+ # Dequantize: K_float[h, s, d] = K_int8[h, s, d] * k_scale[h, s]
+ K_float = K_int8.float() * k_scale.unsqueeze(-1) # [num_heads, seq_len, head_dim]
+ V_float = V_int8.float() * v_scale.unsqueeze(-1) # [num_heads, seq_len, head_dim]
+
+ # Scaled dot-product attention: Q [num_heads, head_dim] attends to all seq_len positions
+ scale = 1.0 / math.sqrt(head_dim)
+ # scores: [num_heads, 1, seq_len]
+ scores = torch.bmm(Q.unsqueeze(1), K_float.transpose(1, 2)) * scale
+ weights = torch.softmax(scores, dim=-1) # [num_heads, 1, seq_len]
+
+ # Weighted sum of V: [num_heads, 1, seq_len] @ [num_heads, seq_len, head_dim]
+ out = torch.bmm(weights, V_float) # [num_heads, 1, head_dim]
+ output.copy_(out.squeeze(1))
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K_int8": (ctypes.POINTER(ctypes.c_int8), "in"),
+ "V_int8": (ctypes.POINTER(ctypes.c_int8), "in"),
+ "k_scale": (ctypes.POINTER(ctypes.c_float), "in"),
+ "v_scale": (ctypes.POINTER(ctypes.c_float), "in"),
+ "output": (ctypes.POINTER(ctypes.c_float), "out"),
+ "num_heads": (ctypes.c_int, "in"),
+ "seq_len": (ctypes.c_int, "in"),
+ "head_dim": (ctypes.c_int, "in"),
+ }
+
+ def _make_test_case(self, num_heads, seq_len, head_dim, zero_q=False, seed=None):
+ device = "cuda"
+ if seed is not None:
+ torch.manual_seed(seed)
+ if zero_q:
+ Q = torch.zeros(num_heads, head_dim, dtype=torch.float32, device=device)
+ else:
+ Q = torch.randn(num_heads, head_dim, dtype=torch.float32, device=device)
+ K_int8 = torch.randint(
+ -128, 128, (num_heads, seq_len, head_dim), dtype=torch.int8, device=device
+ )
+ V_int8 = torch.randint(
+ -128, 128, (num_heads, seq_len, head_dim), dtype=torch.int8, device=device
+ )
+ k_scale = torch.rand(num_heads, seq_len, dtype=torch.float32, device=device) * 0.1 + 0.01
+ v_scale = torch.rand(num_heads, seq_len, dtype=torch.float32, device=device) * 0.1 + 0.01
+ output = torch.empty(num_heads, head_dim, dtype=torch.float32, device=device)
+ return {
+ "Q": Q,
+ "K_int8": K_int8,
+ "V_int8": V_int8,
+ "k_scale": k_scale,
+ "v_scale": v_scale,
+ "output": output,
+ "num_heads": num_heads,
+ "seq_len": seq_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ device = "cuda"
+ num_heads, seq_len, head_dim = 1, 3, 4
+ Q = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32, device=device)
+ K_int8 = torch.tensor(
+ [[[10, 0, 0, 0], [0, 10, 0, 0], [0, 0, 10, 0]]], dtype=torch.int8, device=device
+ )
+ V_int8 = torch.tensor(
+ [[[10, 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]],
+ dtype=torch.int8,
+ device=device,
+ )
+ k_scale = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float32, device=device)
+ v_scale = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float32, device=device)
+ output = torch.empty(num_heads, head_dim, dtype=torch.float32, device=device)
+ return {
+ "Q": Q,
+ "K_int8": K_int8,
+ "V_int8": V_int8,
+ "k_scale": k_scale,
+ "v_scale": v_scale,
+ "output": output,
+ "num_heads": num_heads,
+ "seq_len": seq_len,
+ "head_dim": head_dim,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ tests = []
+ # Edge: single key in cache
+ tests.append(self._make_test_case(1, 1, 8, seed=0))
+ # Edge: two keys
+ tests.append(self._make_test_case(1, 2, 8, seed=1))
+ # Edge: four keys, two heads
+ tests.append(self._make_test_case(2, 4, 8, seed=2))
+ # Zero query (uniform softmax weights)
+ tests.append(self._make_test_case(1, 8, 16, zero_q=True, seed=3))
+ # Power-of-2 seq_len
+ tests.append(self._make_test_case(4, 16, 64, seed=4))
+ tests.append(self._make_test_case(8, 64, 64, seed=5))
+ # Non-power-of-2
+ tests.append(self._make_test_case(2, 30, 64, seed=6))
+ tests.append(self._make_test_case(4, 100, 64, seed=7))
+ # Realistic sizes
+ tests.append(self._make_test_case(16, 512, 64, seed=8))
+ tests.append(self._make_test_case(32, 256, 128, seed=9))
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ return self._make_test_case(32, 8192, 128, seed=42)
diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.cu b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cu
new file mode 100644
index 0000000..a9f7fdb
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cu
@@ -0,0 +1,6 @@
+#include
+
+// Q, K_int8, V_int8, k_scale, v_scale, output are device pointers
+extern "C" void solve(const float* Q, const int8_t* K_int8, const int8_t* V_int8,
+ const float* k_scale, const float* v_scale, float* output, int num_heads,
+ int seq_len, int head_dim) {}
diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.cute.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cute.py
new file mode 100644
index 0000000..dd7e82e
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cute.py
@@ -0,0 +1,18 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K_int8: cute.Tensor,
+ V_int8: cute.Tensor,
+ k_scale: cute.Tensor,
+ v_scale: cute.Tensor,
+ output: cute.Tensor,
+ num_heads: cute.Int32,
+ seq_len: cute.Int32,
+ head_dim: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.jax.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.jax.py
new file mode 100644
index 0000000..74b609c
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.jax.py
@@ -0,0 +1,18 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K_int8, V_int8, k_scale, v_scale are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array,
+ K_int8: jax.Array,
+ V_int8: jax.Array,
+ k_scale: jax.Array,
+ v_scale: jax.Array,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.mojo b/challenges/medium/96_int8_kv_cache_attention/starter/starter.mojo
new file mode 100644
index 0000000..84f8074
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.mojo
@@ -0,0 +1,18 @@
+from std.gpu.host import DeviceContext
+from std.memory import UnsafePointer
+
+
+# Q, K_int8, V_int8, k_scale, v_scale, output are device pointers
+@export
+def solve(
+ Q: UnsafePointer[Float32, MutExternalOrigin],
+ K_int8: UnsafePointer[Int8, MutExternalOrigin],
+ V_int8: UnsafePointer[Int8, MutExternalOrigin],
+ k_scale: UnsafePointer[Float32, MutExternalOrigin],
+ v_scale: UnsafePointer[Float32, MutExternalOrigin],
+ output: UnsafePointer[Float32, MutExternalOrigin],
+ num_heads: Int32,
+ seq_len: Int32,
+ head_dim: Int32,
+) raises:
+ pass
diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.pytorch.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.pytorch.py
new file mode 100644
index 0000000..a27270b
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.pytorch.py
@@ -0,0 +1,16 @@
+import torch
+
+
+# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K_int8: torch.Tensor,
+ V_int8: torch.Tensor,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+):
+ pass
diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.triton.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.triton.py
new file mode 100644
index 0000000..d09e48f
--- /dev/null
+++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.triton.py
@@ -0,0 +1,18 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K_int8: torch.Tensor,
+ V_int8: torch.Tensor,
+ k_scale: torch.Tensor,
+ v_scale: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int,
+ seq_len: int,
+ head_dim: int,
+):
+ pass