Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .devcontainer/build_docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env bash
# run this file from repo root

# exit if any command fails
set -e
set -o pipefail

FILE=".devcontainer/skala_dev.Dockerfile"
ENV_VARIANT="${1:-cpu}"
BASE_NAME="skala-dev-${ENV_VARIANT}"
TAG=${BASE_NAME}":"$(date +"%Y%m%dT%H%M%S")

echo "Building Docker image with tag \"${TAG}\" (ENV_VARIANT=${ENV_VARIANT})"

# To ignore the cache, use --no-cache
docker build \
--progress=plain \
--build-arg ENV_VARIANT="${ENV_VARIANT}" \
--tag=${TAG} \
--file=${FILE} \
. \
2>&1 | tee -a "build_${TAG}.log"

docker tag ${TAG} ${BASE_NAME}":latest"
52 changes: 52 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"name": "skala-dev-gpu",
"build": {
"dockerfile": "skala_dev.Dockerfile",
"context": "..",
"args": {
"ENV_VARIANT": "gpu"
}
},
"postCreateCommand": "",
"containerEnv": {
"PYTHONPATH": "${containerWorkspaceFolder}/src/"
},
"runArgs": [
"--shm-size=2gb",
"--network=host"
],
"forwardPorts": [],
"containerUser": "ubuntu", // or "${localEnv:USER}",
"features": {
"ghcr.io/devcontainers/features/common-utils:2": {
"installZsh": false,
"installOhMyZsh": false,
"upgradePackages": false
// If a user with this UID already exists in the image, comment the following lines and set "containerUser" to that user instead.
// "username": "${localEnv:USER}",
// "userUid": "${localEnv:UID}",
// "userGid": "${localEnv:UID}"
}
},
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"charliermarsh.ruff",
"ms-python.mypy-type-checker"
],
"settings": {
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
},
"python.testing.pytestEnabled": true,
"python.testing.pytestArgs": [
"${containerWorkspaceFolder}/tests/"
],
"mypy-type-checker.args": [
"--config-file=pyproject.toml"
]
}
}
}
}
43 changes: 43 additions & 0 deletions .devcontainer/skala_dev.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# syntax=docker/dockerfile:1
FROM ubuntu:noble

