diff --git a/.devcontainer/build_docker.sh b/.devcontainer/build_docker.sh new file mode 100644 index 0000000..d860bae --- /dev/null +++ b/.devcontainer/build_docker.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# run this file from repo root + +# exit if any command fails +set -e +set -o pipefail + +FILE=".devcontainer/skala_dev.Dockerfile" +ENV_VARIANT="${1:-cpu}" +BASE_NAME="skala-dev-${ENV_VARIANT}" +TAG=${BASE_NAME}":"$(date +"%Y%m%dT%H%M%S") + +echo "Building Docker image with tag \"${TAG}\" (ENV_VARIANT=${ENV_VARIANT})" + +# To ignore the cache, use --no-cache +docker build \ + --progress=plain \ + --build-arg ENV_VARIANT="${ENV_VARIANT}" \ + --tag=${TAG} \ + --file=${FILE} \ + . \ + 2>&1 | tee -a "build_${TAG}.log" + +docker tag ${TAG} ${BASE_NAME}":latest" \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..d18f4a5 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,52 @@ +{ + "name": "skala-dev-gpu", + "build": { + "dockerfile": "skala_dev.Dockerfile", + "context": "..", + "args": { + "ENV_VARIANT": "gpu" + } + }, + "postCreateCommand": "", + "containerEnv": { + "PYTHONPATH": "${containerWorkspaceFolder}/src/" + }, + "runArgs": [ + "--shm-size=2gb", + "--network=host" + ], + "forwardPorts": [], + "containerUser": "ubuntu", // or "${localEnv:USER}", + "features": { + "ghcr.io/devcontainers/features/common-utils:2": { + "installZsh": false, + "installOhMyZsh": false, + "upgradePackages": false + // If a user with this UID already exists in the image, comment the following lines and set "containerUser" to that user instead. + // "username": "${localEnv:USER}", + // "userUid": "${localEnv:UID}", + // "userGid": "${localEnv:UID}" + } + }, + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "charliermarsh.ruff", + "ms-python.mypy-type-checker" + ], + "settings": { + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": [ + "${containerWorkspaceFolder}/tests/" + ], + "mypy-type-checker.args": [ + "--config-file=pyproject.toml" + ] + } + } + } +} \ No newline at end of file diff --git a/.devcontainer/skala_dev.Dockerfile b/.devcontainer/skala_dev.Dockerfile new file mode 100644 index 0000000..31e68c5 --- /dev/null +++ b/.devcontainer/skala_dev.Dockerfile @@ -0,0 +1,43 @@ +# syntax=docker/dockerfile:1 +FROM ubuntu:noble + +RUN apt-get update --quiet \ + && apt-get install --yes --quiet --no-install-recommends \ + wget \ + ca-certificates \ + && apt-get clean --yes \ + && rm -rf /var/lib/apt/lists/* + +SHELL [ "/bin/bash", "-c" ] + +# mamba installation (mamba>2.0) +ENV MAMBA_DIR=/opt/miniforge3 +ENV MAMBA_ROOT_PREFIX=${MAMBA_DIR} +ENV PATH=${PATH}:${MAMBA_DIR}/bin + +RUN wget --no-hsts --quiet --output-document=miniforge.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \ + && bash miniforge.sh -b -p ${MAMBA_DIR} \ + && rm miniforge.sh \ + # Remove python3.1 symlink if it exists, as it causes issues with conda + # https://github.com/conda/conda/issues/11423 + && (test -L ${MAMBA_DIR}/bin/python3.1 && unlink ${MAMBA_DIR}/bin/python3.1 || true) \ + && ${MAMBA_DIR}/bin/mamba update -n base --all -y \ + && mamba clean --all --yes \ + # Mamba initialization script + && echo "eval $(mamba shell hook --shell bash)" >> /etc/profile.d/source_mamba.sh \ + # for interactive shells: + && echo "source /etc/profile.d/source_mamba.sh" >> /etc/bash.bashrc + +# for non-interactive, not login shells: +# https://www.solipsys.co.uk/images/BashStartupFiles1.png +ENV BASH_ENV="/etc/profile.d/source_mamba.sh" + +# create environment (ENV_VARIANT: "cpu" or "gpu") +ARG ENV_VARIANT=cpu +COPY ./environment-${ENV_VARIANT}.yml ./environment.yml +# CONDA_OVERRIDE_CUDA lets mamba solve CUDA deps without a GPU present during build +# Cache the mamba package downloads so rebuilds after environment.yml changes are faster. +# The cache mount is not part of the image layer, so mamba clean is unnecessary. +RUN --mount=type=cache,target=${MAMBA_DIR}/pkgs \ + CONDA_OVERRIDE_CUDA="12.0" mamba env create --file environment.yml \ + && rm environment.yml \ No newline at end of file diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index ee613cf..6a68bab 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -47,6 +47,11 @@ jobs: touch docs/_build/html/.nojekyll shell: micromamba-shell {0} + - name: Check external links + run: | + sphinx-build -b linkcheck docs docs/_build/linkcheck + shell: micromamba-shell {0} + - name: Upload Pages artifact uses: actions/upload-pages-artifact@v3 with: diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 19dd7ae..636b83c 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -23,13 +23,13 @@ jobs: - name: Download checkpoint run: >- - hf download microsoft/skala-1.0 skala-1.0.fun --local-dir . + hf download microsoft/skala-1.1 skala-1.1.fun --local-dir . - name: Upload checkpoint uses: actions/upload-artifact@v4 with: name: skala-checkpoint - path: skala-1.0.fun + path: skala-1.1.fun pt-features: runs-on: ubuntu-latest @@ -153,7 +153,7 @@ jobs: run: >- Skala ./gauxc/tests/ref_data/onedft_he_def2qzvp_tpss_uks.hdf5 - --model ./skala-1.0.fun + --model ./skala-1.1.fun shell: micromamba-shell {0} ftorch: @@ -205,6 +205,6 @@ jobs: - name: Run example run: >- Skala - ./skala-1.0.fun + ./skala-1.1.fun ./features shell: micromamba-shell {0} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 51e5e21..556cea2 100644 --- a/.gitignore +++ b/.gitignore @@ -409,3 +409,6 @@ FodyWeavers.xsd # Sphinx documentation build output docs/_build/ + +# Checkpoint files +*.fun \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ad8de94..dbec3fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,24 +2,16 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.8.4 + rev: v0.15.10 hooks: # Run the linter. - id: ruff - args: [--fix, --select, I] - exclude: '^(third_party)/.*' + args: [--config, pyproject.toml, --fix] # Run the formatter. - id: ruff-format - exclude: '^(third_party)/.*' - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.16.1 - hooks: - - id: mypy - exclude: '^(third_party|tests)/.*' - args: - - --strict - - --ignore-missing-imports - - --no-namespace-packages - - --python-version=3.11 - entry: mypy + args: [--config, pyproject.toml] + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.20.0 + # hooks: + # - id: mypy + # args: [--config-file, pyproject.toml] diff --git a/README.md b/README.md index 709ed9e..6480b4b 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,11 @@ [![PyPI](https://img.shields.io/pypi/v/skala?logo=pypi&logoColor=white)](https://pypi.org/project/skala/) [![Paper](https://img.shields.io/badge/arXiv-2506.14665-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2506.14665) -Skala is a neural network-based exchange-correlation functional for density functional theory (DFT), developed by Microsoft Research AI for Science. It leverages deep learning to predict exchange-correlation energies from electron density features, achieving chemical accuracy for atomization energies and strong performance on broad thermochemistry and kinetics benchmarks, all at a computational cost similar to semi-local DFT. +Skala is a neural network-based exchange-correlation functional for density functional theory (DFT), developed by Microsoft Research AI for Science. It uses deep learning to predict exchange-correlation energies from electron density features, achieving chemical accuracy for atomization energies and strong performance on broad thermochemistry and kinetics benchmarks, all at a computational cost similar to semi-local DFT. -Trained on a large, diverse dataset—including coupled cluster atomization energies and public benchmarks—Skala uses scalable message passing and local layers to learn both local and non-local effects. The model has about 276,000 parameters and matches the accuracy of leading hybrid functionals. +Trained on a large, diverse dataset — including coupled-cluster atomization energies and public benchmarks — Skala uses scalable message passing and local layers to learn both local and non-local effects. The model has about 276,000 parameters and matches the accuracy of leading hybrid functionals. + +The recommended neural functional is `skala-1.1`, which uses per-atom packed grids, multiple non-local layers, and symmetric contraction. The legacy `skala-1.0` traced model is still loadable via `load_functional("skala-1.0")`. Learn more about Skala in our [ArXiv paper](https://arxiv.org/abs/2506.14665). @@ -15,8 +17,8 @@ Learn more about Skala in our [ArXiv paper](https://arxiv.org/abs/2506.14665). This repository contains two main components: -1. The Python package `skala`, which is also distributed [on PyPI](https://pypi.org/project/skala/) and contains a PyTorch implementation of the Skala model, its hookups to quantum chemistry packages [PySCF](https://pyscf.org/), [GPU4PySCF](https://pyscf.org/user/gpu.html) and [ASE](https://ase-lib.org/). -2. Example of using Skala applications through LibTorch and GauXC, the following examples are available +1. The Python package `skala`, distributed [on PyPI](https://pypi.org/project/skala/) and on conda-forge. It contains a PyTorch implementation of the Skala model and its bindings to the quantum-chemistry packages [PySCF](https://pyscf.org/), [GPU4PySCF](https://pyscf.org/user/gpu.html), and [ASE](https://ase-lib.org/). +2. Examples of using Skala from compiled code through LibTorch and GauXC: - [Skala in C++ with libtorch](examples/cpp/cpp_integration) - [Skala in Fortran with FTorch](https://microsoft.github.io/skala/ftorch) - [Skala in C++ with GauXC](https://microsoft.github.io/skala/gauxc/cpp-library) @@ -39,18 +41,23 @@ For detailed documentation on using GauXC visit the [Skala integration guide](ht All information below relates to the Python package `skala`. -Install using Pip: +`pip install skala` works out of the box and pulls every dependency from PyPI. +If you don't already have PyTorch installed, install the CPU-only wheel first +to avoid pulling a large CUDA build: ```bash -# Install CPU-only PyTorch (skip if you already have CPU or GPU-enabled PyTorch installed) pip install torch --index-url https://download.pytorch.org/whl/cpu pip install skala ``` -Or using Conda (Mamba): +For a reproducible conda environment, use the provided +[`environment-cpu.yml`](environment-cpu.yml), which pins CPU-only PyTorch and +all runtime dependencies: ```bash -mamba install -c conda-forge skala "pytorch=*=cpu*" +mamba env create -n skala -f environment-cpu.yml +mamba activate skala +pip install skala ``` Run an SCF calculation with Skala for a hydrogen molecule: @@ -63,21 +70,35 @@ mol = gto.M( atom="""H 0 0 0; H 0 0 1.4""", basis="def2-tzvp", ) -ks = SkalaKS(mol, xc="skala") +ks = SkalaKS(mol, xc="skala-1.1") ks.kernel() ``` -## Getting started: GPU4PySCF (GPU) +## Getting started: GPU4PySCF (GPU) -These instructions use Mamba and pip to install CUDA toolkit, Torch, and CuPy. It supports CUDA version 11, 12 or 13. You can find the most recent CUDA version that is supported on your system using `nvidia-smi`. +The GPU install is more involved because `gpu4pyscf` ships CUDA-version-specific +wheels that must match your CUDA toolkit. The recommended path is the provided +[`environment-gpu.yml`](environment-gpu.yml), which pins `pytorch-gpu`, +`cuda-toolkit` 12, `cutensor`, and installs `gpu4pyscf-cuda12x` 1.5 from PyPI: ```bash -cu_version=12 #or 11 or 13 depending on your CUDA version -mamba env create -n skala -f environment-gpu.yml "cuda-version==${cu_version}.*" skala +mamba env create -n skala -f environment-gpu.yml mamba activate skala -pip install --no-deps "gpu4pyscf-cuda${cu_version}x>=1.0,<2" "gpu4pyscf-libxc-cuda${cu_version}x>=0.4,<1" +pip install skala ``` +If you are building inside a container without a GPU attached (e.g., CI or a +Docker image built on a CPU-only host), set `CONDA_OVERRIDE_CUDA` so the solver +proceeds without a device: + +```bash +CONDA_OVERRIDE_CUDA=12.0 mamba env create -n skala -f environment-gpu.yml +``` + +For CUDA 11 or 13, adjust `cuda-toolkit`, `cuda-version`, and the +`gpu4pyscf-cuda{11,13}x` pin in `environment-gpu.yml` accordingly. Check your +driver's maximum supported CUDA version with `nvidia-smi`. + Run an SCF calculation with Skala for a hydrogen molecule on GPU: ```python @@ -88,19 +109,39 @@ mol = gto.M( atom="""H 0 0 0; H 0 0 1.4""", basis="def2-tzvp", ) -ks = SkalaKS(mol, xc="skala") +ks = SkalaKS(mol, xc="skala-1.1") ks.kernel() ``` +## Getting started: ASE calculator + +Skala also provides an [ASE](https://wiki.fysik.dtu.dk/ase/) calculator for energy, force, and geometry optimization workflows: + +```python +from ase.build import molecule +from ase.optimize import LBFGSLineSearch +from skala.ase import Skala + +atoms = molecule("H2O") +atoms.calc = Skala(xc="skala-1.1", basis="def2-tzvp") + +# Single-point energy (eV) +print(atoms.get_potential_energy()) + +# Geometry optimization +opt = LBFGSLineSearch(atoms) +opt.run(fmax=0.01) +``` + ## Documentation and examples -Go to [microsoft.github.io/skala](https://microsoft.github.io/skala) for a more detailed installation guide and further examples of how to use the Skala functional with PySCF, GPU4PySCF and ASE and in [Azure AI Foundry](https://ai.azure.com/catalog/models/Skala). +See [microsoft.github.io/skala](https://microsoft.github.io/skala) for a more detailed installation guide and further examples of how to use the Skala functional with PySCF, GPU4PySCF, and ASE, as well as in [Azure AI Foundry](https://ai.azure.com/catalog/models/Skala). ## Security: loading `.fun` files Skala model files (`.fun`) use TorchScript serialization, which can execute arbitrary code when loaded. **Never load `.fun` files from untrusted sources.** -When loading the official Skala model via `load_functional("skala")`, file integrity is automatically verified against pinned SHA-256 hashes before deserialization. If you load `.fun` files directly with `TracedFunctional.load()`, pass the `expected_hash` parameter to enable verification: +When loading the official Skala models via `load_functional("skala-1.1")` or `load_functional("skala-1.0")`, file integrity is automatically verified against pinned SHA-256 hashes before deserialization. If you load `.fun` files directly with `TracedFunctional.load()`, pass the `expected_hash` parameter to enable verification: ```python TracedFunctional.load("model.fun", expected_hash="") diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..b008765 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,2 @@ +# Artifacts generated by executing notebooks during the Sphinx build. +water_opt.traj diff --git a/docs/_static/bib/data.bib b/docs/_static/bib/data.bib index 3a1772d..1572882 100644 --- a/docs/_static/bib/data.bib +++ b/docs/_static/bib/data.bib @@ -70,7 +70,7 @@ @article{grimme2015 } @article{karton2006, - title={Comment on: ``Estimating the Hartree--Fock limit from finite basis set calculations'' [Jensen F (2005) Theor Chem Acc 113: 267]}, + title={Comment on: ``Estimating the {Hartree--Fock} limit from finite basis set calculations'' [{Jensen} {F} (2005) {Theor} {Chem} {Acc} 113: 267]}, author={Karton, Amir and Martin, Jan ML}, journal={Theoretical Chemistry Accounts}, volume={115}, @@ -91,4 +91,170 @@ @article{karton2009 year={2009}, doi={10.1080/00268970802708959}, publisher = {Taylor \& Francis} +} + +@article{dohm2018, + title={Comprehensive thermochemical benchmark set of realistic closed-shell metal organic reactions}, + author={Dohm, Sebastian and Hansen, Andreas and Steinmetz, Marc and Grimme, Stefan and Checinski, Marek P}, + journal={Journal of chemical theory and computation}, + volume={14}, + number={5}, + pages={2596--2608}, + year={2018}, + doi={10.1021/acs.jctc.7b01183}, + publisher={ACS Publications} +} + +@article{maurer2021, + title={Assessing density functional theory for chemically relevant open-shell transition metal reactions}, + author={Maurer, Leonard R and Bursch, Markus and Grimme, Stefan and Hansen, Andreas}, + journal={Journal of Chemical Theory and Computation}, + volume={17}, + number={10}, + pages={6134--6151}, + year={2021}, + doi={10.1021/acs.jctc.1c00659}, + publisher={ACS Publications} +} + +@article{semidalas2022, + title={The {MOBH35} metal--organic barrier heights reconsidered: Performance of local-orbital coupled cluster approaches in different static correlation regimes}, + author={Semidalas, Emmanouil and Martin, Jan ML}, + journal={Journal of chemical theory and computation}, + volume={18}, + number={2}, + pages={883--898}, + year={2022}, + doi={10.1021/acs.jctc.1c01126}, + publisher={ACS Publications} +} + +@article{neugebauer2023, + title={Toward benchmark-quality ab initio predictions for 3d transition metal electrocatalysts: A comparison of {CCSD(T)} and {PH-AFQMC}}, + author={Neugebauer, Hagen and Vuong, Hung T and Weber, John L and Friesner, Richard A and Shee, James and Hansen, Andreas}, + journal={Journal of Chemical Theory and Computation}, + volume={19}, + number={18}, + pages={6208--6225}, + year={2023}, + doi={10.1021/acs.jctc.3c00617}, + publisher={ACS Publications} +} + +@article{chan2019, + title={The {CUAGAU} set of coupled-cluster reference data for small copper, silver, and gold compounds and assessment of {DFT} methods}, + author={Chan, Bun}, + journal={The Journal of Physical Chemistry A}, + volume={123}, + number={27}, + pages={5781--5788}, + year={2019}, + doi={10.1021/acs.jpca.9b03976}, + publisher={ACS Publications} +} + +@article{chan2023, + title={{DAPD} Set of {Pd}-Containing Diatomic Molecules: Accurate Molecular Properties and the Great Lengths to Obtain Them}, + author={Chan, Bun}, + journal={Journal of Chemical Theory and Computation}, + volume={19}, + number={24}, + pages={9260--9268}, + year={2023}, + doi={10.1021/acs.jctc.3c01060}, + publisher={ACS Publications} +} + +@article{liang2025, + title={Gold-Standard Chemical Database 137 ({GSCDB137}): A diverse set of accurate energy differences for assessing and developing density functionals}, + author={Liang, Jiashu and Head-Gordon, Martin}, + journal={Journal of Chemical Theory and Computation}, + volume={21}, + number={24}, + pages={12601--12621}, + year={2025}, + doi={10.1021/acs.jctc.5c01380}, + publisher={ACS Publications} +} + +@article{karton2011, + title={W4-11: A high-confidence benchmark dataset for computational thermochemistry derived from first-principles {W4} data}, + author={Karton, Amir and Daon, Shauli and Martin, Jan ML}, + journal={Chemical Physics Letters}, + volume={510}, + number={4-6}, + pages={165--178}, + year={2011}, + doi={10.1016/j.cplett.2011.05.007}, + publisher={Elsevier} +} + +@article{karton2025, + title={A highly diverse and accurate database of 3366 total atomization energies calculated at the {CCSD(T)/CBS} level by means of {W1-F12} theory}, + author={Karton, Amir}, + journal={Chemical Physics Letters}, + volume={868}, + pages={142030}, + year={2025}, + doi={10.1016/j.cplett.2025.142030}, + publisher={Elsevier} +} + +@article{gasevic2025, + title={Chemical Space Exploration with Artificial ``Mindless'' Molecules}, + author={Gasevic, Thomas and Müller, Marcel and Schöps, Jonathan and Lanius, Stephanie and Hermann, Jan and Grimme, Stefan and Hansen, Andreas}, + journal={Journal of Chemical Information and Modeling}, + volume={65}, + number={18}, + pages={9576--9587}, + year={2025}, + doi={10.1021/acs.jcim.5c01364}, + publisher={ACS Publications} +} + +@article{prasad2021, + title={{BH9}, a new comprehensive benchmark data set for barrier heights and reaction energies: Assessment of density functional approximations and basis set incompleteness potentials}, + author={Prasad, Viki Kumar and Pei, Zhipeng and Edelmann, Simon and Otero-de-la-Roza, Alberto and DiLabio, Gino A}, + journal={Journal of chemical theory and computation}, + volume={18}, + number={1}, + pages={151--166}, + year={2021}, + doi={10.1021/acs.jctc.1c00694}, + publisher={ACS Publications} +} + +@article{donchev2021, + title={Quantum chemical benchmark databases of gold-standard dimer interaction energies}, + author={Donchev, Alexander G and Taube, Andrew G and Decolvenaere, Elizabeth and Hargus, Cory and McGibbon, Robert T and Law, Ka-Hei and Gregersen, Brent A and Li, Je-Luen and Palmo, Kim and Siva, Karthik and others}, + journal={Scientific data}, + volume={8}, + number={1}, + pages={55}, + year={2021}, + doi={10.1038/s41597-021-00833-x}, + publisher={Nature Publishing Group UK London} +} + +@article{taylor2016, + title={Blind test of density-functional-based methods on intermolecular interaction energies}, + author={Taylor, DeCarlos E and {\'A}ngy{\'a}n, J{\'a}nos G and Galli, Giulia and Zhang, Cui and Gygi, Francois and Hirao, Kimihiko and Song, Jong Won and Rahul, Kar and Anatole von Lilienfeld, O and Podeszwa, Rafa{\l} and others}, + journal={The Journal of chemical physics}, + volume={145}, + number={12}, + year={2016}, + doi={10.1063/1.4961095}, + publisher={AIP Publishing} +} + +@article{smith2016, + title={Revised damping parameters for the {D3} dispersion correction to density functional theory}, + author={Smith, Daniel GA and Burns, Lori A and Patkowski, Konrad and Sherrill, C David}, + journal={The journal of physical chemistry letters}, + volume={7}, + number={12}, + pages={2197--2203}, + year={2016}, + doi={10.1021/acs.jpclett.6b00780}, + publisher={ACS Publications} } \ No newline at end of file diff --git a/docs/ase.ipynb b/docs/ase.ipynb index 94facd0..8aaf4df 100644 --- a/docs/ase.ipynb +++ b/docs/ase.ipynb @@ -20,7 +20,7 @@ "import numpy as np\n", "from ase.build import molecule\n", "from ase.optimize import LBFGSLineSearch as Opt\n", - "from ase.units import Bohr, Hartree\n", + "from ase.units import Hartree\n", "\n", "from skala.ase import Skala" ] @@ -54,7 +54,7 @@ "atoms = molecule(\"H2O\")\n", "\n", "# Set up the Skala calculator with specific parameters\n", - "atoms.calc = Skala(xc=\"skala\", basis=\"def2-svp\", verbose=4)\n", + "atoms.calc = Skala(xc=\"skala-1.1\", basis=\"def2-svp\", verbose=4)\n", "\n", "# Display the calculator parameters\n", "print(\"Calculator parameters:\")\n", @@ -106,7 +106,10 @@ "source": [ "# Update calculator settings\n", "changed_params = atoms.calc.set(\n", - " with_density_fit=True, verbose=0, ks_config={\"conv_tol\": 1e-6}\n", + " with_density_fit=True,\n", + " verbose=0,\n", + " auxbasis=\"def2-universal-jkfit\",\n", + " ks_config={\"conv_tol\": 1e-6},\n", ")\n", "print(changed_params)\n", "print(f\"Changed parameters: {changed_params}\")\n", @@ -140,9 +143,11 @@ "forces = atoms.get_forces() # eV/Å\n", "\n", "print(\"Forces on atoms (eV/Å):\")\n", - "for i, (symbol, force) in enumerate(zip(atoms.get_chemical_symbols(), forces)):\n", + "for i, (symbol, force) in enumerate(\n", + " zip(atoms.get_chemical_symbols(), forces, strict=True)\n", + "):\n", " print(\n", - " f\" Atom {i+1} ({symbol}): [{force[0]:8.4f}, {force[1]:8.4f}, {force[2]:8.4f}]\"\n", + " f\" Atom {i + 1} ({symbol}): [{force[0]:8.4f}, {force[1]:8.4f}, {force[2]:8.4f}]\"\n", " )\n", "\n", "# Calculate maximum force component\n", @@ -235,9 +240,10 @@ "\n", "# Set up calculator with a larger basis set and density fitting\n", "atoms.calc = Skala(\n", - " xc=\"skala\",\n", + " xc=\"skala-1.1\",\n", " basis=\"def2-tzvp\", # Triple-zeta basis set\n", " with_density_fit=True, # Enable density fitting for efficiency\n", + " auxbasis=\"def2-universal-jkfit\",\n", " with_dftd3=True, # Include dispersion correction\n", " verbose=1,\n", ")\n", @@ -251,7 +257,7 @@ "forces = atoms.get_forces()\n", "max_force = np.max(np.abs(forces))\n", "\n", - "print(f\"\\nCalculation results:\")\n", + "print(\"\\nCalculation results:\")\n", "print(f\"Total energy: {energy:.6f} eV\")\n", "print(f\"Maximum force: {max_force:.6f} eV/Å\")" ] diff --git a/docs/conf.py b/docs/conf.py index 6865243..359ee72 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,12 @@ +import os +import sys + import skala +# Ensure CONDA_PREFIX is set so CuPy can locate the CUDA toolkit +# when notebooks are executed during the Sphinx build. +os.environ.setdefault("CONDA_PREFIX", sys.prefix) + project = "Skala" version = skala.__version__ author = "Microsoft Research, AI for Science" @@ -15,6 +22,7 @@ ] nb_execution_timeout = 300 # 5 minutes, set to -1 for no timeout +nb_execution_raise_on_error = True # Fail the build on any notebook execution error nb_merge_streams = True # Merge multiple outputs from the same cell into one box bibtex_bibfiles = [ @@ -26,5 +34,20 @@ html_title = project html_logo = "_static/img/density.png" html_favicon = "_static/img/density.png" +html_theme_options = { + "repository_url": "https://github.com/microsoft/skala", + "repository_branch": "main", + "path_to_docs": "docs", + "use_repository_button": True, +} master_doc = "index" + +suppress_warnings = ["misc.highlighting_failure"] exclude_patterns = ["_build", "jupyter_execute"] + +# DOIs 403 automated requests from linkcheck, so we ignore the `doi.org` prefix +linkcheck_ignore = [ + r"^https://doi\.org/", + # TODO: remove once arXiv v6 is posted. + r"^https://arxiv\.org/abs/2506\.14492v6$", +] diff --git a/docs/ftorch.rst b/docs/ftorch.rst index 36cb838..2b71aa6 100644 --- a/docs/ftorch.rst +++ b/docs/ftorch.rst @@ -182,7 +182,7 @@ To evaluate Skala, we download the model checkpoint from HuggingFace using the ` .. code-block:: shell - hf download microsoft/skala-1.0 skala-1.0.fun --local-dir . + hf download microsoft/skala-1.1 skala-1.1.fun --local-dir . .. note:: @@ -203,14 +203,14 @@ And run the application, passing the path to the Skala model and the feature dir .. code-block:: bash - ./build/Skala skala-1.0.fun features + ./build/Skala skala-1.1.fun features The output for the H2 molecule with the def2-QZVP basis set should look like this: .. code-block:: text - [1] Loading model from skala-1.0.fun - [2] Loading features from features + [1] Loading model from skala-1.1.fun + [2] Loading features from H2-def2qzvp -> Loading coarse_0_atomic_coords -> Loading grad -> Loading grid_coords @@ -254,14 +254,14 @@ Finally, we can run the application again to see the updated output with the com .. code-block:: bash - ./build/Skala skala-1.0.fun features + ./build/Skala skala-1.1.fun features In the output we can see the computed exchange-correlation energy as well as the mean values of the potential components, and the raw tensor data for each component. .. code-block:: text - [1] Loading model from skala-1.0.fun - [2] Loading features from features + [1] Loading model from skala-1.1.fun + [2] Loading features from H2-def2qzvp -> Loading coarse_0_atomic_coords -> Loading grad -> Loading grid_coords diff --git a/docs/gauxc/api/cmake.rst b/docs/gauxc/api/cmake.rst index 3a7d414..7b1ea05 100644 --- a/docs/gauxc/api/cmake.rst +++ b/docs/gauxc/api/cmake.rst @@ -37,7 +37,7 @@ When building with CUDA support via :cmake:variable:`GAUXC_ENABLE_CUDA` (default When building with HDF5 support via :cmake:variable:`GAUXC_ENABLE_HDF5` (default ``on``), the following dependencies are also required: - `HDF5 `__ -- `HighFive `__\ * (version 2.4.0 or higher) +- `HighFive `__\ * (version 2.4.0 or higher) All libraries marked with a * can be automatically fetched by the GauXC build system and do not need to be installed manually. diff --git a/docs/gauxc/c-library.rst b/docs/gauxc/c-library.rst index 17cc32b..ec3cc6a 100644 --- a/docs/gauxc/c-library.rst +++ b/docs/gauxc/c-library.rst @@ -507,8 +507,8 @@ After downloading the model checkpoint we can run our driver again with the new .. code-block:: shell - hf download microsoft/skala-1.0 skala-1.0.fun --local-dir . - ./build/Skala He_def2-svp.h5 --model ./skala-1.0.fun + hf download microsoft/skala-1.1 skala-1.1.fun --local-dir . + ./build/Skala He_def2-svp.h5 --model ./skala-1.1.fun In the output we can see the results for the Skala functional @@ -516,14 +516,14 @@ In the output we can see the results for the Skala functional Configuration -> Input file : He_def2-svp.h5 - -> Model : ./skala-1.0.fun + -> Model : ./skala-1.1.fun -> Grid : fine -> Radial quadrature : muraknowles -> Pruning scheme : robust Results - -> EXC : -1.0712560886 - -> |VXC(a+b)|_F : 1.5002997528 + -> EXC : -1.0646206500 + -> |VXC(a+b)|_F : 1.4893029702 -> |VXC(a-b)|_F : 0.0000000000 Full source code diff --git a/docs/gauxc/cpp-library.rst b/docs/gauxc/cpp-library.rst index 9734205..4935b08 100644 --- a/docs/gauxc/cpp-library.rst +++ b/docs/gauxc/cpp-library.rst @@ -413,8 +413,8 @@ After downloading the model checkpoint we can run our driver again with the new .. code-block:: shell - hf download microsoft/skala-1.0 skala-1.0.fun --local-dir . - ./build/Skala He_def2-svp.h5 --model ./skala-1.0.fun + hf download microsoft/skala-1.1 skala-1.1.fun --local-dir . + ./build/Skala He_def2-svp.h5 --model ./skala-1.1.fun In the output we can see the results for the Skala functional @@ -422,13 +422,13 @@ In the output we can see the results for the Skala functional Configuration -> Input file : He_def2-svp.h5 - -> Model : ./skala-1.0.fun + -> Model : ./skala-1.1.fun -> Grid : fine -> Radial quadrature : muraknowles -> Pruning scheme : robust - EXC = -1.071256087389e+00 Eh - |VXC(a+b)|_F = 1.500299750739e+00 + EXC = -1.064620650033e+00 Eh + |VXC(a+b)|_F = 1.489302970205e+00 |VXC(a-b)|_F = 0.000000000000e+00 Runtime XC = 1.792662281000e+00 s diff --git a/docs/gauxc/fortran-library.rst b/docs/gauxc/fortran-library.rst index 0c20748..f146af0 100644 --- a/docs/gauxc/fortran-library.rst +++ b/docs/gauxc/fortran-library.rst @@ -560,8 +560,8 @@ from the ``huggingface_hub`` package: .. code-block:: shell - hf download microsoft/skala-1.0 skala-1.0.fun --local-dir . - ./build/Skala He_def2-svp.h5 --model ./skala-1.0.fun + hf download microsoft/skala-1.1 skala-1.1.fun --local-dir . + ./build/Skala He_def2-svp.h5 --model ./skala-1.1.fun The output shows results for the Skala functional: @@ -569,14 +569,14 @@ The output shows results for the Skala functional: Configuration -> Input file : He_def2-svp.h5 - -> Model : ./skala-1.0.fun + -> Model : ./skala-1.1.fun -> Grid : fine -> Radial quadrature : muraknowles -> Pruning scheme : robust Results - Exc = -1.0712560874E+00 Eh - |VXC(a+b)|_F = 1.5002997546E+00 + Exc = -1.0646206500E+00 Eh + |VXC(a+b)|_F = 1.4893029702E+00 |VXC(a-b)|_F = 0.0000000000E+00 Runtime XC = 1.5986489670E+00 diff --git a/docs/index.rst b/docs/index.rst index 530a60b..61e49dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,9 +7,6 @@ Overview *Skala* is a deep-learning exchange-correlation functional designed to provide chemical accuracy (less than 1 kcal/mol) for a wide range of chemical systems without using expensive non-local features like exact exchange or hand-crafted density convolutions. The model is trained on a large dataset of highly accurate total atomization energies, thermochemical properties, like ionization potentials and proton affinities, conformer energies, reaction paths, and non-covalent interactions. -*Skala* is still in active development, we are working on improving the model accuracy and also on integrating it into quantum chemistry packages. -Please stay tuned for updates and new releases. - .. admonition:: Learn more :class: important @@ -36,6 +33,7 @@ Please stay tuned for updates and new releases. :caption: References :hidden: - model-card/skala-1.0 + model-card/skala-1.1 + model-card/index Skala preprint Breaking bonds, breaking ground diff --git a/docs/installation.rst b/docs/installation.rst index f143c79..00e570b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -72,18 +72,29 @@ If you prefer to install Skala from the source code, you can clone the repositor mamba activate skala pip install -e . -where `environment-cpu.yml` can be replaced with `environment-gpu.yml` for gpu support (specify CUDA version with `cuda_version=`) with gpu4pyscf, in which case gpu4pyscf needs to be separately installed *after creating the environment* via (for CUDA 12) +where ``environment-cpu.yml`` can be replaced with ``environment-gpu.yml`` for +GPU support via `GPU4PySCF `__. The GPU +environment pins ``cuda-toolkit 12``, ``cuda-version 12``, ``cutensor``, and +installs ``gpu4pyscf-cuda12x 1.5`` from PyPI as part of the environment file — +no separate install step is required: .. code-block:: bash - pip install --no-deps 'gpu4pyscf-cuda12x>=1.0,<2' 'gpu4pyscf-libxc-cuda12x>=0.4,<1' - + mamba env create -n skala -f environment-gpu.yml + mamba activate skala + pip install -e . -or (for CUDA 13) +If you are building inside a container without a GPU attached (for example CI, +or a Docker image built on a CPU-only host), set ``CONDA_OVERRIDE_CUDA`` so the +solver proceeds without a device: .. code-block:: bash - pip install --no-deps 'gpu4pyscf-cuda13x>=1.0,<2' 'gpu4pyscf-libxc-cuda13x>=0.4,<1' + CONDA_OVERRIDE_CUDA=12.0 mamba env create -n skala -f environment-gpu.yml + +For CUDA 11 or 13, adjust ``cuda-toolkit``, ``cuda-version``, and the +``gpu4pyscf-cuda{11,13}x`` pin in ``environment-gpu.yml`` accordingly. Check +your driver's maximum supported CUDA version with ``nvidia-smi``. To install the development dependencies, you can run: diff --git a/docs/model-card/index.rst b/docs/model-card/index.rst new file mode 100644 index 0000000..0e4fb9b --- /dev/null +++ b/docs/model-card/index.rst @@ -0,0 +1,11 @@ +Older models +============ + +The latest version of the Skala model is 1.1, which is described in :doc:`skala-1.1`. +Previous versions of the model, such as 1.0, are still available for reference and can be found here. + +.. toctree:: + :maxdepth: 1 + :caption: Older models + + skala-1.0 \ No newline at end of file diff --git a/docs/model-card/skala-1.0.rst b/docs/model-card/skala-1.0.rst index b47df8c..02b1e16 100644 --- a/docs/model-card/skala-1.0.rst +++ b/docs/model-card/skala-1.0.rst @@ -1,5 +1,5 @@ -Skala model -=========== +Skala 1.0 model +=============== Model details ------------- @@ -85,9 +85,9 @@ The following data is included in our training set: - Four datasets from the `NCI-Atlas collection of non-covalent interactions `__: - `D442x10 `__, dissociation curves for dispersion-bound van der Waals complexes - - `SH250x10 `__, dissociation curves for sigma-hole-bound van der Waals complexes - - `R739x5 `__, compressed van der Waals complexes - - `HB300SPXx10 `__, dissociation curves for hydrogen-bound van der Waals complexes + - `SH250x10 `__, dissociation curves for sigma-hole-bound van der Waals complexes + - `R739x5 `__, compressed van der Waals complexes + - `HB300SPXx10 `__, dissociation curves for hydrogen-bound van der Waals complexes - W4-CC, containing atomization energies of carbon clusters.\ :footcite:`karton2009` diff --git a/docs/model-card/skala-1.1.rst b/docs/model-card/skala-1.1.rst new file mode 100644 index 0000000..5cdbfd3 --- /dev/null +++ b/docs/model-card/skala-1.1.rst @@ -0,0 +1,233 @@ +Skala 1.1 model +=============== + +Model details +------------- + +In pursuit of the universal functional for density functional theory (DFT), the OneDFT team from Microsoft Research AI for Science has developed the Skala-1.1 exchange-correlation functional, as introduced in `Accurate and scalable exchange-correlation with deep learning, Luise et al. 2025 `__. +This approach departs from the traditional route of incorporating increasingly expensive hand-designed non-local features from Jacob's ladder into functional forms to improve their accuracy. +Instead, we employ a deep learning approach with a scalable neural network that uses only inexpensive input features to learn the necessary non-local representations. + +The functional is based on a neural network architecture that takes as input features on a 3D grid describing the electron density and derived meta-generalized-gradient (meta-GGA) quantities. +The architecture performs scalable non-local message-passing on the integration grid via a second, coarser grid, combined with shared local layers that enable representation learning of both local and non-local features. +These representations are then used to predict the exchange-correlation energy in an end-to-end data-driven manner. + +To facilitate this learning, the model is trained on a dataset of unprecedented size, containing highly accurate energy labels from coupled cluster theory. +The largest subset focuses on atomization energies and was generated in collaboration with the University of New England. +This subset is released as part of the Microsoft Research Accurate Chemistry Collection (MSR-ACC, `Accurate Chemistry Collection: Coupled cluster atomization energies for broad chemical space, Ehlert et al. 2025 `__). +To broaden coverage of other types of chemistry, the training dataset is further complemented with in-house generated datasets covering conformers, ionization potentials, electron affinities, proton affinities, noncovalent interactions, distorted equilibrium geometries, and elementary reactions, as well as a small amount of publicly available high-accuracy data. + +We demonstrate that departure from the historical trade-off between accuracy and efficiency is enabled by learning non-local representations of electronic structure directly from data, bypassing the need for increasingly costly hand-engineered features. +The Skala-1.1 functional surpasses state-of-the-art hybrid functionals in accuracy across the main-group chemistry benchmark set GMTKN55, which covers general main-group thermochemistry, kinetics, and noncovalent interactions, with an error of 2.72 kcal/mol, while retaining the lower computational cost characteristic of semi-local DFT. +With this work, we demonstrate the viability of our approach toward the universal density functional across all of chemistry. + +Users of this model are expected to have a basic understanding of the field of quantum chemistry and density functional theory. + +:Developed by: + Chin-Wei Huang, Deniz Gunceler, Derk Kooi, Gregor Simm, Klaas Giesbertz, Giulia Luise, Jan Hermann, Megan Stanley, Paola Gori Giorgi, P. Bernát Szabó, Rianne van den Berg, Sebastian Ehlert, Stefano Battaglia, Stephanie Lanius, Thijs Vogels, Wessel Bruinsma + +:Shared by: + Microsoft Research AI for Science + +:Model type: + Neural Network Density Functional Theory Exchange Correlation Functional + +:License: + MIT + + +Direct intended uses +-------------------- + +#. The Skala-1.1 functional is shared with the research community to facilitate reproduction of the evaluations presented in our paper. +#. Evaluating reaction energy differences by computing the total energy of all compounds in a reaction using a self-consistent field (SCF) calculation with the Skala-1.1 exchange-correlation functional. +#. Evaluating the total energy of a molecule using an SCF calculation with the Skala-1.1 exchange-correlation functional. Note that, as with all density functionals, energy differences are predicted much more reliably than total energies of individual molecules. +#. The SCF implementation provided uses PySCF, which runs the functional on CPU. We also provide a traced version of the Skala-1.1 functional so that other, more optimized open-source SCF codes—including GPU-enabled ones—can integrate it into their pipelines, for instance through GauXC. A compatible fork of GauXC is included in this repository. + +Out-of-scope uses +----------------- + +#. Evaluating the functional with a single pass given a fixed density as input is not the intended way to evaluate the model. The model's predictions should always be made by using it as part of an SCF procedure. +#. We do not include a training pipeline for the Skala-1.1 functional in this code base. + +Risks and limitations +--------------------- + +#. Interpretation of results requires expertise in quantum chemistry. +#. The Skala-1.1 functional is trained on atomization energies, conformers, proton affinities, ionization potentials, electron affinities, elementary reaction pathways, distorted equilibrium geometries, and non-covalent interactions, as well as a small amount of total energies of atoms and transition metal atoms and dimer properties. We have benchmarked performance on W4-17 for atomization energies and on GMTKN55, which covers general main-group thermochemistry, kinetics, and noncovalent interactions, to provide an indication of generalization beyond the training set. We have also evaluated robustness on dipole moment predictions and geometry optimization. +#. The Skala-1.1 functional has been trained on data containing the following elements: H--Xe. It has been tested on data containing H--Xe, Pb, and Bi. +#. Given points 2 and 3 above, this is not a production model. We advise testing the functional further before applying it to your research and welcome any feedback. + +Recommendations +--------------- + +#. In our PySCF-based SCF implementation, the largest system tested contained 180 atoms using the def2-TZVP basis set (:math:`\sim`\ 5000 orbitals) on `Eadsv5 series `__ virtual machines. Larger systems may run out of memory. +#. For implementations optimized for memory, speed, or GPU support, we recommend integrating the functional with other open-source SCF packages, for instance through GauXC. A compatible fork of GauXC is included in this repository. +#. Skala-1.1 will also be available through `Azure AI Foundry `__, where it is coupled with Microsoft's GPU-accelerated `Accelerated DFT `__ application. + + +Training details +---------------- + +Training data +~~~~~~~~~~~~~ + +The following data is included in our training set: + +:MSR-ACC: + 99% of MSR-ACC/TAE25 (~78k reactions) containing atomization energies for up to five non-hydrogen atoms. + This data was generated in collaboration with Prof. Amir Karton, University of New England, with the W1-F12 composite protocol based on CCSD(T) and is released as part of the `Microsoft Research Accurate Chemistry Collection `__ (MSR-ACC). + Additionally the MSR-ACC subsets for larger TAEs (up to 9 non-hydrogen atoms), conformers, ionization potentials, electron affinities, proton affinities, reaction paths, and distorted equilibrium structures were included. + The labels for these data sets are obtained with the W1w method and are part of the currently unpublished subsets of the MSR-ACC. + +:Atomic Data: + Total energies, electron affinities, and ionization potentials (up to triple ionization) for atoms, from H to Ar (excluding Li and Be due to basis-set constraints). + This data was produced in-house with CCSD(T) by extrapolating to the complete basis set limit from quadruple zeta (QZ) and pentuple zeta (5Z) calculations. + The basis sets used for H and He were aug-cc-pV(Q+d)Z and aug-cc-pV(5+d), while for the remaining elements B–Ar the basis sets were aug-cc-pCVQZ and aug-cc-pCV5Z. + All basis sets were obtained from the `Basis Set Exchange (BSE) `__. + Extrapolation of the correlation energy was performed by fitting a :math:`Z^{-3}` expression, while the Hartree–Fock energy was extrapolated using the two-point scheme of :footcite:`karton2006`. + +:Transition metal properties: + Additional data for transition metal atoms and dimers, including ionization potentials, spin splittings, and dissociation energies. + The reference energies were obtained from literature. + +:NCI-Atlas: + Five datasets from the `NCI-Atlas collection of non-covalent interactions `__: + + - `D442x10 `__, dissociation curves for dispersion-bound van der Waals complexes + - `SH250x10 `__, dissociation curves for sigma-hole-bound van der Waals complexes + - `R739x5 `__, compressed van der Waals complexes + - `HB300SPXx10 `__, dissociation curves for hydrogen-bound van der Waals complexes + - `IHB100x10 `__, dissociation curves for ionic hydrogen-bound van der Waals complexes + +:BH9: + Reactions and barrier heights.\ :footcite:`prasad2021` + The data set was filtered for systems with up to ten non-hydrogen atoms. + +:NCIBLIND: + Data set of non-covalent dissociation curves.\ :footcite:`taylor2016` + +:Water2510: + Data set of the potential energy surface of the water dimer.\ :footcite:`smith2016` + The data set was fully relabeled with W1w. + +:DES370k: + Subset with CCSD(T)/dCBS(aug-cc-pVQZ) non-covalent interaction energies.\ :footcite:`donchev2021` + +:MB2061: + Dataset containing decomposition energies of artificial molecules.\ :footcite:`gasevic2025` + +:W4-CC: + Containing atomization energies of carbon clusters.\ :footcite:`karton2009` + +For all training data, input density and derived meta-GGA features were computed from density matrices of converged B3LYP SCF calculations (def2-QZVP and ma-def2-QZVP basis sets) using a modified version of PySCF. + +Training procedure +~~~~~~~~~~~~~~~~~~ + +Preprocessing +^^^^^^^^^^^^^ + +The training datapoints are preprocessed as follows. + +- For each molecule, the density and derived meta-GGA features are computed from the density matrix of a converged B3LYP SCF calculation using a def2-QZVP or ma-def2-QZVP basis set in a modified version of PySCF. +- Density fitting was not applied. +- The density features were evaluated on an atom-centered integration grid of level 1. +- The radial quadrature was performed with Treutler-Ahlrichs, Gauss-Chebyshev, Delley, or Mura-Knowles schemes based on Bragg atomic radii with Treutler-based radii adjustment. +- The space-partitioning was performed with Becke partition and Treutler-Ahlrichs radii adjustment, Stratmann-Scuseria-Frisch (SSF) partition scheme, and Laqua-Kussmann-Ochsenfeld (LKO) partition scheme. +- The angular grid points were pruned using the NWChem scheme. +- No density-based cutoff was applied; all grid points were retained for training. + +Training hyperparameters +^^^^^^^^^^^^^^^^^^^^^^^^ + +The training hyperparameter settings are detailed in the supplementary material of `Accurate and scalable exchange-correlation with deep learning, Luise et al. 2025 `__. +This repository only includes the code to evaluate the provided checkpoints, not the training code. + +Speeds, sizes, times +^^^^^^^^^^^^^^^^^^^^ + +The training of the functional on the dataset described above took approximately 48 hours for 1M steps on an `ND A100 v4 series VM `__ with 8 NVIDIA A100 GPUs (80 GB each), 96 CPU cores, 880 GB RAM, and a 6 TB disk. + +The model checkpoints have :math:`\sim`\ 385k trainable parameters. + +Evaluation +---------- + +Testing data, factors, and metrics +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We have evaluated our functional on several different benchmark sets: + +#. W4-17. A diverse and highly accurate dataset of atomization energies.\ :footcite:`karton2017` +#. Transition metal data sets including MOR41,\ :footcite:`dohm2018` ROST61,\ :footcite:`maurer2021` MOBH35,\ :footcite:`semidalas2022` 3dTMV,\ :footcite:`neugebauer2023` CuAgAu83,\ :footcite:`chan2019` DAPd,\ :footcite:`chan2023` 3d4dIPSS, TMB11, and TMD10.\ :footcite:`liang2025` +#. GMTKN55. A diverse and highly accurate dataset of general main-group thermochemistry, kinetics, and noncovalent interactions.\ :footcite:`goerigk2017` +#. Geometry optimization datasets: (a) CCse21, equilibrium structures, bond lengths, and bond angles;\ :footcite:`piccardo2015` (b) HMGB11, equilibrium structures and bond lengths;\ :footcite:`grimme2015` (c) LMGB35, equilibrium structures and bond lengths;\ :footcite:`grimme2015` and (d) W4-11-GEOM, equilibrium structures, bond lengths, and bond angles.\ :footcite:`karton2011` +#. The dipole benchmark dataset from :footcite:`hait2018`. +#. Conformer search benchmark dataset of 22 molecules spanning 24 to 176 atoms, used for cost-scaling analysis, from :footcite:`grimme2019`. + +These six benchmark types serve to measure different performance aspects of the functional. +Benchmarks 1 and 2 focus on the accuracy of predicted reaction energies. +Benchmark 3 evaluates general main-group thermochemistry, kinetics, and noncovalent interactions. +Benchmark 4 evaluates geometry optimization and convergence to reference equilibrium structures. +Benchmark 5 measures dipole moments, providing a proxy for the quality of the self-consistent electron density produced by the SCF procedure. +Finally, benchmark 6 assesses computational cost scaling with respect to system size. + +The metrics for the different benchmark sets are: + +#. Mean Absolute Error (MAE) in kcal/mol for reactions in W4-17 :math:`MAE = \frac{1}{N} \sum_{r=1}^N |\Delta E_r - \Delta E_r^\theta|`. Here *N* is the number of reactions in W4-17, *r* is the index denoting reactions in W4-17, :math:`\Delta E_r` is the energy difference of reaction r as calculated by a high-accuracy method from the W4 family (CCSDT(Q)/CBS to CCSDTQ56/CBS), and :math:`\Delta E_r^\theta` is the prediction of the reaction energy difference using SCF calculations with our functional. +#. Weighted total mean absolute deviations 2 (WTMAD-2) in kcal/mol for the GMTKN55 benchmark set :math:`\text{WTMAD-2} = \frac1{\sum^{55}_{i=1} N_i} \sum_{i=1}^{55} N_i \frac{56.84\text{ kcal/mol}}{\overline{|\Delta E|}_i} \text{MAE}_i` Here :math:`N_i` is the number of reactions in subset *i*, :math:`\overline{|\Delta E|}_i` is the average energy difference in subset *i* in kcal/mol, and :math:`\text{MAE}_i` is the mean absolute error in kcal/mol for subset *i*. +#. For the geometry benchmark sets that report bond lengths, we measure the absolute error in bond lengths in Angstrom, averaged over the number of bonds and the number of equilibrium structures in the dataset. For the benchmark that also contains bond angles, we report the absolute error of the angles, averaged over the number of bonds and equilibrium structures in the dataset. +#. For the dipole benchmark, we follow the metrics defined in :footcite:`hait2018`. For molecules (indexed by *i*) for which only the reference magnitude of the dipole moment :math:`\mu_i^{\text{ref}} = |{\vec\mu}_i^{\text{ref}}|` is provided, the error is defined as :math:`\text{Error}_i = \frac{\mu_i^\theta - \mu_i^\text{ref}}{\max(\mu_i^\text{ref}, 1D)} \times 100\%`, where :math:`\mu_i^{\theta} = |{\vec\mu}_i^{\theta}|` is the predicted magnitude and *D* denotes the unit of Debye. For molecules for which the reference dipole vector :math:`\vec{\mu}_i^\text{ref}` is also available, we instead compute :math:`\text{Error}_i = \frac{|\vec{\mu}_i^\theta - \vec{\mu}_i^\text{ref}|}{\max(\mu_i^\text{ref}, 1D)} \times 100\%`. The RMSE is then :math:`\text{RMSE} = \sqrt{\frac{1}{N} \sum_{i=1}^N \text{Error}_i^2}`. +#. We fit a power law of the form :math:`C(M) = \left(\frac{n(M)}{A}\right)^k` to the 22 data points of the test set where *C(M)* and *n(M)* are the computational cost and number of atoms of molecule *M*, respectively, and *A* and *k* are fitted parameters. We report the scaling power *k* as the main metric. + +Evaluation results +~~~~~~~~~~~~~~~~~~ + +On W4-17, the Skala-1.1 functional predicts atomization energies at chemical accuracy (:math:`\sim`\ 1 kcal/mol MAE). +On GMTKN55, which covers general main-group thermochemistry, kinetics, and noncovalent interactions, it achieves a WTMAD-2 of 2.72 kcal/mol, surpassing state-of-the-art range-separated hybrid functionals while only requiring runtimes typical of semi-local DFT. + +On the geometry optimization benchmarks, the functional converges to reference equilibrium structures with errors comparable to a range-separated hybrid functional. +On the dipole prediction benchmark, the error in dipole moment predictions is better than that of state-of-the-art range-separated hybrid functionals. + +Finally, the scaling results show that the Skala-1.1 functional exhibits the asymptotic scaling behavior of a meta-GGA functional, with an approximate prefactor of 3 relative to r2SCAN. + +License +------- + +.. dropdown:: MIT License + + .. literalinclude:: ../../LICENSE.txt + :lines: 3- + +Citation +-------- + +When using Skala-1.1 in your research, please reference it including the version number as follows: + + This work uses the Skala-1.1 functional. + +.. code:: bibtex + + @misc{luise2025, + title={Accurate and scalable exchange-correlation with deep learning}, + author={Giulia Luise and Chin-Wei Huang and Thijs Vogels and Derk P. Kooi and Sebastian Ehlert and Stephanie Lanius and Klaas J. H. Giesbertz and Amir Karton and Deniz Gunceler and Megan Stanley and Wessel P. Bruinsma and Lin Huang and Xinran Wei and José Garrido Torres and Abylay Katbashev and Rodrigo Chavez Zavaleta and Bálint Máté and Sékou-Oumar Kaba and Roberto Sordillo and Yingrong Chen and David B. Williams-Young and Christopher M. Bishop and Jan Hermann and Rianne van den Berg and Paola Gori-Giorgi}, + year={2025}, + eprint={2506.14665}, + archivePrefix={arXiv}, + primaryClass={physics.chem-ph}, + url={https://arxiv.org/abs/2506.14665}, + } + +Model card contact +------------------ + +- Rianne van den Berg, `rvandenberg@microsoft.com `_ +- Paola Gori-Giorgi, `pgorigiorgi@microsoft.com `_ +- Jan Hermann, `jan.hermann@microsoft.com `_ +- Sebastian Ehlert, `sehlert@microsoft.com `_ + +References +---------- + +.. footbibliography:: \ No newline at end of file diff --git a/docs/pyscf/gpu4pyscf.rst b/docs/pyscf/gpu4pyscf.rst index 6fde1ab..139ab72 100644 --- a/docs/pyscf/gpu4pyscf.rst +++ b/docs/pyscf/gpu4pyscf.rst @@ -13,7 +13,7 @@ The Skala functional can also be used in GPU4PySCF with an appropriate PyTorch C atom="""H 0 0 0; H 0 0 1.4""", basis="def2-tzvp", ) - ks = SkalaKS(mol, xc="skala") + ks = SkalaKS(mol, xc="skala-1.1") ks.kernel() print(ks.dump_scf_summary()) @@ -22,23 +22,19 @@ The Skala functional can also be used in GPU4PySCF with an appropriate PyTorch C Installation ------------ -Install the latest version of the ``skala`` package from PyPI or conda-forge together with a compatible PyTorch CUDA version. -For pip the default PyTorch installation will be used, which is typically the latest version with CUDA support. +The recommended way to set up a GPU environment is the provided +``environment-gpu.yml``, which pins ``pytorch-gpu``, ``cuda-toolkit 12``, +``cuda-version 12``, ``cutensor``, and installs ``gpu4pyscf-cuda12x 1.5`` from +PyPI as part of the environment file: .. code-block:: bash - pip install skala cupy-cuda12x cutensor-cuda12x + mamba env create -n skala -f environment-gpu.yml + mamba activate skala + pip install skala -For conda-forge, select the pytorch CUDA version that matches your system and CUDA installation. For example, for CUDA 12.8: +For CUDA 11 or 13, adjust ``cuda-toolkit``, ``cuda-version``, and the +``gpu4pyscf-cuda{11,13}x`` pin in ``environment-gpu.yml`` accordingly. -.. code-block:: bash - - mamba install -c conda-forge skala 'cuda-version=12.*' 'pytorch=*=cuda*' cupy cutensor - -In both cases you need to install GPU4PySCF separately from PyPI. - -.. code-block:: bash - - pip install --no-deps "gpu4pyscf-cuda12x>=1.0,<2" "gpu4pyscf-libxc-cuda12x>=0.4,<1" - -We are using ``--no-deps`` to avoid overriding the already installed cupy and cutensor packages in the previous step. \ No newline at end of file +See the :doc:`installation guide ` for more details, including +how to install from conda-forge or inside a container without a GPU attached. \ No newline at end of file diff --git a/docs/pyscf/scf_settings.ipynb b/docs/pyscf/scf_settings.ipynb index e7b11b7..34c2e7e 100644 --- a/docs/pyscf/scf_settings.ipynb +++ b/docs/pyscf/scf_settings.ipynb @@ -72,7 +72,7 @@ } ], "source": [ - "ks = SkalaKS(mol, xc=\"skala\")\n", + "ks = SkalaKS(mol, xc=\"skala-1.1\")\n", "ks.kernel()\n", "\n", "ks.dump_scf_summary()" @@ -129,7 +129,7 @@ " \"diis_start_cycle\": 1,\n", " \"level_shift\": 0.0,\n", "}\n", - "ks = SkalaKS(mol, ks_config=ks_config, xc=\"skala\")" + "ks = SkalaKS(mol, ks_config=ks_config, xc=\"skala-1.1\")" ] }, { @@ -190,4 +190,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/docs/pyscf/singlepoint.ipynb b/docs/pyscf/singlepoint.ipynb index 7a31ce6..8313c37 100644 --- a/docs/pyscf/singlepoint.ipynb +++ b/docs/pyscf/singlepoint.ipynb @@ -29,7 +29,7 @@ "metadata": {}, "source": [ "The Kohn-Sham calculator for the Skala functional is created from a regular PySCF molecule object.\n", - "By specifying the `xc` parameter as `\"skala\"`, the Skala functional is automatically loaded and used for the calculations." + "By specifying the `xc` parameter as `\"skala-1.1\"`, the Skala functional is automatically loaded and used for the calculations." ] }, { @@ -43,7 +43,7 @@ " atom=\"\"\"H 0 0 0; H 0 0 1.4\"\"\",\n", " basis=\"def2-tzvp\",\n", ")\n", - "ks = SkalaKS(mol, xc=\"skala\")\n", + "ks = SkalaKS(mol, xc=\"skala-1.1\")\n", "ks.kernel()\n", "\n", "print(ks.dump_scf_summary())" @@ -72,7 +72,9 @@ " atom=\"\"\"H 0 0 0; H 0 0 1.4\"\"\",\n", " basis=\"def2-tzvp\",\n", ")\n", - "ks = SkalaKS(mol, xc=\"skala\", with_density_fit=True)\n", + "ks = SkalaKS(\n", + " mol, xc=\"skala-1.1\", with_density_fit=True, auxbasis=\"def2-universal-jkfit\"\n", + ")\n", "ks.kernel()\n", "\n", "print(ks.dump_scf_summary())" @@ -99,7 +101,13 @@ " atom=\"\"\"H 0 0 0; H 0 0 1.4\"\"\",\n", " basis=\"def2-tzvp\",\n", ")\n", - "ks = SkalaKS(mol, xc=\"skala\", with_density_fit=True, with_newton=True)\n", + "ks = SkalaKS(\n", + " mol,\n", + " xc=\"skala-1.1\",\n", + " with_density_fit=True,\n", + " auxbasis=\"def2-universal-jkfit\",\n", + " with_newton=True,\n", + ")\n", "ks.kernel()\n", "\n", "print(ks.dump_scf_summary())" diff --git a/environment-cpu.yml b/environment-cpu.yml index c7eb266..9818850 100644 --- a/environment-cpu.yml +++ b/environment-cpu.yml @@ -10,7 +10,7 @@ dependencies: - h5py - numpy - opt_einsum_fx - - pyscf + - pyscf >=2.8,<2.14 - python - pytorch * cpu_* - qcelemental @@ -18,5 +18,8 @@ dependencies: - pre-commit - pytest - pytest-cov + - pytest-randomly + - ruff + - mypy - pip: - huggingface_hub diff --git a/environment-gpu.yml b/environment-gpu.yml index 3d7e43a..30cd58a 100644 --- a/environment-gpu.yml +++ b/environment-gpu.yml @@ -10,17 +10,20 @@ dependencies: - h5py - numpy - opt_einsum_fx - - pyscf + - pyscf >=2.8,<2.14 - python - pytorch-gpu - qcelemental - - cuda-toolkit - - cupy - - cutensor - - cuda-version + - cuda-toolkit >=12,<13 + - cuda-version >=12,<13 + - cutensor >=2 # Testing and development - pre-commit - pytest - pytest-cov + - pytest-randomly + - ruff + - mypy - pip: - huggingface_hub + - gpu4pyscf-cuda12x ==1.5 diff --git a/examples/cpp/cpp_integration/README.md b/examples/cpp/cpp_integration/README.md index 3e00b59..fc0110b 100644 --- a/examples/cpp/cpp_integration/README.md +++ b/examples/cpp/cpp_integration/README.md @@ -42,7 +42,7 @@ python ./prepare_inputs.py --output-dir H2 Finally, run $E_\text{xc}$ and (partial) $V_\text{xc}$ computations with the C++ example: ```bash -./_build/skala_cpp_integration skala-1.0.fun H2 +./_build/skala_cpp_integration skala-1.1.fun H2 ``` **Note:** You are expected to add D3 dispersion correction (using b3lyp settings) to the final energy of Skala. diff --git a/examples/cpp/cpp_integration/download_model.py b/examples/cpp/cpp_integration/download_model.py index d63de85..89574c7 100755 --- a/examples/cpp/cpp_integration/download_model.py +++ b/examples/cpp/cpp_integration/download_model.py @@ -38,7 +38,7 @@ def main() -> None: for huggingface_repo_id, filename in ( - ("microsoft/skala-1.0", "skala-1.0.fun"), + ("microsoft/skala-1.1", "skala-1.1.fun"), ("microsoft/skala-baselines", "ldax.fun"), ): output_path = filename.split("/")[-1] diff --git a/examples/cpp/cpp_integration/prepare_inputs.py b/examples/cpp/cpp_integration/prepare_inputs.py index 26be867..c5d4562 100755 --- a/examples/cpp/cpp_integration/prepare_inputs.py +++ b/examples/cpp/cpp_integration/prepare_inputs.py @@ -8,7 +8,11 @@ from pyscf.dft import gen_grid from skala.functional.traditional import LDA -from skala.pyscf.features import generate_features +from skala.pyscf.features import ( + _ATOMIC_GRID_FEATURES, + DEFAULT_FEATURES_SET, + generate_features, +) def main() -> None: @@ -39,8 +43,10 @@ def main() -> None: dm = get_density_matrix(molecule) grid = gen_grid.Grids(molecule) grid.level = 3 - grid.build() - features = generate_features(molecule, dm, grid) + grid.build(sort_grids=False) + features = generate_features( + molecule, dm, grid, features=DEFAULT_FEATURES_SET | _ATOMIC_GRID_FEATURES + ) # Add a feature called `coarse_0_atomic_coords` containing the atomic coordinates. features["coarse_0_atomic_coords"] = torch.from_numpy(molecule.atom_coords()) diff --git a/examples/fortran/ftorch_integration/app/main.f90 b/examples/fortran/ftorch_integration/app/main.f90 index 60e935e..f6528af 100644 --- a/examples/fortran/ftorch_integration/app/main.f90 +++ b/examples/fortran/ftorch_integration/app/main.f90 @@ -10,7 +10,8 @@ program main type(skala_model) :: model type(skala_dict) :: input, vxc type(torch_tensor) :: exc - type(torch_tensor) :: density, grad, kin, grid_coords, grid_weights, coarse_0_atomic_coords + type(torch_tensor) :: density, grad, kin, grid_coords, grid_weights, coarse_0_atomic_coords, & + & atomic_grid_weights, atomic_grid_sizes, atomic_grid_size_bound_shape type(torch_tensor) :: dexc_ddensity, dexc_dgrad, dexc_dkin, dexc_dgrid_coords, & & dexc_dgrid_weights, dexc_dcoarse_0_atomic_coords, vxc_norm @@ -55,6 +56,15 @@ program main case(skala_feature%coarse_0_atomic_coords) print '(a)', " -> Loading coarse_0_atomic_coords" call skala_tensor_load(coarse_0_atomic_coords, feature_dir//"/coarse_0_atomic_coords.pt") + case(skala_feature%atomic_grid_weights) + print '(a)', " -> Loading atomic_grid_weights" + call skala_tensor_load(atomic_grid_weights, feature_dir//"/atomic_grid_weights.pt") + case(skala_feature%atomic_grid_sizes) + print '(a)', " -> Loading atomic_grid_sizes" + call skala_tensor_load(atomic_grid_sizes, feature_dir//"/atomic_grid_sizes.pt") + case(skala_feature%atomic_grid_size_bound_shape) + print '(a)', " -> Loading atomic_grid_size_bound_shape" + call skala_tensor_load(atomic_grid_size_bound_shape, feature_dir//"/atomic_grid_size_bound_shape.pt") end select end do end block get_features @@ -74,6 +84,12 @@ program main call input%insert(skala_feature%grid_weights, grid_weights) if (model%needs_feature(skala_feature%coarse_0_atomic_coords)) & call input%insert(skala_feature%coarse_0_atomic_coords, coarse_0_atomic_coords) + if (model%needs_feature(skala_feature%atomic_grid_weights)) & + call input%insert(skala_feature%atomic_grid_weights, atomic_grid_weights) + if (model%needs_feature(skala_feature%atomic_grid_sizes)) & + call input%insert(skala_feature%atomic_grid_sizes, atomic_grid_sizes) + if (model%needs_feature(skala_feature%atomic_grid_size_bound_shape)) & + call input%insert(skala_feature%atomic_grid_size_bound_shape, atomic_grid_size_bound_shape) ! Request exc and vxc from the model print '(a)', "[4] Running model inference" diff --git a/examples/fortran/ftorch_integration/src/skala_ftorch.cxx b/examples/fortran/ftorch_integration/src/skala_ftorch.cxx index 00e97e3..54c4e06 100644 --- a/examples/fortran/ftorch_integration/src/skala_ftorch.cxx +++ b/examples/fortran/ftorch_integration/src/skala_ftorch.cxx @@ -17,6 +17,9 @@ typedef enum SkalaFeature { Feature_GridCoords = 4, Feature_GridWeights = 5, Feature_Coarse0AtomicCoords = 6 + Feature_AtomicGridWeights = 7, + Feature_AtomicGridSizes = 8, + Feature_AtomicGridSizeBoundShape = 9 } SkalaFeature; static inline @@ -106,6 +109,12 @@ skala_model_load(const char *filename, feature_keys.insert({feature_key, Feature_GridWeights}); } else if (feature_key == "coarse_0_atomic_coords") { feature_keys.insert({feature_key, Feature_Coarse0AtomicCoords}); + } else if (feature_key == "atomic_grid_weights") { + feature_keys.insert({feature_key, Feature_AtomicGridWeights}); + } else if (feature_key == "atomic_grid_sizes") { + feature_keys.insert({feature_key, Feature_AtomicGridSizes}); + } else if (feature_key == "atomic_grid_size_bound_shape") { + feature_keys.insert({feature_key, Feature_AtomicGridSizeBoundShape}); } pos = end_pos + 1; } diff --git a/examples/fortran/ftorch_integration/src/skala_ftorch.f90 b/examples/fortran/ftorch_integration/src/skala_ftorch.f90 index 66f5776..8aa4ec4 100644 --- a/examples/fortran/ftorch_integration/src/skala_ftorch.f90 +++ b/examples/fortran/ftorch_integration/src/skala_ftorch.f90 @@ -24,7 +24,10 @@ module skala_ftorch integer :: grid_coords = 4 integer :: grid_weights = 5 integer :: coarse_0_atomic_coords = 6 - integer :: max_feature = 6 + integer :: atomic_grid_weights = 7 + integer :: atomic_grid_sizes = 8 + integer :: atomic_grid_size_bound_shape = 9 + integer :: max_feature = 9 end type skala_feature_enum type(skala_feature_enum), parameter :: skala_feature = skala_feature_enum() diff --git a/pyproject.toml b/pyproject.toml index 70f6304..7a37e64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "skala" -version = "1.1.1" +version = "2026.1" description = "Skala Exchange Correlation Functional" authors = [] license-files = ["LICENSE.txt"] @@ -24,7 +24,7 @@ dependencies = [ "huggingface_hub", "numpy", "opt_einsum_fx", - "pyscf", + "pyscf>=2.8,<2.14", "qcelemental", "torch" ] @@ -33,8 +33,10 @@ optional-dependencies.dev = [ "pre-commit", "pytest", "pytest-cov", + "pytest-randomly", ] optional-dependencies.doc = [ + "ipywidgets", "myst-nb", "sphinx", "sphinx-book-theme", @@ -48,22 +50,65 @@ urls.repository = "https://github.com/microsoft/skala" urls.documentation = "https://microsoft.github.io/skala" urls.homepage = "https://aka.ms/dft" +[tool.mypy] +strict = true +ignore_missing_imports = true +no_namespace_packages = true +python_version = "3.11" +exclude = ["third_party/"] +disable_error_code = ["no-any-return"] + +# torch.autograd.Function uses dynamic attributes on FunctionCtx (ctx.save_for_backward pattern) +# and Function.apply() is untyped. These are fundamental to PyTorch's autograd API. +[[tool.mypy.overrides]] +module = "skala.pyscf.features" +disable_error_code = ["no-any-return", "attr-defined", "no-untyped-call"] + [tool.ruff] target-version = "py311" src = ["src"] +exclude = ["third_party"] [tool.ruff.lint] extend-select = [ "B", # flake8-bugbear "D417", # undocumented-param in docstring + "I", # isort "UP", # pyupgrade ] extend-ignore = [ "UP015", # redundant-open-modes + "E741", # variable name l ] +[tool.ruff.lint.per-file-ignores] +# E402: imports must come after CUDA checks and CuPy allocator configuration +"src/skala/gpu4pyscf/__init__.py" = ["E402"] +"src/skala/gpu4pyscf/dft.py" = ["E402"] + [tool.ruff.lint.isort] detect-same-package = true [tool.black] line-length = 100 + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + # PyTorch 2.11 deprecated `torch.jit.load`; Skala's pretrained checkpoints + # are TorchScript archives, so `skala.functional.load` still uses it. + 'ignore:`torch\.jit\.load` is deprecated\. Please switch to `torch\.export`\.:DeprecationWarning', + 'ignore:`torch\.jit\.script` is deprecated\. Please switch to `torch\.compile` or `torch\.export`\.:DeprecationWarning', + 'ignore:`torch\.jit\.save` is deprecated\. Please switch to `torch\.export`\.:DeprecationWarning', + 'ignore:`torch\.jit\.save` is not supported in Python 3\.14+ and may break\. Please switch to `torch\.export`\.:DeprecationWarning', + # PySCF's `SCF.__init__` (pyscf/scf/hf.py, currently line 1733) creates + # `self._chkfile = tempfile.NamedTemporaryFile(...)` as its checkpoint file + # and relies on finalization to close it, which emits ResourceWarning during + # GC. The warning is emitted by CPython's I/O layer (not from the pyscf + # module), so we match on the message. Traced via PYTHONTRACEMALLOC. + "ignore:unclosed file <_io\\.FileIO name='/tmp/tmp.*' mode='rb\\+' closefd=True>:ResourceWarning", + # On recent pytest versions the same leaked tempfile surfaces during GC at + # session teardown via the unraisable-exception hook, which re-emits it as + # `pytest.PytestUnraisableExceptionWarning`. Filter that wrapper too. + "ignore:Exception ignored in. <_io\\.FileIO name='/tmp/tmp.*' mode='rb\\+' closefd=True>:pytest.PytestUnraisableExceptionWarning", +] diff --git a/src/skala/ase/calculator.py b/src/skala/ase/calculator.py index 4b623de..7095b7d 100644 --- a/src/skala/ase/calculator.py +++ b/src/skala/ase/calculator.py @@ -19,7 +19,7 @@ from skala.pyscf.retry import retry_scf -class Skala(Calculator): # type: ignore[misc] +class Skala(Calculator): """ ASE calculator for the Skala exchange-correlation functional. @@ -38,7 +38,7 @@ class Skala(Calculator): # type: ignore[misc] ] default_parameters: dict[str, Any] = { - "xc": "skala", + "xc": "skala-1.1", "basis": None, "with_density_fit": False, "auxbasis": None, @@ -54,8 +54,8 @@ class Skala(Calculator): # type: ignore[misc] _mol: gto.Mole | None = None _ks: grad.rhf.GradientsBase | None = None - def __init__(self, atoms: Atoms | None = None, **kwargs: Any): - super().__init__(atoms=atoms, **kwargs) + def __init__(self, atoms: Atoms | None = None, **kwargs: Any) -> None: + super().__init__(atoms=atoms, **kwargs) # type: ignore[no-untyped-call] def set(self, **kwargs: Any) -> dict[str, Any]: """ @@ -66,12 +66,12 @@ def set(self, **kwargs: Any) -> dict[str, Any]: **kwargs : dict Additional parameters to set for the calculator. """ - changed_parameters: dict[str, Any] = super().set(**kwargs) + changed_parameters: dict[str, Any] = super().set(**kwargs) # type: ignore[no-untyped-call] if "verbose" in changed_parameters: if self._mol is not None: - self._mol.verbose = int(self.parameters.verbose) + self._mol.verbose = int(self.parameters.verbose) # type: ignore if self._ks is not None: - verbose = int(self.parameters.verbose) + verbose = int(self.parameters.verbose) # type: ignore self._ks.verbose = verbose self._ks.base.verbose = verbose @@ -104,7 +104,7 @@ def reset(self) -> None: """ Reset the calculator to its initial state. """ - super().reset() + super().reset() # type: ignore[no-untyped-call] def calculate( self, @@ -129,11 +129,11 @@ def calculate( if system_changes is None: system_changes = all_changes - super().calculate( + super().calculate( # type: ignore[no-untyped-call] atoms=atoms, properties=properties, system_changes=system_changes ) - if not isinstance(basis := self.parameters.basis, str): + if not isinstance(basis := self.parameters.basis, str): # type: ignore raise InputError("Basis set must be specified in the parameters.") if self.atoms is None: @@ -154,25 +154,25 @@ def calculate( atom=atom, basis=basis, unit="Angstrom", - verbose=int(self.parameters.verbose), - charge=_get_charge(self.atoms, self.parameters), - spin=_get_uhf(self.atoms, self.parameters), + verbose=int(self.parameters.verbose), # type: ignore + charge=_get_charge(self.atoms, self.parameters), # type: ignore + spin=_get_uhf(self.atoms, self.parameters), # type: ignore ) self._ks = None else: self._mol = self._mol.set_geom_(atom, inplace=False) if self._ks is None: - if not isinstance(xc_param := self.parameters.xc, (ExcFunctionalBase, str)): + if not isinstance(xc_param := self.parameters.xc, (ExcFunctionalBase, str)): # type: ignore raise InputError("XC functional must be a string or ExcFunctionalBase.") grad_method = SkalaKS( self._mol, xc=xc_param, - with_density_fit=bool(self.parameters.with_density_fit), - auxbasis=self.parameters.auxbasis, - with_newton=bool(self.parameters.with_newton), - with_dftd3=bool(self.parameters.with_dftd3), - ks_config=self.parameters.ks_config, + with_density_fit=bool(self.parameters.with_density_fit), # type: ignore + auxbasis=self.parameters.auxbasis, # type: ignore + with_newton=bool(self.parameters.with_newton), # type: ignore + with_dftd3=bool(self.parameters.with_dftd3), # type: ignore + ks_config=self.parameters.ks_config, # type: ignore ).nuc_grad_method() self._ks = grad_method else: @@ -198,7 +198,7 @@ def _get_charge(atoms: Atoms, parameters: Parameters) -> int: by summing the initial charges of all atoms. """ if parameters.charge is None: - charge = atoms.get_initial_charges().sum() + charge = atoms.get_initial_charges().sum() # type: ignore[no-untyped-call] else: charge = parameters.charge return int(charge) @@ -211,6 +211,6 @@ def _get_uhf(atoms: Atoms, parameters: Parameters) -> int: is calculated by summing the initial magnetic moments of all atoms. """ if parameters.multiplicity is None: - multiplicity = int(atoms.get_initial_magnetic_moments().sum().round()) + multiplicity = int(atoms.get_initial_magnetic_moments().sum().round()) # type: ignore[no-untyped-call] return multiplicity return int(parameters.multiplicity) - 1 diff --git a/src/skala/foundry/client.py b/src/skala/foundry/client.py index 0df7d8a..663b7fa 100644 --- a/src/skala/foundry/client.py +++ b/src/skala/foundry/client.py @@ -3,6 +3,7 @@ import json import logging import time +import urllib.error import urllib.request import uuid from typing import Any diff --git a/src/skala/foundry/schemas.py b/src/skala/foundry/schemas.py index d5414b3..769ca32 100644 --- a/src/skala/foundry/schemas.py +++ b/src/skala/foundry/schemas.py @@ -18,13 +18,13 @@ TaskState: TypeAlias = Literal["succeeded", "failed", "running", "queued", "canceled"] -class SkalaConfig(BaseModel): # type: ignore[misc] +class SkalaConfig(BaseModel): basis: BasisOptions = "def2-qzvp" grid_level: GridLevelOptions = "ultrafine" max_num_scf_steps: int = 100 -class Molecule(BaseModel): # type: ignore[misc] +class Molecule(BaseModel): """Molecule representation based on qcelemental's Molecule model.""" geometry: list[float] = Field( @@ -45,11 +45,15 @@ def from_qcel(cls, molecule: qcel.models.Molecule) -> "Molecule": geometry = molecule.geometry if isinstance(geometry, np.ndarray): geometry = geometry.flatten().tolist() + + molecular_multiplicity = molecule.molecular_multiplicity + assert isinstance(molecular_multiplicity, int) + return cls( geometry=geometry, symbols=molecule.symbols, molecular_charge=molecule.molecular_charge, - molecular_multiplicity=molecule.molecular_multiplicity, + molecular_multiplicity=molecular_multiplicity, ) def to_qcel(self) -> qcel.models.Molecule: @@ -61,12 +65,12 @@ def to_qcel(self) -> qcel.models.Molecule: ) -class SkalaInput(BaseModel): # type: ignore[misc] +class SkalaInput(BaseModel): molecule: Molecule input_config: SkalaConfig = Field(default_factory=SkalaConfig) -class SkalaOutput(BaseModel): # type: ignore[misc] +class SkalaOutput(BaseModel): total_energy: float energy_breakdown: dict[str, float] = Field( default_factory=dict, @@ -78,7 +82,7 @@ class SkalaOutput(BaseModel): # type: ignore[misc] ) -class TaskStatus(BaseModel): # type: ignore[misc] +class TaskStatus(BaseModel): status: TaskState num_tasks_ahead: int exception: str | None = None diff --git a/src/skala/functional/__init__.py b/src/skala/functional/__init__.py index 15dbaf1..ad57507 100644 --- a/src/skala/functional/__init__.py +++ b/src/skala/functional/__init__.py @@ -17,10 +17,12 @@ from skala.functional._hashes import KNOWN_HASHES from skala.functional.base import ExcFunctionalBase from skala.functional.load import TracedFunctional -from skala.functional.traditional import LDA, PBE, SPW92, TPSS +from skala.functional.model import SkalaFunctional +from skala.functional.traditional import LDA, PBE, SPW92, TPSS, SpinScaledXCFunctional __all__ = [ "ExcFunctionalBase", + "SkalaFunctional", "TracedFunctional", "LDA", "PBE", @@ -29,44 +31,55 @@ "load_functional", ] +_SKALA_VERSIONS = { + "skala-1.0": ("skala-1.0.fun", "skala-1.0-cuda.fun"), + "skala-1.1": ("skala-1.1.fun", "skala-1.1-cuda.fun"), +} -def load_functional(name: str, device: torch.device | None = None) -> ExcFunctionalBase: - """ - Load an exchange-correlation functional by name. - - Parameters - ---------- - name : str - Name of the functional. Supported values: - - - "skala": The Skala neural functional - - "lda": Local Density Approximation - - "spw92": SPW92 (LDA with PW92 correlation) - - "pbe": Perdew-Burke-Ernzerhof functional - - "tpss": Tao-Perdew-Staroverov-Scuseria meta-GGA - - Returns - ------- - ExcFunctionalBase - The loaded functional instance. - - Raises - ------ - ValueError - If the functional name is not recognized. - - Example - ------- - >>> func = load_functional("skala") - >>> func.features - ['density', 'kin', 'grad', 'grid_coords', 'grid_weights', 'coarse_0_atomic_coords'] - >>> func = load_functional("lda") - >>> func.features - ['density', 'grid_weights'] + +def load_functional( + name: str, device: torch.device | None = None +) -> ExcFunctionalBase | str: + """Load an exchange-correlation functional by name. + + Args: + name: Name of the functional. Skala-native values: + + - ``"skala-1.1"``: Skala 1.1 neural functional (recommended). + - ``"skala-1.0"``: Skala 1.0 neural functional (legacy, traced only). + - ``"lda"``: Local Density Approximation. + - ``"spw92"``: SPW92 (LDA with PW92 correlation). + - ``"pbe"``: Perdew-Burke-Ernzerhof functional. + - ``"tpss"``: Tao-Perdew-Staroverov-Scuseria meta-GGA. + + Any other string is returned as-is for native PySCF/gpu4pyscf evaluation. + + device: Device to load the functional onto. + + Returns: + An ``ExcFunctionalBase`` instance for Skala-native functionals, or the + name string for PySCF-native functionals. + + Example: + >>> func = load_functional("skala-1.1") + >>> func.features + ['density', 'kin', 'grad', 'grid_coords', 'grid_weights', ... + >>> func = load_functional("lda") + >>> func.features + ['density', 'grid_weights'] + >>> load_functional("b3lyp") + 'b3lyp' """ func_name = name.lower() if func_name == "skala": + raise ValueError( + 'The generic functional name "skala" is no longer supported. ' + 'Please use "skala-1.0" or "skala-1.1".' + ) + + func: SpinScaledXCFunctional + if func_name in _SKALA_VERSIONS: env_path = os.environ.get("SKALA_LOCAL_MODEL_PATH") if env_path is not None: logging.getLogger(__name__).warning( @@ -79,12 +92,14 @@ def load_functional(name: str, device: torch.device | None = None) -> ExcFunctio device_type = ( torch.get_default_device().type if device is None else device.type ) - repo_id = "microsoft/skala-1.0" - filename = "skala-1.0.fun" if device_type == "cpu" else "skala-1.0-cuda.fun" + repo_id = f"microsoft/{func_name}" + cpu_file, cuda_file = _SKALA_VERSIONS[func_name] + filename = cpu_file if device_type == "cpu" else cuda_file path = hf_hub_download(repo_id=repo_id, filename=filename) expected_hash = KNOWN_HASHES.get((repo_id, filename)) - with open(path, "rb") as fd: - return TracedFunctional.load(fd, device=device, expected_hash=expected_hash) + + return TracedFunctional.load(path, device=device, expected_hash=expected_hash) + elif func_name == "lda": func = LDA() elif func_name == "spw92": @@ -94,9 +109,9 @@ def load_functional(name: str, device: torch.device | None = None) -> ExcFunctio elif func_name == "tpss": func = TPSS() else: - raise ValueError( - f"Unknown functional: {name}. Please provide a valid functional name or path to a traced functional file." - ) + return name + if device is not None: func = func.to(device=device) + return func diff --git a/src/skala/functional/_hashes.py b/src/skala/functional/_hashes.py index 9cac6b5..9d984e2 100644 --- a/src/skala/functional/_hashes.py +++ b/src/skala/functional/_hashes.py @@ -16,4 +16,11 @@ ("microsoft/skala-1.0", "skala-1.0-cuda.fun"): ( "0b38e13237cec771fed331664aace42f8c0db8f15caca6a5c563085e61e2b1fd" ), + ("microsoft/skala-1.1", "skala-1.1.fun"): ( + "0c8432ac3f03c8f1276372df9aca5b7ee7f8939d47a8789eb158976e89aa0606" + ), + ( + "microsoft/skala-1.1", + "skala-1.1-cuda.fun", + ): "f77be6002d873c0a2384b6df7850d32bbec519036344ff5fdde9730c6f9a4326", } diff --git a/src/skala/functional/base.py b/src/skala/functional/base.py index ac5c0bc..cd6137b 100644 --- a/src/skala/functional/base.py +++ b/src/skala/functional/base.py @@ -16,7 +16,7 @@ VxcType = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] -class ExcFunctionalBase(nn.Module): # type: ignore[misc] +class ExcFunctionalBase(nn.Module): """ Abstract base class for exchange-correlation functionals. @@ -35,7 +35,7 @@ def get_d3_settings(self) -> str | None: """ return None - def get_exc_density(self, mol: dict[str, torch.Tensor]) -> torch.FloatTensor: + def get_exc_density(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """ Returns the exchange-correlation density for the given molecule. It should return a tensor of shape (G,) where G is the number of grid points @@ -46,7 +46,7 @@ def get_exc_density(self, mol: dict[str, torch.Tensor]) -> torch.FloatTensor: "get_exc_density not implemented for this functional." ) - def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.FloatTensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """ Compute the exchange-correlation energy. @@ -58,7 +58,7 @@ def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.FloatTensor: Returns ------- - torch.FloatTensor + torch.Tensor The total exchange-correlation energy. """ exc_density = self.get_exc_density(mol).double() diff --git a/src/skala/functional/layers.py b/src/skala/functional/layers.py index 19a4214..18624b4 100644 --- a/src/skala/functional/layers.py +++ b/src/skala/functional/layers.py @@ -14,7 +14,7 @@ from torch import nn -class Squasher(nn.Module): # type: ignore[misc] +class Squasher(nn.Module): """ Elementwise squashing function log(|x| + eta). @@ -34,7 +34,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return (x.abs() + self.eta).log() -class LinearSkip(nn.Linear): # type: ignore[misc] +class LinearSkip(nn.Linear): """ Linear layer with skip connection, used to initialize close to identity. @@ -56,9 +56,9 @@ def __init__(self, in_features: int, out_features: int, **kwargs: Any): Additional arguments passed to nn.Linear. """ super().__init__(in_features=in_features, out_features=out_features, **kwargs) - assert ( - in_features == out_features - ), f"Expecting args in_features == out_features, got {in_features} != {out_features}." + assert in_features == out_features, ( + f"Expecting args in_features == out_features, got {in_features} != {out_features}." + ) self._init_weights() def _init_weights(self) -> None: @@ -72,7 +72,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return input + nn.functional.linear(input, self.weight, self.bias) -class ScaledSigmoid(nn.Sigmoid): # type: ignore[misc] +class ScaledSigmoid(nn.Sigmoid): """ Sigmoid activation function with learnable scaling. diff --git a/src/skala/functional/load.py b/src/skala/functional/load.py index ed29714..1edcb9c 100644 --- a/src/skala/functional/load.py +++ b/src/skala/functional/load.py @@ -56,11 +56,11 @@ def get_d3_settings(self) -> str | None: """ return self.expected_d3_settings - def get_exc_density(self, data: dict[str, torch.Tensor]) -> torch.FloatTensor: - return self._traced_model.get_exc_density(data) + def get_exc_density(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + return self._traced_model.get_exc_density(mol) - def get_exc(self, data: dict[str, torch.Tensor]) -> torch.FloatTensor: - return self._traced_model.get_exc(data) + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + return self._traced_model.get_exc(mol) @property def original_name(self) -> str: @@ -77,7 +77,7 @@ def original_name(self) -> str: @classmethod def load( cls, - fp: str | bytes | os.PathLike[str] | IO[bytes], + fp: str | os.PathLike[str] | IO[bytes], device: torch.device | None = None, *, expected_hash: str | None = None, @@ -103,10 +103,11 @@ def load( if expected_hash is not None: # Read the whole file into memory so we can hash it before # passing it to the unsafe torch.jit.load deserializer. - if isinstance(fp, (str, bytes, os.PathLike)): - with open(fp, "rb") as f: + if isinstance(fp, (str, os.PathLike)): + with open(fp, mode="rb") as f: data = f.read() else: + # Assume a file-like object data = fp.read() actual_hash = hashlib.sha256(data).hexdigest() if actual_hash != expected_hash: @@ -117,7 +118,7 @@ def load( ) fp = io.BytesIO(data) - traced_model = torch.jit.load(fp, _extra_files=extra_files, map_location=device) + traced_model = torch.jit.load(fp, _extra_files=extra_files, map_location=device) # type: ignore[no-untyped-call] _metadata = json.loads(extra_files["metadata"].decode("utf-8")) if not isinstance(_metadata, dict): diff --git a/src/skala/functional/model.py b/src/skala/functional/model.py index 72eb356..d5f215e 100644 --- a/src/skala/functional/model.py +++ b/src/skala/functional/model.py @@ -1,6 +1,15 @@ # SPDX-License-Identifier: MIT +""" +Skala neural exchange-correlation functional. + +This module implements the Skala model (``SkalaFunctional``), which uses a +packed per-atom grid layout with multiple non-local equivariant message-passing +layers and symmetric contraction for higher-order body correlations. +""" + import math +from typing import Any, cast import torch from e3nn import o3 @@ -9,174 +18,405 @@ from skala.functional.base import ExcFunctionalBase, enhancement_density_inner_product from skala.functional.layers import ScaledSigmoid -from skala.utils.scatter import scatter_sum +from skala.functional.utils.irreps import Irreps +from skala.functional.utils.pad_ragged import pad_ragged, unpad_ragged +from skala.functional.utils.symmetric_contraction import SymmetricContraction -# 0.32 and 2.32 are the smallest and largest covalent radius estimates -# from Pyykko and Atsumi, Chem. Eur. J. 15, 2009, 188-197 ANGSTROM_TO_BOHR = 1.88973 -MIN_COV_RAD = 0.32 * ANGSTROM_TO_BOHR -MAX_COV_RAD = 2.32 * ANGSTROM_TO_BOHR -def _prepare_features( - mol: dict[str, torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor]: +def _prepare_features_raw( + mol: dict[str, torch.Tensor], eps: float = 1e-5 +) -> torch.Tensor: + """Compute log-space semi-local features from packed density data. + + Args: + mol: Dictionary of packed molecular features (grid_per_atom, atoms, …). + eps: Small constant for numerical stability in log. + + Returns: + Tensor of shape ``(grid_per_atom, atoms, 7)``. + """ x = torch.cat( [ - mol["density"].T, - (mol["grad"] ** 2).sum(1).T, - mol["kin"].T, - (mol["grad"].sum(0) ** 2).sum(0).view(-1, 1), + mol["density"].permute(1, 2, 0), + (mol["grad"] ** 2).sum(1).permute(1, 2, 0), + mol["kin"].permute(1, 2, 0), + (mol["grad"].sum(0) ** 2).sum(0).unsqueeze(-1), ], - dim=1, + dim=-1, ) + + # Cast to double to work around a TorchScript gradient bug with torch.abs. + # See PR #15759 in the livdft repository for details. x = x.double() - features = torch.log(torch.abs(x) + 1e-5) + return torch.log(torch.abs(x) + eps) - features_ab = features - features_ba = features[:, [1, 0, 3, 2, 5, 4, 6]] - return features_ab, features_ba + +class SemiLocalFeatures(nn.Module): + """Compute semi-local (ab, ba) feature pairs with a pre-buffered permutation index.""" + + _PERM = [1, 0, 3, 2, 5, 4, 6] + _feature_perm: torch.Tensor + + def __init__(self) -> None: + super().__init__() + self.register_buffer( + "_feature_perm", + torch.tensor(self._PERM, dtype=torch.long), + persistent=False, + ) + + def forward( + self, mol: dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + features = _prepare_features_raw(mol) + features_ab = features + features_ba = features.index_select(-1, self._feature_perm) + return features_ab, features_ba + + +class ExpRadialScaleModel(nn.Module): + """Learnable radial basis using exponentially-spaced Gaussians. + + Args: + embedding_size: Number of radial basis functions. + """ + + temps: torch.Tensor + + def __init__(self, embedding_size: int = 8) -> None: + super().__init__() + self.embedding_size = embedding_size + min_std = 0.32 * ANGSTROM_TO_BOHR / 2 + max_std = 2.32 * ANGSTROM_TO_BOHR / 2 + self.register_buffer( + "temps", 2 * torch.linspace(min_std, max_std, embedding_size) ** 2 + ) + + def forward(self, dist2: torch.Tensor) -> torch.Tensor: + """Compute radial basis values. + + Args: + dist2: Squared distances, shape ``(…, 1)``. + + Returns: + Radial basis values, shape ``(…, embedding_size)``. + """ + t = self.temps + dim = 3 + return ( + 2 / (dim * t * (math.pi * t) ** (0.5 * dim)) * torch.exp(-dist2 / t) * dist2 + ) class SkalaFunctional(ExcFunctionalBase): + """Skala neural exchange-correlation functional. + + This model operates on per-atom packed grid features and uses multiple + equivariant non-local message-passing layers with symmetric contraction. + + Args: + lmax: Maximum angular momentum order for spherical harmonics. + num_mid_layers: Total number of mid-layers (local + non-local). + num_non_local_layers: How many mid-layers are non-local. + non_local_hidden_nf: Number of channels in the non-local model. + coarse_linear_type: Type of coarse linear layer. + correlation: Correlation order for the symmetric contraction. + """ + features = [ "density", "kin", "grad", "grid_coords", "grid_weights", + "atomic_grid_weights", + "atomic_grid_sizes", "coarse_0_atomic_coords", + "atomic_grid_size_bound_shape", ] def __init__( self, - lmax: int = 3, # max angular momentum order of the spherical harmonics - non_local: bool = True, + lmax: int = 3, + num_mid_layers: int = 3, + num_non_local_layers: int = 2, non_local_hidden_nf: int = 16, - radius_cutoff: float = float("inf"), + coarse_linear_type: str | None = "decomp-identity", + correlation: int = 3, ) -> None: super().__init__() + assert 0 <= num_non_local_layers <= num_mid_layers + self.num_scalar_features = 7 - self.non_local = non_local self.lmax = lmax - + self.num_mid_layers = num_mid_layers + self.num_non_local_layers = num_non_local_layers + self.non_local_hidden_nf = non_local_hidden_nf self.num_feats = 256 + + self.semi_local_features = SemiLocalFeatures() + self.input_model = torch.nn.Sequential( nn.Linear(self.num_scalar_features, self.num_feats), nn.SiLU(), - nn.Linear(self.num_feats, self.num_feats), # layer 1 + nn.Linear(self.num_feats, self.num_feats), nn.SiLU(), ) - if self.non_local: - self.non_local_model = NonLocalModel( - input_nf=self.num_feats, - hidden_nf=non_local_hidden_nf, - lmax=self.lmax, - radius_cutoff=radius_cutoff, + if self.num_non_local_layers > 0: + self.sph_irreps = Irreps.spherical_harmonics(self.lmax, p=1) + self.spherical_harmonics = o3.SphericalHarmonics( + irreps_out=str(self.sph_irreps), + normalize=False, + normalization="norm", + ) + self.non_local_layers = torch.nn.ModuleList( + [ + NonLocalModel( + input_nf=self.num_feats, + hidden_nf=self.non_local_hidden_nf, + lmax=self.lmax, + edge_irreps=self.sph_irreps, + coarse_linear_type=coarse_linear_type, + correlation=correlation, + ) + for _ in range(self.num_non_local_layers) + ] ) - self.num_non_local_contributions = non_local_hidden_nf + self.radial_basis = ExpRadialScaleModel(self.non_local_hidden_nf) else: - self.num_non_local_contributions = 0 + raise NotImplementedError("Non-local model must be enabled.") - # concatenate the non-local contributions to the input layer if non-local is enabled self.output_model = torch.nn.Sequential( - nn.Linear( - self.num_feats + self.num_non_local_contributions, self.num_feats - ), # layer 2 - nn.SiLU(), - nn.Linear(self.num_feats, self.num_feats), # layer 3 - nn.SiLU(), - nn.Linear(self.num_feats, self.num_feats), # layer 4 - nn.SiLU(), + *[ + module + for _ in range(self.num_mid_layers - self.num_non_local_layers) + for module in [ + nn.Linear(self.num_feats, self.num_feats), + nn.SiLU(), + ] + ], nn.Linear(self.num_feats, 1), ScaledSigmoid(scale=2.0), ) - self.reset_parameters() + self._init_weights() + + # Keys introduced after the original checkpoint format (deterministic buffers + # that can be reconstructed from __init__ args). + _RECONSTRUCTABLE_BUFFER_PREFIXES = ("radial_basis.", "semi_local_features.") + + def load_state_dict( # type: ignore # needs mutable dict + self, + state_dict: dict[str, Any], + strict: bool = True, + assign: bool = False, + ) -> dict[str, torch.Tensor]: + """Load state_dict with backward compatibility for older checkpoints.""" + if strict: + current_sd = self.state_dict() + # Add missing reconstructable buffers from the current model. + for key in current_sd: + if key not in state_dict and any( + key.startswith(p) for p in self._RECONSTRUCTABLE_BUFFER_PREFIXES + ): + state_dict[key] = current_sd[key] + # Remove extra reconstructable buffers that the checkpoint carries but + # the model does not expose (e.g. non-persistent buffers). + extra = set(state_dict) - set(current_sd) + for key in extra: + if any( + key.startswith(p) for p in self._RECONSTRUCTABLE_BUFFER_PREFIXES + ): + del state_dict[key] + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + def _init_weights(self) -> None: + for layer in self.input_model: + if isinstance(layer, nn.Linear): + torch.nn.init.xavier_uniform_(layer.weight) + torch.nn.init.zeros_(layer.bias) + + for layer in self.output_model: + if isinstance(layer, nn.Linear): + torch.nn.init.xavier_uniform_(layer.weight) + torch.nn.init.zeros_(layer.bias) + + @property + def dtype(self) -> torch.dtype: + return cast(nn.Linear, self.input_model[0]).weight.dtype + + def pack_features( + self, mol_feats: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Pack flat features into dense (grid_per_atom, atoms, …) layout. + + Args: + mol_feats: Flat molecular features from grid evaluation. + + Returns: + Packed features dictionary. + """ + atomic_grid_sizes = mol_feats["atomic_grid_sizes"] + size_bound = mol_feats["atomic_grid_size_bound_shape"].shape[0] + + packed_mol_feats: dict[str, torch.Tensor] = {} + for key in self.features: + if key == "atomic_grid_weights": + packed_mol_feats[key] = pad_ragged( + mol_feats[key], atomic_grid_sizes, size_bound + ).T # (max_grid_size, num_atoms) + elif key == "grid_weights": + continue + elif key == "grid_coords": + packed_mol_feats[key] = pad_ragged( + mol_feats[key], atomic_grid_sizes, size_bound + ).permute(1, 0, 2) # (max_grid_size, num_atoms, 3) + elif key == "coarse_0_atomic_coords": + packed_mol_feats[key] = mol_feats[key] + elif key == "density": + packed_mol_feats[key] = pad_ragged( + mol_feats[key].T, atomic_grid_sizes, size_bound + ).permute(2, 1, 0) # (2, max_grid_size, num_atoms) + elif key == "grad": + packed_mol_feats[key] = pad_ragged( + mol_feats[key].permute(2, 0, 1), atomic_grid_sizes, size_bound + ).permute(2, 3, 1, 0) # (2, 3, max_grid_size, num_atoms) + elif key == "kin": + packed_mol_feats[key] = pad_ragged( + mol_feats[key].T, atomic_grid_sizes, size_bound + ).permute(2, 1, 0) # (2, max_grid_size, num_atoms) + elif key in ("atomic_grid_sizes", "atomic_grid_size_bound_shape"): + continue + else: + raise ValueError(f"Unexpected key: {key}") + + return packed_mol_feats + + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + exc_density = self._get_exc_density_padded(mol).double() + grid_weights = ( + pad_ragged( + mol["grid_weights"], + mol["atomic_grid_sizes"], + mol["atomic_grid_size_bound_shape"].shape[0], + ) + .T.double() + .reshape(-1) + ) + + return (exc_density * grid_weights).sum() def get_exc_density(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + padded = self._get_exc_density_padded(mol) + sizes = mol["atomic_grid_sizes"] + size_bound = mol["atomic_grid_size_bound_shape"].shape[0] + num_atoms = sizes.shape[0] + total_grid_points = mol["grid_weights"].shape[0] + padded_2d = padded.reshape(size_bound, num_atoms).T + return unpad_ragged(padded_2d, sizes, total_grid_points) + + def _get_exc_density_padded(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + mol = self.pack_features(mol) grid_coords = mol["grid_coords"] - grid_weights = mol["grid_weights"] + atomic_grid_weights = mol["atomic_grid_weights"] coarse_coords = mol["coarse_0_atomic_coords"] - features_ab, features_ba = _prepare_features(mol) + features_ab, features_ba = self.semi_local_features(mol) # Learned symmetrized features - spin_feats = torch.cat([features_ab, features_ba], dim=0) + spin_feats = torch.stack([features_ab, features_ba], dim=0) spin_feats = spin_feats.to(self.dtype) spin_feats = self.input_model(spin_feats) - features = torch.add(*torch.chunk(spin_feats, 2, dim=0)) / 2 + features = spin_feats.sum(0) / 2 # Non-local model - if self.non_local: - h_grid_non_local = self.non_local_model( - features, - grid_coords, - coarse_coords, - grid_weights, + if self.num_non_local_layers > 0: + # grid_coords: (num_fine, num_coarse, 3) + # coarse_coords: (num_coarse, 3) + directions = grid_coords.double() - coarse_coords.double() + distances = (directions**2 + 1e-20).sum(-1) ** 0.5 + directions = directions / distances[:, :, None] + + directions = directions.to(self.dtype) + distances = distances.to(self.dtype) + + distance_ft = self.radial_basis( + distances.unsqueeze(-1) ** 2 + ) # (#fine, #coarse, hidden_nf) + + direction_ft = self.spherical_harmonics( + directions + ) # (num_fine, num_coarse, (lmax+1)^2) + + exp_m1_rho_total = torch.exp(-mol["density"].sum(0).unsqueeze(-1)).to( + self.dtype ) - h_grid_non_local = h_grid_non_local * torch.exp( - -mol["density"].sum(0).view(-1, 1) - ).to(self.dtype) - features = torch.cat([features, h_grid_non_local], dim=-1) + for non_local_layer in self.non_local_layers: + features = non_local_layer( + features, + distance_ft, + direction_ft, + atomic_grid_weights, + exp_m1_rho_total, + ) enhancement_factor = self.output_model(features) return enhancement_density_inner_product( - enhancement_factor=enhancement_factor, density=mol["density"] + enhancement_factor=enhancement_factor.view(-1, 1), + density=mol["density"].reshape(2, -1), ) def reset_parameters(self) -> None: - for layer in self.input_model: - if isinstance(layer, nn.Linear): - torch.nn.init.xavier_uniform_(layer.weight) - torch.nn.init.zeros_(layer.bias) + self._init_weights() - for layer in self.output_model: - if isinstance(layer, nn.Linear): - torch.nn.init.xavier_uniform_(layer.weight) - torch.nn.init.zeros_(layer.bias) - @property - def dtype(self) -> torch.dtype: - return self.input_model[0].weight.dtype +class NonLocalModel(nn.Module): + """Equivariant non-local message-passing layer. + Args: + input_nf: Number of input scalar features. + hidden_nf: Number of hidden channels per irrep. + lmax: Maximum angular momentum. + edge_irreps: Irreps for edge features (spherical harmonics). + coarse_linear_type: Type of O3-equivariant linear on coarse features. + correlation: Correlation order for symmetric contraction. + """ -class NonLocalModel(nn.Module): # type: ignore[misc] def __init__( self, input_nf: int, hidden_nf: int, lmax: int, - radius_cutoff: float = float("inf"), + edge_irreps: Irreps, + coarse_linear_type: str | None = None, + correlation: int = 1, ): super().__init__() self.input_nf = input_nf self.hidden_nf = hidden_nf - self.in_irreps = o3.Irreps(f"{self.hidden_nf}x0e") - self.out_irreps = o3.Irreps(f"{self.hidden_nf}x0e") + self.in_irreps = Irreps(f"{self.hidden_nf}x0e") + self.out_irreps = Irreps(f"{self.hidden_nf}x0e") self.lmax = lmax - self.hidden_irreps = o3.Irreps( + self.hidden_irreps = Irreps( "+".join([f"{hidden_nf}x{i}e" for i in range(self.lmax + 1)]) ) - self.sph_irreps = o3.Irreps.spherical_harmonics(self.lmax, p=1) - self.edge_irreps = self.sph_irreps - self.spherical_harmonics = o3.SphericalHarmonics( - irreps_out=self.sph_irreps, - normalize=False, - normalization="norm", - ) - self.radius_cutoff = radius_cutoff + self.edge_irreps = edge_irreps + self.coarse_linear_type = coarse_linear_type + assert correlation >= 1 + self.correlation = correlation self.pre_down_layer = torch.nn.Sequential( nn.Linear(self.input_nf, self.hidden_nf), torch.nn.SiLU(), ) - torch.nn.init.xavier_uniform_(self.pre_down_layer[0].weight) - torch.nn.init.zeros_(self.pre_down_layer[0].bias) + torch.nn.init.xavier_uniform_(self.pre_down_layer[0].weight) # type: ignore + torch.nn.init.zeros_(self.pre_down_layer[0].bias) # type: ignore self.tp_down = TensorProduct( self.in_irreps, @@ -184,245 +424,413 @@ def __init__( self.hidden_irreps, ) + if coarse_linear_type is not None: + if coarse_linear_type == "decomp": + self.coarse_linear = O3Linear(self.hidden_irreps, self.hidden_irreps) + elif coarse_linear_type == "decomp-identity": + self.coarse_linear = O3Linear(self.hidden_irreps, self.hidden_irreps) + o3_identity_init(self.coarse_linear) + elif coarse_linear_type == "sketch": + self.coarse_linear = O3Linear( + self.hidden_irreps, + (self.hidden_irreps * correlation).sort().irreps.simplify(), + ) + elif coarse_linear_type == "sketch-identity": + self.coarse_linear = O3Linear( + self.hidden_irreps, + (self.hidden_irreps * correlation).sort().irreps.simplify(), + ) + o3_identity_init(self.coarse_linear, out_dim_multiplier=correlation) + else: + raise ValueError(f"Unknown coarse_linear method: {coarse_linear_type}") + + if correlation > 1: + self.symmetric_product = SymmetricContraction( + irreps_in=self.hidden_irreps, + irreps_out=self.hidden_irreps, + correlation=correlation, + ) + self.tp_up = TensorProduct( self.hidden_irreps, self.edge_irreps, self.out_irreps, + x1_contains_r=False, ) self.post_up_layer = torch.nn.Sequential( nn.Linear(self.hidden_nf, self.hidden_nf), torch.nn.SiLU(), ) - torch.nn.init.xavier_uniform_(self.post_up_layer[0].weight) - torch.nn.init.zeros_(self.post_up_layer[0].bias) + torch.nn.init.xavier_uniform_(self.post_up_layer[0].weight) # type: ignore + torch.nn.init.zeros_(self.post_up_layer[0].bias) # type: ignore + + self.concat_layer = torch.nn.Sequential( + nn.Linear(self.input_nf + self.hidden_nf, self.input_nf), + nn.SiLU(), + ) def forward( self, - h: torch.Tensor, # (num_fine, feats) - grid_coords: torch.Tensor, - coarse_coords: torch.Tensor, + h: torch.Tensor, # (num_fine, num_coarse, input_nf) + distance_ft: torch.Tensor, + direction_ft: torch.Tensor, grid_weights: torch.Tensor, + exp_m1_rho_total: torch.Tensor, ) -> torch.Tensor: - h = self.pre_down_layer(h) # (num_fine, hidden_nf) - - directions, distances = vect_cdist(grid_coords, coarse_coords) - directions = directions.to(self.dtype) # (num_fine, num_coarse, 3) - distances = distances.to(self.dtype) # (num_fine, num_coarse) - if self.radius_cutoff != float("inf"): - up_weight = normalization_envelope(distances, self.radius_cutoff) - else: - up_weight = torch.ones_like(distances) - - # Find edges within the radius cutoff. - radius_mask = distances <= self.radius_cutoff - edge_directions = directions[radius_mask] # (num_edges, 3) - edge_distances = distances[radius_mask] # (num_edges,) - up_weight = up_weight[radius_mask] # (num_edges,) - edge_indices = radius_mask.nonzero() # (num_edges, 2) - edge_fine_idx = edge_indices[:, 0] # (num_edges,) - edge_coarse_idx = edge_indices[:, 1] # (num_edges,) - - # For each edge, form a feature vector of size (hidden_nf,) - # based on the distance between the fine and coarse points. - edge_dist_ft = exp_radial_func( - edge_distances, self.hidden_nf - ) # (num_edges, hidden_nf) - if self.radius_cutoff != float("inf"): - # Make the cutoff smooth using a polynomial that goes from 1 at distance 0 to 0 at the - # cutoff distance. - envelope = polynomial_envelope(edge_distances, self.radius_cutoff, 8) - edge_dist_ft *= envelope.unsqueeze(-1) - else: - envelope = torch.ones_like(edge_distances) - - # For each edge, compute a feature vector of size (hidden_nf,) - # based on the direction. - edge_direction_ft = self.spherical_harmonics( - edge_directions - ) # (num_edges, (lmax+1)^2) + features = h # skip connection + h = self.pre_down_layer(h) # Process (fine -> coarse) features on each edge. - edge_h = h[edge_fine_idx] # (num_edges, hidden_nf) - down = self.tp_down( - edge_h, edge_direction_ft - ) # (num_edges, hidden_nf * (lmax+1)^2) - down = self._mul_repeat( - edge_dist_ft, down, self.hidden_irreps - ) # (num_edges, hidden_nf * (lmax+1)^2) + down = self.tp_down(h, direction_ft, distance_weights=distance_ft) # Sum data from incoming edges into each coarse point. - h_coarse = scatter_sum( - down.double() * grid_weights.double()[edge_fine_idx].view(-1, 1), - edge_coarse_idx, - dim=0, - dim_size=coarse_coords.size(0), - ).to(self.dtype) # (num_coarse, hidden_nf * (lmax+1)^2) - - # Process (coarse -> fine) features on each edge. - edge_coarse_ft = h_coarse[edge_coarse_idx] - up = self.tp_up(edge_coarse_ft, edge_direction_ft) - # Compute the normalization factor as the sum of envelope from each coarse point - denom = scatter_sum( - up_weight, - edge_fine_idx, - dim=0, - dim_size=grid_coords.size(0), - )[edge_fine_idx] - up_weight = up_weight / (denom + 0.1) - up = self._mul_repeat( - edge_dist_ft * up_weight.unsqueeze(-1), up, self.out_irreps + h_coarse = torch.einsum("gck,gc->ck", down.double(), grid_weights.double()).to( + self.dtype ) - # Broadcast coarse point information back to fine points. - h_fine = scatter_sum( - up, - edge_fine_idx, - dim=0, - dim_size=grid_coords.size(0), - ) # (num_fine, hidden_nf) + # Correlate the coarse features. + if self.coarse_linear_type is not None: + h_coarse = self.coarse_linear(h_coarse) - # Process the fine points. - h_fine = self.post_up_layer(h_fine) # (num_fine, hidden_nf) + if self.correlation > 1: + h_coarse = self.symmetric_product(h_coarse) - return h_fine + # Process (coarse -> fine) features on each edge. + h_fine = self.tp_up(h_coarse, direction_ft, distance_weights=distance_ft) - @staticmethod - def _mul_repeat( - mul_by: torch.Tensor, edge_attrs: torch.Tensor, irreps: o3.Irreps - ) -> torch.Tensor: - # `edge_attrs` is spherical tensor features - # this function multiplies `edge_attrs` with `mul_by` channels-wise per tensor order - # (repeating over all irreps) - mul_by_shape = mul_by.size()[:-1] - product = torch.cat( - [ - # (..., v, 1) * (..., v, j) -> (..., (v*j)) - ( - mul_by.unsqueeze(-1) - * edge_attrs[..., slices].view(*mul_by_shape, mul, ir.dim) - ).view(*mul_by_shape, -1) - for slices, (mul, ir) in zip(irreps.slices(), irreps, strict=True) - ], - dim=-1, - ) - return product + # Process the fine points. + h_fine = self.post_up_layer(h_fine) + + # Non-linear transform with skip connection + features = torch.cat([features, h_fine * exp_m1_rho_total], dim=-1) + return self.concat_layer(features) @property def dtype(self) -> torch.dtype: - return self.pre_down_layer[0].weight.dtype + return self.pre_down_layer[0].weight.dtype # type: ignore -def vect_cdist(c1: torch.Tensor, c2: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - direction = c1[:, None] - c2[None, :] - dist = (direction**2 + 1e-20).sum(-1).sqrt() - return direction / dist[:, :, None], dist +class TensorProduct(nn.Module): + """Equivariant tensor product with learned weights. + Uses batched gather-index operations for efficient ``tp_down`` (all-scalar + input) and ``tp_up`` (all-scalar output) patterns. -def exp_radial_func(dist: torch.Tensor, num_basis: int, dim: int = 3) -> torch.Tensor: + Args: + irreps_in1: Irreps for the first input. + irreps_in2: Irreps for the second input (edge features). + irreps_out: Output irreps. + x1_contains_r: If True, x1 has a spatial (r) dimension. """ - This ensures two standard deviations of the Gaussian kernel would reach - the desired covalent radius value (95% of the Gaussian mass). - """ - min_std = MIN_COV_RAD / 2 - max_std = MAX_COV_RAD / 2 - s = torch.linspace(min_std, max_std, num_basis, device=dist.device) - - temps = 2 * s**2 - x2 = dist[..., None] ** 2 - emb = ( - torch.exp(-x2 / temps) * 2 / dim * x2 / temps / (math.pi * temps) ** (0.5 * dim) - ) - - return emb - - -def polynomial_envelope(r: torch.Tensor, cutoff: float, p: int) -> torch.Tensor: - """ - This smoothly maps the domain r=[0, cutoff] to the range [1, 0] using a polynomial function. - Every r >= cutoff is mapped to 0. - """ - # from DimeNet (https://arxiv.org/abs/2003.03123) - assert p >= 2 - r = r / cutoff - r = torch.clamp(r, 0, 1) - x = r - 1 - x2 = x * x - poly = p * (p + 1) * x2 - 2 * p * x + 2 - return torch.relu(1 - 0.5 * r.pow(p) * poly) - - -def normalization_envelope(r: torch.Tensor, cutoff: float) -> torch.Tensor: - r = r / cutoff - r = torch.clamp(r, 0, 1) - return 1 - torch.where(r < 0.5, 2 * r**2, -2 * r**2 + 4 * r - 1) + _tp_down_xw_idx: torch.Tensor + _tp_down_sph_idx: torch.Tensor + _tp_down_norm: torch.Tensor -class TensorProduct(nn.Module): # type: ignore[misc] - optimize_einsums = True - script_codegen = True + _tp_up_x1_gather: torch.Tensor + _tp_up_instr_idx: torch.Tensor + _tp_up_norm: torch.Tensor def __init__( self, - irreps_in1: o3.Irreps, - irreps_in2: o3.Irreps, - irreps_out: o3.Irreps, + irreps_in1: Irreps, + irreps_in2: Irreps, + irreps_out: Irreps, + x1_contains_r: bool = True, ): super().__init__() self.irreps_in1 = irreps_in1 self.irreps_in2 = irreps_in2 self.irreps_out = irreps_out + self.x1_contains_r = x1_contains_r self.instr = [ (i_1, i_2, i_out) for i_1, (_, ir_1) in enumerate(irreps_in1) for i_2, (_, ir_2) in enumerate(irreps_in2) for i_out, (_, ir_out) in enumerate(irreps_out) - if ir_out in ir_1 * ir_2 + if ir_out in ir_1 * ir_2 # type: ignore # Irrep.__mul__ not in stubs + ] + + self.slices = [irreps_in1.slices(), irreps_in2.slices(), irreps_out.slices()] + + self._setup_batched_tp() + assert self._batched_tp_mode is not None, ( + f"TensorProduct requires batched TP support but got incompatible irreps: " + f"in1={irreps_in1}, in2={irreps_in2}, out={irreps_out}" + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + def num_elements(ins: tuple[int, int, int]) -> int: + return int(self.irreps_in1[ins[0]].mul * self.irreps_in2[ins[1]].mul) + + for idx, ins in enumerate(self.instr): + num_in = sum(num_elements(ins_) for ins_ in self.instr if ins_[2] == ins[2]) + num_out = self.irreps_out[ins[2]].mul + x = (6 / (num_in + num_out)) ** 0.5 + self._batched_W.data[idx].uniform_(-x, x) + + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata: dict[str, Any], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ) -> None: + """Backward-compatible loading: convert old per-instruction weight keys to _batched_W.""" + if self._batched_tp_mode is not None: + old_keys = [f"{prefix}weight_{i1}_{i2}_{io}" for i1, i2, io in self.instr] + has_old = all(k in state_dict for k in old_keys) + has_new = f"{prefix}_batched_W" in state_dict + + if has_old and not has_new: + stacked = torch.stack([state_dict.pop(k).squeeze(1) for k in old_keys]) + state_dict[f"{prefix}_batched_W"] = stacked + + # Drop legacy w3j buffers that are no longer registered. + for i1, i2, io in self.instr: + state_dict.pop(f"{prefix}w3j_{i1}_{i2}_{io}", None) + + # Fill reconstructable buffers from current model state. + current_sd = dict(self.named_buffers()) + for buf_name, buf_val in current_sd.items(): + full_key = f"{prefix}{buf_name}" + if full_key not in state_dict: + state_dict[full_key] = buf_val + + super()._load_from_state_dict( # type: ignore[no-untyped-call] + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _setup_batched_tp(self) -> None: + """Detect batched mode and pre-compute gather indices. + + Supports two patterns: + + - ``tp_down``: all-scalar input (l1=0), single input irrep, v=1 edge + multiplicities. Batches weight matmuls into a single einsum, then + scatter-multiplies with x2. + - ``tp_up``: all-scalar output (l_out=0), single output irrep, v=1 edge + multiplicities. Computes per-instruction dot products, then batches + weight matmuls + sum. + """ + self._batched_tp_mode: str | None = None + + if len(self.instr) <= 1: + return + + # Common checks: all v=1 and all (u, w) same shape. + if not all(self.irreps_in2[i_2].mul == 1 for _, i_2, _ in self.instr): + return + shapes = { + (self.irreps_in1[i_1].mul, self.irreps_out[i_out].mul) + for i_1, _, i_out in self.instr + } + if len(shapes) != 1: + return + u, w = next(iter(shapes)) + + if self.x1_contains_r: + # tp_down: all l1=0, single input irrep group. + i1_values = {i_1 for i_1, _, _ in self.instr} + if len(i1_values) == 1 and self.irreps_in1[next(iter(i1_values))].ir.l == 0: + self._batched_tp_mode = "tp_down" + self._tp_batch_u = u + self._tp_batch_w = w + self._tp_down_x2_info = [ + ( + self.slices[1][i_2].start, + self.slices[1][i_2].stop, + self.irreps_out[i_out].ir.l, + ) + for _, i_2, i_out in self.instr + ] + # Pre-compute gather indices for vectorized tp_down. + xw_idx = [] + sph_idx = [] + norm_vals = [] + for b_idx in range(len(self.instr)): + x2_start, _, l_out = self._tp_down_x2_info[b_idx] + dim_l = 2 * l_out + 1 + nf = 1.0 if l_out == 0 else 1.0 / math.sqrt(dim_l) + for c in range(w): + for m in range(dim_l): + xw_idx.append(b_idx * w + c) + sph_idx.append(x2_start + m) + norm_vals.append(nf) + self.register_buffer( + "_tp_down_xw_idx", + torch.tensor(xw_idx, dtype=torch.long), + ) + self.register_buffer( + "_tp_down_sph_idx", + torch.tensor(sph_idx, dtype=torch.long), + ) + self.register_buffer("_tp_down_norm", torch.tensor(norm_vals)) + else: + # tp_up: all l_out=0, single output irrep group. + i_out_values = {i_out for _, _, i_out in self.instr} + if len(i_out_values) == 1 and all( + self.irreps_out[i_out].ir.l == 0 for _, _, i_out in self.instr + ): + self._batched_tp_mode = "tp_up" + self._tp_batch_u = u + self._tp_batch_w = w + self._tp_up_x1_info = [ + ( + self.slices[0][i_1].start, + self.slices[0][i_1].stop, + self.irreps_in1[i_1].ir.dim, + ) + for i_1, _, _ in self.instr + ] + self._tp_up_x2_info = [ + (self.slices[1][i_2].start, self.slices[1][i_2].stop) + for _, i_2, _ in self.instr + ] + # Pre-compute gather index to rearrange x1 from interleaved + # layout to sph-major layout aligned with x2. + sph_total = self.irreps_in2.dim + x1_gather = torch.zeros(sph_total * u, dtype=torch.long) + instr_idx = torch.zeros(sph_total, dtype=torch.long) + norm_vals_t = torch.zeros(sph_total) + for b_idx in range(len(self.instr)): + x1_start, _, dim_b = self._tp_up_x1_info[b_idx] + x2_start, _ = self._tp_up_x2_info[b_idx] + nf = 1.0 / math.sqrt(dim_b) + for m in range(dim_b): + s = x2_start + m + instr_idx[s] = b_idx + norm_vals_t[s] = nf + for ch in range(u): + x1_gather[s * u + ch] = x1_start + ch * dim_b + m + self.register_buffer("_tp_up_x1_gather", x1_gather) + self.register_buffer("_tp_up_instr_idx", instr_idx) + self.register_buffer("_tp_up_norm", norm_vals_t) + + if self._batched_tp_mode is not None: + self._batched_W = nn.Parameter(torch.empty(len(self.instr), u, w)) + + def _forward_batched_tp_down( + self, + x1: torch.Tensor, + x2: torch.Tensor, + distance_weights: torch.Tensor | None = None, + ) -> torch.Tensor: + """Batched forward for scalar-input tensor product (tp_down).""" + W = self._batched_W # (B, u, w) + xw = torch.einsum("sru, buw -> bsrw", x1, W) # (B, s, r, w) + if distance_weights is not None: + xw = xw * distance_weights # broadcasts over B + # Flatten instruction & channel dims: (B, s, r, w) -> (s, r, B*w) + xw_flat = xw.permute(1, 2, 0, 3).reshape(xw.shape[1], xw.shape[2], -1) + return ( + xw_flat[:, :, self._tp_down_xw_idx] + * x2[:, :, self._tp_down_sph_idx] + * self._tp_down_norm.to(x2.dtype) + ) + + def _forward_batched_tp_up( + self, + x1: torch.Tensor, + x2: torch.Tensor, + distance_weights: torch.Tensor | None = None, + ) -> torch.Tensor: + """Batched forward for scalar-output tensor product (tp_up).""" + W = self._batched_W # (B, u, w) + u = self._tp_batch_u + c = x1.shape[0] + sph = x2.shape[-1] + + # Rearrange x1: (c, hidden) -> (c, sph, u) aligned with x2's sph layout + x1_r = x1[:, self._tp_up_x1_gather].reshape(c, sph, u) + + # Expand W per sph position with normalization: (B, u, w) -> (sph, u, w) + W_sph = W[self._tp_up_instr_idx] * self._tp_up_norm[:, None, None].to(W.dtype) + + # Linear transform on x1: (c, sph, u) @ (sph, u, w) -> (c, sph, w) + x1_W = torch.einsum("csu, suw -> csw", x1_r, W_sph) + + # Inner product with x2: sum_l = + out = torch.einsum("csw, fcs -> fcw", x1_W, x2) + if distance_weights is not None: + out = distance_weights * out + return out + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + *, + distance_weights: torch.Tensor | None = None, + ) -> torch.Tensor: + assert (len(x1.size()) == 2 and not self.x1_contains_r) or ( + len(x1.size()) == 3 and self.x1_contains_r + ), "x1 must be 2D or 3D" + assert len(x2.size()) == 3, "x2 must be 3D" + if self._batched_tp_mode == "tp_down": + return self._forward_batched_tp_down(x1, x2, distance_weights) + return self._forward_batched_tp_up(x1, x2, distance_weights) + + +class O3Linear(nn.Module): + """Equivariant linear layer operating on irreps.""" + + optimize_einsums = True + script_codegen = False + + def __init__(self, irreps_in: Irreps, irreps_out: Irreps): + super().__init__() + + self.irreps_in = irreps_in + self.irreps_out = irreps_out + + self.instr = [ + (i_in, i_out) + for i_in, (_, ir_in) in enumerate(irreps_in) + for i_out, (_, ir_out) in enumerate(irreps_out) + if ir_in == ir_out ] self.weight_numel = sum( - self.irreps_in1[i_1].mul - * self.irreps_in2[i_2].mul - * self.irreps_out[i_out].mul - for i_1, i_2, i_out in self.instr + self.irreps_in[i_in].mul * self.irreps_out[i_out].mul + for i_in, i_out in self.instr ) - for i_1, i_2, i_out in self.instr: + for i_in, i_out in self.instr: self.register_parameter( - f"weight_{i_1}_{i_2}_{i_out}", + f"weight_{i_in}_{i_out}", nn.Parameter( torch.randn( - self.irreps_in1[i_1].mul, - self.irreps_in2[i_2].mul, + self.irreps_in[i_in].mul, self.irreps_out[i_out].mul, ) ), ) - self.slices = [irreps_in1.slices(), irreps_in2.slices(), irreps_out.slices()] - for i_1, i_2, i_out in self.instr: - w3j = o3.wigner_3j( - irreps_in1[i_1].ir.l, irreps_in2[i_2].ir.l, irreps_out[i_out].ir.l - ).permute(2, 0, 1) # ijk -> kij - self.register_buffer(f"w3j_{i_1}_{i_2}_{i_out}", w3j) - + self.slices = [irreps_in.slices(), irreps_out.slices()] self.reset_parameters() - self._sparse_tp = self.generate_sparse_tp_code() + self._o3_linear = self.generate_o3_linear_code() - def generate_sparse_tp_code(self) -> fx.GraphModule: - graphmod = _sparse_tensor_product_codegen(*self.tp_params) + def generate_o3_linear_code(self) -> fx.GraphModule: + graphmod = _o3_linear_codegen(*self.linear_params) if self.optimize_einsums: - m = 3 - - weight_list = self.weight_list example_inputs = ( - torch.randn(m, self.irreps_in1.dim), - torch.randn(m, self.irreps_in2.dim), - *self.w3j_list, - *weight_list, + torch.randn(3, self.irreps_in.dim), + *self.weight_list, ) graphmod = optimize_einsums_full(graphmod, example_inputs) @@ -432,152 +840,95 @@ def generate_sparse_tp_code(self) -> fx.GraphModule: return graphmod @property - def tp_params( - self, - ) -> tuple[ - list[tuple[int, int, int]], - list[tuple[int, int, int]], - list[tuple[int, int, int]], - list[tuple[int, int, int]], - list[list[tuple[int, int]]], - ]: + def linear_params(self) -> tuple[Any, ...]: return ( self.instr, - convert_irreps(self.irreps_in1), - convert_irreps(self.irreps_in2), + convert_irreps(self.irreps_in), convert_irreps(self.irreps_out), [[(ss.start, ss.stop) for ss in s] for s in self.slices], ) def reset_parameters(self) -> None: - def num_elements(ins: tuple[int, int, int]) -> int: - # assuming uvw connectivity - return int(self.irreps_in1[ins[0]].mul * self.irreps_in2[ins[1]].mul) + def num_elements(ins: tuple[int, int]) -> int: + return int(self.irreps_in[ins[0]].mul) for ins in self.instr: - i_1, i_2, i_out = ins - num_in = sum(num_elements(ins_) for ins_ in self.instr if ins_[2] == ins[2]) - num_out = self.irreps_out[ins[2]].mul + i_in, i_out = ins + num_in = sum(num_elements(ins_) for ins_ in self.instr if ins_[1] == ins[1]) + num_out = self.irreps_out[ins[1]].mul x = (6 / (num_in + num_out)) ** 0.5 - getattr(self, f"weight_{i_1}_{i_2}_{i_out}").data.uniform_(-x, x) + getattr(self, f"weight_{i_in}_{i_out}").data.uniform_(-x, x) @property def weight_list(self) -> list[torch.Tensor]: - return [ - getattr(self, f"weight_{i_1}_{i_2}_{i_out}") - for i_1, i_2, i_out in self.instr - ] - - @property - def w3j_list(self) -> list[torch.Tensor]: - return [ - getattr(self, f"w3j_{i_1}_{i_2}_{i_out}") for i_1, i_2, i_out in self.instr - ] - - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - return self._sparse_tp(x1, x2, *self.w3j_list, *self.weight_list) + return [getattr(self, f"weight_{i_in}_{i_out}") for i_in, i_out in self.instr] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._o3_linear(x, *self.weight_list) + + +def o3_identity_init(linear: O3Linear, out_dim_multiplier: int = 1) -> None: + """Initialize an O3Linear layer as an identity mapping.""" + for ins in linear.instr: + if ( + linear.irreps_in[ins[0]].mul * out_dim_multiplier + != linear.irreps_out[ins[1]].mul + ): + raise ValueError( + f"Input and output irreps must match, got " + f"{linear.irreps_in[ins[0]].mul} * {out_dim_multiplier} " + f"!= {linear.irreps_out[ins[1]].mul}" + ) + getattr(linear, f"weight_{ins[0]}_{ins[1]}").data.copy_( + torch.eye(linear.irreps_in[ins[0]].mul).repeat(1, out_dim_multiplier) + ) # irreps_format: list of (mul, ir.l, ir.dim) -def convert_irreps(irreps: o3.Irreps) -> list[tuple[int, int, int]]: - return [(mul, ir.l, ir.dim) for mul, ir in irreps] - +def convert_irreps(irreps: Irreps) -> list[tuple[int, int, int]]: + return [(mulir.mul, mulir.ir.l, mulir.ir.dim) for mulir in irreps] -def convert_slices(slices: list[slice]) -> list[tuple[int, int]]: - return [(s.start, s.stop) for s in slices] - -def _sparse_tensor_product_codegen( - instr: list[tuple[int, int, int]], - irreps_in1: list[tuple[int, int, int]], - irreps_in2: list[tuple[int, int, int]], +def _o3_linear_codegen( + instr: list[tuple[int, int]], + irreps_in: list[tuple[int, int, int]], irreps_out: list[tuple[int, int, int]], slices: list[list[tuple[int, int]]], ) -> fx.GraphModule: - # x1: m, (u, i) - # x2: m, (v, j) # v is always 1 - graph = fx.Graph() tracer = fx.proxy.GraphAppendingTracer(graph) x1 = fx.Proxy(graph.placeholder("x1", torch.Tensor), tracer=tracer) - x2 = fx.Proxy(graph.placeholder("x2", torch.Tensor), tracer=tracer) - m = x2.size(0) - w3js = [ - fx.Proxy( - graph.placeholder(f"w3j_{i_1}_{i_2}_{i_out}", torch.Tensor), tracer=tracer - ) - for i_1, i_2, i_out in instr - ] weights = [ fx.Proxy( - graph.placeholder(f"weight_{i_1}_{i_2}_{i_out}", torch.Tensor), + graph.placeholder(f"weight_{i_in}_{i_out}", torch.Tensor), tracer=tracer, ) - for i_1, i_2, i_out in instr + for i_in, i_out in instr ] - outs = list() - for (i_1, i_2, i_out), w, w3j in zip(instr, weights, w3js, strict=True): - irrep_in1 = irreps_in1[i_1] - irrep_in2 = irreps_in2[i_2] - irrep_out = irreps_out[i_out] - - l1l2l3 = (irrep_in1[1], irrep_in2[1], irrep_out[1]) - - x1_i = x1[..., slices[0][i_1][0] : slices[0][i_1][1]] - x2_i = x2[..., slices[1][i_2][0] : slices[1][i_2][1]] - - if l1l2l3 == (0, 0, 0): - outs.append(torch.einsum("mu,uvw,mv->mw", x1_i, w, x2_i)) - elif l1l2l3[0] == 0: - outs.append( - torch.einsum( - "mu,uvw,mvj->mwj", x1_i, w, x2_i.view(m, irrep_in2[0], irrep_in2[2]) - ).reshape(m, irrep_out[0] * irrep_out[2]) - / math.sqrt(irrep_out[2]) - ) - elif l1l2l3[1] == 0: - outs.append( - torch.einsum( - "mui,uvw,mv->mwi", x1_i.view(m, irrep_in1[0], irrep_in1[2]), w, x2_i - ).reshape(m, irrep_out[0] * irrep_out[2]) - / math.sqrt(irrep_out[2]) - ) - elif l1l2l3[2] == 0: - outs.append( - torch.einsum( - "mui,uvw,mvi->mw", - x1_i.view(m, irrep_in1[0], irrep_in1[2]), - w, - x2_i.view(m, irrep_in2[0], irrep_in2[2]), - ) - / math.sqrt(irrep_in1[2]) - ) - else: - outs.append( - torch.einsum( - "mui,uvw,mvj,kij->mwk", - x1_i.view(m, irrep_in1[0], irrep_in1[2]), - w, - x2_i.view(m, irrep_in2[0], irrep_in2[2]), - w3j, - ).reshape(m, irrep_out[0] * irrep_out[2]) - ) + outs: list[Any] = list() + for (i_in, i_out), w in zip(instr, weights, strict=True): + x1_i = x1[:, slices[0][i_in][0] : slices[0][i_in][1]] # type: ignore + outs.append( + torch.einsum( + "sui,uv->svi", + x1_i.view(-1, irreps_in[i_in][0], irreps_in[i_in][2]), + w, + ).reshape(-1, irreps_out[i_out][0] * irreps_out[i_out][2]) + ) - out = [ - sum(out for ins, out in zip(instr, outs, strict=False) if ins[2] == i_out) + out: list[Any] = [ + sum(out for ins, out in zip(instr, outs, strict=True) if ins[1] == i_out) for i_out, (mul, *_) in enumerate(irreps_out) if mul > 0 ] if len(out) > 1: - concatenated = torch.cat(out, dim=-1) + concatenated: Any = torch.cat(out, dim=-1) else: concatenated = out[0] graph.output(concatenated.node, torch.Tensor) - graph.lint() - - graphmod = fx.GraphModule(torch.nn.Module(), graph) + graph.lint() # type: ignore[no-untyped-call] - return graphmod + return fx.GraphModule(torch.nn.Module(), graph) diff --git a/src/skala/functional/traditional.py b/src/skala/functional/traditional.py index 9947720..5fc8feb 100644 --- a/src/skala/functional/traditional.py +++ b/src/skala/functional/traditional.py @@ -11,7 +11,7 @@ from collections.abc import Iterator import torch -from torch import Tensor +from torch import Tensor, nn from skala.functional import density from skala.functional.base import ExcFunctionalBase @@ -153,12 +153,12 @@ class PBE(SpinScaledXCFunctional): def __init__(self) -> None: super().__init__() self.lda = SPW92() - self.beta = torch.tensor(0.066725) - self.kappa = torch.tensor(0.804) + self.beta = nn.Parameter(torch.tensor(0.066725), requires_grad=False) + self.kappa = nn.Parameter(torch.tensor(0.804), requires_grad=False) self.mu = self.beta * (math.pi**2 / 3) - def parameters(self, recurse: bool = True) -> Iterator[Tensor]: - super().parameters(recurse) + def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: + yield from super().parameters(recurse) yield from [self.beta, self.kappa] def exchange(self, mol_features: dict[str, Tensor]) -> Tensor: @@ -211,13 +211,13 @@ def __init__(self) -> None: super().__init__() self.lda = SPW92() self.pbe = PBE() - self.c = torch.tensor(1.59096) - self.e = torch.tensor(1.537) - self.b = torch.tensor(0.40) - self.d = torch.tensor(2.8) + self.c = nn.Parameter(torch.tensor(1.59096), requires_grad=False) + self.e = nn.Parameter(torch.tensor(1.537), requires_grad=False) + self.b = nn.Parameter(torch.tensor(0.40), requires_grad=False) + self.d = nn.Parameter(torch.tensor(2.8), requires_grad=False) - def parameters(self, recurse: bool = True) -> Iterator[Tensor]: - super().parameters(recurse) + def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: + yield from super().parameters(recurse) yield from [self.c, self.e, self.b, self.d] def exchange(self, mol_features: dict[str, Tensor]) -> Tensor: diff --git a/src/skala/functional/utils/__init__.py b/src/skala/functional/utils/__init__.py new file mode 100644 index 0000000..548d2d4 --- /dev/null +++ b/src/skala/functional/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: MIT diff --git a/src/skala/functional/utils/cg.py b/src/skala/functional/utils/cg.py new file mode 100644 index 0000000..cb8995d --- /dev/null +++ b/src/skala/functional/utils/cg.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: MIT + +# Based on original MACE code: https://github.com/ACEsuit/mace +# See algorithm 1 in the appendix of https://arxiv.org/pdf/2206.07697 + +"""Clebsch–Gordan coefficient utilities for the symmetric contraction.""" + +from __future__ import annotations + +import torch +from e3nn import o3 + + +def _wigner_nj( + irreps_list: list[o3.Irreps], + filter_irs: list[o3.Irrep] | None = None, + normalization: str = "component", + dtype: torch.dtype | None = None, +) -> list[tuple[o3.Irrep, torch.Tensor]]: + ret: list[tuple[o3.Irrep, torch.Tensor]] = [] + + if len(irreps_list) == 1: + irreps = irreps_list[0] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret.append((ir, e[sl])) + i += ir.dim + return ret + + *irreps_list_left, irreps_right = irreps_list + for ir_left, C_left in _wigner_nj( + irreps_list_left, + normalization=normalization, + filter_irs=filter_irs, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_irs is not None and ir_out not in filter_irs: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, + *(irreps.dim for irreps in irreps_list_left), + ir.dim, + ) + + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irreps_list_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + + ret.append((ir_out, E)) + i += mul * ir.dim + + return sorted(ret, key=lambda x: x[0]) + + +def u_matrix_real( + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + normalization: str = "component", + filter_irs: list[o3.Irrep] | None = None, + dtype: torch.dtype | None = None, +) -> list[torch.Tensor]: + """Compute the real U-matrix for the symmetric contraction. + + Args: + irreps_in: Input irreps (will be repeated ``correlation`` times). + irreps_out: Target output irreps. + correlation: Correlation order. + normalization: Either ``"component"`` or ``"norm"``. + filter_irs: Optional filter on intermediate irreps. + dtype: Tensor dtype. + + Returns: + List of U-matrix tensors, one per output irrep. + """ + irreps_list = [irreps_in] * correlation + + if correlation == 4: + filter_irs = [o3.Irrep(l, (-1) ** l) for l in range(12)] + + wigner_njs = _wigner_nj( + irreps_list, + filter_irs=filter_irs, + normalization=normalization, + dtype=dtype, + ) + + current_ir = wigner_njs[0][0] + + out: list[torch.Tensor] = [] + stack = torch.tensor([]) + + for irrep, base_o3 in wigner_njs: + if irrep in irreps_out and irrep == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + elif irrep in irreps_out and irrep != current_ir: + if len(stack) != 0: + out.append(stack) + stack = base_o3.squeeze().unsqueeze(-1) + current_ir = irrep + else: + current_ir = irrep + + if len(stack) != 0: + out.append(stack) + return out diff --git a/src/skala/functional/utils/irreps.py b/src/skala/functional/utils/irreps.py new file mode 100644 index 0000000..88eddf8 --- /dev/null +++ b/src/skala/functional/utils/irreps.py @@ -0,0 +1,317 @@ +# SPDX-License-Identifier: MIT + +""" +Compile-compatible implementation of e3nn.o3.Irreps. + +This module provides a reimplementation of e3nn's Irrep, MulIr, and Irreps classes +that is fully compatible with torch.compile. The original e3nn implementation +inherits from tuple and raises NotImplementedError for __len__ on Irrep, which +causes issues with torch.compile's graph tracing. + +Key differences from e3nn.o3.Irreps: + - Uses __slots__ instead of inheriting from tuple (for Irrep and MulIr) + - Uses an internal list instead of inheriting from tuple (for Irreps) + - Uses __init__ instead of __new__ for construction + - Omits Wigner D-matrix methods (D_from_angles, D_from_quaternion, etc.) + - Omits some convenience methods (randn, filter, regroup, count, etc.) + +The implemented subset is sufficient for defining irreducible representations +and their direct sums for use in equivariant neural network architectures. + +Example: + >>> from skala.functional.utils.irreps import Irrep, Irreps + >>> Irrep("1o") + 1o + >>> Irreps("16x0e + 8x1o + 4x2e") + 16x0e+8x1o+4x2e + >>> Irreps.spherical_harmonics(3) + 1x0e+1x1o+1x2e+1x3o + +See Also: + e3nn.o3.Irreps: The original implementation from the e3nn library. +""" + +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import NamedTuple, overload + + +class Irrep: + __slots__ = ("_l", "_p") + _l: int + _p: int + + def __init__(self, l: int | str | tuple[int, int] | Irrep, p: int | None = None): + if p is None: + if isinstance(l, Irrep): + self._l = l._l + self._p = l._p + return + if isinstance(l, str): + name = l.strip() + self._l = int(name[:-1]) + assert self._l >= 0 + self._p = {"e": 1, "o": -1, "y": (-1) ** self._l}[name[-1]] + return + if isinstance(l, tuple): + l, p = l + + if not isinstance(l, int) or l < 0: + raise ValueError(f"l must be non-negative integer, got {l}") + if p not in (-1, 1): + raise ValueError(f"parity must be -1 or 1, got {p}") + self._l = l + self._p = p + + @property + def l(self) -> int: # noqa: E743 + return self._l + + @property + def p(self) -> int: + return self._p + + @property + def dim(self) -> int: + return 2 * self._l + 1 + + def __repr__(self) -> str: + p = {1: "e", -1: "o"}[self._p] + return f"{self._l}{p}" + + def __eq__(self, other: object) -> bool: + if isinstance(other, Irrep): + return self._l == other._l and self._p == other._p + if isinstance(other, tuple) and len(other) == 2: + return self._l == other[0] and self._p == other[1] + return False + + def __hash__(self) -> int: + return hash((self._l, self._p)) + + def __mul__(self, other: Irrep) -> Iterator[Irrep]: + other = Irrep(other) + p = self._p * other._p + lmin = abs(self._l - other._l) + lmax = self._l + other._l + for l in range(lmin, lmax + 1): + yield Irrep(l, p) + + def __iter__(self) -> Iterator[int]: + yield self._l + yield self._p + + def __getitem__(self, i: int) -> int: + if i == 0: + return self._l + if i == 1: + return self._p + raise IndexError(i) + + +class MulIr: + __slots__ = ("_mul", "_ir") + _mul: int + _ir: Irrep + + def __init__( + self, + mul: int | MulIr | tuple[int, Irrep | str | tuple[int, int]], + ir: Irrep | str | tuple[int, int] | None = None, + ): + if ir is None: + if isinstance(mul, MulIr): + self._mul = mul._mul + self._ir = mul._ir + return + if isinstance(mul, tuple) and len(mul) == 2: + mul, ir = mul + if not isinstance(mul, int) or mul < 0: + raise ValueError(f"mul must be non-negative integer, got {mul}") + if ir is None: + raise ValueError("ir must be provided") + self._mul = mul + self._ir = Irrep(ir) if not isinstance(ir, Irrep) else ir + + @property + def mul(self) -> int: + return self._mul + + @property + def ir(self) -> Irrep: + return self._ir + + @property + def dim(self) -> int: + return self._mul * self._ir.dim + + def __repr__(self) -> str: + return f"{self._mul}x{self._ir}" + + def __iter__(self) -> Iterator[int | Irrep]: + yield self._mul + yield self._ir + + def __getitem__(self, i: int) -> int | Irrep: + if i == 0: + return self._mul + if i == 1: + return self._ir + raise IndexError(i) + + +class Irreps: + """Compile-compatible direct sum of irreducible representations of O(3).""" + + def __init__( + self, + irreps: str + | Irreps + | Irrep + | Sequence[MulIr | str | Irrep | tuple[int, Irrep | str | tuple[int, int]]] + | None = None, + ): + self._data: list[MulIr] = [] + + if irreps is None: + return + if isinstance(irreps, Irreps): + self._data = list(irreps._data) + return + if isinstance(irreps, Irrep): + self._data = [MulIr(1, irreps)] + return + if isinstance(irreps, str): + if irreps.strip() == "": + return + for mul_ir in irreps.split("+"): + mul_ir = mul_ir.strip() + if "x" in mul_ir: + mul_str, ir_str = mul_ir.split("x") + mul = int(mul_str) + ir = Irrep(ir_str) + else: + mul = 1 + ir = Irrep(mul_ir) + self._data.append(MulIr(mul, ir)) + return + for item in irreps: + if isinstance(item, MulIr): + self._data.append(item) + elif isinstance(item, str): + self._data.append(MulIr(1, Irrep(item))) + elif isinstance(item, Irrep): + self._data.append(MulIr(1, item)) + elif isinstance(item, tuple) and len(item) == 2: + mul, ir_like = item + self._data.append(MulIr(mul, Irrep(ir_like))) + else: + raise ValueError(f"Unable to interpret {item!r} as an irrep") + + @staticmethod + def spherical_harmonics(lmax: int, p: int = -1) -> Irreps: + return Irreps([(1, (l, p**l)) for l in range(lmax + 1)]) + + def slices(self) -> list[slice]: + s = [] + i = 0 + for mul_ir in self._data: + s.append(slice(i, i + mul_ir.dim)) + i += mul_ir.dim + return s + + @property + def dim(self) -> int: + return sum(mul_ir.dim for mul_ir in self._data) + + @property + def num_irreps(self) -> int: + return sum(mul_ir.mul for mul_ir in self._data) + + @property + def ls(self) -> list[int]: + return [mul_ir.ir.l for mul_ir in self._data for _ in range(mul_ir.mul)] + + @property + def lmax(self) -> int: + if len(self._data) == 0: + raise ValueError("Cannot get lmax of empty Irreps") + return max(mul_ir.ir.l for mul_ir in self._data) + + def simplify(self) -> Irreps: + out: list[tuple[int, Irrep]] = [] + for mul_ir in self._data: + mul, ir = mul_ir.mul, mul_ir.ir + if out and out[-1][1] == ir: + out[-1] = (out[-1][0] + mul, ir) + elif mul > 0: + out.append((mul, ir)) + return Irreps(out) + + class _SortResult(NamedTuple): + irreps: Irreps + p: tuple[int, ...] + inv: tuple[int, ...] + + def sort(self) -> Irreps._SortResult: + indexed = [(mul_ir.ir, i, mul_ir.mul) for i, mul_ir in enumerate(self._data)] + indexed.sort(key=lambda x: (x[0].l, x[0].p)) + inv = tuple(i for _, i, _ in indexed) + p_list = [0] * len(inv) + for i, j in enumerate(inv): + p_list[j] = i + p = tuple(p_list) + irreps = Irreps([(mul, ir) for ir, _, mul in indexed]) + return Irreps._SortResult(irreps, p, inv) + + def __len__(self) -> int: + return len(self._data) + + def __iter__(self) -> Iterator[MulIr]: + return iter(self._data) + + @overload + def __getitem__(self, i: int) -> MulIr: ... + + @overload + def __getitem__(self, i: slice) -> Irreps: ... + + def __getitem__(self, i: int | slice) -> MulIr | Irreps: + if isinstance(i, slice): + return Irreps([(m.mul, m.ir) for m in self._data[i]]) + return self._data[i] + + def __contains__(self, ir: Irrep) -> bool: + ir = Irrep(ir) + return any(mul_ir.ir == ir for mul_ir in self._data) + + def __add__(self, other: Irreps) -> Irreps: + other = Irreps(other) + return Irreps( + [(m.mul, m.ir) for m in self._data] + [(m.mul, m.ir) for m in other._data] + ) + + def __mul__(self, n: int) -> Irreps: + if not isinstance(n, int): + raise NotImplementedError("Use o3.TensorProduct for irrep multiplication") + return Irreps([(m.mul, m.ir) for m in self._data] * n) + + def __rmul__(self, n: int) -> Irreps: + return self.__mul__(n) + + def __repr__(self) -> str: + return "+".join(str(mul_ir) for mul_ir in self._data) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Irreps): + return False + if len(self._data) != len(other._data): + return False + return all( + a.mul == b.mul and a.ir == b.ir + for a, b in zip(self._data, other._data, strict=True) + ) + + def __hash__(self) -> int: + return hash(tuple((m.mul, m.ir.l, m.ir.p) for m in self._data)) diff --git a/src/skala/functional/utils/pad_ragged.py b/src/skala/functional/utils/pad_ragged.py new file mode 100644 index 0000000..00bbb62 --- /dev/null +++ b/src/skala/functional/utils/pad_ragged.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: MIT + +import torch + + +def pad_ragged( + data: torch.Tensor, sizes: torch.Tensor, size_bound: int +) -> torch.Tensor: + """Packs variable-length concatenated sequences into a batched tensor with padding. + + You can think of the static parameter `size_bound` as `sizes.max()`, but it is passed separately + as an integer to avoid "data-dependent control flow". We want to avoid the shape of the output + tensor depending on the value of the input tensors. + + Args: + data: Tensor [total, *rest] where total = sum(sizes) + sizes: 1D tensor [batch] of sequence lengths + size_bound: Pad dimension size. If smaller than max(sizes), sequences will be cropped. + + Returns: + Tensor [batch, size_bound, *rest], zero-padded. + + Example: + >>> data = torch.tensor([1, 2, 3, 4, 5]) # two sequences: [1,2] and [3,4,5] + >>> sizes = torch.tensor([2, 3]) + >>> pad_ragged(data, sizes, size_bound=4) + tensor([[1, 2, 0, 0], + [3, 4, 5, 0]]) + """ + if (sizes < 0).any(): + raise ValueError("sizes must contain only non-negative values") + if data.shape[0] != sizes.sum(): + raise ValueError( + f"data length {data.shape[0]} must equal sum of sizes {sizes.sum().item():.0f}" + ) + + batch_size = sizes.shape[0] + rest_shape = data.shape[1:] + + # Fast path: single sequence - just pad directly + if batch_size == 1: + seq_len = data.shape[0] + pad_len = size_bound - seq_len + if pad_len > 0: + padding = torch.zeros( + pad_len, *rest_shape, dtype=data.dtype, device=data.device + ) + return torch.cat([data, padding], dim=0).unsqueeze(0) + return data[:size_bound].unsqueeze(0) + + col_indices = torch.broadcast_to( + torch.arange(size_bound, device=data.device), + (batch_size, size_bound), + ) # [batch_size, size_bound] + + # Compute source indices + ends = sizes.cumsum(0) + starts = ends - sizes + source_indices = starts.unsqueeze(1) + col_indices # [batch_size, size_bound] + clamped_indices = source_indices.clamp( + 0, data.shape[0] - 1 + ) # Don't exceed data size. + + # Gather values from data + gathered = data[clamped_indices.view(-1)].view(batch_size, size_bound, *rest_shape) + + # Zero out invalid positions - expand mask for broadcasting over rest dimensions + mask = col_indices < sizes.unsqueeze(1) # [batch_size, size_bound] + mask_expanded = mask.view(batch_size, size_bound, *([1] * len(rest_shape))) + out = gathered * mask_expanded + + return out + + +def unpad_ragged( + padded: torch.Tensor, sizes: torch.Tensor, total_size: int +) -> torch.Tensor: + """Inverse of pad_ragged: extract variable-length sequences from a padded batch tensor. + + You can think of the static parameter ``total_size`` as ``sizes.sum()``, but it is passed + separately as an integer to avoid data-dependent output shapes, keeping the function compatible + with ``torch.compile(fullgraph=True)`` and ``torch.export``. + + Args: + padded: Tensor [batch, size_bound, *rest] with zero-padding. + sizes: 1D tensor [batch] of true sequence lengths. + total_size: Total number of valid elements, i.e. sum(sizes). Passed as an integer to + keep the output shape static. + + Returns: + Tensor [total_size, *rest], with padding removed. + + Example: + >>> padded = torch.tensor([[1, 2, 0, 0], + ... [3, 4, 5, 0]]) + >>> sizes = torch.tensor([2, 3]) + >>> unpad_ragged(padded, sizes, total_size=5) + tensor([1, 2, 3, 4, 5]) + """ + if (sizes < 0).any(): + raise ValueError("sizes must contain only non-negative values") + + batch_size = padded.shape[0] + size_bound = padded.shape[1] + rest_shape = padded.shape[2:] + + if total_size == 0: + return torch.zeros(0, *rest_shape, dtype=padded.dtype, device=padded.device) + + # Fast path: single sequence - just slice + if batch_size == 1: + return padded[0, :total_size] + + # For each output position, find which batch it belongs to via binary search, + # then gather the corresponding element from the padded tensor. + ends = sizes.cumsum(0) + starts = ends - sizes + positions = torch.arange(total_size, device=padded.device) + batch_id = torch.searchsorted(ends, positions, right=True) + local_col = positions - starts[batch_id] + src_idx = batch_id * size_bound + local_col + + flat_padded = padded.reshape(-1, *rest_shape) + return flat_padded[src_idx] diff --git a/src/skala/functional/utils/symmetric_contraction.py b/src/skala/functional/utils/symmetric_contraction.py new file mode 100644 index 0000000..7dbf2cf --- /dev/null +++ b/src/skala/functional/utils/symmetric_contraction.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: MIT + +# Based on original MACE code: https://github.com/ACEsuit/mace +# See algorithm 1 in the appendix of https://arxiv.org/pdf/2206.07697 + +import opt_einsum_fx +import torch +import torch.fx +from e3nn import o3 + +from skala.functional.utils.cg import u_matrix_real + +ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] + + +def get_alphabet_string(i: int) -> str: + if i == -1: + return "" + return "".join(ALPHABET[:i]) + + +class SymmetricContraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + sketch: bool = False, + ) -> None: + super().__init__() + + self.irreps_in = irreps_in + self.irreps_out = irreps_out + self.correlation = correlation + self.sketch = sketch + + hidden_nfs = [mul for mul, _ in irreps_in] + assert len(set(hidden_nfs)) == 1, ( + "All irreps need to have the same number of channels" + ) + hidden_nf = hidden_nfs[0] + + self.contractions = torch.nn.ModuleList( + [ + Contraction( + irreps_in=irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation, + hidden_nf=hidden_nf, + ) + for irrep_out in irreps_out + ] + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + x_packed = pack(x, self.irreps_in) + if self.sketch: + xs = torch.chunk(x_packed, self.correlation, dim=1) + else: + xs = (x_packed,) * self.correlation + outs = [contraction(*xs) for contraction in self.contractions] + return torch.cat(outs, dim=-1) + + +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + hidden_nf: int, + ) -> None: + super().__init__() + + self.correlation = correlation + + for nu in range(correlation, 0, -1): + u_matrix = u_matrix_real( + irreps_in=o3.Irreps("+".join(str(irrep.ir) for irrep in irreps_in)), + irreps_out=irrep_out, + correlation=nu, + dtype=torch.get_default_dtype(), + )[-1] + self.register_buffer(f"u_matrix_{nu}", u_matrix) + + self.contractions_weighting = torch.nn.ModuleList() + self.contractions_features = torch.nn.ModuleList() + + self.weights = torch.nn.ParameterList() + + for i in range(correlation, 0, -1): + u_tensor = self.get_u_tensor(i) + dim_nu = u_tensor.shape[-1] + dim_lm = u_tensor.shape[-2] + dim_M = 2 * irrep_out.lmax + 1 + dim_x = 11 + dim_c = hidden_nf + + if i == correlation: + indices = get_alphabet_string(i + min(irrep_out.lmax, 1) - 1) + + def _graph_module_main( + C: torch.Tensor, W: torch.Tensor, x: torch.Tensor + ) -> torch.Tensor: + return torch.einsum( + indices + "ik,kc,bci -> bc" + indices, # noqa: B023 + C, + W, + x, + ) + + graph_module_main = torch.fx.symbolic_trace(_graph_module_main) + + graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn([dim_M] + [dim_lm] * i + [dim_nu]).squeeze(0), + torch.randn((dim_nu, dim_c)), + torch.randn((dim_x, dim_c, dim_lm)), + ), + ) + assert isinstance(graph_opt_main, torch.fx.GraphModule) + self.graph_opt_main = graph_opt_main + + self.weights_max = torch.nn.Parameter( + torch.randn((dim_nu, dim_c)) / dim_nu + ) + + else: + indices = get_alphabet_string(i + min(irrep_out.lmax, 1)) + + def _graph_module_weighting( + C: torch.Tensor, W: torch.Tensor + ) -> torch.Tensor: + return torch.einsum( + indices + "k,kc->c" + indices, # noqa: B023 + C, + W, + ) + + graph_module_weighting = torch.fx.symbolic_trace( + _graph_module_weighting + ) + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn([dim_M] + [dim_lm] * i + [dim_nu]).squeeze(0), + torch.randn((dim_nu, dim_c)), + ), + ) + assert isinstance(graph_opt_weighting, torch.fx.GraphModule) + self.contractions_weighting.append(graph_opt_weighting) + + indices = get_alphabet_string(i - 1 + min(irrep_out.lmax, 1)) + + def _graph_module_features( + c: torch.Tensor, x: torch.Tensor + ) -> torch.Tensor: + return torch.einsum( + "bc" + indices + "i,bci->bc" + indices, # noqa: B023 + c, + x, + ) + + graph_module_features = torch.fx.symbolic_trace(_graph_module_features) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + torch.randn([dim_x, dim_c, dim_M] + [dim_lm] * i).squeeze(2), + torch.randn((dim_x, dim_c, dim_lm)), + ), + ) + assert isinstance(graph_opt_features, torch.fx.GraphModule) + self.contractions_features.append(graph_opt_features) + + self.weights.append( + torch.nn.Parameter(torch.randn((dim_nu, dim_c)) / dim_nu) + ) + + def get_u_tensor(self, nu: int) -> torch.Tensor: + return dict(self.named_buffers())[f"u_matrix_{nu}"] + + def forward( + self, + *xs: list[torch.Tensor], + ) -> torch.Tensor: + out = self.graph_opt_main( + self.get_u_tensor(self.correlation), self.weights_max, xs[0] + ) + for i, (x, weights, contract_weights, contract_features) in enumerate( + zip( + xs[1:], + self.weights, + self.contractions_weighting, + self.contractions_features, + strict=False, + ) + ): + c_tensor = contract_weights( + self.get_u_tensor(self.correlation - i - 1), weights + ) + c_tensor = c_tensor + out + out = contract_features(c_tensor, x) + + return out.view(out.shape[0], -1) + + +def pack(x: torch.Tensor, irreps: o3.Irreps) -> torch.Tensor: + return torch.cat( + [ + x[..., slice].view(*x.shape[:-1], mul, ir.dim) + for (mul, ir), slice in zip(irreps, irreps.slices(), strict=False) + ], + dim=-1, + ) diff --git a/src/skala/gauxc/export.py b/src/skala/gauxc/export.py index 8c27dcd..cce923a 100644 --- a/src/skala/gauxc/export.py +++ b/src/skala/gauxc/export.py @@ -80,7 +80,7 @@ def pyscf_to_gauxc_h5( mol.atom_charges(), mol.atom_coords(unit="Bohr"), strict=True ) ], - dtype=MOLECULE_DTYPE, + dtype=MOLECULE_DTYPE, # type: ignore[call-overload] # numpy structured dtype ) basis = np.array( [ @@ -95,7 +95,7 @@ def pyscf_to_gauxc_h5( for func in mol._basis[atom] for prim in range(1, len(func[1])) ], - dtype=BASIS_DTYPE, + dtype=BASIS_DTYPE, # type: ignore[call-overload] # numpy structured dtype ) dm_scalar = dm if dm.ndim == 2 else dm[0] + dm[1] dm_z = np.zeros_like(dm) if dm.ndim == 2 else dm[0] - dm[1] @@ -135,7 +135,7 @@ def norm( aa = K_MINUS_1[2 * l] * SQRT_PI_CUBED / (2**l * gamma ** (l + 1) * np.sqrt(gamma)) coeff = np.asarray(coeff) * normalization_factor normalization_factor = 1.0 / np.sqrt(np.einsum("i,j,ij->", coeff, coeff, aa)) - return (coeff * normalization_factor).tolist() # type: ignore + return (coeff * normalization_factor).tolist() def format_basis( diff --git a/src/skala/gpu4pyscf/__init__.py b/src/skala/gpu4pyscf/__init__.py index ca251f2..eb69ca9 100644 --- a/src/skala/gpu4pyscf/__init__.py +++ b/src/skala/gpu4pyscf/__init__.py @@ -89,7 +89,7 @@ def SkalaKS( >>> from skala.gpu4pyscf import SkalaKS >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp") - >>> ks = SkalaKS(mol, xc="skala") + >>> ks = SkalaKS(mol, xc="skala-1.1") >>> ks = ks.density_fit(auxbasis="def2-svp-jkfit") # Optional: use density fitting >>> ks = ks.set(verbose=0) >>> energy = ks.kernel() @@ -102,6 +102,7 @@ def SkalaKS( """ if isinstance(xc, str): xc = load_functional(xc, device=torch.device("cuda:0")) + assert isinstance(xc, ExcFunctionalBase) if mol.spin == 0: return SkalaRKS( mol, @@ -128,7 +129,7 @@ def SkalaKS( def SkalaRKS( mol: gto.Mole, - xc: ExcFunctionalBase, + xc: ExcFunctionalBase | str, *, with_density_fit: bool = False, with_newton: bool = False, @@ -172,7 +173,7 @@ def SkalaRKS( >>> import torch >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp") - >>> ks = SkalaRKS(mol, xc=load_functional("skala", device=torch.device("cuda:0")), with_density_fit=True)(verbose=0) + >>> ks = SkalaRKS(mol, xc=load_functional("skala-1.1", device=torch.device("cuda:0")), with_density_fit=True)(verbose=0) >>> ks # DOCTEST: Ellipsis >>> energy = ks.kernel() @@ -181,14 +182,12 @@ def SkalaRKS( """ if isinstance(xc, str): xc = load_functional(xc, device=torch.device("cuda:0")) - ks = dft.SkalaRKS(mol, xc) + assert isinstance(xc, ExcFunctionalBase) + ks = dft.SkalaRKS(mol, xc, with_dftd3=with_dftd3) if ks_config is not None: ks = ks(**ks_config) - if not with_dftd3: - ks.with_dftd3 = None - if with_density_fit: ks = ks.density_fit(auxbasis=auxbasis) else: @@ -207,7 +206,7 @@ def SkalaRKS( def SkalaUKS( mol: gto.Mole, - xc: ExcFunctionalBase, + xc: ExcFunctionalBase | str, *, with_density_fit: bool = False, with_newton: bool = False, @@ -251,7 +250,7 @@ def SkalaUKS( >>> import torch >>> >>> mol = gto.M(atom="H", basis="def2-svp", spin=1) - >>> ks = SkalaUKS(mol, xc=load_functional("skala", device=torch.device("cuda:0")), with_density_fit=True, auxbasis="def2-svp-jkfit")(verbose=0) + >>> ks = SkalaUKS(mol, xc=load_functional("skala-1.1", device=torch.device("cuda:0")), with_density_fit=True, auxbasis="def2-svp-jkfit")(verbose=0) >>> ks # DOCTEST: Ellipsis >>> energy = ks.kernel() @@ -260,14 +259,12 @@ def SkalaUKS( """ if isinstance(xc, str): xc = load_functional(xc, device=torch.device("cuda:0")) - ks = dft.SkalaUKS(mol, xc) + assert isinstance(xc, ExcFunctionalBase) + ks = dft.SkalaUKS(mol, xc, with_dftd3=with_dftd3) if ks_config is not None: ks = ks(**ks_config) - if not with_dftd3: - ks.with_dftd3 = None - if with_density_fit: ks = ks.density_fit(auxbasis=auxbasis) else: diff --git a/src/skala/gpu4pyscf/dft.py b/src/skala/gpu4pyscf/dft.py index eca032e..eece484 100644 --- a/src/skala/gpu4pyscf/dft.py +++ b/src/skala/gpu4pyscf/dft.py @@ -15,15 +15,15 @@ >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp", verbose=0) >>> # Create restricted KS calculator ->>> rks = dft.SkalaRKS(mol, xc=load_functional("skala", device=torch.device("cuda:0"))) +>>> rks = dft.SkalaRKS(mol, xc=load_functional("skala-1.1", device=torch.device("cuda:0"))) >>> energy = rks.kernel() >>> print(energy) # DOCTEST: Ellipsis --1.142654... +-1.142903... >>> # Create unrestricted KS calculator ->>> uks = dft.SkalaUKS(mol, xc=load_functional("skala", device=torch.device("cuda:0"))) +>>> uks = dft.SkalaUKS(mol, xc=load_functional("skala-1.1", device=torch.device("cuda:0"))) >>> energy = uks.kernel() >>> print(energy) # DOCTEST: Ellipsis --1.142654... +-1.142903... The `SkalaRKS` and `SkalaUKS` classes can be used in the same way as (GPU4)PySCF's `dft.rks.RKS `__ and @@ -36,7 +36,7 @@ >>> import torch >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp") ->>> ks = dft.SkalaRKS(mol, xc=load_functional("skala", device=torch.device("cuda:0"))) +>>> ks = dft.SkalaRKS(mol, xc=load_functional("skala-1.1", device=torch.device("cuda:0"))) >>> # Apply density fitting >>> ks = ks.density_fit(auxbasis="def2-svp-jkfit") >>> ks # DOCTEST: Ellipsis @@ -60,7 +60,6 @@ from dftd3.pyscf import DFTD3Dispersion from gpu4pyscf import dft from gpu4pyscf.df import df_jk -from pyscf import __version__ as pyscf_version from pyscf import gto # Set the default CuPy memory allocator to avoid memory leak issues @@ -68,6 +67,7 @@ from skala.functional.base import ExcFunctionalBase from skala.gpu4pyscf.gradients import SkalaRKSGradient, SkalaUKSGradient +from skala.pyscf.dft import _build_grids_unsorted, _needs_unsorted_grids from skala.pyscf.numint import SkalaNumInt from skala.pyscf.utils import pyscf_version_newer_than_2_10 @@ -78,13 +78,27 @@ class SkalaRKS(dft.rks.RKS): # type: ignore[misc] with_dftd3: DFTD3Dispersion | None = None """DFT-D3 dispersion correction.""" - def __init__(self, mol: gto.Mole, xc: ExcFunctionalBase): + def __init__( + self, mol: gto.Mole, xc: ExcFunctionalBase, *, with_dftd3: bool = True + ): super().__init__(mol, xc="custom") self._keys.add("with_dftd3") self._numint = SkalaNumInt(xc, device=torch.device("cuda:0")) d3 = xc.get_d3_settings() - self.with_dftd3 = DFTD3Dispersion(mol, d3) if d3 is not None else None + self.with_dftd3 = ( + DFTD3Dispersion(mol, d3) if with_dftd3 and d3 is not None else None + ) + + self._needs_unsorted = _needs_unsorted_grids(xc) + if self._needs_unsorted: + _build_grids_unsorted(self.grids, mol) + + def kernel(self, dm0: Any = None, **kwargs: Any) -> float: + # Ensure grids stay unsorted even if user changed grid settings after __init__ + if self._needs_unsorted and self.grids.coords is None: + _build_grids_unsorted(self.grids, self.mol) + return super().kernel(dm0, **kwargs) def energy_nuc(self) -> float: enuc = float(super().energy_nuc()) @@ -119,7 +133,19 @@ def density_fit( with_df: bool | None = None, only_dfj: bool | None = True, ) -> "SkalaRKS": - ks = df_jk.density_fit(self, auxbasis, with_df, only_dfj) + if pyscf_version_newer_than_2_10() and auxbasis is None: + warnings.warn( + "Using density_fit without specifying auxbasis will lead to different behavior in PySCF >= 2.10.0 compared to PySCF 2.9.0, which was used for benchmarking skala. To reproduce benchmarks, please specify an auxbasis (def2-universal-jkfit for (ma-)def2 basis sets).", + stacklevel=2, + ) + + # We temporarily need to swap out xc for a known functional to satisfy df_jk.density_fit's checks, but we'll swap it back before returning. + try: + real_xc: ExcFunctionalBase | str = self.xc # type: ignore[has-type] + self.xc = "tpss" + ks = df_jk.density_fit(self, auxbasis, with_df, only_dfj) + finally: + ks.xc = real_xc ks.Gradients = lambda: SkalaRKSGradient(ks) ks.nuc_grad_method = ks.Gradients return cast(SkalaRKS, ks) @@ -131,13 +157,27 @@ class SkalaUKS(dft.uks.UKS): # type: ignore[misc] with_dftd3: DFTD3Dispersion | None = None """DFT-D3 dispersion correction.""" - def __init__(self, mol: gto.Mole, xc: ExcFunctionalBase): + def __init__( + self, mol: gto.Mole, xc: ExcFunctionalBase, *, with_dftd3: bool = True + ): super().__init__(mol, xc="custom") self._keys.add("with_dftd3") self._numint = SkalaNumInt(xc, device=torch.device("cuda:0")) d3 = xc.get_d3_settings() - self.with_dftd3 = DFTD3Dispersion(mol, d3) if d3 is not None else None + self.with_dftd3 = ( + DFTD3Dispersion(mol, d3) if with_dftd3 and d3 is not None else None + ) + + self._needs_unsorted = _needs_unsorted_grids(xc) + if self._needs_unsorted: + _build_grids_unsorted(self.grids, mol) + + def kernel(self, dm0: Any = None, **kwargs: Any) -> float: + # Ensure grids stay unsorted even if user changed grid settings after __init__ + if self._needs_unsorted and self.grids.coords is None: + _build_grids_unsorted(self.grids, self.mol) + return super().kernel(dm0, **kwargs) def energy_nuc(self) -> float: enuc = float(super().energy_nuc()) @@ -175,8 +215,16 @@ def density_fit( if pyscf_version_newer_than_2_10() and auxbasis is None: warnings.warn( "Using density_fit without specifying auxbasis will lead to different behavior in PySCF >= 2.10.0 compared to PySCF 2.9.0, which was used for benchmarking skala. To reproduce benchmarks, please specify an auxbasis (def2-universal-jkfit for (ma-)def2 basis sets).", + stacklevel=2, ) - ks = df_jk.density_fit(self, auxbasis, with_df, only_dfj) + + # We temporarily need to swap out xc for a known functional to satisfy df_jk.density_fit's checks, but we'll swap it back before returning. + try: + real_xc: ExcFunctionalBase | str = self.xc # type: ignore[has-type] + self.xc = "tpss" + ks = df_jk.density_fit(self, auxbasis, with_df, only_dfj) + finally: + ks.xc = real_xc ks.Gradients = lambda: SkalaUKSGradient(ks) ks.nuc_grad_method = ks.Gradients return cast(SkalaUKS, ks) diff --git a/src/skala/gpu4pyscf/gradients.py b/src/skala/gpu4pyscf/gradients.py index 83ba10a..3dc1728 100644 --- a/src/skala/gpu4pyscf/gradients.py +++ b/src/skala/gpu4pyscf/gradients.py @@ -42,12 +42,17 @@ def veff_and_expl_nuc_grad( "kin", "grid_coords", "grid_weights", + "atomic_grid_weights", "coarse_0_atomic_coords", } if nuc_grad_feats is None: # generate feature list from functional features nuc_grad_feats = set(functional.features) + # Integer-valued features have no nuclear gradient — always discard them + nuc_grad_feats.discard("atomic_grid_sizes") + nuc_grad_feats.discard("atomic_grid_size_bound_shape") + # check for unsupported features unsupported_feats = {feat for feat in nuc_grad_feats if feat not in SUPPORTED_FEATS} if unsupported_feats != set(): @@ -79,6 +84,11 @@ def veff_and_expl_nuc_grad( mol, rdm1, grid_, set(functional.features), gpu=True ) + # Discard atomic_grid_weights from VJP features: d(atomic_grid_weights)/dR = 0 + # because they are raw quadrature weights that depend only on the radial/angular + # grid rule, not on nuclear positions. They still pass through as other_feats. + nuc_grad_feats.discard("atomic_grid_weights") + # Get required derivatives nuc_feat_names = list(nuc_grad_feats) # ensure specific order nuc_feat_tensors = [mol_feats[feat] for feat in nuc_feat_names] @@ -92,13 +102,27 @@ def exc_feat_func(*nuc_feat_tensors: torch.Tensor) -> torch.Tensor: ) return functional.get_exc(exc_mol_feats) - _, dExc_func = torch.func.vjp(exc_feat_func, *nuc_feat_tensors) - dExc_tuple = dExc_func(torch.tensor(1.0, dtype=rdm1.dtype, device=rdm1.device)) + # torch.func.vjp wraps primals in functional tensors that may not expose + # backing storage, which TorchScript traced models can reject. + diff_nuc_feat_tensors = [ + feat.detach().requires_grad_(True) for feat in nuc_feat_tensors + ] + if len(diff_nuc_feat_tensors) > 0: + exc = exc_feat_func(*diff_nuc_feat_tensors) + dExc_tuple = torch.autograd.grad( + exc, + tuple(diff_nuc_feat_tensors), + create_graph=False, + retain_graph=False, + allow_unused=False, + ) + else: + dExc_tuple = () dExc: dict[str, torch.Tensor] = {} for i in range(len(dExc_tuple)): dExc[nuc_feat_names[i]] = dExc_tuple[i].detach() - LOG.debug("torch.func.vjp done") + LOG.debug("autograd gradients for nuclear features done") nao = rdm1.shape[-1] veff = torch.zeros((2, 3, nao, nao), dtype=rdm1.dtype, device=rdm1.device) @@ -107,7 +131,7 @@ def exc_feat_func(*nuc_feat_tensors: torch.Tensor) -> torch.Tensor: atm_start = 0 for atm_id, (coords, weight, weight1) in enumerate(grids_response_cc(grid)): mask = dft.gen_grid.make_mask(mol, coords) - ao = torch.from_dlpack( + ao = torch.from_dlpack( # type: ignore[attr-defined] dft.numint.eval_ao( mol, coords, @@ -202,11 +226,11 @@ def exc_feat_func(*nuc_feat_tensors: torch.Tensor) -> torch.Tensor: if "grid_coords" in nuc_grad_feats: # also add the explicit grid coordinate dependence - nuc_grad[atm_id] += dExc["grid_coords"][atm_start:atm_end].sum(axis=0) + nuc_grad[atm_id] += dExc["grid_coords"][atm_start:atm_end].sum(dim=0) if "grid_weights" in nuc_grad_feats: Exc_dgw = dExc["grid_weights"][atm_start:atm_end] - nuc_grad += torch.from_dlpack(weight1) @ Exc_dgw + nuc_grad += torch.from_dlpack(weight1) @ Exc_dgw # type: ignore[attr-defined] # add the grid coordinate dependence via the density-like quantities to the nuclear gradient # we get those from the veff block. This tends to largely cancel with the grid_weights derivative, # so that's why we include it here. @@ -280,12 +304,12 @@ def get_veff( self.functional, mol=mol, grid=self.grids, - rdm1=torch.from_dlpack(dm), + rdm1=torch.from_dlpack(dm), # type: ignore[attr-defined] nuc_grad_feats=self.nuc_grad_feats, ) vhfopt = self.base._opt_gpu.get(mol.omega) return cp.from_dlpack( - nuc_grad_from_veff(mol, veff, torch.from_dlpack(dm)) + nuc_grad_from_veff(mol, veff, torch.from_dlpack(dm)) # type: ignore[attr-defined] ) + _jk_energy_per_atom(mol, dm, vhfopt, k_factor=0.0, verbose=self.verbose) def grad_elec( @@ -358,12 +382,12 @@ def get_veff( self.functional, mol=mol, grid=self.grids, - rdm1=torch.from_dlpack(dm), + rdm1=torch.from_dlpack(dm), # type: ignore[attr-defined] nuc_grad_feats=self.nuc_grad_feats, ) vhfopt = self.base._opt_gpu.get(mol.omega) return cp.from_dlpack( - nuc_grad_from_veff(mol, veff, torch.from_dlpack(dm)) + nuc_grad_from_veff(mol, veff, torch.from_dlpack(dm)) # type: ignore[attr-defined] ) + _jk_energy_per_atom(mol, dm, vhfopt, k_factor=0.0, verbose=self.verbose) def grad_elec( diff --git a/src/skala/pyscf/__init__.py b/src/skala/pyscf/__init__.py index 9bfbbc1..f6b7aec 100644 --- a/src/skala/pyscf/__init__.py +++ b/src/skala/pyscf/__init__.py @@ -8,15 +8,10 @@ with neural network-based functionals. """ -try: - import pyscf # noqa: F401 -except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "PySCF is not installed. Please install it with `pip install pyscf` or `conda install pyscf`." - ) from e - from typing import Any +import torch +from pyscf import dft as pyscf_dft from pyscf import gto from skala.functional import ExcFunctionalBase, load_functional @@ -33,6 +28,7 @@ def SkalaKS( auxbasis: str | None = None, ks_config: dict[str, Any] | None = None, soscf_config: dict[str, Any] | None = None, + device: torch.device | None = None, ) -> dft.SkalaRKS | dft.SkalaUKS: """ Create a Kohn-Sham calculator for the Skala functional. @@ -55,6 +51,8 @@ def SkalaKS( Additional configuration options for the Kohn-Sham calculator. Default is None. soscf_config : dict, optional Additional configuration options for the second-order SCF (SOSCF) method. Default is None. + device : torch.device, optional + The device to run the calculations on. Default is None. Returns ------- @@ -68,19 +66,29 @@ def SkalaKS( >>> from skala.pyscf import SkalaKS >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp") - >>> ks = SkalaKS(mol, xc=load_functional("skala")) + >>> ks = SkalaKS(mol, xc=load_functional("skala-1.1")) >>> ks = ks.density_fit(auxbasis="def2-svp-jkfit") # Optional: use density fitting >>> ks = ks.set(verbose=0) >>> energy = ks.kernel() >>> print(energy) # DOCTEST: Ellipsis - -1.142773... + -1.143024... >>> ks = ks.nuc_grad_method() >>> gradient = ks.kernel() >>> print(abs(gradient).mean()) # DOCTEST: Ellipsis - 0.029477... + 0.029415... """ if isinstance(xc, str): xc = load_functional(xc) + if isinstance(xc, str): + return _create_native_pyscf_ks( + mol, + xc, + with_density_fit=with_density_fit, + with_newton=with_newton, + auxbasis=auxbasis, + ks_config=ks_config, + soscf_config=soscf_config, + ) if mol.spin == 0: return SkalaRKS( mol, @@ -91,6 +99,7 @@ def SkalaKS( auxbasis=auxbasis, ks_config=ks_config, soscf_config=soscf_config, + device=device, ) else: return SkalaUKS( @@ -102,6 +111,7 @@ def SkalaKS( auxbasis=auxbasis, ks_config=ks_config, soscf_config=soscf_config, + device=device, ) @@ -115,6 +125,7 @@ def SkalaRKS( auxbasis: str | None = None, ks_config: dict[str, Any] | None = None, soscf_config: dict[str, Any] | None = None, + device: torch.device | None = None, ) -> dft.SkalaRKS: """ Create a restricted Kohn-Sham calculator for the Skala functional. @@ -137,6 +148,8 @@ def SkalaRKS( Additional configuration options for the Kohn-Sham calculator. Default is None. soscf_config : dict, optional Additional configuration options for the second-order SCF (SOSCF) method. Default is None. + device : torch.device, optional + The device to run the calculations on. Default is None. Returns ------- @@ -149,37 +162,36 @@ def SkalaRKS( >>> from skala.pyscf import SkalaRKS >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp") - >>> ks = SkalaRKS(mol, xc="skala", with_density_fit=True, auxbasis="def2-svp-jkfit")(verbose=0) + >>> ks = SkalaRKS(mol, xc="skala-1.1", with_density_fit=True, auxbasis="def2-svp-jkfit")(verbose=0) >>> ks # DOCTEST: Ellipsis >>> energy = ks.kernel() >>> print(energy) # DOCTEST: Ellipsis - -1.142773... + -1.143024... """ if isinstance(xc, str): xc = load_functional(xc) - ks = dft.SkalaRKS(mol, xc) - - if ks_config is not None: - ks = ks(**ks_config) - - if not with_dftd3: - ks.with_dftd3 = None - - if with_density_fit: - ks = ks.density_fit(auxbasis=auxbasis) - else: - if auxbasis is not None: - raise ValueError( - "Auxiliary basis can only be set when density fitting is enabled." - ) - - if with_newton: - ks = ks.newton() - if soscf_config is not None: - ks.__dict__.update(soscf_config) + if isinstance(xc, str): + ks = pyscf_dft.RKS(mol) + ks.xc = xc + return _apply_ks_config( + ks, + with_density_fit=with_density_fit, + with_newton=with_newton, + auxbasis=auxbasis, + ks_config=ks_config, + soscf_config=soscf_config, + ) + ks = dft.SkalaRKS(mol, xc, device=device, with_dftd3=with_dftd3) - return ks + return _apply_ks_config( + ks, + with_density_fit=with_density_fit, + with_newton=with_newton, + auxbasis=auxbasis, + ks_config=ks_config, + soscf_config=soscf_config, + ) def SkalaUKS( @@ -192,6 +204,7 @@ def SkalaUKS( auxbasis: str | None = None, ks_config: dict[str, Any] | None = None, soscf_config: dict[str, Any] | None = None, + device: torch.device | None = None, ) -> dft.SkalaUKS: """ Create an unrestricted Kohn-Sham calculator for the Skala functional. @@ -214,6 +227,8 @@ def SkalaUKS( Additional configuration options for the Kohn-Sham calculator. Default is None. soscf_config : dict, optional Additional configuration options for the second-order SCF (SOSCF) method. Default is None. + device : torch.device, optional + The device to run the calculations on. Default is None. Returns ------- @@ -226,34 +241,82 @@ def SkalaUKS( >>> from skala.pyscf import SkalaUKS >>> >>> mol = gto.M(atom="H", basis="def2-svp", spin=1) - >>> ks = SkalaUKS(mol, xc="skala", with_density_fit=True, auxbasis="def2-svp-jkfit")(verbose=0) + >>> ks = SkalaUKS(mol, xc="skala-1.1", with_density_fit=True, auxbasis="def2-svp-jkfit")(verbose=0) >>> ks # DOCTEST: Ellipsis >>> energy = ks.kernel() >>> print(energy) # DOCTEST: Ellipsis - -0.499031... + -0.499123... """ if isinstance(xc, str): xc = load_functional(xc) - ks = dft.SkalaUKS(mol, xc) + if isinstance(xc, str): + ks = pyscf_dft.UKS(mol) + ks.xc = xc + return _apply_ks_config( + ks, + with_density_fit=with_density_fit, + with_newton=with_newton, + auxbasis=auxbasis, + ks_config=ks_config, + soscf_config=soscf_config, + ) + ks = dft.SkalaUKS(mol, xc, device=device, with_dftd3=with_dftd3) - if ks_config is not None: - ks = ks(**ks_config) + return _apply_ks_config( + ks, + with_density_fit=with_density_fit, + with_newton=with_newton, + auxbasis=auxbasis, + ks_config=ks_config, + soscf_config=soscf_config, + ) - if not with_dftd3: - ks.with_dftd3 = None +def _apply_ks_config( + ks: "dft.SkalaRKS | dft.SkalaUKS", + *, + with_density_fit: bool, + with_newton: bool, + auxbasis: str | None, + ks_config: dict[str, Any] | None, + soscf_config: dict[str, Any] | None, +) -> "dft.SkalaRKS | dft.SkalaUKS": + """Apply common KS configuration (grids, density fitting, Newton, SOSCF).""" + if ks_config is not None: + ks = ks(**ks_config) if with_density_fit: ks = ks.density_fit(auxbasis=auxbasis) - else: - if auxbasis is not None: - raise ValueError( - "Auxiliary basis can only be set when density fitting is enabled." - ) - + elif auxbasis is not None: + raise ValueError( + "Auxiliary basis can only be set when density fitting is enabled." + ) if with_newton: ks = ks.newton() if soscf_config is not None: ks.__dict__.update(soscf_config) - return ks + + +def _create_native_pyscf_ks( + mol: gto.Mole, + xc_name: str, + *, + with_density_fit: bool, + with_newton: bool, + auxbasis: str | None, + ks_config: dict[str, Any] | None, + soscf_config: dict[str, Any] | None, +) -> "pyscf_dft.rks.RKS | pyscf_dft.uks.UKS": + """Create a native PySCF KS calculator for standard functionals.""" + cls = pyscf_dft.RKS if mol.spin == 0 else pyscf_dft.UKS + ks = cls(mol) + ks.xc = xc_name + return _apply_ks_config( + ks, + with_density_fit=with_density_fit, + with_newton=with_newton, + auxbasis=auxbasis, + ks_config=ks_config, + soscf_config=soscf_config, + ) diff --git a/src/skala/pyscf/backend.py b/src/skala/pyscf/backend.py index 23987ef..4ba38e6 100644 --- a/src/skala/pyscf/backend.py +++ b/src/skala/pyscf/backend.py @@ -76,7 +76,7 @@ def from_numpy_or_cupy( if isinstance(x, np.ndarray): x_torch = torch.from_numpy(x) else: - x_torch = torch.from_dlpack(x) + x_torch = torch.from_dlpack(x) # type: ignore[attr-defined] x_torch = x_torch.to(device=device, dtype=dtype) if transpose: return x_torch.transpose(-1, -2) diff --git a/src/skala/pyscf/dft.py b/src/skala/pyscf/dft.py index 2fe2761..071b47f 100644 --- a/src/skala/pyscf/dft.py +++ b/src/skala/pyscf/dft.py @@ -13,17 +13,17 @@ >>> from skala.pyscf import dft >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp", verbose=0) ->>> func = load_functional("skala") +>>> func = load_functional("skala-1.1") >>> # Create restricted KS calculator >>> rks = dft.SkalaRKS(mol, xc=func) >>> energy = rks.kernel() >>> print(energy) # DOCTEST: Ellipsis --1.142654... +-1.142903... >>> # Create unrestricted KS calculator >>> uks = dft.SkalaUKS(mol, xc=func) >>> energy = uks.kernel() >>> print(energy) # DOCTEST: Ellipsis --1.142654... +-1.142903... The `SkalaRKS` and `SkalaUKS` classes can be used in the same way as PySCF's `dft.rks.RKS `__ and @@ -35,7 +35,7 @@ >>> from skala.pyscf import dft >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp") ->>> ks = dft.SkalaRKS(mol, xc=load_functional("skala")) +>>> ks = dft.SkalaRKS(mol, xc=load_functional("skala-1.1")) >>> # Apply density fitting >>> ks = ks.density_fit(auxbasis="def2-svp-jkfit") >>> ks # DOCTEST: Ellipsis @@ -61,11 +61,25 @@ from pyscf.df import df_jk from skala.functional.base import ExcFunctionalBase +from skala.pyscf.features import _ATOMIC_GRID_FEATURES from skala.pyscf.gradients import SkalaRKSGradient, SkalaUKSGradient from skala.pyscf.numint import SkalaNumInt from skala.pyscf.utils import pyscf_version_newer_than_2_10 +def _needs_unsorted_grids(func: ExcFunctionalBase) -> bool: + """Return True when the functional needs per-atom grid ordering.""" + return bool(set(func.features) & _ATOMIC_GRID_FEATURES) + + +def _build_grids_unsorted( + grids: dft.gen_grid.Grids, mol: gto.Mole +) -> dft.gen_grid.Grids: + """Build grids without sorting, preserving per-atom ordering.""" + grids.build(mol, sort_grids=False) + return grids + + class SkalaRKS(dft.rks.RKS): # type: ignore[misc] """Restricted Kohn-Sham method with support for Skala functional.""" @@ -74,13 +88,32 @@ class SkalaRKS(dft.rks.RKS): # type: ignore[misc] with_dftd3: DFTD3Dispersion | None = None """DFT-D3 dispersion correction.""" - def __init__(self, mol: gto.Mole, xc: ExcFunctionalBase): + def __init__( + self, + mol: gto.Mole, + xc: ExcFunctionalBase, + device: torch.device | None = None, + *, + with_dftd3: bool = True, + ): super().__init__(mol, xc="custom") self._keys.add("with_dftd3") - self._numint = SkalaNumInt(xc, device=torch.device("cpu")) + self._numint = SkalaNumInt(xc, device=device or torch.device("cpu")) + self._needs_unsorted = _needs_unsorted_grids(xc) d3 = xc.get_d3_settings() - self.with_dftd3 = DFTD3Dispersion(mol, d3) if d3 is not None else None + self.with_dftd3 = ( + DFTD3Dispersion(mol, d3) if with_dftd3 and d3 is not None else None + ) + + if self._needs_unsorted: + _build_grids_unsorted(self.grids, mol) + + def kernel(self, dm0: np.ndarray | None = None, **kwargs: Any) -> float: + # Ensure grids stay unsorted even if user changed grid settings after __init__ + if self._needs_unsorted and self.grids.coords is None: + _build_grids_unsorted(self.grids, self.mol) + return super().kernel(dm0, **kwargs) def energy_nuc(self) -> float: enuc = float(super().energy_nuc()) @@ -114,6 +147,11 @@ def density_fit( with_df: Any = None, only_dfj: bool = True, ) -> "SkalaRKS": + if pyscf_version_newer_than_2_10() and auxbasis is None: + warnings.warn( + "Using density_fit without specifying auxbasis will lead to different behavior in PySCF >= 2.10.0 compared to PySCF 2.9.0, which was used for benchmarking skala. To reproduce benchmarks, please specify an auxbasis (def2-universal-jkfit for (ma-)def2 basis sets).", + stacklevel=2, + ) xc, self.xc = ( self.xc, "tpss", @@ -133,13 +171,32 @@ class SkalaUKS(dft.uks.UKS): # type: ignore[misc] with_dftd3: DFTD3Dispersion | None = None """DFT-D3 dispersion correction.""" - def __init__(self, mol: gto.Mole, xc: ExcFunctionalBase): + def __init__( + self, + mol: gto.Mole, + xc: ExcFunctionalBase, + device: torch.device | None = None, + *, + with_dftd3: bool = True, + ): super().__init__(mol, xc="custom") self._keys.add("with_dftd3") - self._numint = SkalaNumInt(xc, device=torch.device("cpu")) + self._numint = SkalaNumInt(xc, device=device or torch.device("cpu")) + self._needs_unsorted = _needs_unsorted_grids(xc) d3 = xc.get_d3_settings() - self.with_dftd3 = DFTD3Dispersion(mol, d3) if d3 is not None else None + self.with_dftd3 = ( + DFTD3Dispersion(mol, d3) if with_dftd3 and d3 is not None else None + ) + + if self._needs_unsorted: + _build_grids_unsorted(self.grids, mol) + + def kernel(self, dm0: np.ndarray | None = None, **kwargs: Any) -> float: + # Ensure grids stay unsorted even if user changed grid settings after __init__ + if self._needs_unsorted and self.grids.coords is None: + _build_grids_unsorted(self.grids, self.mol) + return super().kernel(dm0, **kwargs) def energy_nuc(self) -> float: enuc = float(super().energy_nuc()) @@ -176,6 +233,7 @@ def density_fit( if pyscf_version_newer_than_2_10() and auxbasis is None: warnings.warn( "Using density_fit without specifying auxbasis will lead to different behavior in PySCF >= 2.10.0 compared to PySCF 2.9.0, which was used for benchmarking skala. To reproduce benchmarks, please specify an auxbasis (def2-universal-jkfit for (ma-)def2 basis sets).", + stacklevel=2, ) xc, self.xc = ( diff --git a/src/skala/pyscf/features.py b/src/skala/pyscf/features.py index 8aeb4f5..5fa39d8 100644 --- a/src/skala/pyscf/features.py +++ b/src/skala/pyscf/features.py @@ -25,6 +25,13 @@ DEFAULT_FEATURES = ["density", "kin", "grad", "grid_coords", "grid_weights"] DEFAULT_FEATURES_SET = set(DEFAULT_FEATURES) +# Features that require per-atom grid decomposition. +_ATOMIC_GRID_FEATURES = { + "atomic_grid_weights", + "atomic_grid_sizes", + "atomic_grid_size_bound_shape", +} + def maybe_expand_and_divide( feature: torch.Tensor, expand: bool, divisor: float @@ -98,6 +105,31 @@ def generate_features( mol.atom_coords(), device=dm.device, dtype=dm.dtype ) + if features & _ATOMIC_GRID_FEATURES: + atom_grids_tab = grids.gen_atomic_grids( + mol, grids.atom_grid, grids.radi_method, grids.level, grids.prune + ) + sizes = [len(atom_grids_tab[mol.atom_symbol(ia)][1]) for ia in range(mol.natm)] + + if "atomic_grid_sizes" in features: + mol_features["atomic_grid_sizes"] = torch.tensor( + sizes, dtype=torch.long, device=dm.device + ) + + if "atomic_grid_size_bound_shape" in features: + max_size = max(sizes) + mol_features["atomic_grid_size_bound_shape"] = torch.zeros( + max_size, 0, dtype=torch.long, device=dm.device + ) + + if "atomic_grid_weights" in features: + raw_weights = np.concatenate( + [atom_grids_tab[mol.atom_symbol(ia)][1] for ia in range(mol.natm)] + ) + mol_features["atomic_grid_weights"] = from_numpy_or_cupy( + raw_weights, device=dm.device, dtype=dm.dtype + ) + with_mgga_feature = ( "density" in features or "grad" in features @@ -180,7 +212,7 @@ def reduced_vjp(primals: torch.Tensor) -> torch.Tensor: return reduced_vjp -class FeatureFunction(nn.Module, ABC): # type: ignore[misc] +class FeatureFunction(nn.Module, ABC): deriv: int nfeats: int only_linear_feats: bool @@ -376,7 +408,7 @@ def forward(self, dm: torch.Tensor, ao: torch.Tensor) -> torch.Tensor: return features.reshape((*dm.shape[:-2], self.nfeats, -1)) -class ChunkEvalForward(Function): # type: ignore[misc] +class ChunkEvalForward(Function): @staticmethod def setup_context( ctx: FunctionCtx, @@ -441,6 +473,13 @@ def forward( if len(vectors_jvp) > 1 and feature_function.only_linear_feats: return features + # Pre-sort DM and JVP vectors once (sort_idx is constant across blocks) + sort_idx_t = torch.as_tensor(sort_idx, device=dm.device) + dm_sorted = dm[..., sort_idx_t, :][..., sort_idx_t] + vectors_jvp_sorted = [ + v[..., sort_idx_t, :][..., sort_idx_t] for v in vectors_jvp + ] + end = 0 for ao_block, mask, weights, _ in ni.block_loop( *block_loop_args, **block_loop_kwargs @@ -451,9 +490,7 @@ def forward( mask = torch.arange(mol.nao_nr(), device=dm.device) else: mask = torch.from_dlpack(mask) - masked_dm = dm[..., sort_idx, :][..., sort_idx][ - ..., mask[:, None], mask[None, :] - ] + masked_dm = dm_sorted[..., mask[:, None], mask[None, :]] # Apply chain rule for this particular block partial_func = partial_feature_function_over_aos( @@ -462,12 +499,10 @@ def forward( ao_block, device=dm.device, dtype=dm.dtype, transpose=not gpu ), ) - for vector_jvp in vectors_jvp: + for v_sorted in vectors_jvp_sorted: partial_func = partial_jvp_function_over_tangents( partial_func, - vector_jvp[..., sort_idx, :][..., sort_idx][ - ..., mask[:, None], mask[None, :] - ], + v_sorted[..., mask[:, None], mask[None, :]], ) # Compute feature (or its jvp) for this block with masked dm @@ -496,7 +531,7 @@ def jvp(ctx: FunctionCtx, grad_input: torch.Tensor) -> torch.Tensor: @staticmethod def backward( - ctx: FunctionCtx, grad_output: torch.Tensor + ctx: FunctionCtx, *grad_outputs: torch.Tensor ) -> tuple[torch.Tensor | None, ...]: # After one vjp (backward) the signature of the function changes from dm.shape -> (*dm.shape[:-2], nfeats, ngrid) to dm.shape -> dm.shape # therefore we move to a different function that does essentially the same thing, but with the new signature @@ -513,7 +548,7 @@ def backward( ctx.compile_feature_function, ctx.gpu, *ctx.vectors_jvp, - grad_output, + *grad_outputs, ) ] @@ -539,7 +574,7 @@ def backward( ctx.compile_feature_function, ctx.gpu, *ctx.vectors_jvp[:i], - grad_output, + *grad_outputs, *ctx.vectors_jvp[i + 1 :], ) ) @@ -547,7 +582,7 @@ def backward( return tuple(grads) -class ChunkEvalBackward(Function): # type: ignore[misc] +class ChunkEvalBackward(Function): @staticmethod def setup_context( ctx: FunctionCtx, @@ -608,7 +643,15 @@ def forward( if len(vectors) > 1 and feature_function.only_linear_feats: return out - unsort_idx = torch.argsort(torch.tensor(sort_idx)) + # Pre-sort DM and derivative vectors once (sort_idx is constant across blocks) + sort_idx_t = torch.as_tensor(sort_idx, device=dm.device) + unsort_idx = torch.argsort(sort_idx_t) + dm_sorted = dm[..., sort_idx_t, :][..., sort_idx_t] + vectors_sorted = [ + v[..., sort_idx_t, :][..., sort_idx_t] if dt in ("jvp", "vjp") else v + for dt, v in zip(derivative_types, vectors, strict=True) + ] + for ao_block, mask, weights, _ in ni.block_loop( *block_loop_args, **block_loop_kwargs, @@ -629,20 +672,18 @@ def forward( ao_block, device=dm.device, dtype=dm.dtype, transpose=not gpu ), ) - for derivative_type, vector in zip(derivative_types, vectors, strict=True): + for derivative_type, vector, v_sorted in zip( + derivative_types, vectors, vectors_sorted, strict=True + ): if derivative_type == "jvp": partial_func = partial_jvp_function_over_tangents( partial_func, - vector[..., sort_idx, :][..., sort_idx][ - ..., mask[:, None], mask[None, :] - ], + v_sorted[..., mask[:, None], mask[None, :]], ) elif derivative_type == "vjp": partial_func = partial_vjp_function_over_tangents( partial_func, - vector[..., sort_idx, :][..., sort_idx][ - ..., mask[:, None], mask[None, :] - ], + v_sorted[..., mask[:, None], mask[None, :]], ) elif derivative_type == "first_vjp": partial_func = partial_vjp_function_over_tangents( @@ -654,15 +695,11 @@ def forward( ) if compile_feature_function: out[..., mask[:, None], mask[None, :]] += torch.compile(partial_func)( - dm[..., sort_idx, :][..., sort_idx][ - ..., mask[:, None], mask[None, :] - ] + dm_sorted[..., mask[:, None], mask[None, :]] ) else: out[..., mask[:, None], mask[None, :]] += partial_func( - dm[..., sort_idx, :][..., sort_idx][ - ..., mask[:, None], mask[None, :] - ] + dm_sorted[..., mask[:, None], mask[None, :]] ) return out[..., unsort_idx, :][..., unsort_idx] @@ -684,7 +721,7 @@ def jvp(ctx: FunctionCtx, *grad_input: torch.Tensor) -> torch.Tensor: @staticmethod def backward( - ctx: FunctionCtx, grad_output: torch.Tensor + ctx: FunctionCtx, *grad_outputs: torch.Tensor ) -> tuple[torch.Tensor | None, ...]: # Chain rule for the vjp @@ -700,7 +737,7 @@ def backward( ctx.compile_feature_function, ctx.gpu, *ctx.vectors, - grad_output, + *grad_outputs, ) ] # We need to provide None for the gradients of the non-differentiable inputs @@ -725,7 +762,7 @@ def backward( ctx.compile_feature_function, ctx.gpu, *ctx.vectors[:i], - grad_output, + *grad_outputs, *ctx.vectors[i + 1 :], ) ) @@ -740,7 +777,7 @@ def backward( ctx.compile_feature_function, ctx.gpu, *ctx.vectors[:i], - grad_output, + *grad_outputs, *ctx.vectors[i + 1 :], ) ) diff --git a/src/skala/pyscf/gradients.py b/src/skala/pyscf/gradients.py index 5e8c3c1..02c0a97 100644 --- a/src/skala/pyscf/gradients.py +++ b/src/skala/pyscf/gradients.py @@ -39,12 +39,17 @@ def veff_and_expl_nuc_grad( "kin", "grid_coords", "grid_weights", + "atomic_grid_weights", "coarse_0_atomic_coords", } if nuc_grad_feats is None: # generate feature list from functional features nuc_grad_feats = set(functional.features) + # Integer-valued features have no nuclear gradient — always discard them + nuc_grad_feats.discard("atomic_grid_sizes") + nuc_grad_feats.discard("atomic_grid_size_bound_shape") + # check for unsupported features unsupported_feats = {feat for feat in nuc_grad_feats if feat not in SUPPORTED_FEATS} if unsupported_feats != set(): @@ -74,6 +79,11 @@ def veff_and_expl_nuc_grad( grid_.weights = np.concatenate(weight_list) mol_feats = feature.generate_features(mol, rdm1, grid_, set(functional.features)) + # Discard atomic_grid_weights from VJP features: d(atomic_grid_weights)/dR = 0 + # because they are raw quadrature weights that depend only on the radial/angular + # grid rule, not on nuclear positions. They still pass through as other_feats. + nuc_grad_feats.discard("atomic_grid_weights") + # Get required derivatives nuc_feat_names = list(nuc_grad_feats) # ensure specific order nuc_feat_tensors = [mol_feats[feat] for feat in nuc_feat_names] @@ -194,7 +204,7 @@ def exc_feat_func(*nuc_feat_tensors: torch.Tensor) -> torch.Tensor: if "grid_coords" in nuc_grad_feats: # also add the explicit grid coordinate dependence - nuc_grad[atm_id] += dExc["grid_coords"][atm_start:atm_end].sum(axis=0) + nuc_grad[atm_id] += dExc["grid_coords"][atm_start:atm_end].sum(dim=0) if "grid_weights" in nuc_grad_feats: Exc_dgw = dExc["grid_weights"][atm_start:atm_end] diff --git a/src/skala/pyscf/numint.py b/src/skala/pyscf/numint.py index c98a192..110a6b9 100644 --- a/src/skala/pyscf/numint.py +++ b/src/skala/pyscf/numint.py @@ -98,10 +98,12 @@ class SkalaNumInt(PySCFNumInt[Array]): >>> >>> mol = gto.M(atom="H 0 0 0; H 0 0 1", basis="def2-svp", verbose=0) >>> ks = dft.KS(mol) - >>> ks._numint = SkalaNumInt(load_functional("skala")) + >>> ks._numint = SkalaNumInt(load_functional("skala-1.1")) + >>> ks.grids.build(mol, sort_grids=False) # DOCTEST: Ellipsis + >>> energy = ks.kernel() >>> print(energy) # DOCTEST: Ellipsis - -1.142330... + -1.1425799... """ device: torch.device @@ -133,7 +135,7 @@ def from_backend( def to_backend(self, x: Tensor | list[Tensor]) -> Array | list[Array]: if isinstance(x, list): - return [self.to_backend(y) for y in x] + return [self.to_backend(y) for y in x] # type: ignore if self.device.type == "cuda": return to_cupy(x) @@ -157,7 +159,7 @@ def get_rho( max_memory=max_memory, gpu=self.device.type == "cuda", ) - return self.to_backend(mol_features["density"].sum(0)) + return self.to_backend(mol_features["density"].sum(0)) # type: ignore def __call__( self, @@ -210,7 +212,7 @@ def nr_rks( N, E_xc, V_xc = self( mol, grids, xc_code, self.from_backend(dm), max_memory=max_memory ) - return N.sum().item(), E_xc.item(), self.to_backend(V_xc) + return N.sum().item(), E_xc.item(), self.to_backend(V_xc) # type: ignore def nr_uks( self, @@ -225,7 +227,7 @@ def nr_uks( N, E_xc, V_xc = self( mol, grids, xc_code, self.from_backend(dm), max_memory=max_memory ) - return self.to_backend(N), E_xc.item(), self.to_backend(V_xc) + return self.to_backend(N), E_xc.item(), self.to_backend(V_xc) # type: ignore class libxc: __version__ = None @@ -239,7 +241,7 @@ def is_hybrid_xc(xc: str) -> bool: def is_nlc(xc: str) -> bool: return False - def gen_response( + def gen_response( # type: ignore[override] # wider Array type than PySCF base self, mo_coeff: Array | None, mo_occ: Array | None, diff --git a/src/skala/pyscf/retry.py b/src/skala/pyscf/retry.py index c5fda4d..da8d694 100644 --- a/src/skala/pyscf/retry.py +++ b/src/skala/pyscf/retry.py @@ -85,7 +85,12 @@ def post_cycle_callback(self, envs: dict[str, Any]) -> None: self.gradient_norm_per_cycle.append(norm_gorb) if "norm_ddm" not in envs: envs["norm_ddm"] = np.linalg.norm(envs["dm"] - envs["dm_last"]) - self.dm_change_per_cycle.append(envs["norm_ddm"]) + + norm_ddm = envs["norm_ddm"] + assert isinstance(norm_ddm, (float, int)), ( + f"Expected norm_ddm to be a float, got {type(norm_ddm)}" + ) + self.dm_change_per_cycle.append(norm_ddm) if not isinstance(mo_energy, list) and len(mo_energy.shape) == 1: self.homo_lumo_gap_up_per_cycle.append( diff --git a/src/skala/utils/__init__.py b/src/skala/utils/__init__.py index e416a58..57310c0 100644 --- a/src/skala/utils/__init__.py +++ b/src/skala/utils/__init__.py @@ -4,5 +4,5 @@ Utility functions and classes for the Skala package. This module contains various utility functions used throughout -the Skala codebase, including tensor operations and scatter functions. +the Skala codebase. """ diff --git a/src/skala/utils/scatter.py b/src/skala/utils/scatter.py deleted file mode 100644 index b24b1b0..0000000 --- a/src/skala/utils/scatter.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-License-Identifier: MIT -# A copy of useful code from torch scatter -# https://github.com/rusty1s/pytorch_scatter/blob/96aa2e3587123ba4ef31820899d5e62141e9a4c2/torch_scatter/scatter.py - -""" -Scatter operations for PyTorch tensors. - -This module provides scatter operations similar to pytorch_scatter, -specifically scatter_sum for aggregating values at specified indices. -""" - -import torch - - -def scatter_sum( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: None | torch.Tensor = None, - dim_size: None | int = None, -) -> torch.Tensor: - """ - Sum all values from the src tensor at indices specified in the index tensor. - - Parameters - ---------- - src : torch.Tensor - Source tensor containing values to scatter. - index : torch.Tensor - Index tensor specifying where to scatter values. - dim : int, optional - Dimension along which to scatter. Default: -1. - out : torch.Tensor or None, optional - Output tensor. If None, a new tensor is created. - dim_size : int or None, optional - Size of the output tensor along the scatter dimension. - - Returns - ------- - torch.Tensor - Tensor with scattered and summed values. - """ - index = broadcast(index, src, dim) - if out is None: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() == 0: - size[dim] = 0 - else: - size[dim] = int(index.max()) + 1 - out = torch.zeros(size, dtype=src.dtype, device=src.device) - return out.scatter_add_(dim, index, src) - else: - return out.scatter_add_(dim, index, src) - - -def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor: - """ - Broadcast src tensor to match the shape of other tensor along specified dimensions. - - Parameters - ---------- - src : torch.Tensor - Source tensor to broadcast. - other : torch.Tensor - Target tensor whose shape to match. - dim : int - Dimension along which to perform broadcasting. - - Returns - ------- - torch.Tensor - Broadcasted tensor with shape matching other. - """ - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand(other.size()) - return src diff --git a/tests/test_ase.py b/tests/test_ase.py index a397c1c..09de96f 100644 --- a/tests/test_ase.py +++ b/tests/test_ase.py @@ -3,50 +3,55 @@ try: import ase - from ase.build import molecule - from ase.calculators import calculator - - from skala.ase import Skala except ModuleNotFoundError: ase = None +from ase.build import molecule +from ase.calculators import calculator -@pytest.fixture(params=["pbe", "tpss", "skala"]) -def xc(request) -> str: - return request.param +from skala.ase import Skala @pytest.mark.skipif(ase is None, reason="ASE is not installed") +@pytest.mark.parametrize("xc", ["pbe", "tpss", "skala-1.0", "skala-1.1"]) def test_calc(xc: str) -> None: - atoms = molecule("H2O") + atoms = molecule("H2O") # type: ignore[no-untyped-call] atoms.calc = Skala( - xc=xc, basis="def2-svp", with_density_fit=True, auxbasis="def2-svp-jkfit" + xc=xc, + basis="def2-svp", + with_density_fit=True, + auxbasis="def2-svp-jkfit", ) energy = atoms.get_potential_energy() - reference_energy, reference_fnorm, reference_dipm = { + reference_energy, reference_fnorm, reference_dipole_moment = { "pbe": (-2075.4896490374904, 0.6395142802693002, 0.40519674886465107), "tpss": (-2077.88636677525, 0.5863078815838786, 0.40534133865824284), - "skala": (-2076.4586374337177, 1.127975901679744, 0.4173008295594236), + "skala-1.0": (-2076.4586374337177, 1.127975901679744, 0.4173008295594236), + "skala-1.1": (-2076.839069353949, 0.5614649968829959, 0.41354587147386074), }[xc] - assert ( - pytest.approx(energy, rel=1e-3) == reference_energy - ), f"Energy mismatch for {xc}: {energy} vs {reference_energy}" + assert pytest.approx(energy, rel=1e-3) == reference_energy, ( + f"Energy mismatch for {xc}: {energy} vs {reference_energy}" + ) assert ( pytest.approx(np.linalg.norm(np.abs(atoms.get_forces())), rel=1e-3) == reference_fnorm - ), f"Forces norm mismatch for {xc}: {np.linalg.norm(np.abs(atoms.get_forces()))} vs {reference_fnorm}" + ), ( + f"Forces norm mismatch for {xc}: {np.linalg.norm(np.abs(atoms.get_forces()))} vs {reference_fnorm}" + ) assert ( pytest.approx(np.linalg.norm(atoms.get_dipole_moment()), rel=1e-3) - == reference_dipm - ), f"Dipole moment mismatch for {xc}: {np.linalg.norm(atoms.get_dipole_moment())} vs {reference_dipm}" + == reference_dipole_moment + ), ( + f"Dipole moment mismatch for {xc}: {np.linalg.norm(atoms.get_dipole_moment())} vs {reference_dipole_moment}" + ) @pytest.mark.skipif(ase is None, reason="ASE is not installed") def test_missing_basis() -> None: - atoms = molecule("H2O") + atoms = molecule("H2O") # type: ignore[no-untyped-call] atoms.calc = Skala(xc="pbe", with_density_fit=True, auxbasis="def2-svp-jkfit") with pytest.raises( @@ -68,11 +73,11 @@ def test_ks_config() -> None: energy = atoms.get_potential_energy() - assert ( - atoms.calc._ks.base.conv_tol == 1e-6 - ), "KS solver convergence tolerance not set correctly" + assert atoms.calc._ks.base.conv_tol == 1e-6, ( + "KS solver convergence tolerance not set correctly" + ) reference_energy = -2075.4896490374904 - assert ( - pytest.approx(energy, rel=1e-3) == reference_energy - ), f"Energy mismatch with custom KS config: {energy} vs {reference_energy}" + assert pytest.approx(energy, rel=1e-3) == reference_energy, ( + f"Energy mismatch with custom KS config: {energy} vs {reference_energy}" + ) diff --git a/tests/test_enhancement_factor.py b/tests/test_enhancement_factor.py index 3f691b7..90e329c 100644 --- a/tests/test_enhancement_factor.py +++ b/tests/test_enhancement_factor.py @@ -5,7 +5,7 @@ from skala.functional.base import spin_symmetrized_enhancement_factor -def test_spin_symmetrized_enhancement_factor(): +def test_spin_symmetrized_enhancement_factor() -> None: n = 16 dim_ab = 6 dim_agnostic = 2 diff --git a/tests/test_gauxc_export.py b/tests/test_gauxc_export.py index 87d56ff..e19e53d 100644 --- a/tests/test_gauxc_export.py +++ b/tests/test_gauxc_export.py @@ -9,7 +9,7 @@ @pytest.fixture(params=["He", "Li"]) -def mol_name(request) -> str: +def mol_name(request: pytest.FixtureRequest) -> str: return request.param @@ -19,7 +19,7 @@ def basis() -> str: @pytest.fixture(params=["cart", "sph"]) -def cartesian(request) -> bool: +def cartesian(request: pytest.FixtureRequest) -> bool: return request.param == "cart" @@ -64,7 +64,9 @@ def vxc(ks: dft.rks.RKS, dm: np.ndarray) -> np.ndarray: return vxc -def test_write_pyscf(mol: gto.Mole, dm: np.ndarray, exc, vxc) -> None: +def test_write_pyscf( + mol: gto.Mole, dm: np.ndarray, exc: float, vxc: np.ndarray +) -> None: with NamedTemporaryFile(suffix=".h5") as tmp: write_gauxc_h5_from_pyscf(tmp.name, mol, dm, exc, vxc) @@ -74,9 +76,9 @@ def test_write_pyscf(mol: gto.Mole, dm: np.ndarray, exc, vxc) -> None: assert "DENSITY_SCALAR" in h5, "Density (a+b) is missing in h5 export" assert "DENSITY_Z" in h5, "Density (a-b) is missing in h5 export" assert "EXC" in h5, "Exchange-correlation energy is missing in h5 export" - assert ( - "VXC_SCALAR" in h5 - ), "Exchange-correlation potential (a+b) is missing in h5 export" - assert ( - "VXC_Z" in h5 - ), "Exchange-correlation potential (a-b) is missing in h5 export" + assert "VXC_SCALAR" in h5, ( + "Exchange-correlation potential (a+b) is missing in h5 export" + ) + assert "VXC_Z" in h5, ( + "Exchange-correlation potential (a-b) is missing in h5 export" + ) diff --git a/tests/test_gpu4pyscf_classes.py b/tests/test_gpu4pyscf_classes.py index 6ded68a..49aaf2c 100644 --- a/tests/test_gpu4pyscf_classes.py +++ b/tests/test_gpu4pyscf_classes.py @@ -8,13 +8,23 @@ allow_module_level=True, ) +from skala.functional import load_functional +from skala.functional.base import ExcFunctionalBase from skala.gpu4pyscf import SkalaKS from skala.gpu4pyscf.dft import SkalaRKS, SkalaUKS from skala.gpu4pyscf.gradients import SkalaRKSGradient, SkalaUKSGradient +@pytest.fixture(params=["skala-1.0", "skala-1.1"]) +def skala_xc(request: pytest.FixtureRequest) -> ExcFunctionalBase: + """Load the Skala functional under test on GPU.""" + func = load_functional(request.param, device=torch.device("cuda:0")) + assert isinstance(func, ExcFunctionalBase) + return func + + @pytest.fixture(params=["H", "H2"]) -def mol(request) -> gto.Mole: +def mol(request: pytest.FixtureRequest) -> gto.Mole: if request.param == "H": return gto.M(atom="H", basis="sto-3g", spin=1) if request.param == "H2": @@ -23,28 +33,33 @@ def mol(request) -> gto.Mole: @pytest.fixture(params=["dfj", "no df"]) -def with_density_fit(request) -> bool: +def with_density_fit(request: pytest.FixtureRequest) -> bool: return request.param == "dfj" @pytest.fixture(params=["soscf", "scf"]) -def with_newton(request) -> bool: +def with_newton(request: pytest.FixtureRequest) -> bool: return request.param == "soscf" @pytest.fixture(params=["d3", "no d3"]) -def with_dftd3(request) -> bool: +def with_dftd3(request: pytest.FixtureRequest) -> bool: return request.param == "d3" def test_skala_class( - mol: gto.Mole, with_density_fit: bool, with_newton: bool, with_dftd3: bool -): + mol: gto.Mole, + skala_xc: ExcFunctionalBase, + with_density_fit: bool, + with_newton: bool, + with_dftd3: bool, +) -> None: """Test whether classes get correctly preserved.""" ks = SkalaKS( mol, - xc="skala", + xc=skala_xc, with_density_fit=with_density_fit, + auxbasis="def2-universal-jkfit" if with_density_fit else None, with_newton=with_newton, with_dftd3=with_dftd3, ) diff --git a/tests/test_gpu4pyscf_gradients.py b/tests/test_gpu4pyscf_gradients.py index 6575e97..137d999 100644 --- a/tests/test_gpu4pyscf_gradients.py +++ b/tests/test_gpu4pyscf_gradients.py @@ -50,6 +50,7 @@ def num_dif_ridders( d_estimate[0, 0] = (func(x + step) - func(x - step)) / (2 * step) prev_deriv = d_estimate[0, 0] + num_deriv = prev_deriv for iter in range(1, max_tab): step /= step_div d_estimate[iter, 0] = (func(x + step) - func(x - step)) / (2 * step) @@ -93,7 +94,7 @@ def num_grad_ridders( ) -> tuple[torch.Tensor, torch.Tensor]: """Recursively calculates the partial derivative w.r.t. all elements of x over all dimensions.""" - def func_1d_red(xi: torch.Tensor): + def func_1d_red(xi: torch.Tensor) -> torch.Tensor: x_ = x.clone() x_[i] = xi return func(x_) @@ -116,7 +117,7 @@ def func_1d_red(xi: torch.Tensor): @pytest.fixture(params=["HF", "H2O", "H2O+"]) -def mol_name(request) -> gto.Mole: +def mol_name(request: pytest.FixtureRequest) -> str: return request.param @@ -153,19 +154,19 @@ def get_grid_and_rdm1(mol: gto.Mole) -> tuple[dft.Grids, torch.Tensor]: grids=minimal_grid(mol), ) mf.kernel() - rdm1 = torch.from_dlpack(mf.make_rdm1()) + rdm1 = torch.from_dlpack(mf.make_rdm1()) # type: ignore[attr-defined] return mf.grids, rdm1 # maybe_expand_and_divide(rdm1, len(rdm1.shape) == 2, 2) def test_grid_coords_gradient(mol_name: str) -> None: class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["grid_coords"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return mol_feats["grid_coords"].sum() + return mol["grid_coords"].sum() mol = get_mol(mol_name) grid, rdm1 = get_grid_and_rdm1(mol) @@ -186,13 +187,13 @@ def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: def test_coarse_0_atomic_coords_gradient(mol_name: str) -> None: class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["coarse_0_atomic_coords"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return torch.einsum("nx->", mol_feats["coarse_0_atomic_coords"]) + return torch.einsum("nx->", mol["coarse_0_atomic_coords"]) mol = get_mol(mol_name) grid, rdm1 = get_grid_and_rdm1(mol) @@ -209,17 +210,17 @@ def test_grid_weights_gradient(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return mol_feats["grid_weights"].sum() + return mol["grid_weights"].sum() def finite_difference_nuc_grad( weight_sum: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" # mol_.verbose = 2 mol_feats = generate_features( @@ -229,7 +230,7 @@ def finite_difference_nuc_grad( def weight_sum_as_nuc_coords_func(nuc_coords: torch.Tensor) -> torch.Tensor: """Exc wrapper for the finite difference""" mol.set_geom_(nuc_coords.cpu().numpy(), "bohr", symmetry=None) - mol_feats["grid_weights"] = torch.from_dlpack(minimal_grid(mol).weights) + mol_feats["grid_weights"] = torch.from_dlpack(minimal_grid(mol).weights) # type: ignore[attr-defined] return weight_sum.get_exc(mol_feats) @@ -244,7 +245,7 @@ def weight_sum_as_nuc_coords_func(nuc_coords: torch.Tensor) -> torch.Tensor: num_grad, num_err = finite_difference_nuc_grad(exc_test, mol, rdm1) # estimate the minimum expected absolute error eps = ( - exc_test.get_exc({"grid_weights": torch.from_dlpack(grid.weights)}) + exc_test.get_exc({"grid_weights": torch.from_dlpack(grid.weights)}) # type: ignore[attr-defined] * torch.finfo(num_grad.dtype).eps ) @@ -261,17 +262,17 @@ def test_density_veff(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["density", "grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return (mol_feats["density"] @ mol_feats["grid_weights"]).sum() + return (mol["density"] @ mol["grid_weights"]).sum() def finite_difference_nuc_grad( dens_sum: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" grid = minimal_grid(mol) @@ -316,23 +317,23 @@ def test_grad_veff(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["grad", "grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: return ( - (mol_feats["grad"] ** 2 @ mol_feats["grid_weights"]) + (mol["grad"] ** 2 @ mol["grid_weights"]) @ torch.tensor( [1.0, 2.0, 3.0], dtype=torch.float64, - device=mol_feats["grad"].device, + device=mol["grad"].device, ) ).sum() def finite_difference_nuc_grad( grad_func: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" grid = minimal_grid(mol) @@ -376,17 +377,17 @@ def test_kin_veff(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["kin", "grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total kinetic energy number""" - return (mol_feats["kin"] @ mol_feats["grid_weights"]).sum() + return (mol["kin"] @ mol["grid_weights"]).sum() def finite_difference_nuc_grad( kin_func: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" grid = minimal_grid(mol) @@ -427,20 +428,29 @@ def kin_func_as_nuc_coords_func(nuc_coords: torch.Tensor) -> torch.Tensor: def run_scf( - mol: gto.Mole, functional: ExcFunctionalBase, with_dftd3: bool + mol: gto.Mole, + functional: ExcFunctionalBase, + with_dftd3: bool, + *, + grid_level: int = 1, ) -> scf.hf.SCF: print(f"{mol.basis = }") scf = SkalaKS(mol, xc=functional, with_dftd3=with_dftd3) - scf.grids = minimal_grid(mol) + scf.grids.level = grid_level scf.conv_tol = 1e-14 - # scf.verbose = 0 scf.kernel() return scf -@pytest.fixture(params=["pbe"]) -def xc_name(request) -> str: +@pytest.fixture( + params=[ + "pbe", + "skala-1.0", + "skala-1.1", + ] +) +def xc_name(request: pytest.FixtureRequest) -> str: return request.param @@ -472,15 +482,63 @@ def mol_min_bas(molname: str) -> gto.Mole: ], dtype=torch.float64, ), + "HF:skala-1.0": torch.tensor( + [ + [-6.613324869518045e-11, 8.252792463549002e-11, -0.11766455093793571], + [6.61332489653549e-11, -8.252792460791792e-11, 0.11766455093858674], + ], + dtype=torch.float64, + ), + "H2O:skala-1.0": torch.tensor( + [ + [0.047614265311160864, -1.531395405982722e-10, -3.8752502138778113e-10], + [-0.023807132833466138, 1.4666252996608134e-10, -0.1265627683625854], + [-0.02380713247717514, 6.477008760972603e-12, 0.12656276875009254], + ], + dtype=torch.float64, + ), + "H2O+:skala-1.0": torch.tensor( + [ + [0.11016449231802916, -1.4692817466244152e-10, 8.447353527447336e-10], + [-0.05508224759855285, 4.2209029890734686e-10, -0.15564538047513543], + [-0.05508224472116563, -2.751621283378772e-10, 0.15564537963040848], + ], + dtype=torch.float64, + ), + "HF:skala-1.1": torch.tensor( + [ + [-9.093651147681359e-11, 1.8436550945505342e-10, -0.11922130029704636], + [9.093651147684337e-11, -1.8436550945509882e-10, 0.11922130029705125], + ], + dtype=torch.float64, + ), + "H2O:skala-1.1": torch.tensor( + [ + [0.05518685428627901, 1.0112539782960166e-09, 5.727682266064312e-10], + [-0.027593427632945478, -4.887476070628282e-06, -0.12591870031741337], + [-0.027593426653364173, 4.886464816658505e-06, 0.12591869974465286], + ], + dtype=torch.float64, + ), + "H2O+:skala-1.1": torch.tensor( + [ + [0.11201511304824052, -2.33353601498583e-10, 4.162082128268623e-10], + [-0.05600755684684611, -1.4123079854089175e-06, -0.15729176960843216], + [-0.056007556201392195, 1.4125413389947007e-06, 0.15729176919222287], + ], + dtype=torch.float64, + ), } def test_full_grad(mol_name: str, xc_name: str) -> None: # analytical result mol = get_mol(mol_name) - scf = run_scf( - mol, load_functional(xc_name, device=torch.device("cuda:0")), with_dftd3=False - ) + func = load_functional(xc_name, device=torch.device("cuda:0")) + assert isinstance(func, ExcFunctionalBase) + # skala-1.1 uses per-atom packed grids (unsorted) and needs a denser grid + # to avoid NaNs in the SCF. + scf = run_scf(mol, func, with_dftd3=False) if mol.spin == 0: grad = SkalaRKSGradient(scf).kernel() @@ -490,17 +548,10 @@ def test_full_grad(mol_name: str, xc_name: str) -> None: # get reference result ref_grad = FULL_GRAD_REF[mol_name + ":" + xc_name] - # get numerical result - # num_grad, num_err = SkalaRKSGradient(scf).numerical() - # print(f"{ana_grad = }") - # print(f"{num_grad = }") - # print(f"{num_err = }") - # print(f"{ana_grad - num_grad}") - # print(f"{ref_grad = }") assert torch.allclose(ana_grad, ref_grad, atol=1e-4), ( f"Gradients for {mol_name} with {xc_name} do not match reference.\n" - f"Analytic: {ana_grad}\n" + f"Analytic: {ana_grad.tolist()!r}\n" f"Reference: {ref_grad}\n" f"Difference: {ana_grad - ref_grad}" ) diff --git a/tests/test_hash_pinning.py b/tests/test_hash_pinning.py index 9cc48e5..c91bd31 100644 --- a/tests/test_hash_pinning.py +++ b/tests/test_hash_pinning.py @@ -34,7 +34,7 @@ def get_exc(self, data: dict[str, torch.Tensor]) -> torch.Tensor: "protocol_version": json.dumps(2).encode(), } buf = io.BytesIO() - torch.jit.save(scripted, buf, _extra_files=extra_files) + torch.jit.save(scripted, buf, _extra_files=extra_files) # type: ignore[no-untyped-call] return buf.getvalue() diff --git a/tests/test_model.py b/tests/test_model.py index 0e4b23a..35e5eec 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,123 +1,438 @@ # SPDX-License-Identifier: MIT -import numpy as np +"""Snapshot regression tests for SkalaFunctional components. + +Each test captures deterministic numerical output using torch.manual_seed(42) +on CPU. Any change that alters model output will be caught. + +RNG state is forked via an autouse fixture so that seeding inside tests +does not mutate the global RNG visible to other tests in the process. +""" + +import math +from collections.abc import Iterator + import pytest import torch -from pyscf import dft, gto, scf -from skala.functional import load_functional +from skala.functional import ExcFunctionalBase, load_functional from skala.functional.model import ( + ANGSTROM_TO_BOHR, + ExpRadialScaleModel, + NonLocalModel, + O3Linear, + SemiLocalFeatures, SkalaFunctional, - exp_radial_func, + TensorProduct, + _prepare_features_raw, ) -from skala.pyscf.features import generate_features +from skala.functional.utils.irreps import Irreps + + +@pytest.fixture(autouse=True) +def _isolated_rng() -> Iterator[None]: + """Fork the PyTorch RNG so manual_seed calls inside tests don't leak.""" + with torch.random.fork_rng(): + yield + + +def exp_radial_func(dist: torch.Tensor, num_basis: int, dim: int = 3) -> torch.Tensor: + """Legacy standalone version, kept here for snapshot regression testing.""" + min_std = 0.32 * ANGSTROM_TO_BOHR / 2 + max_std = 2.32 * ANGSTROM_TO_BOHR / 2 + s = torch.linspace(min_std, max_std, num_basis, device=dist.device) + temps = 2 * s**2 + x2 = dist[..., None] ** 2 + return ( + torch.exp(-x2 / temps) * 2 / dim * x2 / temps / (math.pi * temps) ** (0.5 * dim) + ) + + +def make_mol( + num_atoms: int, + grid_per_atom: int, + device: str = "cpu", + dtype: torch.dtype = torch.float64, +) -> dict[str, torch.Tensor]: + total_grid = num_atoms * grid_per_atom + return { + "density": torch.randn(2, total_grid, dtype=dtype, device=device), + "grad": torch.randn(2, 3, total_grid, dtype=dtype, device=device), + "kin": torch.randn(2, total_grid, dtype=dtype, device=device), + "grid_coords": torch.randn(total_grid, 3, dtype=dtype, device=device), + "grid_weights": torch.randn(total_grid, dtype=dtype, device=device).abs(), + "atomic_grid_weights": torch.randn( + total_grid, dtype=dtype, device=device + ).abs(), + "atomic_grid_sizes": torch.tensor( + [grid_per_atom] * num_atoms, dtype=torch.int64, device=device + ), + "coarse_0_atomic_coords": torch.randn(num_atoms, 3, dtype=dtype, device=device), + "atomic_grid_size_bound_shape": torch.zeros( + grid_per_atom, 0, dtype=torch.int64, device=device + ), + } + + +def make_mol_variable_grid( + atomic_grid_sizes: list[int], + device: str = "cpu", + dtype: torch.dtype = torch.float64, +) -> dict[str, torch.Tensor]: + """Create a mol dict with variable grid sizes per atom.""" + sizes = torch.tensor(atomic_grid_sizes, dtype=torch.int64, device=device) + num_atoms = len(atomic_grid_sizes) + total_grid = sum(atomic_grid_sizes) + size_bound = max(atomic_grid_sizes) + return { + "density": torch.randn(2, total_grid, dtype=dtype, device=device), + "grad": torch.randn(2, 3, total_grid, dtype=dtype, device=device), + "kin": torch.randn(2, total_grid, dtype=dtype, device=device), + "grid_coords": torch.randn(total_grid, 3, dtype=dtype, device=device), + "grid_weights": torch.randn(total_grid, dtype=dtype, device=device).abs(), + "atomic_grid_weights": torch.randn( + total_grid, dtype=dtype, device=device + ).abs(), + "atomic_grid_sizes": sizes, + "coarse_0_atomic_coords": torch.randn(num_atoms, 3, dtype=dtype, device=device), + "atomic_grid_size_bound_shape": torch.zeros( + size_bound, 0, dtype=torch.int64, device=device + ), + } + -torch.manual_seed(0) +def small_model() -> SkalaFunctional: + return SkalaFunctional( + num_mid_layers=1, + num_non_local_layers=1, + non_local_hidden_nf=3, + correlation=1, + ) -@pytest.fixture(scope="session") -def mol() -> gto.Mole: - mol = gto.M(atom="H 0 0 0; F 0 0 1.1", basis="def2-qzvp", cart=True) - return mol +def test_prepare_features_raw_snapshot() -> None: + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol(4, 10) + packed = model.pack_features(mol) + out = _prepare_features_raw(packed) + + assert out.shape == (10, 4, 7) + torch.testing.assert_close( + out.sum(), + torch.tensor(1.104726756002241e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(2.873065774435380e02, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + expected_first = torch.tensor( + [ + -1.2054195578684468, + -1.2484940336629657, + 0.9278196025135572, + 1.1062339879348806, + -0.00458365814011141, + -4.091211863274872, + 1.9987709853481832, + ], + dtype=torch.float64, + ) + torch.testing.assert_close(out[0, 0, :], expected_first, rtol=1e-5, atol=1e-5) + + +def test_prepare_features_snapshot() -> None: + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol(4, 10) + packed = model.pack_features(mol) + semi_local = SemiLocalFeatures() + features_ab, features_ba = semi_local(packed) + + assert features_ab.shape == (10, 4, 7) + assert features_ba.shape == (10, 4, 7) + torch.testing.assert_close( + features_ab.sum(), + torch.tensor(1.104726756002241e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + features_ba.sum(), + torch.tensor(1.104726756002241e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + # features_ba is the column-swapped version of features_ab + expected_ba = torch.stack( + [features_ab[..., i] for i in [1, 0, 3, 2, 5, 4, 6]], dim=-1 + ) + torch.testing.assert_close(features_ba, expected_ba, rtol=0, atol=0) + + +def test_pack_features_snapshot() -> None: + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol(4, 10) + packed = model.pack_features(mol) + + assert packed["density"].shape == (2, 10, 4) + assert packed["kin"].shape == (2, 10, 4) + assert packed["grad"].shape == (2, 3, 10, 4) + assert packed["grid_coords"].shape == (10, 4, 3) + assert packed["atomic_grid_weights"].shape == (10, 4) + assert packed["coarse_0_atomic_coords"].shape == (4, 3) + + torch.testing.assert_close( + packed["density"].sum(), + torch.tensor(1.020635438470402e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + packed["atomic_grid_weights"].sum(), + torch.tensor(4.032819661608873e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + +def test_exp_radial_func_snapshot() -> None: + torch.manual_seed(42) + dist = torch.randn(5, 3, dtype=torch.float64).abs() + out = exp_radial_func(dist, num_basis=16) -def get_mf_dm(mol: gto.Mole) -> tuple[scf.hf.SCF, np.ndarray]: - ks = dft.KS( - mol, - xc="pbe", - )( - grids=dft.Grids(mol)(level=1, radi_method=dft.radi.treutler).build(), - max_cycle=1, + assert out.shape == (5, 3, 16) + torch.testing.assert_close( + out.sum(), + torch.tensor(6.552108459223353e00, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + # All values should be non-negative (Gaussian-based radial function) + assert (out >= 0).all() + + # ExpRadialScaleModel module must produce identical output + radial_basis = ExpRadialScaleModel(embedding_size=16).double() + out_module = radial_basis(dist.unsqueeze(-1) ** 2) + torch.testing.assert_close(out_module, out, rtol=1e-6, atol=1e-7) + + +def test_tensor_product_snapshot() -> None: + torch.manual_seed(42) + irreps_in1 = Irreps("3x0e") + irreps_in2 = Irreps("1x0e+1x1e") + irreps_out = Irreps("3x0e+3x1e") + tp = TensorProduct(irreps_in1, irreps_in2, irreps_out) + + torch.manual_seed(42) + x1 = torch.randn(5, 4, irreps_in1.dim, dtype=torch.float32) + x2 = torch.randn(5, 4, irreps_in2.dim, dtype=torch.float32) + out = tp(x1, x2) + + assert out.shape == (5, 4, 12) + torch.testing.assert_close( + out.sum(), + torch.tensor(5.987227916717529e00, dtype=torch.float32), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(1.092445983886719e02, dtype=torch.float32), + rtol=1e-5, + atol=1e-5, ) - ks.kernel() - return ks, ks.make_rdm1() -def test_vne3nn_invariance(mol: gto.Mole): - # fix np seed - np.random.seed(0) +def test_o3_linear_snapshot() -> None: + torch.manual_seed(42) + irreps_in = Irreps("3x0e+3x1e") + irreps_out = Irreps("3x0e+3x1e") + linear = O3Linear(irreps_in, irreps_out) - ks, dm = get_mf_dm(mol) + torch.manual_seed(42) + x = torch.randn(5, irreps_in.dim, dtype=torch.float32) + out = linear(x) - model = SkalaFunctional(non_local=True) + assert out.shape == (5, 12) + torch.testing.assert_close( + out.sum(), + torch.tensor(1.546258926391602e01, dtype=torch.float32), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(5.258611297607422e01, dtype=torch.float32), + rtol=1e-5, + atol=1e-5, + ) - dm_torch1 = torch.from_numpy(dm).float() - exc1 = model.get_exc( - generate_features(mol, dm_torch1, ks.grids, set(model.features)) +def test_nonlocal_model_snapshot() -> None: + torch.manual_seed(42) + sph_irreps = Irreps.spherical_harmonics(1, p=1) + nlm = NonLocalModel( + input_nf=256, + hidden_nf=3, + lmax=1, + edge_irreps=sph_irreps, + coarse_linear_type="decomp-identity", + correlation=1, + ).float() + + torch.manual_seed(42) + num_fine, num_coarse = 10, 4 + h = torch.randn(num_fine, num_coarse, 256, dtype=torch.float32) + distance_ft = torch.randn(num_fine, num_coarse, 3, dtype=torch.float32).abs() + direction_ft = torch.randn(num_fine, num_coarse, 4, dtype=torch.float32) + grid_weights = torch.randn(num_fine, num_coarse, dtype=torch.float32).abs() + exp_m1_rho = torch.randn(num_fine, num_coarse, 1, dtype=torch.float32).abs() + out = nlm(h, distance_ft, direction_ft, grid_weights, exp_m1_rho) + + assert out.shape == (10, 4, 256) + torch.testing.assert_close( + out.sum(), + torch.tensor(7.943189086914062e02, dtype=torch.float32), + rtol=1e-4, + atol=1e-1, + ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(2.372190673828125e03, dtype=torch.float32), + rtol=1e-4, + atol=1e-1, ) - # Check that the model is invariant to the rotation of the coordinates - Q = np.linalg.qr(np.random.randn(3, 3)).Q - atom_coords = mol.atom_coords() - mol.set_geom_(atom_coords @ Q, unit="bohr") - assert mol.atom_coords() == pytest.approx(atom_coords @ Q, abs=1e-6) - ks, dm = get_mf_dm(mol) - dm_torch2 = torch.from_numpy(dm).float().requires_grad_(True) - assert not torch.allclose(dm_torch1, dm_torch2) +def test_get_exc_density_snapshot_4atoms() -> None: + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol(4, 10) + out = model.get_exc_density(mol) - exc2 = model.get_exc( - generate_features(mol, dm_torch2, ks.grids, set(model.features)) + assert out.shape == (40,) + torch.testing.assert_close( + out.sum(), + torch.tensor(-3.425105949459996e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(3.425105949459996e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + - assert torch.allclose(exc1, exc2, atol=0.00001) +def test_get_exc_density_snapshot_17atoms() -> None: + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol(17, 10) + out = model.get_exc_density(mol) + assert out.shape == (170,) + torch.testing.assert_close( + out.sum(), + torch.tensor(-1.321246666755034e02, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(1.321246666755034e02, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) -def test_double_precision(mol: gto.Mole): - # this ensures the functional can handle double precision inputs - ks, dm = get_mf_dm(mol) - model = SkalaFunctional(non_local=True) +def test_get_exc_snapshot() -> None: + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol(4, 10) + out = model.get_exc(mol) - model.double() - dm_torch1 = torch.from_numpy(dm).double() + torch.testing.assert_close( + out, + torch.tensor(-2.821835255217404e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) - _ = model.get_exc(generate_features(mol, dm_torch1, ks.grids, set(model.features))) +def test_get_exc_density_variable_grid_sizes() -> None: + """Test that get_exc_density returns the correct unpadded shape and values with variable grid sizes.""" + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + sizes = [5, 10, 8, 3] + mol = make_mol_variable_grid(sizes) + out = model.get_exc_density(mol) -def test_exp_radial_func_normalization(): - N, num_basis = 100000, 16 - xx = torch.linspace(-10, 10, N) - dx = 20 / N + assert out.shape == (sum(sizes),), ( + f"Expected shape ({sum(sizes)},) but got {out.shape}. " + "get_exc_density should return (num_grid_points,), not padded shape." + ) + torch.testing.assert_close( + out.sum(), + torch.tensor(-1.642121553088128e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) + torch.testing.assert_close( + out.abs().sum(), + torch.tensor(1.642121553088128e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) - emb = exp_radial_func(xx, num_basis=num_basis, dim=1) - assert list(emb.shape) == [N, num_basis] - integrals = (emb * dx).sum(0) - assert torch.isclose( - integrals, torch.ones_like(integrals), atol=1e-4, rtol=1e-4 - ).all(), integrals +def test_get_exc_variable_grid_sizes() -> None: + """Test that get_exc still works correctly with variable grid sizes.""" + torch.manual_seed(42) + model = small_model() + torch.manual_seed(42) + mol = make_mol_variable_grid([5, 10, 8, 3]) + out = model.get_exc(mol) + + torch.testing.assert_close( + out, + torch.tensor(-1.121880984584683e01, dtype=torch.float64), + rtol=1e-5, + atol=1e-5, + ) -def test_traced_functional_and_loaded_functional_are_equal(): +def test_traced_functional_and_loaded_functional_are_equal() -> None: # This test ensures that the traced functional and the loaded functional # give the same output for the same input. - traced_model = load_functional("skala") + traced_model = load_functional("skala-1.1") + assert isinstance(traced_model, ExcFunctionalBase) + clean_state_dict = { k.replace("_traced_model.", ""): v for k, v in traced_model.state_dict().items() } - model = SkalaFunctional(lmax=3, radius_cutoff=5.0) + model = SkalaFunctional(lmax=3, num_non_local_layers=3, num_mid_layers=4) model.load_state_dict(clean_state_dict, strict=True) - # Create a dummy load_input - num_grid_points = 10 - grid_coords = torch.randn(num_grid_points, 3) - density = torch.randn(2, num_grid_points) - gradients = torch.randn(2, 3, num_grid_points) - kin = torch.randn(2, num_grid_points) - grid_weights = torch.randn(num_grid_points) - atom_coords = torch.tensor([[0.0, 0.0, 0.0]]) - features_dict = { - "density": density, - "grad": gradients, - "kin": kin, - "grid_weights": grid_weights, - "grid_coords": grid_coords, - "coarse_0_atomic_coords": atom_coords, - } + # Create a dummy input using the same helper as the other tests + torch.manual_seed(42) + features_dict = make_mol(num_atoms=1, grid_per_atom=10) + original_output = model.get_exc(features_dict) traced_model_output = traced_model.get_exc(features_dict) diff --git a/tests/test_pyscf_classes.py b/tests/test_pyscf_classes.py index 474dec0..aec4cac 100644 --- a/tests/test_pyscf_classes.py +++ b/tests/test_pyscf_classes.py @@ -1,13 +1,23 @@ import pytest from pyscf import gto +from skala.functional import load_functional +from skala.functional.base import ExcFunctionalBase from skala.pyscf import SkalaKS from skala.pyscf.dft import SkalaRKS, SkalaUKS from skala.pyscf.gradients import SkalaRKSGradient, SkalaUKSGradient +@pytest.fixture(params=["skala-1.0", "skala-1.1"]) +def skala_xc(request: pytest.FixtureRequest) -> ExcFunctionalBase: + """Load the Skala functional under test.""" + func = load_functional(request.param) + assert isinstance(func, ExcFunctionalBase) + return func + + @pytest.fixture(params=["H", "H2"]) -def mol(request) -> gto.Mole: +def mol(request: pytest.FixtureRequest) -> gto.Mole: if request.param == "H": return gto.M(atom="H", basis="sto-3g", spin=1) if request.param == "H2": @@ -16,28 +26,33 @@ def mol(request) -> gto.Mole: @pytest.fixture(params=["dfj", "no df"]) -def with_density_fit(request) -> bool: +def with_density_fit(request: pytest.FixtureRequest) -> bool: return request.param == "dfj" @pytest.fixture(params=["soscf", "scf"]) -def with_newton(request) -> bool: +def with_newton(request: pytest.FixtureRequest) -> bool: return request.param == "soscf" @pytest.fixture(params=["d3", "no d3"]) -def with_dftd3(request) -> bool: +def with_dftd3(request: pytest.FixtureRequest) -> bool: return request.param == "d3" def test_skala_class( - mol: gto.Mole, with_density_fit: bool, with_newton: bool, with_dftd3: bool -): + mol: gto.Mole, + skala_xc: ExcFunctionalBase, + with_density_fit: bool, + with_newton: bool, + with_dftd3: bool, +) -> None: """Test whether classes get correctly preserved.""" ks = SkalaKS( mol, - xc="skala", + xc=skala_xc, with_density_fit=with_density_fit, + auxbasis="def2-universal-jkfit" if with_density_fit else None, with_newton=with_newton, with_dftd3=with_dftd3, ) diff --git a/tests/test_pyscf_gradients.py b/tests/test_pyscf_gradients.py index 1ecc584..beca985 100644 --- a/tests/test_pyscf_gradients.py +++ b/tests/test_pyscf_gradients.py @@ -41,6 +41,7 @@ def num_dif_ridders( d_estimate[0, 0] = (func(x + step) - func(x - step)) / (2 * step) prev_deriv = d_estimate[0, 0] + num_deriv = prev_deriv for iter in range(1, max_tab): step /= step_div d_estimate[iter, 0] = (func(x + step) - func(x - step)) / (2 * step) @@ -84,7 +85,7 @@ def num_grad_ridders( ) -> tuple[torch.Tensor, torch.Tensor]: """Recursively calculates the partial derivative w.r.t. all elements of x over all dimensions.""" - def func_1d_red(xi: torch.Tensor): + def func_1d_red(xi: torch.Tensor) -> torch.Tensor: x_ = x.clone() x_[i] = xi return func(x_) @@ -107,7 +108,7 @@ def func_1d_red(xi: torch.Tensor): @pytest.fixture(params=["HF", "H2O", "H2O+"]) -def mol_name(request) -> gto.Mole: +def mol_name(request: pytest.FixtureRequest) -> str: return request.param @@ -132,8 +133,10 @@ def get_mol(molname: str) -> gto.Mole: return mol -def minimal_grid(mol: gto.Mole) -> dft.Grids: - return dft.Grids(mol)(level=1, radi_method=dft.radi.treutler).build() +def minimal_grid(mol: gto.Mole, sort_grids: bool = True) -> dft.Grids: + grids = dft.Grids(mol)(level=1, radi_method=dft.radi.treutler) + grids.build(sort_grids=sort_grids) + return grids def get_grid_and_rdm1(mol: gto.Mole) -> tuple[dft.Grids, torch.Tensor]: @@ -150,13 +153,13 @@ def get_grid_and_rdm1(mol: gto.Mole) -> tuple[dft.Grids, torch.Tensor]: def test_grid_coords_gradient(mol_name: str) -> None: class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["grid_coords"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return mol_feats["grid_coords"].sum() + return mol["grid_coords"].sum() mol = get_mol(mol_name) grid, rdm1 = get_grid_and_rdm1(mol) @@ -177,13 +180,13 @@ def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: def test_coarse_0_atomic_coords_gradient(mol_name: str) -> None: class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["coarse_0_atomic_coords"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return torch.einsum("nx->", mol_feats["coarse_0_atomic_coords"]) + return torch.einsum("nx->", mol["coarse_0_atomic_coords"]) mol = get_mol(mol_name) grid, rdm1 = get_grid_and_rdm1(mol) @@ -200,17 +203,17 @@ def test_grid_weights_gradient(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return mol_feats["grid_weights"].sum() + return mol["grid_weights"].sum() def finite_difference_nuc_grad( weight_sum: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" # mol_.verbose = 2 mol_feats = generate_features( @@ -266,17 +269,17 @@ def test_density_veff(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["density", "grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total electron number""" - return (mol_feats["density"] @ mol_feats["grid_weights"]).sum() + return (mol["density"] @ mol["grid_weights"]).sum() def finite_difference_nuc_grad( dens_sum: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" grid = minimal_grid(mol) @@ -319,19 +322,19 @@ def test_grad_veff(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["grad", "grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: return ( - (mol_feats["grad"] ** 2 @ mol_feats["grid_weights"]) + (mol["grad"] ** 2 @ mol["grid_weights"]) @ torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) ).sum() def finite_difference_nuc_grad( grad_func: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" grid = minimal_grid(mol) @@ -373,17 +376,17 @@ def test_kin_veff(mol_name: str) -> None: mol = get_mol(mol_name) class TestFunc(ExcFunctionalBase): - def __init__(self): + def __init__(self) -> None: super().__init__() self.features = ["kin", "grid_weights"] - def get_exc(self, mol_feats: dict[str, torch.Tensor]) -> torch.Tensor: + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: """This actually calculates the total kinetic energy number""" - return (mol_feats["kin"] @ mol_feats["grid_weights"]).sum() + return (mol["kin"] @ mol["grid_weights"]).sum() def finite_difference_nuc_grad( kin_func: ExcFunctionalBase, mol: gto.Mole, rdm1: torch.Tensor - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Calculates the gradient in Exc w.r.t. nuclear coordinates numerically""" grid = minimal_grid(mol) @@ -422,25 +425,28 @@ def kin_func_as_nuc_coords_func(nuc_coords: torch.Tensor) -> torch.Tensor: def run_scf( - mol: gto.Mole, functional: ExcFunctionalBase, with_dftd3: bool + mol: gto.Mole, + functional: ExcFunctionalBase, + with_dftd3: bool, + *, + grid_level: int = 1, ) -> scf.hf.SCF: print(f"{mol.basis = }") scf = SkalaKS(mol, xc=functional, with_dftd3=with_dftd3) - scf.grids = minimal_grid(mol) + scf.grids.level = grid_level scf.conv_tol = 1e-14 - # scf.verbose = 0 scf.kernel() return scf -@pytest.fixture(params=["pbe"]) -def xc_name(request) -> str: +@pytest.fixture(params=["pbe", "skala-1.0", "skala-1.1"]) +def xc_name(request: pytest.FixtureRequest) -> str: return request.param -def mol_min_bas(molname: str) -> gto.Mole: - molecule = get_mol(molname) +def mol_min_bas(mol_name: str) -> gto.Mole: + molecule = get_mol(mol_name) molecule.basis = "sto-3g" return molecule @@ -467,13 +473,62 @@ def mol_min_bas(molname: str) -> gto.Mole: ], dtype=torch.float64, ), + "HF:skala-1.0": torch.tensor( + [ + [-1.8951600005625355e-10, 2.0983011686819494e-10, -0.11766455110756313], + [1.895159998665323e-10, -2.0983011653002862e-10, 0.11766455110756091], + ], + dtype=torch.float64, + ), + "H2O:skala-1.0": torch.tensor( + [ + [0.04761426020567949, 9.68124090794071e-11, 1.2024662967084874e-09], + [-0.023807130986786884, 1.628796777076799e-10, -0.12656276817486223], + [-0.023807129218868184, -2.596920517571628e-10, 0.126562766972401], + ], + dtype=torch.float64, + ), + "H2O+:skala-1.0": torch.tensor( + [ + [0.11016447311737299, -1.7268193014843002e-09, 6.9612067477191e-10], + [-0.055082237334041384, 6.800820207918395e-10, -0.15564537931499212], + [-0.05508223578332139, 1.0467372935944724e-09, 0.15564537861887162], + ], + dtype=torch.float64, + ), + "HF:skala-1.1": torch.tensor( + [ + [-9.093651147681359e-11, 1.8436550945505342e-10, -0.11922130029704636], + [9.093651147684337e-11, -1.8436550945509882e-10, 0.11922130029705125], + ], + dtype=torch.float64, + ), + "H2O:skala-1.1": torch.tensor( + [ + [0.05518685428627901, 1.0112539782960166e-09, 5.727682266064312e-10], + [-0.027593427632945478, -4.887476070628282e-06, -0.12591870031741337], + [-0.027593426653364173, 4.886464816658505e-06, 0.12591869974465286], + ], + dtype=torch.float64, + ), + "H2O+:skala-1.1": torch.tensor( + [ + [0.11201511304824052, -2.33353601498583e-10, 4.162082128268623e-10], + [-0.05600755684684611, -1.4123079854089175e-06, -0.15729176960843216], + [-0.056007556201392195, 1.4125413389947007e-06, 0.15729176919222287], + ], + dtype=torch.float64, + ), } def test_full_grad(mol_name: str, xc_name: str) -> None: # analytical result mol = get_mol(mol_name) - scf = run_scf(mol, load_functional(xc_name), with_dftd3=False) + func = load_functional(xc_name) + assert isinstance(func, ExcFunctionalBase) + + scf = run_scf(mol, func, with_dftd3=False) if mol.spin == 0: grad = SkalaRKSGradient(scf).kernel() @@ -483,17 +538,106 @@ def test_full_grad(mol_name: str, xc_name: str) -> None: # get reference result ref_grad = FULL_GRAD_REF[mol_name + ":" + xc_name] - # get numerical result - # num_grad, num_err = SkalaRKSGradient(scf).numerical() - # print(f"{ana_grad = }") - # print(f"{num_grad = }") - # print(f"{num_err = }") - # print(f"{ana_grad - num_grad}") - # print(f"{ref_grad = }") assert torch.allclose(ana_grad, ref_grad, atol=1e-4), ( f"Gradients for {mol_name} with {xc_name} do not match reference.\n" - f"Analytic: {ana_grad}\n" + f"Analytic: {ana_grad.tolist()!r}\n" f"Reference: {ref_grad}\n" f"Difference: {ana_grad - ref_grad}" ) + + +def test_atomic_grid_weights_gradient(mol_name: str) -> None: + """atomic_grid_weights are raw quadrature weights independent of nuclear positions. + + d(atomic_grid_weights)/dR = 0, so the contribution to the nuclear gradient must be zero. + """ + + class TestFunc(ExcFunctionalBase): + def __init__(self) -> None: + super().__init__() + self.features = ["atomic_grid_weights"] + + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + return mol["atomic_grid_weights"].sum() + + mol = get_mol(mol_name) + grid, rdm1 = get_grid_and_rdm1(mol) + # Rebuild grid unsorted (required for atomic_grid_weights feature generation) + grid = minimal_grid(mol, sort_grids=False) + + exc_test = TestFunc() + _, nuc_grad = veff_and_expl_nuc_grad(exc_test, mol, grid, rdm1) + + assert torch.allclose(nuc_grad, torch.zeros_like(nuc_grad), atol=1e-15), ( + f"atomic_grid_weights gradient should be zero, got {nuc_grad}" + ) + + +def test_atomic_grid_features_passthrough(mol_name: str) -> None: + """Verify all three per-atom grid features pass through gradient computation without error. + + atomic_grid_sizes and atomic_grid_size_bound_shape are integer metadata and should be + auto-discarded. atomic_grid_weights should be discarded from VJP but passed as other_feats. + """ + + class TestFunc(ExcFunctionalBase): + def __init__(self) -> None: + super().__init__() + self.features = [ + "density", + "grid_weights", + "atomic_grid_weights", + "atomic_grid_sizes", + "atomic_grid_size_bound_shape", + ] + + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + # Use density and grid_weights (differentiable) plus atomic_grid_weights (other_feat) + n_electrons = (mol["density"] @ mol["grid_weights"]).sum() + agw_sum = mol["atomic_grid_weights"].sum() + return n_electrons + agw_sum + + mol = get_mol(mol_name) + grid, rdm1 = get_grid_and_rdm1(mol) + # Rebuild grid unsorted (required for per-atom grid features) + grid = minimal_grid(mol, sort_grids=False) + + exc_test = TestFunc() + # This should not raise NotImplementedError + veff, nuc_grad = veff_and_expl_nuc_grad(exc_test, mol, grid, rdm1) + + if mol.spin == 0: + assert veff.shape == (3, mol.nao, mol.nao) + else: + assert veff.shape == (2, 3, mol.nao, mol.nao) + assert nuc_grad.shape == (mol.natm, 3) + + +def test_explicit_nuc_grad_feats_with_integer_features(mol_name: str) -> None: + """Passing integer metadata features via explicit nuc_grad_feats should not raise.""" + + class TestFunc(ExcFunctionalBase): + def __init__(self) -> None: + super().__init__() + self.features = [ + "density", + "grid_weights", + "atomic_grid_weights", + "atomic_grid_sizes", + "atomic_grid_size_bound_shape", + ] + + def get_exc(self, mol: dict[str, torch.Tensor]) -> torch.Tensor: + return (mol["density"] @ mol["grid_weights"]).sum() + + mol = get_mol(mol_name) + grid, rdm1 = get_grid_and_rdm1(mol) + grid = minimal_grid(mol, sort_grids=False) + + exc_test = TestFunc() + # Explicitly pass all features including integer ones — should auto-discard them + veff, nuc_grad = veff_and_expl_nuc_grad( + exc_test, mol, grid, rdm1, nuc_grad_feats=set(exc_test.features) + ) + assert nuc_grad.shape == (mol.natm, 3) diff --git a/tests/test_scf_retry.py b/tests/test_scf_retry.py index ff8db9a..55035ed 100644 --- a/tests/test_scf_retry.py +++ b/tests/test_scf_retry.py @@ -40,7 +40,7 @@ def test_retry_newton(mol: gto.Mole) -> None: assert state.ntries > 1, "SCF should have been retried" assert ks.converged, "SCF did not converge with retry mechanism" - assert isinstance( - ks, _CIAH_SOSCF - ), "SCF should have used Newton's method after retries" + assert isinstance(ks, _CIAH_SOSCF), ( + "SCF should have used Newton's method after retries" + ) assert ks.level_shift == 0, "Level shift should be zero after Newton's method retry" diff --git a/tests/test_traditional.py b/tests/test_traditional.py index 2be6159..feac460 100644 --- a/tests/test_traditional.py +++ b/tests/test_traditional.py @@ -9,7 +9,7 @@ @pytest.fixture(params=["HF", "B", "H"]) -def mol(request) -> gto.Mole: +def mol(request: pytest.FixtureRequest) -> gto.Mole: if request.param == "HF": return gto.M(atom="H 0 0 0; F 0 0 1.1", basis="cc-pvdz") elif request.param == "B": @@ -20,14 +20,16 @@ def mol(request) -> gto.Mole: @pytest.fixture(params=["lda", "spw92", "pbe", "tpss"]) -def xc(request) -> str: +def xc(request: pytest.FixtureRequest) -> str: return request.param @pytest.fixture def xc_fun(xc: str) -> ExcFunctionalBase: """Fixture to load the functional.""" - return load_functional(xc) + func = load_functional(xc) + assert isinstance(func, ExcFunctionalBase) + return func @pytest.fixture diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b846cd9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: MIT + +"""Tests for the pad_ragged, irreps, and symmetric_contraction utilities.""" + +from collections.abc import Iterator + +import pytest +import torch + +from skala.functional.utils.irreps import Irrep, Irreps, MulIr +from skala.functional.utils.pad_ragged import pad_ragged, unpad_ragged + + +class TestPadRagged: + def test_round_trip_uniform(self) -> None: + """pad then unpad recovers the original flat tensor.""" + sizes = torch.tensor([5, 5, 5]) + flat = torch.randn(15, 3) + padded = pad_ragged(flat, sizes, 5) + assert padded.shape == (3, 5, 3) + recovered = unpad_ragged(padded, sizes, 15) + torch.testing.assert_close(recovered, flat) + + def test_round_trip_variable(self) -> None: + """pad then unpad works with variable sizes.""" + sizes = torch.tensor([3, 7, 5]) + total = sizes.sum().item() + assert isinstance(total, int) + flat = torch.randn(total, 2) + padded = pad_ragged(flat, sizes, 7) + assert padded.shape == (3, 7, 2) + recovered = unpad_ragged(padded, sizes, total) + torch.testing.assert_close(recovered, flat) + + def test_single_sequence_fast_path(self) -> None: + """Single sequence (1 atom) takes the fast path.""" + sizes = torch.tensor([10]) + flat = torch.randn(10, 4) + padded = pad_ragged(flat, sizes, 10) + assert padded.shape == (1, 10, 4) + recovered = unpad_ragged(padded, sizes, 10) + torch.testing.assert_close(recovered, flat) + + def test_1d_input(self) -> None: + """pad_ragged works with 1D inputs.""" + sizes = torch.tensor([3, 2]) + flat = torch.randn(5) + padded = pad_ragged(flat, sizes, 3) + assert padded.shape == (2, 3) + recovered = unpad_ragged(padded, sizes, 5) + torch.testing.assert_close(recovered, flat) + + def test_padding_is_zero(self) -> None: + """Padded elements should be zero.""" + sizes = torch.tensor([2, 5]) + flat = torch.randn(7, 3) + 10 # shift away from zero + padded = pad_ragged(flat, sizes, 5) + # Atom 0 has 2 elements, so positions [2:5] should be zero + assert (padded[0, 2:5, :] == 0).all() + + def test_negative_sizes_raises(self) -> None: + """pad_ragged should reject negative sizes.""" + sizes = torch.tensor([2, -1]) + flat = torch.randn(1, 3) + with pytest.raises(ValueError, match="non-negative"): + pad_ragged(flat, sizes, 5) + + def test_mismatched_data_length_raises(self) -> None: + """pad_ragged should reject data length != sum(sizes).""" + sizes = torch.tensor([2, 3]) + flat = torch.randn(6, 3) # should be 5 + with pytest.raises(ValueError, match="data length"): + pad_ragged(flat, sizes, 5) + + def test_unpad_negative_sizes_raises(self) -> None: + """unpad_ragged should reject negative sizes.""" + padded = torch.randn(2, 5, 3) + sizes = torch.tensor([2, -1]) + with pytest.raises(ValueError, match="non-negative"): + unpad_ragged(padded, sizes, 1) + + +class TestIrreps: + def test_irrep_from_string(self) -> None: + ir = Irrep("1e") + assert ir.l == 1 + assert ir.p == 1 + assert ir.dim == 3 + + def test_irrep_from_tuple(self) -> None: + ir = Irrep(2, -1) + assert ir.l == 2 + assert ir.p == -1 + assert ir.dim == 5 + + def test_irreps_from_string(self) -> None: + irreps = Irreps("3x0e+2x1o") + assert len(irreps) == 2 + assert irreps[0].mul == 3 + assert irreps[0].ir.l == 0 + assert irreps[1].mul == 2 + assert irreps[1].ir.l == 1 + assert irreps[1].ir.p == -1 + + def test_irreps_dim(self) -> None: + irreps = Irreps("3x0e+2x1e") + assert irreps.dim == 3 * 1 + 2 * 3 # 9 + + def test_irreps_spherical_harmonics(self) -> None: + irreps = Irreps.spherical_harmonics(2, p=1) + # 1x0e + 1x1e + 1x2e + assert irreps.dim == 1 + 3 + 5 + + def test_irreps_slices(self) -> None: + irreps = Irreps("2x0e+3x1e") + slices = irreps.slices() + assert slices[0] == slice(0, 2) + assert slices[1] == slice(2, 11) + + def test_irrep_multiplication(self) -> None: + ir0 = Irrep(0, 1) + ir1 = Irrep(1, -1) + products = list(ir0 * ir1) + assert len(products) == 1 + assert products[0].l == 1 + + def test_mul_ir_equality(self) -> None: + a = MulIr(3, Irrep(0, 1)) + b = MulIr(3, Irrep(0, 1)) + assert a.mul == b.mul and a.ir == b.ir + + def test_irreps_simplify(self) -> None: + irreps = Irreps("2x0e+3x0e+1x1e") + simplified = irreps.simplify() + assert simplified[0].mul == 5 + assert simplified[0].ir.l == 0 + + def test_irreps_sort(self) -> None: + irreps = Irreps("1x1e+1x0e") + sorted_irreps = irreps.sort() + assert sorted_irreps.irreps[0].ir.l == 0 + assert sorted_irreps.irreps[1].ir.l == 1 + + def test_irreps_lmax(self) -> None: + irreps = Irreps("1x0e+1x1e+1x3o") + assert irreps.lmax == 3 + + def test_irreps_mul_scalar(self) -> None: + irreps = Irreps("2x0e+1x1e") + scaled = irreps * 3 + # Irreps * int repeats the whole sequence 3 times + assert len(scaled) == 6 # 3 copies of [2x0e, 1x1e] + + +class TestSymmetricContraction: + @pytest.fixture(autouse=True) + def _isolated_rng(self) -> Iterator[None]: + with torch.random.fork_rng(): + yield + + def test_output_shape(self) -> None: + from e3nn import o3 as e3nn_o3 + + from skala.functional.utils.symmetric_contraction import SymmetricContraction + + torch.manual_seed(42) + irreps_in = e3nn_o3.Irreps("3x0e+3x1e") + irreps_out = e3nn_o3.Irreps("3x0e+3x1e") + sc = SymmetricContraction(irreps_in, irreps_out, correlation=2) + + x = torch.randn(5, irreps_in.dim) + out = sc(x) + assert out.shape == (5, irreps_out.dim) + + def test_output_deterministic(self) -> None: + from e3nn import o3 as e3nn_o3 + + from skala.functional.utils.symmetric_contraction import SymmetricContraction + + torch.manual_seed(42) + irreps_in = e3nn_o3.Irreps("3x0e+3x1e") + irreps_out = e3nn_o3.Irreps("3x0e+3x1e") + sc = SymmetricContraction(irreps_in, irreps_out, correlation=2) + + torch.manual_seed(123) + x = torch.randn(5, irreps_in.dim) + out1 = sc(x) + out2 = sc(x) + torch.testing.assert_close(out1, out2)