Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions docs/ftorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ The output for the H2 molecule with the def2-QZVP basis set should look like thi
[3] Preparing input dictionary
[4] Running model inference
[5] Computing XC energy = sum(exc * grid_weights)
-> E_xc = -6.23580775902096E-0
-> E_xc = -6.26480858787919E-01

The ``get_exc_vxc`` procedure computes the exchange-correlation energy and potential, which we can then access from the returned dictionary.
The potential terms are stored under the same keys as the input features and can be extracted as tensors.
Expand Down Expand Up @@ -271,35 +271,35 @@ In the output we can see the computed exchange-correlation energy as well as the
[3] Preparing input dictionary
[4] Running model inference
[5] Computing XC energy = sum(exc * grid_weights)
-> E_xc = -6.23580775902096E-01
-> E_xc = -6.26480858787919E-01
[6] Extracting vxc components
[7] Gradient means (dexc/dx)
-> mean(dexc/d_density) = -5.22209689934698E-03
-> mean(dexc/d_grad) = -5.63822012469897E-12
-> mean(dexc/d_kin) = -5.74746008322732E-04
-> mean(dexc/d_grid_coords) = -2.86873728034485E-14
-> mean(dexc/d_grid_weights) = -3.47603697347615E-02
-> mean(dexc/d_coarse_0_atomic_coords) = 2.81365752618854E-10
-> mean(dexc/d_density) = -3.61391813525080E-03
-> mean(dexc/d_grad) = -2.32959324675360E-13
-> mean(dexc/d_kin) = -8.83645473913423E-04
-> mean(dexc/d_grid_coords) = 4.28210805205588E-15
-> mean(dexc/d_grid_weights) = -3.07264776522541E-02
-> mean(dexc/d_coarse_0_atomic_coords) = -4.19989157107543E-11
[8] Accessing tensor data as Fortran arrays
-> exc: shape = (19616)
[ -1.40859407629398E-11 -5.02191290297704E-14 -4.82271929094491E-19 ...]
[ -1.70762973238251E-01 -1.70762851641779E-01 -1.70761433376289E-01 ...]
-> dexc/d_density: shape = (19616, 2)
[[ -1.41238001293684E-02 -7.07236361510544E-03 -1.22373184625014E-03 ...]
[ -1.41238001293684E-02 -7.07236361510544E-03 -1.22373184625014E-03 ...]]
[[ -6.91094160750184E-15 -2.68068796422506E-12 -8.75755568871490E-11 ...]
[ -6.91094160750184E-15 -2.68068796422506E-12 -8.75755568871490E-11 ...]]
-> dexc/d_grad: shape = (19616, 3, 2)
[[[ 2.54671362468241E-15 3.18318623003652E-19 2.04989634182288E-27 ...]
[ 2.54671362468241E-15 3.18318623003652E-19 2.04989634182288E-27 ...]]
[[[ -3.53224861366702E-20 -1.25883699330863E-16 -1.50302693127561E-14 ...]
[ 2.45810484635441E-32 9.53375368531974E-30 3.11056983360779E-28 ...]]
[[ ... ]]]
-> dexc/d_kin: shape = (19616, 2)
[[ -2.80759310169539E-07 -2.05221067626510E-09 -6.12585093687525E-14 ...]
[ -2.80759310169539E-07 -2.05221067626510E-09 -6.12585093687525E-14 ...]]
[[ 8.31256652855356E-16 3.22438658440735E-13 1.05346827546379E-11 ...]
[ 8.31256652855356E-16 3.22438658440735E-13 1.05346827546379E-11 ...]]
-> dexc/d_grid_coords: shape = (3, 19616)
[[ 0.00000000000000E+00 0.00000000000000E+00 0.00000000000000E+00 ...]
[ 0.00000000000000E+00 0.00000000000000E+00 0.00000000000000E+00 ...]]
[[ -1.10244039021109E-20 -3.46793732833479E-29 1.44417263443964E-21 ...]
[ -3.92941562081365E-17 -1.23603204621552E-25 5.14711125999810E-18 ...]]
-> dexc/d_grid_weights: shape = (19616)
[ -1.40859407629398E-11 -5.02191290297704E-14 -4.82271929094491E-19 ...]
[ -1.70762973238251E-01 -1.70762851641779E-01 -1.70761433376289E-01 ...]
-> dexc/d_coarse_0_atomic_coords: shape = (3, 2)
[[ -4.21657770765388E-10 1.95481841380307E-10 -4.14780685154337E-03]
[[ 5.94466622059813E-12 -3.78104062375730E-12 4.36839713298186E-04]
[ ...]]

Summary
Expand Down
11 changes: 7 additions & 4 deletions examples/fortran/ftorch_integration/src/skala_ftorch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ typedef enum SkalaFeature {
Feature_Kin = 3,
Feature_GridCoords = 4,
Feature_GridWeights = 5,
Feature_Coarse0AtomicCoords = 6
Feature_Coarse0AtomicCoords = 6,
Feature_AtomicGridWeights = 7,
Feature_AtomicGridSizes = 8,
Feature_AtomicGridSizeBoundShape = 9
Expand Down Expand Up @@ -169,13 +169,16 @@ skala_model_get_exc_and_vxc(torch_jit_script_module_t module, skala_dict_t input
c10::Dict<std::string, at::Tensor> features_with_grad;
for (const auto& entry : *dict) {
std::string key = entry.key();
bool requires_grad = (key == "density" || key == "grad" || key == "kin" || key == "coarse_0_atomic_coords" || key == "grid_coords" || key == "grid_weights");
const auto& values = entry.value();
std::vector<at::Tensor> tensors;
for (const auto &value : values) {
auto tensor_with_grad = value.clone().requires_grad_(true);
auto tensor_with_grad = value.clone().requires_grad_(requires_grad);
tensors.push_back(tensor_with_grad);
input_tensors.push_back(tensor_with_grad);
tensor_keys.push_back(key);
if (requires_grad) {
input_tensors.push_back(tensor_with_grad);
tensor_keys.push_back(key);
}
}
features_with_grad.insert(key, torch::concat(tensors));
}
Expand Down
6 changes: 6 additions & 0 deletions examples/fortran/ftorch_integration/src/skala_ftorch.f90
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ subroutine skala_feature_key(feature, key)
key = "grid_weights"
case (skala_feature%coarse_0_atomic_coords)
key = "coarse_0_atomic_coords"
case (skala_feature%atomic_grid_weights)
key = "atomic_grid_weights"
case (skala_feature%atomic_grid_sizes)
key = "atomic_grid_sizes"
case (skala_feature%atomic_grid_size_bound_shape)
key = "atomic_grid_size_bound_shape"
end select
end subroutine skala_feature_key

Expand Down
Loading