diff --git a/challenges/medium/96_int8_kv_cache_attention/challenge.html b/challenges/medium/96_int8_kv_cache_attention/challenge.html new file mode 100644 index 0000000..a7a9566 --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/challenge.html @@ -0,0 +1,131 @@ +

+Implement decode-phase multi-head attention where the key and value caches are stored as +int8 with per-token scale factors. This memory layout halves KV-cache bandwidth +versus float32 and is used in production LLM serving systems such as TensorRT-LLM +and vLLM. Given a query tensor Q for a single new token, int8 key cache +K_int8, int8 value cache V_int8, and per-token scales +k_scale and v_scale, dequantize the caches and compute scaled +dot-product attention to produce output. All non-integer tensors use +float32. +

+ + + + INT8 KV-Cache Attention — single token decode + + + + Q (fp32) + + + + K_int8 (int8) + + k_scale (fp32) + + + + V_int8 (int8) + + v_scale (fp32) + + + + + × + × + + + + K_float + + V_float + + + + + + + scores + + + softmax + + + + output + + + K[h,s,:] = K_int8[h,s,:] × k_scale[h,s] + V[h,s,:] = V_int8[h,s,:] × v_scale[h,s] + scores[h,s] = Q[h,:]·K[h,s,:] / √head_dim + w[h,:] = softmax(scores[h,:]) + out[h,:] = Σ_s w[h,s] · V[h,s,:] + + + + + + + + +

Implementation Requirements

+ + +

Example

+

+ With num_heads = 1, seq_len = 3, head_dim = 4: +

+

+ Input:
+ \(Q\) (1×4): + \[ + \begin{bmatrix} 1 & 1 & 1 & 1 \end{bmatrix} + \] + \(K\_int8\) (1×3×4): + \[ + \begin{bmatrix} 10 & 0 & 0 & 0 \\ 0 & 10 & 0 & 0 \\ 0 & 0 & 10 & 0 \end{bmatrix} + \] + \(k\_scale\) (1×3): \(\begin{bmatrix} 0.1 & 0.1 & 0.1 \end{bmatrix}\) +  ⇒  + \(K\_float\) (1×3×4): + \[ + \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix} + \] + \(V\_int8\) (1×3×4): + \[ + \begin{bmatrix} 10 & 20 & 30 & 40 \\ 50 & 60 & 70 & 80 \\ 90 & 100 & 110 & 120 \end{bmatrix} + \] + \(v\_scale\) (1×3): \(\begin{bmatrix} 0.1 & 0.1 & 0.1 \end{bmatrix}\) +  ⇒  + \(V\_float\) (1×3×4): + \[ + \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} + \] +

+

+ Scores = \(Q \cdot K\_float^T / \sqrt{4}\) = \(\begin{bmatrix} 0.5 & 0.5 & 0.5 \end{bmatrix}\), + so softmax weights = \(\begin{bmatrix} 1/3 & 1/3 & 1/3 \end{bmatrix}\). +

+

