Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tutorial/5 - flash attention2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the flash attention1, the developer has implemented flash attention2, whose main changes are to put the tiled Q in the outer loop and the tiled K and V in the inner loop, and to change the way of calculating online softmax and O. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(456)\n",
"\n",
"N,d=16,8\n",
"Q_mat=torch.rand((N,d))\n",
"K_mat=torch.rand((N,d))\n",
"V_mat=torch.rand((N,d))\n",
"\n",
"expected_softmax=torch.softmax(Q_mat@K_mat.T,dim=1)\n",
"expected_attention=expected_softmax@V_mat\n",
"\n",
"# tile size for matmul, no op bigger than this size can be stored in SRAM\n",
"Br=4\n",
"Bc=d\n",
"\n",
"# variables outside the for loop represent the global memory\n",
"# they are the only ones bigger than what the SRAM can store\n",
"O=torch.zeros((N,d))\n",
"\n",
"# For the 2 variables below, they may be removed in a serially executed code (in particular the outter for loop)\n",
"# They are needed in parallelized execution where each thread block need to sync its findings with the others\n",
"for block_start_Br in range(0,N,Br):\n",
" block_end_Br=block_start_Br+Br\n",
" # line 4, load a block of Q_mat from HBM\n",
" Qi=Q_mat[block_start_Br:block_end_Br,:]\n",
" # line 5, initialize Oi, li and mi.\n",
" Oi=torch.zeros((Br,d)) # shape Br x d\n",
" li=torch.zeros((Br,1)) # shape Br x 1\n",
" mi=torch.full((Br,1),-torch.inf) # shape Br x 1\n",
"\n",
" for block_start_Bc in range(0,N,Bc):\n",
" block_end_Bc=block_start_Bc+Bc\n",
"\n",
" # line 7, load a block from matmul input tensor\n",
" Kj=K_mat[block_start_Bc:block_end_Bc,:]\n",
" Vj=V_mat[block_start_Bc:block_end_Bc,:]\n",
"\n",
" # line 8,QKt at the tile level\n",
" [email protected]\n",
"\n",
" # line 9, find max of each row regarding the current block and the previous ones we have already visited\n",
" mi_new=torch.max(torch.column_stack([mi,torch.max(Sij,dim=1).values[:,None]]),dim=1).values[:,None]\n",
" \n",
" # line 9,compute the softmax numerator like if we only had the data from this block (and nothing before or after)\n",
" Pij_hat=torch.exp(Sij-mi_new)\n",
"\n",
" # line 9,adjusting factor (see online softmax computation above) leveraging the rule of exponentiation\n",
" li=torch.exp(mi-mi_new)*li+torch.sum(Pij_hat,dim=1)[:,None]\n",
" \n",
" # line 10\n",
" Oi=Oi*torch.exp(mi-mi_new)+Pij_hat@Vj\n",
"\n",
" # update the mi\n",
" mi=mi_new\n",
" \n",
" # line 12\n",
" Oi=Oi/li\n",
"\n",
" # line 14\n",
" O[block_start_Br:block_end_Br,:]=Oi\n",
"assert torch.allclose(O,expected_attention)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}