Skip to content

Commit 01c7a33

Browse files
committed
- Simplify graph cache and usage of cudnn.
- Fix failures in H100
1 parent b8eaafd commit 01c7a33

File tree

1 file changed

+71
-62
lines changed

1 file changed

+71
-62
lines changed

cudnn_att.cpp

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -60,38 +60,35 @@ static void checkCudnnFE(fe::error_object e, const char *file, int line) {
6060
}
6161
#define checkCudnnFE(err) checkCudnnFE(err, __FILE__, __LINE__)
6262

63-
using graph_tensors_fwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
64-
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
65-
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
66-
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
67-
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
68-
std::shared_ptr<fe::graph::Tensor_attributes>, // O
69-
std::shared_ptr<fe::graph::Tensor_attributes> // Stats
70-
>;
71-
72-
using graph_tensors_bwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
73-
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
74-
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
75-
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
76-
std::shared_ptr<fe::graph::Tensor_attributes>, // O
77-
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
78-
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
79-
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
80-
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
81-
std::shared_ptr<fe::graph::Tensor_attributes>, // dK,
82-
std::shared_ptr<fe::graph::Tensor_attributes> // dV
83-
>;
63+
enum UIDs {
64+
Q_UID,
65+
K_UID,
66+
V_UID,
67+
Attn_scale_UID,
68+
O_UID,
69+
Stats_UID,
70+
dO_UID,
71+
dQ_UID,
72+
dK_UID,
73+
dV_UID
74+
};
8475

8576
// Need a cache because graph->build_operation_graph() is slow but everything else seems fast
86-
using cache_type_fwd = std::unordered_map<std::size_t, graph_tensors_fwd>;
87-
using cache_type_bwd = std::unordered_map<std::size_t, graph_tensors_bwd>;
77+
using cache_type_fwd = std::map<std::tuple<int,int,int,int, int>, std::shared_ptr<fe::graph::Graph>>;
78+
using cache_type_bwd = std::map<std::tuple<int,int,int,int>, std::shared_ptr<fe::graph::Graph>>;
8879

8980
// Loosely based on cuDNN frontend samples functions and massively simplified
90-
template <typename... Args>
91-
auto lookup_cache_or_build_graph_fwd(Args... args) {
81+
auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_only) {
82+
9283
static cache_type_fwd user_maintained_cache_fwd;
93-
auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...);
9484

85+
auto key = std::make_tuple(B, H, T, HS, is_inference_only);
86+
87+
auto it = user_maintained_cache_fwd.find(key);
88+
if (it != user_maintained_cache_fwd.end()) {
89+
return it->second;
90+
}
91+
9592
auto graph = std::make_shared<fe::graph::Graph>();
9693
graph->set_io_data_type(CUDNN_16BIT)
9794
.set_intermediate_data_type(fe::DataType_t::FLOAT)
@@ -100,16 +97,20 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
10097
// QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute
10198
auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q")
10299
.set_dim({B, H, T, HS})
100+
.set_uid(Q_UID)
103101
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
104102
auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K")
105103
.set_dim({B, H, T, HS})
104+
.set_uid(K_UID)
106105
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
107106
auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V")
108107
.set_dim({B, H, T, HS})
108+
.set_uid(V_UID)
109109
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
110110
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale")
111111
.set_dim({1, 1, 1, 1})
112112
.set_stride({1, 1, 1, 1})
113+
.set_uid(Attn_scale_UID)
113114
.set_is_pass_by_value(true)
114115
.set_data_type(fe::DataType_t::FLOAT));
115116

@@ -122,38 +123,47 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
122123
auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options);
123124

124125
// Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32
125-
O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1});
126+
O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(O_UID);
126127

127128
assert(stats == nullptr || is_inference_only == false);
128129
if (is_inference_only == false) {
129130
stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
130131
.set_dim({B, H, T, 1})
131-
.set_stride({H * T, T, 1, 1});
132+
.set_stride({H * T, T, 1, 1})
133+
.set_uid(Stats_UID);
132134
}
133135

134136
checkCudnnFE(graph->validate());
135-
auto key = graph->key();
136-
auto it = user_maintained_cache_fwd.find(key);
137-
if (it != user_maintained_cache_fwd.end()) {
138-
return it->second;
139-
}
140137

141138
// Build the operation graph and execution part (this is the VERY SLOW PART)
142139
checkCudnnFE(graph->build_operation_graph(cudnn_handle));
143140
auto plans = graph->create_execution_plans({fe::HeurMode_t::A});
144141
checkCudnnFE(graph->check_support(cudnn_handle));
145142
checkCudnnFE(graph->build_plans(cudnn_handle));
146-
assert(graph->get_workspace_size() <= cudnn_workspace_size); // fwd shouldn't need workspace
143+
// Reallocate the workspace if the required size is greater than the current workspace
144+
// In H100 this may be around 16B
145+
if (graph->get_workspace_size() > cudnn_workspace_size) {
146+
if (cudnn_workspace_size > 0) {
147+
cudaCheck(cudaFree(cudnn_workspace));
148+
}
149+
cudnn_workspace_size = graph->get_workspace_size();
150+
cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));
151+
}
147152

148-
auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats);
149-
user_maintained_cache_fwd.insert({key, tuple});
150-
return tuple;
153+
user_maintained_cache_fwd.insert({key, graph});
154+
155+
return graph;
151156
}
152157

