Skip to content

Commit

Permalink
- Simplify graph cache and usage of cudnn.
Browse files Browse the repository at this point in the history
- Fix failures in H100
  • Loading branch information
Anerudhan committed May 9, 2024
1 parent b8eaafd commit 01c7a33
Showing 1 changed file with 71 additions and 62 deletions.
133 changes: 71 additions & 62 deletions cudnn_att.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,38 +60,35 @@ static void checkCudnnFE(fe::error_object e, const char *file, int line) {
}
#define checkCudnnFE(err) checkCudnnFE(err, __FILE__, __LINE__)

using graph_tensors_fwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes> // Stats
>;

using graph_tensors_bwd = std::tuple<std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // O
std::shared_ptr<fe::graph::Tensor_attributes>, // dO
std::shared_ptr<fe::graph::Tensor_attributes>, // Stats
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale,
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
std::shared_ptr<fe::graph::Tensor_attributes>, // dK,
std::shared_ptr<fe::graph::Tensor_attributes> // dV
>;
enum UIDs {
Q_UID,
K_UID,
V_UID,
Attn_scale_UID,
O_UID,
Stats_UID,
dO_UID,
dQ_UID,
dK_UID,
dV_UID
};

// Need a cache because graph->build_operation_graph() is slow but everything else seems fast
using cache_type_fwd = std::unordered_map<std::size_t, graph_tensors_fwd>;
using cache_type_bwd = std::unordered_map<std::size_t, graph_tensors_bwd>;
using cache_type_fwd = std::map<std::tuple<int,int,int,int, int>, std::shared_ptr<fe::graph::Graph>>;
using cache_type_bwd = std::map<std::tuple<int,int,int,int>, std::shared_ptr<fe::graph::Graph>>;

// Loosely based on cuDNN frontend samples functions and massively simplified
template <typename... Args>
auto lookup_cache_or_build_graph_fwd(Args... args) {
auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_only) {

static cache_type_fwd user_maintained_cache_fwd;
auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...);

auto key = std::make_tuple(B, H, T, HS, is_inference_only);

auto it = user_maintained_cache_fwd.find(key);
if (it != user_maintained_cache_fwd.end()) {
return it->second;
}

auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(CUDNN_16BIT)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
Expand All @@ -100,16 +97,20 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
// QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute
auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q")
.set_dim({B, H, T, HS})
.set_uid(Q_UID)
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K")
.set_dim({B, H, T, HS})
.set_uid(K_UID)
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V")
.set_dim({B, H, T, HS})
.set_uid(V_UID)
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_uid(Attn_scale_UID)
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));

Expand All @@ -122,38 +123,47 @@ auto lookup_cache_or_build_graph_fwd(Args... args) {
auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options);

// Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32
O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1});
O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(O_UID);

assert(stats == nullptr || is_inference_only == false);
if (is_inference_only == false) {
stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({B, H, T, 1})
.set_stride({H * T, T, 1, 1});
.set_stride({H * T, T, 1, 1})
.set_uid(Stats_UID);
}

checkCudnnFE(graph->validate());
auto key = graph->key();
auto it = user_maintained_cache_fwd.find(key);
if (it != user_maintained_cache_fwd.end()) {
return it->second;
}

// Build the operation graph and execution part (this is the VERY SLOW PART)
checkCudnnFE(graph->build_operation_graph(cudnn_handle));
auto plans = graph->create_execution_plans({fe::HeurMode_t::A});
checkCudnnFE(graph->check_support(cudnn_handle));
checkCudnnFE(graph->build_plans(cudnn_handle));
assert(graph->get_workspace_size() <= cudnn_workspace_size); // fwd shouldn't need workspace
// Reallocate the workspace if the required size is greater than the current workspace
// In H100 this may be around 16B
if (graph->get_workspace_size() > cudnn_workspace_size) {
if (cudnn_workspace_size > 0) {
cudaCheck(cudaFree(cudnn_workspace));
}
cudnn_workspace_size = graph->get_workspace_size();
cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));
}

auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats);
user_maintained_cache_fwd.insert({key, tuple});
return tuple;
user_maintained_cache_fwd.insert({key, graph});

return graph;
}

