@@ -60,38 +60,35 @@ static void checkCudnnFE(fe::error_object e, const char *file, int line) {
60
60
}
61
61
#define checkCudnnFE (err ) checkCudnnFE(err, __FILE__, __LINE__)
62
62
63
- using graph_tensors_fwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
64
- std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
65
- std::shared_ptr<fe::graph::Tensor_attributes>, // K,
66
- std::shared_ptr<fe::graph::Tensor_attributes>, // V,
67
- std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
68
- std::shared_ptr<fe::graph::Tensor_attributes>, // O
69
- std::shared_ptr<fe::graph::Tensor_attributes> // Stats
70
- >;
71
-
72
- using graph_tensors_bwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
73
- std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
74
- std::shared_ptr<fe::graph::Tensor_attributes>, // K,
75
- std::shared_ptr<fe::graph::Tensor_attributes>, // V,
76
- std::shared_ptr<fe::graph::Tensor_attributes>, // O
77
- std::shared_ptr<fe::graph::Tensor_attributes>, // dO
78
- std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
79
- std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
80
- std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
81
- std::shared_ptr<fe::graph::Tensor_attributes>, // dK,
82
- std::shared_ptr<fe::graph::Tensor_attributes> // dV
83
- >;
63
+ enum UIDs {
64
+ Q_UID,
65
+ K_UID,
66
+ V_UID,
67
+ Attn_scale_UID,
68
+ O_UID,
69
+ Stats_UID,
70
+ dO_UID,
71
+ dQ_UID,
72
+ dK_UID,
73
+ dV_UID
74
+ };
84
75
85
76
// Need a cache because graph->build_operation_graph() is slow but everything else seems fast
86
- using cache_type_fwd = std::unordered_map <std::size_t , graph_tensors_fwd >;
87
- using cache_type_bwd = std::unordered_map <std::size_t , graph_tensors_bwd >;
77
+ using cache_type_fwd = std::map <std::tuple< int , int , int , int , int >, std::shared_ptr<fe::graph::Graph> >;
78
+ using cache_type_bwd = std::map <std::tuple< int , int , int , int >, std::shared_ptr<fe::graph::Graph> >;
88
79
89
80
// Loosely based on cuDNN frontend samples functions and massively simplified
90
- template < typename ... Args>
91
- auto lookup_cache_or_build_graph_fwd (Args... args) {
81
+ auto lookup_cache_or_build_graph_fwd ( int B, int H, int T, int HS, int is_inference_only) {
82
+
92
83
static cache_type_fwd user_maintained_cache_fwd;
93
- auto [B, H, T, HS, is_inference_only] = std::make_tuple (args...);
94
84
85
+ auto key = std::make_tuple (B, H, T, HS, is_inference_only);
86
+
87
+ auto it = user_maintained_cache_fwd.find (key);
88
+ if (it != user_maintained_cache_fwd.end ()) {
89
+ return it->second ;
90
+ }
91
+
95
92
auto graph = std::make_shared<fe::graph::Graph>();
96
93
graph->set_io_data_type (CUDNN_16BIT)
97
94
.set_intermediate_data_type (fe::DataType_t::FLOAT)
@@ -100,16 +97,20 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
100
97
// QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute
101
98
auto Q = graph->tensor (fe::graph::Tensor_attributes ().set_name (" Q" )
102
99
.set_dim ({B, H, T, HS})
100
+ .set_uid (Q_UID)
103
101
.set_stride ({3 * H * HS * T, HS, 3 * H * HS, 1 }));
104
102
auto K = graph->tensor (fe::graph::Tensor_attributes ().set_name (" K" )
105
103
.set_dim ({B, H, T, HS})
104
+ .set_uid (K_UID)
106
105
.set_stride ({3 * H * HS * T, HS, 3 * H * HS, 1 }));
107
106
auto V = graph->tensor (fe::graph::Tensor_attributes ().set_name (" V" )
108
107
.set_dim ({B, H, T, HS})
108
+ .set_uid (V_UID)
109
109
.set_stride ({3 * H * HS * T, HS, 3 * H * HS, 1 }));
110
110
auto attn_scale = graph->tensor (fe::graph::Tensor_attributes ().set_name (" attn_scale" )
111
111
.set_dim ({1 , 1 , 1 , 1 })
112
112
.set_stride ({1 , 1 , 1 , 1 })
113
+ .set_uid (Attn_scale_UID)
113
114
.set_is_pass_by_value (true )
114
115
.set_data_type (fe::DataType_t::FLOAT));
115
116
@@ -122,38 +123,47 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
122
123
auto [O, stats] = graph->sdpa (Q, K, V, sdpa_options);
123
124
124
125
// Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32
125
- O->set_output (true ).set_dim ({B, H, T, HS}).set_stride ({H * HS * T, HS, H * HS, 1 });
126
+ O->set_output (true ).set_dim ({B, H, T, HS}).set_stride ({H * HS * T, HS, H * HS, 1 }). set_uid (O_UID) ;
126
127
127
128
assert (stats == nullptr || is_inference_only == false );
128
129
if (is_inference_only == false ) {
129
130
stats->set_output (true ).set_data_type (fe::DataType_t::FLOAT)
130
131
.set_dim ({B, H, T, 1 })
131
- .set_stride ({H * T, T, 1 , 1 });
132
+ .set_stride ({H * T, T, 1 , 1 })
133
+ .set_uid (Stats_UID);
132
134
}
133
135
134
136
checkCudnnFE (graph->validate ());
135
- auto key = graph->key ();
136
- auto it = user_maintained_cache_fwd.find (key);
137
- if (it != user_maintained_cache_fwd.end ()) {
138
- return it->second ;
139
- }
140
137
141
138
// Build the operation graph and execution part (this is the VERY SLOW PART)
142
139
checkCudnnFE (graph->build_operation_graph (cudnn_handle));
143
140
auto plans = graph->create_execution_plans ({fe::HeurMode_t::A});
144
141
checkCudnnFE (graph->check_support (cudnn_handle));
145
142
checkCudnnFE (graph->build_plans (cudnn_handle));
146
- assert (graph->get_workspace_size () <= cudnn_workspace_size); // fwd shouldn't need workspace
143
+ // Reallocate the workspace if the required size is greater than the current workspace
144
+ // In H100 this may be around 16B
145
+ if (graph->get_workspace_size () > cudnn_workspace_size) {
146
+ if (cudnn_workspace_size > 0 ) {
147
+ cudaCheck (cudaFree (cudnn_workspace));
148
+ }
149
+ cudnn_workspace_size = graph->get_workspace_size ();
150
+ cudaCheck (cudaMalloc (&cudnn_workspace, cudnn_workspace_size));
151
+ }
147
152
148
- auto tuple = std::make_tuple (graph, Q, K, V, attn_scale, O, stats );
149
- user_maintained_cache_fwd. insert ({key, tuple});
150
- return tuple ;
153
+ user_maintained_cache_fwd. insert ({key, graph} );
154
+
155
+ return graph ;
151
156
}
152
157
153
- template <typename ... Args>
154
- auto lookup_cache_or_build_graph_bwd (Args... args) {
158
+ auto lookup_cache_or_build_graph_bwd (int B, int NH, int T, int HS) {
155
159
static cache_type_bwd user_maintained_cache_bwd;
156
- auto [B, NH, T, HS] = std::make_tuple (args...);
160
+
161
+ auto key = std::make_tuple (B, NH, T, HS);
162
+
163
+ auto it = user_maintained_cache_bwd.find (key);
164
+ if (it != user_maintained_cache_bwd.end ()) {
165
+ return it->second ;
166
+ }
157
167
158
168
auto graph = std::make_shared<fe::graph::Graph>();
159
169
graph->set_io_data_type (CUDNN_16BIT)
@@ -164,28 +174,35 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
164
174
// must come from inp (which means we also need to convert THAT to FP16)
165
175
auto Q = graph->tensor (fe::graph::Tensor_attributes ().set_name (" Q" )
166
176
.set_dim ({B, NH, T, HS})
177
+ .set_uid (Q_UID)
167
178
.set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }));
168
179
auto K = graph->tensor (fe::graph::Tensor_attributes ().set_name (" K" )
169
180
.set_dim ({B, NH, T, HS})
181
+ .set_uid (K_UID)
170
182
.set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }));
171
183
auto V = graph->tensor (fe::graph::Tensor_attributes ().set_name (" V" )
172
184
.set_dim ({B, NH, T, HS})
185
+ .set_uid (V_UID)
173
186
.set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }));
174
187
auto O = graph->tensor (fe::graph::Tensor_attributes ().set_name (" O" )
175
188
.set_dim ({B, NH, T, HS})
189
+ .set_uid (O_UID)
176
190
.set_stride ({NH * HS * T, HS, NH * HS, 1 }));
177
191
auto dO = graph->tensor (fe::graph::Tensor_attributes ().set_name (" dO" )
178
192
.set_dim ({B, NH, T, HS})
193
+ .set_uid (dO_UID)
179
194
.set_stride ({NH * HS * T, HS, NH * HS, 1 }));
180
195
181
196
auto stats = graph->tensor (fe::graph::Tensor_attributes ().set_name (" stats" )
182
197
.set_dim ({B, NH, T, 1 })
198
+ .set_uid (Stats_UID)
183
199
.set_stride ({NH * T, T, 1 , 1 })
184
200
.set_data_type (fe::DataType_t::FLOAT));
185
201
auto attn_scale = graph->tensor (fe::graph::Tensor_attributes ().set_name (" attn_scale" )
186
202
.set_dim ({1 , 1 , 1 , 1 })
187
203
.set_stride ({1 , 1 , 1 , 1 })
188
204
.set_is_pass_by_value (true )
205
+ .set_uid (Attn_scale_UID)
189
206
.set_data_type (fe::DataType_t::FLOAT));
190
207
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes ().set_name (" flash_attention_backward" )
191
208
.set_causal_mask (true )
@@ -194,16 +211,11 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
194
211
// Create the graph operation and get the output tensors back
195
212
auto [dQ, dK, dV] = graph->sdpa_backward (Q, K, V, O, dO, stats, sdpa_backward_options);
196
213
197
- dQ->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 });
198
- dK->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 });
199
- dV->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 });
214
+ dQ->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }). set_uid (dQ_UID) ;
215
+ dK->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }). set_uid (dK_UID) ;
216
+ dV->set_output (true ).set_dim ({B, NH, T, HS}).set_stride ({3 * NH * HS * T, HS, 3 * NH * HS, 1 }). set_uid (dV_UID) ;
200
217
201
218
checkCudnnFE (graph->validate ());
202
- auto key = graph->key ();
203
- auto it = user_maintained_cache_bwd.find (key);
204
- if (it != user_maintained_cache_bwd.end ()) {
205
- return it->second ;
206
- }
207
219
208
220
// Build the operation graph and execution part (this is the VERY SLOW PART)
209
221
checkCudnnFE (graph->build_operation_graph (cudnn_handle));
@@ -221,9 +233,8 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
221
233
cudaCheck (cudaMalloc (&cudnn_workspace, cudnn_workspace_size));
222
234
}
223
235
224
- auto tuple = std::make_tuple (graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV);
225
- user_maintained_cache_bwd.insert ({key, tuple});
226
- return tuple;
236
+ user_maintained_cache_bwd.insert ({key, graph});
237
+ return graph;
227
238
}
228
239
229
240
void attention_forward_cudnn (floatX* out, // output: (B, T, NH, HS)
@@ -235,8 +246,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
235
246
bool is_inference_only = (stats == nullptr );
236
247
237
248
// Get graph and tensors from cache (or generate it on first use)
238
- auto [graph, Q, K, V, attn_scale, O, softmax_stats] =
239
- lookup_cache_or_build_graph_fwd (B, NH, T, HS, is_inference_only);
249
+ auto graph = lookup_cache_or_build_graph_fwd (B, NH, T, HS, is_inference_only);
240
250
241
251
// Prepare all the tensor pointers for executing the graph
242
252
void * devPtrQ = inp;
@@ -246,12 +256,12 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
246
256
void * devPtrO = out;
247
257
248
258
// Build variant pack
249
- std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes> , void *> variant_pack = {
250
- {Q , devPtrQ}, {K , devPtrK}, {V , devPtrV}, {attn_scale , &attn_scale_cpu}, {O , devPtrO}};
259
+ std::unordered_map<int64_t , void *> variant_pack = {
260
+ {Q_UID , devPtrQ}, {K_UID , devPtrK}, {V_UID , devPtrV}, {Attn_scale_UID , &attn_scale_cpu}, {O_UID , devPtrO}};
251
261
252
262
// Add the stats tensor unless we are only doing inference (only needed for backward pass)
253
263
if (is_inference_only == false ) {
254
- variant_pack[softmax_stats ] = stats;
264
+ variant_pack[Stats_UID ] = stats;
255
265
}
256
266
257
267
// Execute graph
@@ -266,8 +276,7 @@ void attention_backward_cudnn(floatX* dqkvr,
266
276
int HS = C / NH; // number of features per head
267
277
268
278
// Get graph and tensors from cache (or generate it on first use)
269
- auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] =
270
- lookup_cache_or_build_graph_bwd (B, NH, T, HS);
279
+ auto graph = lookup_cache_or_build_graph_bwd (B, NH, T, HS);
271
280
272
281
// Prepare all the tensor pointers for executing the graph
273
282
void * devPtrQ = qkvr;
@@ -283,10 +292,10 @@ void attention_backward_cudnn(floatX* dqkvr,
283
292
void * devPtrdV = (dqkvr + 2 * NH * HS);
284
293
285
294
// Build variant pack that links each tensor to its data pointer
286
- std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes> , void *> variant_pack = {
287
- {Q , devPtrQ}, {K , devPtrK}, {V , devPtrV}, {O , devPtrO}, {dO , devPtrdO}, {Stats , devPtrStats},
288
- {dQ , devPtrdQ}, {dK , devPtrdK}, {dV , devPtrdV},
289
- {attn_scale , &attn_scale_cpu}};
295
+ std::unordered_map<int64_t , void *> variant_pack = {
296
+ {Q_UID , devPtrQ}, {K_UID , devPtrK}, {V_UID , devPtrV}, {O_UID , devPtrO}, {dO_UID , devPtrdO}, {Stats_UID , devPtrStats},
297
+ {dQ_UID , devPtrdQ}, {dK_UID , devPtrdK}, {dV_UID , devPtrdV},
298
+ {Attn_scale_UID , &attn_scale_cpu}};
290
299
291
300
// Execute graph
292
301
checkCudnnFE (graph->execute (cudnn_handle, variant_pack, cudnn_workspace));
0 commit comments