Skip to content

bhcao/deep-interatomic-potential-models

 
 

⚛️ DIPM: Deep Interatomic Potentials Models in JAX 🚀

uv Python 3.11 pre-commit Tests and Linters 🧪 badge

👀 Overview

dipm is a enhancement of MLIP standing for Deep Interatomic Potentials Models (DIPM). It provides the following functionality:

  • Multiple NNX model architectures (for now: MACE, NequIP, ViSNet, LiTEN, EquiformerV2 and UMA)
  • Dataset loading and preprocessing
  • Training and fine-tuning MLIP models
  • Batched inference with trained MLIP models
  • MD simulations with MLIP models using multiple simulation backends (for now: JAX-MD and ASE)
  • Batched MD simulations and energy minimizations with the JAX-MD simulation backend.
  • Energy minimizations with MLIP models using the same simulation backends as for MD.

The purpose of the library is to provide users with a toolbox to deal with MLIP models in true end-to-end fashion. Hereby we follow the key design principles of (1) easy-of-use also for non-expert users that mainly care about applying pre-trained models to relevant biological or material science applications, (2) extensibility and flexibility for users more experienced with MLIP and JAX, and (3) a focus on high inference speeds that enable running long MD simulations on large systems which we believe is necessary in order to bring MLIP to large-scale industrial application. See our inference speed benchmark below. With our library, we observe a 10x speedup on 138 atoms and up to 4x speed up on 1205 atoms over equivalent implementations relying on Torch and ASE.

See the Installation section for details on how to install MLIP-JAX and the example Google Colab notebooks linked below for a quick way to get started. For detailed instructions, visit our extensive code documentation.

This repository currently supports implementations of:

As the backend for equivariant operations, the current version of the code relies on the e3nn library.

📦 Installation

dipm can be installed via pip like this:

pip install dipm

However, this command only installs the regular CPU version of JAX. We recommend that the library is run on GPU. Use this command instead to install the GPU-compatible version:

pip install "dipm[cuda]"

This command installs the CUDA 12 version of JAX. For different versions, please install mlip without the cuda flag and install the desired JAX version via pip.

Simulation related tasks such as MD or energy minimization will require JAX-MD and ASE as dependencies. ASE can be installed as an optional dependency while the newest version of JAX-MD must be installed directly from the GitHub repository to avoid critical bugs. Here is the installation commands:

pip install git+https://github.com/jax-md/jax-md.git
pip install "dipm[cuda,md]"

To use TensorBoard or Weights and Biases logging in the training loop, install the corresponding optional dependencies:

pip install "dipm[cuda,visual]"

Furthermore, note that among our library dependencies we have pinned the versions for jaxlib, matscipy, and orbax-checkpoint to one specific version only to prioritize reliability, however, we plan to allow for a more flexible definition of our dependencies in upcoming releases.

📚 Dataset preparation

We only support HDF5 format datasets (compatible with HDF5 used in MACE). We provided a dataset conversion toolkit DIPM-Cvt for this purpose. We recommend to install it in a different environment than dipm to avoid conflicts. We provided a command-line interface dipm-cvt-cli for user-friendly usage.

⚡ Examples

For coding-free training, use python scripts/train.py scripts/train.yaml and adapt the config file (scripts/train.yaml) to your needs. See the documentation for more details.

For coding-free MD simulations with JAX-MD, run python scripts/run_md.py --model path_to_model --path path_to_structures. Run python scripts/run_md.py --help for more options.

In addition to the in-depth tutorials provided as part of our documentation here, we also provide example Jupyter notebooks that can be used as simple templates to build your own MLIP pipelines:

To run the tutorials, just install Jupyter notebooks via pip and launch it from a directory that contains the notebooks:

pip install notebook && jupyter notebook

The installation of mlip itself is included within the notebooks. We recommend to run these notebooks with GPU acceleration enabled.

🤗 Pre-trained models (via HuggingFace)

We have prepared pre-trained models trained on a subset of the SPICE2 dataset as described in MILP's white paper (See below). MACE-S / ViSNet-S are converted from pre-trained MLIP models, please refer to InstaDeep's MLIP collection for details, while LiTEN-M is trained from scratch. Models and dataset can be downloaded through the huggingface-hub Python API:

from huggingface_hub import hf_hub_download

hf_hub_download(repo_id="bhcao/dipm-pretrained-models", filename="mace_s_organics_mlip.safetensors", local_dir="")
hf_hub_download(repo_id="bhcao/dipm-pretrained-models", filename="visnet_s_organics_mlip.safetensors", local_dir="")
hf_hub_download(repo_id="bhcao/dipm-pretrained-models", filename="liten_m_organics_dipm.safetensors", local_dir="")
# hf_hub_download(repo_id="InstaDeepAI/nequip-organics", filename="nequip_organics_01.zip", local_dir="") # Broken
hf_hub_download(repo_id="InstaDeepAI/SPICE2-curated", filename="SPICE2_curated.zip", local_dir="")

Note that the pre-trained models are released on a different license than this library, please refer to the model cards of the relevant HuggingFace repos.

🚀 Inference time benchmarks

To showcase the runtime efficiency, we conducted benchmarks across all three models on two different systems: Chignolin (1UAO, 138 atoms) and Alpha-bungarotoxin (1ABT, 1205 atoms), both run for 1 ns of MD simulation on a H100 NVIDIA GPU. All these JAX-based model implementations are our own and should not be considered representative of the performance of the code developed by the original authors of the methods. In the table below, we compare our integrations with the JAX-MD and the ASE simulation engines, respectively. Further details can be found in our white paper (see below).

MACE (2,139,152 parameters):

Systems JAX-MD ASE
1UAO 6.3 ms/step 11.6 ms/step
1ABT 66.8 ms/step 99.5 ms/step

ViSNet (1,137,922 parameters):

Systems JAX-MD ASE
1UAO 2.9 ms/step 6.2 ms/step
1ABT 25.4 ms/step 46.4 ms/step

NequIP (1,327,792 parameters):

Systems JAX-MD ASE
1UAO 3.8 ms/step 8.5 ms/step
1ABT 67.0 ms/step 105.7 ms/step

🙏 Acknowledgments

We would like to acknowledge beta testers for this library: Isabel Wilkinson, Nick Venanzi, Hassan Sirelkhatim, Leon Wehrhan, Sebastien Boyer, Massimo Bortone, Scott Cameron, Louis Robinson, Tom Barrett, and Alex Laterre.

📝 License

The upstream repository MLIP is licensed under the Apache 2.0 license. This repository is licensed under the GNU Lesser General Public License v3.0.

📚 Citing

Here is the citing entry of original repository MILP:

C. Brunken, O. Peltre, H. Chomet, L. Walewski, M. McAuliffe, V. Heyraud, S. Attias, M. Maarand, Y. Khanfir, E. Toledo, F. Falcioni, M. Bluntzer, S. Acosta-Gutiérrez and J. Tilly, Machine Learning Interatomic Potentials: library for efficient training, model development and simulation of molecular systems, arXiv, 2025, arXiv:2505.22397.

About

Library for efficient training and application of Deep Interatomic Potential Models (DIPM)

Topics

Resources

License

GPL-3.0 and 2 other licenses found

Licenses found

GPL-3.0
LICENSE
LGPL-3.0
LICENSE.LESSER
Apache-2.0
LICENSE.MLIP

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 86.7%
  • Jupyter Notebook 13.2%
  • Dockerfile 0.1%