RUN apt-get update --quiet \
&& apt-get install --yes --quiet --no-install-recommends \
wget \
ca-certificates \
&& apt-get clean --yes \
&& rm -rf /var/lib/apt/lists/*

SHELL [ "/bin/bash", "-c" ]

# mamba installation (mamba>2.0)
ENV MAMBA_DIR=/opt/miniforge3
ENV MAMBA_ROOT_PREFIX=${MAMBA_DIR}
ENV PATH=${PATH}:${MAMBA_DIR}/bin

RUN wget --no-hsts --quiet --output-document=miniforge.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \
&& bash miniforge.sh -b -p ${MAMBA_DIR} \
&& rm miniforge.sh \
# Remove python3.1 symlink if it exists, as it causes issues with conda
# https://github.com/conda/conda/issues/11423
&& (test -L ${MAMBA_DIR}/bin/python3.1 && unlink ${MAMBA_DIR}/bin/python3.1 || true) \
&& ${MAMBA_DIR}/bin/mamba update -n base --all -y \
&& mamba clean --all --yes \
# Mamba initialization script
&& echo "eval $(mamba shell hook --shell bash)" >> /etc/profile.d/source_mamba.sh \
# for interactive shells:
&& echo "source /etc/profile.d/source_mamba.sh" >> /etc/bash.bashrc

# for non-interactive, not login shells:
# https://www.solipsys.co.uk/images/BashStartupFiles1.png
ENV BASH_ENV="/etc/profile.d/source_mamba.sh"

# create environment (ENV_VARIANT: "cpu" or "gpu")
ARG ENV_VARIANT=cpu
COPY ./environment-${ENV_VARIANT}.yml ./environment.yml
# CONDA_OVERRIDE_CUDA lets mamba solve CUDA deps without a GPU present during build
# Cache the mamba package downloads so rebuilds after environment.yml changes are faster.
# The cache mount is not part of the image layer, so mamba clean is unnecessary.
RUN --mount=type=cache,target=${MAMBA_DIR}/pkgs \
CONDA_OVERRIDE_CUDA="12.0" mamba env create --file environment.yml \
&& rm environment.yml
5 changes: 5 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ jobs:
touch docs/_build/html/.nojekyll
shell: micromamba-shell {0}

- name: Check external links
run: |
sphinx-build -b linkcheck docs docs/_build/linkcheck
shell: micromamba-shell {0}

- name: Upload Pages artifact
uses: actions/upload-pages-artifact@v3
with:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ jobs:

- name: Download checkpoint
run: >-
hf download microsoft/skala-1.0 skala-1.0.fun --local-dir .
hf download microsoft/skala-1.1 skala-1.1.fun --local-dir .

- name: Upload checkpoint
uses: actions/upload-artifact@v4
with:
name: skala-checkpoint
path: skala-1.0.fun
path: skala-1.1.fun

pt-features:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -153,7 +153,7 @@ jobs:
run: >-
Skala
./gauxc/tests/ref_data/onedft_he_def2qzvp_tpss_uks.hdf5
--model ./skala-1.0.fun
--model ./skala-1.1.fun
shell: micromamba-shell {0}

ftorch:
Expand Down Expand Up @@ -205,6 +205,6 @@ jobs:
- name: Run example
run: >-
Skala
./skala-1.0.fun
./skala-1.1.fun
./features
shell: micromamba-shell {0}
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,6 @@ FodyWeavers.xsd

# Sphinx documentation build output
docs/_build/

# Checkpoint files
*.fun
24 changes: 8 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,16 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.4
rev: v0.15.10
hooks:
# Run the linter.
- id: ruff
args: [--fix, --select, I]
exclude: '^(third_party)/.*'
args: [--config, pyproject.toml, --fix]
# Run the formatter.
- id: ruff-format
exclude: '^(third_party)/.*'
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1
hooks:
- id: mypy
exclude: '^(third_party|tests)/.*'
args:
- --strict
- --ignore-missing-imports
- --no-namespace-packages
- --python-version=3.11
entry: mypy
args: [--config, pyproject.toml]
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.20.0
# hooks:
# - id: mypy
# args: [--config-file, pyproject.toml]
75 changes: 58 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
[![PyPI](https://img.shields.io/pypi/v/skala?logo=pypi&logoColor=white)](https://pypi.org/project/skala/)
[![Paper](https://img.shields.io/badge/arXiv-2506.14665-b31b1b?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2506.14665)

Skala is a neural network-based exchange-correlation functional for density functional theory (DFT), developed by Microsoft Research AI for Science. It leverages deep learning to predict exchange-correlation energies from electron density features, achieving chemical accuracy for atomization energies and strong performance on broad thermochemistry and kinetics benchmarks, all at a computational cost similar to semi-local DFT.
Skala is a neural network-based exchange-correlation functional for density functional theory (DFT), developed by Microsoft Research AI for Science. It uses deep learning to predict exchange-correlation energies from electron density features, achieving chemical accuracy for atomization energies and strong performance on broad thermochemistry and kinetics benchmarks, all at a computational cost similar to semi-local DFT.

Trained on a large, diverse dataset—including coupled cluster atomization energies and public benchmarks—Skala uses scalable message passing and local layers to learn both local and non-local effects. The model has about 276,000 parameters and matches the accuracy of leading hybrid functionals.
Trained on a large, diverse dataset — including coupled-cluster atomization energies and public benchmarks — Skala uses scalable message passing and local layers to learn both local and non-local effects. The model has about 276,000 parameters and matches the accuracy of leading hybrid functionals.

The recommended neural functional is `skala-1.1`, which uses per-atom packed grids, multiple non-local layers, and symmetric contraction. The legacy `skala-1.0` traced model is still loadable via `load_functional("skala-1.0")`.

Learn more about Skala in our [ArXiv paper](https://arxiv.org/abs/2506.14665).

## What's in here

This repository contains two main components:

1. The Python package `skala`, which is also distributed [on PyPI](https://pypi.org/project/skala/) and contains a PyTorch implementation of the Skala model, its hookups to quantum chemistry packages [PySCF](https://pyscf.org/), [GPU4PySCF](https://pyscf.org/user/gpu.html) and [ASE](https://ase-lib.org/).
2. Example of using Skala applications through LibTorch and GauXC, the following examples are available
1. The Python package `skala`, distributed [on PyPI](https://pypi.org/project/skala/) and on conda-forge. It contains a PyTorch implementation of the Skala model and its bindings to the quantum-chemistry packages [PySCF](https://pyscf.org/), [GPU4PySCF](https://pyscf.org/user/gpu.html), and [ASE](https://ase-lib.org/).
2. Examples of using Skala from compiled code through LibTorch and GauXC:
- [Skala in C++ with libtorch](examples/cpp/cpp_integration)
- [Skala in Fortran with FTorch](https://microsoft.github.io/skala/ftorch)
- [Skala in C++ with GauXC](https://microsoft.github.io/skala/gauxc/cpp-library)
Expand All @@ -39,18 +41,23 @@ For detailed documentation on using GauXC visit the [Skala integration guide](ht

All information below relates to the Python package `skala`.

Install using Pip:
`pip install skala` works out of the box and pulls every dependency from PyPI.
If you don't already have PyTorch installed, install the CPU-only wheel first
to avoid pulling a large CUDA build:

```bash
# Install CPU-only PyTorch (skip if you already have CPU or GPU-enabled PyTorch installed)
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install skala
```

Or using Conda (Mamba):
For a reproducible conda environment, use the provided
[`environment-cpu.yml`](environment-cpu.yml), which pins CPU-only PyTorch and
all runtime dependencies:

```bash
mamba install -c conda-forge skala "pytorch=*=cpu*"
mamba env create -n skala -f environment-cpu.yml
mamba activate skala
pip install skala
```

Run an SCF calculation with Skala for a hydrogen molecule:
Expand All @@ -63,21 +70,35 @@ mol = gto.M(
atom="""H 0 0 0; H 0 0 1.4""",
basis="def2-tzvp",
)
ks = SkalaKS(mol, xc="skala")
ks = SkalaKS(mol, xc="skala-1.1")
ks.kernel()
```

## Getting started: GPU4PySCF (GPU)
## Getting started: GPU4PySCF (GPU)

These instructions use Mamba and pip to install CUDA toolkit, Torch, and CuPy. It supports CUDA version 11, 12 or 13. You can find the most recent CUDA version that is supported on your system using `nvidia-smi`.
The GPU install is more involved because `gpu4pyscf` ships CUDA-version-specific
wheels that must match your CUDA toolkit. The recommended path is the provided
[`environment-gpu.yml`](environment-gpu.yml), which pins `pytorch-gpu`,
`cuda-toolkit` 12, `cutensor`, and installs `gpu4pyscf-cuda12x` 1.5 from PyPI:

```bash
cu_version=12 #or 11 or 13 depending on your CUDA version
mamba env create -n skala -f environment-gpu.yml "cuda-version==${cu_version}.*" skala
mamba env create -n skala -f environment-gpu.yml
mamba activate skala
pip install --no-deps "gpu4pyscf-cuda${cu_version}x>=1.0,<2" "gpu4pyscf-libxc-cuda${cu_version}x>=0.4,<1"
pip install skala
```

If you are building inside a container without a GPU attached (e.g., CI or a
Docker image built on a CPU-only host), set `CONDA_OVERRIDE_CUDA` so the solver
proceeds without a device:

```bash
CONDA_OVERRIDE_CUDA=12.0 mamba env create -n skala -f environment-gpu.yml
```

For CUDA 11 or 13, adjust `cuda-toolkit`, `cuda-version`, and the
`gpu4pyscf-cuda{11,13}x` pin in `environment-gpu.yml` accordingly. Check your
driver's maximum supported CUDA version with `nvidia-smi`.

Run an SCF calculation with Skala for a hydrogen molecule on GPU:

```python
Expand All @@ -88,19 +109,39 @@ mol = gto.M(
atom="""H 0 0 0; H 0 0 1.4""",
basis="def2-tzvp",
)
ks = SkalaKS(mol, xc="skala")
ks = SkalaKS(mol, xc="skala-1.1")
ks.kernel()
```

## Getting started: ASE calculator

Skala also provides an [ASE](https://wiki.fysik.dtu.dk/ase/) calculator for energy, force, and geometry optimization workflows:

```python
from ase.build import molecule
from ase.optimize import LBFGSLineSearch
from skala.ase import Skala

atoms = molecule("H2O")
atoms.calc = Skala(xc="skala-1.1", basis="def2-tzvp")

# Single-point energy (eV)
print(atoms.get_potential_energy())

# Geometry optimization
opt = LBFGSLineSearch(atoms)
opt.run(fmax=0.01)
```

## Documentation and examples

Go to [microsoft.github.io/skala](https://microsoft.github.io/skala) for a more detailed installation guide and further examples of how to use the Skala functional with PySCF, GPU4PySCF and ASE and in [Azure AI Foundry](https://ai.azure.com/catalog/models/Skala).
See [microsoft.github.io/skala](https://microsoft.github.io/skala) for a more detailed installation guide and further examples of how to use the Skala functional with PySCF, GPU4PySCF, and ASE, as well as in [Azure AI Foundry](https://ai.azure.com/catalog/models/Skala).

## Security: loading `.fun` files

Skala model files (`.fun`) use TorchScript serialization, which can execute arbitrary code when loaded. **Never load `.fun` files from untrusted sources.**

When loading the official Skala model via `load_functional("skala")`, file integrity is automatically verified against pinned SHA-256 hashes before deserialization. If you load `.fun` files directly with `TracedFunctional.load()`, pass the `expected_hash` parameter to enable verification:
When loading the official Skala models via `load_functional("skala-1.1")` or `load_functional("skala-1.0")`, file integrity is automatically verified against pinned SHA-256 hashes before deserialization. If you load `.fun` files directly with `TracedFunctional.load()`, pass the `expected_hash` parameter to enable verification:

```python
TracedFunctional.load("model.fun", expected_hash="<sha256-hex-digest>")
Expand Down
2 changes: 2 additions & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Artifacts generated by executing notebooks during the Sphinx build.
water_opt.traj
Loading
Loading