From eea68c63fdd124e8c2f7638c2aaab30269c8b7ed Mon Sep 17 00:00:00 2001 From: Vladimir Mironov Date: Thu, 12 Mar 2026 16:49:09 +0100 Subject: [PATCH 1/3] Fix memory issues in CUDA EXX screening --- .../integrator_util/exx_screening.cxx | 41 ++++++------------- .../kernels/exx_ek_screening_bfn_stats.cu | 40 +++++++++--------- .../device/scheme1_data_base.cxx | 2 +- .../xc_data/device/xc_device_aos_data.cxx | 4 ++ .../xc_data/device/xc_device_data.hpp | 8 ++++ 5 files changed, 46 insertions(+), 49 deletions(-) diff --git a/src/xc_integrator/integrator_util/exx_screening.cxx b/src/xc_integrator/integrator_util/exx_screening.cxx index 5c7efcd13..6bed9ec84 100644 --- a/src/xc_integrator/integrator_util/exx_screening.cxx +++ b/src/xc_integrator/integrator_util/exx_screening.cxx @@ -253,7 +253,6 @@ void exx_ek_screening( const auto nshells = basis.nshells(); const size_t ntasks = std::distance(task_begin, task_end); - const size_t task_batch_size = 10000; // Setup EXX EK Screening memory on the device device_data.reset_allocations(); @@ -268,39 +267,25 @@ void exx_ek_screening( - auto task_batch_begin = task_begin; - while(task_batch_begin != task_end) { - - size_t nleft = std::distance(task_batch_begin, task_end); - exx_detail::host_task_iterator task_batch_end; - if(nleft > task_batch_size) - task_batch_end = task_batch_begin + task_batch_size; - else - task_batch_end = task_end; - - device_data.zero_exx_ek_screening_intermediates(); - - // Loop over tasks and form basis-related buffers - auto task_it = task_batch_begin; - while( task_it != task_batch_end ) { - - // Determine next task patch, send relevant data (EXX_EK only) - task_it = device_data.generate_buffers( enabled_terms, basis_map, task_it, - task_batch_end ); + auto task_it = task_begin; + while (task_it != task_end) { - // Evaluate collocation - lwd->eval_collocation( &device_data ); + device_data.zero_exx_ek_screening_intermediates(); + auto task_batch_begin = task_it; - // Evaluate EXX EK Screening Basis Statistics - lwd->eval_exx_ek_screening_bfn_stats( &device_data ); + // Determine next task patch, send relevant data (EXX_EK only) + task_it = device_data.generate_buffers(enabled_terms, basis_map, task_it, + task_end); - } + // Evaluate collocation + lwd->eval_collocation(&device_data); + // Evaluate EXX EK Screening Basis Statistics + lwd->eval_exx_ek_screening_bfn_stats(&device_data); - lwd->exx_ek_shellpair_collision( eps_E, eps_K, &device_data, task_batch_begin, - task_batch_end, shpairs ); - task_batch_begin = task_batch_end; + lwd->exx_ek_shellpair_collision(eps_E, eps_K, &device_data, task_batch_begin, + task_it, shpairs); } //GAUXC_CUDA_ERROR("End Sync", cudaDeviceSynchronize()); diff --git a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu index 86799ad25..81cda6952 100644 --- a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu +++ b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu @@ -199,8 +199,8 @@ __global__ void exx_ek_shellpair_collision_shared_kernel( int LD_coll, uint32_t* rc_collisions, int LD_rc, - uint32_t* counts, - uint32_t* rc_counts + uint64_t* counts, + uint64_t* rc_counts ) { extern __shared__ uint32_t s_rc_collisions[]; @@ -253,13 +253,13 @@ __global__ void exx_ek_shellpair_collision_shared_kernel( // TODO use thread block level reduction before writing to global memory - uint32_t count = 0; + unsigned long long count = 0; for(int ij = threadIdx.x; ij < LD_coll; ij+=blockDim.x) count += __popc(collisions[i_task * LD_coll + ij]); - atomicAdd(&(counts[i_task]), count); + atomicAdd((unsigned long long *)&(counts[i_task]), count); count = 0; for(int ij = threadIdx.x; ij < LD_rc; ij+=blockDim.x) count += __popc(rc_collisions[i_task * LD_rc + ij]); - atomicAdd(&(rc_counts[i_task]), count); + atomicAdd((unsigned long long *)&(rc_counts[i_task]), count); __syncthreads(); } @@ -289,7 +289,7 @@ __global__ void print_coll(size_t ntasks, size_t nshells, uint32_t* collisions, } } -__global__ void print_counts(size_t ntasks, uint32_t* counts) { +__global__ void print_counts(size_t ntasks, uint64_t* counts) { for(auto i_task = 0 ; i_task < ntasks; ++i_task) { @@ -308,8 +308,8 @@ __global__ void bitvector_to_position_list_shellpair( size_t nsp, size_t LD_bit, const uint32_t* collisions, - const uint32_t* counts, - uint32_t* position_list + const uint64_t* counts, + uint64_t* position_list ) { constexpr auto warp_size = cuda::warp_size; @@ -370,9 +370,9 @@ __global__ void bitvector_to_position_list_shells( size_t nshells, size_t LD_bit, const uint32_t* collisions, - const uint32_t* counts, + const uint64_t* counts, const int32_t* shell_size, - uint32_t* position_list, + uint64_t* position_list, size_t* nbe_list ) { constexpr auto warp_size = cuda::warp_size; @@ -500,8 +500,8 @@ void exx_ek_shellpair_collision( using dur_t = std::chrono::duration; cudaStream_t stream = queue.queue_as(); - std::vector counts_host (ntasks); - std::vector rc_counts_host (ntasks); + std::vector counts_host (ntasks); + std::vector rc_counts_host (ntasks); const size_t nshell_pairs = shpairs.npairs(); const size_t LD_coll = util::div_ceil(nshell_pairs, 32); @@ -533,9 +533,9 @@ void exx_ek_shellpair_collision( buffer_adaptor full_stack(dyn_stack, dyn_size); auto collisions = full_stack.aligned_alloc(ntasks * LD_coll); - auto counts = full_stack.aligned_alloc(ntasks); + auto counts = full_stack.aligned_alloc(ntasks); auto rc_collisions = full_stack.aligned_alloc(ntasks * LD_rc); - auto rc_counts = full_stack.aligned_alloc(ntasks); + auto rc_counts = full_stack.aligned_alloc(ntasks); auto sp_check_st = hrt_t::now(); util::cuda_set_zero_async( ntasks * LD_coll,collisions.ptr, stream, "Zero Coll"); @@ -641,8 +641,8 @@ void exx_ek_shellpair_collision( auto scan_en = hrt_t::now(); dur_t scan_dur = scan_en - scan_st; - uint32_t total_sp_count = counts_host[ntasks-1]; - uint32_t total_s_count = rc_counts_host[ntasks-1]; + uint64_t total_sp_count = counts_host[ntasks-1]; + uint64_t total_s_count = rc_counts_host[ntasks-1]; //size_t global_sp_count = total_sp_count; //MPI_Allreduce(MPI_IN_PLACE, &global_sp_count, 1, MPI_UINT64_T, MPI_SUM, @@ -653,8 +653,8 @@ void exx_ek_shellpair_collision( auto bv_st = hrt_t::now(); - auto position_sp_list_device = full_stack.aligned_alloc(total_sp_count); - auto position_s_list_device = full_stack.aligned_alloc(total_s_count); + auto position_sp_list_device = full_stack.aligned_alloc(total_sp_count); + auto position_s_list_device = full_stack.aligned_alloc(total_s_count); auto nbe_list = full_stack.aligned_alloc(ntasks); { dim3 threads(32,32); @@ -668,7 +668,7 @@ void exx_ek_shellpair_collision( ); } - std::vector position_sp_list(total_sp_count); + std::vector position_sp_list(total_sp_count); util::cuda_copy(total_sp_count, position_sp_list.data(), position_sp_list_device.ptr, "Position List ShellPair"); auto bv_en = hrt_t::now(); @@ -676,7 +676,7 @@ void exx_ek_shellpair_collision( auto d2h_st = hrt_t::now(); - std::vector position_s_list(total_s_count); + std::vector position_s_list(total_s_count); std::vector nbe_list_host(ntasks); util::cuda_copy(total_s_count, position_s_list.data(), position_s_list_device.ptr, "Position List Shell"); util::cuda_copy(ntasks, nbe_list_host.data(), nbe_list.ptr, "NBE List"); diff --git a/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx b/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx index 7818a5a83..3d0c6656a 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx +++ b/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx @@ -58,7 +58,7 @@ size_t Scheme1DataBase::get_static_mem_requirement() { nsp * sizeof(int32_t) + // nprim_pairs nsp * sizeof(shell_pair*) + // shell_pair pointer nsp * 3 * sizeof(double) + // X_AB, Y_AB, Z_AB - 1024 * 1024; // additional memory for alignment padding + 4 * 1024 * 1024; // additional memory for alignment padding return size; } diff --git a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx index 2e043842f..c7c76d245 100644 --- a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx +++ b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx @@ -52,6 +52,7 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, const size_t nbe_cou = task.cou_screening.nbe; const size_t ncut_cou = submat_cut_cou.size(); const size_t nblock_cou = submat_block_cou.size(); + const size_t nshells = global_dims.nshells; return base_size + // Collocation + Derivatives @@ -88,6 +89,9 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, // Map from packed to unpacked indices reqt.task_bfn_shell_indirection_size( nbe_bfn ) * sizeof(int32_t) + + // Scratch memory to store shell pairs + reqt.task_exx_collision_size( nshells ) * sizeof(int64_t) + + // Memory associated with task indirection: valid for both AoS and SoA reqt.task_indirection_size() * sizeof(XCDeviceTask); } diff --git a/src/xc_integrator/xc_data/device/xc_device_data.hpp b/src/xc_integrator/xc_data/device/xc_device_data.hpp index 781e23729..a2579bf44 100644 --- a/src/xc_integrator/xc_data/device/xc_device_data.hpp +++ b/src/xc_integrator/xc_data/device/xc_device_data.hpp @@ -376,6 +376,7 @@ struct required_term_storage { bool task_gmat = false; bool task_nbe_scr = false; bool task_bfn_shell_indirection = false; + bool task_exx_collision = false; inline size_t task_bfn_size(size_t nbe, size_t npts) { @@ -506,6 +507,12 @@ struct required_term_storage { const size_t num_subtasks = util::div_ceil(npts, subtask_size); return PRDVL(task_to_shell_pair_cou, num_subtasks); } + inline size_t task_exx_collision_size(size_t nshells) { + const size_t nslt = (nshells * (nshells+1)) / 2 + + nshells + ; + return PRDVL(task_exx_collision, nslt); + } @@ -638,6 +645,7 @@ struct required_term_storage { task_shell_offs_bfn = true; task_bfn_shell_indirection = true; shell_to_task_bfn = true; + task_exx_collision = true; } } From 68579bf38b29c3c51c56cbc44e060518e0a67309 Mon Sep 17 00:00:00 2001 From: Vladimir Mironov Date: Fri, 17 Apr 2026 15:22:44 +0000 Subject: [PATCH 2/3] Fix collision memory estimates Compute collision scratch requirements as max(mem_coll + mem_bfn_stats, mem_collion) per task. Replace task_exx_collision_size with task_exx_coll_bitvec_size, task_exx_coll_fmax_size, and task_exx_coll_position_size helpers. --- .../xc_data/device/xc_device_aos_data.cxx | 24 +++++++++++++++---- .../xc_data/device/xc_device_data.hpp | 20 +++++++++++----- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx index c7c76d245..5aa7af758 100644 --- a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx +++ b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx @@ -52,9 +52,8 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, const size_t nbe_cou = task.cou_screening.nbe; const size_t ncut_cou = submat_cut_cou.size(); const size_t nblock_cou = submat_block_cou.size(); - const size_t nshells = global_dims.nshells; - return base_size + + size_t mem_req = base_size + // Collocation + Derivatives reqt.task_bfn_size ( nbe_bfn, npts ) * sizeof(double) + reqt.task_bfn_grad_size( nbe_bfn, npts ) * sizeof(double) + @@ -88,12 +87,27 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, // Map from packed to unpacked indices reqt.task_bfn_shell_indirection_size( nbe_bfn ) * sizeof(int32_t) + - - // Scratch memory to store shell pairs - reqt.task_exx_collision_size( nshells ) * sizeof(int64_t) + // Memory associated with task indirection: valid for both AoS and SoA reqt.task_indirection_size() * sizeof(XCDeviceTask); + + // Collision reuses dynamic memory after bfn_stats completes, so the + // per-task requirement is max(mem_coll + mem_bfn_stats, mem_collision). + if (reqt.task_exx_collision) { + const size_t nshells = global_dims.nshells; + const size_t nshell_pairs = global_dims.nshell_pairs; + const size_t nbf = global_dims.nbf; + size_t coll_scratch = + reqt.task_exx_coll_bitvec_size(nshell_pairs, nshells) * sizeof(uint32_t) + + 2 * sizeof(uint64_t) + // counts + rc_counts + std::max( + reqt.task_exx_coll_fmax_size(nshells, nbf) * sizeof(double), + reqt.task_exx_coll_position_size(nshell_pairs, nshells) * sizeof(uint32_t) + + sizeof(size_t) + ); + return std::max(mem_req, coll_scratch); + } + return mem_req; } diff --git a/src/xc_integrator/xc_data/device/xc_device_data.hpp b/src/xc_integrator/xc_data/device/xc_device_data.hpp index a2579bf44..38bcf6cde 100644 --- a/src/xc_integrator/xc_data/device/xc_device_data.hpp +++ b/src/xc_integrator/xc_data/device/xc_device_data.hpp @@ -507,13 +507,21 @@ struct required_term_storage { const size_t num_subtasks = util::div_ceil(npts, subtask_size); return PRDVL(task_to_shell_pair_cou, num_subtasks); } - inline size_t task_exx_collision_size(size_t nshells) { - const size_t nslt = (nshells * (nshells+1)) / 2 - + nshells - ; - return PRDVL(task_exx_collision, nslt); - } + inline size_t task_exx_coll_bitvec_size(size_t nshell_pairs, size_t nshells) { + // collision + rc_collision bitvectors + const size_t LD_coll = (nshell_pairs + 31) / 32; + const size_t LD_rc = (nshells + 31) / 32; + return PRDVL(task_exx_collision, LD_coll + LD_rc); + } + inline size_t task_exx_coll_fmax_size(size_t nshells, size_t nbf) { + // fmax scratch (fmax_bfn + fmax_shell) + return PRDVL(task_exx_collision, nbf + nshells); + } + inline size_t task_exx_coll_position_size(size_t nshell_pairs, size_t nshells) { + // position lists (shell-pairs + shells) + return PRDVL(task_exx_collision, nshell_pairs + nshells); + } inline explicit required_term_storage(integrator_term_tracker tracker) { From a0220259c7e436b0edf1417e86e9f462c9d460ea Mon Sep 17 00:00:00 2001 From: Vladimir Mironov Date: Fri, 24 Apr 2026 10:49:19 +0000 Subject: [PATCH 3/3] Undo unneeded uint32_t -> uint64_t promotions Only use uint64_t for prefix sum counters --- .../cuda/kernels/exx_ek_screening_bfn_stats.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu index 81cda6952..2172fc066 100644 --- a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu +++ b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu @@ -309,7 +309,7 @@ __global__ void bitvector_to_position_list_shellpair( size_t LD_bit, const uint32_t* collisions, const uint64_t* counts, - uint64_t* position_list + uint32_t* position_list ) { constexpr auto warp_size = cuda::warp_size; @@ -372,7 +372,7 @@ __global__ void bitvector_to_position_list_shells( const uint32_t* collisions, const uint64_t* counts, const int32_t* shell_size, - uint64_t* position_list, + uint32_t* position_list, size_t* nbe_list ) { constexpr auto warp_size = cuda::warp_size; @@ -653,8 +653,8 @@ void exx_ek_shellpair_collision( auto bv_st = hrt_t::now(); - auto position_sp_list_device = full_stack.aligned_alloc(total_sp_count); - auto position_s_list_device = full_stack.aligned_alloc(total_s_count); + auto position_sp_list_device = full_stack.aligned_alloc(total_sp_count); + auto position_s_list_device = full_stack.aligned_alloc(total_s_count); auto nbe_list = full_stack.aligned_alloc(ntasks); { dim3 threads(32,32); @@ -668,7 +668,7 @@ void exx_ek_shellpair_collision( ); } - std::vector position_sp_list(total_sp_count); + std::vector position_sp_list(total_sp_count); util::cuda_copy(total_sp_count, position_sp_list.data(), position_sp_list_device.ptr, "Position List ShellPair"); auto bv_en = hrt_t::now(); @@ -676,7 +676,7 @@ void exx_ek_shellpair_collision( auto d2h_st = hrt_t::now(); - std::vector position_s_list(total_s_count); + std::vector position_s_list(total_s_count); std::vector nbe_list_host(ntasks); util::cuda_copy(total_s_count, position_s_list.data(), position_s_list_device.ptr, "Position List Shell"); util::cuda_copy(ntasks, nbe_list_host.data(), nbe_list.ptr, "NBE List");