template <typename... Args>
auto lookup_cache_or_build_graph_bwd(Args... args) {
auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
static cache_type_bwd user_maintained_cache_bwd;
auto [B, NH, T, HS] = std::make_tuple(args...);

auto key = std::make_tuple(B, NH, T, HS);

auto it = user_maintained_cache_bwd.find(key);
if (it != user_maintained_cache_bwd.end()) {
return it->second;
}

auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(CUDNN_16BIT)
Expand All @@ -164,28 +174,35 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
// must come from inp (which means we also need to convert THAT to FP16)
auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q")
.set_dim({B, NH, T, HS})
.set_uid(Q_UID)
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K")
.set_dim({B, NH, T, HS})
.set_uid(K_UID)
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V")
.set_dim({B, NH, T, HS})
.set_uid(V_UID)
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
auto O = graph->tensor(fe::graph::Tensor_attributes().set_name("O")
.set_dim({B, NH, T, HS})
.set_uid(O_UID)
.set_stride({NH * HS * T, HS, NH * HS, 1}));
auto dO = graph->tensor(fe::graph::Tensor_attributes().set_name("dO")
.set_dim({B, NH, T, HS})
.set_uid(dO_UID)
.set_stride({NH * HS * T, HS, NH * HS, 1}));

auto stats = graph->tensor(fe::graph::Tensor_attributes().set_name("stats")
.set_dim({B, NH, T, 1})
.set_uid(Stats_UID)
.set_stride({NH * T, T, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_uid(Attn_scale_UID)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes().set_name("flash_attention_backward")
.set_causal_mask(true)
Expand All @@ -194,16 +211,11 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
// Create the graph operation and get the output tensors back
auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options);

dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1});
dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dQ_UID);
dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dK_UID);
dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dV_UID);

checkCudnnFE(graph->validate());
auto key = graph->key();
auto it = user_maintained_cache_bwd.find(key);
if (it != user_maintained_cache_bwd.end()) {
return it->second;
}

// Build the operation graph and execution part (this is the VERY SLOW PART)
checkCudnnFE(graph->build_operation_graph(cudnn_handle));
Expand All @@ -221,9 +233,8 @@ auto lookup_cache_or_build_graph_bwd(Args... args) {
cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));
}

auto tuple = std::make_tuple(graph, Q, K, V, O, dO, stats, attn_scale, dQ, dK, dV);
user_maintained_cache_bwd.insert({key, tuple});
return tuple;
user_maintained_cache_bwd.insert({key, graph});
return graph;
}

void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
Expand All @@ -235,8 +246,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
bool is_inference_only = (stats == nullptr);

// Get graph and tensors from cache (or generate it on first use)
auto [graph, Q, K, V, attn_scale, O, softmax_stats] =
lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);
auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);

// Prepare all the tensor pointers for executing the graph
void* devPtrQ = inp;
Expand All @@ -246,12 +256,12 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
void* devPtrO = out;

// Build variant pack
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}};
std::unordered_map<int64_t , void*> variant_pack = {
{Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {Attn_scale_UID, &attn_scale_cpu}, {O_UID, devPtrO}};

// Add the stats tensor unless we are only doing inference (only needed for backward pass)
if (is_inference_only == false) {
variant_pack[softmax_stats] = stats;
variant_pack[Stats_UID] = stats;
}

// Execute graph
Expand All @@ -266,8 +276,7 @@ void attention_backward_cudnn(floatX* dqkvr,
int HS = C / NH; // number of features per head

// Get graph and tensors from cache (or generate it on first use)
auto [graph, Q, K, V, O, dO, Stats, attn_scale, dQ, dK, dV] =
lookup_cache_or_build_graph_bwd(B, NH, T, HS);
auto graph = lookup_cache_or_build_graph_bwd(B, NH, T, HS);

// Prepare all the tensor pointers for executing the graph
void* devPtrQ = qkvr;
Expand All @@ -283,10 +292,10 @@ void attention_backward_cudnn(floatX* dqkvr,
void* devPtrdV = (dqkvr + 2 * NH * HS);

// Build variant pack that links each tensor to its data pointer
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
{Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {O, devPtrO}, {dO, devPtrdO}, {Stats, devPtrStats},
{dQ, devPtrdQ}, {dK, devPtrdK}, {dV, devPtrdV},
{attn_scale, &attn_scale_cpu}};
std::unordered_map<int64_t, void*> variant_pack = {
{Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {O_UID, devPtrO}, {dO_UID, devPtrdO}, {Stats_UID, devPtrStats},
{dQ_UID, devPtrdQ}, {dK_UID, devPtrdK}, {dV_UID, devPtrdV},
{Attn_scale_UID, &attn_scale_cpu}};

// Execute graph
checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));
Expand Down

0 comments on commit 01c7a33

Please sign in to comment.