Accelerated TensorNet with NVIDIA Warp Kernels#709
Accelerated TensorNet with NVIDIA Warp Kernels#709zubatyuk wants to merge 19 commits intomaterialyzeai:pure_torchfrom
Conversation
Signed-off-by: Roman Zubatyuk <[email protected]>
Signed-off-by: Roman Zubatyuk <[email protected]>
Signed-off-by: Roman Zubatyuk <[email protected]>
Signed-off-by: Roman Zubatyuk <[email protected]>
|
@kenko911 you need to review this asap. I am not sure why the unittest Github actions are not running on this PR. |
|
Hi @zubatyuk, thanks for the contributions! Could you please invite me to your repo? I think the first step is to update the repo to match the latest version of MatGL and we need to make a series of united tests for the new components. |
|
Hi @zubatyuk, The issue appears during module import, when generate_compose_tensor is called, and results in the following error: (stack trace) NameError: name 'dim' is not defined Any insight into how this should be handled (e.g., lazy kernel construction or annotation changes) would be greatly appreciated. |
|
@kenko911 can you point to the error? I did not do units tests for layers, these are very basic and cover forward output shapes. I did end to end model tests instead. |
|
@zubatyuk I also ran pytest on test_tensornet_pyg.py after cloning your repository and switching to the pure_torch branch, but I still encountered the same error. (mavrl) kenko@KenkosMachine-100 matgl % pytest tests/models/test_tensornet_pyg.py ================================================================ ERRORS ================================================================ |
|
Warp uses Also linting introduced few other problems with the kernels. Fixed and added ignore for "I002" in "src/matgl/kernels/*". |
|
Done few fixes: 1. Fix double backward support in tensor kernels 2. Fix mypy configuration and type annotations |
|
Hello, I am a developer who uses the https://github.com/torchmd/torchmd-net version of tensornet, I have been looking at these kernels, they seem very useful! I have noticed a potential small error. In the backwards function of matgl/src/matgl/ops/tensornet_mp.py Lines 151 to 153 in 127d7a0 For the case of symmetric edge_attr (the default in tensornet when they are just based on rbfs) I don't think there will be any error, but not non-symmetric edge_attr I find it gives incorrect gradients (observed via exploding loss when trying to train) |
…_bwd The backward primitive was incorrectly using row_* (CSR) tensors instead of col_* (CSC) tensors. For gradient computation, the transpose of the forward sparse matrix is needed. This caused incorrect gradients for non-symmetric edge attributes. Also updates test script: - Fix pymatgen API compatibility (site.specie -> site.species_string) - Loosen parameter gradient thresholds from 1e-5 to 5e-5 for double backward
|
Thanks for catching this! I have pushed a fix in 3242b35 that switches these to use the col_* tensors. |
|
Just wanted to say we are very interested in this as well! Looking forward to seeing it merged. |
|
Now that the neighborlist PR is merged, we (@atulcthakur @kenko911) are going to work on moving this PR over to the main branch. Plan is to move over the newly introduced Any and all suggestions are welcome. |
This PR integrates GPU-accelerated custom kernels for TensorNet operations using NVIDIA Warp and updates the model implementation accordingly.
Summary of Changes
New Warp Kernel Infrastructure:
Added matgl.kernels package with Warp kernel generators for key tensor operations:
compose_tensor/decompose_tensor– Tensor composition and decomposition operationsequivariant_o3_matmul/equivariant_so3_matmul– Equivariant matrix multiplication for O(3) and SO(3) groupstensor_norm3– Frobenius norm computation for rank-3 tensorstensornet_mp– TensorNet message passingtensornet_radial_mp– Radial message passing for tensor embeddinggraph_transform– Utilities for converting edge indices to CSR formatPyTorch Custom Ops with Autograd Support:
torch.library.custom_opfor seamless integrationModel Refactoring:
_tensornet_pyg.pyto use new Warp-based operationsTensorEmbeddingandTensorNetInteractionlayers to use CSR-based message passingDependencies:
Added
warp-lang>=10.1to project dependenciesTesting:
Added dev/test_tensornet_forward_backward.py for comparing forward/backward/double-backward passes with reference implementation.
Since we introduce custom backward kernels,
tests/models/test_tensornet_pyg.pynow tests also backward and double backward.Technical Notes
Kernels use Warp's code generation for type-specialized implementations (fp16/fp32/fp64)
Message passing operations leverage CSR sparse format for efficient neighbor aggregation
The implementation maintains numerical equivalence with the original PyTorch operations