Skip to content

Accelerated TensorNet with NVIDIA Warp Kernels#709

Open
zubatyuk wants to merge 19 commits intomaterialyzeai:pure_torchfrom
zubatyuk:pure_torch
Open

Accelerated TensorNet with NVIDIA Warp Kernels#709
zubatyuk wants to merge 19 commits intomaterialyzeai:pure_torchfrom
zubatyuk:pure_torch

Conversation

@zubatyuk
Copy link
Contributor

@zubatyuk zubatyuk commented Jan 9, 2026

This PR integrates GPU-accelerated custom kernels for TensorNet operations using NVIDIA Warp and updates the model implementation accordingly.

Summary of Changes

New Warp Kernel Infrastructure:

Added matgl.kernels package with Warp kernel generators for key tensor operations:

  • compose_tensor / decompose_tensor – Tensor composition and decomposition operations
  • equivariant_o3_matmul / equivariant_so3_matmul – Equivariant matrix multiplication for O(3) and SO(3) groups
  • tensor_norm3 – Frobenius norm computation for rank-3 tensors
  • tensornet_mp – TensorNet message passing
  • tensornet_radial_mp – Radial message passing for tensor embedding
  • graph_transform – Utilities for converting edge indices to CSR format

PyTorch Custom Ops with Autograd Support:

  • Added matgl.ops package providing PyTorch-compatible wrappers for Warp kernels:
  • All ops registered via torch.library.custom_op for seamless integration
  • Full forward and backward pass support with custom autograd functions
  • Works with both CPU and CUDA devices

Model Refactoring:

  • Refactored TensorNet model in _tensornet_pyg.py to use new Warp-based operations
  • Replaced inline PyTorch operations with fused kernel calls for better performance
  • Updated TensorEmbedding and TensorNetInteraction layers to use CSR-based message passing

Dependencies:
Added warp-lang>=10.1 to project dependencies

Testing:
Added dev/test_tensornet_forward_backward.py for comparing forward/backward/double-backward passes with reference implementation.

Since we introduce custom backward kernels, tests/models/test_tensornet_pyg.py now tests also backward and double backward.

Technical Notes

Kernels use Warp's code generation for type-specialized implementations (fp16/fp32/fp64)

Message passing operations leverage CSR sparse format for efficient neighbor aggregation

The implementation maintains numerical equivalence with the original PyTorch operations

Signed-off-by: Roman Zubatyuk <[email protected]>
Signed-off-by: Roman Zubatyuk <[email protected]>
Signed-off-by: Roman Zubatyuk <[email protected]>
Signed-off-by: Roman Zubatyuk <[email protected]>
@shyuep
Copy link
Contributor

shyuep commented Jan 10, 2026

@kenko911 you need to review this asap. I am not sure why the unittest Github actions are not running on this PR.

@kenko911
Copy link
Collaborator

Hi @zubatyuk, thanks for the contributions! Could you please invite me to your repo? I think the first step is to update the repo to match the latest version of MatGL and we need to make a series of united tests for the new components.

@kenko911
Copy link
Collaborator

Hi @zubatyuk,
I’ve started adding a few unit tests for the kernel components, but I consistently run into errors originating from the Warp library during test collection. Would you mind taking a look when you have a moment?

The issue appears during module import, when generate_compose_tensor is called, and results in the following error:

(stack trace)

NameError: name 'dim' is not defined

Any insight into how this should be handled (e.g., lazy kernel construction or annotation changes) would be greatly appreciated.

@zubatyuk
Copy link
Contributor Author

@kenko911 can you point to the error? I did not do units tests for layers, these are very basic and cover forward output shapes. I did end to end model tests instead.

@kenko911
Copy link
Collaborator

@zubatyuk I also ran pytest on test_tensornet_pyg.py after cloning your repository and switching to the pure_torch branch, but I still encountered the same error.

(mavrl) kenko@KenkosMachine-100 matgl % pytest tests/models/test_tensornet_pyg.py