153-
template <typename... Args>
154-
auto lookup_cache_or_build_graph_bwd(Args... args) {
158+
auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
155159
static cache_type_bwd user_maintained_cache_bwd;
156-
auto [B, NH, T, HS] = std::make_tuple(args...);
160+
161+
auto key = std::make_tuple(B, NH, T, HS);
162+
163+
auto it = user_maintained_cache_bwd.find(key);
164+
if (it != user_maintained_cache_bwd.end()) {
165+
return it->second;
166+
}
157167

158168
auto graph = std::make_shared<fe::graph::Graph>();
159169
graph->set_io_data_type(CUDNN_16BIT)
@@ -164,28 +174,35 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
164174
// must come from inp (which means we also need to convert THAT to FP16)
165175
auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q")
166176
.set_dim({B, NH, T, HS})
177+
.set_uid(Q_UID)
167178
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
168179
auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K")
169180
.set_dim({B, NH, T, HS})
181+
.set_uid(K_UID)
170182
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
171183
auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V")
172184
.set_dim({B, NH, T, HS})
185+
.set_uid(V_UID)
173186
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
174187
auto O = graph->tensor(fe::graph::Tensor_attributes().set_name("O")
175188
.set_dim({B, NH, T, HS})
189+
.set_uid(O_UID)
176190
.set_stride({NH * HS * T, HS, NH * HS, 1}));
177191
auto dO = graph->tensor(fe::graph::Tensor_attributes().set_name("dO")
178192
.set_dim({B, NH, T, HS})
193+
.set_uid(dO_UID)
179194
.set_stride({NH * HS * T, HS, NH * HS, 1}));
180195

181196
auto stats = graph->tensor(fe::graph::Tensor_attributes().set_name("stats")
182197
.set_dim({B, NH, T, 1})
198+
.set_uid(Stats_UID)
183199
.set_stride({NH * T, T, 1, 1})
184200
.set_data_type(fe::DataType_t::FLOAT));
185201
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale")
186202
.set_dim({1, 1, 1, 1})
187203
.set_stride({1, 1, 1, 1})
188204
.set_is_pass_by_value(true)
205+
.set_uid(Attn_scale_UID)
189206
.set_data_type(fe::DataType_t::FLOAT));
190207
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes().set_name("flash_attention_backward")
191208
.set_causal_mask(true)
@@ -194,16 +211,11 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
194211
// Create the graph operation and get the output tensors back
195212
auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options);
196213

197-
dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
198-
dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
199-
dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
214+
dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dQ_UID);
215+
dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dK_UID);
216+
dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dV_UID);
200217

201218
checkCudnnFE(graph->validate());
202-
auto key = graph->key();
203-
auto it = user_maintained_cache_bwd.find(key);
204-
if (it != user_maintained_cache_bwd.end()) {
205-
return it->second;
206-
}
207219

208220
// Build the operation graph and execution part (this is the VERY SLOW PART)
209221
checkCudnnFE(graph->build_operation_graph(cudnn_handle));
@@ -221,9 +233,8 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
221233
cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));
222234
}
223235

224-
auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV);
225-
user_maintained_cache_bwd.insert({key, tuple});
226-
return tuple;
236+
user_maintained_cache_bwd.insert({key, graph});
237+
return graph;
227238
}
228239

229240
void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
@@ -235,8 +246,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
235246
bool is_inference_only = (stats == nullptr);
236247

237248
// Get graph and tensors from cache (or generate it on first use)
238-
auto [graph, Q, K, V, attn_scale, O, softmax_stats] =
239-
lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);
249+
auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);
240250

241251
// Prepare all the tensor pointers for executing the graph
242252
void* devPtrQ = inp;
@@ -246,12 +256,12 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
246256
void* devPtrO = out;
247257

248258
// Build variant pack
249-
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
250-
{Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}};
259+
std::unordered_map<int64_t , void*> variant_pack = {
260+
{Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {Attn_scale_UID, &attn_scale_cpu}, {O_UID, devPtrO}};
251261

252262
// Add the stats tensor unless we are only doing inference (only needed for backward pass)
253263
if (is_inference_only == false) {
254-
variant_pack[softmax_stats] = stats;
264+
variant_pack[Stats_UID] = stats;
255265
}
256266

257267
// Execute graph
@@ -266,8 +276,7 @@ void attention_backward_cudnn(floatX* dqkvr,
266276
int HS = C / NH; // number of features per head
267277

268278
// Get graph and tensors from cache (or generate it on first use)
269-
auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] =
270-
lookup_cache_or_build_graph_bwd(B, NH, T, HS);
279+
auto graph = lookup_cache_or_build_graph_bwd(B, NH, T, HS);
271280

272281
// Prepare all the tensor pointers for executing the graph
273282
void* devPtrQ = qkvr;
@@ -283,10 +292,10 @@ void attention_backward_cudnn(floatX* dqkvr,
283292
void* devPtrdV = (dqkvr + 2 * NH * HS);
284293

285294
// Build variant pack that links each tensor to its data pointer
286-
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
287-
{Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {O, devPtrO}, {dO, devPtrdO}, {Stats, devPtrStats},
288-
{dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV},
289-
{attn_scale, &attn_scale_cpu}};
295+
std::unordered_map<int64_t, void*> variant_pack = {
296+
{Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {O_UID, devPtrO}, {dO_UID, devPtrdO}, {Stats_UID, devPtrStats},
297+
{dQ_UID, devPtrdQ}, {dK_UID, devPtrdK}, {dV_UID, devPtrdV},
298+
{Attn_scale_UID, &attn_scale_cpu}};
290299

291300
// Execute graph
292301
checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));

0 commit comments

Comments
 (0)