diff --git a/cmake/gauxc-onedft.cmake b/cmake/gauxc-onedft.cmake index 6e7998490..269afe812 100644 --- a/cmake/gauxc-onedft.cmake +++ b/cmake/gauxc-onedft.cmake @@ -1,33 +1,43 @@ -find_package(nlohmann_json) -if( NOT nlohmann_json_FOUND ) +if(NOT TARGET nlohmann_json::nlohmann_json) + find_package(nlohmann_json) + if( NOT nlohmann_json_FOUND ) - message( STATUS "Could Not Find nlohmann_json... Building" ) - message( STATUS "NLOHMANN_JSON URL = ${GAUXC_NLOHMANN_JSON_URL}" ) + message( STATUS "Could Not Find nlohmann_json... Building" ) + message( STATUS "NLOHMANN_JSON URL = ${GAUXC_NLOHMANN_JSON_URL}" ) - FetchContent_Declare( - nlohmann_json - URL ${GAUXC_NLOHMANN_JSON_URL} - URL_HASH SHA256=${GAUXC_NLOHMANN_JSON_SHA256} - DOWNLOAD_EXTRACT_TIMESTAMP ON - ) + FetchContent_Declare( + nlohmann_json + URL ${GAUXC_NLOHMANN_JSON_URL} + URL_HASH SHA256=${GAUXC_NLOHMANN_JSON_SHA256} + DOWNLOAD_EXTRACT_TIMESTAMP ON + ) - FetchContent_GetProperties( nlohmann_json ) - if( NOT nlohmann_json_POPULATED ) - FetchContent_Populate( nlohmann_json ) - endif() + FetchContent_GetProperties( nlohmann_json ) + if( NOT nlohmann_json_POPULATED ) + FetchContent_Populate( nlohmann_json ) + endif() - add_library( nlohmann_json::nlohmann_json INTERFACE IMPORTED ) - set_target_properties( nlohmann_json::nlohmann_json PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES ${nlohmann_json_SOURCE_DIR}/include - ) + add_library( nlohmann_json::nlohmann_json INTERFACE IMPORTED ) + set_target_properties( nlohmann_json::nlohmann_json PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES ${nlohmann_json_SOURCE_DIR}/include + ) + endif() endif() -# store and restore CMAKE_CUDA_ARCHITECTURES if Torch clobbers it +# store and restore CMAKE_CUDA_ARCHITECTURES and CMAKE_CUDA_FLAGS if Torch clobbers them set(_PREV_CUDA_ARCHS "${CMAKE_CUDA_ARCHITECTURES}") +set(_PREV_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") find_package(Torch REQUIRED) -if(CMAKE_CUDA_ARCHITECTURES STREQUAL "OFF") - set(CMAKE_CUDA_ARCHITECTURES "${_PREV_CUDA_ARCHS}" CACHE STRING "Restore CUDA archs after Torch override" FORCE) - message(WARNING "Torch set CMAKE_CUDA_ARCHITECTURES to OFF. Restored previous value: ${CMAKE_CUDA_ARCHITECTURES}") +# Restore CMAKE_CUDA_ARCHITECTURES (Torch may set it to OFF) +if(NOT "${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "${_PREV_CUDA_ARCHS}") + set(CMAKE_CUDA_ARCHITECTURES "${_PREV_CUDA_ARCHS}" CACHE STRING "" FORCE) + message(WARNING "Torch changed CMAKE_CUDA_ARCHITECTURES. Restored previous value: ${CMAKE_CUDA_ARCHITECTURES}") +endif() +# Strip Torch-injected -gencode flags from CMAKE_CUDA_FLAGS (PyTorch issue #71379) +string(REGEX REPLACE " -gencode [^ ]+" "" _cleaned_cuda_flags "${CMAKE_CUDA_FLAGS}") +if(NOT "${_cleaned_cuda_flags}" STREQUAL "${CMAKE_CUDA_FLAGS}") + set(CMAKE_CUDA_FLAGS "${_cleaned_cuda_flags}" CACHE STRING "" FORCE) + message(WARNING "Stripped Torch-injected -gencode flags from CMAKE_CUDA_FLAGS") endif() list(REMOVE_ITEM TORCH_LIBRARIES torch::nvtoolsext) message(STATUS "Torch libraries without nvtoolsext: ${TORCH_LIBRARIES}") diff --git a/include/gauxc/gauxc_config.hpp.in b/include/gauxc/gauxc_config.hpp.in index d05a360e8..6bfba814c 100644 --- a/include/gauxc/gauxc_config.hpp.in +++ b/include/gauxc/gauxc_config.hpp.in @@ -21,6 +21,7 @@ #cmakedefine GAUXC_HAS_CUTLASS #cmakedefine GAUXC_HAS_GAU2GRID #cmakedefine GAUXC_HAS_HDF5 +#cmakedefine GAUXC_HAS_ONEDFT #cmakedefine GAUXC_USE_FAST_RSQRT #ifdef GAUXC_HAS_ONEDFT diff --git a/include/gauxc/xc_integrator.hpp b/include/gauxc/xc_integrator.hpp index 798ffb515..8b3ebc664 100644 --- a/include/gauxc/xc_integrator.hpp +++ b/include/gauxc/xc_integrator.hpp @@ -78,6 +78,8 @@ class XCIntegrator { exc_grad_type eval_exc_grad( const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} ); exc_grad_type eval_exc_grad( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} ); + exc_grad_type eval_exc_grad_onedft( const MatrixType&, const MatrixType&, + const IntegratorSettingsXC& = IntegratorSettingsXC{} ); exx_type eval_exx ( const MatrixType&, const IntegratorSettingsEXX& = IntegratorSettingsEXX{} ); diff --git a/include/gauxc/xc_integrator/impl.hpp b/include/gauxc/xc_integrator/impl.hpp index 02ceeac6b..a7f437c0e 100644 --- a/include/gauxc/xc_integrator/impl.hpp +++ b/include/gauxc/xc_integrator/impl.hpp @@ -99,6 +99,13 @@ typename XCIntegrator::exc_grad_type return pimpl_->eval_exc_grad(Ps, Pz, ks_settings); }; +template +typename XCIntegrator::exc_grad_type + XCIntegrator::eval_exc_grad_onedft( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) { + if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED(); + return pimpl_->eval_exc_grad_onedft(Ps, Pz, ks_settings); +}; + template typename XCIntegrator::exx_type XCIntegrator::eval_exx( const MatrixType& P, diff --git a/include/gauxc/xc_integrator/replicated/impl.hpp b/include/gauxc/xc_integrator/replicated/impl.hpp index 09caefc0e..3303c15b0 100644 --- a/include/gauxc/xc_integrator/replicated/impl.hpp +++ b/include/gauxc/xc_integrator/replicated/impl.hpp @@ -205,6 +205,20 @@ typename ReplicatedXCIntegrator::exc_grad_type } +template +typename ReplicatedXCIntegrator::exc_grad_type + ReplicatedXCIntegrator::eval_exc_grad_onedft_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) { + + if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED(); + + std::vector EXC_GRAD( 3*pimpl_->load_balancer().molecule().natoms() ); + pimpl_->eval_exc_grad_onedft( Ps.rows(), Ps.cols(), Ps.data(), Ps.rows(), Pz.data(), Pz.rows(), + EXC_GRAD.data(), ks_settings ); + + return EXC_GRAD; + +} + template typename ReplicatedXCIntegrator::exx_type ReplicatedXCIntegrator::eval_exx_( const MatrixType& P, const IntegratorSettingsEXX& settings ) { diff --git a/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp b/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp index fa0e3763c..2ca3ec12b 100644 --- a/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp +++ b/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp @@ -88,6 +88,8 @@ class ReplicatedXCIntegratorImpl { value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) = 0; virtual void eval_exc_grad_( int64_t m, int64_t n, const value_type* P, int64_t ldps, const value_type* Pz, int64_t lpdz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) = 0; + virtual void eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) = 0; virtual void eval_exx_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ) = 0; @@ -169,6 +171,8 @@ class ReplicatedXCIntegratorImpl { value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ); void eval_exc_grad( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ); + void eval_exc_grad_onedft( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ); void eval_exx( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* K, int64_t ldk, diff --git a/include/gauxc/xc_integrator/replicated_xc_integrator.hpp b/include/gauxc/xc_integrator/replicated_xc_integrator.hpp index a702b3b34..0925e4ee6 100644 --- a/include/gauxc/xc_integrator/replicated_xc_integrator.hpp +++ b/include/gauxc/xc_integrator/replicated_xc_integrator.hpp @@ -57,6 +57,7 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl { exc_vxc_type_uks eval_exc_vxc_onedft_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override; exc_grad_type eval_exc_grad_( const MatrixType&, const IntegratorSettingsXC& ) override; exc_grad_type eval_exc_grad_( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override; + exc_grad_type eval_exc_grad_onedft_( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override; exx_type eval_exx_ ( const MatrixType&, const IntegratorSettingsEXX& ) override; fxc_contraction_type_rks eval_fxc_contraction_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override; fxc_contraction_type_uks eval_fxc_contraction_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC&) override; diff --git a/include/gauxc/xc_integrator/xc_integrator_impl.hpp b/include/gauxc/xc_integrator/xc_integrator_impl.hpp index 300117fee..a9143f04b 100644 --- a/include/gauxc/xc_integrator/xc_integrator_impl.hpp +++ b/include/gauxc/xc_integrator/xc_integrator_impl.hpp @@ -49,6 +49,7 @@ class XCIntegratorImpl { virtual exc_vxc_type_uks eval_exc_vxc_onedft_ ( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0; virtual exc_grad_type eval_exc_grad_( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) = 0; virtual exc_grad_type eval_exc_grad_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0; + virtual exc_grad_type eval_exc_grad_onedft_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0; virtual exx_type eval_exx_ ( const MatrixType& P, const IntegratorSettingsEXX& settings ) = 0; virtual fxc_contraction_type_rks eval_fxc_contraction_ ( const MatrixType& P, @@ -152,6 +153,16 @@ class XCIntegratorImpl { return eval_exc_grad_(Ps, Pz, ks_settings); } + /** Integrate EXC gradient for OneDFT models + * + * @param[in] Ps The total density matrix + * @param[in] Pz The magnetization density matrix + * @returns EXC gradient + */ + exc_grad_type eval_exc_grad_onedft( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) { + return eval_exc_grad_onedft_(Ps, Pz, ks_settings); + } + /** Integrate Exact Exchange for RHF * * @param[in] P The alpha density matrix diff --git a/include/gauxc/xc_task.hpp b/include/gauxc/xc_task.hpp index 0b574ae71..4e0f5cb57 100644 --- a/include/gauxc/xc_task.hpp +++ b/include/gauxc/xc_task.hpp @@ -70,6 +70,8 @@ struct XCTask { std::vector vdden_z_eval_a; std::vector vdden_z_eval_b; std::vector vtau; + // energy density per grid point (for gradient weight derivative) + std::vector eps; }; features feat; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c789ff9f6..c55423896 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -119,7 +119,7 @@ if( GAUXC_HAS_MPI ) endif() if ( GAUXC_HAS_ONEDFT ) - target_link_libraries( gauxc PUBLIC "${TORCH_LIBRARIES}") + target_link_libraries( gauxc PUBLIC "${TORCH_LIBRARIES}" nlohmann_json::nlohmann_json) endif() add_subdirectory( runtime_environment ) diff --git a/src/molecular_weights/device/device_molecular_weights.cxx b/src/molecular_weights/device/device_molecular_weights.cxx index c5bcce5cd..91ba89682 100644 --- a/src/molecular_weights/device/device_molecular_weights.cxx +++ b/src/molecular_weights/device/device_molecular_weights.cxx @@ -43,6 +43,10 @@ void DeviceMolecularWeights::modify_weights( LoadBalancer& lb ) const { }; std::stable_sort(task_begin, task_end, task_comparator ); + // Save raw quadrature weights before partition modifies them in-place. + // These are needed by OneDFT models that use atomic_grid_weights. + for( auto it = task_begin; it != task_end; ++it ) it->raw_weights = it->weights; + const auto& mol = lb.molecule(); const auto natoms = mol.natoms(); const auto& meta = lb.molmeta(); diff --git a/src/molecular_weights/host/host_molecular_weights.cxx b/src/molecular_weights/host/host_molecular_weights.cxx index e722d22b8..ea370354d 100644 --- a/src/molecular_weights/host/host_molecular_weights.cxx +++ b/src/molecular_weights/host/host_molecular_weights.cxx @@ -31,6 +31,10 @@ void HostMolecularWeights::modify_weights( LoadBalancer& lb ) const { }; std::stable_sort( tasks.begin(), tasks.end(), task_comparator ); + // Save raw quadrature weights before partition modifies them in-place. + // These are needed by OneDFT models that use atomic_grid_weights. + for( auto& task : tasks ) task.raw_weights = task.weights; + // Modify the weights const auto& mol = lb.molecule(); const auto& meta = lb.molmeta(); diff --git a/src/xc_integrator/integrator_util/onedft_util.cxx b/src/xc_integrator/integrator_util/onedft_util.cxx index f77655874..77563cc31 100644 --- a/src/xc_integrator/integrator_util/onedft_util.cxx +++ b/src/xc_integrator/integrator_util/onedft_util.cxx @@ -4,6 +4,7 @@ #include #endif #include +#include #include #include namespace GauXC { @@ -50,6 +51,8 @@ std::string map_model(const std::string& model, torch::DeviceType device) { return model_path + "/tpss.fun"; } else if (model == "LDA") { return model_path + "/lda.fun"; + } else if (model == "SKALA") { + GAUXC_GENERIC_EXCEPTION("To use the Skala functional, specify a local checkpoint path."); } else { GAUXC_GENERIC_EXCEPTION("Model " + model + " not found in " + model_path); } @@ -57,13 +60,24 @@ std::string map_model(const std::string& model, torch::DeviceType device) { std::tuple> load_model(const std::string filename, torch::DeviceType device) -{ +{ + // Cache loaded models to avoid re-reading from disk on every call. + // Key: (resolved_path, device_type) + using CacheKey = std::pair; + using CacheVal = std::tuple>; + static std::map cache; + + std::string model = map_model(filename, device); + CacheKey key{model, device}; + auto it = cache.find(key); + if (it != cache.end()) { + return std::make_tuple(std::get<1>(it->second), std::get<2>(it->second)); + } + torch::jit::script::Module mod; torch::jit::ExtraFilesMap extra_files{{"features", ""}, {"protocol_version", ""}}; std::vector keys; - std::string model = map_model(filename, device); try { - // Deserialize the ScriptModule from a file using torch::jit::load(). mod = torch::jit::load(model, device, extra_files); } catch (const c10::Error& e) { @@ -76,7 +90,6 @@ load_model(const std::string filename, torch::DeviceType device) } auto features = json::parse(extra_files.at("features")); - // check if features is array if (!features.is_array()) { GAUXC_GENERIC_EXCEPTION("features is not an array"); } @@ -87,7 +100,9 @@ load_model(const std::string filename, torch::DeviceType device) keys.push_back(feature.get()); } - return std::make_tuple(mod.get_method("get_exc_density"), keys); + auto method = mod.get_method("get_exc_density"); + cache.emplace(key, CacheVal{std::move(mod), method, keys}); + return std::make_tuple(method, keys); } at::Tensor @@ -101,12 +116,13 @@ get_exc(torch::jit::Method exc_func, FeatureDict features) { int mpi_scatter_onedft_outputs(const FeatureDict features_dict, // only exist in rank 0 const int world_rank, const int world_size, std::vector recvcounts, std::vector displs, + const std::vector& atom_reorder_inv_perm, std::vector& den_eval, std::vector& dden_eval, std::vector& tau) { // store data std::vector recv_den_eval, recv_dden_eval, recv_tau; - int total_npts; - bool is_gga, is_mgga; + int total_npts = 0; + bool is_gga = false, is_mgga = false; if (world_rank == 0) { total_npts = features_dict.at(feat_map.at(ONEDFT_FEATURE::DEN)).size(1); is_gga = (features_dict.find(feat_map.at(ONEDFT_FEATURE::DDEN)) != features_dict.end()); @@ -157,6 +173,13 @@ int mpi_scatter_onedft_outputs(const FeatureDict features_dict, // only exist in std::memcpy(recv_tau_b, tau_grad_tensor.data_ptr() + total_npts, total_npts * sizeof(double)); } } + + // Apply inverse atom-reorder: convert atom-ordered gradients back to rank-ordered + // so each rank receives the correct values after Scatterv. + if (world_rank == 0 && !atom_reorder_inv_perm.empty()) { + reorder_to_rank_order(recv_den_eval, recv_dden_eval, recv_tau, + atom_reorder_inv_perm, total_npts, is_gga, is_mgga); + } if (world_size == 1) { // If only one rank, no need to scatter @@ -231,7 +254,7 @@ int mpi_gather_onedft_inputs_gpu(std::vector& den_eval, std::vector& den_eval, std::vector& den_eval, std::vector& dden_eval, @@ -388,7 +412,261 @@ int mpi_gather_onedft_inputs(std::vector& den_eval, std::vector& } return total_npts_sum; #endif + return 0; +} + +AtomReorderResult mpi_gather_and_reorder( + std::vector& den_eval, + std::vector& dden_eval, + std::vector& tau, + std::vector& grid_coords, + std::vector& grid_weights, + const std::vector& local_atomic_grid_sizes, + int total_npts, int natoms, + const RuntimeEnvironment& rt, + std::vector& sendcounts, + std::vector& displs) { + + AtomReorderResult result; + result.global_atomic_grid_sizes = local_atomic_grid_sizes; + int world_rank = rt.comm_rank(); + + GAUXC_MPI_CODE( + total_npts = mpi_gather_onedft_inputs(den_eval, dden_eval, tau, grid_coords, + grid_weights, total_npts, world_rank, rt.comm_size(), sendcounts, displs); + ); + + GAUXC_MPI_CODE( + if (rt.comm_size() > 1) { + int world_size = rt.comm_size(); + + // Gather per-rank per-atom sizes to rank 0 + std::vector all_rank_atom_sizes(world_rank == 0 ? world_size * natoms : 0); + MPI_Gather(local_atomic_grid_sizes.data(), natoms, MPI_INT64_T, + all_rank_atom_sizes.data(), natoms, MPI_INT64_T, + 0, rt.comm()); + + if (world_rank == 0) { + // Compute global atom sizes by summing across ranks + result.global_atomic_grid_sizes.assign(natoms, 0); + for (int r = 0; r < world_size; ++r) + for (int a = 0; a < natoms; ++a) + result.global_atomic_grid_sizes[a] += all_rank_atom_sizes[r * natoms + a]; + + // Build permutation and reorder all arrays to atom-order + auto [perm, inv_perm] = build_atom_reorder_perm( + all_rank_atom_sizes, sendcounts, displs, natoms, world_size); + reorder_to_atom_order(grid_weights, den_eval, grid_coords, + dden_eval, tau, perm, total_npts); + result.inv_perm = std::move(inv_perm); + } + } + ); + + result.total_npts = total_npts; + return result; +} + +std::pair, std::vector> +build_atom_reorder_perm(const std::vector& all_rank_atom_sizes, + const std::vector& sendcounts, + const std::vector& displs, + int natoms, int world_size) { + int64_t total_npts = 0; + for (int r = 0; r < world_size; ++r) total_npts += sendcounts[r]; + + std::vector perm(total_npts); + std::vector inv_perm(total_npts); + + // Precompute per-rank per-atom offsets within each rank's chunk + // src_off[r][a] = displs[r] + sum of all_rank_atom_sizes[r*natoms + a'] for a' < a + std::vector> src_off(world_size, std::vector(natoms)); + for (int r = 0; r < world_size; ++r) { + int64_t off = displs[r]; + for (int a = 0; a < natoms; ++a) { + src_off[r][a] = off; + off += all_rank_atom_sizes[r * natoms + a]; + } + } + + // Precompute global atom offsets (destination start for each atom) + std::vector global_atom_off(natoms); + { + int64_t off = 0; + for (int a = 0; a < natoms; ++a) { + global_atom_off[a] = off; + for (int r = 0; r < world_size; ++r) + off += all_rank_atom_sizes[r * natoms + a]; + } + } + + // Build perm: for each atom, concatenate contributions from all ranks in rank order + // dst_cursor tracks the next write position for each atom + std::vector dst_cursor = global_atom_off; + for (int a = 0; a < natoms; ++a) { + for (int r = 0; r < world_size; ++r) { + int64_t count = all_rank_atom_sizes[r * natoms + a]; + int64_t src = src_off[r][a]; + for (int64_t k = 0; k < count; ++k) { + perm[src + k] = dst_cursor[a] + k; + } + dst_cursor[a] += count; + } + } + + // Build inverse: inv_perm[perm[i]] = i + for (int64_t i = 0; i < total_npts; ++i) { + inv_perm[perm[i]] = i; + } + + return {std::move(perm), std::move(inv_perm)}; } +void apply_strided_permutation(const double* src, double* dst, + const std::vector& perm, + int64_t npts, int stride) { + for (int64_t i = 0; i < npts; ++i) { + int64_t j = perm[i]; + std::copy(src + i * stride, src + (i + 1) * stride, dst + j * stride); + } +} + +// --- Paired forward/inverse reorder helpers --- + +void reorder_to_atom_order( + std::vector& grid_weights, + std::vector& den_eval, + std::vector& grid_coords, + std::vector& dden_eval, + std::vector& tau, + const std::vector& perm, + int64_t total_npts) { + auto reorder_vec = [&](std::vector& vec, int stride) { + if (vec.empty()) return; + std::vector tmp(vec.size()); + apply_strided_permutation(vec.data(), tmp.data(), perm, total_npts, stride); + vec = std::move(tmp); + }; + reorder_vec(grid_weights, 1); + reorder_vec(den_eval, 2); // interleaved [alpha, beta] per point + reorder_vec(grid_coords, 3); // interleaved [x, y, z] per point + reorder_vec(dden_eval, 6); // [dXa, dYa, dZa, dXb, dYb, dZb] per point + reorder_vec(tau, 2); // interleaved [alpha, beta] per point +} + +void reorder_to_rank_order( + std::vector& recv_den_eval, + std::vector& recv_dden_eval, + std::vector& recv_tau, + const std::vector& inv_perm, + int64_t total_npts, + bool is_gga, bool is_mgga) { + // Gradient data is channel-first: each channel has total_npts contiguous values. + // Apply inv_perm (stride 1) to each channel independently. + auto inv_reorder_channel = [&](double* channel) { + std::vector tmp(total_npts); + apply_strided_permutation(channel, tmp.data(), inv_perm, total_npts, 1); + std::memcpy(channel, tmp.data(), total_npts * sizeof(double)); + }; + // den_eval: [alpha(npts) | beta(npts)] + inv_reorder_channel(recv_den_eval.data()); + inv_reorder_channel(recv_den_eval.data() + total_npts); + // dden_eval: [dXa(npts) | dYa | dZa | dXb | dYb | dZb] + if (is_gga || is_mgga) { + for (int c = 0; c < 6; ++c) + inv_reorder_channel(recv_dden_eval.data() + c * total_npts); + } + // tau: [alpha(npts) | beta(npts)] + if (is_mgga) { + inv_reorder_channel(recv_tau.data()); + inv_reorder_channel(recv_tau.data() + total_npts); + } +} + +void reorder_to_atom_order_channel_first( + std::vector& grid_weights, + std::vector& den_eval, + std::vector& grid_coords, + std::vector& dden_eval, + std::vector& tau, + const std::vector& perm, + int64_t total_npts, + bool is_gga, bool is_mgga) { + // Helper: permute a vector of nchannels × total_npts values (stride 1 per channel) + auto reorder_channels = [&](std::vector& vec, int nchannels) { + if (vec.empty()) return; + std::vector tmp(vec.size()); + for (int c = 0; c < nchannels; ++c) + apply_strided_permutation(vec.data() + c * total_npts, + tmp.data() + c * total_npts, perm, total_npts, 1); + vec = std::move(tmp); + }; + // Helper: permute a strided vector (e.g. stride 3 for coords) + auto reorder_strided = [&](std::vector& vec, int stride) { + if (vec.empty()) return; + std::vector tmp(vec.size()); + apply_strided_permutation(vec.data(), tmp.data(), perm, total_npts, stride); + vec = std::move(tmp); + }; + + reorder_strided(grid_weights, 1); + reorder_strided(grid_coords, 3); + reorder_channels(den_eval, 2); + if (is_gga || is_mgga) reorder_channels(dden_eval, 6); + if (is_mgga) reorder_channels(tau, 2); +} + +AtomReorderResult mpi_gather_and_reorder_gpu( + std::vector& den_eval, + std::vector& dden_eval, + std::vector& tau, + std::vector& grid_coords, + std::vector& grid_weights, + const std::vector& local_atomic_grid_sizes, + int total_npts, int natoms, + const RuntimeEnvironment& rt, + std::vector& sendcounts, + std::vector& displs) { + + AtomReorderResult result; + result.global_atomic_grid_sizes = local_atomic_grid_sizes; + int world_rank = rt.comm_rank(); + + bool is_gga = !dden_eval.empty(); + bool is_mgga = !tau.empty(); + + GAUXC_MPI_CODE( + total_npts = mpi_gather_onedft_inputs_gpu(den_eval, dden_eval, tau, grid_coords, + grid_weights, total_npts, world_rank, rt.comm_size(), sendcounts, displs); + ); + + GAUXC_MPI_CODE( + if (rt.comm_size() > 1) { + int world_size = rt.comm_size(); + + std::vector all_rank_atom_sizes(world_rank == 0 ? world_size * natoms : 0); + MPI_Gather(local_atomic_grid_sizes.data(), natoms, MPI_INT64_T, + all_rank_atom_sizes.data(), natoms, MPI_INT64_T, + 0, rt.comm()); + + if (world_rank == 0) { + result.global_atomic_grid_sizes.assign(natoms, 0); + for (int r = 0; r < world_size; ++r) + for (int a = 0; a < natoms; ++a) + result.global_atomic_grid_sizes[a] += all_rank_atom_sizes[r * natoms + a]; + + auto [perm, inv_perm] = build_atom_reorder_perm( + all_rank_atom_sizes, sendcounts, displs, natoms, world_size); + reorder_to_atom_order_channel_first(grid_weights, den_eval, grid_coords, + dden_eval, tau, perm, total_npts, + is_gga, is_mgga); + result.inv_perm = std::move(inv_perm); + } + } + ); + + result.total_npts = total_npts; + return result; +} } // namespace GauXC diff --git a/src/xc_integrator/integrator_util/onedft_util.hpp b/src/xc_integrator/integrator_util/onedft_util.hpp index 6cf27e128..f88a1427a 100644 --- a/src/xc_integrator/integrator_util/onedft_util.hpp +++ b/src/xc_integrator/integrator_util/onedft_util.hpp @@ -3,6 +3,7 @@ #include #include #include +#include using json = nlohmann::json; using IValueList = std::vector; @@ -16,7 +17,7 @@ namespace GauXC { // TODO add laplacian ? void print_memory_stats(size_t device_id); - enum ONEDFT_FEATURE { DEN, DDEN, TAU, POINTS, WEIGHTS, COORDS }; + enum ONEDFT_FEATURE { DEN, DDEN, TAU, POINTS, WEIGHTS, COORDS, ATOMIC_GRID_WEIGHTS, ATOMIC_GRID_SIZES, ATOMIC_GRID_SIZE_BOUND_SHAPE }; // Mapping enums to string values inline const std::map feat_map = { @@ -25,7 +26,10 @@ namespace GauXC { {TAU, "kin"}, {POINTS, "grid_coords"}, {WEIGHTS, "grid_weights"}, - {COORDS, "coarse_0_atomic_coords"} + {COORDS, "coarse_0_atomic_coords"}, + {ATOMIC_GRID_WEIGHTS, "atomic_grid_weights"}, + {ATOMIC_GRID_SIZES, "atomic_grid_sizes"}, + {ATOMIC_GRID_SIZE_BOUND_SHAPE, "atomic_grid_size_bound_shape"} }; inline const std::map reverse_feat_map = { @@ -34,12 +38,16 @@ namespace GauXC { {"kin", TAU}, {"grid_coords", POINTS}, {"grid_weights", WEIGHTS}, - {"coarse_0_atomic_coords", COORDS} + {"coarse_0_atomic_coords", COORDS}, + {"atomic_grid_weights", ATOMIC_GRID_WEIGHTS}, + {"atomic_grid_sizes", ATOMIC_GRID_SIZES}, + {"atomic_grid_size_bound_shape", ATOMIC_GRID_SIZE_BOUND_SHAPE} }; int mpi_scatter_onedft_outputs(const FeatureDict features_dict, const int world_rank, const int world_size, std::vector recvcounts, std::vector displs, + const std::vector& atom_reorder_inv_perm, std::vector& den_eval, std::vector& dden_eval, std::vector& tau); @@ -48,6 +56,28 @@ int mpi_gather_onedft_inputs(std::vector& den_eval, std::vector& std::vector& grid_weights, const int total_npts, const int world_rank, const int world_size, std::vector& sendcounts, std::vector& displs); + + // Result of MPI gather + atom-reorder pipeline + struct AtomReorderResult { + std::vector global_atomic_grid_sizes; + std::vector inv_perm; + int total_npts; + }; + + // Gather local features from all ranks to rank 0, then reorder from + // rank-order to atom-order. Encapsulates MPI_Gather of atom sizes, + // mpi_gather_onedft_inputs, build_atom_reorder_perm, and reorder_to_atom_order. + AtomReorderResult mpi_gather_and_reorder( + std::vector& den_eval, + std::vector& dden_eval, + std::vector& tau, + std::vector& grid_coords, + std::vector& grid_weights, + const std::vector& local_atomic_grid_sizes, + int total_npts, int natoms, + const RuntimeEnvironment& rt, + std::vector& sendcounts, + std::vector& displs); int mpi_gather_onedft_inputs_gpu(std::vector& den_eval, std::vector& dden_eval, std::vector& tau, std::vector& grid_coords, @@ -61,4 +91,75 @@ int mpi_gather_onedft_inputs_gpu(std::vector& den_eval, std::vector, std::vector> + build_atom_reorder_perm(const std::vector& all_rank_atom_sizes, + const std::vector& sendcounts, + const std::vector& displs, + int natoms, int world_size); + + // Apply a point-level permutation to a strided array. + // For each point i, copies stride elements from src[i*stride..] to dst[perm[i]*stride..]. + void apply_strided_permutation(const double* src, double* dst, + const std::vector& perm, + int64_t npts, int stride); + + // --- Paired forward/inverse reorder helpers --- + // These two functions form a symmetric pair: forward reorders gathered MPI data + // from rank-order to atom-order, inverse reverses atom-ordered gradients back + // to rank-order before Scatterv. + + // Forward: reorder interleaved flat arrays from rank-order to atom-order. + // Applies perm with strides matching the interleaved data layout: + // grid_weights(1), den_eval(2), grid_coords(3), dden_eval(6), tau(2). + void reorder_to_atom_order( + std::vector& grid_weights, + std::vector& den_eval, + std::vector& grid_coords, + std::vector& dden_eval, + std::vector& tau, + const std::vector& perm, + int64_t total_npts); + + // Inverse: reorder channel-first gradient arrays from atom-order to rank-order. + // Gradient data uses channel-first layout (each channel has total_npts contiguous + // values with stride 1): den[2*npts], dden[6*npts], tau[2*npts]. + void reorder_to_rank_order( + std::vector& recv_den_eval, + std::vector& recv_dden_eval, + std::vector& recv_tau, + const std::vector& inv_perm, + int64_t total_npts, + bool is_gga, bool is_mgga); + + // Forward (channel-first): reorder GPU-layout flat arrays from rank-order to atom-order. + // GPU data uses channel-first layout for den/dden/tau (each channel has npts contiguous + // values), but stride 3 for grid_coords and stride 1 for grid_weights. + void reorder_to_atom_order_channel_first( + std::vector& grid_weights, + std::vector& den_eval, + std::vector& grid_coords, + std::vector& dden_eval, + std::vector& tau, + const std::vector& perm, + int64_t total_npts, + bool is_gga, bool is_mgga); + + // GPU variant of mpi_gather_and_reorder: gathers GPU-layout (channel-first) data, + // builds atom-order permutation, and reorders to atom-order on rank 0. + AtomReorderResult mpi_gather_and_reorder_gpu( + std::vector& den_eval, + std::vector& dden_eval, + std::vector& tau, + std::vector& grid_coords, + std::vector& grid_weights, + const std::vector& local_atomic_grid_sizes, + int total_npts, int natoms, + const RuntimeEnvironment& rt, + std::vector& sendcounts, + std::vector& displs); } // namespace GauXC \ No newline at end of file diff --git a/src/xc_integrator/local_work_driver/device/common/onedft_exc_grad.hpp b/src/xc_integrator/local_work_driver/device/common/onedft_exc_grad.hpp new file mode 100644 index 000000000..476ac0e1f --- /dev/null +++ b/src/xc_integrator/local_work_driver/device/common/onedft_exc_grad.hpp @@ -0,0 +1,27 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). + * + * (c) 2024-2025, Microsoft Corporation + * + * All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include "device/xc_device_task.hpp" +#include "device/device_queue.hpp" + +namespace GauXC { + +// Transform OneDFT per-direction Vxc (stored in gamma_pp/vgamma_pp etc.) +// into the standard vrho/vgamma/dden format expected by inc_exc_grad kernels. +// After this call, dden_sx/sy/sz contain vds_x/y/z and vgamma_pp/pm/mm = 1/0/1. +void transform_onedft_vxc_for_grad( + size_t ntasks, + int32_t max_npts, + XCDeviceTask* tasks_device, + device_queue queue ); + +} diff --git a/src/xc_integrator/local_work_driver/device/cuda/CMakeLists.txt b/src/xc_integrator/local_work_driver/device/cuda/CMakeLists.txt index e557037c3..842087000 100644 --- a/src/xc_integrator/local_work_driver/device/cuda/CMakeLists.txt +++ b/src/xc_integrator/local_work_driver/device/cuda/CMakeLists.txt @@ -30,6 +30,7 @@ target_sources(gauxc PRIVATE kernels/increment_exc_grad.cu kernels/exx_ek_screening_bfn_stats.cu kernels/onedft.cu + kernels/onedft_exc_grad.cu ) # Check if CMAKE_CUDA_ARCHITECTURES is set diff --git a/src/xc_integrator/local_work_driver/device/cuda/kernels/onedft_exc_grad.cu b/src/xc_integrator/local_work_driver/device/cuda/kernels/onedft_exc_grad.cu new file mode 100644 index 000000000..9c87d1dce --- /dev/null +++ b/src/xc_integrator/local_work_driver/device/cuda/kernels/onedft_exc_grad.cu @@ -0,0 +1,82 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). + * + * (c) 2024-2025, Microsoft Corporation + * + * All rights reserved. + * + * See LICENSE.txt for details + */ +#include "device/common/onedft_exc_grad.hpp" +#include +#include "device_specific/cuda_util.hpp" +#include "device_specific/cuda_device_constants.hpp" + +namespace GauXC { + +// Transform OneDFT per-component Vxc to the standard format expected by +// inc_exc_grad_gga/mgga kernels. In OneDFT: +// gamma_pp = dden_x_grad_a (per-direction derivative, alpha) +// vgamma_pp = dden_x_grad_b (per-direction derivative, beta) +// gamma_pm = dden_y_grad_a +// vgamma_pm = dden_y_grad_b +// gamma_mm = dden_z_grad_a +// vgamma_mm = dden_z_grad_b +// +// The standard kernel reads dden_sx = (a+b)/2 (s-spin), dden_zx = (a-b)/2 (z-spin) +// and vgamma_pp/pm/mm as scalar coupling coefficients. +// Setting vgamma_pp=1, vgamma_pm=0, vgamma_mm=1 makes the standard kernel +// reproduce the OneDFT gradient formula exactly. +__global__ void transform_onedft_vxc_for_grad_kernel( + uint32_t ntasks, + XCDeviceTask* __restrict__ tasks_device ) { + + const int batch_idx = blockIdx.z; + if( batch_idx >= ntasks ) return; + + auto& task = tasks_device[batch_idx]; + const auto npts = task.npts; + + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if( tid >= npts ) return; + + // Read per-direction OneDFT derivatives (alpha/beta) + const double dx_a = task.gamma_pp[tid]; + const double dx_b = task.vgamma_pp[tid]; + const double dy_a = task.gamma_pm[tid]; + const double dy_b = task.vgamma_pm[tid]; + const double dz_a = task.gamma_mm[tid]; + const double dz_b = task.vgamma_mm[tid]; + + // Convert to total (s) and magnetization (z) form + task.dden_sx[tid] = 0.5 * (dx_a + dx_b); + task.dden_sy[tid] = 0.5 * (dy_a + dy_b); + task.dden_sz[tid] = 0.5 * (dz_a + dz_b); + task.dden_zx[tid] = 0.5 * (dx_a - dx_b); + task.dden_zy[tid] = 0.5 * (dy_a - dy_b); + task.dden_zz[tid] = 0.5 * (dz_a - dz_b); + + // Set vgamma coefficients so the standard kernel reproduces the OneDFT formula + task.vgamma_pp[tid] = 1.0; + task.vgamma_pm[tid] = 0.0; + task.vgamma_mm[tid] = 1.0; +} + +void transform_onedft_vxc_for_grad( + size_t ntasks, + int32_t max_npts, + XCDeviceTask* tasks_device, + device_queue queue ) { + + cudaStream_t stream = queue.queue_as(); + + dim3 threads(256); + dim3 blocks( util::div_ceil((uint32_t)max_npts, threads.x), 1, ntasks ); + + transform_onedft_vxc_for_grad_kernel<<>>( + ntasks, tasks_device ); +} + +} // namespace GauXC diff --git a/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx b/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx index 384d6bb2b..04a0eed9b 100644 --- a/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx +++ b/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx @@ -140,6 +140,7 @@ FWD_TO_PIMPL(inc_exx_k) FWD_TO_PIMPL_KS_SCHEME_BOOL(inc_exc_grad_lda) FWD_TO_PIMPL_KS_SCHEME_BOOL(inc_exc_grad_gga) FWD_TO_PIMPL_KS_SCHEME_BOOL_BOOL(inc_exc_grad_mgga) +FWD_TO_PIMPL(transform_onedft_vxc_for_grad) FWD_TO_PIMPL_DEN_ID(symmetrize_vxc) FWD_TO_PIMPL_DEN_ID(symmetrize_fxc) // Added FXC function diff --git a/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp b/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp index f44773853..50928fa7b 100644 --- a/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp +++ b/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp @@ -117,6 +117,7 @@ class LocalDeviceWorkDriver : public LocalWorkDriver { void inc_exc_grad_lda( XCDeviceData*, integrator_ks_scheme, bool ); void inc_exc_grad_gga( XCDeviceData*, integrator_ks_scheme, bool ); void inc_exc_grad_mgga( XCDeviceData*, integrator_ks_scheme , bool, bool ); + void transform_onedft_vxc_for_grad( XCDeviceData* ); void inc_exx_k( XCDeviceData* ); void eval_exx_ek_screening_bfn_stats( XCDeviceData* ); diff --git a/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp b/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp index f20a73f6d..94bec303e 100644 --- a/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp +++ b/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp @@ -74,6 +74,7 @@ struct LocalDeviceWorkDriverPIMPL { virtual void inc_exc_grad_lda( XCDeviceData*, integrator_ks_scheme, bool ) = 0; virtual void inc_exc_grad_gga( XCDeviceData*, integrator_ks_scheme, bool ) = 0; virtual void inc_exc_grad_mgga( XCDeviceData*, integrator_ks_scheme , bool, bool ) = 0; + virtual void transform_onedft_vxc_for_grad( XCDeviceData* ) = 0; virtual void inc_exx_k( XCDeviceData* ) = 0; virtual void symmetrize_vxc( XCDeviceData*, density_id ) = 0; virtual void symmetrize_fxc( XCDeviceData*, density_id ) = 0; diff --git a/src/xc_integrator/local_work_driver/device/scheme1_base.cxx b/src/xc_integrator/local_work_driver/device/scheme1_base.cxx index 233ff5af4..a858d1b97 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_base.cxx +++ b/src/xc_integrator/local_work_driver/device/scheme1_base.cxx @@ -21,6 +21,7 @@ #include "device/common/inc_potential.hpp" #include "device/common/symmetrize_mat.hpp" #include "device/common/increment_exc_grad.hpp" +#include "device/common/onedft_exc_grad.hpp" #include "device/common/exx_ek_screening.hpp" #include "buffer_adaptor.hpp" @@ -1963,6 +1964,28 @@ void AoSScheme1Base::inc_exc_grad_mgga( XCDeviceData* _data, integrator_ks_schem #endif } +void AoSScheme1Base::transform_onedft_vxc_for_grad( XCDeviceData* _data ) { +#ifdef GAUXC_HAS_HIP + GAUXC_GENERIC_EXCEPTION("OneDFT Grad Transform NYI for HIP Backends"); +#else + auto* data = dynamic_cast(_data); + if( !data ) GAUXC_BAD_LWD_DATA_CAST(); + + if( not data->device_backend_ ) GAUXC_UNINITIALIZED_DEVICE_BACKEND(); + + auto& tasks = data->host_device_tasks; + const auto ntasks = tasks.size(); + size_t npts_max = 0; + for( auto& task : tasks ) { + npts_max = std::max( npts_max, task.npts ); + } + + GauXC::transform_onedft_vxc_for_grad( ntasks, npts_max, + data->aos_stack.device_tasks, + data->device_backend_->queue() ); +#endif +} + void AoSScheme1Base::eval_exx_fmat( XCDeviceData* _data ) { #ifndef GAUXC_ENABLE_EXX diff --git a/src/xc_integrator/local_work_driver/device/scheme1_base.hpp b/src/xc_integrator/local_work_driver/device/scheme1_base.hpp index 5abac35d0..fb108de0e 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_base.hpp +++ b/src/xc_integrator/local_work_driver/device/scheme1_base.hpp @@ -63,6 +63,7 @@ struct AoSScheme1Base : public detail::LocalDeviceWorkDriverPIMPL { void inc_exc_grad_lda( XCDeviceData*, integrator_ks_scheme, bool ) override final; void inc_exc_grad_gga( XCDeviceData*, integrator_ks_scheme, bool ) override final; void inc_exc_grad_mgga( XCDeviceData*, integrator_ks_scheme , bool, bool ) override final; + void transform_onedft_vxc_for_grad( XCDeviceData* ) override final; void symmetrize_vxc( XCDeviceData* , density_id) override final; void symmetrize_fxc( XCDeviceData* , density_id) override final; void symmetrize_exx_k( XCDeviceData* ) override final; diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp index 38b5f2067..1287835e9 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp @@ -113,6 +113,10 @@ class IncoreReplicatedXCDeviceIntegrator : value_type* VXCz, int64_t ldvxcz, value_type* EXC, const IntegratorSettingsXC& settings ) override; + void eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, + const IntegratorSettingsXC& settings ) override; + void integrate_den_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, value_type *N_EL, host_task_iterator task_begin, host_task_iterator task_end, diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp index bc183e711..387f0bd4a 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp @@ -13,7 +13,6 @@ #include "device/common/device_blas.hpp" #include "integrator_util/onedft_util.hpp" #include -#include #include "device/cuda/cuda_backend.hpp" #include // for size_t @@ -22,7 +21,10 @@ namespace GauXC::detail { FeatureDict prepare_onedft_features( const size_t natoms, const size_t total_npts, const size_t ndm, const at::TensorOptions options, const std::vector feature_keys, double* den_eval, double* dden_eval, double* tau, double* grid_coords, - double* grid_weights, double* coords ); + double* grid_weights, double* coords, + const std::vector& atomic_grid_sizes = {}, + int64_t max_grid_size = 0, + double* raw_grid_weights = nullptr ); size_t save_static_data_onedft_features (XCDeviceData* _data, const integrator_term_tracker enabled_terms, size_t offset); @@ -81,7 +83,11 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, } // Get Tasks - auto& tasks = this->load_balancer_->get_tasks(); + auto& tasks = this->load_balancer_->get_tasks(); + // Sort tasks by atom index so that grid points are grouped by atom. + // build_atom_reorder_perm assumes this contiguous-by-atom layout. + std::stable_sort(tasks.begin(), tasks.end(), + [](const auto& a, const auto& b) { return a.iParent < b.iParent; }); size_t total_npts = std::accumulate( tasks.begin(), tasks.end(), 0ul, [](const auto& a, const auto& b) { return a + b.npts; } ); @@ -169,18 +175,39 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, std::vector grid_weights, grid_coords, den_eval, dden_eval, tau; std::vector displs(world_size), recvcounts(world_size); - // run onedft model on thread 0 + // Collect raw (pre-partition) quadrature weights from tasks + std::vector raw_grid_weights; + raw_grid_weights.reserve(total_npts); + for (const auto& task : tasks) { + raw_grid_weights.insert(raw_grid_weights.end(), task.raw_weights.begin(), task.raw_weights.end()); + } + + // Compute per-atom grid sizes from tasks + std::vector atomic_grid_sizes_vec(natoms, 0); + for (const auto& task : tasks) { + if (task.iParent >= 0 && task.iParent < (int)natoms) { + atomic_grid_sizes_vec[task.iParent] += task.npts; + } + } + + // run onedft model on rank 0 FeatureDict features_dict; + std::vector atom_reorder_inv_perm; + std::vector global_atomic_grid_sizes_vec = atomic_grid_sizes_vec; if ( world_size == 1 ) { // keep everything on device + int64_t max_grid_size = *std::max_element( + atomic_grid_sizes_vec.begin(), atomic_grid_sizes_vec.end()); auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCUDA); features_dict = prepare_onedft_features( natoms, total_npts, ndm, options, feature_keys, device_data_ptr->den_eval_device_data(), device_data_ptr->dden_eval_device_data(), device_data_ptr->tau_device_data(), device_data_ptr->grid_coords_device_data(), device_data_ptr->grid_weights_device_data(), - device_data_ptr->coords_device_data() + device_data_ptr->coords_device_data(), + atomic_grid_sizes_vec, max_grid_size, + raw_grid_weights.data() ); - } else { // copy to host and then back to device + } else { // copy to host, gather, reorder, then back to device grid_weights.resize(total_npts); grid_coords.resize(total_npts * 3); den_eval.resize(total_npts * ndm); @@ -194,7 +221,8 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, (is_gga || is_mgga) ? dden_eval.data() : nullptr, is_mgga ? tau.data() : nullptr, grid_coords.data(), grid_weights.data() ); - // Build host coords from molecule (avoids passing GPU pointer with CPU TensorOptions) + + // Build host coords from molecule (avoids device-to-host copy of coords_device_data) std::vector host_coords(natoms * 3); for (size_t i = 0; i < natoms; i++) { host_coords[3*i] = mol[i].x; @@ -202,14 +230,42 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, host_coords[3*i+2] = mol[i].z; } - int total_npts_sum = mpi_gather_onedft_inputs_gpu(den_eval, dden_eval, tau, grid_coords, grid_weights, - total_npts, world_rank, world_size, recvcounts, displs); + auto reorder_result = mpi_gather_and_reorder_gpu( + den_eval, dden_eval, tau, grid_coords, grid_weights, + atomic_grid_sizes_vec, total_npts, natoms, rt, recvcounts, displs); + int total_npts_sum = reorder_result.total_npts; + atom_reorder_inv_perm = std::move(reorder_result.inv_perm); + global_atomic_grid_sizes_vec = std::move(reorder_result.global_atomic_grid_sizes); + + // Gather and reorder raw_grid_weights using the same MPI layout + GAUXC_MPI_CODE( + if (world_size > 1) { + int local_npts = (int)total_npts; + std::vector recv_raw(world_rank == 0 ? total_npts_sum : 0); + MPI_Gatherv(raw_grid_weights.data(), local_npts, MPI_DOUBLE, + recv_raw.data(), recvcounts.data(), displs.data(), + MPI_DOUBLE, 0, rt.comm()); + if (world_rank == 0) { + raw_grid_weights = std::move(recv_raw); + std::vector perm(total_npts_sum); + for (int64_t j = 0; j < total_npts_sum; j++) perm[atom_reorder_inv_perm[j]] = j; + std::vector tmp(total_npts_sum); + for (int64_t i = 0; i < total_npts_sum; i++) tmp[perm[i]] = raw_grid_weights[i]; + raw_grid_weights = std::move(tmp); + } + } + ) + if (world_rank == 0) { + int64_t max_grid_size = *std::max_element( + global_atomic_grid_sizes_vec.begin(), global_atomic_grid_sizes_vec.end()); auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); features_dict = prepare_onedft_features( natoms, total_npts_sum, ndm, options, feature_keys, den_eval.data(), dden_eval.data(), tau.data(), grid_coords.data(), grid_weights.data(), - host_coords.data() + host_coords.data(), + global_atomic_grid_sizes_vec, max_grid_size, + raw_grid_weights.data() ); } } @@ -218,9 +274,7 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, auto exc = (exc_on_grid * features_dict.at(feat_map.at(ONEDFT_FEATURE::WEIGHTS))).sum(); // if do_vxc exc.backward(); - c10::cuda::CUDACachingAllocator::emptyCache(); EXC[0] = exc.item(); - // std::cout << "EXC: " << EXC[0] << std::endl; } else { EXC[0] = 0.0; } @@ -242,7 +296,8 @@ eval_exc_vxc_onedft_( int64_t m, int64_t n, den_grad, dden_grad, tau_grad ); } else { total_npts = mpi_scatter_onedft_outputs(features_dict, rt.comm_rank(), rt.comm_size(), - recvcounts, displs, den_eval, dden_eval, tau); + recvcounts, displs, atom_reorder_inv_perm, + den_eval, dden_eval, tau); device_data_ptr->send_static_data_onedft_results( total_npts, ndm, EXC, den_eval.data(), dden_eval.data(), tau.data()); } @@ -565,7 +620,10 @@ size_t save_static_data_onedft_features(XCDeviceData* _data, const integrator_te FeatureDict prepare_onedft_features( const size_t natoms, const size_t total_npts, const size_t ndm, const at::TensorOptions options, const std::vector feature_keys, double* den_eval, double* dden_eval, double* tau, double* grid_coords, - double* grid_weights, double* coords ) { + double* grid_weights, double* coords, + const std::vector& atomic_grid_sizes, + int64_t max_grid_size, + double* raw_grid_weights ) { auto device = torch::Device(torch::kCUDA, 0); FeatureDict featmap; for (const auto& key : feature_keys) { @@ -607,10 +665,490 @@ FeatureDict prepare_onedft_features( const size_t natoms, const size_t total_npt featmap.insert(key, tensor); break; } + case ONEDFT_FEATURE::ATOMIC_GRID_WEIGHTS: { + // Use raw (pre-partition) quadrature weights if available + double* w_ptr = raw_grid_weights ? raw_grid_weights : grid_weights; + auto w_opts = raw_grid_weights + ? torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU) + : options; + auto flat_tensor = torch::from_blob(w_ptr, {total_npts}, w_opts); + auto tensor = flat_tensor.view({total_npts}).to(device); + featmap.insert(key, tensor); + break; + } + case ONEDFT_FEATURE::ATOMIC_GRID_SIZES: { + auto sizes_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU); + auto tensor = torch::from_blob(const_cast(atomic_grid_sizes.data()), + {static_cast(natoms)}, sizes_options).clone().to(device); + featmap.insert(key, tensor); + break; + } + case ONEDFT_FEATURE::ATOMIC_GRID_SIZE_BOUND_SHAPE: { + auto sizes_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU); + auto tensor = torch::zeros({max_grid_size, 0}, sizes_options).to(device); + featmap.insert(key, tensor); + break; + } default: GAUXC_GENERIC_EXCEPTION("Feature Key Not Implemented: " + key); } } return featmap; } +} // namespace GauXC::detail (prepare_onedft_features) + +// ============================================================================ +// OneDFT EXC Gradient — Device Implementation +// ============================================================================ + +namespace GauXC::detail { + +template +void IncoreReplicatedXCDeviceIntegrator:: +eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, + const IntegratorSettingsXC& settings ) { + + const auto& basis = this->load_balancer_->basis(); + const int64_t nbf = basis.nbf(); + if( m != n ) GAUXC_GENERIC_EXCEPTION("P Must Be Square"); + if( m != nbf ) GAUXC_GENERIC_EXCEPTION("P Must Have Same Dimension as Basis"); + if( ldps < nbf) GAUXC_GENERIC_EXCEPTION("Invalid LDPS"); + if( ldpz && ldpz < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDPZ"); + + const bool is_uks = (Pz != nullptr); + if (not is_uks) { + GAUXC_GENERIC_EXCEPTION("RKS OneDFT gradient Not Yet Implemented"); + } + + auto& tasks = this->load_balancer_->get_tasks(); + // Sort tasks by atom so grid points are contiguous per atom + std::stable_sort(tasks.begin(), tasks.end(), + [](const auto& a, const auto& b) { return a.iParent < b.iParent; }); + size_t total_npts = std::accumulate( tasks.begin(), tasks.end(), 0ul, + [](const auto& a, const auto& b) { return a + b.npts; } ); + + auto* lwd = dynamic_cast(this->local_work_driver_.get()); + auto rt = detail::as_device_runtime(this->load_balancer_->runtime()); + auto device_data_ptr = lwd->create_device_data(rt); + + int32_t world_rank = rt.comm_rank(); + int32_t world_size = rt.comm_size(); + + // Load model + OneDFTSettings onedft_settings; + if( auto* tmp = dynamic_cast(&settings) ) { + onedft_settings = *tmp; + } + const auto model_path = onedft_settings.model; + torch::DeviceType torch_device = torch::kCPU; + auto [exc_func, feature_keys] = load_model(model_path, torch_device); + + // Determine feature requirements + bool is_gga = false; + bool is_mgga = false; + for (const auto& key : feature_keys) { + if ( not valueExists(key) ) GAUXC_GENERIC_EXCEPTION("Feature Key Required Not Implemented: " + key); + if (key == feat_map.at(ONEDFT_FEATURE::TAU)) is_mgga = true; + if (key == feat_map.at(ONEDFT_FEATURE::DDEN)) is_gga = true; + } + if (is_mgga) is_gga = false; + + const auto& mol = this->load_balancer_->molecule(); + const auto natoms = mol.natoms(); + const auto nshells = basis.nshells(); + + // Phase 1: Pre-work — compute density features on device + integrator_term_tracker enabled_terms; + enabled_terms.exc_vxc = true; + enabled_terms.onedft = true; + enabled_terms.exc_grad = true; // Also allocate exc_grad_device + if (is_uks) enabled_terms.ks_scheme = UKS; + else enabled_terms.ks_scheme = RKS; + + if (is_mgga) enabled_terms.xc_approx = integrator_xc_approx::MGGA_TAU; + else if (is_gga) enabled_terms.xc_approx = integrator_xc_approx::GGA; + else enabled_terms.xc_approx = integrator_xc_approx::LDA; + + device_data_ptr->reset_allocations(); + device_data_ptr->allocate_static_data_onedft( nbf, nshells, natoms, total_npts, enabled_terms ); + device_data_ptr->send_static_data_onedft( mol, Ps, ldps, Pz, ldpz, nullptr, 0, nullptr, 0, basis ); + device_data_ptr->zero_exc_grad_integrands(); + + // Pre-OneDFT density computation on device + integrator_term_tracker pre_terms = enabled_terms; + pre_terms.exc_grad = false; // pre-work doesn't need gradient buffers + this->timer_.time_op("XCIntegrator.LocalWork_PreOneDFT", [&](){ + pre_onedft_local_work_( basis, Ps, ldps, Pz, ldpz, nullptr, 0, nullptr, 0, + tasks.begin(), tasks.end(), *device_data_ptr, pre_terms ); + }); + + // Collect raw quadrature weights and atomic grid sizes (for model features) + std::vector raw_grid_weights; + raw_grid_weights.reserve(total_npts); + for (const auto& task : tasks) { + raw_grid_weights.insert(raw_grid_weights.end(), task.raw_weights.begin(), task.raw_weights.end()); + } + std::vector atomic_grid_sizes_vec(natoms, 0); + for (const auto& task : tasks) { + if (task.iParent >= 0 && task.iParent < (int)natoms) { + atomic_grid_sizes_vec[task.iParent] += task.npts; + } + } + + size_t ndm = 2; // UKS + + // Phase 2: Retrieve features and run model + // For single rank: features stay on device → just build tensors from device pointers + // For multi rank: retrieve to host, gather, reorder + FeatureDict features_dict; + std::vector atom_reorder_inv_perm; + std::vector den_eval, dden_eval, tau; + std::vector displs(world_size), recvcounts(world_size); + + if (world_size == 1) { + // Keep on device for model inference + int64_t max_grid_size = *std::max_element( + atomic_grid_sizes_vec.begin(), atomic_grid_sizes_vec.end()); + // For gradient, we need CPU tensors since we need requires_grad on points/coords + // Retrieve features to host + std::vector grid_weights(total_npts), grid_coords(total_npts * 3); + den_eval.resize(total_npts * ndm); + if (is_gga || is_mgga) dden_eval.resize(total_npts * ndm * 3); + if (is_mgga) tau.resize(total_npts * ndm); + + device_data_ptr->retrieve_onedft_features( total_npts, 2, den_eval.data(), + (is_gga || is_mgga) ? dden_eval.data() : nullptr, + is_mgga ? tau.data() : nullptr, + grid_coords.data(), grid_weights.data() ); + rt.device_backend()->master_queue_synchronize(); + + std::vector host_coords(natoms * 3); + for (size_t i = 0; i < (size_t)natoms; i++) { + host_coords[3*i] = mol[i].x; + host_coords[3*i+1] = mol[i].y; + host_coords[3*i+2] = mol[i].z; + } + + auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + features_dict = prepare_onedft_features( + natoms, total_npts, ndm, options, feature_keys, den_eval.data(), + dden_eval.data(), tau.data(), grid_coords.data(), grid_weights.data(), + host_coords.data(), atomic_grid_sizes_vec, max_grid_size, + raw_grid_weights.data() ); + } else { + // Multi-rank: retrieve features, gather, reorder + std::vector grid_weights(total_npts), grid_coords(total_npts * 3); + den_eval.resize(total_npts * ndm); + if (is_gga || is_mgga) dden_eval.resize(total_npts * ndm * 3); + if (is_mgga) tau.resize(total_npts * ndm); + + device_data_ptr->retrieve_onedft_features( total_npts, 2, den_eval.data(), + (is_gga || is_mgga) ? dden_eval.data() : nullptr, + is_mgga ? tau.data() : nullptr, + grid_coords.data(), grid_weights.data() ); + rt.device_backend()->master_queue_synchronize(); + + std::vector host_coords(natoms * 3); + for (size_t i = 0; i < (size_t)natoms; i++) { + host_coords[3*i] = mol[i].x; + host_coords[3*i+1] = mol[i].y; + host_coords[3*i+2] = mol[i].z; + } + + auto reorder_result = mpi_gather_and_reorder_gpu( + den_eval, dden_eval, tau, grid_coords, grid_weights, + atomic_grid_sizes_vec, total_npts, natoms, rt, recvcounts, displs); + atom_reorder_inv_perm = std::move(reorder_result.inv_perm); + + if (world_rank == 0) { + int64_t max_grid_size = *std::max_element( + reorder_result.global_atomic_grid_sizes.begin(), + reorder_result.global_atomic_grid_sizes.end()); + auto options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + features_dict = prepare_onedft_features( + natoms, reorder_result.total_npts, ndm, options, feature_keys, + den_eval.data(), dden_eval.data(), tau.data(), + grid_coords.data(), grid_weights.data(), host_coords.data(), + reorder_result.global_atomic_grid_sizes, max_grid_size, + raw_grid_weights.data() ); + } + } + + // Phase 3: Forward + backward with requires_grad on POINTS and COORDS + std::vector eps_on_grid_global; + std::vector points_grad_global; + std::vector coords_grad_global; + + if (world_rank == 0) { + // Enable requires_grad on points and coords + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::POINTS)) != features_dict.end()) { + features_dict.at(feat_map.at(ONEDFT_FEATURE::POINTS)).requires_grad_(true); + } + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::COORDS)) != features_dict.end()) { + features_dict.at(feat_map.at(ONEDFT_FEATURE::COORDS)).requires_grad_(true); + } + + auto exc_on_grid = get_exc(exc_func, features_dict); + if (exc_on_grid.isnan().any().item()) { + GAUXC_GENERIC_EXCEPTION("exc_on_grid has NaN"); + } + auto exc = (exc_on_grid * features_dict.at(feat_map.at(ONEDFT_FEATURE::WEIGHTS))).sum(); + exc.backward(); + + // Extract eps_on_grid for weight derivative + int total_npts_model = exc_on_grid.size(0); + at::Tensor eps_cpu = exc_on_grid.detach().cpu().contiguous(); + eps_on_grid_global.resize(total_npts_model); + std::memcpy(eps_on_grid_global.data(), eps_cpu.data_ptr(), + total_npts_model * sizeof(double)); + + // Extract points.grad() + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::POINTS)) != features_dict.end()) { + auto pg = features_dict.at(feat_map.at(ONEDFT_FEATURE::POINTS)).grad(); + if (pg.defined()) { + at::Tensor pg_cpu = pg.cpu().contiguous(); + points_grad_global.resize(total_npts_model * 3); + std::memcpy(points_grad_global.data(), pg_cpu.data_ptr(), + total_npts_model * 3 * sizeof(double)); + } + } + + // Extract coords.grad() + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::COORDS)) != features_dict.end()) { + auto cg = features_dict.at(feat_map.at(ONEDFT_FEATURE::COORDS)).grad(); + if (cg.defined()) { + at::Tensor cg_cpu = cg.cpu().contiguous(); + coords_grad_global.resize(natoms * 3); + std::memcpy(coords_grad_global.data(), cg_cpu.data_ptr(), + natoms * 3 * sizeof(double)); + } + } + + // Reorder eps_on_grid and points_grad from atom-order back to rank-order + if (!atom_reorder_inv_perm.empty()) { + std::vector tmp(total_npts_model); + for (int64_t i = 0; i < total_npts_model; i++) { + tmp[atom_reorder_inv_perm[i]] = eps_on_grid_global[i]; + } + eps_on_grid_global = std::move(tmp); + + if (!points_grad_global.empty()) { + std::vector tmp3(total_npts_model * 3); + for (int64_t i = 0; i < total_npts_model; i++) { + int64_t j = atom_reorder_inv_perm[i]; + tmp3[j*3+0] = points_grad_global[i*3+0]; + tmp3[j*3+1] = points_grad_global[i*3+1]; + tmp3[j*3+2] = points_grad_global[i*3+2]; + } + points_grad_global = std::move(tmp3); + } + } + } + + // Phase 4: Send OneDFT Vxc outputs back to device + if (world_size == 1) { + double* den_grad = features_dict.at(feat_map.at(ONEDFT_FEATURE::DEN)).grad().data_ptr(); + double* dden_grad = nullptr; + double* tau_grad = nullptr; + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::DDEN)) != features_dict.end()) { + dden_grad = features_dict.at(feat_map.at(ONEDFT_FEATURE::DDEN)).grad().data_ptr(); + } + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::TAU)) != features_dict.end()) { + tau_grad = features_dict.at(feat_map.at(ONEDFT_FEATURE::TAU)).grad().data_ptr(); + } + double exc_val = 0.0; // Not needed for gradient, but API requires it + device_data_ptr->send_static_data_onedft_results( total_npts, ndm, &exc_val, + den_grad, dden_grad, tau_grad ); + } else { + total_npts = mpi_scatter_onedft_outputs(features_dict, rt.comm_rank(), rt.comm_size(), + recvcounts, displs, atom_reorder_inv_perm, + den_eval, dden_eval, tau); + double exc_val = 0.0; + device_data_ptr->send_static_data_onedft_results( total_npts, ndm, &exc_val, + den_eval.data(), dden_eval.data(), tau.data()); + } + + // Scatter eps_on_grid to local tasks + std::vector eps_on_grid_local; + if (world_size == 1) { + eps_on_grid_local = std::move(eps_on_grid_global); + } else { + GAUXC_GENERIC_EXCEPTION("OneDFT gradient with MPI not yet implemented"); + } + + // Zero out EXC_GRAD on host + for (int i = 0; i < 3*natoms; ++i) EXC_GRAD[i] = 0.0; + + // Phase 5: Pulay gradient + weight derivative on device + // Use the standard gradient flow: for each batch, load OneDFT Vxc, + // evaluate collocation (hessian), X-matrix, transform, and inc_exc_grad + + auto& lb_state = this->load_balancer_->state(); + if( not lb_state.modified_weights_are_stored ) { + GAUXC_GENERIC_EXCEPTION("Weights Have Not Been Modified"); + } + XCWeightAlg& weight_alg = lb_state.weight_alg; + + BasisSetMap basis_map(basis, mol); + const auto& meta = this->load_balancer_->molmeta(); + device_data_ptr->populate_submat_maps( nbf, tasks.begin(), tasks.end(), basis_map ); + + // Sort tasks by workload for load balancing + auto task_comparator = []( const XCTask& a, const XCTask& b ) { + return (a.points.size() * a.bfn_screening.nbe) > (b.points.size() * b.bfn_screening.nbe); + }; + + // First: distribute eps_on_grid to tasks (in iParent order, before sorting by workload) + // Tasks are currently in iParent order (from stable_sort above) + { + size_t offset = 0; + for (auto& task : tasks) { + int64_t npts = task.points.size(); + task.feat.eps.resize(npts); + std::copy(eps_on_grid_local.data() + offset, + eps_on_grid_local.data() + offset + npts, + task.feat.eps.begin()); + offset += npts; + } + } + + // Now sort by workload + std::sort( tasks.begin(), tasks.end(), task_comparator ); + + // Build concatenated eps_w array (eps_on_grid * weights) in task order after sorting + // This will be used for the weight derivative + std::vector eps_w_all; + eps_w_all.reserve(total_npts); + for (const auto& task : tasks) { + for (size_t ipt = 0; ipt < task.points.size(); ++ipt) { + eps_w_all.push_back(task.feat.eps[ipt] * task.weights[ipt]); + } + } + + // Set up gradient computation terms + integrator_term_tracker grad_terms; + grad_terms.exc_grad = true; + grad_terms.weights = true; + grad_terms.onedft = true; + grad_terms.ks_scheme = UKS; + if (is_mgga) grad_terms.xc_approx = integrator_xc_approx::MGGA_TAU; + else if (is_gga) grad_terms.xc_approx = integrator_xc_approx::GGA; + else grad_terms.xc_approx = integrator_xc_approx::LDA; + + // Weight derivative needs: RAB, coords, dist scratch + device_data_ptr->allocate_static_data_weights( natoms ); + device_data_ptr->send_static_data_weights( mol, meta ); + + this->timer_.time_op("XCIntegrator.LocalWork_OneDFTGrad", [&](){ + auto task_it = tasks.begin(); + size_t vxc_offset = 0; + size_t eps_offset = 0; + while( task_it != tasks.end() ) { + + auto batch_begin = task_it; + + // Generate batch buffers (allocates collocation, xmat, hessian, etc.) + task_it = device_data_ptr->generate_buffers( grad_terms, basis_map, + task_it, tasks.end() ); + + // Load OneDFT Vxc for this batch + vxc_offset = send_buffer_onedft_outputs( device_data_ptr.get(), grad_terms, vxc_offset ); + + // Evaluate collocation (hessian for GGA/MGGA, gradient for LDA) + if (is_gga || is_mgga) lwd->eval_collocation_hessian( device_data_ptr.get() ); + else lwd->eval_collocation_gradient( device_data_ptr.get() ); + + // Evaluate X-matrices and save for gradient + const double xmat_fac = 1.0; // UKS: factor 1 + const bool need_xmat_grad = is_gga || is_mgga; + auto do_xmat = [&](density_id den_id) { + lwd->eval_xmat( xmat_fac, device_data_ptr.get(), need_xmat_grad, den_id ); + lwd->save_xmat( device_data_ptr.get(), need_xmat_grad, den_id ); + }; + do_xmat(DEN_S); + do_xmat(DEN_Z); + + // For GGA/MGGA: transform OneDFT per-direction Vxc to standard format + if (is_gga || is_mgga) { + lwd->transform_onedft_vxc_for_grad( device_data_ptr.get() ); + } + + // Increment EXC gradient using standard kernels + const bool with_weight_derivatives = true; + if (is_mgga) lwd->inc_exc_grad_mgga( device_data_ptr.get(), UKS, false, with_weight_derivatives ); + else if (is_gga) lwd->inc_exc_grad_gga( device_data_ptr.get(), UKS, with_weight_derivatives ); + else lwd->inc_exc_grad_lda( device_data_ptr.get(), UKS, with_weight_derivatives ); + + // Weight derivative: load eps*w values for this batch into eps_eval, + // set den_s=1.0 and den_z=0.0 so kernel computes eps*(1.0) = eps*w + { + size_t batch_npts = 0; + for (auto it = batch_begin; it != task_it; ++it) + batch_npts += it->points.size(); + + auto* data = dynamic_cast(device_data_ptr.get()); + auto* backend = dynamic_cast(data->device_backend_); + auto base_stack = data->base_stack; + + // Copy eps*w for this batch to eps_eval_device + backend->copy_async( batch_npts, eps_w_all.data() + eps_offset, + base_stack.eps_eval_device, "Copy OneDFT eps*w" ); + // Set den_s = 1.0 and den_z = 0.0 so kernel's (den_s += den_z; eps *= den_s) is identity + std::vector ones(batch_npts, 1.0); + backend->copy_async( batch_npts, ones.data(), + base_stack.den_s_eval_device, "Set den_s=1 for weight deriv" ); + if (base_stack.den_z_eval_device) { + backend->set_zero( batch_npts, base_stack.den_z_eval_device, "Zero den_z for weight deriv" ); + } + backend->master_queue_synchronize(); + + eps_offset += batch_npts; + } + + lwd->eval_weight_1st_deriv_contracted( device_data_ptr.get(), weight_alg ); + + } // batch loop + }); + + rt.device_backend()->master_queue_synchronize(); + + // Retrieve gradient from device + double N_EL; + device_data_ptr->retrieve_exc_grad_integrands( EXC_GRAD, &N_EL ); + rt.device_backend()->master_queue_synchronize(); + + // Phase 6: Add autograd forces (points -> parent atoms, coords -> direct) + if (!points_grad_global.empty() && world_rank == 0) { + size_t pg_offset = 0; + // Iterate in iParent order — re-sort tasks + std::stable_sort(tasks.begin(), tasks.end(), + [](const auto& a, const auto& b) { return a.iParent < b.iParent; }); + for (const auto& task : tasks) { + int iParent = task.iParent; + for (size_t ipt = 0; ipt < task.points.size(); ++ipt) { + EXC_GRAD[3*iParent + 0] += points_grad_global[(pg_offset + ipt)*3 + 0]; + EXC_GRAD[3*iParent + 1] += points_grad_global[(pg_offset + ipt)*3 + 1]; + EXC_GRAD[3*iParent + 2] += points_grad_global[(pg_offset + ipt)*3 + 2]; + } + pg_offset += task.points.size(); + } + } + + if (!coords_grad_global.empty() && world_rank == 0) { + for (int a = 0; a < natoms; ++a) { + EXC_GRAD[3*a + 0] += coords_grad_global[3*a + 0]; + EXC_GRAD[3*a + 1] += coords_grad_global[3*a + 1]; + EXC_GRAD[3*a + 2] += coords_grad_global[3*a + 2]; + } + } + + // Phase 7: Allreduce + this->timer_.time_op("XCIntegrator.Allreduce", [&](){ + if( not this->reduction_driver_->takes_host_memory() ) + GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions"); + this->reduction_driver_->allreduce_inplace( EXC_GRAD, 3*natoms, ReductionOp::Sum ); + }); +} + } // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp index ab87877be..4e7b85caf 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp @@ -79,12 +79,18 @@ class ReferenceReplicatedXCHostIntegrator : void eval_exc_vxc_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, const value_type* Pz, int64_t ldpz, value_type* VXCs, int64_t ldvxcs, value_type* VXCz, int64_t ldvxcz, value_type* EXC, const IntegratorSettingsXC& ks_settings ) override; + void eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) override; #else void eval_exc_vxc_onedft_( int64_t, int64_t, const value_type*, int64_t, const value_type*, int64_t, value_type*, int64_t, value_type*, int64_t, value_type*, const IntegratorSettingsXC& ) override { throw std::runtime_error("OneDFT support not compiled"); } + void eval_exc_grad_onedft_( int64_t, int64_t, const value_type*, int64_t, + const value_type*, int64_t, value_type*, const IntegratorSettingsXC& ) override { + throw std::runtime_error("OneDFT support not compiled"); + } #endif /// RKS EXC Gradient @@ -172,6 +178,11 @@ class ReferenceReplicatedXCHostIntegrator : value_type* VXCs, int64_t ldvxcs, value_type* VXCz, int64_t ldvxcz, const bool is_gga, const bool is_mgga, const bool needs_laplacian); + + void exc_grad_local_work_onedft_( const value_type* Ps, int64_t ldps, const value_type* Pz, int64_t ldpz, + value_type* EXC_GRAD, + const std::vector& eps_on_grid, + const bool is_gga, const bool is_mgga); #endif diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp index 498f852d6..21e4a359c 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_onedft.hpp @@ -21,10 +21,11 @@ namespace detail { FeatureDict prepare_onedft_features(const int ndm, std::vector& tasks, const Molecule& mol, const std::vector feature_keys, const RuntimeEnvironment& rt, std::vector& sendcounts, - std::vector& displs); + std::vector& displs, std::vector& atom_reorder_inv_perm); void send_buffer_onedft_outputs(const int ndm, const FeatureDict features_dict, std::vector& tasks, - const RuntimeEnvironment& rt, std::vector sendcounts, std::vector displs); + const RuntimeEnvironment& rt, std::vector sendcounts, std::vector displs, + const std::vector& atom_reorder_inv_perm); void interleave_data(const double* a, const double* b, const size_t n, double* out); @@ -81,11 +82,7 @@ void ReferenceReplicatedXCHostIntegrator:: } // Get Tasks auto& tasks = this->load_balancer_->get_tasks(); -#ifdef GAUXC_HAS_DEVICE - auto rt = detail::as_device_runtime(this->load_balancer_->runtime()); -#else auto rt = this->load_balancer_->runtime(); -#endif int32_t world_rank = rt.comm_rank(); // Temporary electron count to judge integrator accuracy value_type N_EL; @@ -118,8 +115,9 @@ void ReferenceReplicatedXCHostIntegrator:: }); std::vector sendcounts(rt.comm_size(), 0); std::vector displs(rt.comm_size(), 0); + std::vector atom_reorder_inv_perm; FeatureDict features_dict = prepare_onedft_features(2/*ndm*/, tasks, this->load_balancer_->molecule(), feature_keys, rt, - sendcounts, displs); + sendcounts, displs, atom_reorder_inv_perm); if (world_rank == 0) { auto exc_on_grid = get_exc(exc_func, features_dict); // check is_nan @@ -134,7 +132,7 @@ void ReferenceReplicatedXCHostIntegrator:: } // TODO: stop here if only exc - send_buffer_onedft_outputs(2/*ndm*/, features_dict, tasks, rt, sendcounts, displs); + send_buffer_onedft_outputs(2/*ndm*/, features_dict, tasks, rt, sendcounts, displs, atom_reorder_inv_perm); this->timer_.time_op("XCIntegrator.LocalWork2", [&](){ post_onedft_local_work_( basis, Ps, ldps, Pz, ldpz, VXCs, n, VXCz, n, is_gga, is_mgga, false /*needs_laplacian*/); @@ -622,7 +620,6 @@ void eval_zmat_mgga_vxc_uks(size_t npts, size_t nbf, auto* bf_x_col = dbasis_x_eval + ioff; auto* bf_y_col = dbasis_y_eval + ioff; auto* bf_z_col = dbasis_z_eval + ioff; - auto* lbf_col = lbasis_eval + ioff; const double factp = 0.5 * vdden_eval_a[i]; const double factm = 0.5 * vdden_eval_b[i]; @@ -646,6 +643,7 @@ void eval_zmat_mgga_vxc_uks(size_t npts, size_t nbf, GauXC::blas::axpy( nbf, z_factm, bf_z_col, 1, zz_col, 1 ); if (vlapl_a != nullptr) { + auto* lbf_col = lbasis_eval + ioff; const auto lfactp = vlapl_a[i]; const auto lfactm = vlapl_b[i]; blas::axpy( nbf, 0.5*(lfactp + lfactm), lbf_col, 1, zs_col, 1); @@ -672,12 +670,18 @@ void interleave_data(const double* a, const double* b, const size_t n, double* r FeatureDict prepare_onedft_features(const int ndm, std::vector& tasks, const Molecule& mol, const std::vector feature_keys, const RuntimeEnvironment& rt, - std::vector& sendcounts, std::vector& displs) { - std::vector den_eval, dden_eval, tau, grid_coords, grid_weights; + std::vector& sendcounts, std::vector& displs, + std::vector& atom_reorder_inv_perm) { + std::vector den_eval, dden_eval, tau, grid_coords, grid_weights, raw_grid_weights; + // Sort tasks by atom index so that grid points are grouped by atom. + // build_atom_reorder_perm assumes this contiguous-by-atom layout. + std::stable_sort(tasks.begin(), tasks.end(), + [](const auto& a, const auto& b) { return a.iParent < b.iParent; }); int total_npts = std::accumulate( tasks.begin(), tasks.end(), 0, [](const auto& a, const auto& b) { return a + b.npts; } ); grid_coords.reserve(total_npts * 3); grid_weights.reserve(total_npts); + raw_grid_weights.reserve(total_npts); den_eval.reserve(total_npts * ndm); dden_eval.resize(total_npts * 6); // 2 values per point, 3 components tau.reserve(total_npts * ndm); @@ -690,6 +694,7 @@ FeatureDict prepare_onedft_features(const int ndm, std::vector& tasks, c grid_coords.push_back(point[2]); } std::copy(task.weights.begin(), task.weights.end(), std::back_inserter(grid_weights)); + std::copy(task.raw_weights.begin(), task.raw_weights.end(), std::back_inserter(raw_grid_weights)); std::copy(task.feat.den_eval.begin(), task.feat.den_eval.end(), std::back_inserter(den_eval)); if (task.feat.dden_x_eval.size() != 0){ @@ -708,15 +713,49 @@ FeatureDict prepare_onedft_features(const int ndm, std::vector& tasks, c offset += task.points.size(); std::copy(task.feat.tau.begin(), task.feat.tau.end(), std::back_inserter(tau)); } - - int world_rank = rt.comm_rank(); + + // Compute per-atom grid sizes from local tasks + int natoms = mol.size(); + std::vector atomic_grid_sizes_vec(natoms, 0); + for (const auto& task : tasks) { + if (task.iParent >= 0 && task.iParent < natoms) { + atomic_grid_sizes_vec[task.iParent] += task.npts; + } + } + + // MPI gather all data to rank 0 and reorder from rank-order to atom-order + int world_rank = rt.comm_rank(); + int local_npts = total_npts; // save before gather overwrites + auto reorder_result = mpi_gather_and_reorder( + den_eval, dden_eval, tau, grid_coords, grid_weights, + atomic_grid_sizes_vec, total_npts, natoms, rt, sendcounts, displs); + total_npts = reorder_result.total_npts; + atom_reorder_inv_perm = std::move(reorder_result.inv_perm); + auto& global_atomic_grid_sizes_vec = reorder_result.global_atomic_grid_sizes; + + // Gather and reorder raw_grid_weights using the same MPI layout GAUXC_MPI_CODE( - total_npts = mpi_gather_onedft_inputs(den_eval, dden_eval, tau, grid_coords, grid_weights, total_npts, - world_rank, rt.comm_size(), sendcounts, displs); - ); + if (rt.comm_size() > 1) { + std::vector recv_raw(world_rank == 0 ? total_npts : 0); + MPI_Gatherv(raw_grid_weights.data(), local_npts, MPI_DOUBLE, + recv_raw.data(), sendcounts.data(), displs.data(), + MPI_DOUBLE, 0, rt.comm()); + if (world_rank == 0) { + raw_grid_weights = std::move(recv_raw); + // Reconstruct forward perm from inv_perm and reorder + std::vector perm(total_npts); + for (int64_t j = 0; j < total_npts; j++) perm[atom_reorder_inv_perm[j]] = j; + std::vector tmp(total_npts); + for (int64_t i = 0; i < total_npts; i++) tmp[perm[i]] = raw_grid_weights[i]; + raw_grid_weights = std::move(tmp); + } + } + ) + FeatureDict featmap; if (world_rank == 0) { - int natoms = mol.size(); + int64_t max_grid_size = *std::max_element( + global_atomic_grid_sizes_vec.begin(), global_atomic_grid_sizes_vec.end()); std::vector coarse_0_atomic_coords (natoms*3); for (int i = 0; i < natoms; i++) { coarse_0_atomic_coords[3*i] = mol[i].x; @@ -759,6 +798,21 @@ FeatureDict prepare_onedft_features(const int ndm, std::vector& tasks, c tensor = flat_tensor.clone(); break; } + case ONEDFT_FEATURE::ATOMIC_GRID_WEIGHTS: { + auto flat_tensor = torch::from_blob(raw_grid_weights.data(), {total_npts}, options); + tensor = flat_tensor.clone(); + break; + } + case ONEDFT_FEATURE::ATOMIC_GRID_SIZES: { + auto sizes_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU); + tensor = torch::from_blob(global_atomic_grid_sizes_vec.data(), {natoms}, sizes_options).clone(); + break; + } + case ONEDFT_FEATURE::ATOMIC_GRID_SIZE_BOUND_SHAPE: { + auto sizes_options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU); + tensor = torch::zeros({max_grid_size, 0}, sizes_options); + break; + } default: GAUXC_GENERIC_EXCEPTION("Feature Key Not Implemented: " + key); } @@ -772,11 +826,13 @@ FeatureDict prepare_onedft_features(const int ndm, std::vector& tasks, c } void send_buffer_onedft_outputs(const int ndm, const FeatureDict features_dict, std::vector& tasks, - const RuntimeEnvironment& rt, std::vector sendcounts, std::vector displs) { + const RuntimeEnvironment& rt, std::vector sendcounts, std::vector displs, + const std::vector& atom_reorder_inv_perm) { std::vector den_eval, dden_eval, tau; auto total_npts = mpi_scatter_onedft_outputs(features_dict, rt.comm_rank(), rt.comm_size(), - sendcounts, displs, den_eval, dden_eval, tau); + sendcounts, displs, atom_reorder_inv_perm, + den_eval, dden_eval, tau); size_t offset = 0; for (auto&task : tasks) { @@ -837,5 +893,571 @@ void send_buffer_onedft_outputs(const int ndm, const FeatureDict features_dict, // } +// ============================================================================ +// OneDFT EXC Gradient +// ============================================================================ + +template +void ReferenceReplicatedXCHostIntegrator:: + eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + value_type* EXC_GRAD, const IntegratorSettingsXC& settings ) { + + const auto& basis = this->load_balancer_->basis(); + const int64_t nbf = basis.nbf(); + if( m != n ) GAUXC_GENERIC_EXCEPTION("P Must Be Square"); + if( m != nbf ) GAUXC_GENERIC_EXCEPTION("P Must Have Same Dimension as Basis"); + if( ldps < nbf) GAUXC_GENERIC_EXCEPTION("Invalid LDPS"); + if( ldpz && ldpz < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDPZ"); + + const bool is_uks = (Pz != nullptr); + if (not is_uks) { + GAUXC_GENERIC_EXCEPTION("RKS OneDFT gradient Not Yet Implemented"); + } + + // Get Tasks + auto& tasks = this->load_balancer_->get_tasks(); + auto rt = this->load_balancer_->runtime(); + int32_t world_rank = rt.comm_rank(); + + // Load model + OneDFTSettings onedft_settings; + if( auto* tmp = dynamic_cast(&settings) ) { + onedft_settings = *tmp; + } + const auto model_path = onedft_settings.model; + torch::DeviceType device = torch::kCPU; + auto [exc_func, feature_keys] = load_model(model_path, device); + + // Determine feature requirements + bool is_gga = false; + bool is_mgga = false; + for (const auto& key : feature_keys) { + if ( not valueExists(key) ) GAUXC_GENERIC_EXCEPTION("Feature Key Required Not Implemented: " + key); + if (key == feat_map.at(ONEDFT_FEATURE::TAU)) is_mgga = true; + if (key == feat_map.at(ONEDFT_FEATURE::DDEN)) is_gga = true; + } + if (is_mgga) is_gga = false; + + value_type N_EL; + + // Step 1: Pre-work (basis eval, density computation) + this->timer_.time_op("XCIntegrator.LocalWork", [&](){ + pre_onedft_local_work_( basis, Ps, ldps, Pz, ldpz, &N_EL, is_gga, is_mgga, false); + }); + + // Step 2: Gather features and build torch tensors + std::vector sendcounts(rt.comm_size(), 0); + std::vector displs(rt.comm_size(), 0); + std::vector atom_reorder_inv_perm; + FeatureDict features_dict = prepare_onedft_features(2/*ndm*/, tasks, + this->load_balancer_->molecule(), feature_keys, rt, + sendcounts, displs, atom_reorder_inv_perm); + + // Step 3: Forward + backward with grad on points and coords + std::vector eps_on_grid_global; // exc_on_grid values for weight derivative + std::vector points_grad_global; // [total_npts * 3] + std::vector coords_grad_global; // [natoms * 3] + const int natoms = this->load_balancer_->molecule().natoms(); + + if (world_rank == 0) { + // Enable requires_grad on points and coords tensors + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::POINTS)) != features_dict.end()) { + features_dict.at(feat_map.at(ONEDFT_FEATURE::POINTS)).requires_grad_(true); + } + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::COORDS)) != features_dict.end()) { + features_dict.at(feat_map.at(ONEDFT_FEATURE::COORDS)).requires_grad_(true); + } + + auto exc_on_grid = get_exc(exc_func, features_dict); + if (exc_on_grid.isnan().any().item()) { + GAUXC_GENERIC_EXCEPTION("exc_on_grid has NaN"); + } + auto exc = (exc_on_grid * features_dict.at(feat_map.at(ONEDFT_FEATURE::WEIGHTS))).sum(); + exc.backward(); + + // Extract eps_on_grid for weight derivative term + int total_npts = exc_on_grid.size(0); + at::Tensor eps_cpu = exc_on_grid.detach().cpu().contiguous(); + eps_on_grid_global.resize(total_npts); + std::memcpy(eps_on_grid_global.data(), eps_cpu.data_ptr(), total_npts * sizeof(double)); + + // Extract points.grad() -> per-grid-point forces + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::POINTS)) != features_dict.end()) { + auto pg = features_dict.at(feat_map.at(ONEDFT_FEATURE::POINTS)).grad(); + if (pg.defined()) { + at::Tensor pg_cpu = pg.cpu().contiguous(); + points_grad_global.resize(total_npts * 3); + std::memcpy(points_grad_global.data(), pg_cpu.data_ptr(), total_npts * 3 * sizeof(double)); + } + } + + // Extract coords.grad() -> per-atom forces + if (features_dict.find(feat_map.at(ONEDFT_FEATURE::COORDS)) != features_dict.end()) { + auto cg = features_dict.at(feat_map.at(ONEDFT_FEATURE::COORDS)).grad(); + if (cg.defined()) { + at::Tensor cg_cpu = cg.cpu().contiguous(); + coords_grad_global.resize(natoms * 3); + std::memcpy(coords_grad_global.data(), cg_cpu.data_ptr(), natoms * 3 * sizeof(double)); + } + } + + // Reorder eps_on_grid from atom-order back to rank-order for scatter + if (!atom_reorder_inv_perm.empty()) { + std::vector tmp(total_npts); + for (int64_t i = 0; i < total_npts; i++) { + tmp[atom_reorder_inv_perm[i]] = eps_on_grid_global[i]; + } + eps_on_grid_global = std::move(tmp); + + // Also reorder points_grad + if (!points_grad_global.empty()) { + std::vector tmp3(total_npts * 3); + for (int64_t i = 0; i < total_npts; i++) { + int64_t j = atom_reorder_inv_perm[i]; + tmp3[j*3+0] = points_grad_global[i*3+0]; + tmp3[j*3+1] = points_grad_global[i*3+1]; + tmp3[j*3+2] = points_grad_global[i*3+2]; + } + points_grad_global = std::move(tmp3); + } + } + } + + // Step 4: Scatter Vxc back to tasks + send_buffer_onedft_outputs(2/*ndm*/, features_dict, tasks, rt, sendcounts, displs, atom_reorder_inv_perm); + + // Scatter eps_on_grid to local tasks (for single rank, just distribute) + // For MPI, would need MPI_Scatterv — for now handle single rank + std::vector eps_on_grid_local; + if (rt.comm_size() == 1) { + eps_on_grid_local = std::move(eps_on_grid_global); + } else { + // TODO: MPI scatter of eps_on_grid + GAUXC_GENERIC_EXCEPTION("OneDFT gradient with MPI not yet implemented"); + } + + // Zero out EXC_GRAD + for (int i = 0; i < 3*natoms; ++i) EXC_GRAD[i] = 0.0; + + // Step 6: Add autograd forces BEFORE Pulay (which re-sorts tasks!) + // points.grad gives ∂E/∂r_g. Since grid points move with their parent atom, + // the force on atom A = Σ_{g∈A} points_grad[g]. + // NOTE: Must be done while tasks are still in iParent-sorted order + // (matching points_grad_global layout). exc_grad_local_work_onedft_ + // re-sorts tasks by workload, breaking the correspondence. + if (!points_grad_global.empty() && world_rank == 0) { + size_t offset = 0; + for (const auto& task : tasks) { + int iParent = task.iParent; + for (size_t ipt = 0; ipt < task.points.size(); ++ipt) { + EXC_GRAD[3*iParent + 0] += points_grad_global[(offset + ipt)*3 + 0]; + EXC_GRAD[3*iParent + 1] += points_grad_global[(offset + ipt)*3 + 1]; + EXC_GRAD[3*iParent + 2] += points_grad_global[(offset + ipt)*3 + 2]; + } + offset += task.points.size(); + } + } + + // coords.grad gives ∂E/∂R_A directly (no task-order dependence) + if (!coords_grad_global.empty() && world_rank == 0) { + for (int a = 0; a < natoms; ++a) { + EXC_GRAD[3*a + 0] += coords_grad_global[3*a + 0]; + EXC_GRAD[3*a + 1] += coords_grad_global[3*a + 1]; + EXC_GRAD[3*a + 2] += coords_grad_global[3*a + 2]; + } + } + + // Step 5: Pulay + weight derivative term (re-sorts tasks internally!) + this->timer_.time_op("XCIntegrator.LocalWork2", [&](){ + exc_grad_local_work_onedft_( Ps, ldps, Pz, ldpz, EXC_GRAD, eps_on_grid_local, is_gga, is_mgga); + }); + + // Step 7: Allreduce + this->timer_.time_op("XCIntegrator.Allreduce", [&](){ + if( not this->reduction_driver_->takes_host_memory() ) + GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions"); + this->reduction_driver_->allreduce_inplace( EXC_GRAD, 3*natoms, ReductionOp::Sum ); + }); +} + + +// Pulay + weight derivative local work using OneDFT Vxc format +template +void ReferenceReplicatedXCHostIntegrator:: + exc_grad_local_work_onedft_( const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + value_type* EXC_GRAD, + const std::vector& eps_on_grid, + const bool is_gga, const bool is_mgga) { + + const bool is_uks = Pz != nullptr; + const bool is_rks = not is_uks; + + auto* lwd = dynamic_cast(this->local_work_driver_.get()); + + const auto& basis = this->load_balancer_->basis(); + const auto& mol = this->load_balancer_->molecule(); + const auto& molmeta = this->load_balancer_->molmeta(); + + // Weight derivative settings + auto& lb_state = this->load_balancer_->state(); + if( not lb_state.modified_weights_are_stored ) { + GAUXC_GENERIC_EXCEPTION("Weights Have Not Been Modified"); + } + XCWeightAlg& weight_alg = lb_state.weight_alg; + + BasisSetMap basis_map(basis, mol); + const int32_t nbf = basis.nbf(); + const int32_t natoms = mol.natoms(); + + auto& tasks = this->load_balancer_->get_tasks(); + const size_t ntasks = tasks.size(); + + // Sort tasks for load balancing + auto task_comparator = []( const XCTask& a, const XCTask& b ) { + return (a.points.size() * a.bfn_screening.nbe) > (b.points.size() * b.bfn_screening.nbe); + }; + std::sort( tasks.begin(), tasks.end(), task_comparator ); + + // Build global eps_on_grid offset map: since tasks may be re-sorted, + // we need to distribute eps_on_grid to tasks. We use the task ordering + // from send_buffer_onedft_outputs (which was sorted by iParent). + // After re-sorting by task_comparator, we need a different approach. + // Actually, eps_on_grid_local was already in the send_buffer_onedft_outputs + // task order (sorted by iParent). The tasks are now re-sorted. + // We need to store eps per-task before re-sorting. + // + // WORKAROUND: Store per-task eps in task.feat before sorting. + // Actually, the simpler approach: don't re-sort. Use the current task order. + // The tasks were already sorted by iParent from prepare_onedft_features. + // Let's just rebuild the eps_per_task mapping. + + // Build task -> eps mapping from the eps_on_grid vector (in iParent-sorted order) + // First, re-sort back to iParent order to match eps_on_grid + std::stable_sort( tasks.begin(), tasks.end(), + [](const auto& a, const auto& b) { return a.iParent < b.iParent; }); + + // Distribute eps_on_grid to per-task storage + { + size_t offset = 0; + for (auto& task : tasks) { + int64_t npts = task.points.size(); + task.feat.eps.resize(npts); + std::copy(eps_on_grid.data() + offset, + eps_on_grid.data() + offset + npts, + task.feat.eps.begin()); + offset += npts; + } + } + + // Now sort by workload for the Pulay loop + std::sort( tasks.begin(), tasks.end(), task_comparator ); + + #pragma omp parallel + { + + XCHostData host_data; + + #pragma omp for schedule(dynamic) + for( size_t iT = 0; iT < ntasks; ++iT ) { + + auto& task = tasks[iT]; + const int32_t npts = task.points.size(); + const int32_t nbe = task.bfn_screening.nbe; + const int32_t nshells = task.bfn_screening.shell_list.size(); + + const auto* points = task.points.data()->data(); + const auto* weights = task.weights.data(); + const int32_t* shell_list = task.bfn_screening.shell_list.data(); + + // Allocate memory for basis evaluation (up to hessian for GGA/MGGA) + if (is_gga || is_mgga) { + host_data.basis_eval.resize(10 * npts * nbe); // B, dB_xyz, d2B_6 + host_data.zmat.resize(4 * 2 * npts * nbe); // xN, xN_xyz, xZ, xZ_xyz + } else { + host_data.basis_eval.resize(4 * npts * nbe); // B, dB_xyz + host_data.zmat.resize(2 * npts * nbe); // xN, xZ + } + host_data.nbe_scr.resize(nbe * nbe); + host_data.eps.resize(npts); + + auto* basis_eval = host_data.basis_eval.data(); + auto* nbe_scr = host_data.nbe_scr.data(); + auto* eps_buf = host_data.eps.data(); + + auto* dbasis_x_eval = basis_eval + npts * nbe; + auto* dbasis_y_eval = dbasis_x_eval + npts * nbe; + auto* dbasis_z_eval = dbasis_y_eval + npts * nbe; + + value_type* d2basis_xx_eval = nullptr; + value_type* d2basis_xy_eval = nullptr; + value_type* d2basis_xz_eval = nullptr; + value_type* d2basis_yy_eval = nullptr; + value_type* d2basis_yz_eval = nullptr; + value_type* d2basis_zz_eval = nullptr; + + if (is_gga || is_mgga) { + d2basis_xx_eval = dbasis_z_eval + npts * nbe; + d2basis_xy_eval = d2basis_xx_eval + npts * nbe; + d2basis_xz_eval = d2basis_xy_eval + npts * nbe; + d2basis_yy_eval = d2basis_xz_eval + npts * nbe; + d2basis_yz_eval = d2basis_yy_eval + npts * nbe; + d2basis_zz_eval = d2basis_yz_eval + npts * nbe; + } + + // X-matrix pointers + auto* xNmat = host_data.zmat.data(); + value_type* xNmat_x = nullptr; + value_type* xNmat_y = nullptr; + value_type* xNmat_z = nullptr; + value_type* xZmat = nullptr; + value_type* xZmat_x = nullptr; + value_type* xZmat_y = nullptr; + value_type* xZmat_z = nullptr; + + if (is_gga || is_mgga) { + xNmat_x = xNmat + npts*nbe; + xNmat_y = xNmat_x + npts*nbe; + xNmat_z = xNmat_y + npts*nbe; + xZmat = xNmat_z + npts*nbe; + xZmat_x = xZmat + npts*nbe; + xZmat_y = xZmat_x + npts*nbe; + xZmat_z = xZmat_y + npts*nbe; + } else { + xZmat = xNmat + npts*nbe; + } + + // Get submat map + auto [submat_map, foo] = + gen_compressed_submat_map( basis_map, task.bfn_screening.shell_list, nbf, nbf ); + + // Evaluate collocation (gradient + hessian for GGA/MGGA) + if (is_gga || is_mgga) { + lwd->eval_collocation_hessian( npts, nshells, nbe, points, basis, shell_list, + basis_eval, dbasis_x_eval, dbasis_y_eval, dbasis_z_eval, + d2basis_xx_eval, d2basis_xy_eval, d2basis_xz_eval, + d2basis_yy_eval, d2basis_yz_eval, d2basis_zz_eval ); + } else { + lwd->eval_collocation_gradient( npts, nshells, nbe, points, basis, shell_list, + basis_eval, dbasis_x_eval, dbasis_y_eval, dbasis_z_eval ); + } + + // Evaluate X-matrices: xN = Ps * B, xZ = Pz * B + const int xmat_len = (is_gga || is_mgga) ? 4 : 1; + lwd->eval_xmat( xmat_len*npts, nbf, nbe, submat_map, 1.0, Ps, ldps, basis_eval, nbe, + xNmat, nbe, nbe_scr ); + if (is_uks) { + lwd->eval_xmat( xmat_len*npts, nbf, nbe, submat_map, 1.0, Pz, ldpz, basis_eval, nbe, + xZmat, nbe, nbe_scr ); + } + + // Read OneDFT Vxc from task.feat (already includes grid weights from autograd) + const value_type* vdden_a = task.feat.vdden_eval_a.data(); + const value_type* vdden_b = task.feat.vdden_eval_b.data(); + const value_type* vdden_x_a = nullptr; + const value_type* vdden_y_a = nullptr; + const value_type* vdden_z_a = nullptr; + const value_type* vdden_x_b = nullptr; + const value_type* vdden_y_b = nullptr; + const value_type* vdden_z_b = nullptr; + const value_type* vtau_data = nullptr; + + if (is_gga || is_mgga) { + vdden_x_a = task.feat.vdden_x_eval_a.data(); + vdden_y_a = task.feat.vdden_y_eval_a.data(); + vdden_z_a = task.feat.vdden_z_eval_a.data(); + vdden_x_b = task.feat.vdden_x_eval_b.data(); + vdden_y_b = task.feat.vdden_y_eval_b.data(); + vdden_z_b = task.feat.vdden_z_eval_b.data(); + } + if (is_mgga) { + vtau_data = task.feat.vtau.data(); + } + + // --- Weight derivative term --- + // eps_contracted[ipt] = exc_on_grid[ipt] * w[ipt] + for (int ipt = 0; ipt < npts; ++ipt) { + eps_buf[ipt] = task.feat.eps[ipt] * weights[ipt]; + } + lwd->eval_weight_1st_deriv_contracted( weight_alg, mol, molmeta, + task, eps_buf, EXC_GRAD); + + // --- Pulay gradient loop --- + // Using OneDFT's native per-component Vxc (weights already included) + size_t bf_off = 0; + for (auto ish = 0; ish < nshells; ++ish) { + const int sh_idx = shell_list[ish]; + const int sh_sz = basis[sh_idx].size(); + const int iAt = basis_map.shell_to_center( sh_idx ); + + // Skip basis functions on the parent atom (handled by weight derivative) + if (iAt == task.iParent) { + bf_off += sh_sz; + continue; + } + + double g_acc_x(0), g_acc_y(0), g_acc_z(0); + + for (int ibf = 0, mu = bf_off; ibf < sh_sz; ++ibf, ++mu) + for (int ipt = 0; ipt < npts; ++ipt) { + + const int32_t mu_i = mu + ipt*nbe; + + // OneDFT Vxc: vdden_a = w * ∂ε/∂ρ_α, vdden_b = w * ∂ε/∂ρ_β + // vrho_s = vdden_a + vdden_b (total density derivative, weighted) + // vrho_z = vdden_a - vdden_b (magnetization derivative, weighted) + const double vrho_s = vdden_a[ipt] + vdden_b[ipt]; + const double vrho_z = vdden_a[ipt] - vdden_b[ipt]; + + const double xN = xNmat[mu_i]; + const double xZ = is_uks ? xZmat[mu_i] : 0.0; + + const double dbx = dbasis_x_eval[mu_i]; + const double dby = dbasis_y_eval[mu_i]; + const double dbz = dbasis_z_eval[mu_i]; + + // LDA contribution (no separate weight multiplication — already in Vxc) + g_acc_x += 0.5 * vrho_s * xN * dbx; + g_acc_y += 0.5 * vrho_s * xN * dby; + g_acc_z += 0.5 * vrho_s * xN * dbz; + + if (is_uks) { + g_acc_x += 0.5 * vrho_z * xZ * dbx; + g_acc_y += 0.5 * vrho_z * xZ * dby; + g_acc_z += 0.5 * vrho_z * xZ * dbz; + } + + if (is_gga || is_mgga) { + // GGA contribution using OneDFT per-component derivatives + // vdden_d_a = w * ∂ε/∂(∂ρ_α/∂d) + // Force = Σ_d vdden_d_s * (d2B_{cd} * xN + dBc * xN_d) + z terms + const double vds_x = 0.5 * (vdden_x_a[ipt] + vdden_x_b[ipt]); + const double vds_y = 0.5 * (vdden_y_a[ipt] + vdden_y_b[ipt]); + const double vds_z = 0.5 * (vdden_z_a[ipt] + vdden_z_b[ipt]); + const double vdz_x = 0.5 * (vdden_x_a[ipt] - vdden_x_b[ipt]); + const double vdz_y = 0.5 * (vdden_y_a[ipt] - vdden_y_b[ipt]); + const double vdz_z = 0.5 * (vdden_z_a[ipt] - vdden_z_b[ipt]); + + const double xNx = xNmat_x[mu_i]; + const double xNy = xNmat_y[mu_i]; + const double xNz = xNmat_z[mu_i]; + const double xZx = is_uks ? xZmat_x[mu_i] : 0.0; + const double xZy = is_uks ? xZmat_y[mu_i] : 0.0; + const double xZz = is_uks ? xZmat_z[mu_i] : 0.0; + + const double d2bxx = d2basis_xx_eval[mu_i]; + const double d2bxy = d2basis_xy_eval[mu_i]; + const double d2bxz = d2basis_xz_eval[mu_i]; + const double d2byy = d2basis_yy_eval[mu_i]; + const double d2byz = d2basis_yz_eval[mu_i]; + const double d2bzz = d2basis_zz_eval[mu_i]; + + // s (total) contribution: Σ_d vds_d * (d2B_{c,d} * xN + dBc * xN_d) + // x-component of force: + const double d2_xN_x = d2bxx * xN + dbx * xNx; + const double d2_xN_y = d2bxy * xN + dbx * xNy; + const double d2_xN_z = d2bxz * xN + dbx * xNz; + g_acc_x += vds_x * d2_xN_x + vds_y * d2_xN_y + vds_z * d2_xN_z; + + // y-component: + const double d2_yN_x = d2bxy * xN + dby * xNx; + const double d2_yN_y = d2byy * xN + dby * xNy; + const double d2_yN_z = d2byz * xN + dby * xNz; + g_acc_y += vds_x * d2_yN_x + vds_y * d2_yN_y + vds_z * d2_yN_z; + + // z-component: + const double d2_zN_x = d2bxz * xN + dbz * xNx; + const double d2_zN_y = d2byz * xN + dbz * xNy; + const double d2_zN_z = d2bzz * xN + dbz * xNz; + g_acc_z += vds_x * d2_zN_x + vds_y * d2_zN_y + vds_z * d2_zN_z; + + if (is_uks) { + // z (magnetization) contribution: Σ_d vdz_d * (d2B_{c,d} * xZ + dBc * xZ_d) + const double d2_xZ_x = d2bxx * xZ + dbx * xZx; + const double d2_xZ_y = d2bxy * xZ + dbx * xZy; + const double d2_xZ_z = d2bxz * xZ + dbx * xZz; + g_acc_x += vdz_x * d2_xZ_x + vdz_y * d2_xZ_y + vdz_z * d2_xZ_z; + + const double d2_yZ_x = d2bxy * xZ + dby * xZx; + const double d2_yZ_y = d2byy * xZ + dby * xZy; + const double d2_yZ_z = d2byz * xZ + dby * xZz; + g_acc_y += vdz_x * d2_yZ_x + vdz_y * d2_yZ_y + vdz_z * d2_yZ_z; + + const double d2_zZ_x = d2bxz * xZ + dbz * xZx; + const double d2_zZ_y = d2byz * xZ + dbz * xZy; + const double d2_zZ_z = d2bzz * xZ + dbz * xZz; + g_acc_z += vdz_x * d2_zZ_x + vdz_y * d2_zZ_y + vdz_z * d2_zZ_z; + } + } + + if (is_mgga) { + // MGGA τ contribution + // vtau is interleaved [α₀, β₀, α₁, β₁, ...] + const double vtaup = 0.5 * vtau_data[2*ipt]; // α, already weighted + const double vtaum = 0.5 * vtau_data[2*ipt + 1]; // β, already weighted + const double vtaun = vtaup + vtaum; + const double vtauz = vtaup - vtaum; + + const double xNx = xNmat_x[mu_i]; + const double xNy = xNmat_y[mu_i]; + const double xNz = xNmat_z[mu_i]; + + const double d2bxx = d2basis_xx_eval[mu_i]; + const double d2bxy = d2basis_xy_eval[mu_i]; + const double d2bxz = d2basis_xz_eval[mu_i]; + const double d2byy = d2basis_yy_eval[mu_i]; + const double d2byz = d2basis_yz_eval[mu_i]; + const double d2bzz = d2basis_zz_eval[mu_i]; + + auto d2_term_x = d2bxx * xNx + d2bxy * xNy + d2bxz * xNz; + auto d2_term_y = d2bxy * xNx + d2byy * xNy + d2byz * xNz; + auto d2_term_z = d2bxz * xNx + d2byz * xNy + d2bzz * xNz; + + g_acc_x += 0.5 * vtaun * d2_term_x; + g_acc_y += 0.5 * vtaun * d2_term_y; + g_acc_z += 0.5 * vtaun * d2_term_z; + + if (is_uks) { + const double xZx = xZmat_x[mu_i]; + const double xZy = xZmat_y[mu_i]; + const double xZz = xZmat_z[mu_i]; + + d2_term_x = d2bxx * xZx + d2bxy * xZy + d2bxz * xZz; + d2_term_y = d2bxy * xZx + d2byy * xZy + d2byz * xZz; + d2_term_z = d2bxz * xZx + d2byz * xZy + d2bzz * xZz; + + g_acc_x += 0.5 * vtauz * d2_term_x; + g_acc_y += 0.5 * vtauz * d2_term_y; + g_acc_z += 0.5 * vtauz * d2_term_z; + } + } + + } // end loop over bfns + grid points + + #pragma omp atomic + EXC_GRAD[3*iAt + 0] += -2 * g_acc_x; + #pragma omp atomic + EXC_GRAD[3*iAt + 1] += -2 * g_acc_y; + #pragma omp atomic + EXC_GRAD[3*iAt + 2] += -2 * g_acc_z; + + // Weight derivative counterpart for non-parent atoms + #pragma omp atomic + EXC_GRAD[3*task.iParent + 0] -= -2 * g_acc_x; + #pragma omp atomic + EXC_GRAD[3*task.iParent + 1] -= -2 * g_acc_y; + #pragma omp atomic + EXC_GRAD[3*task.iParent + 2] -= -2 * g_acc_z; + + bf_off += sh_sz; + + } // end loop over shells + + } // end loop over tasks + + } // end OpenMP region +} + } // namespace detail } // namespace GauXC diff --git a/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx b/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx index 3cb0d74ce..550028a09 100644 --- a/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx +++ b/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx @@ -157,6 +157,16 @@ void ReplicatedXCIntegratorImpl:: } + +template +void ReplicatedXCIntegratorImpl:: + eval_exc_grad_onedft( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) { + + eval_exc_grad_onedft_(m,n,Ps,ldps,Pz,ldpz,EXC_GRAD, ks_settings); + +} + template void ReplicatedXCIntegratorImpl:: eval_exx( int64_t m, int64_t n, const value_type* P, diff --git a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp index 41419b85e..40e9512f7 100644 --- a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp @@ -95,6 +95,9 @@ class ShellBatchedReplicatedXCIntegrator : /// UKS EXC Gradient void eval_exc_grad_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& settings ) override; + /// OneDFT EXC Gradient + void eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& settings ) override; /// sn-LinK void eval_exx_( int64_t m, int64_t n, const value_type* P, diff --git a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp index f329bc025..1b594e317 100644 --- a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp @@ -34,5 +34,14 @@ void ShellBatchedReplicatedXCIntegrator +void ShellBatchedReplicatedXCIntegrator:: + eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& settings ) { + + GAUXC_GENERIC_EXCEPTION("ShellBatched exc_grad_onedft NYI" ); + util::unused(m,n,Ps,ldps,Pz,ldpz,EXC_GRAD); +} + } } diff --git a/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx b/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx index 15c812619..83e4e1119 100644 --- a/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx +++ b/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx @@ -197,6 +197,12 @@ void XCDeviceStackData::allocate_static_data_onedft( int32_t nbf, int32_t nshell static_stack.tau_grad_device = mem.aligned_alloc( 2 * total_npts , csl ); } + // Allocate gradient accumulator for OneDFT gradient computation + if( enabled_terms.exc_grad ) { + static_stack.exc_grad_device = mem.aligned_alloc( 3 * natoms, csl ); + allocated_terms.exc_grad = true; + } + // Get current stack location dynmem_ptr = mem.stack(); dynmem_sz = mem.nleft(); diff --git a/tests/onedft_test.cxx b/tests/onedft_test.cxx index ee3130444..f211297f4 100644 --- a/tests/onedft_test.cxx +++ b/tests/onedft_test.cxx @@ -149,4 +149,177 @@ TEST_CASE( "OneDFT", "[onedft]" ) { SECTION( " HE / def2-qzvp / lda.fun" ) { test_integrator( GAUXC_REF_DATA_PATH "/onedft_he_def2qzvp_lda_uks.hdf5", GAUXC_ONEDFT_MODEL_PATH "/lda.fun" ); } +} + +#include +// Include the OneDFT utility header for reorder helper functions +#include "../../src/xc_integrator/integrator_util/onedft_util.hpp" + +TEST_CASE( "Atom Reorder Permutation", "[onedft][reorder]" ) { + + // Scenario: 2 ranks, 3 atoms + // Rank 0 has: atom0=2pts, atom1=3pts, atom2=1pt (6 pts total) + // Rank 1 has: atom0=1pt, atom1=0pts, atom2=2pts (3 pts total) + // + // Rank-ordered layout (what MPI_Gatherv produces): + // [r0_a0(2), r0_a1(3), r0_a2(1), r1_a0(1), r1_a1(0), r1_a2(2)] + // indices: 0 1 | 2 3 4 | 5 | 6 | | 7 8 + // + // Atom-ordered layout (what we want): + // [a0_r0(2), a0_r1(1), a1_r0(3), a1_r1(0), a2_r0(1), a2_r1(2)] + // indices: 0 1 | 2 | 3 4 5 | | 6 | 7 8 + + int natoms = 3; + int world_size = 2; + // all_rank_atom_sizes: [rank0_atom0, rank0_atom1, rank0_atom2, rank1_atom0, rank1_atom1, rank1_atom2] + std::vector all_rank_atom_sizes = {2, 3, 1, 1, 0, 2}; + std::vector sendcounts = {6, 3}; + std::vector displs = {0, 6}; + + SECTION("Permutation correctness") { + auto [perm, inv_perm] = GauXC::build_atom_reorder_perm( + all_rank_atom_sizes, sendcounts, displs, natoms, world_size); + + REQUIRE(perm.size() == 9); + REQUIRE(inv_perm.size() == 9); + + // Expected mapping: + // rank-ordered idx -> atom-ordered idx + // r0_a0: src 0->dst 0, src 1->dst 1 + // r0_a1: src 2->dst 3, src 3->dst 4, src 4->dst 5 + // r0_a2: src 5->dst 6 + // r1_a0: src 6->dst 2 + // r1_a1: (empty) + // r1_a2: src 7->dst 7, src 8->dst 8 + CHECK(perm[0] == 0); + CHECK(perm[1] == 1); + CHECK(perm[2] == 3); + CHECK(perm[3] == 4); + CHECK(perm[4] == 5); + CHECK(perm[5] == 6); + CHECK(perm[6] == 2); + CHECK(perm[7] == 7); + CHECK(perm[8] == 8); + + // Round-trip: inv_perm[perm[i]] == i + for (int64_t i = 0; i < 9; ++i) { + CHECK(inv_perm[perm[i]] == i); + } + } + + SECTION("Strided permutation with stride 3 (coords)") { + auto [perm, inv_perm] = GauXC::build_atom_reorder_perm( + all_rank_atom_sizes, sendcounts, displs, natoms, world_size); + + // 9 points, stride 3 -> 27 doubles + // Rank-ordered data: point i has values [i*10, i*10+1, i*10+2] + std::vector src(27); + for (int i = 0; i < 9; ++i) { + src[i*3] = i * 10.0; + src[i*3+1] = i * 10.0 + 1.0; + src[i*3+2] = i * 10.0 + 2.0; + } + + std::vector dst(27, -1.0); + GauXC::apply_strided_permutation(src.data(), dst.data(), perm, 9, 3); + + // Verify: dst[perm[i]*3..] should equal src[i*3..] + for (int i = 0; i < 9; ++i) { + int64_t j = perm[i]; + CHECK(dst[j*3] == src[i*3]); + CHECK(dst[j*3+1] == src[i*3+1]); + CHECK(dst[j*3+2] == src[i*3+2]); + } + } + + SECTION("Round-trip: forward then inverse restores original") { + auto [perm, inv_perm] = GauXC::build_atom_reorder_perm( + all_rank_atom_sizes, sendcounts, displs, natoms, world_size); + + // Stride 1 (like grid_weights) + std::vector original(9); + for (int i = 0; i < 9; ++i) original[i] = i * 1.5 + 0.7; + + // Forward: rank-ordered -> atom-ordered + std::vector atom_ordered(9); + GauXC::apply_strided_permutation(original.data(), atom_ordered.data(), perm, 9, 1); + + // Inverse: atom-ordered -> rank-ordered + std::vector restored(9); + GauXC::apply_strided_permutation(atom_ordered.data(), restored.data(), inv_perm, 9, 1); + + for (int i = 0; i < 9; ++i) { + CHECK(restored[i] == Approx(original[i])); + } + } + + SECTION("Single rank is identity permutation") { + // With 1 rank, no reorder is needed + std::vector single_rank_sizes = {2, 3, 1}; + std::vector sc = {6}; + std::vector dp = {0}; + + auto [perm, inv_perm] = GauXC::build_atom_reorder_perm( + single_rank_sizes, sc, dp, 3, 1); + + for (int64_t i = 0; i < 6; ++i) { + CHECK(perm[i] == i); + CHECK(inv_perm[i] == i); + } + } + + SECTION("reorder_to_atom_order / reorder_to_rank_order round-trip") { + auto [perm, inv_perm] = GauXC::build_atom_reorder_perm( + all_rank_atom_sizes, sendcounts, displs, natoms, world_size); + int64_t npts = 9; + + // Create synthetic interleaved data matching the real layouts + std::vector weights(npts), den(npts*2), coords(npts*3), dden(npts*6), tau_v(npts*2); + for (int64_t i = 0; i < npts; ++i) { + weights[i] = i * 0.1; + den[i*2] = i * 1.0; den[i*2+1] = i * 1.0 + 100; + coords[i*3] = i; coords[i*3+1] = i+0.1; coords[i*3+2] = i+0.2; + for (int c = 0; c < 6; ++c) dden[i*6+c] = i * 10.0 + c; + tau_v[i*2] = i * 5.0; tau_v[i*2+1] = i * 5.0 + 50; + } + // Save originals + auto orig_weights = weights, orig_den = den, orig_coords = coords; + auto orig_dden = dden, orig_tau = tau_v; + + // Forward: rank-order → atom-order + GauXC::reorder_to_atom_order(weights, den, coords, dden, tau_v, perm, npts); + + // Verify data actually changed (perm is non-trivial) + bool any_different = false; + for (int64_t i = 0; i < npts && !any_different; ++i) + if (weights[i] != orig_weights[i]) any_different = true; + CHECK(any_different); + + // Now simulate the gradient path: convert interleaved atom-ordered data + // to channel-first layout (as mpi_scatter_onedft_outputs does) + std::vector grad_den(npts*2), grad_dden(npts*6), grad_tau(npts*2); + // Channel-first: [alpha(npts) | beta(npts)] + for (int64_t i = 0; i < npts; ++i) { + grad_den[i] = den[i*2]; // alpha channel + grad_den[npts+i] = den[i*2+1]; // beta channel + grad_tau[i] = tau_v[i*2]; + grad_tau[npts+i] = tau_v[i*2+1]; + // dden channel-first: [dXa(npts)|dYa|dZa|dXb|dYb|dZb] + for (int c = 0; c < 6; ++c) + grad_dden[c*npts+i] = dden[i*6+c]; + } + + // Inverse: atom-order → rank-order + GauXC::reorder_to_rank_order(grad_den, grad_dden, grad_tau, inv_perm, npts, true, true); + + // Verify round-trip: channel-first rank-ordered should match original interleaved + for (int64_t i = 0; i < npts; ++i) { + CHECK(grad_den[i] == Approx(orig_den[i*2])); + CHECK(grad_den[npts+i] == Approx(orig_den[i*2+1])); + CHECK(grad_tau[i] == Approx(orig_tau[i*2])); + CHECK(grad_tau[npts+i] == Approx(orig_tau[i*2+1])); + for (int c = 0; c < 6; ++c) + CHECK(grad_dden[c*npts+i] == Approx(orig_dden[i*6+c])); + } + } } \ No newline at end of file diff --git a/tests/ref_data/test_mol_c.hdf5 b/tests/ref_data/test_mol_c.hdf5 new file mode 100644 index 000000000..f26235546 Binary files /dev/null and b/tests/ref_data/test_mol_c.hdf5 differ