diff --git a/cudnn_att.cpp b/cudnn_att.cpp index fd9760b1a..04b1a92ec 100644 --- a/cudnn_att.cpp +++ b/cudnn_att.cpp @@ -60,38 +60,35 @@ static void checkCudnnFE(fe::error_object e, const char *file, int line) { } #define checkCudnnFE(err) checkCudnnFE(err, __FILE__, __LINE__) -using graph_tensors_fwd = std::tuple, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::shared_ptr, // Attn_scale, - std::shared_ptr, // O - std::shared_ptr // Stats ->; - -using graph_tensors_bwd = std::tuple, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::shared_ptr, // O - std::shared_ptr, // dO - std::shared_ptr, // Stats - std::shared_ptr, // Attn_scale, - std::shared_ptr, // dQ, - std::shared_ptr, // dK, - std::shared_ptr // dV ->; +enum UIDs { + Q_UID, + K_UID, + V_UID, + Attn_scale_UID, + O_UID, + Stats_UID, + dO_UID, + dQ_UID, + dK_UID, + dV_UID +}; // Need a cache because graph->build_operation_graph() is slow but everything else seems fast -using cache_type_fwd = std::unordered_map; -using cache_type_bwd = std::unordered_map; +using cache_type_fwd = std::map, std::shared_ptr>; +using cache_type_bwd = std::map, std::shared_ptr>; // Loosely based on cuDNN frontend samples functions and massively simplified -template -auto lookup_cache_or_build_graph_fwd(Args... args) { +auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_only) { + static cache_type_fwd user_maintained_cache_fwd; - auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...); + auto key = std::make_tuple(B, H, T, HS, is_inference_only); + + auto it = user_maintained_cache_fwd.find(key); + if (it != user_maintained_cache_fwd.end()) { + return it->second; + } + auto graph = std::make_shared(); graph->set_io_data_type(CUDNN_16BIT) .set_intermediate_data_type(fe::DataType_t::FLOAT) @@ -100,16 +97,20 @@ auto lookup_cache_or_build_graph_fwd(Args... args) { // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q") .set_dim({B, H, T, HS}) + .set_uid(Q_UID) .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K") .set_dim({B, H, T, HS}) + .set_uid(K_UID) .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V") .set_dim({B, H, T, HS}) + .set_uid(V_UID) .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) + .set_uid(Attn_scale_UID) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); @@ -122,38 +123,47 @@ auto lookup_cache_or_build_graph_fwd(Args... args) { auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32 - O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}); + O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(O_UID); assert(stats == nullptr || is_inference_only == false); if (is_inference_only == false) { stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) .set_dim({B, H, T, 1}) - .set_stride({H * T, T, 1, 1}); + .set_stride({H * T, T, 1, 1}) + .set_uid(Stats_UID); } checkCudnnFE(graph->validate()); - auto key = graph->key(); - auto it = user_maintained_cache_fwd.find(key); - if (it != user_maintained_cache_fwd.end()) { - return it->second; - } // Build the operation graph and execution part (this is the VERY SLOW PART) checkCudnnFE(graph->build_operation_graph(cudnn_handle)); auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); checkCudnnFE(graph->check_support(cudnn_handle)); checkCudnnFE(graph->build_plans(cudnn_handle)); - assert(graph->get_workspace_size() <= cudnn_workspace_size); // fwd shouldn't need workspace + // Reallocate the workspace if the required size is greater than the current workspace + // In H100 this may be around 16B + if (graph->get_workspace_size() > cudnn_workspace_size) { + if (cudnn_workspace_size > 0) { + cudaCheck(cudaFree(cudnn_workspace)); + } + cudnn_workspace_size = graph->get_workspace_size(); + cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); + } - auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats); - user_maintained_cache_fwd.insert({key, tuple}); - return tuple; + user_maintained_cache_fwd.insert({key, graph}); + + return graph; } -template -auto lookup_cache_or_build_graph_bwd(Args... args) { +auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { static cache_type_bwd user_maintained_cache_bwd; - auto [B, NH, T, HS] = std::make_tuple(args...); + + auto key = std::make_tuple(B, NH, T, HS); + + auto it = user_maintained_cache_bwd.find(key); + if (it != user_maintained_cache_bwd.end()) { + return it->second; + } auto graph = std::make_shared(); graph->set_io_data_type(CUDNN_16BIT) @@ -164,28 +174,35 @@ auto lookup_cache_or_build_graph_bwd(Args... args) { // must come from inp (which means we also need to convert THAT to FP16) auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q") .set_dim({B, NH, T, HS}) + .set_uid(Q_UID) .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K") .set_dim({B, NH, T, HS}) + .set_uid(K_UID) .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V") .set_dim({B, NH, T, HS}) + .set_uid(V_UID) .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); auto O = graph->tensor(fe::graph::Tensor_attributes().set_name("O") .set_dim({B, NH, T, HS}) + .set_uid(O_UID) .set_stride({NH * HS * T, HS, NH * HS, 1})); auto dO = graph->tensor(fe::graph::Tensor_attributes().set_name("dO") .set_dim({B, NH, T, HS}) + .set_uid(dO_UID) .set_stride({NH * HS * T, HS, NH * HS, 1})); auto stats = graph->tensor(fe::graph::Tensor_attributes().set_name("stats") .set_dim({B, NH, T, 1}) + .set_uid(Stats_UID) .set_stride({NH * T, T, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) + .set_uid(Attn_scale_UID) .set_data_type(fe::DataType_t::FLOAT)); auto sdpa_backward_options = fe::graph::SDPA_backward_attributes().set_name("flash_attention_backward") .set_causal_mask(true) @@ -194,16 +211,11 @@ auto lookup_cache_or_build_graph_bwd(Args... args) { // Create the graph operation and get the output tensors back auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options); - dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); - dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); - dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}); + dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dQ_UID); + dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dK_UID); + dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dV_UID); checkCudnnFE(graph->validate()); - auto key = graph->key(); - auto it = user_maintained_cache_bwd.find(key); - if (it != user_maintained_cache_bwd.end()) { - return it->second; - } // Build the operation graph and execution part (this is the VERY SLOW PART) checkCudnnFE(graph->build_operation_graph(cudnn_handle)); @@ -221,9 +233,8 @@ auto lookup_cache_or_build_graph_bwd(Args... args) { cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size)); } - auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV); - user_maintained_cache_bwd.insert({key, tuple}); - return tuple; + user_maintained_cache_bwd.insert({key, graph}); + return graph; } void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) @@ -235,8 +246,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) bool is_inference_only = (stats == nullptr); // Get graph and tensors from cache (or generate it on first use) - auto [graph, Q, K, V, attn_scale, O, softmax_stats] = - lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only); + auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only); // Prepare all the tensor pointers for executing the graph void* devPtrQ = inp; @@ -246,12 +256,12 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) void* devPtrO = out; // Build variant pack - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}}; + std::unordered_map variant_pack = { + {Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {Attn_scale_UID, &attn_scale_cpu}, {O_UID, devPtrO}}; // Add the stats tensor unless we are only doing inference (only needed for backward pass) if (is_inference_only == false) { - variant_pack[softmax_stats] = stats; + variant_pack[Stats_UID] = stats; } // Execute graph @@ -266,8 +276,7 @@ void attention_backward_cudnn(floatX* dqkvr, int HS = C / NH; // number of features per head // Get graph and tensors from cache (or generate it on first use) - auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] = - lookup_cache_or_build_graph_bwd(B, NH, T, HS); + auto graph = lookup_cache_or_build_graph_bwd(B, NH, T, HS); // Prepare all the tensor pointers for executing the graph void* devPtrQ = qkvr; @@ -283,10 +292,10 @@ void attention_backward_cudnn(floatX* dqkvr, void* devPtrdV = (dqkvr + 2 * NH * HS); // Build variant pack that links each tensor to its data pointer - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {O, devPtrO}, {dO, devPtrdO}, {Stats, devPtrStats}, - {dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV}, - {attn_scale, &attn_scale_cpu}}; + std::unordered_map variant_pack = { + {Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {O_UID, devPtrO}, {dO_UID, devPtrdO}, {Stats_UID, devPtrStats}, + {dQ_UID, devPtrdQ}, {dK_UID, devPtrdK}, {dV_UID, devPtrdV}, + {Attn_scale_UID, &attn_scale_cpu}}; // Execute graph checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));