From a54b0803bb3d17ef25059ea67cd0c4eebe35169a Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Wed, 17 Dec 2025 17:27:07 -0500 Subject: [PATCH 01/18] initial add nvtnet Signed-off-by: Roman Zubatyuk --- pyproject.toml | 1 + src/matgl/kernels/__init__.py | 56 ++ src/matgl/kernels/compose_tensor.py | 242 +++++++++ src/matgl/kernels/decompose_tensor.py | 232 ++++++++ src/matgl/kernels/equivariant_o3_matmul.py | 212 ++++++++ src/matgl/kernels/equivariant_so3_matmul.py | 203 +++++++ src/matgl/kernels/graph_transform.py | 70 +++ src/matgl/kernels/tensor_norm3.py | 217 ++++++++ src/matgl/kernels/tensornet_mp.py | 328 ++++++++++++ src/matgl/kernels/tensornet_radial_mp.py | 450 ++++++++++++++++ src/matgl/kernels/utils.py | 106 ++++ src/matgl/models/_tensornet_pyg.py | 461 +++++----------- src/matgl/ops/__init__.py | 52 ++ src/matgl/ops/compose_tensor.py | 228 ++++++++ src/matgl/ops/decompose_tensor.py | 237 +++++++++ src/matgl/ops/equivariant_o3_matmul.py | 230 ++++++++ src/matgl/ops/equivariant_so3_matmul.py | 226 ++++++++ src/matgl/ops/graph_transform.py | 176 +++++++ src/matgl/ops/tensor_norm3.py | 195 +++++++ src/matgl/ops/tensornet_mp.py | 554 ++++++++++++++++++++ src/matgl/ops/tensornet_radial_mp.py | 418 +++++++++++++++ 21 files changed, 4562 insertions(+), 332 deletions(-) create mode 100644 src/matgl/kernels/__init__.py create mode 100644 src/matgl/kernels/compose_tensor.py create mode 100644 src/matgl/kernels/decompose_tensor.py create mode 100644 src/matgl/kernels/equivariant_o3_matmul.py create mode 100644 src/matgl/kernels/equivariant_so3_matmul.py create mode 100644 src/matgl/kernels/graph_transform.py create mode 100644 src/matgl/kernels/tensor_norm3.py create mode 100644 src/matgl/kernels/tensornet_mp.py create mode 100644 src/matgl/kernels/tensornet_radial_mp.py create mode 100644 src/matgl/kernels/utils.py create mode 100644 src/matgl/ops/__init__.py create mode 100644 src/matgl/ops/compose_tensor.py create mode 100644 src/matgl/ops/decompose_tensor.py create mode 100644 src/matgl/ops/equivariant_o3_matmul.py create mode 100644 src/matgl/ops/equivariant_so3_matmul.py create mode 100644 src/matgl/ops/graph_transform.py create mode 100644 src/matgl/ops/tensor_norm3.py create mode 100644 src/matgl/ops/tensornet_mp.py create mode 100644 src/matgl/ops/tensornet_radial_mp.py diff --git a/pyproject.toml b/pyproject.toml index 90635ca3..24e78f7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ classifiers = [ dependencies = [ "ase", "torch<=2.7.0", # TODO: Remove this pin. For some reason, torch 2.9 gives different results. + "warp-lang>=10.1", "torchdata", "pymatgen", "lightning<=2.6.0.dev20251123", diff --git a/src/matgl/kernels/__init__.py b/src/matgl/kernels/__init__.py new file mode 100644 index 00000000..ee234854 --- /dev/null +++ b/src/matgl/kernels/__init__.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .compose_tensor import generate_compose_tensor +from .decompose_tensor import generate_decompose_tensor +from .equivariant_o3_matmul import generate_tensor_matmul_o3_3x3 +from .equivariant_so3_matmul import generate_tensor_matmul_so3_3x3 +from .graph_transform import count_row_col, convert_to_sparse +from .tensor_norm3 import generate_tensor_norm3 +from .tensornet_mp import generate_message_passing +from .tensornet_radial_mp import generate_radial_message_passing +from .utils import add_module, get_module, get_stream + +import warp as wp +wp.init() + + +__all__ = [ + "generate_compose_tensor", + "generate_decompose_tensor", + "generate_tensor_matmul_o3_3x3", + "generate_tensor_matmul_so3_3x3", + "generate_radial_message_passing", + "generate_message_passing", + "generate_tensor_norm3", + "count_row_col", + "convert_to_sparse", + "add_module", + "get_module", + "get_stream", +] \ No newline at end of file diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py new file mode 100644 index 00000000..91d46b3a --- /dev/null +++ b/src/matgl/kernels/compose_tensor.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_compose_tensor(dtype: str, h_last: bool = True, use_irmem: bool = True): + dtype_wp = get_wp_fp_dtype(dtype) + if not use_irmem: + raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") + if not h_last: + raise ValueError(f"only supporting h_last True but got {h_last}") + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + if use_irmem: + dim = 3 + else: + dim = 4 + + def compose_tensor_fwd( + I: wp.array(ndim=dim, dtype=dtype_wp), + A: wp.array(ndim=dim, dtype=dtype_wp), + S: wp.array(ndim=dim, dtype=dtype_wp), + X: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + X_reg = mat3x3() + + I_reg = I[b, 0, h] + A_reg = vec3() + S_reg = vec5() + + for i in range(3): + A_reg[i] = A[b, i, h] + + for i in range(5): + S_reg[i] = S[b, i, h] + + for i in range(3): + X_reg[i, i] += I_reg + + cnt = int(0) + for i in range(3): + for j in range(i + 1, 3): + X_reg[i, j] += A_reg[cnt] + X_reg[j, i] -= A_reg[cnt] + cnt += 1 + + trace_S = -(S_reg[0] + S_reg[3]) + cnt = int(0) + for i in range(2): + X_reg[i, i] += S_reg[cnt] + cnt += 1 + for j in range(i + 1, 3): + X_reg[i, j] += S_reg[cnt] + X_reg[j, i] += S_reg[cnt] + cnt += 1 + + X_reg[2, 2] += trace_S + + for i in range(3): + for j in range(3): + X[b, i, j, h] = X_reg[i, j] + + def compose_tensor_bwd( + dX: wp.array(ndim=4, dtype=dtype_wp), + dI: wp.array(ndim=dim, dtype=dtype_wp), + dA: wp.array(ndim=dim, dtype=dtype_wp), + dS: wp.array(ndim=dim, dtype=dtype_wp), + ): + b, h = wp.tid() + + dX_reg = mat3x3() + for i in range(3): + for j in range(3): + dX_reg[i, j] = dX[b, i, j, h] + + dI_reg = dI.dtype(0) + dA_reg = vec3(dX.dtype(0)) + dS_reg = vec5(dX.dtype(0)) + + for i in range(3): + dI_reg += dX_reg[i, i] + + cnt = int(0) + for i in range(3): + for j in range(i + 1, 3): + dA_reg[cnt] += dX_reg[i, j] + dA_reg[cnt] -= dX_reg[j, i] + cnt += int(1) + + dS_reg[0] += dX_reg[0, 0] + dS_reg[0] -= dX_reg[2, 2] + + dS_reg[1] += dX_reg[0, 1] + dS_reg[1] += dX_reg[1, 0] + + dS_reg[2] += dX_reg[0, 2] + dS_reg[2] += dX_reg[2, 0] + + dS_reg[3] += dX_reg[1, 1] + dS_reg[3] -= dX_reg[2, 2] + + dS_reg[4] += dX_reg[1, 2] + dS_reg[4] += dX_reg[2, 1] + + dI[b, 0, h] = dI_reg + + for i in range(3): + dA[b, i, h] = dA_reg[i] + + for i in range(5): + dS[b, i, h] = dS_reg[i] + + def compose_tensor_bwd_bwd( + dI: wp.array(ndim=dim, dtype=dtype_wp), + dA: wp.array(ndim=dim, dtype=dtype_wp), + dS: wp.array(ndim=dim, dtype=dtype_wp), + d2X: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + d2X_reg = mat3x3() + + dI_reg = dI[b, 0, h] + dA_reg = vec3(dI.dtype(0)) + dS_reg = vec5(dI.dtype(0)) + + for i in range(3): + dA_reg[i] = dA[b, i, h] + + for i in range(5): + dS_reg[i] = dS[b, i, h] + + for i in range(3): + d2X_reg[i, i] += dI_reg + + cnt = int(0) + for i in range(3): + for j in range(i + 1, 3): + d2X_reg[i, j] += dA_reg[cnt] + d2X_reg[j, i] -= dA_reg[cnt] + cnt += int(1) + + cnt = int(0) + for i in range(2): + d2X_reg[i, i] += dS_reg[cnt] + cnt += int(1) + + for j in range(i + 1, 3): + d2X_reg[i, j] += dS_reg[cnt] + d2X_reg[j, i] += dS_reg[cnt] + cnt += int(1) + + d2X_reg[2, 2] -= dS_reg[0] + d2X_reg[2, 2] -= dS_reg[3] + + for i in range(3): + for j in range(3): + d2X[b, i, j, h] = d2X_reg[i, j] + + return ( + wp.Kernel( + compose_tensor_fwd, + key=f"compose_tensor_{dtype}", + module=wp.get_module(f"compose_tensor_{dtype}"), + ), + wp.Kernel( + compose_tensor_bwd, + key=f"compose_tensor_bwd_{dtype}", + module=wp.get_module(f"compose_tensor_bwd_{dtype}"), + ), + wp.Kernel( + compose_tensor_bwd_bwd, + key=f"compose_tensor_bwd_bwd_{dtype}", + module=wp.get_module(f"compose_tensor_bwd_bwd_{dtype}"), + ), + ) + + +( + compose_tensor_fwd_fp64, + compose_tensor_bwd_fp64, + compose_tensor_bwd_bwd_fp64, +) = generate_compose_tensor("float64") +( + compose_tensor_fwd_fp32, + compose_tensor_bwd_fp32, + compose_tensor_bwd_bwd_fp32, +) = generate_compose_tensor("float32") +( + compose_tensor_fwd_fp16, + compose_tensor_bwd_fp16, + compose_tensor_bwd_bwd_fp16, +) = generate_compose_tensor("float16") + +add_module("compose_tensor_fwd", ["float64"], compose_tensor_fwd_fp64) +add_module("compose_tensor_bwd", ["float64"], compose_tensor_bwd_fp64) +add_module("compose_tensor_bwd_bwd", ["float64"], compose_tensor_bwd_bwd_fp64) + +add_module("compose_tensor_fwd", ["float32"], compose_tensor_fwd_fp32) +add_module("compose_tensor_bwd", ["float32"], compose_tensor_bwd_fp32) +add_module("compose_tensor_bwd_bwd", ["float32"], compose_tensor_bwd_bwd_fp32) + +add_module("compose_tensor_fwd", ["float16"], compose_tensor_fwd_fp16) +add_module("compose_tensor_bwd", ["float16"], compose_tensor_bwd_fp16) +add_module("compose_tensor_bwd_bwd", ["float16"], compose_tensor_bwd_bwd_fp16) diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py new file mode 100644 index 00000000..685b6f7f --- /dev/null +++ b/src/matgl/kernels/decompose_tensor.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_decompose_tensor(dtype: str, h_last: bool = True, use_irmem: bool = True): + dtype_wp = get_wp_fp_dtype(dtype) + + if not use_irmem: + raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") + + if not h_last: + raise ValueError(f"only supporting h_last True but got {h_last}") + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + if use_irmem: + dim = 3 + else: + dim = 4 + + def decompose_tensor_fwd( + X: wp.array(ndim=4, dtype=dtype_wp), + I: wp.array(ndim=dim, dtype=dtype_wp), + A: wp.array(ndim=dim, dtype=dtype_wp), + S: wp.array(ndim=dim, dtype=dtype_wp), + ): + b, h = wp.tid() + + X_reg = mat3x3() + for i in range(3): + for j in range(3): + X_reg[i, j] = X[b, i, j, h] + + res = X.dtype(0) + for i in range(3): + res += X_reg[i, i] + res = res / X.dtype(3.0) + + I[b, 0, h] = res + + denom = X.dtype(2.0) + cnt = int(0) + for i in range(2): + for j in range(i + 1, 3): + A[b, cnt, h] = (X_reg[i, j] - X_reg[j, i]) / denom + cnt += int(1) + + cnt = int(0) + for i in range(2): + S[b, cnt, h] = X_reg[i, i] - res + cnt += int(1) + + for j in range(i + 1, 3): + S[b, cnt, h] = (X_reg[i, j] + X_reg[j, i]) / denom + cnt += int(1) + + def decompose_tensor_bwd( + dI: wp.array(ndim=dim, dtype=dtype_wp), + dA: wp.array(ndim=dim, dtype=dtype_wp), + dS: wp.array(ndim=dim, dtype=dtype_wp), + dX: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + dX_reg = mat3x3(dX.dtype(0)) + + dI_reg = dI[b, 0, h] + dA_reg = vec3(dX.dtype(0)) + dS_reg = vec5(dX.dtype(0)) + + for i in range(3): + dA_reg[i] = dA[b, i, h] + + for i in range(5): + dS_reg[i] = dS[b, i, h] + + for i in range(3): + dX_reg[i, i] = dI_reg / dI.dtype(3.0) + + denom = dX.dtype(2.0) + + cnt = int(0) + + for i in range(3): + for j in range(i + 1, 3): + dX_reg[i, j] += dA_reg[cnt] / denom + dX_reg[j, i] -= dA_reg[cnt] / denom + + cnt += int(1) + + cnt = int(0) + for i in range(2): + dX_reg[i, i] += dS_reg[cnt] + for j in range(3): + dX_reg[j, j] -= dS_reg[cnt] / dI.dtype(3.0) + + cnt += int(1) + + for j in range(i + 1, 3): + dX_reg[i, j] += dS_reg[cnt] / denom + dX_reg[j, i] += dS_reg[cnt] / denom + cnt += int(1) + + for i in range(3): + for j in range(3): + dX[b, i, j, h] = dX_reg[i, j] + + def decompose_tensor_bwd_bwd( + dX: wp.array(ndim=4, dtype=dtype_wp), + d2I: wp.array(ndim=dim, dtype=dtype_wp), + d2A: wp.array(ndim=dim, dtype=dtype_wp), + d2S: wp.array(ndim=dim, dtype=dtype_wp), + ): + b, h = wp.tid() + + dX_reg = mat3x3(dX.dtype(0)) + d2I_reg = dX.dtype(0) + d2A_reg = vec3(dX.dtype(0)) + d2S_reg = vec5(dX.dtype(0)) + + for i in range(3): + for j in range(3): + dX_reg[i, j] = dX[b, i, j, h] + + for i in range(3): + d2I_reg += dX_reg[i, i] / d2I.dtype(3.0) + + denom = dX.dtype(2.0) + + cnt = int(0) + + for i in range(3): + for j in range(i + 1, 3): + d2A_reg[cnt] += dX_reg[i, j] / denom + d2A_reg[cnt] -= dX_reg[j, i] / denom + cnt += int(1) + + cnt = int(0) + for i in range(2): + d2S_reg[cnt] += dX_reg[i, i] + for j in range(3): + d2S_reg[cnt] -= dX_reg[j, j] / d2I.dtype(3.0) + cnt += int(1) + + for j in range(i + 1, 3): + d2S_reg[cnt] += dX_reg[i, j] / denom + d2S_reg[cnt] += dX_reg[j, i] / denom + cnt += int(1) + + d2I[b, 0, h] = d2I_reg + for i in range(3): + d2A[b, i, h] = d2A_reg[i] + + for i in range(5): + d2S[b, i, h] = d2S_reg[i] + + return ( + wp.Kernel( + decompose_tensor_fwd, + key=f"decompose_tensor_{dtype}", + module=wp.get_module(f"decompose_tensor_{dtype}"), + ), + wp.Kernel( + decompose_tensor_bwd, + key=f"decompose_tensor_bwd_{dtype}", + module=wp.get_module(f"decompose_tensor_bwd_{dtype}"), + ), + wp.Kernel( + decompose_tensor_bwd_bwd, + key=f"decompose_tensor_bwd_bwd_{dtype}", + module=wp.get_module(f"decompose_tensor_bwd_bwd_{dtype}"), + ), + ) + + +decompose_tensor_fwd_fp64, decompose_tensor_bwd_fp64, decompose_tensor_bwd_bwd_fp64 = ( + generate_decompose_tensor("float64") +) +decompose_tensor_fwd_fp32, decompose_tensor_bwd_fp32, decompose_tensor_bwd_bwd_fp32 = ( + generate_decompose_tensor("float32") +) +decompose_tensor_fwd_fp16, decompose_tensor_bwd_fp16, decompose_tensor_bwd_bwd_fp16 = ( + generate_decompose_tensor("float16") +) + +add_module("decompose_tensor_fwd", ["float64"], decompose_tensor_fwd_fp64) +add_module("decompose_tensor_bwd", ["float64"], decompose_tensor_bwd_fp64) +add_module("decompose_tensor_bwd_bwd", ["float64"], decompose_tensor_bwd_bwd_fp64) + +add_module("decompose_tensor_fwd", ["float32"], decompose_tensor_fwd_fp32) +add_module("decompose_tensor_bwd", ["float32"], decompose_tensor_bwd_fp32) +add_module("decompose_tensor_bwd_bwd", ["float32"], decompose_tensor_bwd_bwd_fp32) + +add_module("decompose_tensor_fwd", ["float16"], decompose_tensor_fwd_fp16) +add_module("decompose_tensor_bwd", ["float16"], decompose_tensor_bwd_fp16) +add_module("decompose_tensor_bwd_bwd", ["float16"], decompose_tensor_bwd_bwd_fp16) diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py new file mode 100644 index 00000000..dcf027b0 --- /dev/null +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -0,0 +1,212 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_tensor_matmul_o3_3x3(dtype: str): + dtype_wp = get_wp_fp_dtype(dtype) + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + def tensor_matmul_o3_3x3_fwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + c_reg[i, j] += a_reg[i, k] * b_reg[k, j] + b_reg[i, k] * a_reg[k, j] + + for i in range(3): + for j in range(3): + C[b, i, j, h] = c_reg[i, j] + + def tensor_matmul_o3_3x3_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + da_reg[i, j] += dc_reg[i, k] * b_reg[j, k] + da_reg[j, k] += dc_reg[i, k] * b_reg[i, j] + db_reg[i, j] += dc_reg[i, k] * a_reg[j, k] + db_reg[j, k] += dc_reg[i, k] * a_reg[i, j] + + for i in range(3): + for j in range(3): + dA[b, i, j, h] = da_reg[i, j] + dB[b, i, j, h] = db_reg[i, j] + + def tensor_matmul_o3_3x3_bwd_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + d2A: wp.array(ndim=4, dtype=dtype_wp), + d2B: wp.array(ndim=4, dtype=dtype_wp), + d2C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + d2a_reg = mat3x3() + d2b_reg = mat3x3() + + d2c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + da_reg[i, j] = dA[b, i, j, h] + db_reg[i, j] = dB[b, i, j, h] + + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + d2a_reg[i, j] += dc_reg[i, k] * db_reg[j, k] + d2a_reg[j, i] += dc_reg[k, i] * db_reg[k, j] + + d2b_reg[i, j] += dc_reg[i, k] * da_reg[j, k] + d2b_reg[j, i] += dc_reg[k, i] * da_reg[k, j] + + for i in range(3): + for j in range(3): + for k in range(3): + # grad_grad_x @ y + x @ grad_grad_y + d2c_reg[i, j] += da_reg[i, k] * b_reg[k, j] + d2c_reg[i, j] += a_reg[i, k] * db_reg[k, j] + + d2c_reg[i, j] += db_reg[i, k] * a_reg[k, j] + d2c_reg[i, j] += b_reg[i, k] * da_reg[k, j] + + for i in range(3): + for j in range(3): + d2A[b, i, j, h] = d2a_reg[i, j] + d2B[b, i, j, h] = d2b_reg[i, j] + d2C[b, i, j, h] = d2c_reg[i, j] + + return ( + wp.Kernel( + tensor_matmul_o3_3x3_fwd, + key=f"tensor_matmul_o3_3x3_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_{dtype}"), + ), + wp.Kernel( + tensor_matmul_o3_3x3_bwd, + key=f"tensor_matmul_o3_3x3_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_bwd_{dtype}"), + ), + wp.Kernel( + tensor_matmul_o3_3x3_bwd_bwd, + key=f"tensor_matmul_o3_3x3_bwd_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_bwd_bwd_{dtype}"), + ), + ) + + +( + tensor_matmul_o3_3x3_fwd_fp64, + tensor_matmul_o3_3x3_bwd_fp64, + tensor_matmul_o3_3x3_bwd_bwd_fp64, +) = generate_tensor_matmul_o3_3x3("float64") +( + tensor_matmul_o3_3x3_fwd_fp32, + tensor_matmul_o3_3x3_bwd_fp32, + tensor_matmul_o3_3x3_bwd_bwd_fp32, +) = generate_tensor_matmul_o3_3x3("float32") +( + tensor_matmul_o3_3x3_fwd_fp16, + tensor_matmul_o3_3x3_bwd_fp16, + tensor_matmul_o3_3x3_bwd_bwd_fp16, +) = generate_tensor_matmul_o3_3x3("float16") + +add_module("tensor_matmul_o3_3x3_fwd", ["float64"], tensor_matmul_o3_3x3_fwd_fp64) +add_module("tensor_matmul_o3_3x3_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_fp64) +add_module( + "tensor_matmul_o3_3x3_bwd_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_bwd_fp64 +) + +add_module("tensor_matmul_o3_3x3_fwd", ["float32"], tensor_matmul_o3_3x3_fwd_fp32) +add_module("tensor_matmul_o3_3x3_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_fp32) +add_module( + "tensor_matmul_o3_3x3_bwd_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_bwd_fp32 +) + +add_module("tensor_matmul_o3_3x3_fwd", ["float16"], tensor_matmul_o3_3x3_fwd_fp16) +add_module("tensor_matmul_o3_3x3_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_fp16) +add_module( + "tensor_matmul_o3_3x3_bwd_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_bwd_fp16 +) diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py new file mode 100644 index 00000000..c50d0258 --- /dev/null +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_tensor_matmul_so3_3x3(dtype: str): + dtype_wp = get_wp_fp_dtype(dtype) + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + def tensor_matmul_so3_3x3_fwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + c_reg[i, j] += a_reg[i, k] * b_reg[k, j] + + for i in range(3): + for j in range(3): + C[b, i, j, h] = c_reg[i, j] + + def tensor_matmul_so3_3x3_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + da_reg[i, k] += dc_reg[i, j] * b_reg[k, j] + db_reg[k, j] += dc_reg[i, j] * a_reg[i, k] + + for i in range(3): + for j in range(3): + dA[b, i, j, h] = da_reg[i, j] + dB[b, i, j, h] = db_reg[i, j] + + def tensor_matmul_so3_3x3_bwd_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + d2A: wp.array(ndim=4, dtype=dtype_wp), + d2B: wp.array(ndim=4, dtype=dtype_wp), + d2C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + d2a_reg = mat3x3() + d2b_reg = mat3x3() + + d2c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + da_reg[i, j] = dA[b, i, j, h] + db_reg[i, j] = dB[b, i, j, h] + + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + d2a_reg[i, k] += dc_reg[i, j] * db_reg[k, j] + d2b_reg[k, j] += dc_reg[i, j] * da_reg[i, k] + + for i in range(3): + for j in range(3): + for k in range(3): + d2c_reg[i, j] += da_reg[i, k] * b_reg[k, j] + d2c_reg[i, j] += a_reg[i, k] * db_reg[k, j] + + for i in range(3): + for j in range(3): + d2A[b, i, j, h] = d2a_reg[i, j] + d2B[b, i, j, h] = d2b_reg[i, j] + d2C[b, i, j, h] = d2c_reg[i, j] + + return ( + wp.Kernel( + tensor_matmul_so3_3x3_fwd, + key=f"tensor_matmul_so3_3x3_{dtype}", + module=wp.get_module(f"tensor_matmul_so3_3x3_{dtype}"), + ), + wp.Kernel( + tensor_matmul_so3_3x3_bwd, + key=f"tensor_matmul_so3_3x3_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_bwd_{dtype}"), + ), + wp.Kernel( + tensor_matmul_so3_3x3_bwd_bwd, + key=f"tensor_matmul_so3_3x3_bwd_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_so3_3x3_bwd_bwd_{dtype}"), + ), + ) + + +( + tensor_matmul_so3_3x3_fwd_fp64, + tensor_matmul_so3_3x3_bwd_fp64, + tensor_matmul_so3_3x3_bwd_bwd_fp64, +) = generate_tensor_matmul_so3_3x3("float64") +( + tensor_matmul_so3_3x3_fwd_fp32, + tensor_matmul_so3_3x3_bwd_fp32, + tensor_matmul_so3_3x3_bwd_bwd_fp32, +) = generate_tensor_matmul_so3_3x3("float32") +( + tensor_matmul_so3_3x3_fwd_fp16, + tensor_matmul_so3_3x3_bwd_fp16, + tensor_matmul_so3_3x3_bwd_bwd_fp16, +) = generate_tensor_matmul_so3_3x3("float16") + +add_module("tensor_matmul_so3_3x3_fwd", ["float64"], tensor_matmul_so3_3x3_fwd_fp64) +add_module("tensor_matmul_so3_3x3_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_fp64) +add_module( + "tensor_matmul_so3_3x3_bwd_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_bwd_fp64 +) + +add_module("tensor_matmul_so3_3x3_fwd", ["float32"], tensor_matmul_so3_3x3_fwd_fp32) +add_module("tensor_matmul_so3_3x3_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_fp32) +add_module( + "tensor_matmul_so3_3x3_bwd_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_bwd_fp32 +) + +add_module("tensor_matmul_so3_3x3_fwd", ["float16"], tensor_matmul_so3_3x3_fwd_fp16) +add_module("tensor_matmul_so3_3x3_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_fp16) +add_module( + "tensor_matmul_so3_3x3_bwd_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_bwd_fp16 +) diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py new file mode 100644 index 00000000..655bb73c --- /dev/null +++ b/src/matgl/kernels/graph_transform.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + + +@wp.kernel +def count_row_col( + edge_index: wp.array(ndim=2, dtype=wp.int32), + row_count: wp.array(ndim=1, dtype=wp.int32), + col_count: wp.array(ndim=1, dtype=wp.int32), +): + tid = wp.tid() + + shift = edge_index.dtype(1) + wp.atomic_add(row_count, edge_index[0, tid] + shift, wp.int32(1)) + wp.atomic_add(col_count, edge_index[1, tid] + shift, wp.int32(1)) + + +@wp.kernel +def convert_to_sparse( + edge_index: wp.array(ndim=2, dtype=wp.int32), + row_count: wp.array(ndim=1, dtype=wp.int32), + col_count: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + col_indptr: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + col_indices: wp.array(ndim=1, dtype=wp.int32), + row_data: wp.array(ndim=1, dtype=wp.int32), + col_data: wp.array(ndim=1, dtype=wp.int32), +): + tid = wp.tid() + shift = edge_index.dtype(1) + + src_id = edge_index[0, tid] + dst_id = edge_index[1, tid] + + src_cnt = wp.atomic_sub(row_count, src_id + shift, wp.int32(1)) + dst_cnt = wp.atomic_sub(col_count, dst_id + shift, wp.int32(1)) + + row_indices[row_indptr[src_id + shift] - src_cnt] = dst_id + row_data[row_indptr[src_id + shift] - src_cnt] = wp.int32(tid) + + col_indices[col_indptr[dst_id + shift] - dst_cnt] = src_id + col_data[col_indptr[dst_id + shift] - dst_cnt] = wp.int32(tid) diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py new file mode 100644 index 00000000..463089c6 --- /dev/null +++ b/src/matgl/kernels/tensor_norm3.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_tensor_norm3(dtype: str, h_last: bool = True, use_irmem: bool = True): + dtype_wp = get_wp_fp_dtype(dtype) + + if not use_irmem: + raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") + + if not h_last: + raise ValueError(f"only supporting h_last True but got {h_last}") + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + def tensor_norm3_fwd( + X: wp.array(ndim=4, dtype=dtype_wp), + output: wp.array(ndim=2, dtype=dtype_wp), + ): + b, h = wp.tid() + + x00 = X[b, 0, 0, h] + x01 = X[b, 0, 1, h] + x02 = X[b, 0, 2, h] + x10 = X[b, 1, 0, h] + x11 = X[b, 1, 1, h] + x12 = X[b, 1, 2, h] + x20 = X[b, 2, 0, h] + x21 = X[b, 2, 1, h] + x22 = X[b, 2, 2, h] + + one_half = X.dtype(0.5) + one_third = X.dtype(1.0 / 3.0) + + trace = x00 + x11 + x22 + trace_third = trace / X.dtype(3.0) + norm2_i = one_third * trace * trace + norm2_a = one_half * ((x01 - x10) * (x01 - x10) + (x02 - x20) * (x02 - x20) + (x12 - x21) * (x12 - x21)) + norm2_s = one_half * ( + (x01 + x10) * (x01 + x10) + + (x02 + x20) * (x02 + x20) + + (x12 + x21) * (x12 + x21) + ) + (x00 - trace_third) * (x00 - trace_third) + (x11 - trace_third) * (x11 - trace_third) + (x22 - trace_third) * (x22 - trace_third) + + output[b, h] = norm2_i + output[b, h + X.shape[3]] = norm2_a + output[b, h + 2 * X.shape[3]] = norm2_s + + def tensor_norm3_bwd( + grad_output: wp.array(ndim=2, dtype=dtype_wp), + X: wp.array(ndim=4, dtype=dtype_wp), + grad_X: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + grad_i = grad_output[b, h] + grad_a = grad_output[b, h + X.shape[3]] + grad_s = grad_output[b, h + 2 * X.shape[3]] + + x00 = X[b, 0, 0, h] + x01 = X[b, 0, 1, h] + x02 = X[b, 0, 2, h] + x10 = X[b, 1, 0, h] + x11 = X[b, 1, 1, h] + x12 = X[b, 1, 2, h] + x20 = X[b, 2, 0, h] + x21 = X[b, 2, 1, h] + x22 = X[b, 2, 2, h] + + trace = x00 + x11 + x22 + trace_third = trace / X.dtype(3.0) + + diag_grad_i = X.dtype(2.0 / 3.0) * trace * grad_i + + dev00 = x00 - trace_third + dev11 = x11 - trace_third + dev22 = x22 - trace_third + + c4_3 = X.dtype(4.0) / X.dtype(3.0) + c2_3 = X.dtype(2.0) / X.dtype(3.0) + + grad_s_term_00 = c4_3 * dev00 - c2_3 * dev11 - c2_3 * dev22 + grad_s_term_11 = c4_3 * dev11 - c2_3 * dev00 - c2_3 * dev22 + grad_s_term_22 = c4_3 * dev22 - c2_3 * dev00 - c2_3 * dev11 + + grad_X[b, 0, 0, h] = diag_grad_i + grad_s * grad_s_term_00 + grad_X[b, 1, 1, h] = diag_grad_i + grad_s * grad_s_term_11 + grad_X[b, 2, 2, h] = diag_grad_i + grad_s * grad_s_term_22 + + diff01 = x01 - x10 + sum01 = x01 + x10 + grad_X[b, 0, 1, h] = grad_a * diff01 + grad_s * sum01 + grad_X[b, 1, 0, h] = -grad_a * diff01 + grad_s * sum01 + + diff02 = x02 - x20 + sum02 = x02 + x20 + grad_X[b, 0, 2, h] = grad_a * diff02 + grad_s * sum02 + grad_X[b, 2, 0, h] = -grad_a * diff02 + grad_s * sum02 + + diff12 = x12 - x21 + sum12 = x12 + x21 + grad_X[b, 1, 2, h] = grad_a * diff12 + grad_s * sum12 + grad_X[b, 2, 1, h] = -grad_a * diff12 + grad_s * sum12 + + def tensor_norm3_bwd_bwd( + grad_grad_X: wp.array(ndim=4, dtype=dtype_wp), + grad_grad_output: wp.array(ndim=2, dtype=dtype_wp), + ): + b, h = wp.tid() + + gg00 = grad_grad_X[b, 0, 0, h] + gg01 = grad_grad_X[b, 0, 1, h] + gg02 = grad_grad_X[b, 0, 2, h] + gg10 = grad_grad_X[b, 1, 0, h] + gg11 = grad_grad_X[b, 1, 1, h] + gg12 = grad_grad_X[b, 1, 2, h] + gg20 = grad_grad_X[b, 2, 0, h] + gg21 = grad_grad_X[b, 2, 1, h] + gg22 = grad_grad_X[b, 2, 2, h] + + trace_gg = gg00 + gg11 + gg22 + trace_third_gg = trace_gg / grad_grad_X.dtype(3.0) + + one_half = grad_grad_X.dtype(0.5) + one_third = grad_grad_X.dtype(1.0 / 3.0) + + norm2_i_gg = one_third * trace_gg * trace_gg + grad_grad_output[b, h] = norm2_i_gg + + diff01 = gg01 - gg10 + diff02 = gg02 - gg20 + diff12 = gg12 - gg21 + norm2_a_gg = one_half * (diff01 * diff01 + diff02 * diff02 + diff12 * diff12) + grad_grad_output[b, h + grad_grad_X.shape[3]] = norm2_a_gg + + sum01 = gg01 + gg10 + sum02 = gg02 + gg20 + sum12 = gg12 + gg21 + + dev00 = gg00 - trace_third_gg + dev11 = gg11 - trace_third_gg + dev22 = gg22 - trace_third_gg + + norm2_s_gg = one_half * (sum01 * sum01 + sum02 * sum02 + sum12 * sum12) + norm2_s_gg += dev00 * dev00 + dev11 * dev11 + dev22 * dev22 + grad_grad_output[b, h + 2 * grad_grad_X.shape[3]] = norm2_s_gg + + return ( + wp.Kernel( + tensor_norm3_fwd, + key=f"tensor_norm3_fwd_{dtype}", + module=wp.get_module(f"tensor_norm3_fwd_{dtype}"), + ), + wp.Kernel( + tensor_norm3_bwd, + key=f"tensor_norm3_bwd_{dtype}", + module=wp.get_module(f"tensor_norm3_bwd_{dtype}"), + ), + wp.Kernel( + tensor_norm3_bwd_bwd, + key=f"tensor_norm3_bwd_bwd_{dtype}", + module=wp.get_module(f"tensor_norm3_bwd_bwd_{dtype}"), + ), + ) + + +tensor_norm3_fwd_fp64, tensor_norm3_bwd_fp64, tensor_norm3_bwd_bwd_fp64 = ( + generate_tensor_norm3("float64") +) +tensor_norm3_fwd_fp32, tensor_norm3_bwd_fp32, tensor_norm3_bwd_bwd_fp32 = ( + generate_tensor_norm3("float32") +) +tensor_norm3_fwd_fp16, tensor_norm3_bwd_fp16, tensor_norm3_bwd_bwd_fp16 = ( + generate_tensor_norm3("float16") +) + +add_module("tensor_norm3_fwd", ["float64"], tensor_norm3_fwd_fp64) +add_module("tensor_norm3_bwd", ["float64"], tensor_norm3_bwd_fp64) +add_module("tensor_norm3_bwd_bwd", ["float64"], tensor_norm3_bwd_bwd_fp64) + +add_module("tensor_norm3_fwd", ["float32"], tensor_norm3_fwd_fp32) +add_module("tensor_norm3_bwd", ["float32"], tensor_norm3_bwd_fp32) +add_module("tensor_norm3_bwd_bwd", ["float32"], tensor_norm3_bwd_bwd_fp32) + +add_module("tensor_norm3_fwd", ["float16"], tensor_norm3_fwd_fp16) +add_module("tensor_norm3_bwd", ["float16"], tensor_norm3_bwd_fp16) +add_module("tensor_norm3_bwd_bwd", ["float16"], tensor_norm3_bwd_bwd_fp16) diff --git a/src/matgl/kernels/tensornet_mp.py b/src/matgl/kernels/tensornet_mp.py new file mode 100644 index 00000000..01baaed4 --- /dev/null +++ b/src/matgl/kernels/tensornet_mp.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_message_passing(dtype: str): + dtype_wp = get_wp_fp_dtype(dtype) + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + def message_passing_fwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + output_I: wp.array(ndim=3, dtype=dtype_wp), + output_A: wp.array(ndim=3, dtype=dtype_wp), + output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + output_I_reg = I.dtype(0) + output_A_reg = vec3(I.dtype(0)) + output_S_reg = vec5(I.dtype(0)) + + _I = I[b, 0, h] + _A = vec3(A[b, 0, h], A[b, 1, h], A[b, 2, h]) + _S = vec5(S[b, 0, h], S[b, 1, h], S[b, 2, h], S[b, 3, h], S[b, 4, h]) + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + wI = edge_attr[idx_w, 0, h] + wA = edge_attr[idx_w, 1, h] + wS = edge_attr[idx_w, 2, h] + + output_I_reg += _I * wI + for j in range(3): + output_A_reg[j] += _A[j] * wA + for j in range(5): + output_S_reg[j] += _S[j] * wS + + output_I[b, 0, h] = output_I_reg + for j in range(3): + output_A[b, j, h] = output_A_reg[j] + + for j in range(5): + output_S[b, j, h] = output_S_reg[j] + + def message_passing_bwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + dI: wp.array(ndim=3, dtype=dtype_wp), + dA: wp.array(ndim=3, dtype=dtype_wp), + dS: wp.array(ndim=3, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + dI_reg = I.dtype(0.0) + dA_reg = vec3(I.dtype(0.0)) + dS_reg = vec5(I.dtype(0.0)) + + dI_b = doutput_I[b, 0, h] + I_b = I[b, 0, h] + + dA_b0 = doutput_A[b, 0, h] + dA_b1 = doutput_A[b, 1, h] + dA_b2 = doutput_A[b, 2, h] + A_b0 = A[b, 0, h] + A_b1 = A[b, 1, h] + A_b2 = A[b, 2, h] + + dS_b0 = doutput_S[b, 0, h] + dS_b1 = doutput_S[b, 1, h] + dS_b2 = doutput_S[b, 2, h] + dS_b3 = doutput_S[b, 3, h] + dS_b4 = doutput_S[b, 4, h] + S_b0 = S[b, 0, h] + S_b1 = S[b, 1, h] + S_b2 = S[b, 2, h] + S_b3 = S[b, 3, h] + S_b4 = S[b, 4, h] + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + + wI = edge_attr[idx_w, 0, h] + wA = edge_attr[idx_w, 1, h] + wS = edge_attr[idx_w, 2, h] + + dI_reg += dI_b * wI + dedge_attr[idx_w, 0, h] = dI_b * I_b + + dA_reg[0] += dA_b0 * wA + dA_reg[1] += dA_b1 * wA + dA_reg[2] += dA_b2 * wA + dedge_attr[idx_w, 1, h] = dA_b0 * A_b0 + dA_b1 * A_b1 + dA_b2 * A_b2 + + dS_reg[0] += dS_b0 * wS + dS_reg[1] += dS_b1 * wS + dS_reg[2] += dS_b2 * wS + dS_reg[3] += dS_b3 * wS + dS_reg[4] += dS_b4 * wS + dedge_attr[idx_w, 2, h] = ( + dS_b0 * S_b0 + dS_b1 * S_b1 + dS_b2 * S_b2 + dS_b3 * S_b3 + dS_b4 * S_b4 + ) + + dI[b, 0, h] = dI_reg + for j in range(3): + dA[b, j, h] = dA_reg[j] + for j in range(5): + dS[b, j, h] = dS_reg[j] + + def message_passing_bwd_bwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + dI: wp.array(ndim=3, dtype=dtype_wp), + dA: wp.array(ndim=3, dtype=dtype_wp), + dS: wp.array(ndim=3, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + col_data: wp.array(ndim=1, dtype=wp.int32), + col_indices: wp.array(ndim=1, dtype=wp.int32), + col_indptr: wp.array(ndim=1, dtype=wp.int32), + d2I: wp.array(ndim=3, dtype=dtype_wp), + d2A: wp.array(ndim=3, dtype=dtype_wp), + d2S: wp.array(ndim=3, dtype=dtype_wp), + d2edge_attr: wp.array(ndim=3, dtype=dtype_wp), + d2output_I: wp.array(ndim=3, dtype=dtype_wp), + d2output_A: wp.array(ndim=3, dtype=dtype_wp), + d2output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + d2I_reg = d2output_I.dtype(0) + d2output_I_reg = d2output_I.dtype(0) + + d2A_reg = vec3(d2output_A.dtype(0)) + d2output_A_reg = vec3(d2output_A.dtype(0)) + + d2S_reg = vec5(d2output_S.dtype(0)) + d2output_S_reg = vec5(d2output_S.dtype(0)) + + for i in range(col_indptr[b], col_indptr[b + 1]): + idx_j = col_indices[i] + idx_w = col_data[i] + + dweight_I = dedge_attr[idx_w, 0, h] + doutput_I_reg = doutput_I[idx_j, 0, h] + + d2I_reg += doutput_I_reg * dweight_I + + dweight_A = dedge_attr[idx_w, 1, h] + for j in range(3): + d2A_reg[j] += doutput_A[idx_j, j, h] * dweight_A + + dweight_S = dedge_attr[idx_w, 2, h] + + for j in range(5): + d2S_reg[j] += doutput_S[idx_j, j, h] * dweight_S + + I_b = I[b, 0, h] + dI_b = dI[b, 0, h] + dO_I_b = doutput_I[b, 0, h] + + A_b0 = A[b, 0, h] + A_b1 = A[b, 1, h] + A_b2 = A[b, 2, h] + dA_b0 = dA[b, 0, h] + dA_b1 = dA[b, 1, h] + dA_b2 = dA[b, 2, h] + + S_b0 = S[b, 0, h] + S_b1 = S[b, 1, h] + S_b2 = S[b, 2, h] + S_b3 = S[b, 3, h] + S_b4 = S[b, 4, h] + dS_b0 = dS[b, 0, h] + dS_b1 = dS[b, 1, h] + dS_b2 = dS[b, 2, h] + dS_b3 = dS[b, 3, h] + dS_b4 = dS[b, 4, h] + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + + wI = edge_attr[idx_w, 0, h] + wA = edge_attr[idx_w, 1, h] + wS = edge_attr[idx_w, 2, h] + + d2output_I_reg += dI_b * wI + d2output_I_reg += I_b * dedge_attr[idx_w, 0, h] + + d2edge_attr[idx_w, 0, h] = dO_I_b * dI_b + + d2output_A_reg[0] += dA_b0 * wA + d2output_A_reg[1] += dA_b1 * wA + d2output_A_reg[2] += dA_b2 * wA + d2output_A_reg[0] += A_b0 * dedge_attr[idx_w, 1, h] + d2output_A_reg[1] += A_b1 * dedge_attr[idx_w, 1, h] + d2output_A_reg[2] += A_b2 * dedge_attr[idx_w, 1, h] + + d2edge_attr[idx_w, 1, h] = ( + doutput_A[b, 0, h] * dA_b0 + + doutput_A[b, 1, h] * dA_b1 + + doutput_A[b, 2, h] * dA_b2 + ) + + d2output_S_reg[0] += dS_b0 * wS + d2output_S_reg[1] += dS_b1 * wS + d2output_S_reg[2] += dS_b2 * wS + d2output_S_reg[3] += dS_b3 * wS + d2output_S_reg[4] += dS_b4 * wS + d2output_S_reg[0] += S_b0 * dedge_attr[idx_w, 2, h] + d2output_S_reg[1] += S_b1 * dedge_attr[idx_w, 2, h] + d2output_S_reg[2] += S_b2 * dedge_attr[idx_w, 2, h] + d2output_S_reg[3] += S_b3 * dedge_attr[idx_w, 2, h] + d2output_S_reg[4] += S_b4 * dedge_attr[idx_w, 2, h] + + d2edge_attr[idx_w, 2, h] = ( + doutput_S[b, 0, h] * dS_b0 + + doutput_S[b, 1, h] * dS_b1 + + doutput_S[b, 2, h] * dS_b2 + + doutput_S[b, 3, h] * dS_b3 + + doutput_S[b, 4, h] * dS_b4 + ) + + d2output_I[b, 0, h] = d2output_I_reg + d2I[b, 0, h] = d2I_reg + + for j in range(3): + d2A[b, j, h] = d2A_reg[j] + d2output_A[b, j, h] = d2output_A_reg[j] + + for j in range(5): + d2S[b, j, h] = d2S_reg[j] + d2output_S[b, j, h] = d2output_S_reg[j] + + return ( + wp.Kernel( + message_passing_fwd, + key=f"message_passing_fwd_{dtype}", + module=wp.get_module(f"message_passing_fwd_{dtype}"), + ), + wp.Kernel( + message_passing_bwd, + key=f"message_passing_bwd_{dtype}", + module=wp.get_module(f"message_passing_bwd_{dtype}"), + ), + wp.Kernel( + message_passing_bwd_bwd, + key=f"message_passing_bwd_bwd_{dtype}", + module=wp.get_module(f"message_passing_bwd_bwd_{dtype}"), + ), + ) + + +message_passing_fwd_fp64, message_passing_bwd_fp64, message_passing_bwd_bwd_fp64 = ( + generate_message_passing("float64") +) +message_passing_fwd_fp32, message_passing_bwd_fp32, message_passing_bwd_bwd_fp32 = ( + generate_message_passing("float32") +) +message_passing_fwd_fp16, message_passing_bwd_fp16, message_passing_bwd_bwd_fp16 = ( + generate_message_passing("float16") +) + +add_module("message_passing_fwd", ["float64"], message_passing_fwd_fp64) +add_module("message_passing_bwd", ["float64"], message_passing_bwd_fp64) +add_module("message_passing_bwd_bwd", ["float64"], message_passing_bwd_bwd_fp64) + +add_module("message_passing_fwd", ["float32"], message_passing_fwd_fp32) +add_module("message_passing_bwd", ["float32"], message_passing_bwd_fp32) +add_module("message_passing_bwd_bwd", ["float32"], message_passing_bwd_bwd_fp32) + +add_module("message_passing_fwd", ["float16"], message_passing_fwd_fp16) +add_module("message_passing_bwd", ["float16"], message_passing_bwd_fp16) +add_module("message_passing_bwd_bwd", ["float16"], message_passing_bwd_bwd_fp16) diff --git a/src/matgl/kernels/tensornet_radial_mp.py b/src/matgl/kernels/tensornet_radial_mp.py new file mode 100644 index 00000000..5e069910 --- /dev/null +++ b/src/matgl/kernels/tensornet_radial_mp.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_radial_message_passing(dtype: str): + dtype_wp = get_wp_fp_dtype(dtype) + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + def radial_message_passing_fwd( + edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + output_I: wp.array(ndim=3, dtype=dtype_wp), + output_A: wp.array(ndim=3, dtype=dtype_wp), + output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + output_I_reg = output_I.dtype(0) + output_A_reg = vec3(output_I.dtype(0)) + output_S_reg = vec5(output_I.dtype(0)) + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + + weight_I_reg = edge_attr[idx_w, 0, h] + weight_A_reg = edge_attr[idx_w, 1, h] + weight_S_reg = edge_attr[idx_w, 2, h] + + r_ij = vec3(output_I.dtype(0)) + r_ij[0] = edge_vec_norm[idx_w, 0] + r_ij[1] = edge_vec_norm[idx_w, 1] + r_ij[2] = edge_vec_norm[idx_w, 2] + + output_I_reg += weight_I_reg + + output_A_reg[0] += r_ij[2] * weight_A_reg + output_A_reg[1] += -r_ij[1] * weight_A_reg + output_A_reg[2] += r_ij[0] * weight_A_reg + + S_reg = vec5() + mean_r2 = ( + r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2] + ) / output_I.dtype(3.0) + S_reg[0] = r_ij[0] * r_ij[0] - mean_r2 + S_reg[1] = r_ij[0] * r_ij[1] + S_reg[2] = r_ij[0] * r_ij[2] + S_reg[3] = r_ij[1] * r_ij[1] - mean_r2 + S_reg[4] = r_ij[1] * r_ij[2] + + output_S_reg[0] += S_reg[0] * weight_S_reg + output_S_reg[1] += S_reg[1] * weight_S_reg + output_S_reg[2] += S_reg[2] * weight_S_reg + output_S_reg[3] += S_reg[3] * weight_S_reg + output_S_reg[4] += S_reg[4] * weight_S_reg + + output_I[b, 0, h] = output_I_reg + for i in range(3): + output_A[b, i, h] = output_A_reg[i] + + for i in range(5): + output_S[b, i, h] = output_S_reg[i] + + def radial_message_passing_bwd( + edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + dedge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + doutput_I_reg = doutput_I[b, 0, h] + doutput_A_reg = vec3() + doutput_A_reg[0] = doutput_A[b, 0, h] + doutput_A_reg[1] = doutput_A[b, 1, h] + doutput_A_reg[2] = doutput_A[b, 2, h] + + doutput_S_reg = vec5() + doutput_S_reg[0] = doutput_S[b, 0, h] + doutput_S_reg[1] = doutput_S[b, 1, h] + doutput_S_reg[2] = doutput_S[b, 2, h] + doutput_S_reg[3] = doutput_S[b, 3, h] + doutput_S_reg[4] = doutput_S[b, 4, h] + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + + edge_attr_A_reg = edge_attr[idx_w, 1, h] + edge_attr_S_reg = edge_attr[idx_w, 2, h] + + r_ij = vec3(doutput_I.dtype(0)) + dr_ij = vec3(doutput_I.dtype(0)) + r_ij[0] = edge_vec_norm[idx_w, 0] + r_ij[1] = edge_vec_norm[idx_w, 1] + r_ij[2] = edge_vec_norm[idx_w, 2] + + dr_ij[2] += doutput_A_reg[0] * edge_attr_A_reg + dr_ij[1] += -doutput_A_reg[1] * edge_attr_A_reg + dr_ij[0] += doutput_A_reg[2] * edge_attr_A_reg + + dedge_attr_I = doutput_I_reg + + dedge_attr_A = ( + doutput_A_reg[0] * r_ij[2] + - doutput_A_reg[1] * r_ij[1] + + doutput_A_reg[2] * r_ij[0] + ) + + S_reg = vec5() + mean_r2 = ( + r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2] + ) / doutput_I.dtype(3.0) + S_reg[0] = r_ij[0] * r_ij[0] - mean_r2 + S_reg[1] = r_ij[0] * r_ij[1] + S_reg[2] = r_ij[0] * r_ij[2] + S_reg[3] = r_ij[1] * r_ij[1] - mean_r2 + S_reg[4] = r_ij[1] * r_ij[2] + + dedge_attr_S = (S_reg[0]) * doutput_S_reg[0] + dedge_attr_S += (S_reg[1]) * doutput_S_reg[1] + dedge_attr_S += (S_reg[2]) * doutput_S_reg[2] + dedge_attr_S += (S_reg[3]) * doutput_S_reg[3] + dedge_attr_S += (S_reg[4]) * doutput_S_reg[4] + + dS_reg = vec5() + dS_reg[0] = edge_attr_S_reg * doutput_S_reg[0] + dS_reg[1] = edge_attr_S_reg * doutput_S_reg[1] + dS_reg[2] = edge_attr_S_reg * doutput_S_reg[2] + dS_reg[3] = edge_attr_S_reg * doutput_S_reg[3] + dS_reg[4] = edge_attr_S_reg * doutput_S_reg[4] + + dr_ij[0] += ( + dS_reg[0] * (doutput_I.dtype(4.0) / doutput_I.dtype(3.0) * r_ij[0]) + + dS_reg[1] * r_ij[1] + + dS_reg[2] * r_ij[2] + + dS_reg[3] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[0]) + ) + dr_ij[1] += ( + dS_reg[0] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[1]) + + dS_reg[1] * r_ij[0] + + dS_reg[3] * (doutput_I.dtype(4.0) / doutput_I.dtype(3.0) * r_ij[1]) + + dS_reg[4] * r_ij[2] + ) + dr_ij[2] += ( + dS_reg[0] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[2]) + + dS_reg[2] * r_ij[0] + + dS_reg[3] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[2]) + + dS_reg[4] * r_ij[1] + ) + + wp.atomic_add(dedge_attr, idx_w, 0, h, dedge_attr_I) + wp.atomic_add(dedge_attr, idx_w, 1, h, dedge_attr_A) + wp.atomic_add(dedge_attr, idx_w, 2, h, dedge_attr_S) + + wp.atomic_add(dedge_vec_norm, idx_w, 0, dr_ij[0]) + wp.atomic_add(dedge_vec_norm, idx_w, 1, dr_ij[1]) + wp.atomic_add(dedge_vec_norm, idx_w, 2, dr_ij[2]) + + def radial_message_passing_bwd_bwd( + edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + dedge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + d2edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + d2edge_attr: wp.array(ndim=3, dtype=dtype_wp), + d2output_I: wp.array(ndim=3, dtype=dtype_wp), + d2output_A: wp.array(ndim=3, dtype=dtype_wp), + d2output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + d2output_I_reg = d2output_I.dtype(0.0) + d2output_A_reg = vec3() + d2output_S_reg = vec5() + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + edge_attr_A_reg = edge_attr[idx_w, 1, h] + edge_attr_S_reg = edge_attr[idx_w, 2, h] + + dedge_attr_I = dedge_attr[idx_w, 0, h] + dedge_attr_A = dedge_attr[idx_w, 1, h] + dedge_attr_S = dedge_attr[idx_w, 2, h] + + r_ij = vec3(d2output_I.dtype(0)) + dr_ij = vec3(d2output_I.dtype(0)) + for j in range(3): + r_ij[j] = edge_vec_norm[idx_w, j] + dr_ij[j] = dedge_vec_norm[idx_w, j] + + d2output_I_reg += dedge_attr_I + + d2r_ij = vec3(d2output_I.dtype(0)) + + # No gradient contribution for edge_attr[*, 0, h] in forward pass + # d2edge_attr[idx_w, 0, h] = d2output_I.dtype(0.0) + + d2output_A_reg[0] += dr_ij[2] * edge_attr_A_reg + d2output_A_reg[1] += -dr_ij[1] * edge_attr_A_reg + d2output_A_reg[2] += dr_ij[0] * edge_attr_A_reg + + d2output_A_reg[0] += dedge_attr_A * r_ij[2] + d2output_A_reg[1] += -dedge_attr_A * r_ij[1] + d2output_A_reg[2] += dedge_attr_A * r_ij[0] + + dweight_A = ( + doutput_A[b, 0, h] * dr_ij[2] + - doutput_A[b, 1, h] * dr_ij[1] + + doutput_A[b, 2, h] * dr_ij[0] + ) + + d2r_ij[2] += dedge_attr_A * doutput_A[b, 0, h] + d2r_ij[1] += -dedge_attr_A * doutput_A[b, 1, h] + d2r_ij[0] += dedge_attr_A * doutput_A[b, 2, h] + + wp.atomic_add(d2edge_attr, idx_w, 1, h, dweight_A) + + c0 = doutput_S.dtype(4.0) / doutput_S.dtype(3.0) + c1 = -doutput_S.dtype(2.0) / doutput_S.dtype(3.0) + + c2 = doutput_S.dtype(2.0) / doutput_S.dtype(3.0) + c3 = -doutput_S.dtype(1.0) / doutput_S.dtype(3.0) + + d2output_S_reg[0] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * c0 * r_ij[0] + + dedge_vec_norm[idx_w, 1] * c1 * r_ij[1] + + dedge_vec_norm[idx_w, 2] * c1 * r_ij[2] + ) + d2output_S_reg[0] += dedge_attr_S * ( + c2 * r_ij[0] * r_ij[0] + c3 * r_ij[1] * r_ij[1] + c3 * r_ij[2] * r_ij[2] + ) + + d2output_S_reg[1] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * r_ij[1] + dedge_vec_norm[idx_w, 1] * r_ij[0] + ) + d2output_S_reg[1] += dedge_attr_S * (r_ij[1] * r_ij[0]) + + d2output_S_reg[2] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * r_ij[2] + dedge_vec_norm[idx_w, 2] * r_ij[0] + ) + d2output_S_reg[2] += dedge_attr_S * (r_ij[2] * r_ij[0]) + + d2output_S_reg[3] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * c1 * r_ij[0] + + dedge_vec_norm[idx_w, 1] * c0 * r_ij[1] + + dedge_vec_norm[idx_w, 2] * c1 * r_ij[2] + ) + d2output_S_reg[3] += dedge_attr_S * ( + c3 * r_ij[0] * r_ij[0] + c2 * r_ij[1] * r_ij[1] + c3 * r_ij[2] * r_ij[2] + ) + + d2output_S_reg[4] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 1] * r_ij[2] + dedge_vec_norm[idx_w, 2] * r_ij[1] + ) + d2output_S_reg[4] += dedge_attr_S * (r_ij[2] * r_ij[1]) + + d2r_ij[0] += ( + doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c0) + ) + d2r_ij[1] += ( + doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c1) + ) + d2r_ij[2] += ( + doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) + ) + + d2r_ij[0] += doutput_S[b, 0, h] * dedge_attr_S * (c0 * r_ij[0]) + d2r_ij[1] += doutput_S[b, 0, h] * dedge_attr_S * (c1 * r_ij[1]) + d2r_ij[2] += doutput_S[b, 0, h] * dedge_attr_S * (c1 * r_ij[2]) + + d2r_ij[0] += ( + doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) + ) + d2r_ij[1] += ( + doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) + ) + + d2r_ij[0] += doutput_S[b, 1, h] * dedge_attr_S * (r_ij[1]) + d2r_ij[1] += doutput_S[b, 1, h] * dedge_attr_S * (r_ij[0]) + + d2r_ij[0] += ( + doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) + ) + d2r_ij[2] += ( + doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) + ) + + d2r_ij[0] += doutput_S[b, 2, h] * dedge_attr_S * (r_ij[2]) + d2r_ij[2] += doutput_S[b, 2, h] * dedge_attr_S * (r_ij[0]) + + d2r_ij[0] += ( + doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c1) + ) + d2r_ij[1] += ( + doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c0) + ) + d2r_ij[2] += ( + doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) + ) + + d2r_ij[0] += doutput_S[b, 3, h] * dedge_attr_S * (c1 * r_ij[0]) + d2r_ij[1] += doutput_S[b, 3, h] * dedge_attr_S * (c0 * r_ij[1]) + d2r_ij[2] += doutput_S[b, 3, h] * dedge_attr_S * (c1 * r_ij[2]) + + d2r_ij[1] += ( + doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) + ) + d2r_ij[2] += ( + doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) + ) + + d2r_ij[1] += doutput_S[b, 4, h] * dedge_attr_S * (r_ij[2]) + d2r_ij[2] += doutput_S[b, 4, h] * dedge_attr_S * (r_ij[1]) + + d2weight_S = doutput_S.dtype(0.0) + d2weight_S += doutput_S[b, 0, h] * ( + c0 * r_ij[0] * dedge_vec_norm[idx_w, 0] + + c1 * r_ij[1] * dedge_vec_norm[idx_w, 1] + + c1 * r_ij[2] * dedge_vec_norm[idx_w, 2] + ) + + d2weight_S += doutput_S[b, 1, h] * ( + r_ij[1] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 1] + ) + + d2weight_S += doutput_S[b, 2, h] * ( + r_ij[2] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 2] + ) + + d2weight_S += doutput_S[b, 3, h] * ( + c1 * r_ij[0] * dedge_vec_norm[idx_w, 0] + + c0 * r_ij[1] * dedge_vec_norm[idx_w, 1] + + c1 * r_ij[2] * dedge_vec_norm[idx_w, 2] + ) + + d2weight_S += doutput_S[b, 4, h] * ( + r_ij[2] * dedge_vec_norm[idx_w, 1] + r_ij[1] * dedge_vec_norm[idx_w, 2] + ) + + wp.atomic_add(d2edge_attr, idx_w, 2, h, d2weight_S) + + wp.atomic_add(d2edge_vec_norm, idx_w, 0, d2r_ij[0]) + wp.atomic_add(d2edge_vec_norm, idx_w, 1, d2r_ij[1]) + wp.atomic_add(d2edge_vec_norm, idx_w, 2, d2r_ij[2]) + + d2output_I[b, 0, h] = d2output_I_reg + + for i in range(3): + d2output_A[b, i, h] = d2output_A_reg[i] + + for i in range(5): + d2output_S[b, i, h] = d2output_S_reg[i] + + return ( + wp.Kernel( + radial_message_passing_fwd, + key=f"radial_message_passing_fwd_{dtype}", + module=wp.get_module(f"radial_message_passing_fwd_{dtype}"), + ), + wp.Kernel( + radial_message_passing_bwd, + key=f"radial_message_passing_bwd_{dtype}", + module=wp.get_module(f"radial_message_passing_bwd_{dtype}"), + ), + wp.Kernel( + radial_message_passing_bwd_bwd, + key=f"radial_message_passing_bwd_bwd_{dtype}", + module=wp.get_module(f"radial_message_passing_bwd_bwd_{dtype}"), + ), + ) + + +( + radial_message_passing_fwd_fp64, + radial_message_passing_bwd_fp64, + radial_message_passing_bwd_bwd_fp64, +) = generate_radial_message_passing("float64") +( + radial_message_passing_fwd_fp32, + radial_message_passing_bwd_fp32, + radial_message_passing_bwd_bwd_fp32, +) = generate_radial_message_passing("float32") +( + radial_message_passing_fwd_fp16, + radial_message_passing_bwd_fp16, + radial_message_passing_bwd_bwd_fp16, +) = generate_radial_message_passing("float16") + +add_module("radial_message_passing_fwd", ["float64"], radial_message_passing_fwd_fp64) +add_module("radial_message_passing_bwd", ["float64"], radial_message_passing_bwd_fp64) +add_module( + "radial_message_passing_bwd_bwd", ["float64"], radial_message_passing_bwd_bwd_fp64 +) + +add_module("radial_message_passing_fwd", ["float32"], radial_message_passing_fwd_fp32) +add_module("radial_message_passing_bwd", ["float32"], radial_message_passing_bwd_fp32) +add_module( + "radial_message_passing_bwd_bwd", ["float32"], radial_message_passing_bwd_bwd_fp32 +) + +add_module("radial_message_passing_fwd", ["float16"], radial_message_passing_fwd_fp16) +add_module("radial_message_passing_bwd", ["float16"], radial_message_passing_bwd_fp16) +add_module( + "radial_message_passing_bwd_bwd", ["float16"], radial_message_passing_bwd_bwd_fp16 +) diff --git a/src/matgl/kernels/utils.py b/src/matgl/kernels/utils.py new file mode 100644 index 00000000..4da11a59 --- /dev/null +++ b/src/matgl/kernels/utils.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from typing import List +import warp as wp +import torch + +MODULES = {} + + +def get_module(name: str, dtype: List[str]): + """ + Get the module for the given name and dtype + """ + full_name = name + "_" + "_".join(get_dtype(d) for d in dtype) + if full_name not in MODULES: + print(f"Module {full_name} not found in MODULES dictionary") + print(f"Available modules: {list(MODULES.keys())}") + raise ValueError(f"Module {full_name} not found") + return MODULES[full_name] + + +def add_module(name: str, dtype: List[str], kernel: wp.Kernel): + """ + Add the module for the given name and dtype + """ + full_name = name + "_" + "_".join(get_dtype(d) for d in dtype) + if full_name not in MODULES: + MODULES[full_name] = kernel + return MODULES[full_name] + + +def get_dtype(dtype: str): + """ + Get the dtype for the given dtype + WIP + """ + if dtype.endswith("16"): + return "fp16" + elif dtype.endswith("32"): + return "fp32" + elif dtype.endswith("64"): + return "fp64" + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def get_wp_fp_dtype(dtype: str): + """ + Get the warp dtype for the given dtype + WIP + """ + if dtype.endswith("16"): + return wp.float16 + elif dtype.endswith("32"): + return wp.float32 + elif dtype.endswith("64"): + return wp.float64 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def list_modules(): + """ + List all modules in the MODULES dictionary + """ + print("Available modules:") + for name in MODULES.keys(): + print(f" - {name}") + return list(MODULES.keys()) + + +def get_stream(device: torch.device): + """ + Get the stream for the given device + """ + if device.type == "cuda": + return wp.stream_from_torch(torch.cuda.current_stream(device)) + else: + return None \ No newline at end of file diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 0fab57f6..72192325 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -11,10 +11,10 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Mapping, Any import torch -from torch import nn +from torch import nn, Tensor import matgl from matgl.config import DEFAULT_ELEMENTS @@ -29,12 +29,18 @@ WeightedReadOut, ) from matgl.utils.cutoff import cosine_cutoff -from matgl.utils.maths import ( - decompose_tensor, - new_radial_tensor, - scatter_add, - vector_to_skewtensor, - vector_to_symtensor, +from matgl.utils.maths import scatter_add + +from matgl.ops import ( + fn_radial_message_passing, + fn_compose_tensor, + fn_decompose_tensor, + fn_tensor_norm3, + fn_message_passing, + fn_radial_message_passing, + fn_tensor_matmul_o3_3x3, + fn_tensor_matmul_so3_3x3, + graph_transform, ) from ._core import MatGLModel @@ -45,40 +51,6 @@ logger = logging.getLogger(__file__) -def compose_tensor(I_tensor: torch.Tensor, A: torch.Tensor, S: torch.Tensor) -> torch.Tensor: - """Compose tensor from scalar (I_tensor), skew-symmetric (A), and traceless symmetric (S) components. - - Args: - I_tensor: Scalar component, shape (num_nodes, 1, 1, units) or (num_nodes, 3, 3, units) - A: Skew-symmetric component, shape (num_nodes, 3, 3, units) - S: Traceless symmetric component, shape (num_nodes, 3, 3, units) - - Returns: - Composed tensor, shape (num_nodes, 3, 3, units) - """ - # I_tensor is scalar (1x1), A is skew (3x3), S is traceless symmetric (3x3) - # For I_tensor, we need to expand it to 3x3 identity matrix - if I_tensor.shape[1] == 1 and I_tensor.shape[2] == 1: - # I_tensor has shape (num_nodes, 1, 1, units) - # Expand scalar to 3x3 identity matrix: multiply I_tensor by identity - eye = torch.eye(3, 3, device=I_tensor.device, dtype=I_tensor.dtype) # (3, 3) - # I_tensor: (num_nodes, 1, 1, units) - # We need: I_expanded[i, :, :, u] = I_tensor[i, 0, 0, u] * eye - # I_values: (num_nodes, units) - I_values = I_tensor.squeeze(1).squeeze(1) # (num_nodes, units) - # eye_expanded: (1, 3, 3, 1) for broadcasting - eye_expanded = eye.unsqueeze(0).unsqueeze(-1) # (1, 3, 3, 1) - # I_values.unsqueeze(1).unsqueeze(1): (num_nodes, 1, 1, units) - # Multiply: (num_nodes, 1, 1, units) * (1, 3, 3, 1) -> (num_nodes, 3, 3, units) - I_expanded = I_values.unsqueeze(1).unsqueeze(1) * eye_expanded # (num_nodes, 3, 3, units) - else: - I_expanded = I_tensor - - # A is already 3x3 skew-symmetric, shape (num_nodes, 3, 3, units) - # S is already 3x3 traceless symmetric, shape (num_nodes, 3, 3, units) - return I_expanded + A + S - - def compute_pair_vector_and_distance( pos: torch.Tensor, edge_index: torch.Tensor, @@ -108,151 +80,9 @@ def compute_pair_vector_and_distance( return bond_vec, bond_dist -def radial_message_passing( - edge_vec_norm: torch.Tensor, - edge_attr: torch.Tensor, - edge_index: torch.Tensor, - num_nodes: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform radial message passing to aggregate edge information to nodes. - - Args: - edge_vec_norm: Normalized edge vectors, shape (num_edges, 3) - edge_attr: Edge attributes, shape (num_edges, 3, units) - edge_index: Edge indices, shape (2, num_edges) - num_nodes: Number of nodes - - Returns: - I: Scalar components, shape (num_nodes, 1, 1, units) - A: Skew-symmetric components, shape (num_nodes, 3, 3, units) - S: Traceless symmetric components, shape (num_nodes, 3, 3, units) - """ - dst = edge_index[1] - - # Create radial tensors from edge vectors - # For scalars: use (1, 1, 1, 1) which will broadcast with f_I - eye_scalar_base = torch.ones(1, 1, 1, 1, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype) - A_skew_base = vector_to_skewtensor(edge_vec_norm).unsqueeze(-3) # (num_edges, 1, 3, 3) - S_sym_base = vector_to_symtensor(edge_vec_norm).unsqueeze(-3) # (num_edges, 1, 3, 3) - - # Split edge_attr into three components - edge_attr_I = edge_attr[:, 0, :] # (num_edges, units) - edge_attr_A = edge_attr[:, 1, :] # (num_edges, units) - edge_attr_S = edge_attr[:, 2, :] # (num_edges, units) - - # Call new_radial_tensor - # new_radial_tensor multiplies f_I[..., None, None] * scalars - # f_I: (num_edges, units) -> (num_edges, units, 1, 1) - # scalars: (1, 1, 1, 1) -> broadcasts to (num_edges, units, 1, 1) - # Result: I_ij (num_edges, units, 1, 1), A_ij (num_edges, units, 3, 3), S_ij (num_edges, units, 3, 3) - I_ij, A_ij, S_ij = new_radial_tensor( - eye_scalar_base, - A_skew_base, - S_sym_base, - edge_attr_I, - edge_attr_A, - edge_attr_S, - ) - - # new_radial_tensor returns with units in position 1, we need units in position -1 - # Transpose: (num_edges, units, 1, 1) -> (num_edges, 1, 1, units) - # Transpose: (num_edges, units, 3, 3) -> (num_edges, 3, 3, units) - I_ij = I_ij.permute(0, 2, 3, 1) # (num_edges, 1, 1, units) - A_ij = A_ij.permute(0, 2, 3, 1) # (num_edges, 3, 3, units) - S_ij = S_ij.permute(0, 2, 3, 1) # (num_edges, 3, 3, units) - - # Aggregate to nodes - I_tensor = scatter_add(I_ij, dst, dim_size=num_nodes, dim=0) - A = scatter_add(A_ij, dst, dim_size=num_nodes, dim=0) - S = scatter_add(S_ij, dst, dim_size=num_nodes, dim=0) - - return I_tensor, A, S - - -def message_passing( - I_tensor: torch.Tensor, - A: torch.Tensor, - S: torch.Tensor, - edge_attr: torch.Tensor, - edge_index: torch.Tensor, - num_nodes: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform message passing for tensor components. - - Args: - I_tensor: Scalar components, shape (num_nodes, 1, 1, units) - A: Skew-symmetric components, shape (num_nodes, 3, 3, units) - S: Traceless symmetric components, shape (num_nodes, 3, 3, units) - edge_attr: Edge attributes, shape (num_edges, 3, units) - edge_index: Edge indices, shape (2, num_edges) - num_nodes: Number of nodes - - Returns: - Im: Aggregated scalar messages, shape (num_nodes, 1, 1, units) - Am: Aggregated skew messages, shape (num_nodes, 3, 3, units) - Sm: Aggregated traceless messages, shape (num_nodes, 3, 3, units) - """ - dst = edge_index[1] - - # Get node features for destination nodes - I_j = I_tensor[dst] - A_j = A[dst] - S_j = S[dst] - - # Extract edge attribute components - # edge_attr has shape (num_edges, 3, units) where dim 1 is (I, A, S) components - edge_attr_I = edge_attr[:, 0, :] # (num_edges, units) - edge_attr_A = edge_attr[:, 1, :] # (num_edges, units) - edge_attr_S = edge_attr[:, 2, :] # (num_edges, units) - - # After linear transformations, I_tensor, A, S all have shape (num_nodes, 3, 3, units) - # So I_j, A_j, S_j have shape (num_edges, 3, 3, units) - # Expand edge attributes for broadcasting: (num_edges, units) -> (num_edges, 1, 1, units) - edge_attr_I = edge_attr_I.unsqueeze(1).unsqueeze(1) # (num_edges, 1, 1, units) - edge_attr_A = edge_attr_A.unsqueeze(1).unsqueeze(1) # (num_edges, 1, 1, units) - edge_attr_S = edge_attr_S.unsqueeze(1).unsqueeze(1) # (num_edges, 1, 1, units) - - # Apply edge attributes to node features - I_m = I_j * edge_attr_I # (num_edges, 3, 3, units) - A_m = A_j * edge_attr_A # (num_edges, 3, 3, units) - S_m = S_j * edge_attr_S # (num_edges, 3, 3, units) - - # Aggregate messages - Im = scatter_add(I_m, dst, dim_size=num_nodes, dim=0) - Am = scatter_add(A_m, dst, dim_size=num_nodes, dim=0) - Sm = scatter_add(S_m, dst, dim_size=num_nodes, dim=0) - - return Im, Am, Sm - - -def tensor_matmul_o3(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: - """O(3) equivariant tensor multiplication. - - Args: - X: First tensor, shape (num_nodes, 3, 3, units) - Y: Second tensor, shape (num_nodes, 3, 3, units) - - Returns: - Result tensor, shape (num_nodes, 3, 3, units) - """ - # O(3) equivariant: A + B where A = X @ Y, B = Y @ X - A = torch.einsum("nijk,njlk->nilk", X, Y) - B = torch.einsum("nijk,njlk->nilk", Y, X) - return A + B - - -def tensor_matmul_so3(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: - """SO(3) equivariant tensor multiplication. - - Args: - X: First tensor, shape (num_nodes, 3, 3, units) - Y: Second tensor, shape (num_nodes, 3, 3, units) - - Returns: - Result tensor, shape (num_nodes, 3, 3, units) - """ - # SO(3) equivariant: 2 * (X @ Y) - return 2 * torch.einsum("nijk,njlk->nilk", X, Y) +def tensor_norm(tensor): + """Computes Frobenius norm.""" + return (tensor*tensor).sum((-3, -2)) class TensorEmbedding(nn.Module): @@ -271,9 +101,7 @@ def __init__( self.units = units self.cutoff = cutoff - self.distance_proj1 = nn.Linear(degree_rbf, units, dtype=dtype) - self.distance_proj2 = nn.Linear(degree_rbf, units, dtype=dtype) - self.distance_proj3 = nn.Linear(degree_rbf, units, dtype=dtype) + self.distance_proj = nn.Linear(degree_rbf, 3 * units, dtype=dtype) self.emb = nn.Embedding(ntypes_node, units, dtype=dtype) self.emb2 = nn.Linear(2 * units, units, dtype=dtype) self.act = activation @@ -288,10 +116,25 @@ def __init__( self.reset_parameters() + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # since the we changed distance_proj to be a single linear layer, + # we need to concatenate the weights and biases of the three distance_proj layers + # into a single weight and bias tensor + w_keys = [f"{prefix}distance_proj{i}.weight" for i in (1, 2, 3)] + b_keys = [f"{prefix}distance_proj{i}.bias" for i in (1, 2, 3)] + new_w = f"{prefix}distance_proj.weight" + new_b = f"{prefix}distance_proj.bias" + + if all(k in state_dict for k in (w_keys + b_keys)): + state_dict = dict(state_dict) + + state_dict[new_w] = torch.cat([state_dict.pop(k) for k in w_keys], dim=0) + state_dict[new_b] = torch.cat([state_dict.pop(k) for k in b_keys], dim=0) + + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + def reset_parameters(self): - self.distance_proj1.reset_parameters() - self.distance_proj2.reset_parameters() - self.distance_proj3.reset_parameters() + self.distance_proj.reset_parameters() self.emb.reset_parameters() self.emb2.reset_parameters() for linear in self.linears_tensor: @@ -307,6 +150,8 @@ def forward( edge_weight: torch.Tensor, edge_vec: torch.Tensor, edge_attr: torch.Tensor, + row_data: torch.Tensor, + row_indptr: torch.Tensor, ) -> torch.Tensor: """Forward pass. @@ -320,90 +165,52 @@ def forward( Returns: X: Tensor representation, shape (num_nodes, 3, 3, units) """ - num_nodes = z.shape[0] - # Node embedding x = self.emb(z) # (num_nodes, units) # Edge processing C = cosine_cutoff(edge_weight, self.cutoff) - W1 = self.distance_proj1(edge_attr) * C.view(-1, 1) # (num_edges, units) - W2 = self.distance_proj2(edge_attr) * C.view(-1, 1) - W3 = self.distance_proj3(edge_attr) * C.view(-1, 1) - - edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) + edge_attr = self.distance_proj(edge_attr).view(-1, 3, self.units) # Get atomic number messages - src, dst = edge_index[0], edge_index[1] - vi = x[src] - vj = x[dst] - zij = torch.cat([vi, vj], dim=-1) + zij = x.index_select(0, edge_index.t().flip(-1).reshape(-1)).view( + -1, self.units * 2 + ) Zij = self.emb2(zij) # (num_edges, units) # Create edge attributes with Zij - edge_attr_processed = torch.stack([W1, W2, W3], dim=1) # (num_edges, 3, units) - edge_attr_processed = edge_attr_processed * Zij.unsqueeze(1) # (num_edges, 3, units) + edge_attr_processed = \ + edge_attr.view(-1, 3, self.units) \ + * C.view(-1, 1, 1) \ + * Zij.view(-1, 1, self.units) # Radial message passing - I_tensor, A, S = radial_message_passing(edge_vec_norm, edge_attr_processed, edge_index, num_nodes) + edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) + I, A, S = fn_radial_message_passing( + edge_vec_norm, edge_attr_processed, row_data, row_indptr + ) # Compose initial tensor to get proper shape for norm computation - X = compose_tensor(I_tensor, A, S) # (num_nodes, 3, 3, units) + X = fn_compose_tensor(I, A, S) # (num_nodes, 3, 3, units) # Normalize and process - # Following original: norm = tensor_norm(scalars + skew_matrices + traceless_tensors) - # For X with shape (num_nodes, 3, 3, units), we need to sum over (-3, -2) - # which are the (3, 3) spatial dimensions - # tensor_norm sums over (-2, -1), but we need (-3, -2) for our tensor shape - # So we compute the norm manually: sum over the spatial (3, 3) dimensions - norm = (X**2).sum((-3, -2)) # (num_nodes, units) - norm = self.init_norm(norm) # (num_nodes, units) - # Apply tensor linear transformations - # I_tensor has shape (num_nodes, 1, 1, units), A and S have (num_nodes, 3, 3, units) - # The linear layer expects (..., units) as the last dimension - # Original code: permute(0, 2, 3, 1) puts units in position -2, then linear, then permute back - # For (num_nodes, 3, 3, units): permute(0, 2, 3, 1) -> (num_nodes, 3, units, 3) - # But linear expects (..., units), so we need to reshape or use a different approach - # Actually, the linear is applied to each spatial position independently - # So we reshape to (num_nodes * 3 * 3, units), apply linear, reshape back - if I_tensor.shape[1] == 1 and I_tensor.shape[2] == 1: - # Expand I_tensor from (num_nodes, 1, 1, units) to (num_nodes, 3, 3, units) - eye = torch.eye(3, 3, device=I_tensor.device, dtype=I_tensor.dtype) # (3, 3) - I_values = I_tensor.squeeze(1).squeeze(1) # (num_nodes, units) - I_expanded = I_values.unsqueeze(1).unsqueeze(1) * eye.unsqueeze(0).unsqueeze(-1) # (num_nodes, 3, 3, units) - # Reshape to (num_nodes * 3 * 3, units), apply linear, reshape back - I_reshaped = I_expanded.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[0](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_expanded.shape) # (num_nodes, 3, 3, units) - else: - # Reshape to (num_nodes * 3 * 3, units), apply linear, reshape back - I_reshaped = I_tensor.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[0](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_tensor.shape) # (num_nodes, 3, 3, units) - - # Same for A and S - A_reshaped = A.reshape(-1, self.units) # (num_nodes * 9, units) - A_reshaped = self.linears_tensor[1](A_reshaped) # (num_nodes * 9, units) - A = A_reshaped.reshape(A.shape) # (num_nodes, 3, 3, units) - - S_reshaped = S.reshape(-1, self.units) # (num_nodes * 9, units) - S_reshaped = self.linears_tensor[2](S_reshaped) # (num_nodes * 9, units) - S = S_reshaped.reshape(S.shape) # (num_nodes, 3, 3, units) + norm = tensor_norm(X) # (num_nodes, units) + norm = self.init_norm(norm) # (num_nodes, units) # Process norm through scalar layers for linear_scalar in self.linears_scalar: norm = self.act(linear_scalar(norm)) - norm = norm.reshape(norm.shape[0], self.units, 3) - norm_I, norm_A, norm_S = norm[..., 0], norm[..., 1], norm[..., 2] + norm = norm.view(-1, self.units, 3) + norm_I, norm_A, norm_S = norm.unbind(dim=-1) # Apply norm to tensors - I_tensor = I_tensor * norm_I.unsqueeze(1).unsqueeze(1) - A = A * norm_A.unsqueeze(1).unsqueeze(1) - S = S * norm_S.unsqueeze(1).unsqueeze(1) + I = self.linears_tensor[0](I) * norm_I.unsqueeze(-2) + A = self.linears_tensor[1](A) * norm_A.unsqueeze(-2) + S = self.linears_tensor[2](S) * norm_S.unsqueeze(-2) - X = compose_tensor(I_tensor, A, S) + X = fn_compose_tensor(I, A, S) return X @@ -453,6 +260,12 @@ def forward( edge_index: torch.Tensor, edge_weight: torch.Tensor, edge_attr: torch.Tensor, + row_data: torch.Tensor, + row_indices: torch.Tensor, + row_indptr: torch.Tensor, + col_data: torch.Tensor, + col_indices: torch.Tensor, + col_indptr: torch.Tensor, ) -> torch.Tensor: """Forward pass. @@ -465,93 +278,65 @@ def forward( Returns: X: Updated tensor representations, shape (num_nodes, 3, 3, units) """ - num_nodes = X.shape[0] - # Process edge attributes C = cosine_cutoff(edge_weight, self.cutoff) edge_attr_processed = edge_attr for linear_scalar in self.linears_scalar: edge_attr_processed = self.act(linear_scalar(edge_attr_processed)) - edge_attr_processed = (edge_attr_processed * C.view(-1, 1)).reshape( + edge_attr_processed = (edge_attr_processed * C.view(-1, 1)).view( edge_attr.shape[0], 3, self.units - ) # (num_edges, 3, units) + ).mT.contiguous() # (num_edges, units, 3) # Normalize input tensor # For X with shape (num_nodes, 3, 3, units), we need to sum over (-3, -2) # which are the (3, 3) spatial dimensions to get (num_nodes, units) - norm_X = (X**2).sum((-3, -2)) + 1 # (num_nodes, units) - X = X / norm_X.reshape(-1, 1, 1, X.shape[-1]) + norm_X = (X*X).sum((-3, -2)) + 1 # (num_nodes, units) + X = X / norm_X.view(-1, 1, 1, X.shape[-1]) # Decompose input tensor - # X has shape (num_nodes, 3, 3, units) - # decompose_tensor expects (..., 3, 3), so we permute to (num_nodes, units, 3, 3) - # then apply decompose_tensor which works on the last two dimensions (3, 3) - X_permuted = X.permute(0, 3, 1, 2) # (num_nodes, units, 3, 3) - # decompose_tensor works on last two dims, so this will work for each (num_nodes, units) slice - I_permuted, A_permuted, S_permuted = decompose_tensor(X_permuted) # Each: (num_nodes, units, 3, 3) - # Permute back: (num_nodes, units, 3, 3) -> (num_nodes, 3, 3, units) - I_tensor = I_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - A = A_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - S = S_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) + I, A, S = fn_decompose_tensor(X) # Apply tensor linear transformations - # Reshape to (num_nodes * 9, units), apply linear, reshape back - I_reshaped = I_tensor.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[0](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_tensor.shape) # (num_nodes, 3, 3, units) - - A_reshaped = A.reshape(-1, self.units) # (num_nodes * 9, units) - A_reshaped = self.linears_tensor[1](A_reshaped) # (num_nodes * 9, units) - A = A_reshaped.reshape(A.shape) # (num_nodes, 3, 3, units) + I = self.linears_tensor[0](I) + A = self.linears_tensor[1](A) + S = self.linears_tensor[2](S) - S_reshaped = S.reshape(-1, self.units) # (num_nodes * 9, units) - S_reshaped = self.linears_tensor[2](S_reshaped) # (num_nodes * 9, units) - S = S_reshaped.reshape(S.shape) # (num_nodes, 3, 3, units) - Y = compose_tensor(I_tensor, A, S) + # compose back + Y = fn_compose_tensor(I, A, S) # Message passing - Im, Am, Sm = message_passing(I_tensor, A, S, edge_attr_processed, edge_index, num_nodes) - msg = compose_tensor(Im, Am, Sm) + Im, Am, Sm = fn_message_passing( + I, + A, + S, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + msg = fn_compose_tensor(Im, Am, Sm) # Apply group action if self.equivariance_invariance_group == "O(3)": - C = tensor_matmul_o3(Y, msg) # (num_nodes, 3, 3, units) + C = fn_tensor_matmul_o3_3x3(Y, msg) else: # SO(3) - C = tensor_matmul_so3(Y, msg) # (num_nodes, 3, 3, units) - C = 2 * C - - # decompose_tensor expects (..., 3, 3), so permute to (num_nodes, units, 3, 3) - C_permuted = C.permute(0, 3, 1, 2) # (num_nodes, units, 3, 3) - I_permuted, A_permuted, S_permuted = decompose_tensor(C_permuted) # Each: (num_nodes, units, 3, 3) - # Permute back: (num_nodes, units, 3, 3) -> (num_nodes, 3, 3, units) - I_tensor = I_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - A = A_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - S = S_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) + C = fn_tensor_matmul_so3_3x3(Y, msg) + C = C = C + I, A, S = fn_decompose_tensor(C) # Normalize - # For compose_tensor(I_tensor, A, S) with shape (num_nodes, 3, 3, units), - # we need to sum over (-3, -2) to get (num_nodes, units) - X_composed = compose_tensor(I_tensor, A, S) # (num_nodes, 3, 3, units) - normp1 = ((X_composed**2).sum((-3, -2)) + 1).reshape(-1, 1, 1, X_composed.shape[-1]) - I_tensor, A, S = I_tensor / normp1, A / normp1, S / normp1 + normp1 = (tensor_norm(C) + 1).unsqueeze(-2) + I, A, S = I / normp1, A / normp1, S / normp1 # Final tensor transformations - # Reshape to (num_nodes * 9, units), apply linear, reshape back - I_reshaped = I_tensor.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[3](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_tensor.shape) # (num_nodes, 3, 3, units) - - A_reshaped = A.reshape(-1, self.units) # (num_nodes * 9, units) - A_reshaped = self.linears_tensor[4](A_reshaped) # (num_nodes * 9, units) - A = A_reshaped.reshape(A.shape) # (num_nodes, 3, 3, units) - - S_reshaped = S.reshape(-1, self.units) # (num_nodes * 9, units) - S_reshaped = self.linears_tensor[5](S_reshaped) # (num_nodes * 9, units) - S = S_reshaped.reshape(S.shape) # (num_nodes, 3, 3, units) - dX = compose_tensor(I_tensor, A, S) - - # Update - X = X + dX + torch.einsum("nijk,njlk->nilk", dX, dX) + I = self.linears_tensor[3](I) + A = self.linears_tensor[4](A) + S = self.linears_tensor[5](S) + dX = fn_compose_tensor(I, A, S) + X = X + dX + fn_tensor_matmul_so3_3x3(dX, dX) return X @@ -760,31 +545,43 @@ def forward( # Obtain graph, with distances and relative position vectors bond_vec, bond_dist = compute_pair_vector_and_distance(pos, edge_index, pbc_offshift) + # perpare graph indices for message passing + row_data, row_indices, row_indptr, col_data, col_indices, col_indptr = ( + graph_transform(edge_index.int(), z.shape[0]) + ) + # Expand distances with radial basis functions edge_attr = self.bond_expansion(bond_dist) # Embedding layer - X = self.tensor_embedding(z, edge_index, bond_dist, bond_vec, edge_attr) + X = self.tensor_embedding( + z, + edge_index, + bond_dist, + bond_vec, + edge_attr, + row_data, + row_indptr + ) # Interaction layers for layer in self.layers: - X = layer(X, edge_index, bond_dist, edge_attr) - - # decompose_tensor expects (..., 3, 3), so permute to (num_nodes, units, 3, 3) - # X has shape (num_nodes, 3, 3, units) - X_permuted = X.permute(0, 3, 1, 2) # (num_nodes, units, 3, 3) - scalars_permuted, skew_metrices_permuted, traceless_tensors_permuted = decompose_tensor(X_permuted) - # Permute back: (num_nodes, units, 3, 3) -> (num_nodes, 3, 3, units) - scalars = scalars_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - skew_metrices = skew_metrices_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - traceless_tensors = traceless_tensors_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - - # tensor_norm sums over (-2, -1), but for (num_nodes, 3, 3, units) we need to sum over (-3, -2) - # to get (num_nodes, units) - scalars_norm = (scalars**2).sum((-3, -2)) # (num_nodes, units) - skew_norm = (skew_metrices**2).sum((-3, -2)) # (num_nodes, units) - traceless_norm = (traceless_tensors**2).sum((-3, -2)) # (num_nodes, units) - x = torch.cat((scalars_norm, skew_norm, traceless_norm), dim=-1) # (num_nodes, 3 * units) + X = layer( + X, + edge_index, + bond_dist, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + # compute I, A, S norms + x = fn_tensor_norm3(X) + # normalize x = self.out_norm(x) x = self.linear(x) diff --git a/src/matgl/ops/__init__.py b/src/matgl/ops/__init__.py new file mode 100644 index 00000000..42b8141a --- /dev/null +++ b/src/matgl/ops/__init__.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .tensornet_radial_mp import fn_radial_message_passing +from .compose_tensor import fn_compose_tensor +from .decompose_tensor import fn_decompose_tensor +from .equivariant_o3_matmul import fn_tensor_matmul_o3_3x3 +from .equivariant_so3_matmul import fn_tensor_matmul_so3_3x3 +from .tensor_norm3 import fn_tensor_norm3 +from .tensornet_mp import fn_message_passing +from .tensornet_radial_mp import fn_radial_message_passing +from .graph_transform import graph_transform + +import warp as wp +wp.init() + +__all__ = [ + "fn_radial_message_passing", + "fn_compose_tensor", + "fn_decompose_tensor", + "fn_tensor_matmul_o3_3x3", + "fn_tensor_matmul_so3_3x3", + "fn_tensor_norm3", + "fn_message_passing", + "fn_radial_message_passing", + "graph_transform", +] \ No newline at end of file diff --git a/src/matgl/ops/compose_tensor.py b/src/matgl/ops/compose_tensor.py new file mode 100644 index 00000000..95dd87e3 --- /dev/null +++ b/src/matgl/ops/compose_tensor.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::compose_tensor_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty( + (x.shape[0], 3, 3, x.shape[-1]), dtype=x.dtype, device=x.device + ) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + compose_tensor_fwd = get_module("compose_tensor_fwd", [str(x.dtype)]) + wp.launch( + compose_tensor_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, z_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("nvtensornet::compose_tensor_fwd_primitive") +def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: + return torch.empty((z.shape[0], 3, 3, z.shape[-1]), dtype=x.dtype, device=x.device) + + +@torch.library.custom_op( + "nvtensornet::compose_tensor_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.zeros_like(x) + grad_y = torch.zeros_like(y) + grad_z = torch.zeros_like(z) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_z_wp = wp.from_torch(grad_z.detach(), return_ctype=True) + + compose_tensor_bwd = get_module("compose_tensor_bwd", [str(x.dtype)]) + wp.launch( + compose_tensor_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_output_wp, grad_x_wp, grad_y_wp, grad_z_wp), + ) + + return [grad_x, grad_y, grad_z] + + +@torch.library.register_fake("nvtensornet::compose_tensor_bwd_primitive") +def _(grad_output: List[Tensor], x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: + return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] + + +@torch.library.custom_op( + "nvtensornet::compose_tensor_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, +) -> List[Tensor]: + stream = get_stream(grad_output.device) + device = wp.device_from_torch(grad_output.device) + grad_x = torch.zeros_like(grad_grad_x) + grad_y = torch.zeros_like(grad_grad_y) + grad_z = torch.zeros_like(grad_grad_z) + + grad_grad_output = torch.zeros_like(grad_output) + + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_grad_z_wp = wp.from_torch(grad_grad_z.detach(), return_ctype=True) + + compose_tensor_bwd_bwd = get_module("compose_tensor_bwd_bwd", [str(x.dtype)]) + wp.launch( + compose_tensor_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_grad_x_wp, grad_grad_y_wp, grad_grad_z_wp, grad_grad_output_wp), + ) + + return [grad_grad_output, grad_x, grad_y, grad_z] + + +@torch.library.register_fake("nvtensornet::compose_tensor_bwd_bwd_primitive") +def _( + grad_output: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, +) -> List[Tensor]: + return [ + torch.empty_like(grad_output), + torch.empty_like(x), + torch.empty_like(y), + torch.empty_like(z), + ] + + +def compose_tensor_setup_fwd_context(ctx, inputs, output): + (x, y, z) = inputs + ctx.save_for_backward(x, y, z) + + +def compose_tensor_setup_bwd_context(ctx, inputs, output): + (grad_output, x, y, z) = inputs + ctx.save_for_backward(grad_output, x, y, z) + + +@torch.compiler.allow_in_graph +def compose_tensor_fwd(*args): + return torch.ops.nvtensornet.compose_tensor_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def compose_tensor_bwd(ctx, grad_output): + x, y, z = ctx.saved_tensors + dx, dy, dz = torch.ops.nvtensornet.compose_tensor_bwd_primitive( + grad_output, x, y, z + ) + return dx, dy, dz + + +@torch.compiler.allow_in_graph +def compose_tensor_bwd_bwd(ctx, *grad_outputs): + grad_grad_x = grad_outputs[0][0] + grad_grad_y = grad_outputs[0][1] + grad_grad_z = grad_outputs[0][2] + + grad_output_saved, x, y, z = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + if grad_grad_y is None: + grad_grad_y = torch.zeros_like(y) + if grad_grad_z is None: + grad_grad_z = torch.zeros_like(z) + + outputs = torch.ops.nvtensornet.compose_tensor_bwd_bwd_primitive( + grad_output_saved, grad_grad_x, grad_grad_y, grad_grad_z, x, y, z + ) + + return outputs[0], outputs[1], outputs[2], outputs[3] + + +torch.library.register_autograd( + "nvtensornet::compose_tensor_fwd_primitive", + compose_tensor_bwd, + setup_context=compose_tensor_setup_fwd_context, +) + +torch.library.register_autograd( + "nvtensornet::compose_tensor_bwd_primitive", + compose_tensor_bwd_bwd, + setup_context=compose_tensor_setup_bwd_context, +) + + +def fn_compose_tensor(x: Tensor, y: Tensor, z: Tensor) -> Tensor: + output = torch.ops.nvtensornet.compose_tensor_fwd_primitive(x, y, z) + return output diff --git a/src/matgl/ops/decompose_tensor.py b/src/matgl/ops/decompose_tensor.py new file mode 100644 index 00000000..70f1589d --- /dev/null +++ b/src/matgl/ops/decompose_tensor.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::decompose_tensor_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor) -> List[Tensor]: + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output_i = torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device) + output_a = torch.empty((x.shape[0], 3, x.shape[-1]), dtype=x.dtype, device=x.device) + output_s = torch.empty((x.shape[0], 5, x.shape[-1]), dtype=x.dtype, device=x.device) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + output_i_wp = wp.from_torch(output_i.detach(), return_ctype=True) + output_a_wp = wp.from_torch(output_a.detach(), return_ctype=True) + output_s_wp = wp.from_torch(output_s.detach(), return_ctype=True) + + decompose_tensor_fwd = get_module("decompose_tensor_fwd", [str(x.dtype)]) + wp.launch( + decompose_tensor_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, output_i_wp, output_a_wp, output_s_wp), + ) + + return [output_i, output_a, output_s] + + +@torch.library.register_fake("nvtensornet::decompose_tensor_fwd_primitive") +def _(x: Tensor) -> List[Tensor]: + return [ + torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device), + torch.empty((x.shape[0], 3, x.shape[-1]), dtype=x.dtype, device=x.device), + torch.empty((x.shape[0], 5, x.shape[-1]), dtype=x.dtype, device=x.device), + ] + + +@torch.library.custom_op( + "nvtensornet::decompose_tensor_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor +) -> List[Tensor]: + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + + grad_output_i_wp = wp.from_torch(grad_output_i.detach(), return_ctype=True) + grad_output_a_wp = wp.from_torch(grad_output_a.detach(), return_ctype=True) + grad_output_s_wp = wp.from_torch(grad_output_s.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + + decompose_tensor_bwd = get_module("decompose_tensor_bwd", [str(x.dtype)]) + wp.launch( + decompose_tensor_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_output_i_wp, grad_output_a_wp, grad_output_s_wp, grad_x_wp), + ) + + return [grad_x] + + +@torch.library.register_fake("nvtensornet::decompose_tensor_bwd_primitive") +def _( + grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor +) -> List[Tensor]: + return [torch.empty_like(x)] + + +@torch.library.custom_op( + "nvtensornet::decompose_tensor_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_i: Tensor, + grad_output_a: Tensor, + grad_output_s: Tensor, + grad_grad_x: Tensor, + x: Tensor, +) -> List[Tensor]: + stream = get_stream(grad_output_i.device) + device = wp.device_from_torch(grad_output_i.device) + grad_x = torch.zeros_like(grad_grad_x) + + grad_grad_output_i = torch.empty_like(grad_output_i) + grad_grad_output_a = torch.empty_like(grad_output_a) + grad_grad_output_s = torch.empty_like(grad_output_s) + + grad_grad_output_i_wp = wp.from_torch( + grad_grad_output_i.detach(), return_ctype=True + ) + grad_grad_output_a_wp = wp.from_torch( + grad_grad_output_a.detach(), return_ctype=True + ) + grad_grad_output_s_wp = wp.from_torch( + grad_grad_output_s.detach(), return_ctype=True + ) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + + decompose_tensor_bwd_bwd = get_module("decompose_tensor_bwd_bwd", [str(x.dtype)]) + wp.launch( + decompose_tensor_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + grad_grad_x_wp, + grad_grad_output_i_wp, + grad_grad_output_a_wp, + grad_grad_output_s_wp, + ), + ) + + return [grad_grad_output_i, grad_grad_output_a, grad_grad_output_s, grad_x] + + +@torch.library.register_fake("nvtensornet::decompose_tensor_bwd_bwd_primitive") +def _( + grad_output_i: Tensor, + grad_output_a: Tensor, + grad_output_s: Tensor, + grad_grad_x: Tensor, + x: Tensor, +) -> List[Tensor]: + return [ + torch.empty_like(grad_output_i), + torch.empty_like(grad_output_a), + torch.empty_like(grad_output_s), + torch.empty_like(grad_grad_x), + ] + + +def decompose_tensor_setup_fwd_context(ctx, inputs, output): + (x,) = inputs # Unpack the single input tensor + ctx.save_for_backward(x) + + +def decompose_tensor_setup_bwd_context(ctx, inputs, output): + (grad_output_i, grad_output_a, grad_output_s, x) = inputs + ctx.save_for_backward(grad_output_i, grad_output_a, grad_output_s, x) + + +@torch.compiler.allow_in_graph +def decompose_tensor_fwd(*args): + return torch.ops.nvtensornet.decompose_tensor_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def decompose_tensor_bwd(ctx, *grad_outputs): + (x,) = ctx.saved_tensors + grad_output_i, grad_output_a, grad_output_s = grad_outputs[0] + dx = torch.ops.nvtensornet.decompose_tensor_bwd_primitive( + grad_output_i, grad_output_a, grad_output_s, x + ) + return dx[0] + + +@torch.compiler.allow_in_graph +def decompose_tensor_bwd_bwd(ctx, *grad_outputs): + (grad_grad_x,) = grad_outputs[0] + + grad_output_i, grad_output_a, grad_output_s, x = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + + outputs = torch.ops.nvtensornet.decompose_tensor_bwd_bwd_primitive( + grad_output_i, grad_output_a, grad_output_s, grad_grad_x, x + ) + + return outputs[0], outputs[1], outputs[2], outputs[3] + + +torch.library.register_autograd( + "nvtensornet::decompose_tensor_fwd_primitive", + decompose_tensor_bwd, + setup_context=decompose_tensor_setup_fwd_context, +) + +torch.library.register_autograd( + "nvtensornet::decompose_tensor_bwd_primitive", + decompose_tensor_bwd_bwd, + setup_context=decompose_tensor_setup_bwd_context, +) + + +def fn_decompose_tensor(x: Tensor) -> List[Tensor]: + output = torch.ops.nvtensornet.decompose_tensor_fwd_primitive(x) + return output diff --git a/src/matgl/ops/equivariant_o3_matmul.py b/src/matgl/ops/equivariant_o3_matmul.py new file mode 100644 index 00000000..fd706da2 --- /dev/null +++ b/src/matgl/ops/equivariant_o3_matmul.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::tensor_matmul_o3_3x3_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor, y: Tensor) -> Tensor: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty_like(x) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + tensor_matmul_o3_3x3_fwd = get_module("tensor_matmul_o3_3x3_fwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_o3_3x3_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("nvtensornet::tensor_matmul_o3_3x3_fwd_primitive") +def _(x: Tensor, y: Tensor) -> Tensor: + return torch.empty_like(x) + + +@torch.library.custom_op( + "nvtensornet::tensor_matmul_o3_3x3_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + grad_y = torch.empty_like(y) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + tensor_matmul_o3_3x3_bwd = get_module("tensor_matmul_o3_3x3_bwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_o3_3x3_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, grad_output_wp, grad_x_wp, grad_y_wp), + ) + + return [grad_x, grad_y] + + +@torch.library.register_fake("nvtensornet::tensor_matmul_o3_3x3_bwd_primitive") +def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: + return [torch.empty_like(x), torch.empty_like(y)] + + +@torch.library.custom_op( + "nvtensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor +) -> List[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + + stream = get_stream(grad_output.device) + device = wp.device_from_torch(grad_output.device) + grad_x = torch.empty_like(grad_output) + grad_y = torch.empty_like(grad_output) + + grad_grad_output = torch.empty_like(grad_output) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + tensor_matmul_o3_3x3_bwd_bwd = get_module( + "tensor_matmul_o3_3x3_bwd_bwd", [str(grad_output.dtype)] + ) + wp.launch( + tensor_matmul_o3_3x3_bwd_bwd, + dim=(grad_output.shape[0], grad_output.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_output_wp, + grad_x_wp, + grad_y_wp, + grad_grad_output_wp, + ), + ) + + return [grad_grad_output, grad_x, grad_y] + + +@torch.library.register_fake("nvtensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive") +def _( + grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor +) -> List[Tensor]: + return [ + torch.empty_like(grad_output), + torch.empty_like(grad_output), + torch.empty_like(grad_output), + ] + + +def tensor_matmul_o3_3x3_setup_fwd_context(ctx, inputs, output): + (x, y) = inputs + ctx.save_for_backward(x, y) + + +def tensor_matmul_o3_3x3_setup_bwd_context(ctx, inputs, output): + (grad_output, x, y) = inputs + ctx.save_for_backward(grad_output, x, y) + + +@torch.compiler.allow_in_graph +def tensor_matmul_o3_3x3_fwd(*args): + return getattr(torch.ops.nvtensornet, "tensor_matmul_o3_3x3_fwd_primitive")( + *args + ) + + +@torch.compiler.allow_in_graph +def tensor_matmul_o3_3x3_bwd(ctx, grad_output): + x, y = ctx.saved_tensors + dx, dy = getattr(torch.ops.nvtensornet, "tensor_matmul_o3_3x3_bwd_primitive")( + grad_output, x, y + ) + return dx, dy + + +@torch.compiler.allow_in_graph +def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): + grad_grad_x = grad_outputs[0][0] + grad_grad_y = grad_outputs[0][1] + + grad_output_saved, x, y = ctx.saved_tensors + + outputs = getattr( + torch.ops.nvtensornet, "tensor_matmul_o3_3x3_bwd_bwd_primitive" + )(grad_output_saved, grad_grad_x, grad_grad_y, x, y) + return outputs[0], outputs[1], outputs[2] + + +torch.library.register_autograd( + "nvtensornet::tensor_matmul_o3_3x3_fwd_primitive", + tensor_matmul_o3_3x3_bwd, + setup_context=tensor_matmul_o3_3x3_setup_fwd_context, +) + +torch.library.register_autograd( + "nvtensornet::tensor_matmul_o3_3x3_bwd_primitive", + tensor_matmul_o3_3x3_bwd_bwd, + setup_context=tensor_matmul_o3_3x3_setup_bwd_context, +) + + +def fn_tensor_matmul_o3_3x3(x: Tensor, y: Tensor) -> Tensor: + z = getattr(torch.ops.nvtensornet, "tensor_matmul_o3_3x3_fwd_primitive")(x, y) + return z diff --git a/src/matgl/ops/equivariant_so3_matmul.py b/src/matgl/ops/equivariant_so3_matmul.py new file mode 100644 index 00000000..16d61748 --- /dev/null +++ b/src/matgl/ops/equivariant_so3_matmul.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::tensor_matmul_so3_3x3_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor, y: Tensor) -> Tensor: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty_like(x) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + tensor_matmul_so3_3x3_fwd = get_module("tensor_matmul_so3_3x3_fwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_so3_3x3_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("nvtensornet::tensor_matmul_so3_3x3_fwd_primitive") +def _(x: Tensor, y: Tensor) -> Tensor: + return torch.empty_like(x) + + +@torch.library.custom_op( + "nvtensornet::tensor_matmul_so3_3x3_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + grad_y = torch.empty_like(y) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + tensor_matmul_so3_3x3_bwd = get_module("tensor_matmul_so3_3x3_bwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_so3_3x3_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, grad_output_wp, grad_x_wp, grad_y_wp), + ) + + return [grad_x, grad_y] + + +@torch.library.register_fake("nvtensornet::tensor_matmul_so3_3x3_bwd_primitive") +def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: + return [torch.empty_like(x), torch.empty_like(y)] + + +@torch.library.custom_op( + "nvtensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor +) -> List[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + stream = get_stream(grad_output.device) + device = wp.device_from_torch(grad_output.device) + grad_x = torch.empty_like(grad_output) + grad_y = torch.empty_like(grad_output) + + grad_grad_output = torch.empty_like(grad_output) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + tensor_matmul_so3_3x3_bwd_bwd = get_module( + "tensor_matmul_so3_3x3_bwd_bwd", [str(grad_output.dtype)] + ) + wp.launch( + tensor_matmul_so3_3x3_bwd_bwd, + dim=(grad_output.shape[0], grad_output.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_output_wp, + grad_x_wp, + grad_y_wp, + grad_grad_output_wp, + ), + ) + + return [grad_grad_output, grad_x, grad_y] + + +@torch.library.register_fake("nvtensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive") +def _( + grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor +) -> List[Tensor]: + return [ + torch.empty_like(grad_output), + torch.empty_like(grad_output), + torch.empty_like(grad_output), + ] + + +def tensor_matmul_so3_3x3_setup_fwd_context(ctx, inputs, output): + (x, y) = inputs + ctx.save_for_backward(x, y) + + +def tensor_matmul_so3_3x3_setup_bwd_context(ctx, inputs, output): + (grad_output, x, y) = inputs + ctx.save_for_backward(grad_output, x, y) + + +@torch.compiler.allow_in_graph +def tensor_matmul_so3_3x3_fwd(*args): + return torch.ops.nvtensornet.tensor_matmul_so3_3x3_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def tensor_matmul_so3_3x3_bwd(ctx, grad_output): + x, y = ctx.saved_tensors + dx, dy = torch.ops.nvtensornet.tensor_matmul_so3_3x3_bwd_primitive( + grad_output, x, y + ) + return dx, dy + + +@torch.compiler.allow_in_graph +def tensor_matmul_so3_3x3_bwd_bwd(ctx, *grad_outputs): + grad_grad_x = grad_outputs[0][0] + grad_grad_y = grad_outputs[0][1] + + grad_output_saved, x, y = ctx.saved_tensors + + outputs = torch.ops.nvtensornet.tensor_matmul_so3_3x3_bwd_bwd_primitive( + grad_output_saved, grad_grad_x, grad_grad_y, x, y + ) + return outputs[0], outputs[1], outputs[2] + + +torch.library.register_autograd( + "nvtensornet::tensor_matmul_so3_3x3_fwd_primitive", + tensor_matmul_so3_3x3_bwd, + setup_context=tensor_matmul_so3_3x3_setup_fwd_context, +) + +torch.library.register_autograd( + "nvtensornet::tensor_matmul_so3_3x3_bwd_primitive", + tensor_matmul_so3_3x3_bwd_bwd, + setup_context=tensor_matmul_so3_3x3_setup_bwd_context, +) + + +def fn_tensor_matmul_so3_3x3(x: Tensor, y: Tensor) -> Tensor: + z = torch.ops.nvtensornet.tensor_matmul_so3_3x3_fwd_primitive(x, y) + return z diff --git a/src/matgl/ops/graph_transform.py b/src/matgl/ops/graph_transform.py new file mode 100644 index 00000000..14e2e944 --- /dev/null +++ b/src/matgl/ops/graph_transform.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Tuple + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import count_row_col, convert_to_sparse, get_stream + + +@torch.library.custom_op( + "nvtnet::count_row_col_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor]: + stream = get_stream(edge_index.device) + device = wp.device_from_torch(edge_index.device) + row_count = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + col_count = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + + edge_index_wp = wp.from_torch(edge_index, return_ctype=True) + row_count_wp = wp.from_torch(row_count, return_ctype=True) + col_count_wp = wp.from_torch(col_count, return_ctype=True) + + wp.launch( + count_row_col, + dim=(edge_index.shape[1]), + stream=stream, + device=device, + inputs=(edge_index_wp, row_count_wp, col_count_wp), + ) + + return row_count, col_count + + +@torch.library.register_fake("nvtnet::count_row_col_primitive") +def _(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor]: + output = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + output2 = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + return output, output2 + + +@torch.library.custom_op( + "nvtnet::convert_to_sparse_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + edge_index: Tensor, + row_count: Tensor, + col_count: Tensor, + row_indptr: Tensor, + col_indptr: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + stream = get_stream(edge_index.device) + device = wp.device_from_torch(edge_index.device) + edge_index_wp = wp.from_torch(edge_index, return_ctype=True) + + row_count_wp = wp.from_torch(row_count, return_ctype=True) + col_count_wp = wp.from_torch(col_count, return_ctype=True) + + row_indptr_wp = wp.from_torch(row_indptr, return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr, return_ctype=True) + + row_indices = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + col_indices = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + + row_data = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + col_data = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + + row_indices_wp = wp.from_torch(row_indices, return_ctype=True) + col_indices_wp = wp.from_torch(col_indices, return_ctype=True) + + row_data_wp = wp.from_torch(row_data, return_ctype=True) + col_data_wp = wp.from_torch(col_data, return_ctype=True) + + wp.launch( + convert_to_sparse, + dim=(edge_index.shape[1]), + stream=stream, + device=device, + inputs=( + edge_index_wp, + row_count_wp, + col_count_wp, + row_indptr_wp, + col_indptr_wp, + row_indices_wp, + col_indices_wp, + row_data_wp, + col_data_wp, + ), + ) + + return row_indices, col_indices, row_data, col_data + + +@torch.library.register_fake("nvtnet::convert_to_sparse_primitive") +def _( + edge_index: Tensor, + row_count: Tensor, + col_count: Tensor, + row_indptr: Tensor, + col_indptr: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + output = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + output2 = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + output3 = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + output4 = torch.empty( + edge_index.shape[1], dtype=torch.int32, device=edge_index.device + ) + return output, output2, output3, output4 + + +@torch.compiler.allow_in_graph +def graph_transform( + edge_index: Tensor, num_nodes: int +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + row_count, col_count = torch.ops.nvtnet.count_row_col_primitive( + edge_index, num_nodes + ) + row_indptr, col_indptr = torch.cumsum( + row_count, dim=0, dtype=torch.int32 + ), torch.cumsum(col_count, dim=0, dtype=torch.int32) + ( + row_indices, + col_indices, + row_data, + col_data, + ) = torch.ops.nvtnet.convert_to_sparse_primitive( + edge_index, row_count, col_count, row_indptr, col_indptr + ) + return row_data, row_indices, row_indptr, col_data, col_indices, col_indptr diff --git a/src/matgl/ops/tensor_norm3.py b/src/matgl/ops/tensor_norm3.py new file mode 100644 index 00000000..7433a7e5 --- /dev/null +++ b/src/matgl/ops/tensor_norm3.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::tensor_norm3_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor) -> Tensor: + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty((x.shape[0], 3 * x.shape[-1]), dtype=x.dtype, device=x.device) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + tensor_norm3_fwd = get_module("tensor_norm3_fwd", [str(x.dtype)]) + wp.launch( + tensor_norm3_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("nvtensornet::tensor_norm3_fwd_primitive") +def _(x: Tensor) -> Tensor: + return torch.empty((x.shape[0], 3 * x.shape[-1]), dtype=x.dtype, device=x.device) + + +@torch.library.custom_op( + "nvtensornet::tensor_norm3_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output: Tensor, x: Tensor +) -> List[Tensor]: + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + + tensor_norm3_bwd = get_module("tensor_norm3_bwd", [str(x.dtype)]) + wp.launch( + tensor_norm3_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_output_wp, x_wp, grad_x_wp), + ) + + return [grad_x] + + +@torch.library.register_fake("nvtensornet::tensor_norm3_bwd_primitive") +def _( + grad_output: Tensor, x: Tensor +) -> List[Tensor]: + return [torch.empty_like(x)] + + +@torch.library.custom_op( + "nvtensornet::tensor_norm3_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_grad_x: Tensor, +) -> Tensor: + stream = get_stream(grad_grad_x.device) + device = wp.device_from_torch(grad_grad_x.device) + grad_grad_output = torch.empty((grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), dtype=grad_grad_x.dtype, device=grad_grad_x.device) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + tensor_norm3_bwd_bwd = get_module("tensor_norm3_bwd_bwd", [str(grad_grad_x.dtype)]) + wp.launch( + tensor_norm3_bwd_bwd, + dim=(grad_grad_x.shape[0], grad_grad_x.shape[-1]), + stream=stream, + device=device, + inputs=( + grad_grad_x_wp, + grad_grad_output_wp, + ), + ) + + return grad_grad_output + + +@torch.library.register_fake("nvtensornet::tensor_norm3_bwd_bwd_primitive") +def _( + grad_grad_x: Tensor, +) -> Tensor: + return torch.empty((grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), dtype=grad_grad_x.dtype, device=grad_grad_x.device) + + +def tensor_norm3_fwd_setup_context(ctx, inputs, output): + (x,) = inputs # Unpack the single input tensor + ctx.save_for_backward(x) + + +def tensor_norm3_bwd_setup_context(ctx, inputs, output): + (grad_output, x) = inputs + ctx.save_for_backward(x) + + +@torch.compiler.allow_in_graph +def tensor_norm3_fwd(*args): + return torch.ops.nvtensornet.tensor_norm3_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def tensor_norm3_bwd(ctx, grad_output): + (x,) = ctx.saved_tensors + dx = torch.ops.nvtensornet.tensor_norm3_bwd_primitive( + grad_output, x + ) + return dx[0] + + +@torch.compiler.allow_in_graph +def tensor_norm3_bwd_bwd(ctx, grad_grad_x): + (x,) = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + + grad_grad_output = torch.ops.nvtensornet.tensor_norm3_bwd_bwd_primitive( + grad_grad_x + ) + + return grad_grad_output + + +torch.library.register_autograd( + "nvtensornet::tensor_norm3_fwd_primitive", + tensor_norm3_bwd, + setup_context=tensor_norm3_fwd_setup_context, +) + +torch.library.register_autograd( + "nvtensornet::tensor_norm3_bwd_primitive", + tensor_norm3_bwd_bwd, + setup_context=tensor_norm3_bwd_setup_context, +) + + +def fn_tensor_norm3(x: Tensor) -> Tensor: + return torch.ops.nvtensornet.tensor_norm3_fwd_primitive(x) diff --git a/src/matgl/ops/tensornet_mp.py b/src/matgl/ops/tensornet_mp.py new file mode 100644 index 00000000..a8829d7a --- /dev/null +++ b/src/matgl/ops/tensornet_mp.py @@ -0,0 +1,554 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::message_passing_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output_x = torch.empty_like(x) + output_y = torch.empty_like(y) + output_z = torch.empty_like(z) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + output_x_wp = wp.from_torch(output_x.detach(), return_ctype=True) + output_y_wp = wp.from_torch(output_y.detach(), return_ctype=True) + output_z_wp = wp.from_torch(output_z.detach(), return_ctype=True) + + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_fwd = get_module("message_passing_fwd", [str(x.dtype)]) + wp.launch( + message_passing_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + row_data_wp, + row_indices_wp, + row_indptr_wp, + output_x_wp, + output_y_wp, + output_z_wp, + ), + ) + + return [output_x, output_y, output_z] + + +@torch.library.register_fake("nvtensornet::message_passing_fwd_primitive") +def _( + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] + + +@torch.library.custom_op( + "nvtensornet::message_passing_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + grad_y = torch.empty_like(y) + grad_z = torch.empty_like(z) + + grad_edge_attr = torch.zeros_like(edge_attr) + + grad_output_x_wp = wp.from_torch(grad_output_x.detach(), return_ctype=True) + grad_output_y_wp = wp.from_torch(grad_output_y.detach(), return_ctype=True) + grad_output_z_wp = wp.from_torch(grad_output_z.detach(), return_ctype=True) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_z_wp = wp.from_torch(grad_z.detach(), return_ctype=True) + grad_edge_attr_wp = wp.from_torch(grad_edge_attr.detach(), return_ctype=True) + + message_passing_bwd = get_module("message_passing_bwd", [str(x.dtype)]) + + wp.launch( + message_passing_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + grad_output_x_wp, + grad_output_y_wp, + grad_output_z_wp, + row_data_wp, + row_indices_wp, + row_indptr_wp, + grad_x_wp, + grad_y_wp, + grad_z_wp, + grad_edge_attr_wp, + ), + ) + + return [grad_x, grad_y, grad_z, grad_edge_attr] + + +@torch.library.register_fake("nvtensornet::message_passing_bwd_primitive") +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + return [ + torch.empty_like(x), + torch.empty_like(y), + torch.empty_like(z), + torch.empty_like(edge_attr), + ] + + +@torch.library.custom_op( + "nvtensornet::message_passing_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + grad_grad_edge_attr: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_grad_z_wp = wp.from_torch(grad_grad_z.detach(), return_ctype=True) + grad_grad_edge_attr_wp = wp.from_torch( + grad_grad_edge_attr.detach(), return_ctype=True + ) + grad_output_x_wp = wp.from_torch(grad_output_x.detach(), return_ctype=True) + grad_output_y_wp = wp.from_torch(grad_output_y.detach(), return_ctype=True) + grad_output_z_wp = wp.from_torch(grad_output_z.detach(), return_ctype=True) + + dgrad_output_x = torch.empty_like(grad_output_x) + dgrad_output_y = torch.empty_like(grad_output_y) + dgrad_output_z = torch.empty_like(grad_output_z) + + dgrad_x = torch.empty_like(x) + dgrad_y = torch.empty_like(y) + dgrad_z = torch.empty_like(z) + + dgrad_edge_attr = torch.empty_like(edge_attr) + + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + dgrad_x_wp = wp.from_torch(dgrad_x.detach(), return_ctype=True) + dgrad_y_wp = wp.from_torch(dgrad_y.detach(), return_ctype=True) + dgrad_z_wp = wp.from_torch(dgrad_z.detach(), return_ctype=True) + + dgrad_edge_attr_wp = wp.from_torch(dgrad_edge_attr.detach(), return_ctype=True) + + dgrad_output_x_wp = wp.from_torch(dgrad_output_x.detach(), return_ctype=True) + dgrad_output_y_wp = wp.from_torch(dgrad_output_y.detach(), return_ctype=True) + dgrad_output_z_wp = wp.from_torch(dgrad_output_z.detach(), return_ctype=True) + + col_data_wp = wp.from_torch(col_data.detach(), return_ctype=True) + col_indices_wp = wp.from_torch(col_indices.detach(), return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_bwd_bwd = get_module("message_passing_bwd_bwd", [str(x.dtype)]) + + wp.launch( + message_passing_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_grad_z_wp, + grad_grad_edge_attr_wp, + grad_output_x_wp, + grad_output_y_wp, + grad_output_z_wp, + row_data_wp, + row_indices_wp, + row_indptr_wp, + col_data_wp, + col_indices_wp, + col_indptr_wp, + dgrad_x_wp, + dgrad_y_wp, + dgrad_z_wp, + dgrad_edge_attr_wp, + dgrad_output_x_wp, + dgrad_output_y_wp, + dgrad_output_z_wp, + ), + ) + return [ + dgrad_output_x, + dgrad_output_y, + dgrad_output_z, + dgrad_x, + dgrad_y, + dgrad_z, + dgrad_edge_attr, + ] + + +@torch.library.register_fake("nvtensornet::message_passing_bwd_bwd_primitive") +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + grad_grad_edge_attr: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + return [ + torch.empty_like(grad_output_x), + torch.empty_like(grad_output_y), + torch.empty_like(grad_output_z), + torch.empty_like(grad_grad_x), + torch.empty_like(grad_grad_y), + torch.empty_like(grad_grad_z), + torch.empty_like(grad_grad_edge_attr), + ] + + +def message_passing_setup_fwd_context(ctx, inputs, output): + ( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = inputs + ctx.save_for_backward( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + +def message_passing_setup_bwd_context(ctx, inputs, output): + ( + grad_output_x, + grad_output_y, + grad_output_z, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = inputs + ctx.save_for_backward( + grad_output_x, + grad_output_y, + grad_output_z, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + +@torch.compiler.allow_in_graph +def message_passing_fwd(*args): + return torch.ops.nvtensornet.message_passing_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def message_passing_bwd(ctx, grad_outputs): + ( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = ctx.saved_tensors + + result = torch.ops.nvtensornet.message_passing_bwd_primitive( + grad_outputs[0], + grad_outputs[1], + grad_outputs[2], + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + grad_x, grad_y, grad_z, grad_edge_attr = result + + return grad_x, grad_y, grad_z, grad_edge_attr, None, None, None, None, None, None + + +@torch.compiler.allow_in_graph +def message_passing_bwd_bwd(ctx, *grad_outputs): + grad_grad_x, grad_grad_y, grad_grad_z, grad_grad_edge_attr = grad_outputs[0] + + ( + grad_output_x, + grad_output_y, + grad_output_z, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = ctx.saved_tensors + + result = torch.ops.nvtensornet.message_passing_bwd_bwd_primitive( + grad_output_x, + grad_output_y, + grad_output_z, + grad_grad_x, + grad_grad_y, + grad_grad_z, + grad_grad_edge_attr, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + return ( + result[0], + result[1], + result[2], + result[3], + result[4], + result[5], + result[6], + None, + None, + None, + None, + None, + None, + ) + + +torch.library.register_autograd( + "nvtensornet::message_passing_fwd_primitive", + message_passing_bwd, + setup_context=message_passing_setup_fwd_context, +) + +torch.library.register_autograd( + "nvtensornet::message_passing_bwd_primitive", + message_passing_bwd_bwd, + setup_context=message_passing_setup_bwd_context, +) + + +def fn_message_passing( + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> List[Tensor]: + return torch.ops.nvtensornet.message_passing_fwd_primitive( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) diff --git a/src/matgl/ops/tensornet_radial_mp.py b/src/matgl/ops/tensornet_radial_mp.py new file mode 100644 index 00000000..524d98ee --- /dev/null +++ b/src/matgl/ops/tensornet_radial_mp.py @@ -0,0 +1,418 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +import torch +from torch import Tensor + +import warp as wp + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "nvtensornet::radial_message_passing_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor +) -> List[Tensor]: + + num_atoms = row_indptr.shape[0] - 1 + stream = get_stream(edge_vec_norm.device) + device = wp.device_from_torch(edge_vec_norm.device) + output_I = torch.zeros( + (num_atoms, 1, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ) + output_A = torch.zeros( + (num_atoms, 3, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ) + output_S = torch.zeros( + (num_atoms, 5, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ) + + output_I_wp = wp.from_torch(output_I.detach(), return_ctype=True) + output_A_wp = wp.from_torch(output_A.detach(), return_ctype=True) + output_S_wp = wp.from_torch(output_S.detach(), return_ctype=True) + + edge_vec_norm_wp = wp.from_torch(edge_vec_norm.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_fwd = get_module( + "radial_message_passing_fwd", [str(edge_vec_norm.dtype)] + ) + wp.launch( + message_passing_fwd, + dim=(num_atoms, edge_attr.shape[-1]), + stream=stream, + device=device, + inputs=( + edge_vec_norm_wp, + edge_attr_wp, + row_data_wp, + row_indptr_wp, + output_I_wp, + output_A_wp, + output_S_wp, + ), + ) + + return [output_I, output_A, output_S] + + +@torch.library.register_fake("nvtensornet::radial_message_passing_fwd_primitive") +def _( + edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor +) -> List[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + return [ + torch.empty( + (num_atoms, 1, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ), + torch.empty( + (num_atoms, 3, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ), + torch.empty( + (num_atoms, 5, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ), + ] + + +@torch.library.custom_op( + "nvtensornet::radial_message_passing_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> List[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + stream = get_stream(grad_output_I.device) + device = wp.device_from_torch(grad_output_I.device) + + grad_output_I_wp = wp.from_torch(grad_output_I.detach(), return_ctype=True) + grad_output_A_wp = wp.from_torch(grad_output_A.detach(), return_ctype=True) + grad_output_S_wp = wp.from_torch(grad_output_S.detach(), return_ctype=True) + + grad_edge_vec_norm = torch.zeros_like(edge_vec_norm) + grad_edge_vec_norm_wp = wp.from_torch( + grad_edge_vec_norm.detach(), return_ctype=True + ) + + grad_edge_attr = torch.zeros_like(edge_attr) + grad_edge_attr_wp = wp.from_torch(grad_edge_attr.detach(), return_ctype=True) + + edge_vec_norm_wp = wp.from_torch(edge_vec_norm.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_bwd = get_module( + "radial_message_passing_bwd", [str(edge_vec_norm.dtype)] + ) + wp.launch( + message_passing_bwd, + dim=(num_atoms, edge_attr.shape[-1]), + stream=stream, + device=device, + inputs=( + edge_vec_norm_wp, + edge_attr_wp, + row_data_wp, + row_indptr_wp, + grad_output_I_wp, + grad_output_A_wp, + grad_output_S_wp, + grad_edge_vec_norm_wp, + grad_edge_attr_wp, + ), + ) + + return [grad_edge_vec_norm, grad_edge_attr] + + +@torch.library.register_fake("nvtensornet::radial_message_passing_bwd_primitive") +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> List[Tensor]: + return [torch.empty_like(edge_vec_norm), torch.empty_like(edge_attr)] + + +@torch.library.custom_op( + "nvtensornet::radial_message_passing_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + grad_grad_edge_vec_norm: Tensor, + grad_grad_edge_attr: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> List[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + stream = get_stream(grad_output_I.device) + device = wp.device_from_torch(grad_output_I.device) + + edge_vec_norm_wp = wp.from_torch(edge_vec_norm.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + grad_grad_edge_vec_norm_wp = wp.from_torch( + grad_grad_edge_vec_norm.detach(), return_ctype=True + ) + grad_grad_edge_attr_wp = wp.from_torch( + grad_grad_edge_attr.detach(), return_ctype=True + ) + + grad_output_I_wp = wp.from_torch(grad_output_I.detach(), return_ctype=True) + grad_output_A_wp = wp.from_torch(grad_output_A.detach(), return_ctype=True) + grad_output_S_wp = wp.from_torch(grad_output_S.detach(), return_ctype=True) + dgrad_output_I = torch.zeros_like(grad_output_I) + dgrad_output_A = torch.zeros_like(grad_output_A) + dgrad_output_S = torch.zeros_like(grad_output_S) + dgrad_output_I_wp = wp.from_torch(dgrad_output_I.detach(), return_ctype=True) + dgrad_output_A_wp = wp.from_torch(dgrad_output_A.detach(), return_ctype=True) + dgrad_output_S_wp = wp.from_torch(dgrad_output_S.detach(), return_ctype=True) + + dgrad_grad_edge_vec_norm = torch.zeros_like(grad_grad_edge_vec_norm) + dgrad_grad_edge_vec_norm_wp = wp.from_torch( + dgrad_grad_edge_vec_norm.detach(), return_ctype=True + ) + + dgrad_grad_edge_attr = torch.zeros_like(grad_grad_edge_attr) + dgrad_grad_edge_attr_wp = wp.from_torch( + dgrad_grad_edge_attr.detach(), return_ctype=True + ) + + message_passing_bwd_bwd = get_module( + "radial_message_passing_bwd_bwd", [str(edge_vec_norm.dtype)] + ) + wp.launch( + message_passing_bwd_bwd, + dim=(num_atoms, edge_attr.shape[-1]), + stream=stream, + device=device, + inputs=( + edge_vec_norm_wp, + edge_attr_wp, + grad_grad_edge_vec_norm_wp, + grad_grad_edge_attr_wp, + grad_output_I_wp, + grad_output_A_wp, + grad_output_S_wp, + row_data_wp, + row_indptr_wp, + dgrad_grad_edge_vec_norm_wp, + dgrad_grad_edge_attr_wp, + dgrad_output_I_wp, + dgrad_output_A_wp, + dgrad_output_S_wp, + ), + ) + + return [ + dgrad_output_I, + dgrad_output_A, + dgrad_output_S, + dgrad_grad_edge_vec_norm, + dgrad_grad_edge_attr, + ] + + +@torch.library.register_fake( + "nvtensornet::radial_message_passing_bwd_bwd_primitive" +) +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_grad_edge_vec_norm: Tensor, + grad_grad_edge_attr: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> List[Tensor]: + return [ + torch.empty_like(grad_output_I), + torch.empty_like(grad_output_A), + torch.empty_like(grad_grad_edge_vec_norm), + torch.empty_like(grad_grad_edge_attr), + ] + + +def radial_message_passing_setup_fwd_context(ctx, inputs, output): + (edge_vec_norm, edge_attr, row_data, row_indptr) = inputs + ctx.save_for_backward(edge_vec_norm, edge_attr, row_data, row_indptr) + + +def radial_message_passing_setup_bwd_context(ctx, inputs, output): + ( + grad_output_I, + grad_output_A, + grad_output_S, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) = inputs + ctx.save_for_backward( + grad_output_I, + grad_output_A, + grad_output_S, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) + + +@torch.compiler.allow_in_graph +def radial_message_passing_fwd(*args): + return torch.ops.nvtensornet.radial_message_passing_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def radial_message_passing_bwd(ctx, grad_outputs): + edge_vec_norm, edge_attr, row_data, row_indptr = ctx.saved_tensors + + result = torch.ops.nvtensornet.radial_message_passing_bwd_primitive( + grad_outputs[0], + grad_outputs[1], + grad_outputs[2], + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) + + grad_edge_vec_norm, grad_edge_attr = result + + return grad_edge_vec_norm, grad_edge_attr, None, None + + +@torch.compiler.allow_in_graph +def radial_message_passing_bwd_bwd(ctx, *grad_outputs): + + grad_grad_edge_vec_norm, grad_grad_edge_attr = grad_outputs[0] + ( + grad_output_I, + grad_output_A, + grad_output_S, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) = ctx.saved_tensors + + result = torch.ops.nvtensornet.radial_message_passing_bwd_bwd_primitive( + grad_output_I, + grad_output_A, + grad_output_S, + grad_grad_edge_vec_norm, + grad_grad_edge_attr, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) + + ( + dgrad_output_I, + dgrad_output_A, + dgrad_output_S, + dgrad_grad_edge_vec_norm, + dgrad_grad_edge_attr, + ) = result + + return ( + dgrad_output_I, + dgrad_output_A, + dgrad_output_S, + dgrad_grad_edge_vec_norm, + dgrad_grad_edge_attr, + None, + None, + ) + + +torch.library.register_autograd( + "nvtensornet::radial_message_passing_fwd_primitive", + radial_message_passing_bwd, + setup_context=radial_message_passing_setup_fwd_context, +) + +torch.library.register_autograd( + "nvtensornet::radial_message_passing_bwd_primitive", + radial_message_passing_bwd_bwd, + setup_context=radial_message_passing_setup_bwd_context, +) + + +def fn_radial_message_passing( + edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor +) -> List[Tensor]: + return torch.ops.nvtensornet.radial_message_passing_fwd_primitive( + edge_vec_norm, edge_attr, row_data, row_indptr + ) From 483125440db8de31e86ec7aeb7f470eac1df973b Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Wed, 7 Jan 2026 16:47:50 -0500 Subject: [PATCH 02/18] fix model Signed-off-by: Roman Zubatyuk --- dev/test_model_forward_backward.py | 334 ++++++++++++++++++++ src/matgl/kernels/equivariant_so3_matmul.py | 2 +- src/matgl/kernels/tensor_norm3.py | 111 +++++-- src/matgl/kernels/tensornet_mp.py | 277 ++++++++-------- src/matgl/models/_tensornet_pyg.py | 99 ++++-- src/matgl/ops/compose_tensor.py | 24 +- src/matgl/ops/decompose_tensor.py | 24 +- src/matgl/ops/equivariant_o3_matmul.py | 24 +- src/matgl/ops/equivariant_so3_matmul.py | 24 +- src/matgl/ops/tensor_norm3.py | 76 +++-- src/matgl/ops/tensornet_mp.py | 92 ++++-- src/matgl/ops/tensornet_radial_mp.py | 24 +- 12 files changed, 794 insertions(+), 317 deletions(-) create mode 100644 dev/test_model_forward_backward.py diff --git a/dev/test_model_forward_backward.py b/dev/test_model_forward_backward.py new file mode 100644 index 00000000..a38864e5 --- /dev/null +++ b/dev/test_model_forward_backward.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +"""Compare forward/backward/double-backward between matgl-main and current TensorNet.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any + +import torch +from pymatgen.core import Structure + + +# ============================================================================= +# Configuration +# ============================================================================= + +DEFAULT_MATGL_MAIN_PATH = str(Path(__file__).parent.parent / "matgl-main" / "src") + +MODEL_CONFIG = { + "units": 64, + "nblocks": 2, + "num_rbf": 32, + "cutoff": 5.0, + "rbf_type": "Gaussian", + "activation_type": "swish", + "equivariance_invariance_group": "O(3)", + "is_intensive": False, + "ntargets": 1, +} + + +# ============================================================================= +# Utilities +# ============================================================================= + +def clear_matgl_modules() -> None: + """Remove all matgl modules from sys.modules.""" + for mod in [k for k in sys.modules if k.startswith("matgl")]: + del sys.modules[mod] + + +def print_section(title: str) -> None: + """Print a section header.""" + print(f"\n{'=' * 70}\n{title}\n{'=' * 70}") + + +def load_structure(path: str) -> Structure: + """Load structure from file using pymatgen.""" + return Structure.from_file(path) + + +def get_element_types(structure: Structure) -> tuple[str, ...]: + """Extract sorted unique element symbols from structure.""" + return tuple(sorted({site.specie.symbol for site in structure})) + + +def build_graph( + converter: Any, + structure: Structure, + device: torch.device, + compute_bond: Any = None, + requires_grad: bool = False, +) -> Any: + """Build graph from structure.""" + graph, lat, _ = converter.get_graph(structure) + pos = graph.frac_coords @ lat[0] + graph.pos = pos.clone().detach().requires_grad_(requires_grad) if requires_grad else pos + graph.pbc_offshift = graph.pbc_offset @ lat[0] + + if compute_bond is not None: + bond_vec, bond_dist = compute_bond(graph) + graph.bond_vec = bond_vec + graph.bond_dist = bond_dist + + return graph.to(device) + + +# ============================================================================= +# Comparison Functions +# ============================================================================= + +def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-6) -> bool: + """Compare two tensors, return True if matching.""" + if t1.shape != t2.shape: + print(f" {name}: SHAPE MISMATCH {t1.shape} vs {t2.shape}") + return False + + if torch.allclose(t1, t2, atol=atol): + print(f" {name}: MATCH") + return True + + diff = (t1 - t2).abs() + print(f" {name}: DIFF (max={diff.max():.2e}, mean={diff.mean():.2e})") + return False + + +def compare_weights(ref_model: Any, cur_model: Any) -> bool: + """Compare model weights, handling distance_proj1/2/3 -> distance_proj mapping.""" + print_section("Weight Comparison") + + ref_sd, cur_sd = ref_model.state_dict(), cur_model.state_dict() + all_match = True + + # Handle merged distance_proj layers + dp_keys = [f"tensor_embedding.distance_proj{i}" for i in range(1, 4)] + if f"{dp_keys[0]}.weight" in ref_sd: + ref_w = torch.cat([ref_sd[f"{k}.weight"] for k in dp_keys], dim=0) + ref_b = torch.cat([ref_sd[f"{k}.bias"] for k in dp_keys], dim=0) + + print("\n--- distance_proj (merged) ---") + all_match &= compare_tensors("weight", ref_w, cur_sd["tensor_embedding.distance_proj.weight"]) + all_match &= compare_tensors("bias", ref_b, cur_sd["tensor_embedding.distance_proj.bias"]) + + # Compare remaining parameters + skip = {f"{k}.{p}" for k in dp_keys for p in ("weight", "bias")} + print("\n--- Other Parameters ---") + + for key in sorted(cur_sd): + if "distance_proj" in key: + continue + if key in ref_sd: + all_match &= compare_tensors(key, ref_sd[key], cur_sd[key]) + else: + print(f" {key}: NOT IN REFERENCE") + + for key in sorted(ref_sd): + if key not in skip and key not in cur_sd: + print(f" {key}: IN REFERENCE ONLY") + all_match = False + + print(f"\n{'=' * 70}\nResult: {'ALL MATCH' if all_match else 'MISMATCH'}") + return all_match + + +def compare_forward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device +) -> bool: + """Compare forward pass outputs.""" + print_section("Forward Pass") + + ref_model.eval() + cur_model.eval() + state_attr = torch.tensor([0.0, 0.0], device=device) + + ref_e = ref_model(g=ref_graph, state_attr=state_attr) + cur_e = cur_model(g=cur_graph, state_attr=state_attr) + diff = abs(float(ref_e) - float(cur_e)) + + print(f"Reference: {float(ref_e):.10f}") + print(f"Current: {float(cur_e):.10f}") + print(f"Diff: {diff:.2e}") + + match = diff < 1e-5 + print(f"Result: {'PASS' if match else 'FAIL'}") + return match + + +def compare_backward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device +) -> tuple[bool, torch.Tensor, torch.Tensor, Any, Any]: + """Compare backward pass (forces = -dE/dpos).""" + print_section("Backward Pass (Forces)") + + ref_model.train() + cur_model.train() + state_attr = torch.tensor([0.0, 0.0], device=device) + + def get_forces(model, graph): + energy = model(g=graph, state_attr=state_attr) + return -torch.autograd.grad(energy, graph.pos, create_graph=True, retain_graph=True)[0] + + ref_f = get_forces(ref_model, ref_graph) + cur_f = get_forces(cur_model, cur_graph) + + print(f"Reference: mean={ref_f.mean():.6f}, std={ref_f.std():.6f}") + print(f"Current: mean={cur_f.mean():.6f}, std={cur_f.std():.6f}") + + diff = (ref_f - cur_f).abs() + print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") + + match = diff.max().item() < 1e-5 + print(f"Result: {'PASS' if match else 'FAIL'}") + return match, ref_f, cur_f, ref_graph, cur_graph + + +def compare_double_backward( + ref_forces: torch.Tensor, cur_forces: torch.Tensor, ref_graph: Any, cur_graph: Any +) -> bool: + """Compare Hessian-vector product: d(F·v)/dpos.""" + print_section("Double Backward (Hessian-Vector Product)") + + torch.manual_seed(123) + v = torch.randn_like(ref_forces) + + ref_Hv = torch.autograd.grad((ref_forces * v).sum(), ref_graph.pos, retain_graph=True)[0] + cur_Hv = torch.autograd.grad((cur_forces * v).sum(), cur_graph.pos, retain_graph=True)[0] + + print(f"Reference: mean={ref_Hv.mean():.6f}, std={ref_Hv.std():.6f}") + print(f"Current: mean={cur_Hv.mean():.6f}, std={cur_Hv.std():.6f}") + + if ref_Hv.abs().max() < 1e-10 or cur_Hv.abs().max() < 1e-10: + print("WARNING: Hessian-vector product is nearly zero") + + diff = (ref_Hv - cur_Hv).abs() + print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") + + match = diff.max().item() < 1e-4 + print(f"Result: {'PASS' if match else 'FAIL'}") + return match + + +# ============================================================================= +# Main +# ============================================================================= + +def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: + """Run all comparison tests.""" + print_section("TensorNet Comparison: matgl-main vs Current") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Seed: {seed}, Device: {device}") + print(f"matgl-main path: {matgl_main_path}") + + structure = load_structure(structure_path) + element_types = get_element_types(structure) + print(f"Structure: {structure_path} ({len(structure)} atoms, elements: {element_types})") + + model_config = {**MODEL_CONFIG, "element_types": element_types} + + # Load reference model (matgl-main) + clear_matgl_modules() + sys.path.insert(0, matgl_main_path) + + from matgl.models._tensornet_pyg import TensorNet as RefTensorNet + from matgl.ext._pymatgen_pyg import Structure2Graph as RefConverter + from matgl.graph._compute_pyg import compute_pair_vector_and_distance as ref_compute_bond + + torch.manual_seed(seed) + ref_model = RefTensorNet(**model_config).to(device) + ref_converter = RefConverter(element_types=element_types, cutoff=MODEL_CONFIG["cutoff"]) + + ref_graph = build_graph(ref_converter, structure, device, ref_compute_bond) + ref_graph_grad = build_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + + sys.path.pop(0) + + # Load current model (src) + clear_matgl_modules() + + from matgl.models._tensornet_pyg import TensorNet as CurTensorNet + from matgl.ext._pymatgen_pyg import Structure2Graph as CurConverter + + torch.manual_seed(seed) + cur_model = CurTensorNet(**model_config).to(device) + cur_converter = CurConverter(element_types=element_types, cutoff=MODEL_CONFIG["cutoff"]) + + cur_graph = build_graph(cur_converter, structure, device) + cur_graph_grad = build_graph(cur_converter, structure, device, requires_grad=True) + + print(f"Models: {sum(p.numel() for p in ref_model.parameters())} params each") + + # Run comparisons + results = { + "Weights": compare_weights(ref_model, cur_model), + "Forward": compare_forward(ref_model, cur_model, ref_graph, cur_graph, device), + } + + back_ok, ref_f, cur_f, ref_g, cur_g = compare_backward( + ref_model, cur_model, ref_graph_grad, cur_graph_grad, device + ) + results["Backward"] = back_ok + results["Double Backward"] = compare_double_backward(ref_f, cur_f, ref_g, cur_g) + + # Summary + print_section("SUMMARY") + all_pass = all(results.values()) + for name, passed in results.items(): + print(f" {name}: {'PASS' if passed else 'FAIL'}") + + print(f"\n{'=' * 70}") + print("ALL TESTS PASSED" if all_pass else "SOME TESTS FAILED") + print("=" * 70) + + assert all_pass, "Model comparison tests failed" + return all_pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare TensorNet implementations") + parser.add_argument( + "--structure", "-s", + required=True, + help="Path to structure file (any format supported by pymatgen)", + ) + parser.add_argument( + "--matgl-main-path", + default=os.environ.get("MATGL_MAIN_PATH", DEFAULT_MATGL_MAIN_PATH), + help="Path to matgl-main/src (default: $MATGL_MAIN_PATH or ../matgl-main/src)", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + args = parser.parse_args() + main(structure_path=args.structure, matgl_main_path=args.matgl_main_path, seed=args.seed) diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index c50d0258..49c2a7c5 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -158,7 +158,7 @@ def tensor_matmul_so3_3x3_bwd_bwd( wp.Kernel( tensor_matmul_so3_3x3_bwd, key=f"tensor_matmul_so3_3x3_bwd_{dtype}", - module=wp.get_module(f"tensor_matmul_o3_3x3_bwd_{dtype}"), + module=wp.get_module(f"tensor_matmul_so3_3x3_bwd_{dtype}"), ), wp.Kernel( tensor_matmul_so3_3x3_bwd_bwd, diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py index 463089c6..3d24adda 100644 --- a/src/matgl/kernels/tensor_norm3.py +++ b/src/matgl/kernels/tensor_norm3.py @@ -47,6 +47,7 @@ def tensor_norm3_fwd( X: wp.array(ndim=4, dtype=dtype_wp), output: wp.array(ndim=2, dtype=dtype_wp), ): + """Computes I, A, S norms of 3x3 tensor: trace², antisym², sym_traceless².""" b, h = wp.tid() x00 = X[b, 0, 0, h] @@ -81,6 +82,7 @@ def tensor_norm3_bwd( X: wp.array(ndim=4, dtype=dtype_wp), grad_X: wp.array(ndim=4, dtype=dtype_wp), ): + """Backward: grad_X = d(I,A,S norms)/dX · grad_output.""" b, h = wp.tid() grad_i = grad_output[b, h] @@ -134,8 +136,12 @@ def tensor_norm3_bwd( def tensor_norm3_bwd_bwd( grad_grad_X: wp.array(ndim=4, dtype=dtype_wp), + X: wp.array(ndim=4, dtype=dtype_wp), + grad_output: wp.array(ndim=2, dtype=dtype_wp), grad_grad_output: wp.array(ndim=2, dtype=dtype_wp), + grad_x: wp.array(ndim=4, dtype=dtype_wp), ): + """Computes d(grad_X)/d(grad_output) and d(grad_X)/d(X) contracted with grad_grad_X.""" b, h = wp.tid() gg00 = grad_grad_X[b, 0, 0, h] @@ -148,32 +154,87 @@ def tensor_norm3_bwd_bwd( gg21 = grad_grad_X[b, 2, 1, h] gg22 = grad_grad_X[b, 2, 2, h] - trace_gg = gg00 + gg11 + gg22 - trace_third_gg = trace_gg / grad_grad_X.dtype(3.0) + x00 = X[b, 0, 0, h] + x01 = X[b, 0, 1, h] + x02 = X[b, 0, 2, h] + x10 = X[b, 1, 0, h] + x11 = X[b, 1, 1, h] + x12 = X[b, 1, 2, h] + x20 = X[b, 2, 0, h] + x21 = X[b, 2, 1, h] + x22 = X[b, 2, 2, h] - one_half = grad_grad_X.dtype(0.5) - one_third = grad_grad_X.dtype(1.0 / 3.0) - - norm2_i_gg = one_third * trace_gg * trace_gg - grad_grad_output[b, h] = norm2_i_gg - - diff01 = gg01 - gg10 - diff02 = gg02 - gg20 - diff12 = gg12 - gg21 - norm2_a_gg = one_half * (diff01 * diff01 + diff02 * diff02 + diff12 * diff12) - grad_grad_output[b, h + grad_grad_X.shape[3]] = norm2_a_gg - - sum01 = gg01 + gg10 - sum02 = gg02 + gg20 - sum12 = gg12 + gg21 - - dev00 = gg00 - trace_third_gg - dev11 = gg11 - trace_third_gg - dev22 = gg22 - trace_third_gg - - norm2_s_gg = one_half * (sum01 * sum01 + sum02 * sum02 + sum12 * sum12) - norm2_s_gg += dev00 * dev00 + dev11 * dev11 + dev22 * dev22 - grad_grad_output[b, h + 2 * grad_grad_X.shape[3]] = norm2_s_gg + grad_i = grad_output[b, h] + grad_a = grad_output[b, h + X.shape[3]] + grad_s = grad_output[b, h + 2 * X.shape[3]] + + trace_X = x00 + x11 + x22 + trace_gg = gg00 + gg11 + gg22 + c2_3 = X.dtype(2.0 / 3.0) + c4_3 = X.dtype(4.0 / 3.0) + + # Part 1: grad_grad_output = d(grad_X)/d(grad_output) · grad_grad_X + # I channel: (2/3) * trace(X) * trace(gg) + grad_grad_output[b, h] = c2_3 * trace_X * trace_gg + + # A channel: diff_X · diff_gg + diff01_X = x01 - x10 + diff02_X = x02 - x20 + diff12_X = x12 - x21 + diff01_gg = gg01 - gg10 + diff02_gg = gg02 - gg20 + diff12_gg = gg12 - gg21 + grad_grad_output[b, h + X.shape[3]] = diff01_X * diff01_gg + diff02_X * diff02_gg + diff12_X * diff12_gg + + # S channel: sum_X · sum_gg + dev_terms · diag_gg + trace_third_X = trace_X / X.dtype(3.0) + dev00 = x00 - trace_third_X + dev11 = x11 - trace_third_X + dev22 = x22 - trace_third_X + grad_s_term_00 = c4_3 * dev00 - c2_3 * dev11 - c2_3 * dev22 + grad_s_term_11 = c4_3 * dev11 - c2_3 * dev00 - c2_3 * dev22 + grad_s_term_22 = c4_3 * dev22 - c2_3 * dev00 - c2_3 * dev11 + sum01_X = x01 + x10 + sum02_X = x02 + x20 + sum12_X = x12 + x21 + sum01_gg = gg01 + gg10 + sum02_gg = gg02 + gg20 + sum12_gg = gg12 + gg21 + grad_grad_output_s = sum01_X * sum01_gg + sum02_X * sum02_gg + sum12_X * sum12_gg + grad_grad_output_s += grad_s_term_00 * gg00 + grad_s_term_11 * gg11 + grad_s_term_22 * gg22 + grad_grad_output[b, h + 2 * X.shape[3]] = grad_grad_output_s + + # Part 2: grad_x = d(grad_X)/d(X) · grad_grad_X + # I channel: (2/3) * grad_i * trace(gg) on diagonals + scalar_diag = c2_3 * grad_i * trace_gg + + # A channel: grad_a * diff_gg (antisymmetric) + antisym_01 = grad_a * diff01_gg + antisym_02 = grad_a * diff02_gg + antisym_12 = grad_a * diff12_gg + + # S channel off-diag: grad_s * sum_gg + sym_offdiag_01 = grad_s * sum01_gg + sym_offdiag_02 = grad_s * sum02_gg + sym_offdiag_12 = grad_s * sum12_gg + + # S channel diag: grad_s * (4/3 on self, -2/3 on others) + sym_diag_00 = grad_s * (c4_3 * gg00 - c2_3 * gg11 - c2_3 * gg22) + sym_diag_11 = grad_s * (c4_3 * gg11 - c2_3 * gg00 - c2_3 * gg22) + sym_diag_22 = grad_s * (c4_3 * gg22 - c2_3 * gg00 - c2_3 * gg11) + + # Diagonals + grad_x[b, 0, 0, h] = scalar_diag + sym_diag_00 + grad_x[b, 1, 1, h] = scalar_diag + sym_diag_11 + grad_x[b, 2, 2, h] = scalar_diag + sym_diag_22 + + # Off-diagonals + grad_x[b, 0, 1, h] = antisym_01 + sym_offdiag_01 + grad_x[b, 1, 0, h] = -antisym_01 + sym_offdiag_01 + grad_x[b, 0, 2, h] = antisym_02 + sym_offdiag_02 + grad_x[b, 2, 0, h] = -antisym_02 + sym_offdiag_02 + grad_x[b, 1, 2, h] = antisym_12 + sym_offdiag_12 + grad_x[b, 2, 1, h] = -antisym_12 + sym_offdiag_12 return ( wp.Kernel( diff --git a/src/matgl/kernels/tensornet_mp.py b/src/matgl/kernels/tensornet_mp.py index 01baaed4..c2a8246f 100644 --- a/src/matgl/kernels/tensornet_mp.py +++ b/src/matgl/kernels/tensornet_mp.py @@ -58,21 +58,18 @@ def message_passing_fwd( output_A_reg = vec3(I.dtype(0)) output_S_reg = vec5(I.dtype(0)) - _I = I[b, 0, h] - _A = vec3(A[b, 0, h], A[b, 1, h], A[b, 2, h]) - _S = vec5(S[b, 0, h], S[b, 1, h], S[b, 2, h], S[b, 3, h], S[b, 4, h]) - for i in range(row_indptr[b], row_indptr[b + 1]): + idx_j = row_indices[i] idx_w = row_data[i] wI = edge_attr[idx_w, 0, h] wA = edge_attr[idx_w, 1, h] wS = edge_attr[idx_w, 2, h] - output_I_reg += _I * wI + output_I_reg += I[idx_j, 0, h] * wI for j in range(3): - output_A_reg[j] += _A[j] * wA + output_A_reg[j] += A[idx_j, j, h] * wA for j in range(5): - output_S_reg[j] += _S[j] * wS + output_S_reg[j] += S[idx_j, j, h] * wS output_I[b, 0, h] = output_I_reg for j in range(3): @@ -89,9 +86,9 @@ def message_passing_bwd( doutput_I: wp.array(ndim=3, dtype=dtype_wp), doutput_A: wp.array(ndim=3, dtype=dtype_wp), doutput_S: wp.array(ndim=3, dtype=dtype_wp), - row_data: wp.array(ndim=1, dtype=wp.int32), - row_indices: wp.array(ndim=1, dtype=wp.int32), - row_indptr: wp.array(ndim=1, dtype=wp.int32), + col_data: wp.array(ndim=1, dtype=wp.int32), + col_indices: wp.array(ndim=1, dtype=wp.int32), + col_indptr: wp.array(ndim=1, dtype=wp.int32), dI: wp.array(ndim=3, dtype=dtype_wp), dA: wp.array(ndim=3, dtype=dtype_wp), dS: wp.array(ndim=3, dtype=dtype_wp), @@ -103,50 +100,30 @@ def message_passing_bwd( dA_reg = vec3(I.dtype(0.0)) dS_reg = vec5(I.dtype(0.0)) - dI_b = doutput_I[b, 0, h] - I_b = I[b, 0, h] - - dA_b0 = doutput_A[b, 0, h] - dA_b1 = doutput_A[b, 1, h] - dA_b2 = doutput_A[b, 2, h] - A_b0 = A[b, 0, h] - A_b1 = A[b, 1, h] - A_b2 = A[b, 2, h] - - dS_b0 = doutput_S[b, 0, h] - dS_b1 = doutput_S[b, 1, h] - dS_b2 = doutput_S[b, 2, h] - dS_b3 = doutput_S[b, 3, h] - dS_b4 = doutput_S[b, 4, h] - S_b0 = S[b, 0, h] - S_b1 = S[b, 1, h] - S_b2 = S[b, 2, h] - S_b3 = S[b, 3, h] - S_b4 = S[b, 4, h] - - for i in range(row_indptr[b], row_indptr[b + 1]): - idx_w = row_data[i] + for i in range(col_indptr[b], col_indptr[b + 1]): + idx_j = col_indices[i] + idx_w = col_data[i] wI = edge_attr[idx_w, 0, h] - wA = edge_attr[idx_w, 1, h] - wS = edge_attr[idx_w, 2, h] + doutput_I_j = doutput_I[idx_j, 0, h] + dI_reg += doutput_I_j * wI + dedge_attr[idx_w, 0, h] = doutput_I_j * I[b, 0, h] - dI_reg += dI_b * wI - dedge_attr[idx_w, 0, h] = dI_b * I_b - - dA_reg[0] += dA_b0 * wA - dA_reg[1] += dA_b1 * wA - dA_reg[2] += dA_b2 * wA - dedge_attr[idx_w, 1, h] = dA_b0 * A_b0 + dA_b1 * A_b1 + dA_b2 * A_b2 + # A + wA = edge_attr[idx_w, 1, h] + dweight_A = I.dtype(0.0) + for j in range(3): + dA_reg[j] += doutput_A[idx_j, j, h] * wA + dweight_A += doutput_A[idx_j, j, h] * A[b, j, h] + dedge_attr[idx_w, 1, h] = dweight_A - dS_reg[0] += dS_b0 * wS - dS_reg[1] += dS_b1 * wS - dS_reg[2] += dS_b2 * wS - dS_reg[3] += dS_b3 * wS - dS_reg[4] += dS_b4 * wS - dedge_attr[idx_w, 2, h] = ( - dS_b0 * S_b0 + dS_b1 * S_b1 + dS_b2 * S_b2 + dS_b3 * S_b3 + dS_b4 * S_b4 - ) + # S + wS = edge_attr[idx_w, 2, h] + dweight_S = I.dtype(0.0) + for j in range(5): + dS_reg[j] += doutput_S[idx_j, j, h] * wS + dweight_S += doutput_S[idx_j, j, h] * S[b, j, h] + dedge_attr[idx_w, 2, h] = dweight_S dI[b, 0, h] = dI_reg for j in range(3): @@ -154,11 +131,10 @@ def message_passing_bwd( for j in range(5): dS[b, j, h] = dS_reg[j] - def message_passing_bwd_bwd( + def message_passing_edge_bwd_bwd( I: wp.array(ndim=3, dtype=dtype_wp), A: wp.array(ndim=3, dtype=dtype_wp), S: wp.array(ndim=3, dtype=dtype_wp), - edge_attr: wp.array(ndim=3, dtype=dtype_wp), dI: wp.array(ndim=3, dtype=dtype_wp), dA: wp.array(ndim=3, dtype=dtype_wp), dS: wp.array(ndim=3, dtype=dtype_wp), @@ -166,9 +142,6 @@ def message_passing_bwd_bwd( doutput_I: wp.array(ndim=3, dtype=dtype_wp), doutput_A: wp.array(ndim=3, dtype=dtype_wp), doutput_S: wp.array(ndim=3, dtype=dtype_wp), - row_data: wp.array(ndim=1, dtype=wp.int32), - row_indices: wp.array(ndim=1, dtype=wp.int32), - row_indptr: wp.array(ndim=1, dtype=wp.int32), col_data: wp.array(ndim=1, dtype=wp.int32), col_indices: wp.array(ndim=1, dtype=wp.int32), col_indptr: wp.array(ndim=1, dtype=wp.int32), @@ -176,114 +149,107 @@ def message_passing_bwd_bwd( d2A: wp.array(ndim=3, dtype=dtype_wp), d2S: wp.array(ndim=3, dtype=dtype_wp), d2edge_attr: wp.array(ndim=3, dtype=dtype_wp), - d2output_I: wp.array(ndim=3, dtype=dtype_wp), - d2output_A: wp.array(ndim=3, dtype=dtype_wp), - d2output_S: wp.array(ndim=3, dtype=dtype_wp), ): + # Col-based iteration: b is source node, idx_j is destination node + # Computes d2I, d2A, d2S, d2edge_attr - no atomics needed b, h = wp.tid() - d2I_reg = d2output_I.dtype(0) - d2output_I_reg = d2output_I.dtype(0) - - d2A_reg = vec3(d2output_A.dtype(0)) - d2output_A_reg = vec3(d2output_A.dtype(0)) - - d2S_reg = vec5(d2output_S.dtype(0)) - d2output_S_reg = vec5(d2output_S.dtype(0)) + d2I_reg = I.dtype(0) + d2A_reg = vec3(I.dtype(0)) + d2S_reg = vec5(I.dtype(0)) for i in range(col_indptr[b], col_indptr[b + 1]): - idx_j = col_indices[i] + idx_j = col_indices[i] # Destination node idx_w = col_data[i] dweight_I = dedge_attr[idx_w, 0, h] - doutput_I_reg = doutput_I[idx_j, 0, h] + dweight_A = dedge_attr[idx_w, 1, h] + dweight_S = dedge_attr[idx_w, 2, h] - d2I_reg += doutput_I_reg * dweight_I + # d2I[b] = Σ dedge_attr[edge] * doutput_I[dst] + d2I_reg += doutput_I[idx_j, 0, h] * dweight_I - dweight_A = dedge_attr[idx_w, 1, h] + # d2edge_attr[edge] = dI[src] * doutput_I[dst] + d2edge_attr[idx_w, 0, h] = doutput_I[idx_j, 0, h] * dI[b, 0, h] + + # A + dweight_A_reg = I.dtype(0.0) for j in range(3): d2A_reg[j] += doutput_A[idx_j, j, h] * dweight_A + dweight_A_reg += doutput_A[idx_j, j, h] * dA[b, j, h] + d2edge_attr[idx_w, 1, h] = dweight_A_reg - dweight_S = dedge_attr[idx_w, 2, h] - + # S + dweight_S_reg = I.dtype(0.0) for j in range(5): d2S_reg[j] += doutput_S[idx_j, j, h] * dweight_S + dweight_S_reg += doutput_S[idx_j, j, h] * dS[b, j, h] + d2edge_attr[idx_w, 2, h] = dweight_S_reg + + d2I[b, 0, h] = d2I_reg - I_b = I[b, 0, h] - dI_b = dI[b, 0, h] - dO_I_b = doutput_I[b, 0, h] - - A_b0 = A[b, 0, h] - A_b1 = A[b, 1, h] - A_b2 = A[b, 2, h] - dA_b0 = dA[b, 0, h] - dA_b1 = dA[b, 1, h] - dA_b2 = dA[b, 2, h] - - S_b0 = S[b, 0, h] - S_b1 = S[b, 1, h] - S_b2 = S[b, 2, h] - S_b3 = S[b, 3, h] - S_b4 = S[b, 4, h] - dS_b0 = dS[b, 0, h] - dS_b1 = dS[b, 1, h] - dS_b2 = dS[b, 2, h] - dS_b3 = dS[b, 3, h] - dS_b4 = dS[b, 4, h] + for j in range(3): + d2A[b, j, h] = d2A_reg[j] + + for j in range(5): + d2S[b, j, h] = d2S_reg[j] + + def message_passing_output_bwd_bwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + dI: wp.array(ndim=3, dtype=dtype_wp), + dA: wp.array(ndim=3, dtype=dtype_wp), + dS: wp.array(ndim=3, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + d2output_I: wp.array(ndim=3, dtype=dtype_wp), + d2output_A: wp.array(ndim=3, dtype=dtype_wp), + d2output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + # Row-based iteration: b is destination node, idx_j is source node + # Computes d2output_I, d2output_A, d2output_S - no atomics needed + b, h = wp.tid() + + d2output_I_reg = I.dtype(0) + d2output_A_reg = vec3(I.dtype(0)) + d2output_S_reg = vec5(I.dtype(0)) for i in range(row_indptr[b], row_indptr[b + 1]): + idx_j = row_indices[i] # Source node idx_w = row_data[i] wI = edge_attr[idx_w, 0, h] wA = edge_attr[idx_w, 1, h] wS = edge_attr[idx_w, 2, h] - d2output_I_reg += dI_b * wI - d2output_I_reg += I_b * dedge_attr[idx_w, 0, h] - - d2edge_attr[idx_w, 0, h] = dO_I_b * dI_b - - d2output_A_reg[0] += dA_b0 * wA - d2output_A_reg[1] += dA_b1 * wA - d2output_A_reg[2] += dA_b2 * wA - d2output_A_reg[0] += A_b0 * dedge_attr[idx_w, 1, h] - d2output_A_reg[1] += A_b1 * dedge_attr[idx_w, 1, h] - d2output_A_reg[2] += A_b2 * dedge_attr[idx_w, 1, h] - - d2edge_attr[idx_w, 1, h] = ( - doutput_A[b, 0, h] * dA_b0 - + doutput_A[b, 1, h] * dA_b1 - + doutput_A[b, 2, h] * dA_b2 - ) - - d2output_S_reg[0] += dS_b0 * wS - d2output_S_reg[1] += dS_b1 * wS - d2output_S_reg[2] += dS_b2 * wS - d2output_S_reg[3] += dS_b3 * wS - d2output_S_reg[4] += dS_b4 * wS - d2output_S_reg[0] += S_b0 * dedge_attr[idx_w, 2, h] - d2output_S_reg[1] += S_b1 * dedge_attr[idx_w, 2, h] - d2output_S_reg[2] += S_b2 * dedge_attr[idx_w, 2, h] - d2output_S_reg[3] += S_b3 * dedge_attr[idx_w, 2, h] - d2output_S_reg[4] += S_b4 * dedge_attr[idx_w, 2, h] - - d2edge_attr[idx_w, 2, h] = ( - doutput_S[b, 0, h] * dS_b0 - + doutput_S[b, 1, h] * dS_b1 - + doutput_S[b, 2, h] * dS_b2 - + doutput_S[b, 3, h] * dS_b3 - + doutput_S[b, 4, h] * dS_b4 - ) + dweight_I = dedge_attr[idx_w, 0, h] + dweight_A = dedge_attr[idx_w, 1, h] + dweight_S = dedge_attr[idx_w, 2, h] + + # d2output_I[b] = Σ (dI[src] * edge_attr + I[src] * dedge_attr) + d2output_I_reg += dI[idx_j, 0, h] * wI + d2output_I_reg += I[idx_j, 0, h] * dweight_I + + # A + for j in range(3): + d2output_A_reg[j] += dA[idx_j, j, h] * wA + d2output_A_reg[j] += A[idx_j, j, h] * dweight_A + + # S + for j in range(5): + d2output_S_reg[j] += dS[idx_j, j, h] * wS + d2output_S_reg[j] += S[idx_j, j, h] * dweight_S d2output_I[b, 0, h] = d2output_I_reg - d2I[b, 0, h] = d2I_reg for j in range(3): - d2A[b, j, h] = d2A_reg[j] d2output_A[b, j, h] = d2output_A_reg[j] for j in range(5): - d2S[b, j, h] = d2S_reg[j] d2output_S[b, j, h] = d2output_S_reg[j] return ( @@ -298,31 +264,48 @@ def message_passing_bwd_bwd( module=wp.get_module(f"message_passing_bwd_{dtype}"), ), wp.Kernel( - message_passing_bwd_bwd, - key=f"message_passing_bwd_bwd_{dtype}", - module=wp.get_module(f"message_passing_bwd_bwd_{dtype}"), + message_passing_edge_bwd_bwd, + key=f"message_passing_edge_bwd_bwd_{dtype}", + module=wp.get_module(f"message_passing_edge_bwd_bwd_{dtype}"), + ), + wp.Kernel( + message_passing_output_bwd_bwd, + key=f"message_passing_output_bwd_bwd_{dtype}", + module=wp.get_module(f"message_passing_output_bwd_bwd_{dtype}"), ), ) -message_passing_fwd_fp64, message_passing_bwd_fp64, message_passing_bwd_bwd_fp64 = ( - generate_message_passing("float64") -) -message_passing_fwd_fp32, message_passing_bwd_fp32, message_passing_bwd_bwd_fp32 = ( - generate_message_passing("float32") -) -message_passing_fwd_fp16, message_passing_bwd_fp16, message_passing_bwd_bwd_fp16 = ( - generate_message_passing("float16") -) +( + message_passing_fwd_fp64, + message_passing_bwd_fp64, + message_passing_edge_bwd_bwd_fp64, + message_passing_output_bwd_bwd_fp64, +) = generate_message_passing("float64") +( + message_passing_fwd_fp32, + message_passing_bwd_fp32, + message_passing_edge_bwd_bwd_fp32, + message_passing_output_bwd_bwd_fp32, +) = generate_message_passing("float32") +( + message_passing_fwd_fp16, + message_passing_bwd_fp16, + message_passing_edge_bwd_bwd_fp16, + message_passing_output_bwd_bwd_fp16, +) = generate_message_passing("float16") add_module("message_passing_fwd", ["float64"], message_passing_fwd_fp64) add_module("message_passing_bwd", ["float64"], message_passing_bwd_fp64) -add_module("message_passing_bwd_bwd", ["float64"], message_passing_bwd_bwd_fp64) +add_module("message_passing_edge_bwd_bwd", ["float64"], message_passing_edge_bwd_bwd_fp64) +add_module("message_passing_output_bwd_bwd", ["float64"], message_passing_output_bwd_bwd_fp64) add_module("message_passing_fwd", ["float32"], message_passing_fwd_fp32) add_module("message_passing_bwd", ["float32"], message_passing_bwd_fp32) -add_module("message_passing_bwd_bwd", ["float32"], message_passing_bwd_bwd_fp32) +add_module("message_passing_edge_bwd_bwd", ["float32"], message_passing_edge_bwd_bwd_fp32) +add_module("message_passing_output_bwd_bwd", ["float32"], message_passing_output_bwd_bwd_fp32) add_module("message_passing_fwd", ["float16"], message_passing_fwd_fp16) add_module("message_passing_bwd", ["float16"], message_passing_bwd_fp16) -add_module("message_passing_bwd_bwd", ["float16"], message_passing_bwd_bwd_fp16) +add_module("message_passing_edge_bwd_bwd", ["float16"], message_passing_edge_bwd_bwd_fp16) +add_module("message_passing_output_bwd_bwd", ["float16"], message_passing_output_bwd_bwd_fp16) diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 72192325..269b56ca 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -101,7 +101,9 @@ def __init__( self.units = units self.cutoff = cutoff - self.distance_proj = nn.Linear(degree_rbf, 3 * units, dtype=dtype) + # Create unified distance_proj from 3 temp layers (matches reference RNG pattern). + self.distance_proj = self._create_distance_proj(degree_rbf, units, dtype=dtype) + self.emb = nn.Embedding(ntypes_node, units, dtype=dtype) self.emb2 = nn.Linear(2 * units, units, dtype=dtype) self.act = activation @@ -116,25 +118,77 @@ def __init__( self.reset_parameters() + def _create_distance_proj( + self, + in_features: int, + units: int, + dtype: torch.dtype = matgl.float_th, + ) -> nn.Linear: + """Create unified distance_proj from 3 separate layers to match reference RNG pattern.""" + d_proj1 = nn.Linear(in_features, units, bias=True, dtype=dtype) + d_proj2 = nn.Linear(in_features, units, bias=True, dtype=dtype) + d_proj3 = nn.Linear(in_features, units, bias=True, dtype=dtype) + + layer = torch.nn.utils.skip_init( + nn.Linear, + in_features, + 3 * units, + bias=True, + dtype=dtype + ) + with torch.no_grad(): + layer.weight.copy_(torch.cat([d_proj1.weight, d_proj2.weight, d_proj3.weight], dim=0)) + layer.bias.copy_(torch.cat([d_proj1.bias, d_proj2.bias, d_proj3.bias], dim=0)) + return layer + + def _reset_distance_proj(self) -> None: + """Reset distance_proj weights using 3 temp layers to match reference RNG pattern.""" + dtype = self.distance_proj.weight.dtype + d_proj1 = torch.nn.utils.skip_init( + nn.Linear, + self.distance_proj.in_features, + self.units, + bias=True, + dtype=dtype + ) + d_proj2 = torch.nn.utils.skip_init( + nn.Linear, + self.distance_proj.in_features, + self.units, + bias=True, + dtype=dtype + ) + d_proj3 = torch.nn.utils.skip_init( + nn.Linear, + self.distance_proj.in_features, + self.units, + bias=True, + dtype=dtype + ) + d_proj1.reset_parameters() + d_proj2.reset_parameters() + d_proj3.reset_parameters() + with torch.no_grad(): + self.distance_proj.weight.copy_(torch.cat([d_proj1.weight, d_proj2.weight, d_proj3.weight], dim=0)) + self.distance_proj.bias.copy_(torch.cat([d_proj1.bias, d_proj2.bias, d_proj3.bias], dim=0)) + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - # since the we changed distance_proj to be a single linear layer, - # we need to concatenate the weights and biases of the three distance_proj layers - # into a single weight and bias tensor + """Handle legacy checkpoints with separate distance_proj1/2/3 layers.""" w_keys = [f"{prefix}distance_proj{i}.weight" for i in (1, 2, 3)] - b_keys = [f"{prefix}distance_proj{i}.bias" for i in (1, 2, 3)] - new_w = f"{prefix}distance_proj.weight" - new_b = f"{prefix}distance_proj.bias" + b_keys = [f"{prefix}distance_proj{i}.bias" for i in (1, 2, 3)] + new_w = f"{prefix}distance_proj.weight" + new_b = f"{prefix}distance_proj.bias" - if all(k in state_dict for k in (w_keys + b_keys)): + if all(k in state_dict for k in w_keys + b_keys): state_dict = dict(state_dict) - state_dict[new_w] = torch.cat([state_dict.pop(k) for k in w_keys], dim=0) state_dict[new_b] = torch.cat([state_dict.pop(k) for k in b_keys], dim=0) - return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def reset_parameters(self): - self.distance_proj.reset_parameters() + """Reinitialize parameters with RNG pattern matching reference implementation.""" + self._reset_distance_proj() self.emb.reset_parameters() self.emb2.reset_parameters() for linear in self.linears_tensor: @@ -150,8 +204,8 @@ def forward( edge_weight: torch.Tensor, edge_vec: torch.Tensor, edge_attr: torch.Tensor, - row_data: torch.Tensor, - row_indptr: torch.Tensor, + col_data: torch.Tensor, + col_indptr: torch.Tensor, ) -> torch.Tensor: """Forward pass. @@ -161,6 +215,8 @@ def forward( edge_weight: Edge weights (distances), shape (num_edges,) edge_vec: Edge vectors, shape (num_edges, 3) edge_attr: Edge attributes (RBF), shape (num_edges, num_rbf) + col_data: CSR col data for destination aggregation, shape (num_edges,) + col_indptr: CSR col indptr for destination aggregation, shape (num_nodes+1,) Returns: X: Tensor representation, shape (num_nodes, 3, 3, units) @@ -173,7 +229,7 @@ def forward( edge_attr = self.distance_proj(edge_attr).view(-1, 3, self.units) # Get atomic number messages - zij = x.index_select(0, edge_index.t().flip(-1).reshape(-1)).view( + zij = x.index_select(0, edge_index.t().reshape(-1)).view( -1, self.units * 2 ) Zij = self.emb2(zij) # (num_edges, units) @@ -187,7 +243,7 @@ def forward( # Radial message passing edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) I, A, S = fn_radial_message_passing( - edge_vec_norm, edge_attr_processed, row_data, row_indptr + edge_vec_norm, edge_attr_processed, col_data, col_indptr ) # Compose initial tensor to get proper shape for norm computation @@ -284,8 +340,8 @@ def forward( for linear_scalar in self.linears_scalar: edge_attr_processed = self.act(linear_scalar(edge_attr_processed)) edge_attr_processed = (edge_attr_processed * C.view(-1, 1)).view( - edge_attr.shape[0], 3, self.units - ).mT.contiguous() # (num_edges, units, 3) + edge_attr.shape[0], self.units, 3 + ).mT.contiguous() # (num_edges, 3, units) # Normalize input tensor # For X with shape (num_nodes, 3, 3, units), we need to sum over (-3, -2) @@ -309,7 +365,7 @@ def forward( I, A, S, - edge_attr, + edge_attr_processed, row_data, row_indices, row_indptr, @@ -323,8 +379,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": C = fn_tensor_matmul_o3_3x3(Y, msg) else: # SO(3) - C = fn_tensor_matmul_so3_3x3(Y, msg) - C = C = C + C = 2 * fn_tensor_matmul_so3_3x3(Y, msg) I, A, S = fn_decompose_tensor(C) # Normalize @@ -560,8 +615,8 @@ def forward( bond_dist, bond_vec, edge_attr, - row_data, - row_indptr + col_data, + col_indptr ) # Interaction layers diff --git a/src/matgl/ops/compose_tensor.py b/src/matgl/ops/compose_tensor.py index 95dd87e3..6fd56079 100644 --- a/src/matgl/ops/compose_tensor.py +++ b/src/matgl/ops/compose_tensor.py @@ -37,7 +37,7 @@ @torch.library.custom_op( - "nvtensornet::compose_tensor_fwd_primitive", + "tensornet::compose_tensor_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -66,13 +66,13 @@ def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: return output -@torch.library.register_fake("nvtensornet::compose_tensor_fwd_primitive") +@torch.library.register_fake("tensornet::compose_tensor_fwd_primitive") def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: return torch.empty((z.shape[0], 3, 3, z.shape[-1]), dtype=x.dtype, device=x.device) @torch.library.custom_op( - "nvtensornet::compose_tensor_bwd_primitive", + "tensornet::compose_tensor_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -101,13 +101,13 @@ def _(grad_output: Tensor, x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: return [grad_x, grad_y, grad_z] -@torch.library.register_fake("nvtensornet::compose_tensor_bwd_primitive") +@torch.library.register_fake("tensornet::compose_tensor_bwd_primitive") def _(grad_output: List[Tensor], x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] @torch.library.custom_op( - "nvtensornet::compose_tensor_bwd_bwd_primitive", + "tensornet::compose_tensor_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -146,7 +146,7 @@ def _( return [grad_grad_output, grad_x, grad_y, grad_z] -@torch.library.register_fake("nvtensornet::compose_tensor_bwd_bwd_primitive") +@torch.library.register_fake("tensornet::compose_tensor_bwd_bwd_primitive") def _( grad_output: Tensor, grad_grad_x: Tensor, @@ -176,13 +176,13 @@ def compose_tensor_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def compose_tensor_fwd(*args): - return torch.ops.nvtensornet.compose_tensor_fwd_primitive(*args) + return torch.ops.tensornet.compose_tensor_fwd_primitive(*args) @torch.compiler.allow_in_graph def compose_tensor_bwd(ctx, grad_output): x, y, z = ctx.saved_tensors - dx, dy, dz = torch.ops.nvtensornet.compose_tensor_bwd_primitive( + dx, dy, dz = torch.ops.tensornet.compose_tensor_bwd_primitive( grad_output, x, y, z ) return dx, dy, dz @@ -203,7 +203,7 @@ def compose_tensor_bwd_bwd(ctx, *grad_outputs): if grad_grad_z is None: grad_grad_z = torch.zeros_like(z) - outputs = torch.ops.nvtensornet.compose_tensor_bwd_bwd_primitive( + outputs = torch.ops.tensornet.compose_tensor_bwd_bwd_primitive( grad_output_saved, grad_grad_x, grad_grad_y, grad_grad_z, x, y, z ) @@ -211,18 +211,18 @@ def compose_tensor_bwd_bwd(ctx, *grad_outputs): torch.library.register_autograd( - "nvtensornet::compose_tensor_fwd_primitive", + "tensornet::compose_tensor_fwd_primitive", compose_tensor_bwd, setup_context=compose_tensor_setup_fwd_context, ) torch.library.register_autograd( - "nvtensornet::compose_tensor_bwd_primitive", + "tensornet::compose_tensor_bwd_primitive", compose_tensor_bwd_bwd, setup_context=compose_tensor_setup_bwd_context, ) def fn_compose_tensor(x: Tensor, y: Tensor, z: Tensor) -> Tensor: - output = torch.ops.nvtensornet.compose_tensor_fwd_primitive(x, y, z) + output = torch.ops.tensornet.compose_tensor_fwd_primitive(x, y, z) return output diff --git a/src/matgl/ops/decompose_tensor.py b/src/matgl/ops/decompose_tensor.py index 70f1589d..a2c1973a 100644 --- a/src/matgl/ops/decompose_tensor.py +++ b/src/matgl/ops/decompose_tensor.py @@ -37,7 +37,7 @@ @torch.library.custom_op( - "nvtensornet::decompose_tensor_fwd_primitive", + "tensornet::decompose_tensor_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -66,7 +66,7 @@ def _(x: Tensor) -> List[Tensor]: return [output_i, output_a, output_s] -@torch.library.register_fake("nvtensornet::decompose_tensor_fwd_primitive") +@torch.library.register_fake("tensornet::decompose_tensor_fwd_primitive") def _(x: Tensor) -> List[Tensor]: return [ torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device), @@ -76,7 +76,7 @@ def _(x: Tensor) -> List[Tensor]: @torch.library.custom_op( - "nvtensornet::decompose_tensor_bwd_primitive", + "tensornet::decompose_tensor_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -106,7 +106,7 @@ def _( return [grad_x] -@torch.library.register_fake("nvtensornet::decompose_tensor_bwd_primitive") +@torch.library.register_fake("tensornet::decompose_tensor_bwd_primitive") def _( grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor ) -> List[Tensor]: @@ -114,7 +114,7 @@ def _( @torch.library.custom_op( - "nvtensornet::decompose_tensor_bwd_bwd_primitive", + "tensornet::decompose_tensor_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -162,7 +162,7 @@ def _( return [grad_grad_output_i, grad_grad_output_a, grad_grad_output_s, grad_x] -@torch.library.register_fake("nvtensornet::decompose_tensor_bwd_bwd_primitive") +@torch.library.register_fake("tensornet::decompose_tensor_bwd_bwd_primitive") def _( grad_output_i: Tensor, grad_output_a: Tensor, @@ -190,14 +190,14 @@ def decompose_tensor_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def decompose_tensor_fwd(*args): - return torch.ops.nvtensornet.decompose_tensor_fwd_primitive(*args) + return torch.ops.tensornet.decompose_tensor_fwd_primitive(*args) @torch.compiler.allow_in_graph def decompose_tensor_bwd(ctx, *grad_outputs): (x,) = ctx.saved_tensors grad_output_i, grad_output_a, grad_output_s = grad_outputs[0] - dx = torch.ops.nvtensornet.decompose_tensor_bwd_primitive( + dx = torch.ops.tensornet.decompose_tensor_bwd_primitive( grad_output_i, grad_output_a, grad_output_s, x ) return dx[0] @@ -212,7 +212,7 @@ def decompose_tensor_bwd_bwd(ctx, *grad_outputs): if grad_grad_x is None: grad_grad_x = torch.zeros_like(x) - outputs = torch.ops.nvtensornet.decompose_tensor_bwd_bwd_primitive( + outputs = torch.ops.tensornet.decompose_tensor_bwd_bwd_primitive( grad_output_i, grad_output_a, grad_output_s, grad_grad_x, x ) @@ -220,18 +220,18 @@ def decompose_tensor_bwd_bwd(ctx, *grad_outputs): torch.library.register_autograd( - "nvtensornet::decompose_tensor_fwd_primitive", + "tensornet::decompose_tensor_fwd_primitive", decompose_tensor_bwd, setup_context=decompose_tensor_setup_fwd_context, ) torch.library.register_autograd( - "nvtensornet::decompose_tensor_bwd_primitive", + "tensornet::decompose_tensor_bwd_primitive", decompose_tensor_bwd_bwd, setup_context=decompose_tensor_setup_bwd_context, ) def fn_decompose_tensor(x: Tensor) -> List[Tensor]: - output = torch.ops.nvtensornet.decompose_tensor_fwd_primitive(x) + output = torch.ops.tensornet.decompose_tensor_fwd_primitive(x) return output diff --git a/src/matgl/ops/equivariant_o3_matmul.py b/src/matgl/ops/equivariant_o3_matmul.py index fd706da2..c6d37d2d 100644 --- a/src/matgl/ops/equivariant_o3_matmul.py +++ b/src/matgl/ops/equivariant_o3_matmul.py @@ -36,7 +36,7 @@ @torch.library.custom_op( - "nvtensornet::tensor_matmul_o3_3x3_fwd_primitive", + "tensornet::tensor_matmul_o3_3x3_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -66,13 +66,13 @@ def _(x: Tensor, y: Tensor) -> Tensor: return output -@torch.library.register_fake("nvtensornet::tensor_matmul_o3_3x3_fwd_primitive") +@torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_fwd_primitive") def _(x: Tensor, y: Tensor) -> Tensor: return torch.empty_like(x) @torch.library.custom_op( - "nvtensornet::tensor_matmul_o3_3x3_bwd_primitive", + "tensornet::tensor_matmul_o3_3x3_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -104,13 +104,13 @@ def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: return [grad_x, grad_y] -@torch.library.register_fake("nvtensornet::tensor_matmul_o3_3x3_bwd_primitive") +@torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_primitive") def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: return [torch.empty_like(x), torch.empty_like(y)] @torch.library.custom_op( - "nvtensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive", + "tensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -162,7 +162,7 @@ def _( return [grad_grad_output, grad_x, grad_y] -@torch.library.register_fake("nvtensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive") +@torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive") def _( grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor ) -> List[Tensor]: @@ -185,7 +185,7 @@ def tensor_matmul_o3_3x3_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def tensor_matmul_o3_3x3_fwd(*args): - return getattr(torch.ops.nvtensornet, "tensor_matmul_o3_3x3_fwd_primitive")( + return getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_fwd_primitive")( *args ) @@ -193,7 +193,7 @@ def tensor_matmul_o3_3x3_fwd(*args): @torch.compiler.allow_in_graph def tensor_matmul_o3_3x3_bwd(ctx, grad_output): x, y = ctx.saved_tensors - dx, dy = getattr(torch.ops.nvtensornet, "tensor_matmul_o3_3x3_bwd_primitive")( + dx, dy = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_primitive")( grad_output, x, y ) return dx, dy @@ -207,24 +207,24 @@ def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): grad_output_saved, x, y = ctx.saved_tensors outputs = getattr( - torch.ops.nvtensornet, "tensor_matmul_o3_3x3_bwd_bwd_primitive" + torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_bwd_primitive" )(grad_output_saved, grad_grad_x, grad_grad_y, x, y) return outputs[0], outputs[1], outputs[2] torch.library.register_autograd( - "nvtensornet::tensor_matmul_o3_3x3_fwd_primitive", + "tensornet::tensor_matmul_o3_3x3_fwd_primitive", tensor_matmul_o3_3x3_bwd, setup_context=tensor_matmul_o3_3x3_setup_fwd_context, ) torch.library.register_autograd( - "nvtensornet::tensor_matmul_o3_3x3_bwd_primitive", + "tensornet::tensor_matmul_o3_3x3_bwd_primitive", tensor_matmul_o3_3x3_bwd_bwd, setup_context=tensor_matmul_o3_3x3_setup_bwd_context, ) def fn_tensor_matmul_o3_3x3(x: Tensor, y: Tensor) -> Tensor: - z = getattr(torch.ops.nvtensornet, "tensor_matmul_o3_3x3_fwd_primitive")(x, y) + z = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_fwd_primitive")(x, y) return z diff --git a/src/matgl/ops/equivariant_so3_matmul.py b/src/matgl/ops/equivariant_so3_matmul.py index 16d61748..44d27958 100644 --- a/src/matgl/ops/equivariant_so3_matmul.py +++ b/src/matgl/ops/equivariant_so3_matmul.py @@ -37,7 +37,7 @@ @torch.library.custom_op( - "nvtensornet::tensor_matmul_so3_3x3_fwd_primitive", + "tensornet::tensor_matmul_so3_3x3_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -66,13 +66,13 @@ def _(x: Tensor, y: Tensor) -> Tensor: return output -@torch.library.register_fake("nvtensornet::tensor_matmul_so3_3x3_fwd_primitive") +@torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_fwd_primitive") def _(x: Tensor, y: Tensor) -> Tensor: return torch.empty_like(x) @torch.library.custom_op( - "nvtensornet::tensor_matmul_so3_3x3_bwd_primitive", + "tensornet::tensor_matmul_so3_3x3_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -103,13 +103,13 @@ def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: return [grad_x, grad_y] -@torch.library.register_fake("nvtensornet::tensor_matmul_so3_3x3_bwd_primitive") +@torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_primitive") def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: return [torch.empty_like(x), torch.empty_like(y)] @torch.library.custom_op( - "nvtensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive", + "tensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -160,7 +160,7 @@ def _( return [grad_grad_output, grad_x, grad_y] -@torch.library.register_fake("nvtensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive") +@torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive") def _( grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor ) -> List[Tensor]: @@ -183,13 +183,13 @@ def tensor_matmul_so3_3x3_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def tensor_matmul_so3_3x3_fwd(*args): - return torch.ops.nvtensornet.tensor_matmul_so3_3x3_fwd_primitive(*args) + return torch.ops.tensornet.tensor_matmul_so3_3x3_fwd_primitive(*args) @torch.compiler.allow_in_graph def tensor_matmul_so3_3x3_bwd(ctx, grad_output): x, y = ctx.saved_tensors - dx, dy = torch.ops.nvtensornet.tensor_matmul_so3_3x3_bwd_primitive( + dx, dy = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_primitive( grad_output, x, y ) return dx, dy @@ -202,25 +202,25 @@ def tensor_matmul_so3_3x3_bwd_bwd(ctx, *grad_outputs): grad_output_saved, x, y = ctx.saved_tensors - outputs = torch.ops.nvtensornet.tensor_matmul_so3_3x3_bwd_bwd_primitive( + outputs = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_bwd_primitive( grad_output_saved, grad_grad_x, grad_grad_y, x, y ) return outputs[0], outputs[1], outputs[2] torch.library.register_autograd( - "nvtensornet::tensor_matmul_so3_3x3_fwd_primitive", + "tensornet::tensor_matmul_so3_3x3_fwd_primitive", tensor_matmul_so3_3x3_bwd, setup_context=tensor_matmul_so3_3x3_setup_fwd_context, ) torch.library.register_autograd( - "nvtensornet::tensor_matmul_so3_3x3_bwd_primitive", + "tensornet::tensor_matmul_so3_3x3_bwd_primitive", tensor_matmul_so3_3x3_bwd_bwd, setup_context=tensor_matmul_so3_3x3_setup_bwd_context, ) def fn_tensor_matmul_so3_3x3(x: Tensor, y: Tensor) -> Tensor: - z = torch.ops.nvtensornet.tensor_matmul_so3_3x3_fwd_primitive(x, y) + z = torch.ops.tensornet.tensor_matmul_so3_3x3_fwd_primitive(x, y) return z diff --git a/src/matgl/ops/tensor_norm3.py b/src/matgl/ops/tensor_norm3.py index 7433a7e5..3eabb174 100644 --- a/src/matgl/ops/tensor_norm3.py +++ b/src/matgl/ops/tensor_norm3.py @@ -37,7 +37,7 @@ @torch.library.custom_op( - "nvtensornet::tensor_norm3_fwd_primitive", + "tensornet::tensor_norm3_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -62,13 +62,13 @@ def _(x: Tensor) -> Tensor: return output -@torch.library.register_fake("nvtensornet::tensor_norm3_fwd_primitive") +@torch.library.register_fake("tensornet::tensor_norm3_fwd_primitive") def _(x: Tensor) -> Tensor: return torch.empty((x.shape[0], 3 * x.shape[-1]), dtype=x.dtype, device=x.device) @torch.library.custom_op( - "nvtensornet::tensor_norm3_bwd_primitive", + "tensornet::tensor_norm3_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -96,7 +96,7 @@ def _( return [grad_x] -@torch.library.register_fake("nvtensornet::tensor_norm3_bwd_primitive") +@torch.library.register_fake("tensornet::tensor_norm3_bwd_primitive") def _( grad_output: Tensor, x: Tensor ) -> List[Tensor]: @@ -104,19 +104,29 @@ def _( @torch.library.custom_op( - "nvtensornet::tensor_norm3_bwd_bwd_primitive", + "tensornet::tensor_norm3_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) def _( grad_grad_x: Tensor, -) -> Tensor: + x: Tensor, + grad_output: Tensor, +) -> List[Tensor]: stream = get_stream(grad_grad_x.device) device = wp.device_from_torch(grad_grad_x.device) - grad_grad_output = torch.empty((grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), dtype=grad_grad_x.dtype, device=grad_grad_x.device) + grad_grad_output = torch.empty( + (grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), + dtype=grad_grad_x.dtype, + device=grad_grad_x.device, + ) + grad_x = torch.empty_like(x) grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) tensor_norm3_bwd_bwd = get_module("tensor_norm3_bwd_bwd", [str(grad_grad_x.dtype)]) wp.launch( @@ -126,70 +136,82 @@ def _( device=device, inputs=( grad_grad_x_wp, + x_wp, + grad_output_wp, grad_grad_output_wp, + grad_x_wp, ), ) - return grad_grad_output + return [grad_grad_output, grad_x] -@torch.library.register_fake("nvtensornet::tensor_norm3_bwd_bwd_primitive") +@torch.library.register_fake("tensornet::tensor_norm3_bwd_bwd_primitive") def _( grad_grad_x: Tensor, -) -> Tensor: - return torch.empty((grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), dtype=grad_grad_x.dtype, device=grad_grad_x.device) + x: Tensor, + grad_output: Tensor, +) -> List[Tensor]: + return [ + torch.empty( + (grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), + dtype=grad_grad_x.dtype, + device=grad_grad_x.device, + ), + torch.empty_like(x), + ] def tensor_norm3_fwd_setup_context(ctx, inputs, output): - (x,) = inputs # Unpack the single input tensor + (x,) = inputs ctx.save_for_backward(x) def tensor_norm3_bwd_setup_context(ctx, inputs, output): (grad_output, x) = inputs - ctx.save_for_backward(x) + ctx.save_for_backward(grad_output, x) @torch.compiler.allow_in_graph def tensor_norm3_fwd(*args): - return torch.ops.nvtensornet.tensor_norm3_fwd_primitive(*args) + """Forward: computes I, A, S norms of 3x3 tensor.""" + return torch.ops.tensornet.tensor_norm3_fwd_primitive(*args) @torch.compiler.allow_in_graph def tensor_norm3_bwd(ctx, grad_output): + """Backward: returns grad for x.""" (x,) = ctx.saved_tensors - dx = torch.ops.nvtensornet.tensor_norm3_bwd_primitive( - grad_output, x - ) - return dx[0] + return torch.ops.tensornet.tensor_norm3_bwd_primitive(grad_output, x)[0] @torch.compiler.allow_in_graph -def tensor_norm3_bwd_bwd(ctx, grad_grad_x): - (x,) = ctx.saved_tensors +def tensor_norm3_bwd_bwd(ctx, *grad_outputs): + """Double backward: returns (grad for grad_output, grad for x).""" + (grad_grad_x,) = grad_outputs[0] + grad_output, x = ctx.saved_tensors if grad_grad_x is None: grad_grad_x = torch.zeros_like(x) - grad_grad_output = torch.ops.nvtensornet.tensor_norm3_bwd_bwd_primitive( - grad_grad_x + outputs = torch.ops.tensornet.tensor_norm3_bwd_bwd_primitive( + grad_grad_x, x, grad_output ) - - return grad_grad_output + return outputs[0], outputs[1] torch.library.register_autograd( - "nvtensornet::tensor_norm3_fwd_primitive", + "tensornet::tensor_norm3_fwd_primitive", tensor_norm3_bwd, setup_context=tensor_norm3_fwd_setup_context, ) torch.library.register_autograd( - "nvtensornet::tensor_norm3_bwd_primitive", + "tensornet::tensor_norm3_bwd_primitive", tensor_norm3_bwd_bwd, setup_context=tensor_norm3_bwd_setup_context, ) def fn_tensor_norm3(x: Tensor) -> Tensor: - return torch.ops.nvtensornet.tensor_norm3_fwd_primitive(x) + return torch.ops.tensornet.tensor_norm3_fwd_primitive(x) diff --git a/src/matgl/ops/tensornet_mp.py b/src/matgl/ops/tensornet_mp.py index a8829d7a..a4d7ec3e 100644 --- a/src/matgl/ops/tensornet_mp.py +++ b/src/matgl/ops/tensornet_mp.py @@ -37,7 +37,7 @@ @torch.library.custom_op( - "nvtensornet::message_passing_fwd_primitive", + "tensornet::message_passing_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -96,7 +96,7 @@ def _( return [output_x, output_y, output_z] -@torch.library.register_fake("nvtensornet::message_passing_fwd_primitive") +@torch.library.register_fake("tensornet::message_passing_fwd_primitive") def _( x: Tensor, y: Tensor, @@ -113,7 +113,7 @@ def _( @torch.library.custom_op( - "nvtensornet::message_passing_bwd_primitive", + "tensornet::message_passing_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -187,7 +187,7 @@ def _( return [grad_x, grad_y, grad_z, grad_edge_attr] -@torch.library.register_fake("nvtensornet::message_passing_bwd_primitive") +@torch.library.register_fake("tensornet::message_passing_bwd_primitive") def _( grad_output_x: Tensor, grad_output_y: Tensor, @@ -212,7 +212,7 @@ def _( @torch.library.custom_op( - "nvtensornet::message_passing_bwd_bwd_primitive", + "tensornet::message_passing_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -237,9 +237,12 @@ def _( ) -> List[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) + + # Convert inputs to warp arrays x_wp = wp.from_torch(x.detach(), return_ctype=True) y_wp = wp.from_torch(y.detach(), return_ctype=True) z_wp = wp.from_torch(z.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) @@ -247,44 +250,42 @@ def _( grad_grad_edge_attr_wp = wp.from_torch( grad_grad_edge_attr.detach(), return_ctype=True ) + grad_output_x_wp = wp.from_torch(grad_output_x.detach(), return_ctype=True) grad_output_y_wp = wp.from_torch(grad_output_y.detach(), return_ctype=True) grad_output_z_wp = wp.from_torch(grad_output_z.detach(), return_ctype=True) - dgrad_output_x = torch.empty_like(grad_output_x) - dgrad_output_y = torch.empty_like(grad_output_y) - dgrad_output_z = torch.empty_like(grad_output_z) + col_data_wp = wp.from_torch(col_data.detach(), return_ctype=True) + col_indices_wp = wp.from_torch(col_indices.detach(), return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr.detach(), return_ctype=True) + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + # Allocate output tensors (no zero-init needed with two-kernel approach) dgrad_x = torch.empty_like(x) dgrad_y = torch.empty_like(y) dgrad_z = torch.empty_like(z) - dgrad_edge_attr = torch.empty_like(edge_attr) - - edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + dgrad_output_x = torch.empty_like(grad_output_x) + dgrad_output_y = torch.empty_like(grad_output_y) + dgrad_output_z = torch.empty_like(grad_output_z) dgrad_x_wp = wp.from_torch(dgrad_x.detach(), return_ctype=True) dgrad_y_wp = wp.from_torch(dgrad_y.detach(), return_ctype=True) dgrad_z_wp = wp.from_torch(dgrad_z.detach(), return_ctype=True) - dgrad_edge_attr_wp = wp.from_torch(dgrad_edge_attr.detach(), return_ctype=True) - dgrad_output_x_wp = wp.from_torch(dgrad_output_x.detach(), return_ctype=True) dgrad_output_y_wp = wp.from_torch(dgrad_output_y.detach(), return_ctype=True) dgrad_output_z_wp = wp.from_torch(dgrad_output_z.detach(), return_ctype=True) - col_data_wp = wp.from_torch(col_data.detach(), return_ctype=True) - col_indices_wp = wp.from_torch(col_indices.detach(), return_ctype=True) - col_indptr_wp = wp.from_torch(col_indptr.detach(), return_ctype=True) - - row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) - row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) - row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) - - message_passing_bwd_bwd = get_module("message_passing_bwd_bwd", [str(x.dtype)]) - + # Kernel 1: col-based - computes d2I, d2A, d2S, d2edge_attr + message_passing_edge_bwd_bwd = get_module( + "message_passing_edge_bwd_bwd", [str(x.dtype)] + ) wp.launch( - message_passing_bwd_bwd, + message_passing_edge_bwd_bwd, dim=(x.shape[0], x.shape[-1]), stream=stream, device=device, @@ -292,7 +293,6 @@ def _( x_wp, y_wp, z_wp, - edge_attr_wp, grad_grad_x_wp, grad_grad_y_wp, grad_grad_z_wp, @@ -300,9 +300,6 @@ def _( grad_output_x_wp, grad_output_y_wp, grad_output_z_wp, - row_data_wp, - row_indices_wp, - row_indptr_wp, col_data_wp, col_indices_wp, col_indptr_wp, @@ -310,11 +307,36 @@ def _( dgrad_y_wp, dgrad_z_wp, dgrad_edge_attr_wp, + ), + ) + + # Kernel 2: row-based - computes d2output_I, d2output_A, d2output_S + message_passing_output_bwd_bwd = get_module( + "message_passing_output_bwd_bwd", [str(x.dtype)] + ) + wp.launch( + message_passing_output_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_grad_z_wp, + grad_grad_edge_attr_wp, + row_data_wp, + row_indices_wp, + row_indptr_wp, dgrad_output_x_wp, dgrad_output_y_wp, dgrad_output_z_wp, ), ) + return [ dgrad_output_x, dgrad_output_y, @@ -326,7 +348,7 @@ def _( ] -@torch.library.register_fake("nvtensornet::message_passing_bwd_bwd_primitive") +@torch.library.register_fake("tensornet::message_passing_bwd_bwd_primitive") def _( grad_output_x: Tensor, grad_output_y: Tensor, @@ -419,7 +441,7 @@ def message_passing_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def message_passing_fwd(*args): - return torch.ops.nvtensornet.message_passing_fwd_primitive(*args) + return torch.ops.tensornet.message_passing_fwd_primitive(*args) @torch.compiler.allow_in_graph @@ -437,7 +459,7 @@ def message_passing_bwd(ctx, grad_outputs): col_indptr, ) = ctx.saved_tensors - result = torch.ops.nvtensornet.message_passing_bwd_primitive( + result = torch.ops.tensornet.message_passing_bwd_primitive( grad_outputs[0], grad_outputs[1], grad_outputs[2], @@ -478,7 +500,7 @@ def message_passing_bwd_bwd(ctx, *grad_outputs): col_indptr, ) = ctx.saved_tensors - result = torch.ops.nvtensornet.message_passing_bwd_bwd_primitive( + result = torch.ops.tensornet.message_passing_bwd_bwd_primitive( grad_output_x, grad_output_y, grad_output_z, @@ -516,13 +538,13 @@ def message_passing_bwd_bwd(ctx, *grad_outputs): torch.library.register_autograd( - "nvtensornet::message_passing_fwd_primitive", + "tensornet::message_passing_fwd_primitive", message_passing_bwd, setup_context=message_passing_setup_fwd_context, ) torch.library.register_autograd( - "nvtensornet::message_passing_bwd_primitive", + "tensornet::message_passing_bwd_primitive", message_passing_bwd_bwd, setup_context=message_passing_setup_bwd_context, ) @@ -540,7 +562,7 @@ def fn_message_passing( col_indices: Tensor, col_indptr: Tensor, ) -> List[Tensor]: - return torch.ops.nvtensornet.message_passing_fwd_primitive( + return torch.ops.tensornet.message_passing_fwd_primitive( x, y, z, diff --git a/src/matgl/ops/tensornet_radial_mp.py b/src/matgl/ops/tensornet_radial_mp.py index 524d98ee..52371def 100644 --- a/src/matgl/ops/tensornet_radial_mp.py +++ b/src/matgl/ops/tensornet_radial_mp.py @@ -37,7 +37,7 @@ @torch.library.custom_op( - "nvtensornet::radial_message_passing_fwd_primitive", + "tensornet::radial_message_passing_fwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -96,7 +96,7 @@ def _( return [output_I, output_A, output_S] -@torch.library.register_fake("nvtensornet::radial_message_passing_fwd_primitive") +@torch.library.register_fake("tensornet::radial_message_passing_fwd_primitive") def _( edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor ) -> List[Tensor]: @@ -121,7 +121,7 @@ def _( @torch.library.custom_op( - "nvtensornet::radial_message_passing_bwd_primitive", + "tensornet::radial_message_passing_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -180,7 +180,7 @@ def _( return [grad_edge_vec_norm, grad_edge_attr] -@torch.library.register_fake("nvtensornet::radial_message_passing_bwd_primitive") +@torch.library.register_fake("tensornet::radial_message_passing_bwd_primitive") def _( grad_output_I: Tensor, grad_output_A: Tensor, @@ -194,7 +194,7 @@ def _( @torch.library.custom_op( - "nvtensornet::radial_message_passing_bwd_bwd_primitive", + "tensornet::radial_message_passing_bwd_bwd_primitive", mutates_args=(), device_types=["cpu", "cuda"], ) @@ -282,7 +282,7 @@ def _( @torch.library.register_fake( - "nvtensornet::radial_message_passing_bwd_bwd_primitive" + "tensornet::radial_message_passing_bwd_bwd_primitive" ) def _( grad_output_I: Tensor, @@ -330,14 +330,14 @@ def radial_message_passing_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def radial_message_passing_fwd(*args): - return torch.ops.nvtensornet.radial_message_passing_fwd_primitive(*args) + return torch.ops.tensornet.radial_message_passing_fwd_primitive(*args) @torch.compiler.allow_in_graph def radial_message_passing_bwd(ctx, grad_outputs): edge_vec_norm, edge_attr, row_data, row_indptr = ctx.saved_tensors - result = torch.ops.nvtensornet.radial_message_passing_bwd_primitive( + result = torch.ops.tensornet.radial_message_passing_bwd_primitive( grad_outputs[0], grad_outputs[1], grad_outputs[2], @@ -366,7 +366,7 @@ def radial_message_passing_bwd_bwd(ctx, *grad_outputs): row_indptr, ) = ctx.saved_tensors - result = torch.ops.nvtensornet.radial_message_passing_bwd_bwd_primitive( + result = torch.ops.tensornet.radial_message_passing_bwd_bwd_primitive( grad_output_I, grad_output_A, grad_output_S, @@ -398,13 +398,13 @@ def radial_message_passing_bwd_bwd(ctx, *grad_outputs): torch.library.register_autograd( - "nvtensornet::radial_message_passing_fwd_primitive", + "tensornet::radial_message_passing_fwd_primitive", radial_message_passing_bwd, setup_context=radial_message_passing_setup_fwd_context, ) torch.library.register_autograd( - "nvtensornet::radial_message_passing_bwd_primitive", + "tensornet::radial_message_passing_bwd_primitive", radial_message_passing_bwd_bwd, setup_context=radial_message_passing_setup_bwd_context, ) @@ -413,6 +413,6 @@ def radial_message_passing_bwd_bwd(ctx, *grad_outputs): def fn_radial_message_passing( edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor ) -> List[Tensor]: - return torch.ops.nvtensornet.radial_message_passing_fwd_primitive( + return torch.ops.tensornet.radial_message_passing_fwd_primitive( edge_vec_norm, edge_attr, row_data, row_indptr ) From aabf71f736dd4ebcd4a079939abe92a4ab232774 Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Wed, 7 Jan 2026 17:16:52 -0500 Subject: [PATCH 03/18] linting Signed-off-by: Roman Zubatyuk --- src/matgl/kernels/__init__.py | 3 +- src/matgl/kernels/decompose_tensor.py | 12 +-- src/matgl/kernels/equivariant_o3_matmul.py | 12 +-- src/matgl/kernels/equivariant_so3_matmul.py | 12 +-- src/matgl/kernels/tensor_norm3.py | 25 +++--- src/matgl/kernels/tensornet_radial_mp.py | 92 +++++---------------- src/matgl/kernels/utils.py | 2 +- src/matgl/models/_tensornet_pyg.py | 67 ++++----------- src/matgl/ops/__init__.py | 3 +- src/matgl/ops/compose_tensor.py | 8 +- src/matgl/ops/decompose_tensor.py | 26 ++---- src/matgl/ops/equivariant_o3_matmul.py | 26 ++---- src/matgl/ops/equivariant_so3_matmul.py | 16 +--- src/matgl/ops/graph_transform.py | 49 ++++------- src/matgl/ops/tensor_norm3.py | 14 +--- src/matgl/ops/tensornet_mp.py | 12 +-- src/matgl/ops/tensornet_radial_mp.py | 50 +++-------- 17 files changed, 118 insertions(+), 311 deletions(-) diff --git a/src/matgl/kernels/__init__.py b/src/matgl/kernels/__init__.py index ee234854..5a945896 100644 --- a/src/matgl/kernels/__init__.py +++ b/src/matgl/kernels/__init__.py @@ -37,6 +37,7 @@ from .utils import add_module, get_module, get_stream import warp as wp + wp.init() @@ -53,4 +54,4 @@ "add_module", "get_module", "get_stream", -] \ No newline at end of file +] diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py index 685b6f7f..9ce8adfb 100644 --- a/src/matgl/kernels/decompose_tensor.py +++ b/src/matgl/kernels/decompose_tensor.py @@ -209,14 +209,14 @@ def decompose_tensor_bwd_bwd( ) -decompose_tensor_fwd_fp64, decompose_tensor_bwd_fp64, decompose_tensor_bwd_bwd_fp64 = ( - generate_decompose_tensor("float64") +decompose_tensor_fwd_fp64, decompose_tensor_bwd_fp64, decompose_tensor_bwd_bwd_fp64 = generate_decompose_tensor( + "float64" ) -decompose_tensor_fwd_fp32, decompose_tensor_bwd_fp32, decompose_tensor_bwd_bwd_fp32 = ( - generate_decompose_tensor("float32") +decompose_tensor_fwd_fp32, decompose_tensor_bwd_fp32, decompose_tensor_bwd_bwd_fp32 = generate_decompose_tensor( + "float32" ) -decompose_tensor_fwd_fp16, decompose_tensor_bwd_fp16, decompose_tensor_bwd_bwd_fp16 = ( - generate_decompose_tensor("float16") +decompose_tensor_fwd_fp16, decompose_tensor_bwd_fp16, decompose_tensor_bwd_bwd_fp16 = generate_decompose_tensor( + "float16" ) add_module("decompose_tensor_fwd", ["float64"], decompose_tensor_fwd_fp64) diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py index dcf027b0..8f170ee8 100644 --- a/src/matgl/kernels/equivariant_o3_matmul.py +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -195,18 +195,12 @@ def tensor_matmul_o3_3x3_bwd_bwd( add_module("tensor_matmul_o3_3x3_fwd", ["float64"], tensor_matmul_o3_3x3_fwd_fp64) add_module("tensor_matmul_o3_3x3_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_fp64) -add_module( - "tensor_matmul_o3_3x3_bwd_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_bwd_fp64 -) +add_module("tensor_matmul_o3_3x3_bwd_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_bwd_fp64) add_module("tensor_matmul_o3_3x3_fwd", ["float32"], tensor_matmul_o3_3x3_fwd_fp32) add_module("tensor_matmul_o3_3x3_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_fp32) -add_module( - "tensor_matmul_o3_3x3_bwd_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_bwd_fp32 -) +add_module("tensor_matmul_o3_3x3_bwd_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_bwd_fp32) add_module("tensor_matmul_o3_3x3_fwd", ["float16"], tensor_matmul_o3_3x3_fwd_fp16) add_module("tensor_matmul_o3_3x3_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_fp16) -add_module( - "tensor_matmul_o3_3x3_bwd_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_bwd_fp16 -) +add_module("tensor_matmul_o3_3x3_bwd_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_bwd_fp16) diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index 49c2a7c5..abe88120 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -186,18 +186,12 @@ def tensor_matmul_so3_3x3_bwd_bwd( add_module("tensor_matmul_so3_3x3_fwd", ["float64"], tensor_matmul_so3_3x3_fwd_fp64) add_module("tensor_matmul_so3_3x3_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_fp64) -add_module( - "tensor_matmul_so3_3x3_bwd_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_bwd_fp64 -) +add_module("tensor_matmul_so3_3x3_bwd_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_bwd_fp64) add_module("tensor_matmul_so3_3x3_fwd", ["float32"], tensor_matmul_so3_3x3_fwd_fp32) add_module("tensor_matmul_so3_3x3_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_fp32) -add_module( - "tensor_matmul_so3_3x3_bwd_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_bwd_fp32 -) +add_module("tensor_matmul_so3_3x3_bwd_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_bwd_fp32) add_module("tensor_matmul_so3_3x3_fwd", ["float16"], tensor_matmul_so3_3x3_fwd_fp16) add_module("tensor_matmul_so3_3x3_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_fp16) -add_module( - "tensor_matmul_so3_3x3_bwd_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_bwd_fp16 -) +add_module("tensor_matmul_so3_3x3_bwd_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_bwd_fp16) diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py index 3d24adda..032cb6c7 100644 --- a/src/matgl/kernels/tensor_norm3.py +++ b/src/matgl/kernels/tensor_norm3.py @@ -67,12 +67,13 @@ def tensor_norm3_fwd( trace_third = trace / X.dtype(3.0) norm2_i = one_third * trace * trace norm2_a = one_half * ((x01 - x10) * (x01 - x10) + (x02 - x20) * (x02 - x20) + (x12 - x21) * (x12 - x21)) - norm2_s = one_half * ( - (x01 + x10) * (x01 + x10) - + (x02 + x20) * (x02 + x20) - + (x12 + x21) * (x12 + x21) - ) + (x00 - trace_third) * (x00 - trace_third) + (x11 - trace_third) * (x11 - trace_third) + (x22 - trace_third) * (x22 - trace_third) - + norm2_s = ( + one_half * ((x01 + x10) * (x01 + x10) + (x02 + x20) * (x02 + x20) + (x12 + x21) * (x12 + x21)) + + (x00 - trace_third) * (x00 - trace_third) + + (x11 - trace_third) * (x11 - trace_third) + + (x22 - trace_third) * (x22 - trace_third) + ) + output[b, h] = norm2_i output[b, h + X.shape[3]] = norm2_a output[b, h + 2 * X.shape[3]] = norm2_s @@ -255,15 +256,9 @@ def tensor_norm3_bwd_bwd( ) -tensor_norm3_fwd_fp64, tensor_norm3_bwd_fp64, tensor_norm3_bwd_bwd_fp64 = ( - generate_tensor_norm3("float64") -) -tensor_norm3_fwd_fp32, tensor_norm3_bwd_fp32, tensor_norm3_bwd_bwd_fp32 = ( - generate_tensor_norm3("float32") -) -tensor_norm3_fwd_fp16, tensor_norm3_bwd_fp16, tensor_norm3_bwd_bwd_fp16 = ( - generate_tensor_norm3("float16") -) +tensor_norm3_fwd_fp64, tensor_norm3_bwd_fp64, tensor_norm3_bwd_bwd_fp64 = generate_tensor_norm3("float64") +tensor_norm3_fwd_fp32, tensor_norm3_bwd_fp32, tensor_norm3_bwd_bwd_fp32 = generate_tensor_norm3("float32") +tensor_norm3_fwd_fp16, tensor_norm3_bwd_fp16, tensor_norm3_bwd_bwd_fp16 = generate_tensor_norm3("float16") add_module("tensor_norm3_fwd", ["float64"], tensor_norm3_fwd_fp64) add_module("tensor_norm3_bwd", ["float64"], tensor_norm3_bwd_fp64) diff --git a/src/matgl/kernels/tensornet_radial_mp.py b/src/matgl/kernels/tensornet_radial_mp.py index 5e069910..27d39a2b 100644 --- a/src/matgl/kernels/tensornet_radial_mp.py +++ b/src/matgl/kernels/tensornet_radial_mp.py @@ -74,9 +74,7 @@ def radial_message_passing_fwd( output_A_reg[2] += r_ij[0] * weight_A_reg S_reg = vec5() - mean_r2 = ( - r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2] - ) / output_I.dtype(3.0) + mean_r2 = (r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2]) / output_I.dtype(3.0) S_reg[0] = r_ij[0] * r_ij[0] - mean_r2 S_reg[1] = r_ij[0] * r_ij[1] S_reg[2] = r_ij[0] * r_ij[2] @@ -140,16 +138,10 @@ def radial_message_passing_bwd( dedge_attr_I = doutput_I_reg - dedge_attr_A = ( - doutput_A_reg[0] * r_ij[2] - - doutput_A_reg[1] * r_ij[1] - + doutput_A_reg[2] * r_ij[0] - ) + dedge_attr_A = doutput_A_reg[0] * r_ij[2] - doutput_A_reg[1] * r_ij[1] + doutput_A_reg[2] * r_ij[0] S_reg = vec5() - mean_r2 = ( - r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2] - ) / doutput_I.dtype(3.0) + mean_r2 = (r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2]) / doutput_I.dtype(3.0) S_reg[0] = r_ij[0] * r_ij[0] - mean_r2 S_reg[1] = r_ij[0] * r_ij[1] S_reg[2] = r_ij[0] * r_ij[2] @@ -248,11 +240,7 @@ def radial_message_passing_bwd_bwd( d2output_A_reg[1] += -dedge_attr_A * r_ij[1] d2output_A_reg[2] += dedge_attr_A * r_ij[0] - dweight_A = ( - doutput_A[b, 0, h] * dr_ij[2] - - doutput_A[b, 1, h] * dr_ij[1] - + doutput_A[b, 2, h] * dr_ij[0] - ) + dweight_A = doutput_A[b, 0, h] * dr_ij[2] - doutput_A[b, 1, h] * dr_ij[1] + doutput_A[b, 2, h] * dr_ij[0] d2r_ij[2] += dedge_attr_A * doutput_A[b, 0, h] d2r_ij[1] += -dedge_attr_A * doutput_A[b, 1, h] @@ -299,60 +287,36 @@ def radial_message_passing_bwd_bwd( ) d2output_S_reg[4] += dedge_attr_S * (r_ij[2] * r_ij[1]) - d2r_ij[0] += ( - doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c0) - ) - d2r_ij[1] += ( - doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c1) - ) - d2r_ij[2] += ( - doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) - ) + d2r_ij[0] += doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c0) + d2r_ij[1] += doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c1) + d2r_ij[2] += doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) d2r_ij[0] += doutput_S[b, 0, h] * dedge_attr_S * (c0 * r_ij[0]) d2r_ij[1] += doutput_S[b, 0, h] * dedge_attr_S * (c1 * r_ij[1]) d2r_ij[2] += doutput_S[b, 0, h] * dedge_attr_S * (c1 * r_ij[2]) - d2r_ij[0] += ( - doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) - ) - d2r_ij[1] += ( - doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) - ) + d2r_ij[0] += doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) + d2r_ij[1] += doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) d2r_ij[0] += doutput_S[b, 1, h] * dedge_attr_S * (r_ij[1]) d2r_ij[1] += doutput_S[b, 1, h] * dedge_attr_S * (r_ij[0]) - d2r_ij[0] += ( - doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) - ) - d2r_ij[2] += ( - doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) - ) + d2r_ij[0] += doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) + d2r_ij[2] += doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) d2r_ij[0] += doutput_S[b, 2, h] * dedge_attr_S * (r_ij[2]) d2r_ij[2] += doutput_S[b, 2, h] * dedge_attr_S * (r_ij[0]) - d2r_ij[0] += ( - doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c1) - ) - d2r_ij[1] += ( - doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c0) - ) - d2r_ij[2] += ( - doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) - ) + d2r_ij[0] += doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c1) + d2r_ij[1] += doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c0) + d2r_ij[2] += doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) d2r_ij[0] += doutput_S[b, 3, h] * dedge_attr_S * (c1 * r_ij[0]) d2r_ij[1] += doutput_S[b, 3, h] * dedge_attr_S * (c0 * r_ij[1]) d2r_ij[2] += doutput_S[b, 3, h] * dedge_attr_S * (c1 * r_ij[2]) - d2r_ij[1] += ( - doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) - ) - d2r_ij[2] += ( - doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) - ) + d2r_ij[1] += doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) + d2r_ij[2] += doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) d2r_ij[1] += doutput_S[b, 4, h] * dedge_attr_S * (r_ij[2]) d2r_ij[2] += doutput_S[b, 4, h] * dedge_attr_S * (r_ij[1]) @@ -364,13 +328,9 @@ def radial_message_passing_bwd_bwd( + c1 * r_ij[2] * dedge_vec_norm[idx_w, 2] ) - d2weight_S += doutput_S[b, 1, h] * ( - r_ij[1] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 1] - ) + d2weight_S += doutput_S[b, 1, h] * (r_ij[1] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 1]) - d2weight_S += doutput_S[b, 2, h] * ( - r_ij[2] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 2] - ) + d2weight_S += doutput_S[b, 2, h] * (r_ij[2] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 2]) d2weight_S += doutput_S[b, 3, h] * ( c1 * r_ij[0] * dedge_vec_norm[idx_w, 0] @@ -378,9 +338,7 @@ def radial_message_passing_bwd_bwd( + c1 * r_ij[2] * dedge_vec_norm[idx_w, 2] ) - d2weight_S += doutput_S[b, 4, h] * ( - r_ij[2] * dedge_vec_norm[idx_w, 1] + r_ij[1] * dedge_vec_norm[idx_w, 2] - ) + d2weight_S += doutput_S[b, 4, h] * (r_ij[2] * dedge_vec_norm[idx_w, 1] + r_ij[1] * dedge_vec_norm[idx_w, 2]) wp.atomic_add(d2edge_attr, idx_w, 2, h, d2weight_S) @@ -433,18 +391,12 @@ def radial_message_passing_bwd_bwd( add_module("radial_message_passing_fwd", ["float64"], radial_message_passing_fwd_fp64) add_module("radial_message_passing_bwd", ["float64"], radial_message_passing_bwd_fp64) -add_module( - "radial_message_passing_bwd_bwd", ["float64"], radial_message_passing_bwd_bwd_fp64 -) +add_module("radial_message_passing_bwd_bwd", ["float64"], radial_message_passing_bwd_bwd_fp64) add_module("radial_message_passing_fwd", ["float32"], radial_message_passing_fwd_fp32) add_module("radial_message_passing_bwd", ["float32"], radial_message_passing_bwd_fp32) -add_module( - "radial_message_passing_bwd_bwd", ["float32"], radial_message_passing_bwd_bwd_fp32 -) +add_module("radial_message_passing_bwd_bwd", ["float32"], radial_message_passing_bwd_bwd_fp32) add_module("radial_message_passing_fwd", ["float16"], radial_message_passing_fwd_fp16) add_module("radial_message_passing_bwd", ["float16"], radial_message_passing_bwd_fp16) -add_module( - "radial_message_passing_bwd_bwd", ["float16"], radial_message_passing_bwd_bwd_fp16 -) +add_module("radial_message_passing_bwd_bwd", ["float16"], radial_message_passing_bwd_bwd_fp16) diff --git a/src/matgl/kernels/utils.py b/src/matgl/kernels/utils.py index 4da11a59..cb567243 100644 --- a/src/matgl/kernels/utils.py +++ b/src/matgl/kernels/utils.py @@ -103,4 +103,4 @@ def get_stream(device: torch.device): if device.type == "cuda": return wp.stream_from_torch(torch.cuda.current_stream(device)) else: - return None \ No newline at end of file + return None diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 269b56ca..7a0e52a9 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -82,7 +82,7 @@ def compute_pair_vector_and_distance( def tensor_norm(tensor): """Computes Frobenius norm.""" - return (tensor*tensor).sum((-3, -2)) + return (tensor * tensor).sum((-3, -2)) class TensorEmbedding(nn.Module): @@ -129,13 +129,7 @@ def _create_distance_proj( d_proj2 = nn.Linear(in_features, units, bias=True, dtype=dtype) d_proj3 = nn.Linear(in_features, units, bias=True, dtype=dtype) - layer = torch.nn.utils.skip_init( - nn.Linear, - in_features, - 3 * units, - bias=True, - dtype=dtype - ) + layer = torch.nn.utils.skip_init(nn.Linear, in_features, 3 * units, bias=True, dtype=dtype) with torch.no_grad(): layer.weight.copy_(torch.cat([d_proj1.weight, d_proj2.weight, d_proj3.weight], dim=0)) layer.bias.copy_(torch.cat([d_proj1.bias, d_proj2.bias, d_proj3.bias], dim=0)) @@ -145,25 +139,13 @@ def _reset_distance_proj(self) -> None: """Reset distance_proj weights using 3 temp layers to match reference RNG pattern.""" dtype = self.distance_proj.weight.dtype d_proj1 = torch.nn.utils.skip_init( - nn.Linear, - self.distance_proj.in_features, - self.units, - bias=True, - dtype=dtype + nn.Linear, self.distance_proj.in_features, self.units, bias=True, dtype=dtype ) d_proj2 = torch.nn.utils.skip_init( - nn.Linear, - self.distance_proj.in_features, - self.units, - bias=True, - dtype=dtype - ) + nn.Linear, self.distance_proj.in_features, self.units, bias=True, dtype=dtype + ) d_proj3 = torch.nn.utils.skip_init( - nn.Linear, - self.distance_proj.in_features, - self.units, - bias=True, - dtype=dtype + nn.Linear, self.distance_proj.in_features, self.units, bias=True, dtype=dtype ) d_proj1.reset_parameters() d_proj2.reset_parameters() @@ -229,22 +211,15 @@ def forward( edge_attr = self.distance_proj(edge_attr).view(-1, 3, self.units) # Get atomic number messages - zij = x.index_select(0, edge_index.t().reshape(-1)).view( - -1, self.units * 2 - ) + zij = x.index_select(0, edge_index.t().reshape(-1)).view(-1, self.units * 2) Zij = self.emb2(zij) # (num_edges, units) # Create edge attributes with Zij - edge_attr_processed = \ - edge_attr.view(-1, 3, self.units) \ - * C.view(-1, 1, 1) \ - * Zij.view(-1, 1, self.units) + edge_attr_processed = edge_attr.view(-1, 3, self.units) * C.view(-1, 1, 1) * Zij.view(-1, 1, self.units) # Radial message passing edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) - I, A, S = fn_radial_message_passing( - edge_vec_norm, edge_attr_processed, col_data, col_indptr - ) + I, A, S = fn_radial_message_passing(edge_vec_norm, edge_attr_processed, col_data, col_indptr) # Compose initial tensor to get proper shape for norm computation X = fn_compose_tensor(I, A, S) # (num_nodes, 3, 3, units) @@ -321,7 +296,7 @@ def forward( row_indptr: torch.Tensor, col_data: torch.Tensor, col_indices: torch.Tensor, - col_indptr: torch.Tensor, + col_indptr: torch.Tensor, ) -> torch.Tensor: """Forward pass. @@ -339,14 +314,14 @@ def forward( edge_attr_processed = edge_attr for linear_scalar in self.linears_scalar: edge_attr_processed = self.act(linear_scalar(edge_attr_processed)) - edge_attr_processed = (edge_attr_processed * C.view(-1, 1)).view( - edge_attr.shape[0], self.units, 3 - ).mT.contiguous() # (num_edges, 3, units) + edge_attr_processed = ( + (edge_attr_processed * C.view(-1, 1)).view(edge_attr.shape[0], self.units, 3).mT.contiguous() + ) # (num_edges, 3, units) # Normalize input tensor # For X with shape (num_nodes, 3, 3, units), we need to sum over (-3, -2) # which are the (3, 3) spatial dimensions to get (num_nodes, units) - norm_X = (X*X).sum((-3, -2)) + 1 # (num_nodes, units) + norm_X = (X * X).sum((-3, -2)) + 1 # (num_nodes, units) X = X / norm_X.view(-1, 1, 1, X.shape[-1]) # Decompose input tensor @@ -601,23 +576,15 @@ def forward( bond_vec, bond_dist = compute_pair_vector_and_distance(pos, edge_index, pbc_offshift) # perpare graph indices for message passing - row_data, row_indices, row_indptr, col_data, col_indices, col_indptr = ( - graph_transform(edge_index.int(), z.shape[0]) + row_data, row_indices, row_indptr, col_data, col_indices, col_indptr = graph_transform( + edge_index.int(), z.shape[0] ) # Expand distances with radial basis functions edge_attr = self.bond_expansion(bond_dist) # Embedding layer - X = self.tensor_embedding( - z, - edge_index, - bond_dist, - bond_vec, - edge_attr, - col_data, - col_indptr - ) + X = self.tensor_embedding(z, edge_index, bond_dist, bond_vec, edge_attr, col_data, col_indptr) # Interaction layers for layer in self.layers: diff --git a/src/matgl/ops/__init__.py b/src/matgl/ops/__init__.py index 42b8141a..20428c65 100644 --- a/src/matgl/ops/__init__.py +++ b/src/matgl/ops/__init__.py @@ -37,6 +37,7 @@ from .graph_transform import graph_transform import warp as wp + wp.init() __all__ = [ @@ -49,4 +50,4 @@ "fn_message_passing", "fn_radial_message_passing", "graph_transform", -] \ No newline at end of file +] diff --git a/src/matgl/ops/compose_tensor.py b/src/matgl/ops/compose_tensor.py index 6fd56079..bf948d3d 100644 --- a/src/matgl/ops/compose_tensor.py +++ b/src/matgl/ops/compose_tensor.py @@ -44,9 +44,7 @@ def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: stream = get_stream(x.device) device = wp.device_from_torch(x.device) - output = torch.empty( - (x.shape[0], 3, 3, x.shape[-1]), dtype=x.dtype, device=x.device - ) + output = torch.empty((x.shape[0], 3, 3, x.shape[-1]), dtype=x.dtype, device=x.device) x_wp = wp.from_torch(x.detach(), return_ctype=True) y_wp = wp.from_torch(y.detach(), return_ctype=True) @@ -182,9 +180,7 @@ def compose_tensor_fwd(*args): @torch.compiler.allow_in_graph def compose_tensor_bwd(ctx, grad_output): x, y, z = ctx.saved_tensors - dx, dy, dz = torch.ops.tensornet.compose_tensor_bwd_primitive( - grad_output, x, y, z - ) + dx, dy, dz = torch.ops.tensornet.compose_tensor_bwd_primitive(grad_output, x, y, z) return dx, dy, dz diff --git a/src/matgl/ops/decompose_tensor.py b/src/matgl/ops/decompose_tensor.py index a2c1973a..39197618 100644 --- a/src/matgl/ops/decompose_tensor.py +++ b/src/matgl/ops/decompose_tensor.py @@ -42,7 +42,6 @@ device_types=["cpu", "cuda"], ) def _(x: Tensor) -> List[Tensor]: - stream = get_stream(x.device) device = wp.device_from_torch(x.device) output_i = torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device) @@ -80,10 +79,7 @@ def _(x: Tensor) -> List[Tensor]: mutates_args=(), device_types=["cpu", "cuda"], ) -def _( - grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor -) -> List[Tensor]: - +def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> List[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) grad_x = torch.empty_like(x) @@ -107,9 +103,7 @@ def _( @torch.library.register_fake("tensornet::decompose_tensor_bwd_primitive") -def _( - grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor -) -> List[Tensor]: +def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> List[Tensor]: return [torch.empty_like(x)] @@ -133,15 +127,9 @@ def _( grad_grad_output_a = torch.empty_like(grad_output_a) grad_grad_output_s = torch.empty_like(grad_output_s) - grad_grad_output_i_wp = wp.from_torch( - grad_grad_output_i.detach(), return_ctype=True - ) - grad_grad_output_a_wp = wp.from_torch( - grad_grad_output_a.detach(), return_ctype=True - ) - grad_grad_output_s_wp = wp.from_torch( - grad_grad_output_s.detach(), return_ctype=True - ) + grad_grad_output_i_wp = wp.from_torch(grad_grad_output_i.detach(), return_ctype=True) + grad_grad_output_a_wp = wp.from_torch(grad_grad_output_a.detach(), return_ctype=True) + grad_grad_output_s_wp = wp.from_torch(grad_grad_output_s.detach(), return_ctype=True) grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) @@ -197,9 +185,7 @@ def decompose_tensor_fwd(*args): def decompose_tensor_bwd(ctx, *grad_outputs): (x,) = ctx.saved_tensors grad_output_i, grad_output_a, grad_output_s = grad_outputs[0] - dx = torch.ops.tensornet.decompose_tensor_bwd_primitive( - grad_output_i, grad_output_a, grad_output_s, x - ) + dx = torch.ops.tensornet.decompose_tensor_bwd_primitive(grad_output_i, grad_output_a, grad_output_s, x) return dx[0] diff --git a/src/matgl/ops/equivariant_o3_matmul.py b/src/matgl/ops/equivariant_o3_matmul.py index c6d37d2d..c9db9989 100644 --- a/src/matgl/ops/equivariant_o3_matmul.py +++ b/src/matgl/ops/equivariant_o3_matmul.py @@ -114,9 +114,7 @@ def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: mutates_args=(), device_types=["cpu", "cuda"], ) -def _( - grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor -) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: raise ValueError("x and y must be 3x3 matrices") if x.ndim != 4 or y.ndim != 4: @@ -139,9 +137,7 @@ def _( grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) - tensor_matmul_o3_3x3_bwd_bwd = get_module( - "tensor_matmul_o3_3x3_bwd_bwd", [str(grad_output.dtype)] - ) + tensor_matmul_o3_3x3_bwd_bwd = get_module("tensor_matmul_o3_3x3_bwd_bwd", [str(grad_output.dtype)]) wp.launch( tensor_matmul_o3_3x3_bwd_bwd, dim=(grad_output.shape[0], grad_output.shape[-1]), @@ -163,9 +159,7 @@ def _( @torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive") -def _( - grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor -) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: return [ torch.empty_like(grad_output), torch.empty_like(grad_output), @@ -185,17 +179,13 @@ def tensor_matmul_o3_3x3_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def tensor_matmul_o3_3x3_fwd(*args): - return getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_fwd_primitive")( - *args - ) + return getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_fwd_primitive")(*args) @torch.compiler.allow_in_graph def tensor_matmul_o3_3x3_bwd(ctx, grad_output): x, y = ctx.saved_tensors - dx, dy = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_primitive")( - grad_output, x, y - ) + dx, dy = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_primitive")(grad_output, x, y) return dx, dy @@ -206,9 +196,9 @@ def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): grad_output_saved, x, y = ctx.saved_tensors - outputs = getattr( - torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_bwd_primitive" - )(grad_output_saved, grad_grad_x, grad_grad_y, x, y) + outputs = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_bwd_primitive")( + grad_output_saved, grad_grad_x, grad_grad_y, x, y + ) return outputs[0], outputs[1], outputs[2] diff --git a/src/matgl/ops/equivariant_so3_matmul.py b/src/matgl/ops/equivariant_so3_matmul.py index 44d27958..1c0880c6 100644 --- a/src/matgl/ops/equivariant_so3_matmul.py +++ b/src/matgl/ops/equivariant_so3_matmul.py @@ -113,9 +113,7 @@ def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: mutates_args=(), device_types=["cpu", "cuda"], ) -def _( - grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor -) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: raise ValueError("x and y must be 3x3 matrices") if x.ndim != 4 or y.ndim != 4: @@ -137,9 +135,7 @@ def _( grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) - tensor_matmul_so3_3x3_bwd_bwd = get_module( - "tensor_matmul_so3_3x3_bwd_bwd", [str(grad_output.dtype)] - ) + tensor_matmul_so3_3x3_bwd_bwd = get_module("tensor_matmul_so3_3x3_bwd_bwd", [str(grad_output.dtype)]) wp.launch( tensor_matmul_so3_3x3_bwd_bwd, dim=(grad_output.shape[0], grad_output.shape[-1]), @@ -161,9 +157,7 @@ def _( @torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive") -def _( - grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor -) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: return [ torch.empty_like(grad_output), torch.empty_like(grad_output), @@ -189,9 +183,7 @@ def tensor_matmul_so3_3x3_fwd(*args): @torch.compiler.allow_in_graph def tensor_matmul_so3_3x3_bwd(ctx, grad_output): x, y = ctx.saved_tensors - dx, dy = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_primitive( - grad_output, x, y - ) + dx, dy = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_primitive(grad_output, x, y) return dx, dy diff --git a/src/matgl/ops/graph_transform.py b/src/matgl/ops/graph_transform.py index 14e2e944..a4a3274a 100644 --- a/src/matgl/ops/graph_transform.py +++ b/src/matgl/ops/graph_transform.py @@ -91,19 +91,11 @@ def _( row_indptr_wp = wp.from_torch(row_indptr, return_ctype=True) col_indptr_wp = wp.from_torch(col_indptr, return_ctype=True) - row_indices = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) - col_indices = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) + row_indices = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + col_indices = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) - row_data = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) - col_data = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) + row_data = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + col_data = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) row_indices_wp = wp.from_torch(row_indices, return_ctype=True) col_indices_wp = wp.from_torch(col_indices, return_ctype=True) @@ -140,37 +132,24 @@ def _( row_indptr: Tensor, col_indptr: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - output = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) - output2 = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) - output3 = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) - output4 = torch.empty( - edge_index.shape[1], dtype=torch.int32, device=edge_index.device - ) + output = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + output2 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + output3 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + output4 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) return output, output2, output3, output4 @torch.compiler.allow_in_graph -def graph_transform( - edge_index: Tensor, num_nodes: int -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - row_count, col_count = torch.ops.nvtnet.count_row_col_primitive( - edge_index, num_nodes +def graph_transform(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + row_count, col_count = torch.ops.nvtnet.count_row_col_primitive(edge_index, num_nodes) + row_indptr, col_indptr = ( + torch.cumsum(row_count, dim=0, dtype=torch.int32), + torch.cumsum(col_count, dim=0, dtype=torch.int32), ) - row_indptr, col_indptr = torch.cumsum( - row_count, dim=0, dtype=torch.int32 - ), torch.cumsum(col_count, dim=0, dtype=torch.int32) ( row_indices, col_indices, row_data, col_data, - ) = torch.ops.nvtnet.convert_to_sparse_primitive( - edge_index, row_count, col_count, row_indptr, col_indptr - ) + ) = torch.ops.nvtnet.convert_to_sparse_primitive(edge_index, row_count, col_count, row_indptr, col_indptr) return row_data, row_indices, row_indptr, col_data, col_indices, col_indptr diff --git a/src/matgl/ops/tensor_norm3.py b/src/matgl/ops/tensor_norm3.py index 3eabb174..60d0e793 100644 --- a/src/matgl/ops/tensor_norm3.py +++ b/src/matgl/ops/tensor_norm3.py @@ -42,7 +42,6 @@ device_types=["cpu", "cuda"], ) def _(x: Tensor) -> Tensor: - stream = get_stream(x.device) device = wp.device_from_torch(x.device) output = torch.empty((x.shape[0], 3 * x.shape[-1]), dtype=x.dtype, device=x.device) @@ -72,10 +71,7 @@ def _(x: Tensor) -> Tensor: mutates_args=(), device_types=["cpu", "cuda"], ) -def _( - grad_output: Tensor, x: Tensor -) -> List[Tensor]: - +def _(grad_output: Tensor, x: Tensor) -> List[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) grad_x = torch.empty_like(x) @@ -97,9 +93,7 @@ def _( @torch.library.register_fake("tensornet::tensor_norm3_bwd_primitive") -def _( - grad_output: Tensor, x: Tensor -) -> List[Tensor]: +def _(grad_output: Tensor, x: Tensor) -> List[Tensor]: return [torch.empty_like(x)] @@ -194,9 +188,7 @@ def tensor_norm3_bwd_bwd(ctx, *grad_outputs): if grad_grad_x is None: grad_grad_x = torch.zeros_like(x) - outputs = torch.ops.tensornet.tensor_norm3_bwd_bwd_primitive( - grad_grad_x, x, grad_output - ) + outputs = torch.ops.tensornet.tensor_norm3_bwd_bwd_primitive(grad_grad_x, x, grad_output) return outputs[0], outputs[1] diff --git a/src/matgl/ops/tensornet_mp.py b/src/matgl/ops/tensornet_mp.py index a4d7ec3e..f22c72d4 100644 --- a/src/matgl/ops/tensornet_mp.py +++ b/src/matgl/ops/tensornet_mp.py @@ -247,9 +247,7 @@ def _( grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) grad_grad_z_wp = wp.from_torch(grad_grad_z.detach(), return_ctype=True) - grad_grad_edge_attr_wp = wp.from_torch( - grad_grad_edge_attr.detach(), return_ctype=True - ) + grad_grad_edge_attr_wp = wp.from_torch(grad_grad_edge_attr.detach(), return_ctype=True) grad_output_x_wp = wp.from_torch(grad_output_x.detach(), return_ctype=True) grad_output_y_wp = wp.from_torch(grad_output_y.detach(), return_ctype=True) @@ -281,9 +279,7 @@ def _( dgrad_output_z_wp = wp.from_torch(dgrad_output_z.detach(), return_ctype=True) # Kernel 1: col-based - computes d2I, d2A, d2S, d2edge_attr - message_passing_edge_bwd_bwd = get_module( - "message_passing_edge_bwd_bwd", [str(x.dtype)] - ) + message_passing_edge_bwd_bwd = get_module("message_passing_edge_bwd_bwd", [str(x.dtype)]) wp.launch( message_passing_edge_bwd_bwd, dim=(x.shape[0], x.shape[-1]), @@ -311,9 +307,7 @@ def _( ) # Kernel 2: row-based - computes d2output_I, d2output_A, d2output_S - message_passing_output_bwd_bwd = get_module( - "message_passing_output_bwd_bwd", [str(x.dtype)] - ) + message_passing_output_bwd_bwd = get_module("message_passing_output_bwd_bwd", [str(x.dtype)]) wp.launch( message_passing_output_bwd_bwd, dim=(x.shape[0], x.shape[-1]), diff --git a/src/matgl/ops/tensornet_radial_mp.py b/src/matgl/ops/tensornet_radial_mp.py index 52371def..1a10b575 100644 --- a/src/matgl/ops/tensornet_radial_mp.py +++ b/src/matgl/ops/tensornet_radial_mp.py @@ -41,10 +41,7 @@ mutates_args=(), device_types=["cpu", "cuda"], ) -def _( - edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor -) -> List[Tensor]: - +def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> List[Tensor]: num_atoms = row_indptr.shape[0] - 1 stream = get_stream(edge_vec_norm.device) device = wp.device_from_torch(edge_vec_norm.device) @@ -74,9 +71,7 @@ def _( row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) - message_passing_fwd = get_module( - "radial_message_passing_fwd", [str(edge_vec_norm.dtype)] - ) + message_passing_fwd = get_module("radial_message_passing_fwd", [str(edge_vec_norm.dtype)]) wp.launch( message_passing_fwd, dim=(num_atoms, edge_attr.shape[-1]), @@ -97,9 +92,7 @@ def _( @torch.library.register_fake("tensornet::radial_message_passing_fwd_primitive") -def _( - edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor -) -> List[Tensor]: +def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> List[Tensor]: num_atoms = row_indptr.shape[0] - 1 return [ torch.empty( @@ -143,9 +136,7 @@ def _( grad_output_S_wp = wp.from_torch(grad_output_S.detach(), return_ctype=True) grad_edge_vec_norm = torch.zeros_like(edge_vec_norm) - grad_edge_vec_norm_wp = wp.from_torch( - grad_edge_vec_norm.detach(), return_ctype=True - ) + grad_edge_vec_norm_wp = wp.from_torch(grad_edge_vec_norm.detach(), return_ctype=True) grad_edge_attr = torch.zeros_like(edge_attr) grad_edge_attr_wp = wp.from_torch(grad_edge_attr.detach(), return_ctype=True) @@ -156,9 +147,7 @@ def _( row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) - message_passing_bwd = get_module( - "radial_message_passing_bwd", [str(edge_vec_norm.dtype)] - ) + message_passing_bwd = get_module("radial_message_passing_bwd", [str(edge_vec_norm.dtype)]) wp.launch( message_passing_bwd, dim=(num_atoms, edge_attr.shape[-1]), @@ -219,12 +208,8 @@ def _( row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) - grad_grad_edge_vec_norm_wp = wp.from_torch( - grad_grad_edge_vec_norm.detach(), return_ctype=True - ) - grad_grad_edge_attr_wp = wp.from_torch( - grad_grad_edge_attr.detach(), return_ctype=True - ) + grad_grad_edge_vec_norm_wp = wp.from_torch(grad_grad_edge_vec_norm.detach(), return_ctype=True) + grad_grad_edge_attr_wp = wp.from_torch(grad_grad_edge_attr.detach(), return_ctype=True) grad_output_I_wp = wp.from_torch(grad_output_I.detach(), return_ctype=True) grad_output_A_wp = wp.from_torch(grad_output_A.detach(), return_ctype=True) @@ -237,18 +222,12 @@ def _( dgrad_output_S_wp = wp.from_torch(dgrad_output_S.detach(), return_ctype=True) dgrad_grad_edge_vec_norm = torch.zeros_like(grad_grad_edge_vec_norm) - dgrad_grad_edge_vec_norm_wp = wp.from_torch( - dgrad_grad_edge_vec_norm.detach(), return_ctype=True - ) + dgrad_grad_edge_vec_norm_wp = wp.from_torch(dgrad_grad_edge_vec_norm.detach(), return_ctype=True) dgrad_grad_edge_attr = torch.zeros_like(grad_grad_edge_attr) - dgrad_grad_edge_attr_wp = wp.from_torch( - dgrad_grad_edge_attr.detach(), return_ctype=True - ) + dgrad_grad_edge_attr_wp = wp.from_torch(dgrad_grad_edge_attr.detach(), return_ctype=True) - message_passing_bwd_bwd = get_module( - "radial_message_passing_bwd_bwd", [str(edge_vec_norm.dtype)] - ) + message_passing_bwd_bwd = get_module("radial_message_passing_bwd_bwd", [str(edge_vec_norm.dtype)]) wp.launch( message_passing_bwd_bwd, dim=(num_atoms, edge_attr.shape[-1]), @@ -281,9 +260,7 @@ def _( ] -@torch.library.register_fake( - "tensornet::radial_message_passing_bwd_bwd_primitive" -) +@torch.library.register_fake("tensornet::radial_message_passing_bwd_bwd_primitive") def _( grad_output_I: Tensor, grad_output_A: Tensor, @@ -354,7 +331,6 @@ def radial_message_passing_bwd(ctx, grad_outputs): @torch.compiler.allow_in_graph def radial_message_passing_bwd_bwd(ctx, *grad_outputs): - grad_grad_edge_vec_norm, grad_grad_edge_attr = grad_outputs[0] ( grad_output_I, @@ -413,6 +389,4 @@ def radial_message_passing_bwd_bwd(ctx, *grad_outputs): def fn_radial_message_passing( edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor ) -> List[Tensor]: - return torch.ops.tensornet.radial_message_passing_fwd_primitive( - edge_vec_norm, edge_attr, row_data, row_indptr - ) + return torch.ops.tensornet.radial_message_passing_fwd_primitive(edge_vec_norm, edge_attr, row_data, row_indptr) From 313385cb7a159759b8653a0111219de3948b4f28 Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Fri, 9 Jan 2026 16:24:32 -0500 Subject: [PATCH 04/18] add bw and bbw tests Signed-off-by: Roman Zubatyuk --- ....py => test_tensornet_forward_backward.py} | 103 +++++++++--------- tests/models/test_tensornet_pyg.py | 53 +++++++++ 2 files changed, 104 insertions(+), 52 deletions(-) rename dev/{test_model_forward_backward.py => test_tensornet_forward_backward.py} (79%) diff --git a/dev/test_model_forward_backward.py b/dev/test_tensornet_forward_backward.py similarity index 79% rename from dev/test_model_forward_backward.py rename to dev/test_tensornet_forward_backward.py index a38864e5..d9466720 100644 --- a/dev/test_model_forward_backward.py +++ b/dev/test_tensornet_forward_backward.py @@ -25,7 +25,6 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# """Compare forward/backward/double-backward between matgl-main and current TensorNet.""" from __future__ import annotations @@ -38,11 +37,6 @@ import torch from pymatgen.core import Structure - -# ============================================================================= -# Configuration -# ============================================================================= - DEFAULT_MATGL_MAIN_PATH = str(Path(__file__).parent.parent / "matgl-main" / "src") MODEL_CONFIG = { @@ -58,10 +52,6 @@ } -# ============================================================================= -# Utilities -# ============================================================================= - def clear_matgl_modules() -> None: """Remove all matgl modules from sys.modules.""" for mod in [k for k in sys.modules if k.startswith("matgl")]: @@ -69,17 +59,17 @@ def clear_matgl_modules() -> None: def print_section(title: str) -> None: - """Print a section header.""" + """Print a formatted section header.""" print(f"\n{'=' * 70}\n{title}\n{'=' * 70}") def load_structure(path: str) -> Structure: - """Load structure from file using pymatgen.""" + """Load structure from file.""" return Structure.from_file(path) def get_element_types(structure: Structure) -> tuple[str, ...]: - """Extract sorted unique element symbols from structure.""" + """Extract sorted unique element symbols.""" return tuple(sorted({site.specie.symbol for site in structure})) @@ -90,7 +80,7 @@ def build_graph( compute_bond: Any = None, requires_grad: bool = False, ) -> Any: - """Build graph from structure.""" + """Build graph from structure with optional gradient tracking.""" graph, lat, _ = converter.get_graph(structure) pos = graph.frac_coords @ lat[0] graph.pos = pos.clone().detach().requires_grad_(requires_grad) if requires_grad else pos @@ -104,12 +94,8 @@ def build_graph( return graph.to(device) -# ============================================================================= -# Comparison Functions -# ============================================================================= - def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-6) -> bool: - """Compare two tensors, return True if matching.""" + """Compare two tensors element-wise.""" if t1.shape != t2.shape: print(f" {name}: SHAPE MISMATCH {t1.shape} vs {t2.shape}") return False @@ -124,13 +110,13 @@ def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = def compare_weights(ref_model: Any, cur_model: Any) -> bool: - """Compare model weights, handling distance_proj1/2/3 -> distance_proj mapping.""" + """Compare model weights with distance_proj layer remapping.""" print_section("Weight Comparison") ref_sd, cur_sd = ref_model.state_dict(), cur_model.state_dict() all_match = True - # Handle merged distance_proj layers + # Handle merged distance_proj layers (distance_proj1/2/3 -> distance_proj) dp_keys = [f"tensor_embedding.distance_proj{i}" for i in range(1, 4)] if f"{dp_keys[0]}.weight" in ref_sd: ref_w = torch.cat([ref_sd[f"{k}.weight"] for k in dp_keys], dim=0) @@ -140,7 +126,6 @@ def compare_weights(ref_model: Any, cur_model: Any) -> bool: all_match &= compare_tensors("weight", ref_w, cur_sd["tensor_embedding.distance_proj.weight"]) all_match &= compare_tensors("bias", ref_b, cur_sd["tensor_embedding.distance_proj.bias"]) - # Compare remaining parameters skip = {f"{k}.{p}" for k in dp_keys for p in ("weight", "bias")} print("\n--- Other Parameters ---") @@ -164,7 +149,7 @@ def compare_weights(ref_model: Any, cur_model: Any) -> bool: def compare_forward( ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device ) -> bool: - """Compare forward pass outputs.""" + """Compare forward pass energy predictions.""" print_section("Forward Pass") ref_model.eval() @@ -186,8 +171,8 @@ def compare_forward( def compare_backward( ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device -) -> tuple[bool, torch.Tensor, torch.Tensor, Any, Any]: - """Compare backward pass (forces = -dE/dpos).""" +) -> bool: + """Compare forces (F = -dE/dpos).""" print_section("Backward Pass (Forces)") ref_model.train() @@ -196,7 +181,7 @@ def compare_backward( def get_forces(model, graph): energy = model(g=graph, state_attr=state_attr) - return -torch.autograd.grad(energy, graph.pos, create_graph=True, retain_graph=True)[0] + return -torch.autograd.grad(energy, graph.pos, create_graph=True)[0] ref_f = get_forces(ref_model, ref_graph) cur_f = get_forces(cur_model, cur_graph) @@ -209,28 +194,46 @@ def get_forces(model, graph): match = diff.max().item() < 1e-5 print(f"Result: {'PASS' if match else 'FAIL'}") - return match, ref_f, cur_f, ref_graph, cur_graph + return match def compare_double_backward( - ref_forces: torch.Tensor, cur_forces: torch.Tensor, ref_graph: Any, cur_graph: Any + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device ) -> bool: - """Compare Hessian-vector product: d(F·v)/dpos.""" - print_section("Double Backward (Hessian-Vector Product)") + """Compare position gradients via loss = sum(forces^2).""" + print_section("Double Backward (Position Gradients)") - torch.manual_seed(123) - v = torch.randn_like(ref_forces) + ref_model.train() + cur_model.train() + state_attr = torch.tensor([0.0, 0.0], device=device) + + ref_graph.pos.retain_grad() + cur_graph.pos.retain_grad() + + # Reference + ref_energy = ref_model(g=ref_graph, state_attr=state_attr) + ref_forces = torch.autograd.grad(ref_energy, ref_graph.pos, create_graph=True)[0] + ref_loss = (ref_forces * ref_forces).sum() + ref_loss.backward() + ref_pos_grad = ref_graph.pos.grad.clone() - ref_Hv = torch.autograd.grad((ref_forces * v).sum(), ref_graph.pos, retain_graph=True)[0] - cur_Hv = torch.autograd.grad((cur_forces * v).sum(), cur_graph.pos, retain_graph=True)[0] + # Current + cur_energy = cur_model(g=cur_graph, state_attr=state_attr) + cur_forces = torch.autograd.grad(cur_energy, cur_graph.pos, create_graph=True)[0] + cur_loss = (cur_forces * cur_forces).sum() + cur_loss.backward() + cur_pos_grad = cur_graph.pos.grad.clone() - print(f"Reference: mean={ref_Hv.mean():.6f}, std={ref_Hv.std():.6f}") - print(f"Current: mean={cur_Hv.mean():.6f}, std={cur_Hv.std():.6f}") + forces_diff = (ref_forces - cur_forces).abs() + print(f"Forces: max_diff={forces_diff.max():.2e}, mean_diff={forces_diff.mean():.2e}") - if ref_Hv.abs().max() < 1e-10 or cur_Hv.abs().max() < 1e-10: - print("WARNING: Hessian-vector product is nearly zero") + print(f"Reference pos.grad: mean={ref_pos_grad.mean():.6f}, std={ref_pos_grad.std():.6f}") + print(f"Current pos.grad: mean={cur_pos_grad.mean():.6f}, std={cur_pos_grad.std():.6f}") - diff = (ref_Hv - cur_Hv).abs() + if ref_pos_grad.abs().max() < 1e-10 or cur_pos_grad.abs().max() < 1e-10: + print("WARNING: Position gradient is nearly zero") + + diff = (ref_pos_grad - cur_pos_grad).abs() print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") match = diff.max().item() < 1e-4 @@ -238,12 +241,8 @@ def compare_double_backward( return match -# ============================================================================= -# Main -# ============================================================================= - def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: - """Run all comparison tests.""" + """Run all comparison tests between reference and current implementations.""" print_section("TensorNet Comparison: matgl-main vs Current") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -256,7 +255,7 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: model_config = {**MODEL_CONFIG, "element_types": element_types} - # Load reference model (matgl-main) + # Reference model (matgl-main) clear_matgl_modules() sys.path.insert(0, matgl_main_path) @@ -270,10 +269,11 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: ref_graph = build_graph(ref_converter, structure, device, ref_compute_bond) ref_graph_grad = build_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + ref_graph_grad2 = build_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) sys.path.pop(0) - # Load current model (src) + # Current model (src) clear_matgl_modules() from matgl.models._tensornet_pyg import TensorNet as CurTensorNet @@ -285,6 +285,7 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: cur_graph = build_graph(cur_converter, structure, device) cur_graph_grad = build_graph(cur_converter, structure, device, requires_grad=True) + cur_graph_grad2 = build_graph(cur_converter, structure, device, requires_grad=True) print(f"Models: {sum(p.numel() for p in ref_model.parameters())} params each") @@ -292,14 +293,12 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: results = { "Weights": compare_weights(ref_model, cur_model), "Forward": compare_forward(ref_model, cur_model, ref_graph, cur_graph, device), + "Backward": compare_backward(ref_model, cur_model, ref_graph_grad, cur_graph_grad, device), + "Double Backward": compare_double_backward( + ref_model, cur_model, ref_graph_grad2, cur_graph_grad2, device + ), } - back_ok, ref_f, cur_f, ref_g, cur_g = compare_backward( - ref_model, cur_model, ref_graph_grad, cur_graph_grad, device - ) - results["Backward"] = back_ok - results["Double Backward"] = compare_double_backward(ref_f, cur_f, ref_g, cur_g) - # Summary print_section("SUMMARY") all_pass = all(results.values()) diff --git a/tests/models/test_tensornet_pyg.py b/tests/models/test_tensornet_pyg.py index a91673ce..c2ff494f 100644 --- a/tests/models/test_tensornet_pyg.py +++ b/tests/models/test_tensornet_pyg.py @@ -106,3 +106,56 @@ def test_model_intensive_with_classification(self, graph_MoS_pyg): ) output = model(g=graph) assert torch.numel(output) == 1 + + def test_backward(self, graph_MoS_pyg): + """Test cell gradient (dE/dcell).""" + torch.manual_seed(0) + torch.use_deterministic_algorithms(True) + + EXPECTED_CELL_GRAD = torch.tensor([ + [-0.000967, 0.000000, 0.000000], + [0.000000, -0.000967, 0.000000], + [0.000000, 0.000000, -0.000967], + ]) + + structure, graph, _ = graph_MoS_pyg + cell = torch.tensor(structure.lattice.matrix, dtype=matgl.float_th).requires_grad_(True) + + graph.pbc_offshift = torch.matmul(graph.pbc_offset, cell) + graph.pos = graph.frac_coords @ cell + + model = TensorNet(is_intensive=False, activation_type="swish") + model.train() + + energy = model(g=graph) + cell_grad = torch.autograd.grad(energy, cell, create_graph=True)[0] + + assert torch.allclose(cell_grad, EXPECTED_CELL_GRAD, atol=1e-6) + + def test_double_backward(self, graph_MoS_pyg): + """Test double backward: loss = sum(cell_grad^2), compare cell.grad.""" + torch.manual_seed(0) + torch.use_deterministic_algorithms(True) + + EXPECTED_CELL_GRAD2 = torch.tensor([ + [-0.000010, -0.000000, -0.000000], + [-0.000000, -0.000010, -0.000000], + [-0.000000, -0.000000, -0.000010], + ]) + + structure, graph, _ = graph_MoS_pyg + cell = torch.tensor(structure.lattice.matrix, dtype=matgl.float_th).requires_grad_(True) + cell.retain_grad() + + graph.pbc_offshift = torch.matmul(graph.pbc_offset, cell) + graph.pos = graph.frac_coords @ cell + + model = TensorNet(is_intensive=False, activation_type="swish") + model.train() + + energy = model(g=graph) + cell_grad = torch.autograd.grad(energy, cell, create_graph=True)[0] + loss = (cell_grad * cell_grad).sum() + loss.backward() + + assert torch.allclose(cell.grad, EXPECTED_CELL_GRAD2, atol=1e-6) From 10bc7e62b7280b51af988773f30be91982d547f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 21:42:14 +0000 Subject: [PATCH 05/18] pre-commit auto-fixes --- dev/test_tensornet_forward_backward.py | 19 +++++-------- src/matgl/kernels/__init__.py | 17 +++++------ src/matgl/kernels/compose_tensor.py | 19 +++++++------ src/matgl/kernels/decompose_tensor.py | 31 +++++++++++---------- src/matgl/kernels/equivariant_o3_matmul.py | 1 + src/matgl/kernels/equivariant_so3_matmul.py | 1 + src/matgl/kernels/graph_transform.py | 1 + src/matgl/kernels/tensor_norm3.py | 1 + src/matgl/kernels/tensornet_mp.py | 1 + src/matgl/kernels/tensornet_radial_mp.py | 1 + src/matgl/kernels/utils.py | 28 ++++++++----------- src/matgl/models/_tensornet_pyg.py | 12 ++++---- src/matgl/ops/__init__.py | 14 +++++----- src/matgl/ops/compose_tensor.py | 14 ++++------ src/matgl/ops/decompose_tensor.py | 20 ++++++------- src/matgl/ops/equivariant_o3_matmul.py | 21 +++++++------- src/matgl/ops/equivariant_so3_matmul.py | 14 ++++------ src/matgl/ops/graph_transform.py | 18 ++++++------ src/matgl/ops/tensor_norm3.py | 14 ++++------ src/matgl/ops/tensornet_mp.py | 20 ++++++------- src/matgl/ops/tensornet_radial_mp.py | 20 ++++++------- tests/models/test_tensornet_pyg.py | 24 +++++++++------- 22 files changed, 149 insertions(+), 162 deletions(-) diff --git a/dev/test_tensornet_forward_backward.py b/dev/test_tensornet_forward_backward.py index d9466720..ead8a877 100644 --- a/dev/test_tensornet_forward_backward.py +++ b/dev/test_tensornet_forward_backward.py @@ -146,9 +146,7 @@ def compare_weights(ref_model: Any, cur_model: Any) -> bool: return all_match -def compare_forward( - ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device -) -> bool: +def compare_forward(ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device) -> bool: """Compare forward pass energy predictions.""" print_section("Forward Pass") @@ -169,9 +167,7 @@ def compare_forward( return match -def compare_backward( - ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device -) -> bool: +def compare_backward(ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device) -> bool: """Compare forces (F = -dE/dpos).""" print_section("Backward Pass (Forces)") @@ -259,9 +255,9 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: clear_matgl_modules() sys.path.insert(0, matgl_main_path) - from matgl.models._tensornet_pyg import TensorNet as RefTensorNet from matgl.ext._pymatgen_pyg import Structure2Graph as RefConverter from matgl.graph._compute_pyg import compute_pair_vector_and_distance as ref_compute_bond + from matgl.models._tensornet_pyg import TensorNet as RefTensorNet torch.manual_seed(seed) ref_model = RefTensorNet(**model_config).to(device) @@ -276,8 +272,8 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: # Current model (src) clear_matgl_modules() - from matgl.models._tensornet_pyg import TensorNet as CurTensorNet from matgl.ext._pymatgen_pyg import Structure2Graph as CurConverter + from matgl.models._tensornet_pyg import TensorNet as CurTensorNet torch.manual_seed(seed) cur_model = CurTensorNet(**model_config).to(device) @@ -294,9 +290,7 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: "Weights": compare_weights(ref_model, cur_model), "Forward": compare_forward(ref_model, cur_model, ref_graph, cur_graph, device), "Backward": compare_backward(ref_model, cur_model, ref_graph_grad, cur_graph_grad, device), - "Double Backward": compare_double_backward( - ref_model, cur_model, ref_graph_grad2, cur_graph_grad2, device - ), + "Double Backward": compare_double_backward(ref_model, cur_model, ref_graph_grad2, cur_graph_grad2, device), } # Summary @@ -318,7 +312,8 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: parser = argparse.ArgumentParser(description="Compare TensorNet implementations") parser.add_argument( - "--structure", "-s", + "--structure", + "-s", required=True, help="Path to structure file (any format supported by pymatgen)", ) diff --git a/src/matgl/kernels/__init__.py b/src/matgl/kernels/__init__.py index 5a945896..5e8c9a3f 100644 --- a/src/matgl/kernels/__init__.py +++ b/src/matgl/kernels/__init__.py @@ -25,33 +25,34 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import warp as wp from .compose_tensor import generate_compose_tensor from .decompose_tensor import generate_decompose_tensor from .equivariant_o3_matmul import generate_tensor_matmul_o3_3x3 from .equivariant_so3_matmul import generate_tensor_matmul_so3_3x3 -from .graph_transform import count_row_col, convert_to_sparse +from .graph_transform import convert_to_sparse, count_row_col from .tensor_norm3 import generate_tensor_norm3 from .tensornet_mp import generate_message_passing from .tensornet_radial_mp import generate_radial_message_passing from .utils import add_module, get_module, get_stream -import warp as wp - wp.init() __all__ = [ + "add_module", + "convert_to_sparse", + "count_row_col", "generate_compose_tensor", "generate_decompose_tensor", + "generate_message_passing", + "generate_radial_message_passing", "generate_tensor_matmul_o3_3x3", "generate_tensor_matmul_so3_3x3", - "generate_radial_message_passing", - "generate_message_passing", "generate_tensor_norm3", - "count_row_col", - "convert_to_sparse", - "add_module", "get_module", "get_stream", ] diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py index 91d46b3a..d40f113e 100644 --- a/src/matgl/kernels/compose_tensor.py +++ b/src/matgl/kernels/compose_tensor.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp @@ -75,7 +76,7 @@ def compose_tensor_fwd( for i in range(3): X_reg[i, i] += I_reg - cnt = int(0) + cnt = 0 for i in range(3): for j in range(i + 1, 3): X_reg[i, j] += A_reg[cnt] @@ -83,7 +84,7 @@ def compose_tensor_fwd( cnt += 1 trace_S = -(S_reg[0] + S_reg[3]) - cnt = int(0) + cnt = 0 for i in range(2): X_reg[i, i] += S_reg[cnt] cnt += 1 @@ -118,12 +119,12 @@ def compose_tensor_bwd( for i in range(3): dI_reg += dX_reg[i, i] - cnt = int(0) + cnt = 0 for i in range(3): for j in range(i + 1, 3): dA_reg[cnt] += dX_reg[i, j] dA_reg[cnt] -= dX_reg[j, i] - cnt += int(1) + cnt += 1 dS_reg[0] += dX_reg[0, 0] dS_reg[0] -= dX_reg[2, 2] @@ -170,22 +171,22 @@ def compose_tensor_bwd_bwd( for i in range(3): d2X_reg[i, i] += dI_reg - cnt = int(0) + cnt = 0 for i in range(3): for j in range(i + 1, 3): d2X_reg[i, j] += dA_reg[cnt] d2X_reg[j, i] -= dA_reg[cnt] - cnt += int(1) + cnt += 1 - cnt = int(0) + cnt = 0 for i in range(2): d2X_reg[i, i] += dS_reg[cnt] - cnt += int(1) + cnt += 1 for j in range(i + 1, 3): d2X_reg[i, j] += dS_reg[cnt] d2X_reg[j, i] += dS_reg[cnt] - cnt += int(1) + cnt += 1 d2X_reg[2, 2] -= dS_reg[0] d2X_reg[2, 2] -= dS_reg[3] diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py index 9ce8adfb..3f81bd71 100644 --- a/src/matgl/kernels/decompose_tensor.py +++ b/src/matgl/kernels/decompose_tensor.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp @@ -75,20 +76,20 @@ def decompose_tensor_fwd( I[b, 0, h] = res denom = X.dtype(2.0) - cnt = int(0) + cnt = 0 for i in range(2): for j in range(i + 1, 3): A[b, cnt, h] = (X_reg[i, j] - X_reg[j, i]) / denom - cnt += int(1) + cnt += 1 - cnt = int(0) + cnt = 0 for i in range(2): S[b, cnt, h] = X_reg[i, i] - res - cnt += int(1) + cnt += 1 for j in range(i + 1, 3): S[b, cnt, h] = (X_reg[i, j] + X_reg[j, i]) / denom - cnt += int(1) + cnt += 1 def decompose_tensor_bwd( dI: wp.array(ndim=dim, dtype=dtype_wp), @@ -115,27 +116,27 @@ def decompose_tensor_bwd( denom = dX.dtype(2.0) - cnt = int(0) + cnt = 0 for i in range(3): for j in range(i + 1, 3): dX_reg[i, j] += dA_reg[cnt] / denom dX_reg[j, i] -= dA_reg[cnt] / denom - cnt += int(1) + cnt += 1 - cnt = int(0) + cnt = 0 for i in range(2): dX_reg[i, i] += dS_reg[cnt] for j in range(3): dX_reg[j, j] -= dS_reg[cnt] / dI.dtype(3.0) - cnt += int(1) + cnt += 1 for j in range(i + 1, 3): dX_reg[i, j] += dS_reg[cnt] / denom dX_reg[j, i] += dS_reg[cnt] / denom - cnt += int(1) + cnt += 1 for i in range(3): for j in range(3): @@ -163,25 +164,25 @@ def decompose_tensor_bwd_bwd( denom = dX.dtype(2.0) - cnt = int(0) + cnt = 0 for i in range(3): for j in range(i + 1, 3): d2A_reg[cnt] += dX_reg[i, j] / denom d2A_reg[cnt] -= dX_reg[j, i] / denom - cnt += int(1) + cnt += 1 - cnt = int(0) + cnt = 0 for i in range(2): d2S_reg[cnt] += dX_reg[i, i] for j in range(3): d2S_reg[cnt] -= dX_reg[j, j] / d2I.dtype(3.0) - cnt += int(1) + cnt += 1 for j in range(i + 1, 3): d2S_reg[cnt] += dX_reg[i, j] / denom d2S_reg[cnt] += dX_reg[j, i] / denom - cnt += int(1) + cnt += 1 d2I[b, 0, h] = d2I_reg for i in range(3): diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py index 8f170ee8..1d9d47fc 100644 --- a/src/matgl/kernels/equivariant_o3_matmul.py +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index abe88120..fab2ebca 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py index 655bb73c..64b20bf5 100644 --- a/src/matgl/kernels/graph_transform.py +++ b/src/matgl/kernels/graph_transform.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py index 032cb6c7..c255a753 100644 --- a/src/matgl/kernels/tensor_norm3.py +++ b/src/matgl/kernels/tensor_norm3.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/tensornet_mp.py b/src/matgl/kernels/tensornet_mp.py index c2a8246f..016de513 100644 --- a/src/matgl/kernels/tensornet_mp.py +++ b/src/matgl/kernels/tensornet_mp.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/tensornet_radial_mp.py b/src/matgl/kernels/tensornet_radial_mp.py index 27d39a2b..f223884c 100644 --- a/src/matgl/kernels/tensornet_radial_mp.py +++ b/src/matgl/kernels/tensornet_radial_mp.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/utils.py b/src/matgl/kernels/utils.py index cb567243..241a3754 100644 --- a/src/matgl/kernels/utils.py +++ b/src/matgl/kernels/utils.py @@ -25,16 +25,15 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations - -from typing import List -import warp as wp import torch +import warp as wp MODULES = {} -def get_module(name: str, dtype: List[str]): +def get_module(name: str, dtype: list[str]): """ Get the module for the given name and dtype """ @@ -46,7 +45,7 @@ def get_module(name: str, dtype: List[str]): return MODULES[full_name] -def add_module(name: str, dtype: List[str], kernel: wp.Kernel): +def add_module(name: str, dtype: list[str], kernel: wp.Kernel): """ Add the module for the given name and dtype """ @@ -63,12 +62,11 @@ def get_dtype(dtype: str): """ if dtype.endswith("16"): return "fp16" - elif dtype.endswith("32"): + if dtype.endswith("32"): return "fp32" - elif dtype.endswith("64"): + if dtype.endswith("64"): return "fp64" - else: - raise ValueError(f"Unsupported dtype: {dtype}") + raise ValueError(f"Unsupported dtype: {dtype}") def get_wp_fp_dtype(dtype: str): @@ -78,12 +76,11 @@ def get_wp_fp_dtype(dtype: str): """ if dtype.endswith("16"): return wp.float16 - elif dtype.endswith("32"): + if dtype.endswith("32"): return wp.float32 - elif dtype.endswith("64"): + if dtype.endswith("64"): return wp.float64 - else: - raise ValueError(f"Unsupported dtype: {dtype}") + raise ValueError(f"Unsupported dtype: {dtype}") def list_modules(): @@ -91,7 +88,7 @@ def list_modules(): List all modules in the MODULES dictionary """ print("Available modules:") - for name in MODULES.keys(): + for name in MODULES: print(f" - {name}") return list(MODULES.keys()) @@ -102,5 +99,4 @@ def get_stream(device: torch.device): """ if device.type == "cuda": return wp.stream_from_torch(torch.cuda.current_stream(device)) - else: - return None + return None diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 7a0e52a9..78ce0f12 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -11,10 +11,10 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Literal, Mapping, Any +from typing import TYPE_CHECKING, Literal import torch -from torch import nn, Tensor +from torch import nn import matgl from matgl.config import DEFAULT_ELEMENTS @@ -28,20 +28,18 @@ WeightedAtomReadOut, WeightedReadOut, ) -from matgl.utils.cutoff import cosine_cutoff -from matgl.utils.maths import scatter_add - from matgl.ops import ( - fn_radial_message_passing, fn_compose_tensor, fn_decompose_tensor, - fn_tensor_norm3, fn_message_passing, fn_radial_message_passing, fn_tensor_matmul_o3_3x3, fn_tensor_matmul_so3_3x3, + fn_tensor_norm3, graph_transform, ) +from matgl.utils.cutoff import cosine_cutoff +from matgl.utils.maths import scatter_add from ._core import MatGLModel diff --git a/src/matgl/ops/__init__.py b/src/matgl/ops/__init__.py index 20428c65..18b905f9 100644 --- a/src/matgl/ops/__init__.py +++ b/src/matgl/ops/__init__.py @@ -25,29 +25,29 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import warp as wp -from .tensornet_radial_mp import fn_radial_message_passing from .compose_tensor import fn_compose_tensor from .decompose_tensor import fn_decompose_tensor from .equivariant_o3_matmul import fn_tensor_matmul_o3_3x3 from .equivariant_so3_matmul import fn_tensor_matmul_so3_3x3 +from .graph_transform import graph_transform from .tensor_norm3 import fn_tensor_norm3 from .tensornet_mp import fn_message_passing from .tensornet_radial_mp import fn_radial_message_passing -from .graph_transform import graph_transform - -import warp as wp wp.init() __all__ = [ - "fn_radial_message_passing", "fn_compose_tensor", "fn_decompose_tensor", + "fn_message_passing", + "fn_radial_message_passing", + "fn_radial_message_passing", "fn_tensor_matmul_o3_3x3", "fn_tensor_matmul_so3_3x3", "fn_tensor_norm3", - "fn_message_passing", - "fn_radial_message_passing", "graph_transform", ] diff --git a/src/matgl/ops/compose_tensor.py b/src/matgl/ops/compose_tensor.py index bf948d3d..78e141b2 100644 --- a/src/matgl/ops/compose_tensor.py +++ b/src/matgl/ops/compose_tensor.py @@ -25,13 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -74,7 +72,7 @@ def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output: Tensor, x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, x: Tensor, y: Tensor, z: Tensor) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) grad_x = torch.zeros_like(x) @@ -100,7 +98,7 @@ def _(grad_output: Tensor, x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: @torch.library.register_fake("tensornet::compose_tensor_bwd_primitive") -def _(grad_output: List[Tensor], x: Tensor, y: Tensor, z: Tensor) -> List[Tensor]: +def _(grad_output: list[Tensor], x: Tensor, y: Tensor, z: Tensor) -> list[Tensor]: return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] @@ -117,7 +115,7 @@ def _( x: Tensor, y: Tensor, z: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: stream = get_stream(grad_output.device) device = wp.device_from_torch(grad_output.device) grad_x = torch.zeros_like(grad_grad_x) @@ -153,7 +151,7 @@ def _( x: Tensor, y: Tensor, z: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [ torch.empty_like(grad_output), torch.empty_like(x), diff --git a/src/matgl/ops/decompose_tensor.py b/src/matgl/ops/decompose_tensor.py index 39197618..42a92744 100644 --- a/src/matgl/ops/decompose_tensor.py +++ b/src/matgl/ops/decompose_tensor.py @@ -25,13 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -41,7 +39,7 @@ mutates_args=(), device_types=["cpu", "cuda"], ) -def _(x: Tensor) -> List[Tensor]: +def _(x: Tensor) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) output_i = torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device) @@ -66,7 +64,7 @@ def _(x: Tensor) -> List[Tensor]: @torch.library.register_fake("tensornet::decompose_tensor_fwd_primitive") -def _(x: Tensor) -> List[Tensor]: +def _(x: Tensor) -> list[Tensor]: return [ torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device), torch.empty((x.shape[0], 3, x.shape[-1]), dtype=x.dtype, device=x.device), @@ -79,7 +77,7 @@ def _(x: Tensor) -> List[Tensor]: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> List[Tensor]: +def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) grad_x = torch.empty_like(x) @@ -103,7 +101,7 @@ def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Te @torch.library.register_fake("tensornet::decompose_tensor_bwd_primitive") -def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> List[Tensor]: +def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> list[Tensor]: return [torch.empty_like(x)] @@ -118,7 +116,7 @@ def _( grad_output_s: Tensor, grad_grad_x: Tensor, x: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: stream = get_stream(grad_output_i.device) device = wp.device_from_torch(grad_output_i.device) grad_x = torch.zeros_like(grad_grad_x) @@ -157,7 +155,7 @@ def _( grad_output_s: Tensor, grad_grad_x: Tensor, x: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [ torch.empty_like(grad_output_i), torch.empty_like(grad_output_a), @@ -218,6 +216,6 @@ def decompose_tensor_bwd_bwd(ctx, *grad_outputs): ) -def fn_decompose_tensor(x: Tensor) -> List[Tensor]: +def fn_decompose_tensor(x: Tensor) -> list[Tensor]: output = torch.ops.tensornet.decompose_tensor_fwd_primitive(x) return output diff --git a/src/matgl/ops/equivariant_o3_matmul.py b/src/matgl/ops/equivariant_o3_matmul.py index c9db9989..570f0836 100644 --- a/src/matgl/ops/equivariant_o3_matmul.py +++ b/src/matgl/ops/equivariant_o3_matmul.py @@ -25,12 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -76,7 +75,7 @@ def _(x: Tensor, y: Tensor) -> Tensor: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: raise ValueError("x and y must be 3x3 matrices") if x.ndim != 4 or y.ndim != 4: @@ -105,7 +104,7 @@ def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: @torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_primitive") -def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: list[Tensor], x: Tensor, y: Tensor) -> list[Tensor]: return [torch.empty_like(x), torch.empty_like(y)] @@ -114,7 +113,7 @@ def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: raise ValueError("x and y must be 3x3 matrices") if x.ndim != 4 or y.ndim != 4: @@ -159,7 +158,7 @@ def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, @torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive") -def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: return [ torch.empty_like(grad_output), torch.empty_like(grad_output), @@ -179,13 +178,13 @@ def tensor_matmul_o3_3x3_setup_bwd_context(ctx, inputs, output): @torch.compiler.allow_in_graph def tensor_matmul_o3_3x3_fwd(*args): - return getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_fwd_primitive")(*args) + return torch.ops.tensornet.tensor_matmul_o3_3x3_fwd_primitive(*args) @torch.compiler.allow_in_graph def tensor_matmul_o3_3x3_bwd(ctx, grad_output): x, y = ctx.saved_tensors - dx, dy = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_primitive")(grad_output, x, y) + dx, dy = torch.ops.tensornet.tensor_matmul_o3_3x3_bwd_primitive(grad_output, x, y) return dx, dy @@ -196,7 +195,7 @@ def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): grad_output_saved, x, y = ctx.saved_tensors - outputs = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_bwd_bwd_primitive")( + outputs = torch.ops.tensornet.tensor_matmul_o3_3x3_bwd_bwd_primitive( grad_output_saved, grad_grad_x, grad_grad_y, x, y ) return outputs[0], outputs[1], outputs[2] @@ -216,5 +215,5 @@ def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): def fn_tensor_matmul_o3_3x3(x: Tensor, y: Tensor) -> Tensor: - z = getattr(torch.ops.tensornet, "tensor_matmul_o3_3x3_fwd_primitive")(x, y) + z = torch.ops.tensornet.tensor_matmul_o3_3x3_fwd_primitive(x, y) return z diff --git a/src/matgl/ops/equivariant_so3_matmul.py b/src/matgl/ops/equivariant_so3_matmul.py index 1c0880c6..739b9366 100644 --- a/src/matgl/ops/equivariant_so3_matmul.py +++ b/src/matgl/ops/equivariant_so3_matmul.py @@ -25,13 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -76,7 +74,7 @@ def _(x: Tensor, y: Tensor) -> Tensor: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: raise ValueError("x and y must be 3x3 matrices") if x.ndim != 4 or y.ndim != 4: @@ -104,7 +102,7 @@ def _(grad_output: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: @torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_primitive") -def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: list[Tensor], x: Tensor, y: Tensor) -> list[Tensor]: return [torch.empty_like(x), torch.empty_like(y)] @@ -113,7 +111,7 @@ def _(grad_output: List[Tensor], x: Tensor, y: Tensor) -> List[Tensor]: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: raise ValueError("x and y must be 3x3 matrices") if x.ndim != 4 or y.ndim != 4: @@ -157,7 +155,7 @@ def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, @torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive") -def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: return [ torch.empty_like(grad_output), torch.empty_like(grad_output), diff --git a/src/matgl/ops/graph_transform.py b/src/matgl/ops/graph_transform.py index a4a3274a..8e4c481e 100644 --- a/src/matgl/ops/graph_transform.py +++ b/src/matgl/ops/graph_transform.py @@ -25,15 +25,13 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import Tuple +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor -from matgl.kernels import count_row_col, convert_to_sparse, get_stream +from matgl.kernels import convert_to_sparse, count_row_col, get_stream @torch.library.custom_op( @@ -41,7 +39,7 @@ mutates_args=(), device_types=["cpu", "cuda"], ) -def _(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor]: +def _(edge_index: Tensor, num_nodes: int) -> tuple[Tensor, Tensor]: stream = get_stream(edge_index.device) device = wp.device_from_torch(edge_index.device) row_count = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) @@ -63,7 +61,7 @@ def _(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor]: @torch.library.register_fake("nvtnet::count_row_col_primitive") -def _(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor]: +def _(edge_index: Tensor, num_nodes: int) -> tuple[Tensor, Tensor]: output = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) output2 = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) return output, output2 @@ -80,7 +78,7 @@ def _( col_count: Tensor, row_indptr: Tensor, col_indptr: Tensor, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: stream = get_stream(edge_index.device) device = wp.device_from_torch(edge_index.device) edge_index_wp = wp.from_torch(edge_index, return_ctype=True) @@ -131,7 +129,7 @@ def _( col_count: Tensor, row_indptr: Tensor, col_indptr: Tensor, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[Tensor, Tensor, Tensor, Tensor]: output = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) output2 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) output3 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) @@ -140,7 +138,7 @@ def _( @torch.compiler.allow_in_graph -def graph_transform(edge_index: Tensor, num_nodes: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: +def graph_transform(edge_index: Tensor, num_nodes: int) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: row_count, col_count = torch.ops.nvtnet.count_row_col_primitive(edge_index, num_nodes) row_indptr, col_indptr = ( torch.cumsum(row_count, dim=0, dtype=torch.int32), diff --git a/src/matgl/ops/tensor_norm3.py b/src/matgl/ops/tensor_norm3.py index 60d0e793..00e0f971 100644 --- a/src/matgl/ops/tensor_norm3.py +++ b/src/matgl/ops/tensor_norm3.py @@ -25,13 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -71,7 +69,7 @@ def _(x: Tensor) -> Tensor: mutates_args=(), device_types=["cpu", "cuda"], ) -def _(grad_output: Tensor, x: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, x: Tensor) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) grad_x = torch.empty_like(x) @@ -93,7 +91,7 @@ def _(grad_output: Tensor, x: Tensor) -> List[Tensor]: @torch.library.register_fake("tensornet::tensor_norm3_bwd_primitive") -def _(grad_output: Tensor, x: Tensor) -> List[Tensor]: +def _(grad_output: Tensor, x: Tensor) -> list[Tensor]: return [torch.empty_like(x)] @@ -106,7 +104,7 @@ def _( grad_grad_x: Tensor, x: Tensor, grad_output: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: stream = get_stream(grad_grad_x.device) device = wp.device_from_torch(grad_grad_x.device) grad_grad_output = torch.empty( @@ -145,7 +143,7 @@ def _( grad_grad_x: Tensor, x: Tensor, grad_output: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [ torch.empty( (grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), diff --git a/src/matgl/ops/tensornet_mp.py b/src/matgl/ops/tensornet_mp.py index f22c72d4..22ff2848 100644 --- a/src/matgl/ops/tensornet_mp.py +++ b/src/matgl/ops/tensornet_mp.py @@ -25,13 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -52,7 +50,7 @@ def _( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) output_x = torch.empty_like(x) @@ -108,7 +106,7 @@ def _( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] @@ -131,7 +129,7 @@ def _( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) grad_x = torch.empty_like(x) @@ -202,7 +200,7 @@ def _( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [ torch.empty_like(x), torch.empty_like(y), @@ -234,7 +232,7 @@ def _( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: stream = get_stream(x.device) device = wp.device_from_torch(x.device) @@ -361,7 +359,7 @@ def _( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [ torch.empty_like(grad_output_x), torch.empty_like(grad_output_y), @@ -555,7 +553,7 @@ def fn_message_passing( col_data: Tensor, col_indices: Tensor, col_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return torch.ops.tensornet.message_passing_fwd_primitive( x, y, diff --git a/src/matgl/ops/tensornet_radial_mp.py b/src/matgl/ops/tensornet_radial_mp.py index 1a10b575..d3070ec9 100644 --- a/src/matgl/ops/tensornet_radial_mp.py +++ b/src/matgl/ops/tensornet_radial_mp.py @@ -25,13 +25,11 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List +from __future__ import annotations import torch -from torch import Tensor - import warp as wp +from torch import Tensor from matgl.kernels import get_module, get_stream @@ -41,7 +39,7 @@ mutates_args=(), device_types=["cpu", "cuda"], ) -def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> List[Tensor]: +def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> list[Tensor]: num_atoms = row_indptr.shape[0] - 1 stream = get_stream(edge_vec_norm.device) device = wp.device_from_torch(edge_vec_norm.device) @@ -92,7 +90,7 @@ def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Te @torch.library.register_fake("tensornet::radial_message_passing_fwd_primitive") -def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> List[Tensor]: +def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> list[Tensor]: num_atoms = row_indptr.shape[0] - 1 return [ torch.empty( @@ -126,7 +124,7 @@ def _( edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: num_atoms = row_indptr.shape[0] - 1 stream = get_stream(grad_output_I.device) device = wp.device_from_torch(grad_output_I.device) @@ -178,7 +176,7 @@ def _( edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [torch.empty_like(edge_vec_norm), torch.empty_like(edge_attr)] @@ -197,7 +195,7 @@ def _( edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: num_atoms = row_indptr.shape[0] - 1 stream = get_stream(grad_output_I.device) device = wp.device_from_torch(grad_output_I.device) @@ -270,7 +268,7 @@ def _( edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor, -) -> List[Tensor]: +) -> list[Tensor]: return [ torch.empty_like(grad_output_I), torch.empty_like(grad_output_A), @@ -388,5 +386,5 @@ def radial_message_passing_bwd_bwd(ctx, *grad_outputs): def fn_radial_message_passing( edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor -) -> List[Tensor]: +) -> list[Tensor]: return torch.ops.tensornet.radial_message_passing_fwd_primitive(edge_vec_norm, edge_attr, row_data, row_indptr) diff --git a/tests/models/test_tensornet_pyg.py b/tests/models/test_tensornet_pyg.py index c2ff494f..2a761422 100644 --- a/tests/models/test_tensornet_pyg.py +++ b/tests/models/test_tensornet_pyg.py @@ -112,11 +112,13 @@ def test_backward(self, graph_MoS_pyg): torch.manual_seed(0) torch.use_deterministic_algorithms(True) - EXPECTED_CELL_GRAD = torch.tensor([ - [-0.000967, 0.000000, 0.000000], - [0.000000, -0.000967, 0.000000], - [0.000000, 0.000000, -0.000967], - ]) + EXPECTED_CELL_GRAD = torch.tensor( + [ + [-0.000967, 0.000000, 0.000000], + [0.000000, -0.000967, 0.000000], + [0.000000, 0.000000, -0.000967], + ] + ) structure, graph, _ = graph_MoS_pyg cell = torch.tensor(structure.lattice.matrix, dtype=matgl.float_th).requires_grad_(True) @@ -137,11 +139,13 @@ def test_double_backward(self, graph_MoS_pyg): torch.manual_seed(0) torch.use_deterministic_algorithms(True) - EXPECTED_CELL_GRAD2 = torch.tensor([ - [-0.000010, -0.000000, -0.000000], - [-0.000000, -0.000010, -0.000000], - [-0.000000, -0.000000, -0.000010], - ]) + EXPECTED_CELL_GRAD2 = torch.tensor( + [ + [-0.000010, -0.000000, -0.000000], + [-0.000000, -0.000010, -0.000000], + [-0.000000, -0.000000, -0.000010], + ] + ) structure, graph, _ = graph_MoS_pyg cell = torch.tensor(structure.lattice.matrix, dtype=matgl.float_th).requires_grad_(True) From 4ec5210a041c27175bad22b4289aaa641e571d57 Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Fri, 9 Jan 2026 17:01:27 -0500 Subject: [PATCH 06/18] linting --- src/matgl/kernels/__init__.py | 7 ++++++ src/matgl/kernels/compose_tensor.py | 18 +++++++++++----- src/matgl/kernels/decompose_tensor.py | 24 ++++++++++++++++----- src/matgl/kernels/equivariant_o3_matmul.py | 2 ++ src/matgl/kernels/equivariant_so3_matmul.py | 2 ++ src/matgl/kernels/graph_transform.py | 5 +++-- src/matgl/kernels/tensor_norm3.py | 2 ++ src/matgl/models/_tensornet_pyg.py | 20 +++++++++++------ 8 files changed, 61 insertions(+), 19 deletions(-) diff --git a/src/matgl/kernels/__init__.py b/src/matgl/kernels/__init__.py index 5e8c9a3f..76b620e2 100644 --- a/src/matgl/kernels/__init__.py +++ b/src/matgl/kernels/__init__.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp GPU kernels for TensorNet operations.""" from __future__ import annotations import warp as wp @@ -34,6 +35,7 @@ from .equivariant_o3_matmul import generate_tensor_matmul_o3_3x3 from .equivariant_so3_matmul import generate_tensor_matmul_so3_3x3 from .graph_transform import convert_to_sparse, count_row_col +from .graph_transform import convert_to_sparse, count_row_col from .tensor_norm3 import generate_tensor_norm3 from .tensornet_mp import generate_message_passing from .tensornet_radial_mp import generate_radial_message_passing @@ -43,6 +45,9 @@ __all__ = [ + "add_module", + "convert_to_sparse", + "count_row_col", "add_module", "convert_to_sparse", "count_row_col", @@ -50,6 +55,8 @@ "generate_decompose_tensor", "generate_message_passing", "generate_radial_message_passing", + "generate_message_passing", + "generate_radial_message_passing", "generate_tensor_matmul_o3_3x3", "generate_tensor_matmul_so3_3x3", "generate_tensor_norm3", diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py index d40f113e..9b3d9139 100644 --- a/src/matgl/kernels/compose_tensor.py +++ b/src/matgl/kernels/compose_tensor.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for composing 3x3 tensors from I, A, S components.""" from __future__ import annotations import warp as wp @@ -33,6 +34,7 @@ def generate_compose_tensor(dtype: str, h_last: bool = True, use_irmem: bool = True): + """Generate Warp kernels for composing a 3x3 tensor from I, A, S components.""" dtype_wp = get_wp_fp_dtype(dtype) if not use_irmem: raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") @@ -48,13 +50,10 @@ class vec3(wp.types.vector(length=3, dtype=dtype_wp)): class vec5(wp.types.vector(length=5, dtype=dtype_wp)): pass - if use_irmem: - dim = 3 - else: - dim = 4 + dim = 3 if use_irmem else 4 def compose_tensor_fwd( - I: wp.array(ndim=dim, dtype=dtype_wp), + I: wp.array(ndim=dim, dtype=dtype_wp), # noqa: E741 A: wp.array(ndim=dim, dtype=dtype_wp), S: wp.array(ndim=dim, dtype=dtype_wp), X: wp.array(ndim=4, dtype=dtype_wp), @@ -76,6 +75,7 @@ def compose_tensor_fwd( for i in range(3): X_reg[i, i] += I_reg + cnt = 0 cnt = 0 for i in range(3): for j in range(i + 1, 3): @@ -85,6 +85,7 @@ def compose_tensor_fwd( trace_S = -(S_reg[0] + S_reg[3]) cnt = 0 + cnt = 0 for i in range(2): X_reg[i, i] += S_reg[cnt] cnt += 1 @@ -119,12 +120,14 @@ def compose_tensor_bwd( for i in range(3): dI_reg += dX_reg[i, i] + cnt = 0 cnt = 0 for i in range(3): for j in range(i + 1, 3): dA_reg[cnt] += dX_reg[i, j] dA_reg[cnt] -= dX_reg[j, i] cnt += 1 + cnt += 1 dS_reg[0] += dX_reg[0, 0] dS_reg[0] -= dX_reg[2, 2] @@ -171,22 +174,27 @@ def compose_tensor_bwd_bwd( for i in range(3): d2X_reg[i, i] += dI_reg + cnt = 0 cnt = 0 for i in range(3): for j in range(i + 1, 3): d2X_reg[i, j] += dA_reg[cnt] d2X_reg[j, i] -= dA_reg[cnt] cnt += 1 + cnt += 1 + cnt = 0 cnt = 0 for i in range(2): d2X_reg[i, i] += dS_reg[cnt] cnt += 1 + cnt += 1 for j in range(i + 1, 3): d2X_reg[i, j] += dS_reg[cnt] d2X_reg[j, i] += dS_reg[cnt] cnt += 1 + cnt += 1 d2X_reg[2, 2] -= dS_reg[0] d2X_reg[2, 2] -= dS_reg[3] diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py index 3f81bd71..3dbf3cfa 100644 --- a/src/matgl/kernels/decompose_tensor.py +++ b/src/matgl/kernels/decompose_tensor.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for decomposing 3x3 tensors into I, A, S components.""" from __future__ import annotations import warp as wp @@ -33,6 +34,7 @@ def generate_decompose_tensor(dtype: str, h_last: bool = True, use_irmem: bool = True): + """Generate Warp kernels for decomposing a 3x3 tensor into I, A, S components.""" dtype_wp = get_wp_fp_dtype(dtype) if not use_irmem: @@ -50,14 +52,11 @@ class vec3(wp.types.vector(length=3, dtype=dtype_wp)): class vec5(wp.types.vector(length=5, dtype=dtype_wp)): pass - if use_irmem: - dim = 3 - else: - dim = 4 + dim = 3 if use_irmem else 4 def decompose_tensor_fwd( X: wp.array(ndim=4, dtype=dtype_wp), - I: wp.array(ndim=dim, dtype=dtype_wp), + I: wp.array(ndim=dim, dtype=dtype_wp), # noqa: E741 A: wp.array(ndim=dim, dtype=dtype_wp), S: wp.array(ndim=dim, dtype=dtype_wp), ): @@ -77,19 +76,24 @@ def decompose_tensor_fwd( denom = X.dtype(2.0) cnt = 0 + cnt = 0 for i in range(2): for j in range(i + 1, 3): A[b, cnt, h] = (X_reg[i, j] - X_reg[j, i]) / denom cnt += 1 + cnt += 1 + cnt = 0 cnt = 0 for i in range(2): S[b, cnt, h] = X_reg[i, i] - res cnt += 1 + cnt += 1 for j in range(i + 1, 3): S[b, cnt, h] = (X_reg[i, j] + X_reg[j, i]) / denom cnt += 1 + cnt += 1 def decompose_tensor_bwd( dI: wp.array(ndim=dim, dtype=dtype_wp), @@ -116,6 +120,7 @@ def decompose_tensor_bwd( denom = dX.dtype(2.0) + cnt = 0 cnt = 0 for i in range(3): @@ -124,7 +129,9 @@ def decompose_tensor_bwd( dX_reg[j, i] -= dA_reg[cnt] / denom cnt += 1 + cnt += 1 + cnt = 0 cnt = 0 for i in range(2): dX_reg[i, i] += dS_reg[cnt] @@ -132,11 +139,13 @@ def decompose_tensor_bwd( dX_reg[j, j] -= dS_reg[cnt] / dI.dtype(3.0) cnt += 1 + cnt += 1 for j in range(i + 1, 3): dX_reg[i, j] += dS_reg[cnt] / denom dX_reg[j, i] += dS_reg[cnt] / denom cnt += 1 + cnt += 1 for i in range(3): for j in range(3): @@ -164,6 +173,7 @@ def decompose_tensor_bwd_bwd( denom = dX.dtype(2.0) + cnt = 0 cnt = 0 for i in range(3): @@ -171,18 +181,22 @@ def decompose_tensor_bwd_bwd( d2A_reg[cnt] += dX_reg[i, j] / denom d2A_reg[cnt] -= dX_reg[j, i] / denom cnt += 1 + cnt += 1 + cnt = 0 cnt = 0 for i in range(2): d2S_reg[cnt] += dX_reg[i, i] for j in range(3): d2S_reg[cnt] -= dX_reg[j, j] / d2I.dtype(3.0) cnt += 1 + cnt += 1 for j in range(i + 1, 3): d2S_reg[cnt] += dX_reg[i, j] / denom d2S_reg[cnt] += dX_reg[j, i] / denom cnt += 1 + cnt += 1 d2I[b, 0, h] = d2I_reg for i in range(3): diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py index 1d9d47fc..389f3438 100644 --- a/src/matgl/kernels/equivariant_o3_matmul.py +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for O(3)-equivariant 3x3 tensor matrix multiplication.""" from __future__ import annotations import warp as wp @@ -33,6 +34,7 @@ def generate_tensor_matmul_o3_3x3(dtype: str): + """Generate Warp kernels for O(3)-equivariant 3x3 matrix multiplication: C = AB + BA.""" dtype_wp = get_wp_fp_dtype(dtype) class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index fab2ebca..063bf82f 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for SO(3)-equivariant 3x3 tensor matrix multiplication.""" from __future__ import annotations import warp as wp @@ -33,6 +34,7 @@ def generate_tensor_matmul_so3_3x3(dtype: str): + """Generate Warp kernels for SO(3)-equivariant 3x3 matrix multiplication: C = AB.""" dtype_wp = get_wp_fp_dtype(dtype) class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py index 64b20bf5..b75a4a5e 100644 --- a/src/matgl/kernels/graph_transform.py +++ b/src/matgl/kernels/graph_transform.py @@ -25,13 +25,14 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for graph edge index transformation to sparse CSR format.""" from __future__ import annotations import warp as wp @wp.kernel -def count_row_col( +def count_row_col( # noqa: D103 edge_index: wp.array(ndim=2, dtype=wp.int32), row_count: wp.array(ndim=1, dtype=wp.int32), col_count: wp.array(ndim=1, dtype=wp.int32), @@ -44,7 +45,7 @@ def count_row_col( @wp.kernel -def convert_to_sparse( +def convert_to_sparse( # noqa: D103 edge_index: wp.array(ndim=2, dtype=wp.int32), row_count: wp.array(ndim=1, dtype=wp.int32), col_count: wp.array(ndim=1, dtype=wp.int32), diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py index c255a753..45d0ab84 100644 --- a/src/matgl/kernels/tensor_norm3.py +++ b/src/matgl/kernels/tensor_norm3.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for computing 3x3 tensor norms (I, A, S components).""" from __future__ import annotations import warp as wp @@ -33,6 +34,7 @@ def generate_tensor_norm3(dtype: str, h_last: bool = True, use_irmem: bool = True): + """Generate Warp kernels for computing squared norms of 3x3 tensor I, A, S components.""" dtype_wp = get_wp_fp_dtype(dtype) if not use_irmem: diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 78ce0f12..4b686a39 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -217,7 +217,7 @@ def forward( # Radial message passing edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) - I, A, S = fn_radial_message_passing(edge_vec_norm, edge_attr_processed, col_data, col_indptr) + I, A, S = fn_radial_message_passing(edge_vec_norm, edge_attr_processed, col_data, col_indptr) # noqa: E741 # Compose initial tensor to get proper shape for norm computation X = fn_compose_tensor(I, A, S) # (num_nodes, 3, 3, units) @@ -235,7 +235,7 @@ def forward( norm_I, norm_A, norm_S = norm.unbind(dim=-1) # Apply norm to tensors - I = self.linears_tensor[0](I) * norm_I.unsqueeze(-2) + I = self.linears_tensor[0](I) * norm_I.unsqueeze(-2) # noqa: E741 A = self.linears_tensor[1](A) * norm_A.unsqueeze(-2) S = self.linears_tensor[2](S) * norm_S.unsqueeze(-2) @@ -303,6 +303,12 @@ def forward( edge_index: Edge indices, shape (2, num_edges) edge_weight: Edge weights (distances), shape (num_edges,) edge_attr: Edge attributes (RBF), shape (num_edges, num_rbf) + row_data: CSR row data indices for message passing. + row_indices: CSR row indices for message passing. + row_indptr: CSR row pointers for message passing. + col_data: CSC column data indices for message passing. + col_indices: CSC column indices for message passing. + col_indptr: CSC column pointers for message passing. Returns: X: Updated tensor representations, shape (num_nodes, 3, 3, units) @@ -323,10 +329,10 @@ def forward( X = X / norm_X.view(-1, 1, 1, X.shape[-1]) # Decompose input tensor - I, A, S = fn_decompose_tensor(X) + I, A, S = fn_decompose_tensor(X) # noqa: E741 # Apply tensor linear transformations - I = self.linears_tensor[0](I) + I = self.linears_tensor[0](I) # noqa: E741 A = self.linears_tensor[1](A) S = self.linears_tensor[2](S) @@ -353,14 +359,14 @@ def forward( C = fn_tensor_matmul_o3_3x3(Y, msg) else: # SO(3) C = 2 * fn_tensor_matmul_so3_3x3(Y, msg) - I, A, S = fn_decompose_tensor(C) + I, A, S = fn_decompose_tensor(C) # noqa: E741 # Normalize normp1 = (tensor_norm(C) + 1).unsqueeze(-2) - I, A, S = I / normp1, A / normp1, S / normp1 + I, A, S = I / normp1, A / normp1, S / normp1 # noqa: E741 # Final tensor transformations - I = self.linears_tensor[3](I) + I = self.linears_tensor[3](I) # noqa: E741 A = self.linears_tensor[4](A) S = self.linears_tensor[5](S) dX = fn_compose_tensor(I, A, S) From 78ddd3f1f84b8360dc597fd56b7240ddcd30bcb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 22:15:09 +0000 Subject: [PATCH 07/18] pre-commit auto-fixes --- src/matgl/kernels/__init__.py | 8 ++++---- src/matgl/kernels/compose_tensor.py | 1 + src/matgl/kernels/decompose_tensor.py | 1 + src/matgl/kernels/equivariant_o3_matmul.py | 1 + src/matgl/kernels/equivariant_so3_matmul.py | 1 + src/matgl/kernels/graph_transform.py | 1 + src/matgl/kernels/tensor_norm3.py | 1 + 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/matgl/kernels/__init__.py b/src/matgl/kernels/__init__.py index 76b620e2..7367c920 100644 --- a/src/matgl/kernels/__init__.py +++ b/src/matgl/kernels/__init__.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp GPU kernels for TensorNet operations.""" + from __future__ import annotations import warp as wp @@ -35,7 +36,6 @@ from .equivariant_o3_matmul import generate_tensor_matmul_o3_3x3 from .equivariant_so3_matmul import generate_tensor_matmul_so3_3x3 from .graph_transform import convert_to_sparse, count_row_col -from .graph_transform import convert_to_sparse, count_row_col from .tensor_norm3 import generate_tensor_norm3 from .tensornet_mp import generate_message_passing from .tensornet_radial_mp import generate_radial_message_passing @@ -46,17 +46,17 @@ __all__ = [ "add_module", - "convert_to_sparse", - "count_row_col", "add_module", "convert_to_sparse", + "convert_to_sparse", + "count_row_col", "count_row_col", "generate_compose_tensor", "generate_decompose_tensor", "generate_message_passing", - "generate_radial_message_passing", "generate_message_passing", "generate_radial_message_passing", + "generate_radial_message_passing", "generate_tensor_matmul_o3_3x3", "generate_tensor_matmul_so3_3x3", "generate_tensor_norm3", diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py index 9b3d9139..8fe134e6 100644 --- a/src/matgl/kernels/compose_tensor.py +++ b/src/matgl/kernels/compose_tensor.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for composing 3x3 tensors from I, A, S components.""" + from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py index 3dbf3cfa..da2986db 100644 --- a/src/matgl/kernels/decompose_tensor.py +++ b/src/matgl/kernels/decompose_tensor.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for decomposing 3x3 tensors into I, A, S components.""" + from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py index 389f3438..639da41d 100644 --- a/src/matgl/kernels/equivariant_o3_matmul.py +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for O(3)-equivariant 3x3 tensor matrix multiplication.""" + from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index 063bf82f..bcf33238 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for SO(3)-equivariant 3x3 tensor matrix multiplication.""" + from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py index b75a4a5e..400ea33a 100644 --- a/src/matgl/kernels/graph_transform.py +++ b/src/matgl/kernels/graph_transform.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for graph edge index transformation to sparse CSR format.""" + from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py index 45d0ab84..540cc576 100644 --- a/src/matgl/kernels/tensor_norm3.py +++ b/src/matgl/kernels/tensor_norm3.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for computing 3x3 tensor norms (I, A, S components).""" + from __future__ import annotations import warp as wp From adf4fa256da00654232cd754ed762508894c6cdc Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Fri, 9 Jan 2026 18:20:48 -0500 Subject: [PATCH 08/18] fix warp version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 24e78f7c..54b5babc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ classifiers = [ dependencies = [ "ase", "torch<=2.7.0", # TODO: Remove this pin. For some reason, torch 2.9 gives different results. - "warp-lang>=10.1", + "warp-lang>=1.10.1", "torchdata", "pymatgen", "lightning<=2.6.0.dev20251123", From e979d3f819464f229528f934a4b86e4143393885 Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Fri, 9 Jan 2026 18:25:14 -0500 Subject: [PATCH 09/18] linting --- pyproject.toml | 2 ++ src/matgl/kernels/compose_tensor.py | 2 +- src/matgl/kernels/decompose_tensor.py | 2 +- src/matgl/kernels/graph_transform.py | 4 ++-- src/matgl/kernels/utils.py | 26 ++++++-------------------- src/matgl/ops/__init__.py | 1 + 6 files changed, 13 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 54b5babc..8bcfee38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,6 +159,8 @@ lint.isort.required-imports = ["from __future__ import annotations"] "tests/**/*" = ["D", "PERF"] "docs/**/*" = ["D"] "examples/**/*" = ["D"] +"src/matgl/kernels/*" = ["D100", "D103", "E741"] +"src/matgl/ops/*" = ["D100", "D103"] [tool.pytest.ini_options] addopts = "--durations=30 --quiet -rXs --color=yes -p no:warnings" diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py index 8fe134e6..f3766ce7 100644 --- a/src/matgl/kernels/compose_tensor.py +++ b/src/matgl/kernels/compose_tensor.py @@ -54,7 +54,7 @@ class vec5(wp.types.vector(length=5, dtype=dtype_wp)): dim = 3 if use_irmem else 4 def compose_tensor_fwd( - I: wp.array(ndim=dim, dtype=dtype_wp), # noqa: E741 + I: wp.array(ndim=dim, dtype=dtype_wp), A: wp.array(ndim=dim, dtype=dtype_wp), S: wp.array(ndim=dim, dtype=dtype_wp), X: wp.array(ndim=4, dtype=dtype_wp), diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py index da2986db..1a14d1dd 100644 --- a/src/matgl/kernels/decompose_tensor.py +++ b/src/matgl/kernels/decompose_tensor.py @@ -57,7 +57,7 @@ class vec5(wp.types.vector(length=5, dtype=dtype_wp)): def decompose_tensor_fwd( X: wp.array(ndim=4, dtype=dtype_wp), - I: wp.array(ndim=dim, dtype=dtype_wp), # noqa: E741 + I: wp.array(ndim=dim, dtype=dtype_wp), A: wp.array(ndim=dim, dtype=dtype_wp), S: wp.array(ndim=dim, dtype=dtype_wp), ): diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py index 400ea33a..b21f6281 100644 --- a/src/matgl/kernels/graph_transform.py +++ b/src/matgl/kernels/graph_transform.py @@ -33,7 +33,7 @@ @wp.kernel -def count_row_col( # noqa: D103 +def count_row_col( edge_index: wp.array(ndim=2, dtype=wp.int32), row_count: wp.array(ndim=1, dtype=wp.int32), col_count: wp.array(ndim=1, dtype=wp.int32), @@ -46,7 +46,7 @@ def count_row_col( # noqa: D103 @wp.kernel -def convert_to_sparse( # noqa: D103 +def convert_to_sparse( edge_index: wp.array(ndim=2, dtype=wp.int32), row_count: wp.array(ndim=1, dtype=wp.int32), col_count: wp.array(ndim=1, dtype=wp.int32), diff --git a/src/matgl/kernels/utils.py b/src/matgl/kernels/utils.py index 241a3754..e2ef2d30 100644 --- a/src/matgl/kernels/utils.py +++ b/src/matgl/kernels/utils.py @@ -34,9 +34,7 @@ def get_module(name: str, dtype: list[str]): - """ - Get the module for the given name and dtype - """ + """Get the module for the given name and dtype.""" full_name = name + "_" + "_".join(get_dtype(d) for d in dtype) if full_name not in MODULES: print(f"Module {full_name} not found in MODULES dictionary") @@ -46,9 +44,7 @@ def get_module(name: str, dtype: list[str]): def add_module(name: str, dtype: list[str], kernel: wp.Kernel): - """ - Add the module for the given name and dtype - """ + """Add the module for the given name and dtype.""" full_name = name + "_" + "_".join(get_dtype(d) for d in dtype) if full_name not in MODULES: MODULES[full_name] = kernel @@ -56,10 +52,7 @@ def add_module(name: str, dtype: list[str], kernel: wp.Kernel): def get_dtype(dtype: str): - """ - Get the dtype for the given dtype - WIP - """ + """Get the dtype string representation for the given dtype (WIP).""" if dtype.endswith("16"): return "fp16" if dtype.endswith("32"): @@ -70,10 +63,7 @@ def get_dtype(dtype: str): def get_wp_fp_dtype(dtype: str): - """ - Get the warp dtype for the given dtype - WIP - """ + """Get the warp dtype for the given dtype (WIP).""" if dtype.endswith("16"): return wp.float16 if dtype.endswith("32"): @@ -84,9 +74,7 @@ def get_wp_fp_dtype(dtype: str): def list_modules(): - """ - List all modules in the MODULES dictionary - """ + """List all modules in the MODULES dictionary.""" print("Available modules:") for name in MODULES: print(f" - {name}") @@ -94,9 +82,7 @@ def list_modules(): def get_stream(device: torch.device): - """ - Get the stream for the given device - """ + """Get the stream for the given device.""" if device.type == "cuda": return wp.stream_from_torch(torch.cuda.current_stream(device)) return None diff --git a/src/matgl/ops/__init__.py b/src/matgl/ops/__init__.py index 18b905f9..dc5bfd2d 100644 --- a/src/matgl/ops/__init__.py +++ b/src/matgl/ops/__init__.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Custom tensor operations with Warp kernel implementations.""" from __future__ import annotations import warp as wp From fb91cecf5f6636cf9adc5ce813cd5313d919426d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 23:25:42 +0000 Subject: [PATCH 10/18] pre-commit auto-fixes --- src/matgl/ops/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/matgl/ops/__init__.py b/src/matgl/ops/__init__.py index dc5bfd2d..a2d408aa 100644 --- a/src/matgl/ops/__init__.py +++ b/src/matgl/ops/__init__.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Custom tensor operations with Warp kernel implementations.""" + from __future__ import annotations import warp as wp From ffee3889a8059a54ee407a27d69ee8194973e4de Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Sun, 11 Jan 2026 20:57:23 -0500 Subject: [PATCH 11/18] fix kernels --- pyproject.toml | 2 +- src/matgl/kernels/compose_tensor.py | 21 ++++---------- src/matgl/kernels/decompose_tensor.py | 31 ++++----------------- src/matgl/kernels/equivariant_o3_matmul.py | 2 -- src/matgl/kernels/equivariant_so3_matmul.py | 1 - src/matgl/kernels/graph_transform.py | 2 -- src/matgl/kernels/tensor_norm3.py | 2 -- src/matgl/kernels/tensornet_mp.py | 2 -- src/matgl/kernels/tensornet_radial_mp.py | 1 - 9 files changed, 12 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8bcfee38..07d1c904 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,7 +159,7 @@ lint.isort.required-imports = ["from __future__ import annotations"] "tests/**/*" = ["D", "PERF"] "docs/**/*" = ["D"] "examples/**/*" = ["D"] -"src/matgl/kernels/*" = ["D100", "D103", "E741"] +"src/matgl/kernels/*" = ["D100", "D103", "E741", "I002"] "src/matgl/ops/*" = ["D100", "D103"] [tool.pytest.ini_options] diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py index f3766ce7..1ed84191 100644 --- a/src/matgl/kernels/compose_tensor.py +++ b/src/matgl/kernels/compose_tensor.py @@ -27,8 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for composing 3x3 tensors from I, A, S components.""" -from __future__ import annotations - import warp as wp from .utils import add_module, get_wp_fp_dtype @@ -76,8 +74,7 @@ def compose_tensor_fwd( for i in range(3): X_reg[i, i] += I_reg - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(3): for j in range(i + 1, 3): X_reg[i, j] += A_reg[cnt] @@ -85,8 +82,7 @@ def compose_tensor_fwd( cnt += 1 trace_S = -(S_reg[0] + S_reg[3]) - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(2): X_reg[i, i] += S_reg[cnt] cnt += 1 @@ -121,14 +117,12 @@ def compose_tensor_bwd( for i in range(3): dI_reg += dX_reg[i, i] - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(3): for j in range(i + 1, 3): dA_reg[cnt] += dX_reg[i, j] dA_reg[cnt] -= dX_reg[j, i] cnt += 1 - cnt += 1 dS_reg[0] += dX_reg[0, 0] dS_reg[0] -= dX_reg[2, 2] @@ -175,27 +169,22 @@ def compose_tensor_bwd_bwd( for i in range(3): d2X_reg[i, i] += dI_reg - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(3): for j in range(i + 1, 3): d2X_reg[i, j] += dA_reg[cnt] d2X_reg[j, i] -= dA_reg[cnt] cnt += 1 - cnt += 1 - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(2): d2X_reg[i, i] += dS_reg[cnt] cnt += 1 - cnt += 1 for j in range(i + 1, 3): d2X_reg[i, j] += dS_reg[cnt] d2X_reg[j, i] += dS_reg[cnt] cnt += 1 - cnt += 1 d2X_reg[2, 2] -= dS_reg[0] d2X_reg[2, 2] -= dS_reg[3] diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py index 1a14d1dd..34652fc5 100644 --- a/src/matgl/kernels/decompose_tensor.py +++ b/src/matgl/kernels/decompose_tensor.py @@ -27,8 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for decomposing 3x3 tensors into I, A, S components.""" -from __future__ import annotations - import warp as wp from .utils import add_module, get_wp_fp_dtype @@ -76,25 +74,20 @@ def decompose_tensor_fwd( I[b, 0, h] = res denom = X.dtype(2.0) - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(2): for j in range(i + 1, 3): A[b, cnt, h] = (X_reg[i, j] - X_reg[j, i]) / denom cnt += 1 - cnt += 1 - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(2): S[b, cnt, h] = X_reg[i, i] - res cnt += 1 - cnt += 1 for j in range(i + 1, 3): S[b, cnt, h] = (X_reg[i, j] + X_reg[j, i]) / denom cnt += 1 - cnt += 1 def decompose_tensor_bwd( dI: wp.array(ndim=dim, dtype=dtype_wp), @@ -121,32 +114,26 @@ def decompose_tensor_bwd( denom = dX.dtype(2.0) - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(3): for j in range(i + 1, 3): dX_reg[i, j] += dA_reg[cnt] / denom dX_reg[j, i] -= dA_reg[cnt] / denom - - cnt += 1 cnt += 1 - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(2): dX_reg[i, i] += dS_reg[cnt] for j in range(3): dX_reg[j, j] -= dS_reg[cnt] / dI.dtype(3.0) cnt += 1 - cnt += 1 for j in range(i + 1, 3): dX_reg[i, j] += dS_reg[cnt] / denom dX_reg[j, i] += dS_reg[cnt] / denom cnt += 1 - cnt += 1 for i in range(3): for j in range(3): @@ -174,30 +161,24 @@ def decompose_tensor_bwd_bwd( denom = dX.dtype(2.0) - cnt = 0 - cnt = 0 - + cnt = wp.int32(0) for i in range(3): for j in range(i + 1, 3): d2A_reg[cnt] += dX_reg[i, j] / denom d2A_reg[cnt] -= dX_reg[j, i] / denom cnt += 1 - cnt += 1 - cnt = 0 - cnt = 0 + cnt = wp.int32(0) for i in range(2): d2S_reg[cnt] += dX_reg[i, i] for j in range(3): d2S_reg[cnt] -= dX_reg[j, j] / d2I.dtype(3.0) cnt += 1 - cnt += 1 for j in range(i + 1, 3): d2S_reg[cnt] += dX_reg[i, j] / denom d2S_reg[cnt] += dX_reg[j, i] / denom cnt += 1 - cnt += 1 d2I[b, 0, h] = d2I_reg for i in range(3): diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py index 639da41d..a7d36a87 100644 --- a/src/matgl/kernels/equivariant_o3_matmul.py +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -27,8 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for O(3)-equivariant 3x3 tensor matrix multiplication.""" -from __future__ import annotations - import warp as wp from .utils import add_module, get_wp_fp_dtype diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index bcf33238..1be0e825 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -27,7 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for SO(3)-equivariant 3x3 tensor matrix multiplication.""" -from __future__ import annotations import warp as wp diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py index b21f6281..9d7e6077 100644 --- a/src/matgl/kernels/graph_transform.py +++ b/src/matgl/kernels/graph_transform.py @@ -27,8 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for graph edge index transformation to sparse CSR format.""" -from __future__ import annotations - import warp as wp diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py index 540cc576..302651da 100644 --- a/src/matgl/kernels/tensor_norm3.py +++ b/src/matgl/kernels/tensor_norm3.py @@ -27,8 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for computing 3x3 tensor norms (I, A, S components).""" -from __future__ import annotations - import warp as wp from .utils import add_module, get_wp_fp_dtype diff --git a/src/matgl/kernels/tensornet_mp.py b/src/matgl/kernels/tensornet_mp.py index 016de513..e7bd34ba 100644 --- a/src/matgl/kernels/tensornet_mp.py +++ b/src/matgl/kernels/tensornet_mp.py @@ -25,8 +25,6 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import annotations - import warp as wp from .utils import add_module, get_wp_fp_dtype diff --git a/src/matgl/kernels/tensornet_radial_mp.py b/src/matgl/kernels/tensornet_radial_mp.py index f223884c..27d39a2b 100644 --- a/src/matgl/kernels/tensornet_radial_mp.py +++ b/src/matgl/kernels/tensornet_radial_mp.py @@ -25,7 +25,6 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from __future__ import annotations import warp as wp From 52aab20933d8ee589f1c060da9f2de586d1dac9c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 02:00:26 +0000 Subject: [PATCH 12/18] pre-commit auto-fixes --- src/matgl/kernels/equivariant_so3_matmul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py index 1be0e825..38ee5065 100644 --- a/src/matgl/kernels/equivariant_so3_matmul.py +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -27,7 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Warp kernels for SO(3)-equivariant 3x3 tensor matrix multiplication.""" - import warp as wp from .utils import add_module, get_wp_fp_dtype From 18229f4022a82a428a1f3d3477f4802344fde81b Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Mon, 12 Jan 2026 10:17:43 -0500 Subject: [PATCH 13/18] fix weighs loading --- src/matgl/models/_tensornet_pyg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 4b686a39..21abfb0f 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -160,7 +160,6 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): new_b = f"{prefix}distance_proj.bias" if all(k in state_dict for k in w_keys + b_keys): - state_dict = dict(state_dict) state_dict[new_w] = torch.cat([state_dict.pop(k) for k in w_keys], dim=0) state_dict[new_b] = torch.cat([state_dict.pop(k) for k in b_keys], dim=0) From f4d8c93ee8e01b3c7e61ea4d44f594eccbbab85e Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Mon, 26 Jan 2026 17:04:21 -0500 Subject: [PATCH 14/18] fix tensor init in kernels --- src/matgl/ops/equivariant_o3_matmul.py | 5 +++++ src/matgl/ops/equivariant_so3_matmul.py | 5 +++++ src/matgl/ops/tensornet_radial_mp.py | 2 ++ 3 files changed, 12 insertions(+) diff --git a/src/matgl/ops/equivariant_o3_matmul.py b/src/matgl/ops/equivariant_o3_matmul.py index 570f0836..c042fd10 100644 --- a/src/matgl/ops/equivariant_o3_matmul.py +++ b/src/matgl/ops/equivariant_o3_matmul.py @@ -195,6 +195,11 @@ def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): grad_output_saved, x, y = ctx.saved_tensors + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + if grad_grad_y is None: + grad_grad_y = torch.zeros_like(y) + outputs = torch.ops.tensornet.tensor_matmul_o3_3x3_bwd_bwd_primitive( grad_output_saved, grad_grad_x, grad_grad_y, x, y ) diff --git a/src/matgl/ops/equivariant_so3_matmul.py b/src/matgl/ops/equivariant_so3_matmul.py index 739b9366..7ea61565 100644 --- a/src/matgl/ops/equivariant_so3_matmul.py +++ b/src/matgl/ops/equivariant_so3_matmul.py @@ -192,6 +192,11 @@ def tensor_matmul_so3_3x3_bwd_bwd(ctx, *grad_outputs): grad_output_saved, x, y = ctx.saved_tensors + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + if grad_grad_y is None: + grad_grad_y = torch.zeros_like(y) + outputs = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_bwd_primitive( grad_output_saved, grad_grad_x, grad_grad_y, x, y ) diff --git a/src/matgl/ops/tensornet_radial_mp.py b/src/matgl/ops/tensornet_radial_mp.py index d3070ec9..8e3fccbe 100644 --- a/src/matgl/ops/tensornet_radial_mp.py +++ b/src/matgl/ops/tensornet_radial_mp.py @@ -262,6 +262,7 @@ def _( def _( grad_output_I: Tensor, grad_output_A: Tensor, + grad_output_S: Tensor, grad_grad_edge_vec_norm: Tensor, grad_grad_edge_attr: Tensor, edge_vec_norm: Tensor, @@ -272,6 +273,7 @@ def _( return [ torch.empty_like(grad_output_I), torch.empty_like(grad_output_A), + torch.empty_like(grad_output_S), torch.empty_like(grad_grad_edge_vec_norm), torch.empty_like(grad_grad_edge_attr), ] From 9f9da0954581849e211c3466ff0f97330525162b Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Mon, 26 Jan 2026 17:04:53 -0500 Subject: [PATCH 15/18] fix mypy --- pyproject.toml | 4 ++++ src/matgl/models/_tensornet_pyg.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07d1c904..dbb181b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,6 +174,10 @@ exclude = ['examples', 'tests'] module = ["requests.*", "tabulate.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["matgl.kernels.*"] +ignore_errors = true + [tool.coverage.run] relative_files = true diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 21abfb0f..32050d10 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -569,8 +569,8 @@ def forward( else: # PyG Data object - extract tensors z = getattr(g, "node_type", getattr(g, "z", None)) - pos = g.pos # type: ignore[union-attr] - edge_index = g.edge_index # type: ignore[union-attr] + pos = g.pos # type: ignore[attr-defined] + edge_index = g.edge_index # type: ignore[attr-defined] pbc_offshift = getattr(g, "pbc_offshift", None) batch = getattr(g, "batch", None) num_graphs = getattr(g, "num_graphs", None) @@ -580,7 +580,7 @@ def forward( # perpare graph indices for message passing row_data, row_indices, row_indptr, col_data, col_indices, col_indptr = graph_transform( - edge_index.int(), z.shape[0] + edge_index.int(), z.shape[0] # type: ignore[union-attr] ) # Expand distances with radial basis functions @@ -629,7 +629,7 @@ def forward( batch_long = batch.to(torch.long) if num_graphs is None: num_graphs = int(batch_long.max().item()) + 1 - return scatter_add(atomic_energies, batch_long, dim_size=num_graphs) + return scatter_add(atomic_energies, batch_long, dim_size=num_graphs) # type: ignore[arg-type] # Single graph case: Sum all energies (equivalent to scatter_add with all nodes in one graph) return torch.sum(atomic_energies, dim=0, keepdim=True).squeeze() From 127d7a043ad9f8d1632466e3584a7d606eb45584 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 22:13:09 +0000 Subject: [PATCH 16/18] pre-commit auto-fixes --- src/matgl/models/_tensornet_pyg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 32050d10..378a7210 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -580,7 +580,8 @@ def forward( # perpare graph indices for message passing row_data, row_indices, row_indptr, col_data, col_indices, col_indptr = graph_transform( - edge_index.int(), z.shape[0] # type: ignore[union-attr] + edge_index.int(), + z.shape[0], # type: ignore[union-attr] ) # Expand distances with radial basis functions From 3242b35b8885c40589b10612f3faf8e525fc13f8 Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Wed, 28 Jan 2026 18:42:25 -0500 Subject: [PATCH 17/18] Fix backward pass using wrong sparse matrix format in message_passing_bwd The backward primitive was incorrectly using row_* (CSR) tensors instead of col_* (CSC) tensors. For gradient computation, the transpose of the forward sparse matrix is needed. This caused incorrect gradients for non-symmetric edge attributes. Also updates test script: - Fix pymatgen API compatibility (site.specie -> site.species_string) - Loosen parameter gradient thresholds from 1e-5 to 5e-5 for double backward --- dev/test_tensornet_forward_backward.py | 355 +++++++++++++++++++++---- src/matgl/ops/tensornet_mp.py | 12 +- 2 files changed, 315 insertions(+), 52 deletions(-) diff --git a/dev/test_tensornet_forward_backward.py b/dev/test_tensornet_forward_backward.py index ead8a877..0c0ecd54 100644 --- a/dev/test_tensornet_forward_backward.py +++ b/dev/test_tensornet_forward_backward.py @@ -36,9 +36,12 @@ import torch from pymatgen.core import Structure +from torch_geometric.data import Batch DEFAULT_MATGL_MAIN_PATH = str(Path(__file__).parent.parent / "matgl-main" / "src") +BATCH_SIZE = 13 + MODEL_CONFIG = { "units": 64, "nblocks": 2, @@ -70,7 +73,7 @@ def load_structure(path: str) -> Structure: def get_element_types(structure: Structure) -> tuple[str, ...]: """Extract sorted unique element symbols.""" - return tuple(sorted({site.specie.symbol for site in structure})) + return tuple(sorted({site.species_string for site in structure})) def build_graph( @@ -94,6 +97,39 @@ def build_graph( return graph.to(device) +def build_batched_graph( + converter: Any, + structure: Structure, + device: torch.device, + compute_bond: Any = None, + requires_grad: bool = False, + batch_size: int = BATCH_SIZE, +) -> Any: + """Build batched graph by repeating the same structure multiple times.""" + graphs = [] + for _ in range(batch_size): + graph, lat, _ = converter.get_graph(structure) + pos = graph.frac_coords @ lat[0] + graph.pos = pos.clone().detach().requires_grad_(requires_grad) if requires_grad else pos.clone() + graph.pbc_offshift = (graph.pbc_offset @ lat[0]).clone() + + if compute_bond is not None: + bond_vec, bond_dist = compute_bond(graph) + graph.bond_vec = bond_vec.clone() + graph.bond_dist = bond_dist.clone() + + # Clone all tensor attributes to ensure independence + for key in list(graph.keys()): + val = graph[key] + if isinstance(val, torch.Tensor): + graph[key] = val.clone() + + graphs.append(graph) + + batched = Batch.from_data_list(graphs) + return batched.to(device) + + def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-6) -> bool: """Compare two tensors element-wise.""" if t1.shape != t2.shape: @@ -116,17 +152,36 @@ def compare_weights(ref_model: Any, cur_model: Any) -> bool: ref_sd, cur_sd = ref_model.state_dict(), cur_model.state_dict() all_match = True - # Handle merged distance_proj layers (distance_proj1/2/3 -> distance_proj) + # Handle distance_proj layers + print("--- distance_proj ---") dp_keys = [f"tensor_embedding.distance_proj{i}" for i in range(1, 4)] + skip = set() + if f"{dp_keys[0]}.weight" in ref_sd: + # Reference has separate distance_proj1/2/3 -> merge and compare ref_w = torch.cat([ref_sd[f"{k}.weight"] for k in dp_keys], dim=0) ref_b = torch.cat([ref_sd[f"{k}.bias"] for k in dp_keys], dim=0) + skip = {f"{k}.{p}" for k in dp_keys for p in ("weight", "bias")} - print("\n--- distance_proj (merged) ---") all_match &= compare_tensors("weight", ref_w, cur_sd["tensor_embedding.distance_proj.weight"]) all_match &= compare_tensors("bias", ref_b, cur_sd["tensor_embedding.distance_proj.bias"]) + elif "tensor_embedding.distance_proj.weight" in ref_sd: + # Reference has merged distance_proj -> compare directly + skip = {"tensor_embedding.distance_proj.weight", "tensor_embedding.distance_proj.bias"} + + all_match &= compare_tensors( + "weight", + ref_sd["tensor_embedding.distance_proj.weight"], + cur_sd["tensor_embedding.distance_proj.weight"], + ) + all_match &= compare_tensors( + "bias", + ref_sd["tensor_embedding.distance_proj.bias"], + cur_sd["tensor_embedding.distance_proj.bias"], + ) + else: + print(" WARNING: distance_proj not found in reference model") - skip = {f"{k}.{p}" for k in dp_keys for p in ("weight", "bias")} print("\n--- Other Parameters ---") for key in sorted(cur_sd): @@ -138,7 +193,7 @@ def compare_weights(ref_model: Any, cur_model: Any) -> bool: print(f" {key}: NOT IN REFERENCE") for key in sorted(ref_sd): - if key not in skip and key not in cur_sd: + if key not in skip and key not in cur_sd and "distance_proj" not in key: print(f" {key}: IN REFERENCE ONLY") all_match = False @@ -146,38 +201,45 @@ def compare_weights(ref_model: Any, cur_model: Any) -> bool: return all_match -def compare_forward(ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device) -> bool: - """Compare forward pass energy predictions.""" - print_section("Forward Pass") +def compare_forward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare forward pass energy predictions for batched graphs.""" + print_section("Forward Pass (Batched)") ref_model.eval() cur_model.eval() - state_attr = torch.tensor([0.0, 0.0], device=device) + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) ref_e = ref_model(g=ref_graph, state_attr=state_attr) cur_e = cur_model(g=cur_graph, state_attr=state_attr) - diff = abs(float(ref_e) - float(cur_e)) - print(f"Reference: {float(ref_e):.10f}") - print(f"Current: {float(cur_e):.10f}") - print(f"Diff: {diff:.2e}") + print(f"Reference energies: {ref_e.detach().cpu().numpy()}") + print(f"Current energies: {cur_e.detach().cpu().numpy()}") - match = diff < 1e-5 + diff = (ref_e - cur_e).abs() + print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") + + match = diff.max().item() < 1e-5 print(f"Result: {'PASS' if match else 'FAIL'}") return match -def compare_backward(ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device) -> bool: - """Compare forces (F = -dE/dpos).""" - print_section("Backward Pass (Forces)") +def compare_backward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare forces (F = -dE/dpos) for batched graphs.""" + print_section("Backward Pass (Forces, Batched)") ref_model.train() cur_model.train() - state_attr = torch.tensor([0.0, 0.0], device=device) + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) def get_forces(model, graph): energy = model(g=graph, state_attr=state_attr) - return -torch.autograd.grad(energy, graph.pos, create_graph=True)[0] + # Sum energies to get scalar for gradient + total_energy = energy.sum() + return -torch.autograd.grad(total_energy, graph.pos, create_graph=True)[0] ref_f = get_forces(ref_model, ref_graph) cur_f = get_forces(cur_model, cur_graph) @@ -194,28 +256,30 @@ def get_forces(model, graph): def compare_double_backward( - ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE ) -> bool: - """Compare position gradients via loss = sum(forces^2).""" - print_section("Double Backward (Position Gradients)") + """Compare position gradients via loss = sum(forces^2) for batched graphs.""" + print_section("Double Backward (Position Gradients, Batched)") ref_model.train() cur_model.train() - state_attr = torch.tensor([0.0, 0.0], device=device) + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) ref_graph.pos.retain_grad() cur_graph.pos.retain_grad() # Reference ref_energy = ref_model(g=ref_graph, state_attr=state_attr) - ref_forces = torch.autograd.grad(ref_energy, ref_graph.pos, create_graph=True)[0] + ref_total_energy = ref_energy.sum() + ref_forces = torch.autograd.grad(ref_total_energy, ref_graph.pos, create_graph=True)[0] ref_loss = (ref_forces * ref_forces).sum() ref_loss.backward() ref_pos_grad = ref_graph.pos.grad.clone() # Current cur_energy = cur_model(g=cur_graph, state_attr=state_attr) - cur_forces = torch.autograd.grad(cur_energy, cur_graph.pos, create_graph=True)[0] + cur_total_energy = cur_energy.sum() + cur_forces = torch.autograd.grad(cur_total_energy, cur_graph.pos, create_graph=True)[0] cur_loss = (cur_forces * cur_forces).sum() cur_loss.backward() cur_pos_grad = cur_graph.pos.grad.clone() @@ -237,20 +301,178 @@ def compare_double_backward( return match -def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: +def compare_param_gradients( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare gradients on all model parameters after double backward (forces loss).""" + print_section("Parameter Gradients (Double Backward, Batched)") + + ref_model.train() + cur_model.train() + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) + + # Zero gradients + ref_model.zero_grad() + cur_model.zero_grad() + + # Double backward: compute forces, then loss = sum(forces^2) + # Reference + ref_energy = ref_model(g=ref_graph, state_attr=state_attr) + ref_total_energy = ref_energy.sum() + ref_forces = torch.autograd.grad(ref_total_energy, ref_graph.pos, create_graph=True)[0] + ref_loss = (ref_forces * ref_forces).sum() + ref_loss.backward() + + # Current + cur_energy = cur_model(g=cur_graph, state_attr=state_attr) + cur_total_energy = cur_energy.sum() + cur_forces = torch.autograd.grad(cur_total_energy, cur_graph.pos, create_graph=True)[0] + cur_loss = (cur_forces * cur_forces).sum() + cur_loss.backward() + + print(f"Reference loss: {ref_loss.item():.6f}") + print(f"Current loss: {cur_loss.item():.6f}") + + # Build mapping for distance_proj layers (merged in current, separate in reference) + ref_sd = {k: p for k, p in ref_model.named_parameters()} + cur_sd = {k: p for k, p in cur_model.named_parameters()} + + all_match = True + max_diff_overall = 0.0 + mismatched_params = [] + + # Handle merged distance_proj layers + print("--- distance_proj (merged) ---") + dp_keys = [f"tensor_embedding.distance_proj{i}" for i in range(1, 4)] + skip_ref_keys = set() + skip_cur_keys = set() + for suffix in [".weight", ".bias"]: + ref_grads = [] + for dp_key in dp_keys: + key = dp_key + suffix + if key in ref_sd: + skip_ref_keys.add(key) + if ref_sd[key].grad is not None: + ref_grads.append(ref_sd[key].grad) + + cur_key = "tensor_embedding.distance_proj" + suffix + skip_cur_keys.add(cur_key) + + if not ref_grads: + # Reference doesn't have separate distance_proj layers, compare directly + ref_key = cur_key + if ref_key in ref_sd: + ref_param = ref_sd[ref_key] + cur_param = cur_sd.get(cur_key) + if ref_param.grad is None and (cur_param is None or cur_param.grad is None): + print(f" distance_proj{suffix}: NO GRAD (both)") + elif ref_param.grad is None: + print(f" distance_proj{suffix}: NO GRAD (reference)") + all_match = False + elif cur_param is None or cur_param.grad is None: + print(f" distance_proj{suffix}: NO GRAD (current)") + all_match = False + else: + diff = (ref_param.grad - cur_param.grad).abs() + max_diff = diff.max().item() + max_diff_overall = max(max_diff_overall, max_diff) + if max_diff > 5e-5: + mismatched_params.append(f"distance_proj{suffix}") + all_match = False + print(f" distance_proj{suffix}: DIFF (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: MATCH (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: NOT FOUND IN REFERENCE") + else: + # Reference has separate layers, concatenate and compare + ref_grad = torch.cat(ref_grads, dim=0) + if cur_key in cur_sd and cur_sd[cur_key].grad is not None: + cur_grad = cur_sd[cur_key].grad + if ref_grad.shape == cur_grad.shape: + diff = (ref_grad - cur_grad).abs() + max_diff = diff.max().item() + max_diff_overall = max(max_diff_overall, max_diff) + if max_diff > 5e-5: + mismatched_params.append(f"distance_proj{suffix}") + all_match = False + print(f" distance_proj{suffix}: DIFF (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: MATCH (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: SHAPE MISMATCH {ref_grad.shape} vs {cur_grad.shape}") + all_match = False + else: + print(f" distance_proj{suffix}: NO GRAD (current)") + all_match = False + + # Compare other parameters + print("\n--- Other Parameters ---") + for cur_key, cur_param in cur_sd.items(): + if "distance_proj" in cur_key: + continue + + if cur_key in ref_sd: + ref_param = ref_sd[cur_key] + if ref_param.grad is None and cur_param.grad is None: + print(f" {cur_key}: NO GRAD (both)") + continue + elif ref_param.grad is None: + print(f" {cur_key}: NO GRAD (reference)") + all_match = False + continue + elif cur_param.grad is None: + print(f" {cur_key}: NO GRAD (current)") + all_match = False + continue + + if ref_param.grad.shape != cur_param.grad.shape: + print(f" {cur_key}: SHAPE MISMATCH {ref_param.grad.shape} vs {cur_param.grad.shape}") + all_match = False + continue + + diff = (ref_param.grad - cur_param.grad).abs() + max_diff = diff.max().item() + max_diff_overall = max(max_diff_overall, max_diff) + + if max_diff > 5e-5: + mismatched_params.append(cur_key) + all_match = False + print(f" {cur_key}: DIFF (max={max_diff:.2e}, mean={diff.mean():.2e})") + else: + print(f" {cur_key}: MATCH (max={max_diff:.2e})") + else: + print(f" {cur_key}: NOT IN REFERENCE") + + # Check for params in reference only + for ref_key in ref_sd: + if ref_key not in skip_ref_keys and ref_key not in cur_sd: + print(f" {ref_key}: IN REFERENCE ONLY") + + print(f"\nMax diff overall: {max_diff_overall:.2e}") + if mismatched_params: + print(f"Mismatched params: {mismatched_params}") + + print(f"Result: {'PASS' if all_match else 'FAIL'}") + return all_match + + +def main( + structure_path: str, matgl_main_path: str, seed: int = 42, pretrained_path: str | None = None +) -> bool: """Run all comparison tests between reference and current implementations.""" print_section("TensorNet Comparison: matgl-main vs Current") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Seed: {seed}, Device: {device}") + print(f"Seed: {seed}, Device: {device}, Batch size: {BATCH_SIZE}") print(f"matgl-main path: {matgl_main_path}") + if pretrained_path: + print(f"Pretrained model: {pretrained_path}") structure = load_structure(structure_path) element_types = get_element_types(structure) print(f"Structure: {structure_path} ({len(structure)} atoms, elements: {element_types})") - model_config = {**MODEL_CONFIG, "element_types": element_types} - # Reference model (matgl-main) clear_matgl_modules() sys.path.insert(0, matgl_main_path) @@ -258,14 +480,28 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: from matgl.ext._pymatgen_pyg import Structure2Graph as RefConverter from matgl.graph._compute_pyg import compute_pair_vector_and_distance as ref_compute_bond from matgl.models._tensornet_pyg import TensorNet as RefTensorNet - - torch.manual_seed(seed) - ref_model = RefTensorNet(**model_config).to(device) - ref_converter = RefConverter(element_types=element_types, cutoff=MODEL_CONFIG["cutoff"]) - - ref_graph = build_graph(ref_converter, structure, device, ref_compute_bond) - ref_graph_grad = build_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) - ref_graph_grad2 = build_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + from matgl.utils.io import load_model as ref_load_model + + if pretrained_path: + # Load pre-trained model (Potential wrapper contains TensorNet) + ref_potential = ref_load_model(pretrained_path) + ref_model = ref_potential.model.to(device) + ref_cutoff = ref_model.cutoff + ref_element_types = ref_model.element_types + else: + model_config = {**MODEL_CONFIG, "element_types": element_types} + torch.manual_seed(seed) + ref_model = RefTensorNet(**model_config).to(device) + ref_cutoff = MODEL_CONFIG["cutoff"] + ref_element_types = element_types + + ref_converter = RefConverter(element_types=ref_element_types, cutoff=ref_cutoff) + + # Build batched graphs for reference model + ref_graph = build_batched_graph(ref_converter, structure, device, ref_compute_bond) + ref_graph_grad = build_batched_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + ref_graph_grad2 = build_batched_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + ref_graph_param = build_batched_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) sys.path.pop(0) @@ -274,16 +510,31 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: from matgl.ext._pymatgen_pyg import Structure2Graph as CurConverter from matgl.models._tensornet_pyg import TensorNet as CurTensorNet - - torch.manual_seed(seed) - cur_model = CurTensorNet(**model_config).to(device) - cur_converter = CurConverter(element_types=element_types, cutoff=MODEL_CONFIG["cutoff"]) - - cur_graph = build_graph(cur_converter, structure, device) - cur_graph_grad = build_graph(cur_converter, structure, device, requires_grad=True) - cur_graph_grad2 = build_graph(cur_converter, structure, device, requires_grad=True) + from matgl.utils.io import load_model as cur_load_model + + if pretrained_path: + # Load pre-trained model (Potential wrapper contains TensorNet) + cur_potential = cur_load_model(pretrained_path) + cur_model = cur_potential.model.to(device) + cur_cutoff = cur_model.cutoff + cur_element_types = cur_model.element_types + else: + model_config = {**MODEL_CONFIG, "element_types": element_types} + torch.manual_seed(seed) + cur_model = CurTensorNet(**model_config).to(device) + cur_cutoff = MODEL_CONFIG["cutoff"] + cur_element_types = element_types + + cur_converter = CurConverter(element_types=cur_element_types, cutoff=cur_cutoff) + + # Build batched graphs for current model + cur_graph = build_batched_graph(cur_converter, structure, device) + cur_graph_grad = build_batched_graph(cur_converter, structure, device, requires_grad=True) + cur_graph_grad2 = build_batched_graph(cur_converter, structure, device, requires_grad=True) + cur_graph_param = build_batched_graph(cur_converter, structure, device, requires_grad=True) print(f"Models: {sum(p.numel() for p in ref_model.parameters())} params each") + print(f"Batched graph: {ref_graph.num_nodes} nodes, {ref_graph.num_edges} edges") # Run comparisons results = { @@ -291,6 +542,7 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: "Forward": compare_forward(ref_model, cur_model, ref_graph, cur_graph, device), "Backward": compare_backward(ref_model, cur_model, ref_graph_grad, cur_graph_grad, device), "Double Backward": compare_double_backward(ref_model, cur_model, ref_graph_grad2, cur_graph_grad2, device), + "Param Gradients": compare_param_gradients(ref_model, cur_model, ref_graph_param, cur_graph_param, device), } # Summary @@ -323,6 +575,17 @@ def main(structure_path: str, matgl_main_path: str, seed: int = 42) -> bool: help="Path to matgl-main/src (default: $MATGL_MAIN_PATH or ../matgl-main/src)", ) parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--pretrained", + "-p", + default=None, + help="Path to pretrained model directory (e.g., pretrained_models/TensorNet-MatPES-PBE-v2025.1-PES)", + ) args = parser.parse_args() - main(structure_path=args.structure, matgl_main_path=args.matgl_main_path, seed=args.seed) + main( + structure_path=args.structure, + matgl_main_path=args.matgl_main_path, + seed=args.seed, + pretrained_path=args.pretrained, + ) diff --git a/src/matgl/ops/tensornet_mp.py b/src/matgl/ops/tensornet_mp.py index 22ff2848..da507102 100644 --- a/src/matgl/ops/tensornet_mp.py +++ b/src/matgl/ops/tensornet_mp.py @@ -148,9 +148,9 @@ def _( edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) - row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) - row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) - row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + col_data_wp = wp.from_torch(col_data.detach(), return_ctype=True) + col_indices_wp = wp.from_torch(col_indices.detach(), return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr.detach(), return_ctype=True) grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) @@ -172,9 +172,9 @@ def _( grad_output_x_wp, grad_output_y_wp, grad_output_z_wp, - row_data_wp, - row_indices_wp, - row_indptr_wp, + col_data_wp, + col_indices_wp, + col_indptr_wp, grad_x_wp, grad_y_wp, grad_z_wp, From 3c660342921cc12408f191fc90e521326d48c260 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 23:45:38 +0000 Subject: [PATCH 18/18] pre-commit auto-fixes --- dev/test_tensornet_forward_backward.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dev/test_tensornet_forward_backward.py b/dev/test_tensornet_forward_backward.py index 0c0ecd54..f4fcb4ad 100644 --- a/dev/test_tensornet_forward_backward.py +++ b/dev/test_tensornet_forward_backward.py @@ -417,11 +417,11 @@ def compare_param_gradients( if ref_param.grad is None and cur_param.grad is None: print(f" {cur_key}: NO GRAD (both)") continue - elif ref_param.grad is None: + if ref_param.grad is None: print(f" {cur_key}: NO GRAD (reference)") all_match = False continue - elif cur_param.grad is None: + if cur_param.grad is None: print(f" {cur_key}: NO GRAD (current)") all_match = False continue @@ -457,9 +457,7 @@ def compare_param_gradients( return all_match -def main( - structure_path: str, matgl_main_path: str, seed: int = 42, pretrained_path: str | None = None -) -> bool: +def main(structure_path: str, matgl_main_path: str, seed: int = 42, pretrained_path: str | None = None) -> bool: """Run all comparison tests between reference and current implementations.""" print_section("TensorNet Comparison: matgl-main vs Current")