diff --git a/challenges/medium/94_ssm_selective_scan/challenge.html b/challenges/medium/94_ssm_selective_scan/challenge.html new file mode 100644 index 00000000..36eb7fb1 --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/challenge.html @@ -0,0 +1,134 @@ +

+ Implement the forward pass of a State Space Model (SSM) selective scan, the core operation in + Mamba-style sequence models. Given an input sequence u, time-step parameters + delta, state-transition matrix A, input projection B, + output projection C, and skip-connection weights skip, compute the + output sequence y in float32. +

+ + + + + + + h₀ + + + h₁ + + + h₂ + + + h₃ + + + + + + Ā + Ā + Ā + + + + + + + B̄u₀ + B̄u₁ + B̄u₂ + B̄u₃ + + + + + + + y₀ + y₁ + y₂ + y₃ + + + + + + + + + + + + + + + + + + +

Implementation Requirements

+

+ Implement the function solve(u, delta, A, B, C, skip, y, batch, seq_len, d_model, d_state) + with the signature unchanged. Do not use external libraries beyond the allowed framework. + Write the result into the pre-allocated output tensor y. +

+

+ For each batch b, position t, and channel d, the computation is: +

+

+ \[ + \bar{A}_{b,t,d,n} = \exp(\Delta_{b,t,d} \cdot A_{d,n}) + \] + \[ + \bar{B}_{b,t,d,n} = \Delta_{b,t,d} \cdot B_{b,t,n} + \] + \[ + h_{b,t,d,n} = \bar{A}_{b,t,d,n} \cdot h_{b,t-1,d,n} + \bar{B}_{b,t,d,n} \cdot u_{b,t,d} + \] + \[ + y_{b,t,d} = \sum_{n} C_{b,t,n} \cdot h_{b,t,d,n} + \text{skip}_d \cdot u_{b,t,d} + \] +

+

+ The initial hidden state \(h_{b,-1,d,n} = 0\) for all \(b, d, n\). + All channels d are independent: they share the same B and C + projections but have separate state-transition rows in A. +

+ +

Example

+
+Input:
+  u     = [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]]]  shape (1,4,2)
+  delta = [[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]]  shape (1,4,2)
+  A     = [[-0.5, -1.0], [-0.5, -1.0]]                         shape (2,2)
+  B     = [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]]  shape (1,4,2)
+  C     = [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]]  shape (1,4,2)
+  skip  = [0.0, 0.0]                                            shape (2,)
+  batch=1, seq_len=4, d_model=2, d_state=2
+
+Derivation (delta=1 everywhere, so A_bar_dn = exp(A_dn)):
+  A_bar[d=0] = [exp(-0.5), exp(-1.0)] ≈ [0.607, 0.368]
+  A_bar[d=1] = [exp(-0.5), exp(-1.0)] ≈ [0.607, 0.368]
+
+  Hidden state h has shape (d_model=2, d_state=2); initial h = zeros.
+  t=0: h = [[1.000, 0.000], [0.000, 0.000]]  →  y[0,0] = [1.000, 0.000]
+  t=1: h = [[0.607, 0.000], [0.000, 1.000]]  →  y[0,1] = [0.000, 1.000]
+  t=2: h = [[1.368, 1.000], [1.000, 1.368]]  →  y[0,2] = [2.368, 2.368]
+  t=3: h = [[0.830, 0.368], [0.607, 0.503]]  →  y[0,3] = [0.599, 0.555]
+
+Output:
+  y = [[[1.000, 0.000], [0.000, 1.000], [2.368, 2.368], [0.599, 0.555]]]
+
+ +

Constraints

