Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions cmake/gauxc-onedft.cmake
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
find_package(nlohmann_json)
if( NOT nlohmann_json_FOUND )
if(NOT TARGET nlohmann_json::nlohmann_json)
find_package(nlohmann_json)
if( NOT nlohmann_json_FOUND )

message( STATUS "Could Not Find nlohmann_json... Building" )
message( STATUS "NLOHMANN_JSON URL = ${GAUXC_NLOHMANN_JSON_URL}" )
message( STATUS "Could Not Find nlohmann_json... Building" )
message( STATUS "NLOHMANN_JSON URL = ${GAUXC_NLOHMANN_JSON_URL}" )

FetchContent_Declare(
nlohmann_json
URL ${GAUXC_NLOHMANN_JSON_URL}
URL_HASH SHA256=${GAUXC_NLOHMANN_JSON_SHA256}
DOWNLOAD_EXTRACT_TIMESTAMP ON
)
FetchContent_Declare(
nlohmann_json
URL ${GAUXC_NLOHMANN_JSON_URL}
URL_HASH SHA256=${GAUXC_NLOHMANN_JSON_SHA256}
DOWNLOAD_EXTRACT_TIMESTAMP ON
)

FetchContent_GetProperties( nlohmann_json )
if( NOT nlohmann_json_POPULATED )
FetchContent_Populate( nlohmann_json )
endif()
FetchContent_GetProperties( nlohmann_json )
if( NOT nlohmann_json_POPULATED )
FetchContent_Populate( nlohmann_json )
endif()

add_library( nlohmann_json::nlohmann_json INTERFACE IMPORTED )
set_target_properties( nlohmann_json::nlohmann_json PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${nlohmann_json_SOURCE_DIR}/include
)
add_library( nlohmann_json::nlohmann_json INTERFACE IMPORTED )
set_target_properties( nlohmann_json::nlohmann_json PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${nlohmann_json_SOURCE_DIR}/include
)
endif()
endif()

