Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions challenges/medium/96_int8_kv_cache_attention/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
<p>
Implement decode-phase multi-head attention where the key and value caches are stored as
<code>int8</code> with per-token scale factors. This memory layout halves KV-cache bandwidth
versus <code>float32</code> and is used in production LLM serving systems such as TensorRT-LLM
and vLLM. Given a query tensor <code>Q</code> for a single new token, <code>int8</code> key cache
<code>K_int8</code>, <code>int8</code> value cache <code>V_int8</code>, and per-token scales
<code>k_scale</code> and <code>v_scale</code>, dequantize the caches and compute scaled
dot-product attention to produce <code>output</code>. All non-integer tensors use
<code>float32</code>.
</p>

<svg width="700" height="270" viewBox="0 0 700 270" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
<rect width="700" height="270" fill="#222" rx="10"/>
<text x="350" y="26" fill="#ccc" font-family="monospace" font-size="13" text-anchor="middle">INT8 KV-Cache Attention — single token decode</text>

<!-- Q box -->
<rect x="20" y="45" width="80" height="36" fill="#2563eb" rx="4"/>
<text x="60" y="68" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q (fp32)</text>

<!-- K_int8 + scale -->
<rect x="130" y="45" width="100" height="36" fill="#7c3aed" rx="4"/>
<text x="180" y="63" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">K_int8 (int8)</text>
<rect x="130" y="90" width="100" height="28" fill="#5b21b6" rx="4"/>
<text x="180" y="109" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">k_scale (fp32)</text>

<!-- V_int8 + scale -->
<rect x="260" y="45" width="100" height="36" fill="#065f46" rx="4"/>
<text x="310" y="63" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">V_int8 (int8)</text>
<rect x="260" y="90" width="100" height="28" fill="#064e3b" rx="4"/>
<text x="310" y="109" fill="#fff" font-family="monospace" font-size="11" text-anchor="middle">v_scale (fp32)</text>

<!-- dequant arrows -->
<line x1="180" y1="118" x2="180" y2="148" stroke="#a78bfa" stroke-width="1.5" marker-end="url(#arr2)"/>
<line x1="310" y1="118" x2="310" y2="148" stroke="#34d399" stroke-width="1.5" marker-end="url(#arr2)"/>
<text x="180" y="145" fill="#a78bfa" font-family="monospace" font-size="10" text-anchor="middle">×</text>
<text x="310" y="145" fill="#34d399" font-family="monospace" font-size="10" text-anchor="middle">×</text>

<!-- K_float, V_float -->
<rect x="130" y="155" width="100" height="28" fill="#4c1d95" rx="4"/>
<text x="180" y="174" fill="#e9d5ff" font-family="monospace" font-size="11" text-anchor="middle">K_float</text>
<rect x="260" y="155" width="100" height="28" fill="#022c22" rx="4"/>
<text x="310" y="174" fill="#6ee7b7" font-family="monospace" font-size="11" text-anchor="middle">V_float</text>

<!-- attention flow -->
<line x1="60" y1="81" x2="60" y2="210" stroke="#60a5fa" stroke-width="1.5"/>
<line x1="60" y1="210" x2="160" y2="210" stroke="#60a5fa" stroke-width="1.5" marker-end="url(#arr2)"/>
<line x1="180" y1="183" x2="180" y2="210" stroke="#a78bfa" stroke-width="1.5" marker-end="url(#arr2)"/>
<rect x="155" y="205" width="50" height="26" fill="#1e1b4b" rx="4"/>
<text x="180" y="222" fill="#c4b5fd" font-family="monospace" font-size="10" text-anchor="middle">scores</text>
<line x1="205" y1="218" x2="250" y2="218" stroke="#e5e7eb" stroke-width="1.5" marker-end="url(#arr2)"/>
<rect x="250" y="205" width="70" height="26" fill="#1e293b" rx="4"/>
<text x="285" y="222" fill="#93c5fd" font-family="monospace" font-size="10" text-anchor="middle">softmax</text>
<line x1="320" y1="218" x2="360" y2="218" stroke="#e5e7eb" stroke-width="1.5" marker-end="url(#arr2)"/>
<line x1="310" y1="183" x2="370" y2="218" stroke="#34d399" stroke-width="1.5"/>
<rect x="360" y="205" width="60" height="26" fill="#1a2e1a" rx="4"/>
<text x="390" y="222" fill="#86efac" font-family="monospace" font-size="10" text-anchor="middle">output</text>

