diff --git a/src/xc_integrator/integrator_util/exx_screening.cxx b/src/xc_integrator/integrator_util/exx_screening.cxx index 5c7efcd13..6bed9ec84 100644 --- a/src/xc_integrator/integrator_util/exx_screening.cxx +++ b/src/xc_integrator/integrator_util/exx_screening.cxx @@ -253,7 +253,6 @@ void exx_ek_screening( const auto nshells = basis.nshells(); const size_t ntasks = std::distance(task_begin, task_end); - const size_t task_batch_size = 10000; // Setup EXX EK Screening memory on the device device_data.reset_allocations(); @@ -268,39 +267,25 @@ void exx_ek_screening( - auto task_batch_begin = task_begin; - while(task_batch_begin != task_end) { - - size_t nleft = std::distance(task_batch_begin, task_end); - exx_detail::host_task_iterator task_batch_end; - if(nleft > task_batch_size) - task_batch_end = task_batch_begin + task_batch_size; - else - task_batch_end = task_end; - - device_data.zero_exx_ek_screening_intermediates(); - - // Loop over tasks and form basis-related buffers - auto task_it = task_batch_begin; - while( task_it != task_batch_end ) { - - // Determine next task patch, send relevant data (EXX_EK only) - task_it = device_data.generate_buffers( enabled_terms, basis_map, task_it, - task_batch_end ); + auto task_it = task_begin; + while (task_it != task_end) { - // Evaluate collocation - lwd->eval_collocation( &device_data ); + device_data.zero_exx_ek_screening_intermediates(); + auto task_batch_begin = task_it; - // Evaluate EXX EK Screening Basis Statistics - lwd->eval_exx_ek_screening_bfn_stats( &device_data ); + // Determine next task patch, send relevant data (EXX_EK only) + task_it = device_data.generate_buffers(enabled_terms, basis_map, task_it, + task_end); - } + // Evaluate collocation + lwd->eval_collocation(&device_data); + // Evaluate EXX EK Screening Basis Statistics + lwd->eval_exx_ek_screening_bfn_stats(&device_data); - lwd->exx_ek_shellpair_collision( eps_E, eps_K, &device_data, task_batch_begin, - task_batch_end, shpairs ); - task_batch_begin = task_batch_end; + lwd->exx_ek_shellpair_collision(eps_E, eps_K, &device_data, task_batch_begin, + task_it, shpairs); } //GAUXC_CUDA_ERROR("End Sync", cudaDeviceSynchronize()); diff --git a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu index 86799ad25..2172fc066 100644 --- a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu +++ b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu @@ -199,8 +199,8 @@ __global__ void exx_ek_shellpair_collision_shared_kernel( int LD_coll, uint32_t* rc_collisions, int LD_rc, - uint32_t* counts, - uint32_t* rc_counts + uint64_t* counts, + uint64_t* rc_counts ) { extern __shared__ uint32_t s_rc_collisions[]; @@ -253,13 +253,13 @@ __global__ void exx_ek_shellpair_collision_shared_kernel( // TODO use thread block level reduction before writing to global memory - uint32_t count = 0; + unsigned long long count = 0; for(int ij = threadIdx.x; ij < LD_coll; ij+=blockDim.x) count += __popc(collisions[i_task * LD_coll + ij]); - atomicAdd(&(counts[i_task]), count); + atomicAdd((unsigned long long *)&(counts[i_task]), count); count = 0; for(int ij = threadIdx.x; ij < LD_rc; ij+=blockDim.x) count += __popc(rc_collisions[i_task * LD_rc + ij]); - atomicAdd(&(rc_counts[i_task]), count); + atomicAdd((unsigned long long *)&(rc_counts[i_task]), count); __syncthreads(); } @@ -289,7 +289,7 @@ __global__ void print_coll(size_t ntasks, size_t nshells, uint32_t* collisions, } } -__global__ void print_counts(size_t ntasks, uint32_t* counts) { +__global__ void print_counts(size_t ntasks, uint64_t* counts) { for(auto i_task = 0 ; i_task < ntasks; ++i_task) { @@ -308,7 +308,7 @@ __global__ void bitvector_to_position_list_shellpair( size_t nsp, size_t LD_bit, const uint32_t* collisions, - const uint32_t* counts, + const uint64_t* counts, uint32_t* position_list ) { @@ -370,9 +370,9 @@ __global__ void bitvector_to_position_list_shells( size_t nshells, size_t LD_bit, const uint32_t* collisions, - const uint32_t* counts, + const uint64_t* counts, const int32_t* shell_size, - uint32_t* position_list, + uint32_t* position_list, size_t* nbe_list ) { constexpr auto warp_size = cuda::warp_size; @@ -500,8 +500,8 @@ void exx_ek_shellpair_collision( using dur_t = std::chrono::duration; cudaStream_t stream = queue.queue_as(); - std::vector counts_host (ntasks); - std::vector rc_counts_host (ntasks); + std::vector counts_host (ntasks); + std::vector rc_counts_host (ntasks); const size_t nshell_pairs = shpairs.npairs(); const size_t LD_coll = util::div_ceil(nshell_pairs, 32); @@ -533,9 +533,9 @@ void exx_ek_shellpair_collision( buffer_adaptor full_stack(dyn_stack, dyn_size); auto collisions = full_stack.aligned_alloc(ntasks * LD_coll); - auto counts = full_stack.aligned_alloc(ntasks); + auto counts = full_stack.aligned_alloc(ntasks); auto rc_collisions = full_stack.aligned_alloc(ntasks * LD_rc); - auto rc_counts = full_stack.aligned_alloc(ntasks); + auto rc_counts = full_stack.aligned_alloc(ntasks); auto sp_check_st = hrt_t::now(); util::cuda_set_zero_async( ntasks * LD_coll,collisions.ptr, stream, "Zero Coll"); @@ -641,8 +641,8 @@ void exx_ek_shellpair_collision( auto scan_en = hrt_t::now(); dur_t scan_dur = scan_en - scan_st; - uint32_t total_sp_count = counts_host[ntasks-1]; - uint32_t total_s_count = rc_counts_host[ntasks-1]; + uint64_t total_sp_count = counts_host[ntasks-1]; + uint64_t total_s_count = rc_counts_host[ntasks-1]; //size_t global_sp_count = total_sp_count; //MPI_Allreduce(MPI_IN_PLACE, &global_sp_count, 1, MPI_UINT64_T, MPI_SUM, diff --git a/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx b/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx index 7818a5a83..3d0c6656a 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx +++ b/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx @@ -58,7 +58,7 @@ size_t Scheme1DataBase::get_static_mem_requirement() { nsp * sizeof(int32_t) + // nprim_pairs nsp * sizeof(shell_pair*) + // shell_pair pointer nsp * 3 * sizeof(double) + // X_AB, Y_AB, Z_AB - 1024 * 1024; // additional memory for alignment padding + 4 * 1024 * 1024; // additional memory for alignment padding return size; } diff --git a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx index 2e043842f..5aa7af758 100644 --- a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx +++ b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx @@ -53,7 +53,7 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, const size_t ncut_cou = submat_cut_cou.size(); const size_t nblock_cou = submat_block_cou.size(); - return base_size + + size_t mem_req = base_size + // Collocation + Derivatives reqt.task_bfn_size ( nbe_bfn, npts ) * sizeof(double) + reqt.task_bfn_grad_size( nbe_bfn, npts ) * sizeof(double) + @@ -87,9 +87,27 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, // Map from packed to unpacked indices reqt.task_bfn_shell_indirection_size( nbe_bfn ) * sizeof(int32_t) + - + // Memory associated with task indirection: valid for both AoS and SoA reqt.task_indirection_size() * sizeof(XCDeviceTask); + + // Collision reuses dynamic memory after bfn_stats completes, so the + // per-task requirement is max(mem_coll + mem_bfn_stats, mem_collision). + if (reqt.task_exx_collision) { + const size_t nshells = global_dims.nshells; + const size_t nshell_pairs = global_dims.nshell_pairs; + const size_t nbf = global_dims.nbf; + size_t coll_scratch = + reqt.task_exx_coll_bitvec_size(nshell_pairs, nshells) * sizeof(uint32_t) + + 2 * sizeof(uint64_t) + // counts + rc_counts + std::max( + reqt.task_exx_coll_fmax_size(nshells, nbf) * sizeof(double), + reqt.task_exx_coll_position_size(nshell_pairs, nshells) * sizeof(uint32_t) + + sizeof(size_t) + ); + return std::max(mem_req, coll_scratch); + } + return mem_req; } diff --git a/src/xc_integrator/xc_data/device/xc_device_data.hpp b/src/xc_integrator/xc_data/device/xc_device_data.hpp index 781e23729..38bcf6cde 100644 --- a/src/xc_integrator/xc_data/device/xc_device_data.hpp +++ b/src/xc_integrator/xc_data/device/xc_device_data.hpp @@ -376,6 +376,7 @@ struct required_term_storage { bool task_gmat = false; bool task_nbe_scr = false; bool task_bfn_shell_indirection = false; + bool task_exx_collision = false; inline size_t task_bfn_size(size_t nbe, size_t npts) { @@ -507,6 +508,20 @@ struct required_term_storage { return PRDVL(task_to_shell_pair_cou, num_subtasks); } + inline size_t task_exx_coll_bitvec_size(size_t nshell_pairs, size_t nshells) { + // collision + rc_collision bitvectors + const size_t LD_coll = (nshell_pairs + 31) / 32; + const size_t LD_rc = (nshells + 31) / 32; + return PRDVL(task_exx_collision, LD_coll + LD_rc); + } + inline size_t task_exx_coll_fmax_size(size_t nshells, size_t nbf) { + // fmax scratch (fmax_bfn + fmax_shell) + return PRDVL(task_exx_collision, nbf + nshells); + } + inline size_t task_exx_coll_position_size(size_t nshell_pairs, size_t nshells) { + // position lists (shell-pairs + shells) + return PRDVL(task_exx_collision, nshell_pairs + nshells); + } inline explicit required_term_storage(integrator_term_tracker tracker) { @@ -638,6 +653,7 @@ struct required_term_storage { task_shell_offs_bfn = true; task_bfn_shell_indirection = true; shell_to_task_bfn = true; + task_exx_collision = true; } }