+ Output (1×4): + \[ + \begin{bmatrix} 5.00 & 6.00 & 7.00 & 8.00 \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/96_int8_kv_cache_attention/challenge.py b/challenges/medium/96_int8_kv_cache_attention/challenge.py new file mode 100644 index 0000000..f2b0a27 --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/challenge.py @@ -0,0 +1,155 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="INT8 KV-Cache Attention", + atol=1e-03, + rtol=1e-03, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K_int8: torch.Tensor, + V_int8: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, + ): + assert Q.shape == (num_heads, head_dim) + assert K_int8.shape == (num_heads, seq_len, head_dim) + assert V_int8.shape == (num_heads, seq_len, head_dim) + assert k_scale.shape == (num_heads, seq_len) + assert v_scale.shape == (num_heads, seq_len) + assert output.shape == (num_heads, head_dim) + assert Q.dtype == torch.float32 + assert K_int8.dtype == torch.int8 + assert V_int8.dtype == torch.int8 + assert k_scale.dtype == torch.float32 + assert v_scale.dtype == torch.float32 + assert output.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K_int8.device.type == "cuda" + assert V_int8.device.type == "cuda" + assert k_scale.device.type == "cuda" + assert v_scale.device.type == "cuda" + assert output.device.type == "cuda" + + # Dequantize: K_float[h, s, d] = K_int8[h, s, d] * k_scale[h, s] + K_float = K_int8.float() * k_scale.unsqueeze(-1) # [num_heads, seq_len, head_dim] + V_float = V_int8.float() * v_scale.unsqueeze(-1) # [num_heads, seq_len, head_dim] + + # Scaled dot-product attention: Q [num_heads, head_dim] attends to all seq_len positions + scale = 1.0 / math.sqrt(head_dim) + # scores: [num_heads, 1, seq_len] + scores = torch.bmm(Q.unsqueeze(1), K_float.transpose(1, 2)) * scale + weights = torch.softmax(scores, dim=-1) # [num_heads, 1, seq_len] + + # Weighted sum of V: [num_heads, 1, seq_len] @ [num_heads, seq_len, head_dim] + out = torch.bmm(weights, V_float) # [num_heads, 1, head_dim] + output.copy_(out.squeeze(1)) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K_int8": (ctypes.POINTER(ctypes.c_int8), "in"), + "V_int8": (ctypes.POINTER(ctypes.c_int8), "in"), + "k_scale": (ctypes.POINTER(ctypes.c_float), "in"), + "v_scale": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "num_heads": (ctypes.c_int, "in"), + "seq_len": (ctypes.c_int, "in"), + "head_dim": (ctypes.c_int, "in"), + } + + def _make_test_case(self, num_heads, seq_len, head_dim, zero_q=False, seed=None): + device = "cuda" + if seed is not None: + torch.manual_seed(seed) + if zero_q: + Q = torch.zeros(num_heads, head_dim, dtype=torch.float32, device=device) + else: + Q = torch.randn(num_heads, head_dim, dtype=torch.float32, device=device) + K_int8 = torch.randint( + -128, 128, (num_heads, seq_len, head_dim), dtype=torch.int8, device=device + ) + V_int8 = torch.randint( + -128, 128, (num_heads, seq_len, head_dim), dtype=torch.int8, device=device + ) + k_scale = torch.rand(num_heads, seq_len, dtype=torch.float32, device=device) * 0.1 + 0.01 + v_scale = torch.rand(num_heads, seq_len, dtype=torch.float32, device=device) * 0.1 + 0.01 + output = torch.empty(num_heads, head_dim, dtype=torch.float32, device=device) + return { + "Q": Q, + "K_int8": K_int8, + "V_int8": V_int8, + "k_scale": k_scale, + "v_scale": v_scale, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + num_heads, seq_len, head_dim = 1, 3, 4 + Q = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32, device=device) + K_int8 = torch.tensor( + [[[10, 0, 0, 0], [0, 10, 0, 0], [0, 0, 10, 0]]], dtype=torch.int8, device=device + ) + V_int8 = torch.tensor( + [[[10, 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]], + dtype=torch.int8, + device=device, + ) + k_scale = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float32, device=device) + v_scale = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float32, device=device) + output = torch.empty(num_heads, head_dim, dtype=torch.float32, device=device) + return { + "Q": Q, + "K_int8": K_int8, + "V_int8": V_int8, + "k_scale": k_scale, + "v_scale": v_scale, + "output": output, + "num_heads": num_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + tests = [] + # Edge: single key in cache + tests.append(self._make_test_case(1, 1, 8, seed=0)) + # Edge: two keys + tests.append(self._make_test_case(1, 2, 8, seed=1)) + # Edge: four keys, two heads + tests.append(self._make_test_case(2, 4, 8, seed=2)) + # Zero query (uniform softmax weights) + tests.append(self._make_test_case(1, 8, 16, zero_q=True, seed=3)) + # Power-of-2 seq_len + tests.append(self._make_test_case(4, 16, 64, seed=4)) + tests.append(self._make_test_case(8, 64, 64, seed=5)) + # Non-power-of-2 + tests.append(self._make_test_case(2, 30, 64, seed=6)) + tests.append(self._make_test_case(4, 100, 64, seed=7)) + # Realistic sizes + tests.append(self._make_test_case(16, 512, 64, seed=8)) + tests.append(self._make_test_case(32, 256, 128, seed=9)) + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + return self._make_test_case(32, 8192, 128, seed=42) diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.cu b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cu new file mode 100644 index 0000000..a9f7fdb --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cu @@ -0,0 +1,6 @@ +#include + +// Q, K_int8, V_int8, k_scale, v_scale, output are device pointers +extern "C" void solve(const float* Q, const int8_t* K_int8, const int8_t* V_int8, + const float* k_scale, const float* v_scale, float* output, int num_heads, + int seq_len, int head_dim) {} diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.cute.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cute.py new file mode 100644 index 0000000..dd7e82e --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.cute.py @@ -0,0 +1,18 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K_int8: cute.Tensor, + V_int8: cute.Tensor, + k_scale: cute.Tensor, + v_scale: cute.Tensor, + output: cute.Tensor, + num_heads: cute.Int32, + seq_len: cute.Int32, + head_dim: cute.Int32, +): + pass diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.jax.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.jax.py new file mode 100644 index 0000000..74b609c --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp + + +# Q, K_int8, V_int8, k_scale, v_scale are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K_int8: jax.Array, + V_int8: jax.Array, + k_scale: jax.Array, + v_scale: jax.Array, + num_heads: int, + seq_len: int, + head_dim: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.mojo b/challenges/medium/96_int8_kv_cache_attention/starter/starter.mojo new file mode 100644 index 0000000..84f8074 --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.mojo @@ -0,0 +1,18 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# Q, K_int8, V_int8, k_scale, v_scale, output are device pointers +@export +def solve( + Q: UnsafePointer[Float32, MutExternalOrigin], + K_int8: UnsafePointer[Int8, MutExternalOrigin], + V_int8: UnsafePointer[Int8, MutExternalOrigin], + k_scale: UnsafePointer[Float32, MutExternalOrigin], + v_scale: UnsafePointer[Float32, MutExternalOrigin], + output: UnsafePointer[Float32, MutExternalOrigin], + num_heads: Int32, + seq_len: Int32, + head_dim: Int32, +) raises: + pass diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.pytorch.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.pytorch.py new file mode 100644 index 0000000..a27270b --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.pytorch.py @@ -0,0 +1,16 @@ +import torch + + +# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K_int8: torch.Tensor, + V_int8: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass diff --git a/challenges/medium/96_int8_kv_cache_attention/starter/starter.triton.py b/challenges/medium/96_int8_kv_cache_attention/starter/starter.triton.py new file mode 100644 index 0000000..d09e48f --- /dev/null +++ b/challenges/medium/96_int8_kv_cache_attention/starter/starter.triton.py @@ -0,0 +1,18 @@ +import torch +import triton +import triton.language as tl + + +# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU +def solve( + Q: torch.Tensor, + K_int8: torch.Tensor, + V_int8: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + output: torch.Tensor, + num_heads: int, + seq_len: int, + head_dim: int, +): + pass