<!-- formula text -->
<text x="470" y="75" fill="#93c5fd" font-family="monospace" font-size="11">K[h,s,:] = K_int8[h,s,:] × k_scale[h,s]</text>
<text x="470" y="95" fill="#6ee7b7" font-family="monospace" font-size="11">V[h,s,:] = V_int8[h,s,:] × v_scale[h,s]</text>
<text x="470" y="125" fill="#fde68a" font-family="monospace" font-size="11">scores[h,s] = Q[h,:]·K[h,s,:] / √head_dim</text>
<text x="470" y="145" fill="#fde68a" font-family="monospace" font-size="11">w[h,:] = softmax(scores[h,:])</text>
<text x="470" y="165" fill="#fde68a" font-family="monospace" font-size="11">out[h,:] = Σ_s w[h,s] · V[h,s,:]</text>

<defs>
<marker id="arr2" markerWidth="6" markerHeight="6" refX="3" refY="3" orient="auto">
<path d="M0,0 L0,6 L6,3 z" fill="#888"/>
</marker>
</defs>
</svg>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the function <code>solve(Q, K_int8, V_int8, k_scale, v_scale, output, num_heads, seq_len, head_dim)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write the result into the provided <code>output</code> buffer.</li>
<li>Dequantize using per-token scales: <code>K_float[h, s, d] = K_int8[h, s, d] &times; k_scale[h, s]</code> (and analogously for V).</li>
<li>Use scaled dot-product attention with scale factor <code>1 / sqrt(head_dim)</code> and a softmax over the sequence dimension.</li>
</ul>