+ diff --git a/challenges/medium/94_ssm_selective_scan/challenge.py b/challenges/medium/94_ssm_selective_scan/challenge.py new file mode 100644 index 00000000..b15c9305 --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/challenge.py @@ -0,0 +1,201 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="SSM Selective Scan", + atol=1e-03, + rtol=1e-03, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + skip: torch.Tensor, + y: torch.Tensor, + batch: int, + seq_len: int, + d_model: int, + d_state: int, + ): + assert u.shape == (batch, seq_len, d_model) + assert delta.shape == (batch, seq_len, d_model) + assert A.shape == (d_model, d_state) + assert B.shape == (batch, seq_len, d_state) + assert C.shape == (batch, seq_len, d_state) + assert skip.shape == (d_model,) + assert y.shape == (batch, seq_len, d_model) + assert ( + u.dtype == delta.dtype == A.dtype == B.dtype == C.dtype == skip.dtype == torch.float32 + ) + assert u.device.type == "cuda" + assert delta.device.type == "cuda" + assert A.device.type == "cuda" + assert B.device.type == "cuda" + assert C.device.type == "cuda" + assert skip.device.type == "cuda" + assert y.device.type == "cuda" + + # Hidden state: (batch, d_model, d_state) + h = torch.zeros(batch, d_model, d_state, device=u.device, dtype=u.dtype) + + for t in range(seq_len): + delta_t = delta[:, t, :] # (batch, d_model) + u_t = u[:, t, :] # (batch, d_model) + + # Discretize: A_bar = exp(delta_t * A) + # delta_t: (batch, d_model) -> (batch, d_model, 1) + # A: (d_model, d_state) -> (1, d_model, d_state) + A_bar = torch.exp(delta_t.unsqueeze(-1) * A.unsqueeze(0)) # (batch, d_model, d_state) + + # B_bar = delta_t * B_t + # B[:, t, :]: (batch, d_state) -> (batch, 1, d_state) + B_bar = delta_t.unsqueeze(-1) * B[:, t, :].unsqueeze(1) # (batch, d_model, d_state) + + # State update: h = A_bar * h + B_bar * u_t + h = A_bar * h + B_bar * u_t.unsqueeze(-1) # (batch, d_model, d_state) + + # Output: y_t = C_t @ h + skip * u_t + # C[:, t, :]: (batch, d_state) -> einsum with h (batch, d_model, d_state) + C_t = C[:, t, :] # (batch, d_state) + y_t = torch.einsum("bn,bdn->bd", C_t, h) + skip * u_t # (batch, d_model) + y[:, t, :] = y_t + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "u": (ctypes.POINTER(ctypes.c_float), "in"), + "delta": (ctypes.POINTER(ctypes.c_float), "in"), + "A": (ctypes.POINTER(ctypes.c_float), "in"), + "B": (ctypes.POINTER(ctypes.c_float), "in"), + "C": (ctypes.POINTER(ctypes.c_float), "in"), + "skip": (ctypes.POINTER(ctypes.c_float), "in"), + "y": (ctypes.POINTER(ctypes.c_float), "out"), + "batch": (ctypes.c_int, "in"), + "seq_len": (ctypes.c_int, "in"), + "d_model": (ctypes.c_int, "in"), + "d_state": (ctypes.c_int, "in"), + } + + def _make_test_case(self, batch, seq_len, d_model, d_state, zero_u=False, zero_delta=False): + device = "cuda" + dtype = torch.float32 + if zero_u: + u = torch.zeros(batch, seq_len, d_model, device=device, dtype=dtype) + else: + u = torch.randn(batch, seq_len, d_model, device=device, dtype=dtype) + if zero_delta: + delta = torch.zeros(batch, seq_len, d_model, device=device, dtype=dtype) + else: + # delta must be positive + delta = torch.rand(batch, seq_len, d_model, device=device, dtype=dtype) + 0.01 + # A must be negative for stability (eigenvalues < 0) + A = -torch.rand(d_model, d_state, device=device, dtype=dtype) - 0.01 + B = torch.randn(batch, seq_len, d_state, device=device, dtype=dtype) + C = torch.randn(batch, seq_len, d_state, device=device, dtype=dtype) + skip = torch.rand(d_model, device=device, dtype=dtype) + y = torch.empty(batch, seq_len, d_model, device=device, dtype=dtype) + return { + "u": u, + "delta": delta, + "A": A, + "B": B, + "C": C, + "skip": skip, + "y": y, + "batch": batch, + "seq_len": seq_len, + "d_model": d_model, + "d_state": d_state, + } + + def generate_example_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + device = "cuda" + dtype = torch.float32 + batch, seq_len, d_model, d_state = 1, 4, 2, 2 + u = torch.tensor( + [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]]], + device=device, + dtype=dtype, + ) + delta = torch.tensor( + [[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]], + device=device, + dtype=dtype, + ) + A = torch.tensor([[-0.5, -1.0], [-0.5, -1.0]], device=device, dtype=dtype) + B = torch.tensor( + [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]], + device=device, + dtype=dtype, + ) + C = torch.tensor( + [[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]], + device=device, + dtype=dtype, + ) + skip = torch.tensor([0.0, 0.0], device=device, dtype=dtype) + y = torch.empty(batch, seq_len, d_model, device=device, dtype=dtype) + return { + "u": u, + "delta": delta, + "A": A, + "B": B, + "C": C, + "skip": skip, + "y": y, + "batch": batch, + "seq_len": seq_len, + "d_model": d_model, + "d_state": d_state, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + torch.manual_seed(42) + tests = [] + + # Edge case: single token + tests.append(self._make_test_case(1, 1, 1, 4)) + + # Edge case: tiny dimensions + tests.append(self._make_test_case(1, 2, 2, 2)) + + # Edge case: zero input (output should be skip * 0 = 0) + tests.append(self._make_test_case(1, 4, 4, 4, zero_u=True)) + + # Edge case: zero delta (A_bar=1, B_bar=0, so state stays zero, output = skip * u) + tests.append(self._make_test_case(2, 4, 4, 4, zero_delta=True)) + + # Power-of-2 lengths + tests.append(self._make_test_case(2, 16, 8, 4)) + tests.append(self._make_test_case(2, 64, 16, 8)) + + # Non-power-of-2 + tests.append(self._make_test_case(2, 30, 12, 4)) + tests.append(self._make_test_case(3, 100, 24, 8)) + + # Typical d_state=16 (common Mamba setting) + tests.append(self._make_test_case(2, 128, 32, 16)) + + # Realistic size + tests.append(self._make_test_case(4, 256, 64, 16)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + # batch=4, seq_len=4096, d_model=512, d_state=16 + # Memory: u+delta+y ~ 3 * 4*4096*512*4 = 96MB; A+B+C+skip small + # Total << 1GB, comfortably fits 5x in 16GB T4 + return self._make_test_case(4, 4096, 512, 16) diff --git a/challenges/medium/94_ssm_selective_scan/starter/starter.cu b/challenges/medium/94_ssm_selective_scan/starter/starter.cu new file mode 100644 index 00000000..561954cb --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/starter/starter.cu @@ -0,0 +1,7 @@ +#include +#include + +// u, delta, A, B, C, skip, y are device pointers +extern "C" void solve(const float* u, const float* delta, const float* A, const float* B, + const float* C, const float* skip, float* y, int batch, int seq_len, + int d_model, int d_state) {} diff --git a/challenges/medium/94_ssm_selective_scan/starter/starter.cute.py b/challenges/medium/94_ssm_selective_scan/starter/starter.cute.py new file mode 100644 index 00000000..7c25eaed --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/starter/starter.cute.py @@ -0,0 +1,20 @@ +import cutlass +import cutlass.cute as cute + + +# u, delta, A, B, C, skip, y are tensors on the GPU +@cute.jit +def solve( + u: cute.Tensor, + delta: cute.Tensor, + A: cute.Tensor, + B: cute.Tensor, + C: cute.Tensor, + skip: cute.Tensor, + y: cute.Tensor, + batch: cute.Uint32, + seq_len: cute.Uint32, + d_model: cute.Uint32, + d_state: cute.Uint32, +): + pass diff --git a/challenges/medium/94_ssm_selective_scan/starter/starter.jax.py b/challenges/medium/94_ssm_selective_scan/starter/starter.jax.py new file mode 100644 index 00000000..76491b88 --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/starter/starter.jax.py @@ -0,0 +1,20 @@ +import jax +import jax.numpy as jnp + + +# u, delta, A, B, C, skip are tensors on GPU +@jax.jit +def solve( + u: jax.Array, + delta: jax.Array, + A: jax.Array, + B: jax.Array, + C: jax.Array, + skip: jax.Array, + batch: int, + seq_len: int, + d_model: int, + d_state: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/94_ssm_selective_scan/starter/starter.mojo b/challenges/medium/94_ssm_selective_scan/starter/starter.mojo new file mode 100644 index 00000000..8ce07965 --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/starter/starter.mojo @@ -0,0 +1,20 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# u, delta, A, B, C, skip, y are device pointers +@export +def solve( + u: UnsafePointer[Float32, MutExternalOrigin], + delta: UnsafePointer[Float32, MutExternalOrigin], + A: UnsafePointer[Float32, MutExternalOrigin], + B: UnsafePointer[Float32, MutExternalOrigin], + C: UnsafePointer[Float32, MutExternalOrigin], + skip: UnsafePointer[Float32, MutExternalOrigin], + y: UnsafePointer[Float32, MutExternalOrigin], + batch: Int32, + seq_len: Int32, + d_model: Int32, + d_state: Int32, +) raises: + pass diff --git a/challenges/medium/94_ssm_selective_scan/starter/starter.pytorch.py b/challenges/medium/94_ssm_selective_scan/starter/starter.pytorch.py new file mode 100644 index 00000000..ad85b360 --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/starter/starter.pytorch.py @@ -0,0 +1,18 @@ +import torch + + +# u, delta, A, B, C, skip, y are tensors on the GPU +def solve( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + skip: torch.Tensor, + y: torch.Tensor, + batch: int, + seq_len: int, + d_model: int, + d_state: int, +): + pass diff --git a/challenges/medium/94_ssm_selective_scan/starter/starter.triton.py b/challenges/medium/94_ssm_selective_scan/starter/starter.triton.py new file mode 100644 index 00000000..93bff28c --- /dev/null +++ b/challenges/medium/94_ssm_selective_scan/starter/starter.triton.py @@ -0,0 +1,20 @@ +import torch +import triton +import triton.language as tl + + +# u, delta, A, B, C, skip, y are tensors on the GPU +def solve( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + skip: torch.Tensor, + y: torch.Tensor, + batch: int, + seq_len: int, + d_model: int, + d_state: int, +): + pass