diff --git a/docs/ftorch.rst b/docs/ftorch.rst index 2b71aa6..65344c2 100644 --- a/docs/ftorch.rst +++ b/docs/ftorch.rst @@ -220,7 +220,7 @@ The output for the H2 molecule with the def2-QZVP basis set should look like thi [3] Preparing input dictionary [4] Running model inference [5] Computing XC energy = sum(exc * grid_weights) - -> E_xc = -6.23580775902096E-0 + -> E_xc = -6.26480858787919E-01 The ``get_exc_vxc`` procedure computes the exchange-correlation energy and potential, which we can then access from the returned dictionary. The potential terms are stored under the same keys as the input features and can be extracted as tensors. @@ -271,35 +271,35 @@ In the output we can see the computed exchange-correlation energy as well as the [3] Preparing input dictionary [4] Running model inference [5] Computing XC energy = sum(exc * grid_weights) - -> E_xc = -6.23580775902096E-01 + -> E_xc = -6.26480858787919E-01 [6] Extracting vxc components [7] Gradient means (dexc/dx) - -> mean(dexc/d_density) = -5.22209689934698E-03 - -> mean(dexc/d_grad) = -5.63822012469897E-12 - -> mean(dexc/d_kin) = -5.74746008322732E-04 - -> mean(dexc/d_grid_coords) = -2.86873728034485E-14 - -> mean(dexc/d_grid_weights) = -3.47603697347615E-02 - -> mean(dexc/d_coarse_0_atomic_coords) = 2.81365752618854E-10 + -> mean(dexc/d_density) = -3.61391813525080E-03 + -> mean(dexc/d_grad) = -2.32959324675360E-13 + -> mean(dexc/d_kin) = -8.83645473913423E-04 + -> mean(dexc/d_grid_coords) = 4.28210805205588E-15 + -> mean(dexc/d_grid_weights) = -3.07264776522541E-02 + -> mean(dexc/d_coarse_0_atomic_coords) = -4.19989157107543E-11 [8] Accessing tensor data as Fortran arrays -> exc: shape = (19616) - [ -1.40859407629398E-11 -5.02191290297704E-14 -4.82271929094491E-19 ...] + [ -1.70762973238251E-01 -1.70762851641779E-01 -1.70761433376289E-01 ...] -> dexc/d_density: shape = (19616, 2) - [[ -1.41238001293684E-02 -7.07236361510544E-03 -1.22373184625014E-03 ...] - [ -1.41238001293684E-02 -7.07236361510544E-03 -1.22373184625014E-03 ...]] + [[ -6.91094160750184E-15 -2.68068796422506E-12 -8.75755568871490E-11 ...] + [ -6.91094160750184E-15 -2.68068796422506E-12 -8.75755568871490E-11 ...]] -> dexc/d_grad: shape = (19616, 3, 2) - [[[ 2.54671362468241E-15 3.18318623003652E-19 2.04989634182288E-27 ...] - [ 2.54671362468241E-15 3.18318623003652E-19 2.04989634182288E-27 ...]] + [[[ -3.53224861366702E-20 -1.25883699330863E-16 -1.50302693127561E-14 ...] + [ 2.45810484635441E-32 9.53375368531974E-30 3.11056983360779E-28 ...]] [[ ... ]]] -> dexc/d_kin: shape = (19616, 2) - [[ -2.80759310169539E-07 -2.05221067626510E-09 -6.12585093687525E-14 ...] - [ -2.80759310169539E-07 -2.05221067626510E-09 -6.12585093687525E-14 ...]] + [[ 8.31256652855356E-16 3.22438658440735E-13 1.05346827546379E-11 ...] + [ 8.31256652855356E-16 3.22438658440735E-13 1.05346827546379E-11 ...]] -> dexc/d_grid_coords: shape = (3, 19616) - [[ 0.00000000000000E+00 0.00000000000000E+00 0.00000000000000E+00 ...] - [ 0.00000000000000E+00 0.00000000000000E+00 0.00000000000000E+00 ...]] + [[ -1.10244039021109E-20 -3.46793732833479E-29 1.44417263443964E-21 ...] + [ -3.92941562081365E-17 -1.23603204621552E-25 5.14711125999810E-18 ...]] -> dexc/d_grid_weights: shape = (19616) - [ -1.40859407629398E-11 -5.02191290297704E-14 -4.82271929094491E-19 ...] + [ -1.70762973238251E-01 -1.70762851641779E-01 -1.70761433376289E-01 ...] -> dexc/d_coarse_0_atomic_coords: shape = (3, 2) - [[ -4.21657770765388E-10 1.95481841380307E-10 -4.14780685154337E-03] + [[ 5.94466622059813E-12 -3.78104062375730E-12 4.36839713298186E-04] [ ...]] Summary diff --git a/examples/fortran/ftorch_integration/src/skala_ftorch.cxx b/examples/fortran/ftorch_integration/src/skala_ftorch.cxx index 54c4e06..f482647 100644 --- a/examples/fortran/ftorch_integration/src/skala_ftorch.cxx +++ b/examples/fortran/ftorch_integration/src/skala_ftorch.cxx @@ -16,7 +16,7 @@ typedef enum SkalaFeature { Feature_Kin = 3, Feature_GridCoords = 4, Feature_GridWeights = 5, - Feature_Coarse0AtomicCoords = 6 + Feature_Coarse0AtomicCoords = 6, Feature_AtomicGridWeights = 7, Feature_AtomicGridSizes = 8, Feature_AtomicGridSizeBoundShape = 9 @@ -169,13 +169,16 @@ skala_model_get_exc_and_vxc(torch_jit_script_module_t module, skala_dict_t input c10::Dict features_with_grad; for (const auto& entry : *dict) { std::string key = entry.key(); + bool requires_grad = (key == "density" || key == "grad" || key == "kin" || key == "coarse_0_atomic_coords" || key == "grid_coords" || key == "grid_weights"); const auto& values = entry.value(); std::vector tensors; for (const auto &value : values) { - auto tensor_with_grad = value.clone().requires_grad_(true); + auto tensor_with_grad = value.clone().requires_grad_(requires_grad); tensors.push_back(tensor_with_grad); - input_tensors.push_back(tensor_with_grad); - tensor_keys.push_back(key); + if (requires_grad) { + input_tensors.push_back(tensor_with_grad); + tensor_keys.push_back(key); + } } features_with_grad.insert(key, torch::concat(tensors)); } diff --git a/examples/fortran/ftorch_integration/src/skala_ftorch.f90 b/examples/fortran/ftorch_integration/src/skala_ftorch.f90 index 8aa4ec4..3eaa2fd 100644 --- a/examples/fortran/ftorch_integration/src/skala_ftorch.f90 +++ b/examples/fortran/ftorch_integration/src/skala_ftorch.f90 @@ -349,6 +349,12 @@ subroutine skala_feature_key(feature, key) key = "grid_weights" case (skala_feature%coarse_0_atomic_coords) key = "coarse_0_atomic_coords" + case (skala_feature%atomic_grid_weights) + key = "atomic_grid_weights" + case (skala_feature%atomic_grid_sizes) + key = "atomic_grid_sizes" + case (skala_feature%atomic_grid_size_bound_shape) + key = "atomic_grid_size_bound_shape" end select end subroutine skala_feature_key