@@ -2479,9 +2479,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2479
2479
2480
2480
bool use_cuda_graph = true ;
2481
2481
bool cuda_graph_update_required = false ;
2482
- // vector of pointers to CUDA cpy kernels, which are required to identify
2483
- // kernel parameters which need updated in the graph for each token
2484
- std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2485
2482
2486
2483
if (cuda_ctx->cuda_graph ->graph == nullptr ) {
2487
2484
if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < CC_AMPERE) {
@@ -2527,7 +2524,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2527
2524
}
2528
2525
2529
2526
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2530
- cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2531
2527
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2532
2528
ggml_tensor * node = cgraph->nodes [i];
2533
2529
@@ -2554,16 +2550,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2554
2550
#endif
2555
2551
}
2556
2552
2557
- if (node->op == GGML_OP_CPY) {
2558
- // store the copy op parameter which changes with each token.
2559
- cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2560
- // store a pointer to each copy op CUDA kernel to identify it later
2561
- void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2562
- if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2563
- ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2564
- }
2565
- }
2566
-
2567
2553
if (!use_cuda_graph) {
2568
2554
break ;
2569
2555
}
@@ -2653,64 +2639,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2653
2639
CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2654
2640
}
2655
2641
2656
- // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2657
-
2658
2642
if (cuda_graph_update_required) {
2659
- // Extract nodes from graph
2660
- // First call with null argument gets number of nodes in graph
2661
- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2662
- // Subsequent call with non-null argument gets nodes
2663
- cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2664
- cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2665
- if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2666
- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2667
-
2668
- // Loop over nodes, and extract kernel parameters from each node
2669
- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2670
- cudaGraphNodeType node_type;
2671
- CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2672
- if (node_type == cudaGraphNodeTypeKernel) {
2673
- cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2674
- if (stat == cudaErrorInvalidDeviceFunction) {
2675
- // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2676
- // We don't need to update blas nodes, so clear error and move on.
2677
- cudaGetLastError ();
2678
- } else {
2679
- GGML_ASSERT (stat == cudaSuccess);
2680
- }
2681
- }
2682
- }
2683
- }
2684
- }
2685
-
2686
- // One of the arguments to the copy kernel is updated for each token, hence we need to
2687
- // replace that argument with the updated value in the CUDA graph
2688
- if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2689
- int k = 0 ;
2690
- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2691
- if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2692
- char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2693
- cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2694
- CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2695
- }
2696
- }
2697
- }
2698
-
2699
- // Update graph executable
2700
- cudaGraphExecUpdateResultInfo result_info;
2701
- cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2702
- if (stat == cudaErrorGraphExecUpdateFailure) {
2643
+ // Update graph executable
2644
+ cudaGraphExecUpdateResultInfo result_info;
2645
+ cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2646
+ if (stat == cudaErrorGraphExecUpdateFailure) {
2703
2647
#ifndef NDEBUG
2704
- GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2648
+ GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2705
2649
#endif
2706
- // The pre-existing graph exec cannot be updated due to violated constraints
2707
- // so instead clear error and re-instantiate
2708
- cudaGetLastError ();
2709
- CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2710
- cuda_ctx->cuda_graph ->instance = nullptr ;
2711
- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2712
- } else {
2713
- GGML_ASSERT (stat == cudaSuccess);
2650
+ // The pre-existing graph exec cannot be updated due to violated constraints
2651
+ // so instead clear error and re-instantiate
2652
+ cudaGetLastError ();
2653
+ CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2654
+ cuda_ctx->cuda_graph ->instance = nullptr ;
2655
+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2656
+ } else {
2657
+ GGML_ASSERT (stat == cudaSuccess);
2658
+ }
2714
2659
}
2715
2660
// Launch graph
2716
2661
CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
0 commit comments