Implementations of multi token attention in CUDA and Triton.
- (Paper)[https://arxiv.org/abs/2504.00927]
- (Original implementation)[https://github.com/facebookresearch/RAM/tree/main/projects/mta]
In this variant, there is only a convolution before the softmax.
And the output:
The implementation fuses all these operations in a single CUDA kernel, using the main ideas of tiling and online softmax from Flash Attention.
Because of the convolution, the tiles overlap each other, and there are some values of the Q*K^T matrix that are recomputed.
This kernel achieves some speedup compared with the pytorch implementation:

Because of the unavoidable extra computations involved in the convolution, it is slower than standard flash attention in the training workload:
However, during inference, memory bandwith is the bottleneck and increasing arithmetic intensity doesnt affect the performance as much:
Require: Matrices
- Initialize
$O=(0)_{N \times d}$ ,$L=(0)_N$ ,$M=(-\infty)_N$ in HBM. - Define effective block sizes (strides):
$B_r = B - P_q$ ,$B_c = B - 2P_k$ . - Calculate number of blocks:
$T_r = \lceil N/B_r \rceil, T_c = \lceil N/B_c \rceil$ . -
Parallelize for
$i=1$ to$T_r$ do: - // Define indices for the loaded Q tile (size
$B$ ) -
$I_s = (i-1) \cdot B_r - P_q$ .$I_e = I_s + B$ . - // Define indices for the effective output region (size
$B_r$ ) -
$I_s^{\text{eff}} = (i-1) \cdot B_r. I_e^{\text{eff}} = I_s^{\text{eff}} + B_r$ . - Load
$W$ from HBM to SRAM. - Initialize local statistics in SRAM (Size
$B$ ):$\ell_i = (1)_B$ ,$m_i = (-\infty)_B$ . - for
$j=1$ to$T_c$ do: - // Define indices for the loaded K/V tile (size
$B$ ) -
$J_s = (j-1) \cdot B_c - P_k$ .$J_e = J_s + B$ - // Load Overlapping Tiles
- Load
$Q[I_s:I_e], K[J_s:J_e], V[J_s:J_e]$ from HBM to SRAM. (Handle boundary indices). - // Compute Attention Scores on Full Tile
- On chip,
$S_{ij} = \text{scale} \cdot (Q[I_s:I_e] K[J_s:J_e]^T) \in \mathbb{R}^{B \times B}$ . - // Convolution and Local Statistics
- On chip,
$P_{ij}^{\text{raw}} = \text{Conv2D}(S_{ij}, W) \in \mathbb{R}^{B \times B}$ . (Apply causal masking). - On chip, compute local stats:
-
$\tilde{m}_{ij}$ . -
$\tilde{P}_{ij}$ . -
$\tilde{\ell}_{ij}$ . - // Online Softmax Update
-
$m_i^{\text{new}} = \max(m_i, \tilde{m}_{ij})$ - Calculate rescaling factors:
$\alpha = e^{m_i - m_i^{\text{new}}}$ ,$\beta = e^{\tilde{m}_{ij} - m_i^{\text{new}}}$ . - Update
$\ell_i \leftarrow \alpha \cdot \ell_i + \beta \cdot \tilde{\ell}_{ij}$ . - // Update Output (Read-Modify-Write HBM)
- Load the full tile
$O[I_s:I_e]$ from HBM to SRAM. - On chip,
$O[I_s:I_e] \leftarrow \alpha \cdot O[I_s:I_e] + (\beta \cdot \tilde{P}_{ij})V[J_s:J_e]$ . - Write back only the effective region
$O[I_s^{\text{eff}}:I_e^{\text{eff}}]$ to HBM. (Corresponds to the slice$[P_q:B]$ of the SRAM buffer). -
$m_i \leftarrow m_i^{\text{new}}$ - end for
- // Finalization (Load, Normalize, Write effective region)
- Load the accumulated
$O[I_s^{\text{eff}}:I_e^{\text{eff}}]$ from HBM. - Normalize using the corresponding slice of the local statistics (held in SRAM):
$O[I_s^{\text{eff}}:I_e^{\text{eff}}] \leftarrow O[I_s^{\text{eff}}:I_e^{\text{eff}}] / \ell_i[P_q:B]$ . - Write normalized
$O[I_s^{\text{eff}}:I_e^{\text{eff}}]$ to HBM. - Write
$m_i[P_q:B]$ to$M[I_s^{\text{eff}}:I_e^{\text{eff}}]$ and$\ell_i[P_q:B]$ to$L[I_s^{\text{eff}}:I_e^{\text{eff}}]$ in HBM. - end Parallelize
Require: Matrices
- Initialize
$dQ=(0)_{N \times d}, dW=(0)$ in HBM. - Define effective block sizes (strides):
$B_r = B - 2P_q, B_c = B - 4P_k$ . - Calculate number of blocks:
$T_r = \lceil N/B_r \rceil, T_c = \lceil N/B_c \rceil$ . -
Parallelize for
$i=1$ to$T_r$ do: - // Define indices for the loaded Q/dO tile (size
$B$ ) -
$I_s = (i-1) \cdot B_r - 2P_q$ .$I_e = I_s + B$ . - // Define indices for the effective output region (size
$B_r$ ) -
$I_s^{\text{eff}} = (i-1) \cdot B_r$ .$I_e^{\text{eff}} = I_s^{\text{eff}} + B_r$ . - Load
$W$ from HBM to SRAM. Initialize local$dW_i=(0)$ accumulator in SRAM. - Load
$dO[I_s:I_e], L[I_s:I_e], D[I_s:I_e]$ from HBM to SRAM. - for
$j=1$ to$T_c$ do: - // Define indices for the loaded K/V tile (size
$B$ ) -
$J_s = (j-1) \cdot B_c - 2P_k$ .$J_e = J_s + B$ . - // --- Forward Computation Reprise ---
- Load
$Q[I_s:I_e], K[J_s:J_e]$ from HBM to SRAM. - On chip,
$S_{ij} = \text{scale} \cdot (Q[I_s:I_e] K[J_s:J_e]^T) \in \mathbb{R}^{B \times B}$ . -
$P_{ij}^{\text{raw}} = \text{Conv2D}(S_{ij}, W)$ . (Apply causal masking). -
$P_{ij} = \exp(P_{ij}^{\text{raw}} - L[I_s:I_e])$ . - // --- Backward Computation ---
- // Compute dP
- Load
$V[J_s:J_e]$ from HBM to SRAM. - On chip,
$dP_{ij} = dO[I_s:I_e] V[J_s:J_e]^T$ . - // Compute dS (Backprop through Softmax)
- On chip,
$dS_{ij}^{\text{conv}} = P_{ij} \odot (dP_{ij} - D[I_s:I_e])$ . - // Compute dW (Accumulate Kernel Gradient)
- On chip,
$dW_i \leftarrow dW_i + \text{GradW}(S_{ij}, dS_{ij}^{\text{conv}})$ . - // Compute
$d(QK^T)$ (Backprop through Convolution) - On chip,
$d(QK^T)$ =$\text{TransposedConv2D}(dS^{\text{conv}}, W)$ - // Compute dQ (Read-Modify-Write HBM)
- Load
$K[J_s:J_e]$ from HBM to SRAM (if overwritten by V). - Load
$dQ[I_s:I_e]$ from HBM to SRAM accumulator$dQ_{\text{acc}}$ . (If$j>1$ , else initialize to 0). - On chip,
$dQ_{\text{acc}} \leftarrow dQ_{\text{acc}} + d(QK^T)_{ij} K[J_s:J_e]$ . - Write back the effective region
$dQ[I_s^{\text{eff}}:I_e^{\text{eff}}]$ to HBM (unscaled). - end for
- // Finalization
- // Apply scale factor to dQ (Read-Modify-Write HBM)
- Load
$dQ[I_s^{\text{eff}}:I_e^{\text{eff}}]$ from HBM. -
$dQ[I_s^{\text{eff}}:I_e^{\text{eff}}] \leftarrow dQ[I_s^{\text{eff}}:I_e^{\text{eff}}] \cdot \text{scale}$ . - Write
$dQ[I_s^{\text{eff}}:I_e^{\text{eff}}]$ back to HBM. - // Atomically update global dW
- $\text{AtomicAdd}(dW, dW_i)$.
- end Parallelize
Require: Matrices
- Initialize $dK=(0){N \times d}, dV=(0){N \times d}$ in HBM.
- Define effective block sizes (strides):
$B_r = B - 2P_q, B_c = B - 4P_k$ . - Calculate number of blocks:
$T_r = \lceil N/B_r \rceil, T_c = \lceil N/B_c \rceil$ . -
Parallelize for
$j=1$ to$T_c$ do: - // Define indices for the loaded K/V tile (size
$B$ ) -
$J_s = (j-1) \cdot B_c - 2P_k$ .$J_e = J_s + B$ . - // Define indices for the effective output region (size
$B_c$ ) -
$J_s^{\text{eff}} = (j-1) \cdot B_c$ .$J_e^{\text{eff}} = J_s^{\text{eff}} + B_c$ . - Load
$W$ from HBM to SRAM. - for
$i=1$ to$T_r$ do: - // Define indices for the loaded Q/dO tile (size
$B$ ) -
$I_s = (i-1) \cdot B_r - 2P_q$ .$I_e = I_s + B$ . - // --- Forward Computation Reprise ---
- Load
$Q[I_s:I_e], K[J_s:J_e], L[I_s:I_e], D[I_s:I_e]$ from HBM to SRAM. - On chip,
$S_{ij} = \text{scale} \cdot (Q[I_s:I_e] K[J_s:J_e]^T) \in \mathbb{R}^{B \times B}$ . -
$P_{ij}^{\text{raw}} = \text{Conv2D}(S_{ij}, W)$ . (Apply causal masking). -
$P_{ij} = \exp(P_{ij}^{\text{raw}} - L[I_s:I_e])$ . - // --- Backward Computation (Interleaved dV and dK) ---
- // Compute dP
- Load
$dO[I_s:I_e], V[J_s:J_e]$ from HBM to SRAM (Overwrites Q, K). - On chip,
$dP_{ij} = dO[I_s:I_e] V[J_s:J_e]^T$ . - // Update dV (Read-Modify-Write HBM)
- Load
$dV[J_s:J_e]$ from HBM to SRAM accumulator$dV_{\text{acc}}$ . (If$i>1$ , else initialize to 0). - On chip, transpose
$P_{ij} \rightarrow P_{ij}^T$ . - On chip,
$dV_{\text{acc}} \leftarrow dV_{\text{acc}} + P_{ij}^T dO[I_s:I_e]$ . - Write back the effective region
$dV[J_s^{\text{eff}}:J_e^{\text{eff}}]$ to HBM. - // Compute dS (Backprop through Softmax)
- On chip,
$dS_{ij}^{\text{conv}} = P_{ij} \odot (dP_{ij} - D[I_s:I_e])$ . - // Compute
$d(QK^T)$ (Backprop through Convolution) - On chip,
$d({QK^T})$ =$\text{TransposedConv2D}(dS^{\text{conv}}, W)$ . - // Update dK (Read-Modify-Write HBM)
- Load
$Q[I_s:I_e]$ from HBM to SRAM (was overwritten by dO). - Load
$dK[J_s:J_e]$ from HBM to SRAM accumulator$dK_{\text{acc}}$ . (If$i>1$ , else initialize to 0). - On chip, transpose $d(QK^T){ij} \rightarrow d(QK^T){ij}^T$.
- On chip,
$dK_{\text{acc}} \leftarrow dK_{\text{acc}} + d(QK^T)_{ij}^T Q[I_s:I_e]$ . - Write back the effective region
$dK[J_s^{\text{eff}}:J_e^{\text{eff}}]$ to HBM (unscaled). - end for
- // Finalization
- // Apply scale factor to dK (Read-Modify-Write HBM). dV is already finalized.
- Load
$dK[J_s^{\text{eff}}:J_e^{\text{eff}}]$ from HBM. -
$dK[J_s^{\text{eff}}:J_e^{\text{eff}}] \leftarrow dK[J_s^{\text{eff}}:J_e^{\text{eff}}] \cdot \text{scale}$ . - Write
$dK[J_s^{\text{eff}}:J_e^{\text{eff}}]$ back to HBM. - end Parallelize
Require: Query matrix for the last 16 tokens
- // Part 1: Partial Attention Computation (fwd_inference_kernel)
- // This kernel computes partial attention outputs by splitting the Key/Value matrices along the sequence length dimension (
$N_k$ ). - // Each thread block processes one split for a specific batch and head.
-
Parallelize for each (batch, head, k_split) triplet, where
$s=1, \dots, N_{\text{splits}}$ do: - Initialize partial output
$O_s=(0)_{1 \times d}$ , partial max$M_s=-\infty$ , and partial sum$L_s=0$ in SRAM. - Load the entire query tile
$Q \in \mathbb{R}^{16 \times d}$ and convolution kernel$W$ from HBM to SRAM. - Define the number of K-tiles for this split:
$T_k = \lceil N_k / (N_{\text{splits}} \cdot B_k) \rceil$ . - for each K-tile
$j=1, \dots, T_k$ do: - // Define indices for the current Key/Value tile of size
$B_k$ . - Load K tile from HBM to SRAM.
- On chip, compute attention scores:
$S_j = \text{scale} \cdot (QK^T) \in \mathbb{R}^{16 \times B_k}$ . - On chip, apply 2D convolution to the scores:
$S_{j, \text{conv}} = \text{Conv2D}(S_{j}, W)$ . - Extract the scores corresponding to the last query token:
$s_{j, \text{lastQ}} \in \mathbb{R}^{1 \times B_k}$ . - // --- Online Softmax Update ---
- Find the maximum value in the current tile's scores:
$m_j = \max(s_{j, \text{lastQ}})$ . - Store the previous max for the split:
$M_{\text{old}} = M_s$ . - Update the max for the split:
$M_s = \max(M_s, m_j)$ . - Compute correction factors based on the max update:
$\alpha = \exp(M_{\text{old}} - M_s)$ and$\beta = \exp(m_j - M_s)$ . - Rescale the previous partial output and sum:
$O_s \leftarrow O_s \cdot \alpha$ and$L_s \leftarrow L_s \cdot \alpha$ . - Compute softmax probabilities for the current tile:
$P_j = \exp(s_{j, \text{lastQ}} - M_s)$ . - Update the partial sum:
$L_s \leftarrow L_s + \sum(P_j)$ . - Load the corresponding V tile from HBM to SRAM.
- Update the partial output:
$O_s \leftarrow O_s + P_j V$ . - end for
- Write the final partial results for this split (
$O_s, M_s, L_s$ ) to HBM. - end Parallelize
- // Part 2: Combine K-Splits (combine_splits_kernel)
- // If
$N_{\text{splits}} > 1$ , this kernel is launched to combine the partial results into a final, correct output. - // Each thread block combines the splits for a single (batch, head) pair.
-
if
$N_{\text{splits}} > 1$ then - Parallelize for each (batch, head) pair do:
- Load all partial results ${O_s, M_s, L_s}{s=1}^{N{\text{splits}}}$ for the current (batch, head) from HBM to SRAM.
- Initialize the combined results with the first split's values:
$O_{\text{final}} = O_1, M_{\text{final}} = M_1, L_{\text{final}} = L_1$ . - for each subsequent split
$s=2, \dots, N_{\text{splits}}$ do: - Store the previous combined max:
$M_{\text{old}} = M_{\text{final}}$ . - Find the new true maximum across the combined and current splits:
$M_{\text{final}} = \max(M_{\text{final}}, M_s)$ . - Compute correction factors:
$\alpha = \exp(M_{\text{old}} - M_{\text{final}})$ and$\beta = \exp(M_s - M_{\text{final}})$ . - Combine the output vectors with proper scaling:
$O_{\text{final}} \leftarrow O_{\text{final}} \cdot \alpha + O_s \cdot \beta$ . - Combine the sum values:
$L_{\text{final}} \leftarrow L_{\text{final}} \cdot \alpha + L_s \cdot \beta$ . - end for
- Normalize the final output vector:
$O_{\text{final}} \leftarrow O_{\text{final}} / L_{\text{final}}$ . - Write the final
$O_{\text{final}}$ back to its designated position in HBM. - end Parallelize
- end if