<h2>Example</h2>
<p>
With <code>num_heads</code> = 1, <code>seq_len</code> = 3, <code>head_dim</code> = 4:
</p>
<p>
<strong>Input:</strong><br>
\(Q\) (1&times;4):
\[
\begin{bmatrix} 1 & 1 & 1 & 1 \end{bmatrix}
\]
\(K\_int8\) (1&times;3&times;4):
\[
\begin{bmatrix} 10 & 0 & 0 & 0 \\ 0 & 10 & 0 & 0 \\ 0 & 0 & 10 & 0 \end{bmatrix}
\]
\(k\_scale\) (1&times;3): \(\begin{bmatrix} 0.1 & 0.1 & 0.1 \end{bmatrix}\)
&nbsp;&rArr;&nbsp;
\(K\_float\) (1&times;3&times;4):
\[
\begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix}
\]
\(V\_int8\) (1&times;3&times;4):
\[
\begin{bmatrix} 10 & 20 & 30 & 40 \\ 50 & 60 & 70 & 80 \\ 90 & 100 & 110 & 120 \end{bmatrix}
\]
\(v\_scale\) (1&times;3): \(\begin{bmatrix} 0.1 & 0.1 & 0.1 \end{bmatrix}\)
&nbsp;&rArr;&nbsp;
\(V\_float\) (1&times;3&times;4):
\[
\begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix}
\]
</p>
<p>
Scores = \(Q \cdot K\_float^T / \sqrt{4}\) = \(\begin{bmatrix} 0.5 & 0.5 & 0.5 \end{bmatrix}\),
so <em>softmax</em> weights = \(\begin{bmatrix} 1/3 & 1/3 & 1/3 \end{bmatrix}\).
</p>
<p>
<strong>Output</strong> (1&times;4):
\[
\begin{bmatrix} 5.00 & 6.00 & 7.00 & 8.00 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>num_heads</code> &le; 64</li>
<li>1 &le; <code>seq_len</code> &le; 32,768</li>
<li>8 &le; <code>head_dim</code> &le; 256; <code>head_dim</code> is a multiple of 8</li>
<li><code>K_int8</code> and <code>V_int8</code> values are in \([-128, 127]\)</li>
<li>All scale values are positive <code>float32</code></li>
<li>Performance is measured with <code>num_heads</code> = 32, <code>seq_len</code> = 8,192, <code>head_dim</code> = 128</li>
</ul>
155 changes: 155 additions & 0 deletions challenges/medium/96_int8_kv_cache_attention/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import ctypes
import math
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
def __init__(self):
super().__init__(
name="INT8 KV-Cache Attention",
atol=1e-03,
rtol=1e-03,
num_gpus=1,
access_tier="free",
)

def reference_impl(
self,
Q: torch.Tensor,
K_int8: torch.Tensor,
V_int8: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
assert Q.shape == (num_heads, head_dim)
assert K_int8.shape == (num_heads, seq_len, head_dim)
assert V_int8.shape == (num_heads, seq_len, head_dim)
assert k_scale.shape == (num_heads, seq_len)
assert v_scale.shape == (num_heads, seq_len)
assert output.shape == (num_heads, head_dim)
assert Q.dtype == torch.float32
assert K_int8.dtype == torch.int8
assert V_int8.dtype == torch.int8
assert k_scale.dtype == torch.float32
assert v_scale.dtype == torch.float32
assert output.dtype == torch.float32
assert Q.device.type == "cuda"
assert K_int8.device.type == "cuda"
assert V_int8.device.type == "cuda"
assert k_scale.device.type == "cuda"
assert v_scale.device.type == "cuda"
assert output.device.type == "cuda"

# Dequantize: K_float[h, s, d] = K_int8[h, s, d] * k_scale[h, s]
K_float = K_int8.float() * k_scale.unsqueeze(-1) # [num_heads, seq_len, head_dim]
V_float = V_int8.float() * v_scale.unsqueeze(-1) # [num_heads, seq_len, head_dim]

# Scaled dot-product attention: Q [num_heads, head_dim] attends to all seq_len positions
scale = 1.0 / math.sqrt(head_dim)
# scores: [num_heads, 1, seq_len]
scores = torch.bmm(Q.unsqueeze(1), K_float.transpose(1, 2)) * scale
weights = torch.softmax(scores, dim=-1) # [num_heads, 1, seq_len]

# Weighted sum of V: [num_heads, 1, seq_len] @ [num_heads, seq_len, head_dim]
out = torch.bmm(weights, V_float) # [num_heads, 1, head_dim]
output.copy_(out.squeeze(1))

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
"K_int8": (ctypes.POINTER(ctypes.c_int8), "in"),
"V_int8": (ctypes.POINTER(ctypes.c_int8), "in"),
"k_scale": (ctypes.POINTER(ctypes.c_float), "in"),
"v_scale": (ctypes.POINTER(ctypes.c_float), "in"),
"output": (ctypes.POINTER(ctypes.c_float), "out"),
"num_heads": (ctypes.c_int, "in"),
"seq_len": (ctypes.c_int, "in"),
"head_dim": (ctypes.c_int, "in"),
}

def _make_test_case(self, num_heads, seq_len, head_dim, zero_q=False, seed=None):
device = "cuda"
if seed is not None:
torch.manual_seed(seed)
if zero_q:
Q = torch.zeros(num_heads, head_dim, dtype=torch.float32, device=device)
else:
Q = torch.randn(num_heads, head_dim, dtype=torch.float32, device=device)
K_int8 = torch.randint(
-128, 128, (num_heads, seq_len, head_dim), dtype=torch.int8, device=device
)
V_int8 = torch.randint(
-128, 128, (num_heads, seq_len, head_dim), dtype=torch.int8, device=device
)
k_scale = torch.rand(num_heads, seq_len, dtype=torch.float32, device=device) * 0.1 + 0.01
v_scale = torch.rand(num_heads, seq_len, dtype=torch.float32, device=device) * 0.1 + 0.01
output = torch.empty(num_heads, head_dim, dtype=torch.float32, device=device)
return {
"Q": Q,
"K_int8": K_int8,
"V_int8": V_int8,
"k_scale": k_scale,
"v_scale": v_scale,
"output": output,
"num_heads": num_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_example_test(self) -> Dict[str, Any]:
device = "cuda"
num_heads, seq_len, head_dim = 1, 3, 4
Q = torch.tensor([[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32, device=device)
K_int8 = torch.tensor(
[[[10, 0, 0, 0], [0, 10, 0, 0], [0, 0, 10, 0]]], dtype=torch.int8, device=device
)
V_int8 = torch.tensor(
[[[10, 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]],
dtype=torch.int8,
device=device,
)
k_scale = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float32, device=device)
v_scale = torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.float32, device=device)
output = torch.empty(num_heads, head_dim, dtype=torch.float32, device=device)
return {
"Q": Q,
"K_int8": K_int8,
"V_int8": V_int8,
"k_scale": k_scale,
"v_scale": v_scale,
"output": output,
"num_heads": num_heads,
"seq_len": seq_len,
"head_dim": head_dim,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
tests = []
# Edge: single key in cache
tests.append(self._make_test_case(1, 1, 8, seed=0))
# Edge: two keys
tests.append(self._make_test_case(1, 2, 8, seed=1))
# Edge: four keys, two heads
tests.append(self._make_test_case(2, 4, 8, seed=2))
# Zero query (uniform softmax weights)
tests.append(self._make_test_case(1, 8, 16, zero_q=True, seed=3))
# Power-of-2 seq_len
tests.append(self._make_test_case(4, 16, 64, seed=4))
tests.append(self._make_test_case(8, 64, 64, seed=5))
# Non-power-of-2
tests.append(self._make_test_case(2, 30, 64, seed=6))
tests.append(self._make_test_case(4, 100, 64, seed=7))
# Realistic sizes
tests.append(self._make_test_case(16, 512, 64, seed=8))
tests.append(self._make_test_case(32, 256, 128, seed=9))
return tests

def generate_performance_test(self) -> Dict[str, Any]:
return self._make_test_case(32, 8192, 128, seed=42)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <cuda_runtime.h>

// Q, K_int8, V_int8, k_scale, v_scale, output are device pointers
extern "C" void solve(const float* Q, const int8_t* K_int8, const int8_t* V_int8,
const float* k_scale, const float* v_scale, float* output, int num_heads,
int seq_len, int head_dim) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import cutlass
import cutlass.cute as cute


# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K_int8: cute.Tensor,
V_int8: cute.Tensor,
k_scale: cute.Tensor,
v_scale: cute.Tensor,
output: cute.Tensor,
num_heads: cute.Int32,
seq_len: cute.Int32,
head_dim: cute.Int32,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax
import jax.numpy as jnp


# Q, K_int8, V_int8, k_scale, v_scale are tensors on GPU
@jax.jit
def solve(
Q: jax.Array,
K_int8: jax.Array,
V_int8: jax.Array,
k_scale: jax.Array,
v_scale: jax.Array,
num_heads: int,
seq_len: int,
head_dim: int,
) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from std.gpu.host import DeviceContext
from std.memory import UnsafePointer


# Q, K_int8, V_int8, k_scale, v_scale, output are device pointers
@export
def solve(
Q: UnsafePointer[Float32, MutExternalOrigin],
K_int8: UnsafePointer[Int8, MutExternalOrigin],
V_int8: UnsafePointer[Int8, MutExternalOrigin],
k_scale: UnsafePointer[Float32, MutExternalOrigin],
v_scale: UnsafePointer[Float32, MutExternalOrigin],
output: UnsafePointer[Float32, MutExternalOrigin],
num_heads: Int32,
seq_len: Int32,
head_dim: Int32,
) raises:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K_int8: torch.Tensor,
V_int8: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import triton
import triton.language as tl


# Q, K_int8, V_int8, k_scale, v_scale, output are tensors on the GPU
def solve(
Q: torch.Tensor,
K_int8: torch.Tensor,
V_int8: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
output: torch.Tensor,
num_heads: int,
seq_len: int,
head_dim: int,
):
pass
Loading