# store and restore CMAKE_CUDA_ARCHITECTURES if Torch clobbers it
# store and restore CMAKE_CUDA_ARCHITECTURES and CMAKE_CUDA_FLAGS if Torch clobbers them
set(_PREV_CUDA_ARCHS "${CMAKE_CUDA_ARCHITECTURES}")
set(_PREV_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
find_package(Torch REQUIRED)
if(CMAKE_CUDA_ARCHITECTURES STREQUAL "OFF")
set(CMAKE_CUDA_ARCHITECTURES "${_PREV_CUDA_ARCHS}" CACHE STRING "Restore CUDA archs after Torch override" FORCE)
message(WARNING "Torch set CMAKE_CUDA_ARCHITECTURES to OFF. Restored previous value: ${CMAKE_CUDA_ARCHITECTURES}")
# Restore CMAKE_CUDA_ARCHITECTURES (Torch may set it to OFF)
if(NOT "${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "${_PREV_CUDA_ARCHS}")
set(CMAKE_CUDA_ARCHITECTURES "${_PREV_CUDA_ARCHS}" CACHE STRING "" FORCE)
message(WARNING "Torch changed CMAKE_CUDA_ARCHITECTURES. Restored previous value: ${CMAKE_CUDA_ARCHITECTURES}")
endif()
# Strip Torch-injected -gencode flags from CMAKE_CUDA_FLAGS (PyTorch issue #71379)
string(REGEX REPLACE " -gencode [^ ]+" "" _cleaned_cuda_flags "${CMAKE_CUDA_FLAGS}")
if(NOT "${_cleaned_cuda_flags}" STREQUAL "${CMAKE_CUDA_FLAGS}")
set(CMAKE_CUDA_FLAGS "${_cleaned_cuda_flags}" CACHE STRING "" FORCE)
message(WARNING "Stripped Torch-injected -gencode flags from CMAKE_CUDA_FLAGS")
endif()
list(REMOVE_ITEM TORCH_LIBRARIES torch::nvtoolsext)
message(STATUS "Torch libraries without nvtoolsext: ${TORCH_LIBRARIES}")
1 change: 1 addition & 0 deletions include/gauxc/gauxc_config.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#cmakedefine GAUXC_HAS_CUTLASS
#cmakedefine GAUXC_HAS_GAU2GRID
#cmakedefine GAUXC_HAS_HDF5
#cmakedefine GAUXC_HAS_ONEDFT
#cmakedefine GAUXC_USE_FAST_RSQRT

#ifdef GAUXC_HAS_ONEDFT
Expand Down
2 changes: 2 additions & 0 deletions include/gauxc/xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class XCIntegrator {

exc_grad_type eval_exc_grad( const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );
exc_grad_type eval_exc_grad( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} );
exc_grad_type eval_exc_grad_onedft( const MatrixType&, const MatrixType&,
const IntegratorSettingsXC& = IntegratorSettingsXC{} );

exx_type eval_exx ( const MatrixType&,
const IntegratorSettingsEXX& = IntegratorSettingsEXX{} );
Expand Down
7 changes: 7 additions & 0 deletions include/gauxc/xc_integrator/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ typename XCIntegrator<MatrixType>::exc_grad_type
return pimpl_->eval_exc_grad(Ps, Pz, ks_settings);
};

template <typename MatrixType>
typename XCIntegrator<MatrixType>::exc_grad_type
XCIntegrator<MatrixType>::eval_exc_grad_onedft( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_exc_grad_onedft(Ps, Pz, ks_settings);
};

template <typename MatrixType>
typename XCIntegrator<MatrixType>::exx_type
XCIntegrator<MatrixType>::eval_exx( const MatrixType& P,
Expand Down
14 changes: 14 additions & 0 deletions include/gauxc/xc_integrator/replicated/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,20 @@ typename ReplicatedXCIntegrator<MatrixType>::exc_grad_type

}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::exc_grad_type
ReplicatedXCIntegrator<MatrixType>::eval_exc_grad_onedft_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();

std::vector<value_type> EXC_GRAD( 3*pimpl_->load_balancer().molecule().natoms() );
pimpl_->eval_exc_grad_onedft( Ps.rows(), Ps.cols(), Ps.data(), Ps.rows(), Pz.data(), Pz.rows(),
EXC_GRAD.data(), ks_settings );

return EXC_GRAD;

}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::exx_type
ReplicatedXCIntegrator<MatrixType>::eval_exx_( const MatrixType& P, const IntegratorSettingsEXX& settings ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class ReplicatedXCIntegratorImpl {
value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exc_grad_( int64_t m, int64_t n, const value_type* P, int64_t ldps,
const value_type* Pz, int64_t lpdz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings ) = 0;
virtual void eval_exx_( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* K, int64_t ldk,
const IntegratorSettingsEXX& settings ) = 0;
Expand Down Expand Up @@ -169,6 +171,8 @@ class ReplicatedXCIntegratorImpl {
value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings );
void eval_exc_grad( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings );
void eval_exc_grad_onedft( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz, value_type* EXC_GRAD, const IntegratorSettingsXC& ks_settings );

void eval_exx( int64_t m, int64_t n, const value_type* P,
int64_t ldp, value_type* K, int64_t ldk,
Expand Down
1 change: 1 addition & 0 deletions include/gauxc/xc_integrator/replicated_xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl<MatrixType> {
exc_vxc_type_uks eval_exc_vxc_onedft_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
exc_grad_type eval_exc_grad_( const MatrixType&, const IntegratorSettingsXC& ) override;
exc_grad_type eval_exc_grad_( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
exc_grad_type eval_exc_grad_onedft_( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
exx_type eval_exx_ ( const MatrixType&, const IntegratorSettingsEXX& ) override;
fxc_contraction_type_rks eval_fxc_contraction_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override;
fxc_contraction_type_uks eval_fxc_contraction_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC&) override;
Expand Down
11 changes: 11 additions & 0 deletions include/gauxc/xc_integrator/xc_integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class XCIntegratorImpl {
virtual exc_vxc_type_uks eval_exc_vxc_onedft_ ( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_grad_type eval_exc_grad_( const MatrixType& P, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_grad_type eval_exc_grad_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exc_grad_type eval_exc_grad_onedft_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0;
virtual exx_type eval_exx_ ( const MatrixType& P,
const IntegratorSettingsEXX& settings ) = 0;
virtual fxc_contraction_type_rks eval_fxc_contraction_ ( const MatrixType& P,
Expand Down Expand Up @@ -152,6 +153,16 @@ class XCIntegratorImpl {
return eval_exc_grad_(Ps, Pz, ks_settings);
}

/** Integrate EXC gradient for OneDFT models
*
* @param[in] Ps The total density matrix
* @param[in] Pz The magnetization density matrix
* @returns EXC gradient
*/
exc_grad_type eval_exc_grad_onedft( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) {
return eval_exc_grad_onedft_(Ps, Pz, ks_settings);
}

/** Integrate Exact Exchange for RHF
*
* @param[in] P The alpha density matrix
Expand Down
2 changes: 2 additions & 0 deletions include/gauxc/xc_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ struct XCTask {
std::vector<double> vdden_z_eval_a;
std::vector<double> vdden_z_eval_b;
std::vector<double> vtau;
// energy density per grid point (for gradient weight derivative)
std::vector<double> eps;
};
features feat;

Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ if( GAUXC_HAS_MPI )
endif()

if ( GAUXC_HAS_ONEDFT )
target_link_libraries( gauxc PUBLIC "${TORCH_LIBRARIES}")
target_link_libraries( gauxc PUBLIC "${TORCH_LIBRARIES}" nlohmann_json::nlohmann_json)
endif()

add_subdirectory( runtime_environment )
Expand Down
4 changes: 4 additions & 0 deletions src/molecular_weights/device/device_molecular_weights.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ void DeviceMolecularWeights::modify_weights( LoadBalancer& lb ) const {
};
std::stable_sort(task_begin, task_end, task_comparator );

// Save raw quadrature weights before partition modifies them in-place.
// These are needed by OneDFT models that use atomic_grid_weights.
for( auto it = task_begin; it != task_end; ++it ) it->raw_weights = it->weights;

const auto& mol = lb.molecule();
const auto natoms = mol.natoms();
const auto& meta = lb.molmeta();
Expand Down
4 changes: 4 additions & 0 deletions src/molecular_weights/host/host_molecular_weights.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ void HostMolecularWeights::modify_weights( LoadBalancer& lb ) const {
};
std::stable_sort( tasks.begin(), tasks.end(), task_comparator );

// Save raw quadrature weights before partition modifies them in-place.
// These are needed by OneDFT models that use atomic_grid_weights.
for( auto& task : tasks ) task.raw_weights = task.weights;

// Modify the weights
const auto& mol = lb.molecule();
const auto& meta = lb.molmeta();
Expand Down
Loading