Skip to content

Commit a3d48e4

Browse files
committed
Simplify and improve CUDA graphs through use of indirect copy pointers
Previously there was complexity in the CUDA graphs implementation due frequently changing parameters to copy kernels associated with K and V cache pointers. This patch simplifies by using indirection to avoid such parameters frequently changing, avoiding the need for frequent graph updates.
1 parent 0fd93cd commit a3d48e4

File tree

5 files changed

+118
-135
lines changed

5 files changed

+118
-135
lines changed

ggml/include/ggml-backend.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ extern "C" {
232232
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
233233
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
234234

235+
// Copy K and V cache pointers to backend
236+
GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size);
237+
GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size);
235238

236239
#ifdef __cplusplus
237240
}

ggml/src/ggml-cuda.cu

Lines changed: 14 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,9 +2479,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24792479

24802480
bool use_cuda_graph = true;
24812481
bool cuda_graph_update_required = false;
2482-
// vector of pointers to CUDA cpy kernels, which are required to identify
2483-
// kernel parameters which need updated in the graph for each token
2484-
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
24852482

24862483
if (cuda_ctx->cuda_graph->graph == nullptr) {
24872484
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2527,7 +2524,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25272524
}
25282525

25292526
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2530-
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
25312527
for (int i = 0; i < cgraph->n_nodes; i++) {
25322528
ggml_tensor * node = cgraph->nodes[i];
25332529

@@ -2554,16 +2550,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25542550
#endif
25552551
}
25562552

2557-
if (node->op == GGML_OP_CPY) {
2558-
// store the copy op parameter which changes with each token.
2559-
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2560-
// store a pointer to each copy op CUDA kernel to identify it later
2561-
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2562-
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2563-
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2564-
}
2565-
}
2566-
25672553
if (!use_cuda_graph) {
25682554
break;
25692555
}
@@ -2653,64 +2639,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26532639
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
26542640
}
26552641

2656-
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2657-
26582642
if (cuda_graph_update_required) {
2659-
// Extract nodes from graph
2660-
// First call with null argument gets number of nodes in graph
2661-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2662-
// Subsequent call with non-null argument gets nodes
2663-
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2664-
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2665-
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2666-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2667-
2668-
// Loop over nodes, and extract kernel parameters from each node
2669-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2670-
cudaGraphNodeType node_type;
2671-
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2672-
if (node_type == cudaGraphNodeTypeKernel) {
2673-
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2674-
if (stat == cudaErrorInvalidDeviceFunction) {
2675-
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2676-
// We don't need to update blas nodes, so clear error and move on.
2677-
cudaGetLastError();
2678-
} else {
2679-
GGML_ASSERT(stat == cudaSuccess);
2680-
}
2681-
}
2682-
}
2683-
}
2684-
}
2685-
2686-
// One of the arguments to the copy kernel is updated for each token, hence we need to
2687-
// replace that argument with the updated value in the CUDA graph
2688-
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2689-
int k = 0;
2690-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2691-
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2692-
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2693-
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2694-
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2695-
}
2696-
}
2697-
}
2698-
2699-
// Update graph executable
2700-
cudaGraphExecUpdateResultInfo result_info;
2701-
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2702-
if (stat == cudaErrorGraphExecUpdateFailure) {
2643+
// Update graph executable
2644+
cudaGraphExecUpdateResultInfo result_info;
2645+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2646+
if (stat == cudaErrorGraphExecUpdateFailure) {
27032647
#ifndef NDEBUG
2704-
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
2648+
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
27052649
#endif
2706-
// The pre-existing graph exec cannot be updated due to violated constraints
2707-
// so instead clear error and re-instantiate
2708-
cudaGetLastError();
2709-
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2710-
cuda_ctx->cuda_graph->instance = nullptr;
2711-
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2712-
} else {
2713-
GGML_ASSERT(stat == cudaSuccess);
2650+
// The pre-existing graph exec cannot be updated due to violated constraints
2651+
// so instead clear error and re-instantiate
2652+
cudaGetLastError();
2653+
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2654+
cuda_ctx->cuda_graph->instance = nullptr;
2655+
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2656+
} else {
2657+
GGML_ASSERT(stat == cudaSuccess);
2658+
}
27142659
}
27152660
// Launch graph
27162661
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,11 @@ struct ggml_cuda_graph {
583583
}
584584
cudaGraph_t graph = nullptr;
585585
cudaGraphExec_t instance = nullptr;
586-
size_t num_nodes = 0;
587-
std::vector<cudaGraphNode_t> nodes;
588-
std::vector<cudaKernelNodeParams> params;
589586
bool disable_due_to_gpu_arch = false;
590587
bool disable_due_to_too_many_updates = false;
591588
bool disable_due_to_failed_graph_capture = false;
592589
int number_consecutive_updates = 0;
593590
std::vector<ggml_graph_node_properties> ggml_graph_properties;
594-
std::vector<char **> updated_kernel_arg;
595591
#endif
596592
};
597593

0 commit comments

Comments
 (0)