@@ -60,38 +60,35 @@ static void checkCudnnFE(fe::error_object e, const char *file, int line) {
6060}
6161#define checkCudnnFE (err ) checkCudnnFE(err, __FILE__, __LINE__)
6262
63- using graph_tensors_fwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
64- std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
65- std::shared_ptr<fe::graph::Tensor_attributes>, // K,
66- std::shared_ptr<fe::graph::Tensor_attributes>, // V,
67- std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
68- std::shared_ptr<fe::graph::Tensor_attributes>, // O
69- std::shared_ptr<fe::graph::Tensor_attributes> // Stats
70- >;
71-
72- using graph_tensors_bwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
73- std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
74- std::shared_ptr<fe::graph::Tensor_attributes>, // K,
75- std::shared_ptr<fe::graph::Tensor_attributes>, // V,
76- std::shared_ptr<fe::graph::Tensor_attributes>, // O
77- std::shared_ptr<fe::graph::Tensor_attributes>, // dO
78- std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
79- std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
80- std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
81- std::shared_ptr<fe::graph::Tensor_attributes>, // dK,
82- std::shared_ptr<fe::graph::Tensor_attributes> // dV
83- >;
63+ enum UIDs {
64+ Q_UID,
65+ K_UID,
66+ V_UID,
67+ Attn_scale_UID,
68+ O_UID,
69+ Stats_UID,
70+ dO_UID,
71+ dQ_UID,
72+ dK_UID,
73+ dV_UID
74+ };
8475
8576// Need a cache because graph->build_operation_graph() is slow but everything else seems fast
86- using cache_type_fwd = std::unordered_map <std::size_t , graph_tensors_fwd >;
87- using cache_type_bwd = std::unordered_map <std::size_t , graph_tensors_bwd >;
77+ using cache_type_fwd = std::map <std::tuple< int , int , int , int , int >, std::shared_ptr<fe::graph::Graph> >;
78+ using cache_type_bwd = std::map <std::tuple< int , int , int , int >, std::shared_ptr<fe::graph::Graph> >;
8879
8980// Loosely based on cuDNN frontend samples functions and massively simplified
90- template < typename ... Args>
91- auto lookup_cache_or_build_graph_fwd (Args... args) {
81+ auto lookup_cache_or_build_graph_fwd ( int B, int H, int T, int HS, int is_inference_only) {
82+
9283 static cache_type_fwd user_maintained_cache_fwd;
93- auto [B, H, T, HS, is_inference_only] = std::make_tuple (args...);
9484
85+ auto key = std::make_tuple (B, H, T, HS, is_inference_only);
86+
87+ auto it = user_maintained_cache_fwd.find (key);
88+ if (it != user_maintained_cache_fwd.end ()) {
89+ return it->second ;
90+ }
91+
9592 auto graph = std::make_shared<fe::graph::Graph>();
9693 graph->set_io_data_type (CUDNN_16BIT)
9794 .set_intermediate_data_type (fe::DataType_t::FLOAT)
@@ -100,16 +97,20 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
10097 // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute
10198 auto Q = graph->tensor (fe::graph::Tensor_attributes ().set_name (" Q" )
10299 .set_dim ({B, H, T, HS})
100+ .set_uid (Q_UID)
103101 .set_stride ({3 * H * HS * T, HS, 3 * H * HS, 1 }));
104102 auto K = graph->tensor (fe::graph::Tensor_attributes ().set_name (" K" )
105103 .set_dim ({B, H, T, HS})
104+ .set_uid (K_UID)
106105 .set_stride ({3 * H * HS * T, HS, 3 * H * HS, 1 }));
107106 auto V = graph->tensor (fe::graph::Tensor_attributes ().set_name (" V" )
108107 .set_dim ({B, H, T, HS})
108+ .set_uid (V_UID)
109109 .set_stride ({3 * H * HS * T, HS, 3 * H * HS, 1 }));
110110 auto attn_scale = graph->tensor (fe::graph::Tensor_attributes ().set_name (" attn_scale" )
111111 .set_dim ({1 , 1 , 1 , 1 })
112112 .set_stride ({1 , 1 , 1 , 1 })
113+ .set_uid (Attn_scale_UID)
113114 .set_is_pass_by_value (true )
114115 .set_data_type (fe::DataType_t::FLOAT));
115116
@@ -122,38 +123,47 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
122123 auto [O, stats] = graph->sdpa (Q, K, V, sdpa_options);
123124
124125 // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32
125- O->set_output (true ).set_dim ({B, H, T, HS}).set_stride ({H * HS * T, HS, H * HS, 1 });
126+ O->set_output (true ).set_dim ({B, H, T, HS}).set_stride ({H * HS * T, HS, H * HS, 1 }). set_uid (O_UID) ;
126127
127128 assert (stats == nullptr || is_inference_only == false );
128129 if (is_inference_only == false ) {
129130 stats->set_output (true ).set_data_type (fe::DataType_t::FLOAT)
130131 .set_dim ({B, H, T, 1 })
131- .set_stride ({H * T, T, 1 , 1 });
132+ .set_stride ({H * T, T, 1 , 1 })
133+ .set_uid (Stats_UID);
132134 }
133135
134136 checkCudnnFE (graph->validate ());
135- auto key = graph->key ();
136- auto it = user_maintained_cache_fwd.find (key);
137- if (it != user_maintained_cache_fwd.end ()) {
138- return it->second ;
139- }
140137
141138 // Build the operation graph and execution part (this is the VERY SLOW PART)
142139 checkCudnnFE (graph->build_operation_graph (cudnn_handle));
143140 auto plans = graph->create_execution_plans ({fe::HeurMode_t::A});
144141 checkCudnnFE (graph->check_support (cudnn_handle));
145142 checkCudnnFE (graph->build_plans (cudnn_handle));
146- assert (graph->get_workspace_size () <= cudnn_workspace_size); // fwd shouldn't need workspace
143+ // Reallocate the workspace if the required size is greater than the current workspace
144+ // In H100 this may be around 16B
145+ if (graph->get_workspace_size () > cudnn_workspace_size) {
146+ if (cudnn_workspace_size > 0 ) {
147+ cudaCheck (cudaFree (cudnn_workspace));
148+ }
149+ cudnn_workspace_size = graph->get_workspace_size ();
150+ cudaCheck (cudaMalloc (&cudnn_workspace, cudnn_workspace_size));
151+ }
147152
148- auto tuple = std::make_tuple (graph, Q, K, V, attn_scale, O, stats );
149- user_maintained_cache_fwd. insert ({key, tuple});
150- return tuple ;
153+ user_maintained_cache_fwd. insert ({key, graph} );
154+
155+ return graph ;
151156}
152157
153- template <typename ... Args>
154- auto lookup_cache_or_build_graph_bwd (Args... args) {
158+ auto lookup_cache_or_build_graph_bwd (int B, int NH, int T, int HS) {
155159 static cache_type_bwd user_maintained_cache_bwd;
156- auto [B, NH, T, HS] = std::make_tuple (args...);
160+
161+ auto key = std::make_tuple (B, NH, T, HS);
162+
163+ auto it = user_maintained_cache_bwd.find (key);
164+ if (it != user_maintained_cache_bwd.end ()) {
165+ return it->second ;
166+ }
157167
158168 auto graph = std::make_shared<fe::graph::Graph>();
159169 graph->set_io_data_type (CUDNN_16BIT)
@@ -164,28 +174,35 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
164174 // must come from inp (which means we also need to convert THAT to FP16)
165175 auto Q = graph->tensor (fe::graph::Tensor_attributes ().set_name (" Q" )
166176 .set_dim ({B, NH, T, HS})
177+ .set_uid (Q_UID)
167178 .set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }));
168179 auto K = graph->tensor (fe::graph::Tensor_attributes ().set_name (" K" )
169180 .set_dim ({B, NH, T, HS})
181+ .set_uid (K_UID)
170182 .set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }));
171183 auto V = graph->tensor (fe::graph::Tensor_attributes ().set_name (" V" )
172184 .set_dim ({B, NH, T, HS})
185+ .set_uid (V_UID)
173186 .set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }));
174187 auto O = graph->tensor (fe::graph::Tensor_attributes ().set_name (" O" )
175188 .set_dim ({B, NH, T, HS})
189+ .set_uid (O_UID)
176190 .set_stride ({NH * HS * T, HS, NH * HS, 1 }));
177191 auto dO = graph->tensor (fe::graph::Tensor_attributes ().set_name (" dO" )
178192 .set_dim ({B, NH, T, HS})
193+ .set_uid (dO_UID)
179194 .set_stride ({NH * HS * T, HS, NH * HS, 1 }));
180195
181196 auto stats = graph->tensor (fe::graph::Tensor_attributes ().set_name (" stats" )
182197 .set_dim ({B, NH, T, 1 })
198+ .set_uid (Stats_UID)
183199 .set_stride ({NH * T, T, 1 , 1 })
184200 .set_data_type (fe::DataType_t::FLOAT));
185201 auto attn_scale = graph->tensor (fe::graph::Tensor_attributes ().set_name (" attn_scale" )
186202 .set_dim ({1 , 1 , 1 , 1 })
187203 .set_stride ({1 , 1 , 1 , 1 })
188204 .set_is_pass_by_value (true )
205+ .set_uid (Attn_scale_UID)
189206 .set_data_type (fe::DataType_t::FLOAT));
190207 auto sdpa_backward_options = fe::graph::SDPA_backward_attributes ().set_name (" flash_attention_backward" )
191208 .set_causal_mask (true )
@@ -194,16 +211,11 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
194211 // Create the graph operation and get the output tensors back
195212 auto [dQ, dK, dV] = graph->sdpa_backward (Q, K, V, O, dO, stats, sdpa_backward_options);
196213
197- dQ->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 });
198- dK->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 });
199- dV->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 });
214+ dQ->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }). set_uid (dQ_UID) ;
215+ dK->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }). set_uid (dK_UID) ;
216+ dV->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }). set_uid (dV_UID) ;
200217
201218 checkCudnnFE (graph->validate ());
202- auto key = graph->key ();
203- auto it = user_maintained_cache_bwd.find (key);
204- if (it != user_maintained_cache_bwd.end ()) {
205- return it->second ;
206- }
207219
208220 // Build the operation graph and execution part (this is the VERY SLOW PART)
209221 checkCudnnFE (graph->build_operation_graph (cudnn_handle));
@@ -221,9 +233,8 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
221233 cudaCheck (cudaMalloc (&cudnn_workspace, cudnn_workspace_size));
222234 }
223235
224- auto tuple = std::make_tuple (graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV);
225- user_maintained_cache_bwd.insert ({key, tuple});
226- return tuple;
236+ user_maintained_cache_bwd.insert ({key, graph});
237+ return graph;
227238}
228239
229240void attention_forward_cudnn (floatX* out, // output: (B, T, NH, HS)
@@ -235,8 +246,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
235246 bool is_inference_only = (stats == nullptr );
236247
237248 // Get graph and tensors from cache (or generate it on first use)
238- auto [graph, Q, K, V, attn_scale, O, softmax_stats] =
239- lookup_cache_or_build_graph_fwd (B, NH, T, HS, is_inference_only);
249+ auto graph = lookup_cache_or_build_graph_fwd (B, NH, T, HS, is_inference_only);
240250
241251 // Prepare all the tensor pointers for executing the graph
242252 void * devPtrQ = inp;
@@ -246,12 +256,12 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
246256 void * devPtrO = out;
247257
248258 // Build variant pack
249- std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes> , void *> variant_pack = {
250- {Q , devPtrQ}, {K , devPtrK}, {V , devPtrV}, {attn_scale , &attn_scale_cpu}, {O , devPtrO}};
259+ std::unordered_map<int64_t , void *> variant_pack = {
260+ {Q_UID , devPtrQ}, {K_UID , devPtrK}, {V_UID , devPtrV}, {Attn_scale_UID , &attn_scale_cpu}, {O_UID , devPtrO}};
251261
252262 // Add the stats tensor unless we are only doing inference (only needed for backward pass)
253263 if (is_inference_only == false ) {
254- variant_pack[softmax_stats ] = stats;
264+ variant_pack[Stats_UID ] = stats;
255265 }
256266
257267 // Execute graph
@@ -266,8 +276,7 @@ void attention_backward_cudnn(floatX* dqkvr,
266276 int HS = C / NH; // number of features per head
267277
268278 // Get graph and tensors from cache (or generate it on first use)
269- auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] =
270- lookup_cache_or_build_graph_bwd (B, NH, T, HS);
279+ auto graph = lookup_cache_or_build_graph_bwd (B, NH, T, HS);
271280
272281 // Prepare all the tensor pointers for executing the graph
273282 void * devPtrQ = qkvr;
@@ -283,10 +292,10 @@ void attention_backward_cudnn(floatX* dqkvr,
283292 void * devPtrdV = (dqkvr + 2 * NH * HS);
284293
285294 // Build variant pack that links each tensor to its data pointer
286- std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes> , void *> variant_pack = {
287- {Q , devPtrQ}, {K , devPtrK}, {V , devPtrV}, {O , devPtrO}, {dO , devPtrdO}, {Stats , devPtrStats},
288- {dQ , devPtrdQ}, {dK , devPtrdK}, {dV , devPtrdV},
289- {attn_scale , &attn_scale_cpu}};
295+ std::unordered_map<int64_t , void *> variant_pack = {
296+ {Q_UID , devPtrQ}, {K_UID , devPtrK}, {V_UID , devPtrV}, {O_UID , devPtrO}, {dO_UID , devPtrdO}, {Stats_UID , devPtrStats},
297+ {dQ_UID , devPtrdQ}, {dK_UID , devPtrdK}, {dV_UID , devPtrdV},
298+ {Attn_scale_UID , &attn_scale_cpu}};
290299
291300 // Execute graph
292301 checkCudnnFE (graph->execute (cudnn_handle, variant_pack, cudnn_workspace));
0 commit comments