diff --git a/tutorial/5 - flash attention2.ipynb b/tutorial/5 - flash attention2.ipynb new file mode 100644 index 00000000..4861996a --- /dev/null +++ b/tutorial/5 - flash attention2.ipynb @@ -0,0 +1,90 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Based on the flash attention1, the developer has implemented flash attention2, whose main changes are to put the tiled Q in the outer loop and the tiled K and V in the inner loop, and to change the way of calculating online softmax and O. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "torch.manual_seed(456)\n", + "\n", + "N,d=16,8\n", + "Q_mat=torch.rand((N,d))\n", + "K_mat=torch.rand((N,d))\n", + "V_mat=torch.rand((N,d))\n", + "\n", + "expected_softmax=torch.softmax(Q_mat@K_mat.T,dim=1)\n", + "expected_attention=expected_softmax@V_mat\n", + "\n", + "# tile size for matmul, no op bigger than this size can be stored in SRAM\n", + "Br=4\n", + "Bc=d\n", + "\n", + "# variables outside the for loop represent the global memory\n", + "# they are the only ones bigger than what the SRAM can store\n", + "O=torch.zeros((N,d))\n", + "\n", + "# For the 2 variables below, they may be removed in a serially executed code (in particular the outter for loop)\n", + "# They are needed in parallelized execution where each thread block need to sync its findings with the others\n", + "for block_start_Br in range(0,N,Br):\n", + " block_end_Br=block_start_Br+Br\n", + " # line 4, load a block of Q_mat from HBM\n", + " Qi=Q_mat[block_start_Br:block_end_Br,:]\n", + " # line 5, initialize Oi, li and mi.\n", + " Oi=torch.zeros((Br,d)) # shape Br x d\n", + " li=torch.zeros((Br,1)) # shape Br x 1\n", + " mi=torch.full((Br,1),-torch.inf) # shape Br x 1\n", + "\n", + " for block_start_Bc in range(0,N,Bc):\n", + " block_end_Bc=block_start_Bc+Bc\n", + "\n", + " # line 7, load a block from matmul input tensor\n", + " Kj=K_mat[block_start_Bc:block_end_Bc,:]\n", + " Vj=V_mat[block_start_Bc:block_end_Bc,:]\n", + "\n", + " # line 8,QKt at the tile level\n", + " Sij=Qi@Kj.T\n", + "\n", + " # line 9, find max of each row regarding the current block and the previous ones we have already visited\n", + " mi_new=torch.max(torch.column_stack([mi,torch.max(Sij,dim=1).values[:,None]]),dim=1).values[:,None]\n", + " \n", + " # line 9,compute the softmax numerator like if we only had the data from this block (and nothing before or after)\n", + " Pij_hat=torch.exp(Sij-mi_new)\n", + "\n", + " # line 9,adjusting factor (see online softmax computation above) leveraging the rule of exponentiation\n", + " li=torch.exp(mi-mi_new)*li+torch.sum(Pij_hat,dim=1)[:,None]\n", + " \n", + " # line 10\n", + " Oi=Oi*torch.exp(mi-mi_new)+Pij_hat@Vj\n", + "\n", + " # update the mi\n", + " mi=mi_new\n", + " \n", + " # line 12\n", + " Oi=Oi/li\n", + "\n", + " # line 14\n", + " O[block_start_Br:block_end_Br,:]=Oi\n", + "assert torch.allclose(O,expected_attention)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}