================================================================ ERRORS ================================================================
_________________________________________ ERROR collecting tests/models/test_tensornet_pyg.py __________________________________________
tests/models/test_tensornet_pyg.py:14: in
from matgl.models._tensornet_pyg import TensorNet
src/matgl/models/init.py:16: in
from ._tensornet_pyg import TensorNet # type: ignore[assignment]
src/matgl/models/_tensornet_pyg.py:31: in
from matgl.ops import (
src/matgl/ops/init.py:34: in
from .compose_tensor import fn_compose_tensor
src/matgl/ops/compose_tensor.py:34: in
from matgl.kernels import get_module, get_stream
src/matgl/kernels/init.py:34: in
from .compose_tensor import generate_compose_tensor
src/matgl/kernels/compose_tensor.py:230: in
) = generate_compose_tensor("float64")
src/matgl/kernels/compose_tensor.py:208: in generate_compose_tensor
wp.Kernel(
../../../miniconda3/envs/mavrl/lib/python3.11/site-packages/warp/_src/context.py:749: in init
self.adj = warp._src.codegen.Adjoint(func, transformers=code_transformers, source=source)
../../../miniconda3/envs/mavrl/lib/python3.11/site-packages/warp/_src/codegen.py:1038: in init
argspec = get_full_arg_spec(func)
../../../miniconda3/envs/mavrl/lib/python3.11/site-packages/warp/_src/codegen.py:246: in get_full_arg_spec
return spec._replace(annotations=inspect.get_annotations(func, eval_str=True, locals=closure_vars))
../../../miniconda3/envs/mavrl/lib/python3.11/inspect.py:276: in get_annotations
return_value = {key:
../../../miniconda3/envs/mavrl/lib/python3.11/inspect.py:277: in
value if not isinstance(value, str) else eval(value, globals, locals)
:1: in
???
E NameError: name 'dim' is not defined
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1 error in 0.82s

@zubatyuk
Copy link
Contributor Author

ruff --fix introduced from __future__ import annotations

Warp uses inspect.get_annotations() to evaluate the function's type annotations. It tries to evaluate strings like wp.array(ndim=dim, dtype=dtype_wp). However, dim is a local variable defined inside generate_compose_tensor. When Python tries to evaluate the annotation string wp.array(ndim=dim, dtype=dtype_wp), it doesn't have access to the local variable dim from the enclosing function's scope — hence the NameError: name 'dim' is not defined.

Also linting introduced few other problems with the kernels. Fixed and added ignore for "I002" in "src/matgl/kernels/*".

@zubatyuk
Copy link
Contributor Author

Done few fixes:

1. Fix double backward support in tensor kernels
File: src/matgl/ops/equivariant_o3_matmul.py, src/matgl/ops/equivariant_so3_matmul.py, src/matgl/ops/tensornet_radial_mp.py
Description: Added initialization for grad_grad tensors in backward-backward pass of equivariant matmul operations to handle cases where gradients are None. Updated fake tensor registration for radial message passing to include the S component gradient.
Fixes: Issues with double backward pass (e.g., force calculations) and autograd correctness.

2. Fix mypy configuration and type annotations
File: pyproject.toml, src/matgl/models/_tensornet_pyg.py
Description: Configured mypy to ignore the matgl.kernels module which uses Warp DSL (incompatible with mypy). Fixed type annotation errors in _tensornet_pyg.py by updating error codes and adding missing ignores.
Fixes: Mypy validation errors and CI lint checks.

@sef43
Copy link

sef43 commented Jan 27, 2026

Hello, I am a developer who uses the https://github.com/torchmd/torchmd-net version of tensornet, I have been looking at these kernels, they seem very useful!

I have noticed a potential small error. In the backwards function of ops/tensor_mp.py I think the col_* tensors should be used instead of the row_* ones:

row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True)
row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True)
row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True)

For the case of symmetric edge_attr (the default in tensornet when they are just based on rbfs) I don't think there will be any error, but not non-symmetric edge_attr I find it gives incorrect gradients (observed via exploding loss when trying to train)

zubatyuk and others added 2 commits January 28, 2026 18:44
…_bwd

The backward primitive was incorrectly using row_* (CSR) tensors instead of
col_* (CSC) tensors. For gradient computation, the transpose of the forward
sparse matrix is needed. This caused incorrect gradients for non-symmetric
edge attributes.

Also updates test script:
- Fix pymatgen API compatibility (site.specie -> site.species_string)
- Loosen parameter gradient thresholds from 1e-5 to 5e-5 for double backward
@zubatyuk
Copy link
Contributor Author

Thanks for catching this! I have pushed a fix in 3242b35 that switches these to use the col_* tensors.

@Andrew-S-Rosen
Copy link
Contributor

Just wanted to say we are very interested in this as well! Looking forward to seeing it merged.

@atulcthakur
Copy link
Contributor

Now that the neighborlist PR is merged, we (@atulcthakur @kenko911) are going to work on moving this PR over to the main branch.
We don't plan on maintaining a separate pure_torch branch, as the main branch models are already in torch.

Plan is to move over the newly introduced matgl.kernels and matgl.ops package.
The kernels will be a part of a new TensorNetWarp model leaving the current TensorNet implementation untouched for future flexibility and development. All changes will be based on this PR by @zubatyuk
We will discuss if we need tests for each individual kernel introduced in matgl.kernels

Any and all suggestions are welcome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants