diff --git a/dev/test_tensornet_forward_backward.py b/dev/test_tensornet_forward_backward.py new file mode 100644 index 00000000..f4fcb4ad --- /dev/null +++ b/dev/test_tensornet_forward_backward.py @@ -0,0 +1,589 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Compare forward/backward/double-backward between matgl-main and current TensorNet.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any + +import torch +from pymatgen.core import Structure +from torch_geometric.data import Batch + +DEFAULT_MATGL_MAIN_PATH = str(Path(__file__).parent.parent / "matgl-main" / "src") + +BATCH_SIZE = 13 + +MODEL_CONFIG = { + "units": 64, + "nblocks": 2, + "num_rbf": 32, + "cutoff": 5.0, + "rbf_type": "Gaussian", + "activation_type": "swish", + "equivariance_invariance_group": "O(3)", + "is_intensive": False, + "ntargets": 1, +} + + +def clear_matgl_modules() -> None: + """Remove all matgl modules from sys.modules.""" + for mod in [k for k in sys.modules if k.startswith("matgl")]: + del sys.modules[mod] + + +def print_section(title: str) -> None: + """Print a formatted section header.""" + print(f"\n{'=' * 70}\n{title}\n{'=' * 70}") + + +def load_structure(path: str) -> Structure: + """Load structure from file.""" + return Structure.from_file(path) + + +def get_element_types(structure: Structure) -> tuple[str, ...]: + """Extract sorted unique element symbols.""" + return tuple(sorted({site.species_string for site in structure})) + + +def build_graph( + converter: Any, + structure: Structure, + device: torch.device, + compute_bond: Any = None, + requires_grad: bool = False, +) -> Any: + """Build graph from structure with optional gradient tracking.""" + graph, lat, _ = converter.get_graph(structure) + pos = graph.frac_coords @ lat[0] + graph.pos = pos.clone().detach().requires_grad_(requires_grad) if requires_grad else pos + graph.pbc_offshift = graph.pbc_offset @ lat[0] + + if compute_bond is not None: + bond_vec, bond_dist = compute_bond(graph) + graph.bond_vec = bond_vec + graph.bond_dist = bond_dist + + return graph.to(device) + + +def build_batched_graph( + converter: Any, + structure: Structure, + device: torch.device, + compute_bond: Any = None, + requires_grad: bool = False, + batch_size: int = BATCH_SIZE, +) -> Any: + """Build batched graph by repeating the same structure multiple times.""" + graphs = [] + for _ in range(batch_size): + graph, lat, _ = converter.get_graph(structure) + pos = graph.frac_coords @ lat[0] + graph.pos = pos.clone().detach().requires_grad_(requires_grad) if requires_grad else pos.clone() + graph.pbc_offshift = (graph.pbc_offset @ lat[0]).clone() + + if compute_bond is not None: + bond_vec, bond_dist = compute_bond(graph) + graph.bond_vec = bond_vec.clone() + graph.bond_dist = bond_dist.clone() + + # Clone all tensor attributes to ensure independence + for key in list(graph.keys()): + val = graph[key] + if isinstance(val, torch.Tensor): + graph[key] = val.clone() + + graphs.append(graph) + + batched = Batch.from_data_list(graphs) + return batched.to(device) + + +def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-6) -> bool: + """Compare two tensors element-wise.""" + if t1.shape != t2.shape: + print(f" {name}: SHAPE MISMATCH {t1.shape} vs {t2.shape}") + return False + + if torch.allclose(t1, t2, atol=atol): + print(f" {name}: MATCH") + return True + + diff = (t1 - t2).abs() + print(f" {name}: DIFF (max={diff.max():.2e}, mean={diff.mean():.2e})") + return False + + +def compare_weights(ref_model: Any, cur_model: Any) -> bool: + """Compare model weights with distance_proj layer remapping.""" + print_section("Weight Comparison") + + ref_sd, cur_sd = ref_model.state_dict(), cur_model.state_dict() + all_match = True + + # Handle distance_proj layers + print("--- distance_proj ---") + dp_keys = [f"tensor_embedding.distance_proj{i}" for i in range(1, 4)] + skip = set() + + if f"{dp_keys[0]}.weight" in ref_sd: + # Reference has separate distance_proj1/2/3 -> merge and compare + ref_w = torch.cat([ref_sd[f"{k}.weight"] for k in dp_keys], dim=0) + ref_b = torch.cat([ref_sd[f"{k}.bias"] for k in dp_keys], dim=0) + skip = {f"{k}.{p}" for k in dp_keys for p in ("weight", "bias")} + + all_match &= compare_tensors("weight", ref_w, cur_sd["tensor_embedding.distance_proj.weight"]) + all_match &= compare_tensors("bias", ref_b, cur_sd["tensor_embedding.distance_proj.bias"]) + elif "tensor_embedding.distance_proj.weight" in ref_sd: + # Reference has merged distance_proj -> compare directly + skip = {"tensor_embedding.distance_proj.weight", "tensor_embedding.distance_proj.bias"} + + all_match &= compare_tensors( + "weight", + ref_sd["tensor_embedding.distance_proj.weight"], + cur_sd["tensor_embedding.distance_proj.weight"], + ) + all_match &= compare_tensors( + "bias", + ref_sd["tensor_embedding.distance_proj.bias"], + cur_sd["tensor_embedding.distance_proj.bias"], + ) + else: + print(" WARNING: distance_proj not found in reference model") + + print("\n--- Other Parameters ---") + + for key in sorted(cur_sd): + if "distance_proj" in key: + continue + if key in ref_sd: + all_match &= compare_tensors(key, ref_sd[key], cur_sd[key]) + else: + print(f" {key}: NOT IN REFERENCE") + + for key in sorted(ref_sd): + if key not in skip and key not in cur_sd and "distance_proj" not in key: + print(f" {key}: IN REFERENCE ONLY") + all_match = False + + print(f"\n{'=' * 70}\nResult: {'ALL MATCH' if all_match else 'MISMATCH'}") + return all_match + + +def compare_forward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare forward pass energy predictions for batched graphs.""" + print_section("Forward Pass (Batched)") + + ref_model.eval() + cur_model.eval() + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) + + ref_e = ref_model(g=ref_graph, state_attr=state_attr) + cur_e = cur_model(g=cur_graph, state_attr=state_attr) + + print(f"Reference energies: {ref_e.detach().cpu().numpy()}") + print(f"Current energies: {cur_e.detach().cpu().numpy()}") + + diff = (ref_e - cur_e).abs() + print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") + + match = diff.max().item() < 1e-5 + print(f"Result: {'PASS' if match else 'FAIL'}") + return match + + +def compare_backward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare forces (F = -dE/dpos) for batched graphs.""" + print_section("Backward Pass (Forces, Batched)") + + ref_model.train() + cur_model.train() + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) + + def get_forces(model, graph): + energy = model(g=graph, state_attr=state_attr) + # Sum energies to get scalar for gradient + total_energy = energy.sum() + return -torch.autograd.grad(total_energy, graph.pos, create_graph=True)[0] + + ref_f = get_forces(ref_model, ref_graph) + cur_f = get_forces(cur_model, cur_graph) + + print(f"Reference: mean={ref_f.mean():.6f}, std={ref_f.std():.6f}") + print(f"Current: mean={cur_f.mean():.6f}, std={cur_f.std():.6f}") + + diff = (ref_f - cur_f).abs() + print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") + + match = diff.max().item() < 1e-5 + print(f"Result: {'PASS' if match else 'FAIL'}") + return match + + +def compare_double_backward( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare position gradients via loss = sum(forces^2) for batched graphs.""" + print_section("Double Backward (Position Gradients, Batched)") + + ref_model.train() + cur_model.train() + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) + + ref_graph.pos.retain_grad() + cur_graph.pos.retain_grad() + + # Reference + ref_energy = ref_model(g=ref_graph, state_attr=state_attr) + ref_total_energy = ref_energy.sum() + ref_forces = torch.autograd.grad(ref_total_energy, ref_graph.pos, create_graph=True)[0] + ref_loss = (ref_forces * ref_forces).sum() + ref_loss.backward() + ref_pos_grad = ref_graph.pos.grad.clone() + + # Current + cur_energy = cur_model(g=cur_graph, state_attr=state_attr) + cur_total_energy = cur_energy.sum() + cur_forces = torch.autograd.grad(cur_total_energy, cur_graph.pos, create_graph=True)[0] + cur_loss = (cur_forces * cur_forces).sum() + cur_loss.backward() + cur_pos_grad = cur_graph.pos.grad.clone() + + forces_diff = (ref_forces - cur_forces).abs() + print(f"Forces: max_diff={forces_diff.max():.2e}, mean_diff={forces_diff.mean():.2e}") + + print(f"Reference pos.grad: mean={ref_pos_grad.mean():.6f}, std={ref_pos_grad.std():.6f}") + print(f"Current pos.grad: mean={cur_pos_grad.mean():.6f}, std={cur_pos_grad.std():.6f}") + + if ref_pos_grad.abs().max() < 1e-10 or cur_pos_grad.abs().max() < 1e-10: + print("WARNING: Position gradient is nearly zero") + + diff = (ref_pos_grad - cur_pos_grad).abs() + print(f"Diff: max={diff.max():.2e}, mean={diff.mean():.2e}") + + match = diff.max().item() < 1e-4 + print(f"Result: {'PASS' if match else 'FAIL'}") + return match + + +def compare_param_gradients( + ref_model: Any, cur_model: Any, ref_graph: Any, cur_graph: Any, device: torch.device, batch_size: int = BATCH_SIZE +) -> bool: + """Compare gradients on all model parameters after double backward (forces loss).""" + print_section("Parameter Gradients (Double Backward, Batched)") + + ref_model.train() + cur_model.train() + state_attr = torch.tensor([[0.0, 0.0]] * batch_size, device=device) + + # Zero gradients + ref_model.zero_grad() + cur_model.zero_grad() + + # Double backward: compute forces, then loss = sum(forces^2) + # Reference + ref_energy = ref_model(g=ref_graph, state_attr=state_attr) + ref_total_energy = ref_energy.sum() + ref_forces = torch.autograd.grad(ref_total_energy, ref_graph.pos, create_graph=True)[0] + ref_loss = (ref_forces * ref_forces).sum() + ref_loss.backward() + + # Current + cur_energy = cur_model(g=cur_graph, state_attr=state_attr) + cur_total_energy = cur_energy.sum() + cur_forces = torch.autograd.grad(cur_total_energy, cur_graph.pos, create_graph=True)[0] + cur_loss = (cur_forces * cur_forces).sum() + cur_loss.backward() + + print(f"Reference loss: {ref_loss.item():.6f}") + print(f"Current loss: {cur_loss.item():.6f}") + + # Build mapping for distance_proj layers (merged in current, separate in reference) + ref_sd = {k: p for k, p in ref_model.named_parameters()} + cur_sd = {k: p for k, p in cur_model.named_parameters()} + + all_match = True + max_diff_overall = 0.0 + mismatched_params = [] + + # Handle merged distance_proj layers + print("--- distance_proj (merged) ---") + dp_keys = [f"tensor_embedding.distance_proj{i}" for i in range(1, 4)] + skip_ref_keys = set() + skip_cur_keys = set() + for suffix in [".weight", ".bias"]: + ref_grads = [] + for dp_key in dp_keys: + key = dp_key + suffix + if key in ref_sd: + skip_ref_keys.add(key) + if ref_sd[key].grad is not None: + ref_grads.append(ref_sd[key].grad) + + cur_key = "tensor_embedding.distance_proj" + suffix + skip_cur_keys.add(cur_key) + + if not ref_grads: + # Reference doesn't have separate distance_proj layers, compare directly + ref_key = cur_key + if ref_key in ref_sd: + ref_param = ref_sd[ref_key] + cur_param = cur_sd.get(cur_key) + if ref_param.grad is None and (cur_param is None or cur_param.grad is None): + print(f" distance_proj{suffix}: NO GRAD (both)") + elif ref_param.grad is None: + print(f" distance_proj{suffix}: NO GRAD (reference)") + all_match = False + elif cur_param is None or cur_param.grad is None: + print(f" distance_proj{suffix}: NO GRAD (current)") + all_match = False + else: + diff = (ref_param.grad - cur_param.grad).abs() + max_diff = diff.max().item() + max_diff_overall = max(max_diff_overall, max_diff) + if max_diff > 5e-5: + mismatched_params.append(f"distance_proj{suffix}") + all_match = False + print(f" distance_proj{suffix}: DIFF (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: MATCH (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: NOT FOUND IN REFERENCE") + else: + # Reference has separate layers, concatenate and compare + ref_grad = torch.cat(ref_grads, dim=0) + if cur_key in cur_sd and cur_sd[cur_key].grad is not None: + cur_grad = cur_sd[cur_key].grad + if ref_grad.shape == cur_grad.shape: + diff = (ref_grad - cur_grad).abs() + max_diff = diff.max().item() + max_diff_overall = max(max_diff_overall, max_diff) + if max_diff > 5e-5: + mismatched_params.append(f"distance_proj{suffix}") + all_match = False + print(f" distance_proj{suffix}: DIFF (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: MATCH (max={max_diff:.2e})") + else: + print(f" distance_proj{suffix}: SHAPE MISMATCH {ref_grad.shape} vs {cur_grad.shape}") + all_match = False + else: + print(f" distance_proj{suffix}: NO GRAD (current)") + all_match = False + + # Compare other parameters + print("\n--- Other Parameters ---") + for cur_key, cur_param in cur_sd.items(): + if "distance_proj" in cur_key: + continue + + if cur_key in ref_sd: + ref_param = ref_sd[cur_key] + if ref_param.grad is None and cur_param.grad is None: + print(f" {cur_key}: NO GRAD (both)") + continue + if ref_param.grad is None: + print(f" {cur_key}: NO GRAD (reference)") + all_match = False + continue + if cur_param.grad is None: + print(f" {cur_key}: NO GRAD (current)") + all_match = False + continue + + if ref_param.grad.shape != cur_param.grad.shape: + print(f" {cur_key}: SHAPE MISMATCH {ref_param.grad.shape} vs {cur_param.grad.shape}") + all_match = False + continue + + diff = (ref_param.grad - cur_param.grad).abs() + max_diff = diff.max().item() + max_diff_overall = max(max_diff_overall, max_diff) + + if max_diff > 5e-5: + mismatched_params.append(cur_key) + all_match = False + print(f" {cur_key}: DIFF (max={max_diff:.2e}, mean={diff.mean():.2e})") + else: + print(f" {cur_key}: MATCH (max={max_diff:.2e})") + else: + print(f" {cur_key}: NOT IN REFERENCE") + + # Check for params in reference only + for ref_key in ref_sd: + if ref_key not in skip_ref_keys and ref_key not in cur_sd: + print(f" {ref_key}: IN REFERENCE ONLY") + + print(f"\nMax diff overall: {max_diff_overall:.2e}") + if mismatched_params: + print(f"Mismatched params: {mismatched_params}") + + print(f"Result: {'PASS' if all_match else 'FAIL'}") + return all_match + + +def main(structure_path: str, matgl_main_path: str, seed: int = 42, pretrained_path: str | None = None) -> bool: + """Run all comparison tests between reference and current implementations.""" + print_section("TensorNet Comparison: matgl-main vs Current") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Seed: {seed}, Device: {device}, Batch size: {BATCH_SIZE}") + print(f"matgl-main path: {matgl_main_path}") + if pretrained_path: + print(f"Pretrained model: {pretrained_path}") + + structure = load_structure(structure_path) + element_types = get_element_types(structure) + print(f"Structure: {structure_path} ({len(structure)} atoms, elements: {element_types})") + + # Reference model (matgl-main) + clear_matgl_modules() + sys.path.insert(0, matgl_main_path) + + from matgl.ext._pymatgen_pyg import Structure2Graph as RefConverter + from matgl.graph._compute_pyg import compute_pair_vector_and_distance as ref_compute_bond + from matgl.models._tensornet_pyg import TensorNet as RefTensorNet + from matgl.utils.io import load_model as ref_load_model + + if pretrained_path: + # Load pre-trained model (Potential wrapper contains TensorNet) + ref_potential = ref_load_model(pretrained_path) + ref_model = ref_potential.model.to(device) + ref_cutoff = ref_model.cutoff + ref_element_types = ref_model.element_types + else: + model_config = {**MODEL_CONFIG, "element_types": element_types} + torch.manual_seed(seed) + ref_model = RefTensorNet(**model_config).to(device) + ref_cutoff = MODEL_CONFIG["cutoff"] + ref_element_types = element_types + + ref_converter = RefConverter(element_types=ref_element_types, cutoff=ref_cutoff) + + # Build batched graphs for reference model + ref_graph = build_batched_graph(ref_converter, structure, device, ref_compute_bond) + ref_graph_grad = build_batched_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + ref_graph_grad2 = build_batched_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + ref_graph_param = build_batched_graph(ref_converter, structure, device, ref_compute_bond, requires_grad=True) + + sys.path.pop(0) + + # Current model (src) + clear_matgl_modules() + + from matgl.ext._pymatgen_pyg import Structure2Graph as CurConverter + from matgl.models._tensornet_pyg import TensorNet as CurTensorNet + from matgl.utils.io import load_model as cur_load_model + + if pretrained_path: + # Load pre-trained model (Potential wrapper contains TensorNet) + cur_potential = cur_load_model(pretrained_path) + cur_model = cur_potential.model.to(device) + cur_cutoff = cur_model.cutoff + cur_element_types = cur_model.element_types + else: + model_config = {**MODEL_CONFIG, "element_types": element_types} + torch.manual_seed(seed) + cur_model = CurTensorNet(**model_config).to(device) + cur_cutoff = MODEL_CONFIG["cutoff"] + cur_element_types = element_types + + cur_converter = CurConverter(element_types=cur_element_types, cutoff=cur_cutoff) + + # Build batched graphs for current model + cur_graph = build_batched_graph(cur_converter, structure, device) + cur_graph_grad = build_batched_graph(cur_converter, structure, device, requires_grad=True) + cur_graph_grad2 = build_batched_graph(cur_converter, structure, device, requires_grad=True) + cur_graph_param = build_batched_graph(cur_converter, structure, device, requires_grad=True) + + print(f"Models: {sum(p.numel() for p in ref_model.parameters())} params each") + print(f"Batched graph: {ref_graph.num_nodes} nodes, {ref_graph.num_edges} edges") + + # Run comparisons + results = { + "Weights": compare_weights(ref_model, cur_model), + "Forward": compare_forward(ref_model, cur_model, ref_graph, cur_graph, device), + "Backward": compare_backward(ref_model, cur_model, ref_graph_grad, cur_graph_grad, device), + "Double Backward": compare_double_backward(ref_model, cur_model, ref_graph_grad2, cur_graph_grad2, device), + "Param Gradients": compare_param_gradients(ref_model, cur_model, ref_graph_param, cur_graph_param, device), + } + + # Summary + print_section("SUMMARY") + all_pass = all(results.values()) + for name, passed in results.items(): + print(f" {name}: {'PASS' if passed else 'FAIL'}") + + print(f"\n{'=' * 70}") + print("ALL TESTS PASSED" if all_pass else "SOME TESTS FAILED") + print("=" * 70) + + assert all_pass, "Model comparison tests failed" + return all_pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Compare TensorNet implementations") + parser.add_argument( + "--structure", + "-s", + required=True, + help="Path to structure file (any format supported by pymatgen)", + ) + parser.add_argument( + "--matgl-main-path", + default=os.environ.get("MATGL_MAIN_PATH", DEFAULT_MATGL_MAIN_PATH), + help="Path to matgl-main/src (default: $MATGL_MAIN_PATH or ../matgl-main/src)", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--pretrained", + "-p", + default=None, + help="Path to pretrained model directory (e.g., pretrained_models/TensorNet-MatPES-PBE-v2025.1-PES)", + ) + + args = parser.parse_args() + main( + structure_path=args.structure, + matgl_main_path=args.matgl_main_path, + seed=args.seed, + pretrained_path=args.pretrained, + ) diff --git a/pyproject.toml b/pyproject.toml index 90635ca3..dbb181b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ classifiers = [ dependencies = [ "ase", "torch<=2.7.0", # TODO: Remove this pin. For some reason, torch 2.9 gives different results. + "warp-lang>=1.10.1", "torchdata", "pymatgen", "lightning<=2.6.0.dev20251123", @@ -158,6 +159,8 @@ lint.isort.required-imports = ["from __future__ import annotations"] "tests/**/*" = ["D", "PERF"] "docs/**/*" = ["D"] "examples/**/*" = ["D"] +"src/matgl/kernels/*" = ["D100", "D103", "E741", "I002"] +"src/matgl/ops/*" = ["D100", "D103"] [tool.pytest.ini_options] addopts = "--durations=30 --quiet -rXs --color=yes -p no:warnings" @@ -171,6 +174,10 @@ exclude = ['examples', 'tests'] module = ["requests.*", "tabulate.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["matgl.kernels.*"] +ignore_errors = true + [tool.coverage.run] relative_files = true diff --git a/src/matgl/kernels/__init__.py b/src/matgl/kernels/__init__.py new file mode 100644 index 00000000..7367c920 --- /dev/null +++ b/src/matgl/kernels/__init__.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp GPU kernels for TensorNet operations.""" + +from __future__ import annotations + +import warp as wp + +from .compose_tensor import generate_compose_tensor +from .decompose_tensor import generate_decompose_tensor +from .equivariant_o3_matmul import generate_tensor_matmul_o3_3x3 +from .equivariant_so3_matmul import generate_tensor_matmul_so3_3x3 +from .graph_transform import convert_to_sparse, count_row_col +from .tensor_norm3 import generate_tensor_norm3 +from .tensornet_mp import generate_message_passing +from .tensornet_radial_mp import generate_radial_message_passing +from .utils import add_module, get_module, get_stream + +wp.init() + + +__all__ = [ + "add_module", + "add_module", + "convert_to_sparse", + "convert_to_sparse", + "count_row_col", + "count_row_col", + "generate_compose_tensor", + "generate_decompose_tensor", + "generate_message_passing", + "generate_message_passing", + "generate_radial_message_passing", + "generate_radial_message_passing", + "generate_tensor_matmul_o3_3x3", + "generate_tensor_matmul_so3_3x3", + "generate_tensor_norm3", + "get_module", + "get_stream", +] diff --git a/src/matgl/kernels/compose_tensor.py b/src/matgl/kernels/compose_tensor.py new file mode 100644 index 00000000..1ed84191 --- /dev/null +++ b/src/matgl/kernels/compose_tensor.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for composing 3x3 tensors from I, A, S components.""" + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_compose_tensor(dtype: str, h_last: bool = True, use_irmem: bool = True): + """Generate Warp kernels for composing a 3x3 tensor from I, A, S components.""" + dtype_wp = get_wp_fp_dtype(dtype) + if not use_irmem: + raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") + if not h_last: + raise ValueError(f"only supporting h_last True but got {h_last}") + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + dim = 3 if use_irmem else 4 + + def compose_tensor_fwd( + I: wp.array(ndim=dim, dtype=dtype_wp), + A: wp.array(ndim=dim, dtype=dtype_wp), + S: wp.array(ndim=dim, dtype=dtype_wp), + X: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + X_reg = mat3x3() + + I_reg = I[b, 0, h] + A_reg = vec3() + S_reg = vec5() + + for i in range(3): + A_reg[i] = A[b, i, h] + + for i in range(5): + S_reg[i] = S[b, i, h] + + for i in range(3): + X_reg[i, i] += I_reg + + cnt = wp.int32(0) + for i in range(3): + for j in range(i + 1, 3): + X_reg[i, j] += A_reg[cnt] + X_reg[j, i] -= A_reg[cnt] + cnt += 1 + + trace_S = -(S_reg[0] + S_reg[3]) + cnt = wp.int32(0) + for i in range(2): + X_reg[i, i] += S_reg[cnt] + cnt += 1 + for j in range(i + 1, 3): + X_reg[i, j] += S_reg[cnt] + X_reg[j, i] += S_reg[cnt] + cnt += 1 + + X_reg[2, 2] += trace_S + + for i in range(3): + for j in range(3): + X[b, i, j, h] = X_reg[i, j] + + def compose_tensor_bwd( + dX: wp.array(ndim=4, dtype=dtype_wp), + dI: wp.array(ndim=dim, dtype=dtype_wp), + dA: wp.array(ndim=dim, dtype=dtype_wp), + dS: wp.array(ndim=dim, dtype=dtype_wp), + ): + b, h = wp.tid() + + dX_reg = mat3x3() + for i in range(3): + for j in range(3): + dX_reg[i, j] = dX[b, i, j, h] + + dI_reg = dI.dtype(0) + dA_reg = vec3(dX.dtype(0)) + dS_reg = vec5(dX.dtype(0)) + + for i in range(3): + dI_reg += dX_reg[i, i] + + cnt = wp.int32(0) + for i in range(3): + for j in range(i + 1, 3): + dA_reg[cnt] += dX_reg[i, j] + dA_reg[cnt] -= dX_reg[j, i] + cnt += 1 + + dS_reg[0] += dX_reg[0, 0] + dS_reg[0] -= dX_reg[2, 2] + + dS_reg[1] += dX_reg[0, 1] + dS_reg[1] += dX_reg[1, 0] + + dS_reg[2] += dX_reg[0, 2] + dS_reg[2] += dX_reg[2, 0] + + dS_reg[3] += dX_reg[1, 1] + dS_reg[3] -= dX_reg[2, 2] + + dS_reg[4] += dX_reg[1, 2] + dS_reg[4] += dX_reg[2, 1] + + dI[b, 0, h] = dI_reg + + for i in range(3): + dA[b, i, h] = dA_reg[i] + + for i in range(5): + dS[b, i, h] = dS_reg[i] + + def compose_tensor_bwd_bwd( + dI: wp.array(ndim=dim, dtype=dtype_wp), + dA: wp.array(ndim=dim, dtype=dtype_wp), + dS: wp.array(ndim=dim, dtype=dtype_wp), + d2X: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + d2X_reg = mat3x3() + + dI_reg = dI[b, 0, h] + dA_reg = vec3(dI.dtype(0)) + dS_reg = vec5(dI.dtype(0)) + + for i in range(3): + dA_reg[i] = dA[b, i, h] + + for i in range(5): + dS_reg[i] = dS[b, i, h] + + for i in range(3): + d2X_reg[i, i] += dI_reg + + cnt = wp.int32(0) + for i in range(3): + for j in range(i + 1, 3): + d2X_reg[i, j] += dA_reg[cnt] + d2X_reg[j, i] -= dA_reg[cnt] + cnt += 1 + + cnt = wp.int32(0) + for i in range(2): + d2X_reg[i, i] += dS_reg[cnt] + cnt += 1 + + for j in range(i + 1, 3): + d2X_reg[i, j] += dS_reg[cnt] + d2X_reg[j, i] += dS_reg[cnt] + cnt += 1 + + d2X_reg[2, 2] -= dS_reg[0] + d2X_reg[2, 2] -= dS_reg[3] + + for i in range(3): + for j in range(3): + d2X[b, i, j, h] = d2X_reg[i, j] + + return ( + wp.Kernel( + compose_tensor_fwd, + key=f"compose_tensor_{dtype}", + module=wp.get_module(f"compose_tensor_{dtype}"), + ), + wp.Kernel( + compose_tensor_bwd, + key=f"compose_tensor_bwd_{dtype}", + module=wp.get_module(f"compose_tensor_bwd_{dtype}"), + ), + wp.Kernel( + compose_tensor_bwd_bwd, + key=f"compose_tensor_bwd_bwd_{dtype}", + module=wp.get_module(f"compose_tensor_bwd_bwd_{dtype}"), + ), + ) + + +( + compose_tensor_fwd_fp64, + compose_tensor_bwd_fp64, + compose_tensor_bwd_bwd_fp64, +) = generate_compose_tensor("float64") +( + compose_tensor_fwd_fp32, + compose_tensor_bwd_fp32, + compose_tensor_bwd_bwd_fp32, +) = generate_compose_tensor("float32") +( + compose_tensor_fwd_fp16, + compose_tensor_bwd_fp16, + compose_tensor_bwd_bwd_fp16, +) = generate_compose_tensor("float16") + +add_module("compose_tensor_fwd", ["float64"], compose_tensor_fwd_fp64) +add_module("compose_tensor_bwd", ["float64"], compose_tensor_bwd_fp64) +add_module("compose_tensor_bwd_bwd", ["float64"], compose_tensor_bwd_bwd_fp64) + +add_module("compose_tensor_fwd", ["float32"], compose_tensor_fwd_fp32) +add_module("compose_tensor_bwd", ["float32"], compose_tensor_bwd_fp32) +add_module("compose_tensor_bwd_bwd", ["float32"], compose_tensor_bwd_bwd_fp32) + +add_module("compose_tensor_fwd", ["float16"], compose_tensor_fwd_fp16) +add_module("compose_tensor_bwd", ["float16"], compose_tensor_bwd_fp16) +add_module("compose_tensor_bwd_bwd", ["float16"], compose_tensor_bwd_bwd_fp16) diff --git a/src/matgl/kernels/decompose_tensor.py b/src/matgl/kernels/decompose_tensor.py new file mode 100644 index 00000000..34652fc5 --- /dev/null +++ b/src/matgl/kernels/decompose_tensor.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for decomposing 3x3 tensors into I, A, S components.""" + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_decompose_tensor(dtype: str, h_last: bool = True, use_irmem: bool = True): + """Generate Warp kernels for decomposing a 3x3 tensor into I, A, S components.""" + dtype_wp = get_wp_fp_dtype(dtype) + + if not use_irmem: + raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") + + if not h_last: + raise ValueError(f"only supporting h_last True but got {h_last}") + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + dim = 3 if use_irmem else 4 + + def decompose_tensor_fwd( + X: wp.array(ndim=4, dtype=dtype_wp), + I: wp.array(ndim=dim, dtype=dtype_wp), + A: wp.array(ndim=dim, dtype=dtype_wp), + S: wp.array(ndim=dim, dtype=dtype_wp), + ): + b, h = wp.tid() + + X_reg = mat3x3() + for i in range(3): + for j in range(3): + X_reg[i, j] = X[b, i, j, h] + + res = X.dtype(0) + for i in range(3): + res += X_reg[i, i] + res = res / X.dtype(3.0) + + I[b, 0, h] = res + + denom = X.dtype(2.0) + cnt = wp.int32(0) + for i in range(2): + for j in range(i + 1, 3): + A[b, cnt, h] = (X_reg[i, j] - X_reg[j, i]) / denom + cnt += 1 + + cnt = wp.int32(0) + for i in range(2): + S[b, cnt, h] = X_reg[i, i] - res + cnt += 1 + + for j in range(i + 1, 3): + S[b, cnt, h] = (X_reg[i, j] + X_reg[j, i]) / denom + cnt += 1 + + def decompose_tensor_bwd( + dI: wp.array(ndim=dim, dtype=dtype_wp), + dA: wp.array(ndim=dim, dtype=dtype_wp), + dS: wp.array(ndim=dim, dtype=dtype_wp), + dX: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + dX_reg = mat3x3(dX.dtype(0)) + + dI_reg = dI[b, 0, h] + dA_reg = vec3(dX.dtype(0)) + dS_reg = vec5(dX.dtype(0)) + + for i in range(3): + dA_reg[i] = dA[b, i, h] + + for i in range(5): + dS_reg[i] = dS[b, i, h] + + for i in range(3): + dX_reg[i, i] = dI_reg / dI.dtype(3.0) + + denom = dX.dtype(2.0) + + cnt = wp.int32(0) + + for i in range(3): + for j in range(i + 1, 3): + dX_reg[i, j] += dA_reg[cnt] / denom + dX_reg[j, i] -= dA_reg[cnt] / denom + cnt += 1 + + cnt = wp.int32(0) + for i in range(2): + dX_reg[i, i] += dS_reg[cnt] + for j in range(3): + dX_reg[j, j] -= dS_reg[cnt] / dI.dtype(3.0) + + cnt += 1 + + for j in range(i + 1, 3): + dX_reg[i, j] += dS_reg[cnt] / denom + dX_reg[j, i] += dS_reg[cnt] / denom + cnt += 1 + + for i in range(3): + for j in range(3): + dX[b, i, j, h] = dX_reg[i, j] + + def decompose_tensor_bwd_bwd( + dX: wp.array(ndim=4, dtype=dtype_wp), + d2I: wp.array(ndim=dim, dtype=dtype_wp), + d2A: wp.array(ndim=dim, dtype=dtype_wp), + d2S: wp.array(ndim=dim, dtype=dtype_wp), + ): + b, h = wp.tid() + + dX_reg = mat3x3(dX.dtype(0)) + d2I_reg = dX.dtype(0) + d2A_reg = vec3(dX.dtype(0)) + d2S_reg = vec5(dX.dtype(0)) + + for i in range(3): + for j in range(3): + dX_reg[i, j] = dX[b, i, j, h] + + for i in range(3): + d2I_reg += dX_reg[i, i] / d2I.dtype(3.0) + + denom = dX.dtype(2.0) + + cnt = wp.int32(0) + for i in range(3): + for j in range(i + 1, 3): + d2A_reg[cnt] += dX_reg[i, j] / denom + d2A_reg[cnt] -= dX_reg[j, i] / denom + cnt += 1 + + cnt = wp.int32(0) + for i in range(2): + d2S_reg[cnt] += dX_reg[i, i] + for j in range(3): + d2S_reg[cnt] -= dX_reg[j, j] / d2I.dtype(3.0) + cnt += 1 + + for j in range(i + 1, 3): + d2S_reg[cnt] += dX_reg[i, j] / denom + d2S_reg[cnt] += dX_reg[j, i] / denom + cnt += 1 + + d2I[b, 0, h] = d2I_reg + for i in range(3): + d2A[b, i, h] = d2A_reg[i] + + for i in range(5): + d2S[b, i, h] = d2S_reg[i] + + return ( + wp.Kernel( + decompose_tensor_fwd, + key=f"decompose_tensor_{dtype}", + module=wp.get_module(f"decompose_tensor_{dtype}"), + ), + wp.Kernel( + decompose_tensor_bwd, + key=f"decompose_tensor_bwd_{dtype}", + module=wp.get_module(f"decompose_tensor_bwd_{dtype}"), + ), + wp.Kernel( + decompose_tensor_bwd_bwd, + key=f"decompose_tensor_bwd_bwd_{dtype}", + module=wp.get_module(f"decompose_tensor_bwd_bwd_{dtype}"), + ), + ) + + +decompose_tensor_fwd_fp64, decompose_tensor_bwd_fp64, decompose_tensor_bwd_bwd_fp64 = generate_decompose_tensor( + "float64" +) +decompose_tensor_fwd_fp32, decompose_tensor_bwd_fp32, decompose_tensor_bwd_bwd_fp32 = generate_decompose_tensor( + "float32" +) +decompose_tensor_fwd_fp16, decompose_tensor_bwd_fp16, decompose_tensor_bwd_bwd_fp16 = generate_decompose_tensor( + "float16" +) + +add_module("decompose_tensor_fwd", ["float64"], decompose_tensor_fwd_fp64) +add_module("decompose_tensor_bwd", ["float64"], decompose_tensor_bwd_fp64) +add_module("decompose_tensor_bwd_bwd", ["float64"], decompose_tensor_bwd_bwd_fp64) + +add_module("decompose_tensor_fwd", ["float32"], decompose_tensor_fwd_fp32) +add_module("decompose_tensor_bwd", ["float32"], decompose_tensor_bwd_fp32) +add_module("decompose_tensor_bwd_bwd", ["float32"], decompose_tensor_bwd_bwd_fp32) + +add_module("decompose_tensor_fwd", ["float16"], decompose_tensor_fwd_fp16) +add_module("decompose_tensor_bwd", ["float16"], decompose_tensor_bwd_fp16) +add_module("decompose_tensor_bwd_bwd", ["float16"], decompose_tensor_bwd_bwd_fp16) diff --git a/src/matgl/kernels/equivariant_o3_matmul.py b/src/matgl/kernels/equivariant_o3_matmul.py new file mode 100644 index 00000000..a7d36a87 --- /dev/null +++ b/src/matgl/kernels/equivariant_o3_matmul.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for O(3)-equivariant 3x3 tensor matrix multiplication.""" + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_tensor_matmul_o3_3x3(dtype: str): + """Generate Warp kernels for O(3)-equivariant 3x3 matrix multiplication: C = AB + BA.""" + dtype_wp = get_wp_fp_dtype(dtype) + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + def tensor_matmul_o3_3x3_fwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + c_reg[i, j] += a_reg[i, k] * b_reg[k, j] + b_reg[i, k] * a_reg[k, j] + + for i in range(3): + for j in range(3): + C[b, i, j, h] = c_reg[i, j] + + def tensor_matmul_o3_3x3_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + da_reg[i, j] += dc_reg[i, k] * b_reg[j, k] + da_reg[j, k] += dc_reg[i, k] * b_reg[i, j] + db_reg[i, j] += dc_reg[i, k] * a_reg[j, k] + db_reg[j, k] += dc_reg[i, k] * a_reg[i, j] + + for i in range(3): + for j in range(3): + dA[b, i, j, h] = da_reg[i, j] + dB[b, i, j, h] = db_reg[i, j] + + def tensor_matmul_o3_3x3_bwd_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + d2A: wp.array(ndim=4, dtype=dtype_wp), + d2B: wp.array(ndim=4, dtype=dtype_wp), + d2C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + d2a_reg = mat3x3() + d2b_reg = mat3x3() + + d2c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + da_reg[i, j] = dA[b, i, j, h] + db_reg[i, j] = dB[b, i, j, h] + + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + d2a_reg[i, j] += dc_reg[i, k] * db_reg[j, k] + d2a_reg[j, i] += dc_reg[k, i] * db_reg[k, j] + + d2b_reg[i, j] += dc_reg[i, k] * da_reg[j, k] + d2b_reg[j, i] += dc_reg[k, i] * da_reg[k, j] + + for i in range(3): + for j in range(3): + for k in range(3): + # grad_grad_x @ y + x @ grad_grad_y + d2c_reg[i, j] += da_reg[i, k] * b_reg[k, j] + d2c_reg[i, j] += a_reg[i, k] * db_reg[k, j] + + d2c_reg[i, j] += db_reg[i, k] * a_reg[k, j] + d2c_reg[i, j] += b_reg[i, k] * da_reg[k, j] + + for i in range(3): + for j in range(3): + d2A[b, i, j, h] = d2a_reg[i, j] + d2B[b, i, j, h] = d2b_reg[i, j] + d2C[b, i, j, h] = d2c_reg[i, j] + + return ( + wp.Kernel( + tensor_matmul_o3_3x3_fwd, + key=f"tensor_matmul_o3_3x3_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_{dtype}"), + ), + wp.Kernel( + tensor_matmul_o3_3x3_bwd, + key=f"tensor_matmul_o3_3x3_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_bwd_{dtype}"), + ), + wp.Kernel( + tensor_matmul_o3_3x3_bwd_bwd, + key=f"tensor_matmul_o3_3x3_bwd_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_o3_3x3_bwd_bwd_{dtype}"), + ), + ) + + +( + tensor_matmul_o3_3x3_fwd_fp64, + tensor_matmul_o3_3x3_bwd_fp64, + tensor_matmul_o3_3x3_bwd_bwd_fp64, +) = generate_tensor_matmul_o3_3x3("float64") +( + tensor_matmul_o3_3x3_fwd_fp32, + tensor_matmul_o3_3x3_bwd_fp32, + tensor_matmul_o3_3x3_bwd_bwd_fp32, +) = generate_tensor_matmul_o3_3x3("float32") +( + tensor_matmul_o3_3x3_fwd_fp16, + tensor_matmul_o3_3x3_bwd_fp16, + tensor_matmul_o3_3x3_bwd_bwd_fp16, +) = generate_tensor_matmul_o3_3x3("float16") + +add_module("tensor_matmul_o3_3x3_fwd", ["float64"], tensor_matmul_o3_3x3_fwd_fp64) +add_module("tensor_matmul_o3_3x3_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_fp64) +add_module("tensor_matmul_o3_3x3_bwd_bwd", ["float64"], tensor_matmul_o3_3x3_bwd_bwd_fp64) + +add_module("tensor_matmul_o3_3x3_fwd", ["float32"], tensor_matmul_o3_3x3_fwd_fp32) +add_module("tensor_matmul_o3_3x3_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_fp32) +add_module("tensor_matmul_o3_3x3_bwd_bwd", ["float32"], tensor_matmul_o3_3x3_bwd_bwd_fp32) + +add_module("tensor_matmul_o3_3x3_fwd", ["float16"], tensor_matmul_o3_3x3_fwd_fp16) +add_module("tensor_matmul_o3_3x3_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_fp16) +add_module("tensor_matmul_o3_3x3_bwd_bwd", ["float16"], tensor_matmul_o3_3x3_bwd_bwd_fp16) diff --git a/src/matgl/kernels/equivariant_so3_matmul.py b/src/matgl/kernels/equivariant_so3_matmul.py new file mode 100644 index 00000000..38ee5065 --- /dev/null +++ b/src/matgl/kernels/equivariant_so3_matmul.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for SO(3)-equivariant 3x3 tensor matrix multiplication.""" + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_tensor_matmul_so3_3x3(dtype: str): + """Generate Warp kernels for SO(3)-equivariant 3x3 matrix multiplication: C = AB.""" + dtype_wp = get_wp_fp_dtype(dtype) + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + def tensor_matmul_so3_3x3_fwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + c_reg[i, j] += a_reg[i, k] * b_reg[k, j] + + for i in range(3): + for j in range(3): + C[b, i, j, h] = c_reg[i, j] + + def tensor_matmul_so3_3x3_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + da_reg[i, k] += dc_reg[i, j] * b_reg[k, j] + db_reg[k, j] += dc_reg[i, j] * a_reg[i, k] + + for i in range(3): + for j in range(3): + dA[b, i, j, h] = da_reg[i, j] + dB[b, i, j, h] = db_reg[i, j] + + def tensor_matmul_so3_3x3_bwd_bwd( + A: wp.array(ndim=4, dtype=dtype_wp), + B: wp.array(ndim=4, dtype=dtype_wp), + dA: wp.array(ndim=4, dtype=dtype_wp), + dB: wp.array(ndim=4, dtype=dtype_wp), + dC: wp.array(ndim=4, dtype=dtype_wp), + d2A: wp.array(ndim=4, dtype=dtype_wp), + d2B: wp.array(ndim=4, dtype=dtype_wp), + d2C: wp.array(ndim=4, dtype=dtype_wp), + ): + b, h = wp.tid() + + a_reg = mat3x3() + b_reg = mat3x3() + + da_reg = mat3x3() + db_reg = mat3x3() + + dc_reg = mat3x3() + + d2a_reg = mat3x3() + d2b_reg = mat3x3() + + d2c_reg = mat3x3() + + for i in range(3): + for j in range(3): + a_reg[i, j] = A[b, i, j, h] + b_reg[i, j] = B[b, i, j, h] + + da_reg[i, j] = dA[b, i, j, h] + db_reg[i, j] = dB[b, i, j, h] + + dc_reg[i, j] = dC[b, i, j, h] + + for i in range(3): + for j in range(3): + for k in range(3): + d2a_reg[i, k] += dc_reg[i, j] * db_reg[k, j] + d2b_reg[k, j] += dc_reg[i, j] * da_reg[i, k] + + for i in range(3): + for j in range(3): + for k in range(3): + d2c_reg[i, j] += da_reg[i, k] * b_reg[k, j] + d2c_reg[i, j] += a_reg[i, k] * db_reg[k, j] + + for i in range(3): + for j in range(3): + d2A[b, i, j, h] = d2a_reg[i, j] + d2B[b, i, j, h] = d2b_reg[i, j] + d2C[b, i, j, h] = d2c_reg[i, j] + + return ( + wp.Kernel( + tensor_matmul_so3_3x3_fwd, + key=f"tensor_matmul_so3_3x3_{dtype}", + module=wp.get_module(f"tensor_matmul_so3_3x3_{dtype}"), + ), + wp.Kernel( + tensor_matmul_so3_3x3_bwd, + key=f"tensor_matmul_so3_3x3_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_so3_3x3_bwd_{dtype}"), + ), + wp.Kernel( + tensor_matmul_so3_3x3_bwd_bwd, + key=f"tensor_matmul_so3_3x3_bwd_bwd_{dtype}", + module=wp.get_module(f"tensor_matmul_so3_3x3_bwd_bwd_{dtype}"), + ), + ) + + +( + tensor_matmul_so3_3x3_fwd_fp64, + tensor_matmul_so3_3x3_bwd_fp64, + tensor_matmul_so3_3x3_bwd_bwd_fp64, +) = generate_tensor_matmul_so3_3x3("float64") +( + tensor_matmul_so3_3x3_fwd_fp32, + tensor_matmul_so3_3x3_bwd_fp32, + tensor_matmul_so3_3x3_bwd_bwd_fp32, +) = generate_tensor_matmul_so3_3x3("float32") +( + tensor_matmul_so3_3x3_fwd_fp16, + tensor_matmul_so3_3x3_bwd_fp16, + tensor_matmul_so3_3x3_bwd_bwd_fp16, +) = generate_tensor_matmul_so3_3x3("float16") + +add_module("tensor_matmul_so3_3x3_fwd", ["float64"], tensor_matmul_so3_3x3_fwd_fp64) +add_module("tensor_matmul_so3_3x3_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_fp64) +add_module("tensor_matmul_so3_3x3_bwd_bwd", ["float64"], tensor_matmul_so3_3x3_bwd_bwd_fp64) + +add_module("tensor_matmul_so3_3x3_fwd", ["float32"], tensor_matmul_so3_3x3_fwd_fp32) +add_module("tensor_matmul_so3_3x3_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_fp32) +add_module("tensor_matmul_so3_3x3_bwd_bwd", ["float32"], tensor_matmul_so3_3x3_bwd_bwd_fp32) + +add_module("tensor_matmul_so3_3x3_fwd", ["float16"], tensor_matmul_so3_3x3_fwd_fp16) +add_module("tensor_matmul_so3_3x3_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_fp16) +add_module("tensor_matmul_so3_3x3_bwd_bwd", ["float16"], tensor_matmul_so3_3x3_bwd_bwd_fp16) diff --git a/src/matgl/kernels/graph_transform.py b/src/matgl/kernels/graph_transform.py new file mode 100644 index 00000000..9d7e6077 --- /dev/null +++ b/src/matgl/kernels/graph_transform.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for graph edge index transformation to sparse CSR format.""" + +import warp as wp + + +@wp.kernel +def count_row_col( + edge_index: wp.array(ndim=2, dtype=wp.int32), + row_count: wp.array(ndim=1, dtype=wp.int32), + col_count: wp.array(ndim=1, dtype=wp.int32), +): + tid = wp.tid() + + shift = edge_index.dtype(1) + wp.atomic_add(row_count, edge_index[0, tid] + shift, wp.int32(1)) + wp.atomic_add(col_count, edge_index[1, tid] + shift, wp.int32(1)) + + +@wp.kernel +def convert_to_sparse( + edge_index: wp.array(ndim=2, dtype=wp.int32), + row_count: wp.array(ndim=1, dtype=wp.int32), + col_count: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + col_indptr: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + col_indices: wp.array(ndim=1, dtype=wp.int32), + row_data: wp.array(ndim=1, dtype=wp.int32), + col_data: wp.array(ndim=1, dtype=wp.int32), +): + tid = wp.tid() + shift = edge_index.dtype(1) + + src_id = edge_index[0, tid] + dst_id = edge_index[1, tid] + + src_cnt = wp.atomic_sub(row_count, src_id + shift, wp.int32(1)) + dst_cnt = wp.atomic_sub(col_count, dst_id + shift, wp.int32(1)) + + row_indices[row_indptr[src_id + shift] - src_cnt] = dst_id + row_data[row_indptr[src_id + shift] - src_cnt] = wp.int32(tid) + + col_indices[col_indptr[dst_id + shift] - dst_cnt] = src_id + col_data[col_indptr[dst_id + shift] - dst_cnt] = wp.int32(tid) diff --git a/src/matgl/kernels/tensor_norm3.py b/src/matgl/kernels/tensor_norm3.py new file mode 100644 index 00000000..302651da --- /dev/null +++ b/src/matgl/kernels/tensor_norm3.py @@ -0,0 +1,275 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Warp kernels for computing 3x3 tensor norms (I, A, S components).""" + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_tensor_norm3(dtype: str, h_last: bool = True, use_irmem: bool = True): + """Generate Warp kernels for computing squared norms of 3x3 tensor I, A, S components.""" + dtype_wp = get_wp_fp_dtype(dtype) + + if not use_irmem: + raise ValueError(f"only supporting use_irmem True, but got {use_irmem}") + + if not h_last: + raise ValueError(f"only supporting h_last True but got {h_last}") + + class mat3x3(wp.types.matrix(shape=(3, 3), dtype=dtype_wp)): + pass + + def tensor_norm3_fwd( + X: wp.array(ndim=4, dtype=dtype_wp), + output: wp.array(ndim=2, dtype=dtype_wp), + ): + """Computes I, A, S norms of 3x3 tensor: trace², antisym², sym_traceless².""" + b, h = wp.tid() + + x00 = X[b, 0, 0, h] + x01 = X[b, 0, 1, h] + x02 = X[b, 0, 2, h] + x10 = X[b, 1, 0, h] + x11 = X[b, 1, 1, h] + x12 = X[b, 1, 2, h] + x20 = X[b, 2, 0, h] + x21 = X[b, 2, 1, h] + x22 = X[b, 2, 2, h] + + one_half = X.dtype(0.5) + one_third = X.dtype(1.0 / 3.0) + + trace = x00 + x11 + x22 + trace_third = trace / X.dtype(3.0) + norm2_i = one_third * trace * trace + norm2_a = one_half * ((x01 - x10) * (x01 - x10) + (x02 - x20) * (x02 - x20) + (x12 - x21) * (x12 - x21)) + norm2_s = ( + one_half * ((x01 + x10) * (x01 + x10) + (x02 + x20) * (x02 + x20) + (x12 + x21) * (x12 + x21)) + + (x00 - trace_third) * (x00 - trace_third) + + (x11 - trace_third) * (x11 - trace_third) + + (x22 - trace_third) * (x22 - trace_third) + ) + + output[b, h] = norm2_i + output[b, h + X.shape[3]] = norm2_a + output[b, h + 2 * X.shape[3]] = norm2_s + + def tensor_norm3_bwd( + grad_output: wp.array(ndim=2, dtype=dtype_wp), + X: wp.array(ndim=4, dtype=dtype_wp), + grad_X: wp.array(ndim=4, dtype=dtype_wp), + ): + """Backward: grad_X = d(I,A,S norms)/dX · grad_output.""" + b, h = wp.tid() + + grad_i = grad_output[b, h] + grad_a = grad_output[b, h + X.shape[3]] + grad_s = grad_output[b, h + 2 * X.shape[3]] + + x00 = X[b, 0, 0, h] + x01 = X[b, 0, 1, h] + x02 = X[b, 0, 2, h] + x10 = X[b, 1, 0, h] + x11 = X[b, 1, 1, h] + x12 = X[b, 1, 2, h] + x20 = X[b, 2, 0, h] + x21 = X[b, 2, 1, h] + x22 = X[b, 2, 2, h] + + trace = x00 + x11 + x22 + trace_third = trace / X.dtype(3.0) + + diag_grad_i = X.dtype(2.0 / 3.0) * trace * grad_i + + dev00 = x00 - trace_third + dev11 = x11 - trace_third + dev22 = x22 - trace_third + + c4_3 = X.dtype(4.0) / X.dtype(3.0) + c2_3 = X.dtype(2.0) / X.dtype(3.0) + + grad_s_term_00 = c4_3 * dev00 - c2_3 * dev11 - c2_3 * dev22 + grad_s_term_11 = c4_3 * dev11 - c2_3 * dev00 - c2_3 * dev22 + grad_s_term_22 = c4_3 * dev22 - c2_3 * dev00 - c2_3 * dev11 + + grad_X[b, 0, 0, h] = diag_grad_i + grad_s * grad_s_term_00 + grad_X[b, 1, 1, h] = diag_grad_i + grad_s * grad_s_term_11 + grad_X[b, 2, 2, h] = diag_grad_i + grad_s * grad_s_term_22 + + diff01 = x01 - x10 + sum01 = x01 + x10 + grad_X[b, 0, 1, h] = grad_a * diff01 + grad_s * sum01 + grad_X[b, 1, 0, h] = -grad_a * diff01 + grad_s * sum01 + + diff02 = x02 - x20 + sum02 = x02 + x20 + grad_X[b, 0, 2, h] = grad_a * diff02 + grad_s * sum02 + grad_X[b, 2, 0, h] = -grad_a * diff02 + grad_s * sum02 + + diff12 = x12 - x21 + sum12 = x12 + x21 + grad_X[b, 1, 2, h] = grad_a * diff12 + grad_s * sum12 + grad_X[b, 2, 1, h] = -grad_a * diff12 + grad_s * sum12 + + def tensor_norm3_bwd_bwd( + grad_grad_X: wp.array(ndim=4, dtype=dtype_wp), + X: wp.array(ndim=4, dtype=dtype_wp), + grad_output: wp.array(ndim=2, dtype=dtype_wp), + grad_grad_output: wp.array(ndim=2, dtype=dtype_wp), + grad_x: wp.array(ndim=4, dtype=dtype_wp), + ): + """Computes d(grad_X)/d(grad_output) and d(grad_X)/d(X) contracted with grad_grad_X.""" + b, h = wp.tid() + + gg00 = grad_grad_X[b, 0, 0, h] + gg01 = grad_grad_X[b, 0, 1, h] + gg02 = grad_grad_X[b, 0, 2, h] + gg10 = grad_grad_X[b, 1, 0, h] + gg11 = grad_grad_X[b, 1, 1, h] + gg12 = grad_grad_X[b, 1, 2, h] + gg20 = grad_grad_X[b, 2, 0, h] + gg21 = grad_grad_X[b, 2, 1, h] + gg22 = grad_grad_X[b, 2, 2, h] + + x00 = X[b, 0, 0, h] + x01 = X[b, 0, 1, h] + x02 = X[b, 0, 2, h] + x10 = X[b, 1, 0, h] + x11 = X[b, 1, 1, h] + x12 = X[b, 1, 2, h] + x20 = X[b, 2, 0, h] + x21 = X[b, 2, 1, h] + x22 = X[b, 2, 2, h] + + grad_i = grad_output[b, h] + grad_a = grad_output[b, h + X.shape[3]] + grad_s = grad_output[b, h + 2 * X.shape[3]] + + trace_X = x00 + x11 + x22 + trace_gg = gg00 + gg11 + gg22 + c2_3 = X.dtype(2.0 / 3.0) + c4_3 = X.dtype(4.0 / 3.0) + + # Part 1: grad_grad_output = d(grad_X)/d(grad_output) · grad_grad_X + # I channel: (2/3) * trace(X) * trace(gg) + grad_grad_output[b, h] = c2_3 * trace_X * trace_gg + + # A channel: diff_X · diff_gg + diff01_X = x01 - x10 + diff02_X = x02 - x20 + diff12_X = x12 - x21 + diff01_gg = gg01 - gg10 + diff02_gg = gg02 - gg20 + diff12_gg = gg12 - gg21 + grad_grad_output[b, h + X.shape[3]] = diff01_X * diff01_gg + diff02_X * diff02_gg + diff12_X * diff12_gg + + # S channel: sum_X · sum_gg + dev_terms · diag_gg + trace_third_X = trace_X / X.dtype(3.0) + dev00 = x00 - trace_third_X + dev11 = x11 - trace_third_X + dev22 = x22 - trace_third_X + grad_s_term_00 = c4_3 * dev00 - c2_3 * dev11 - c2_3 * dev22 + grad_s_term_11 = c4_3 * dev11 - c2_3 * dev00 - c2_3 * dev22 + grad_s_term_22 = c4_3 * dev22 - c2_3 * dev00 - c2_3 * dev11 + sum01_X = x01 + x10 + sum02_X = x02 + x20 + sum12_X = x12 + x21 + sum01_gg = gg01 + gg10 + sum02_gg = gg02 + gg20 + sum12_gg = gg12 + gg21 + grad_grad_output_s = sum01_X * sum01_gg + sum02_X * sum02_gg + sum12_X * sum12_gg + grad_grad_output_s += grad_s_term_00 * gg00 + grad_s_term_11 * gg11 + grad_s_term_22 * gg22 + grad_grad_output[b, h + 2 * X.shape[3]] = grad_grad_output_s + + # Part 2: grad_x = d(grad_X)/d(X) · grad_grad_X + # I channel: (2/3) * grad_i * trace(gg) on diagonals + scalar_diag = c2_3 * grad_i * trace_gg + + # A channel: grad_a * diff_gg (antisymmetric) + antisym_01 = grad_a * diff01_gg + antisym_02 = grad_a * diff02_gg + antisym_12 = grad_a * diff12_gg + + # S channel off-diag: grad_s * sum_gg + sym_offdiag_01 = grad_s * sum01_gg + sym_offdiag_02 = grad_s * sum02_gg + sym_offdiag_12 = grad_s * sum12_gg + + # S channel diag: grad_s * (4/3 on self, -2/3 on others) + sym_diag_00 = grad_s * (c4_3 * gg00 - c2_3 * gg11 - c2_3 * gg22) + sym_diag_11 = grad_s * (c4_3 * gg11 - c2_3 * gg00 - c2_3 * gg22) + sym_diag_22 = grad_s * (c4_3 * gg22 - c2_3 * gg00 - c2_3 * gg11) + + # Diagonals + grad_x[b, 0, 0, h] = scalar_diag + sym_diag_00 + grad_x[b, 1, 1, h] = scalar_diag + sym_diag_11 + grad_x[b, 2, 2, h] = scalar_diag + sym_diag_22 + + # Off-diagonals + grad_x[b, 0, 1, h] = antisym_01 + sym_offdiag_01 + grad_x[b, 1, 0, h] = -antisym_01 + sym_offdiag_01 + grad_x[b, 0, 2, h] = antisym_02 + sym_offdiag_02 + grad_x[b, 2, 0, h] = -antisym_02 + sym_offdiag_02 + grad_x[b, 1, 2, h] = antisym_12 + sym_offdiag_12 + grad_x[b, 2, 1, h] = -antisym_12 + sym_offdiag_12 + + return ( + wp.Kernel( + tensor_norm3_fwd, + key=f"tensor_norm3_fwd_{dtype}", + module=wp.get_module(f"tensor_norm3_fwd_{dtype}"), + ), + wp.Kernel( + tensor_norm3_bwd, + key=f"tensor_norm3_bwd_{dtype}", + module=wp.get_module(f"tensor_norm3_bwd_{dtype}"), + ), + wp.Kernel( + tensor_norm3_bwd_bwd, + key=f"tensor_norm3_bwd_bwd_{dtype}", + module=wp.get_module(f"tensor_norm3_bwd_bwd_{dtype}"), + ), + ) + + +tensor_norm3_fwd_fp64, tensor_norm3_bwd_fp64, tensor_norm3_bwd_bwd_fp64 = generate_tensor_norm3("float64") +tensor_norm3_fwd_fp32, tensor_norm3_bwd_fp32, tensor_norm3_bwd_bwd_fp32 = generate_tensor_norm3("float32") +tensor_norm3_fwd_fp16, tensor_norm3_bwd_fp16, tensor_norm3_bwd_bwd_fp16 = generate_tensor_norm3("float16") + +add_module("tensor_norm3_fwd", ["float64"], tensor_norm3_fwd_fp64) +add_module("tensor_norm3_bwd", ["float64"], tensor_norm3_bwd_fp64) +add_module("tensor_norm3_bwd_bwd", ["float64"], tensor_norm3_bwd_bwd_fp64) + +add_module("tensor_norm3_fwd", ["float32"], tensor_norm3_fwd_fp32) +add_module("tensor_norm3_bwd", ["float32"], tensor_norm3_bwd_fp32) +add_module("tensor_norm3_bwd_bwd", ["float32"], tensor_norm3_bwd_bwd_fp32) + +add_module("tensor_norm3_fwd", ["float16"], tensor_norm3_fwd_fp16) +add_module("tensor_norm3_bwd", ["float16"], tensor_norm3_bwd_fp16) +add_module("tensor_norm3_bwd_bwd", ["float16"], tensor_norm3_bwd_bwd_fp16) diff --git a/src/matgl/kernels/tensornet_mp.py b/src/matgl/kernels/tensornet_mp.py new file mode 100644 index 00000000..e7bd34ba --- /dev/null +++ b/src/matgl/kernels/tensornet_mp.py @@ -0,0 +1,310 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_message_passing(dtype: str): + dtype_wp = get_wp_fp_dtype(dtype) + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + def message_passing_fwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + output_I: wp.array(ndim=3, dtype=dtype_wp), + output_A: wp.array(ndim=3, dtype=dtype_wp), + output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + output_I_reg = I.dtype(0) + output_A_reg = vec3(I.dtype(0)) + output_S_reg = vec5(I.dtype(0)) + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_j = row_indices[i] + idx_w = row_data[i] + wI = edge_attr[idx_w, 0, h] + wA = edge_attr[idx_w, 1, h] + wS = edge_attr[idx_w, 2, h] + + output_I_reg += I[idx_j, 0, h] * wI + for j in range(3): + output_A_reg[j] += A[idx_j, j, h] * wA + for j in range(5): + output_S_reg[j] += S[idx_j, j, h] * wS + + output_I[b, 0, h] = output_I_reg + for j in range(3): + output_A[b, j, h] = output_A_reg[j] + + for j in range(5): + output_S[b, j, h] = output_S_reg[j] + + def message_passing_bwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + col_data: wp.array(ndim=1, dtype=wp.int32), + col_indices: wp.array(ndim=1, dtype=wp.int32), + col_indptr: wp.array(ndim=1, dtype=wp.int32), + dI: wp.array(ndim=3, dtype=dtype_wp), + dA: wp.array(ndim=3, dtype=dtype_wp), + dS: wp.array(ndim=3, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + dI_reg = I.dtype(0.0) + dA_reg = vec3(I.dtype(0.0)) + dS_reg = vec5(I.dtype(0.0)) + + for i in range(col_indptr[b], col_indptr[b + 1]): + idx_j = col_indices[i] + idx_w = col_data[i] + + wI = edge_attr[idx_w, 0, h] + doutput_I_j = doutput_I[idx_j, 0, h] + dI_reg += doutput_I_j * wI + dedge_attr[idx_w, 0, h] = doutput_I_j * I[b, 0, h] + + # A + wA = edge_attr[idx_w, 1, h] + dweight_A = I.dtype(0.0) + for j in range(3): + dA_reg[j] += doutput_A[idx_j, j, h] * wA + dweight_A += doutput_A[idx_j, j, h] * A[b, j, h] + dedge_attr[idx_w, 1, h] = dweight_A + + # S + wS = edge_attr[idx_w, 2, h] + dweight_S = I.dtype(0.0) + for j in range(5): + dS_reg[j] += doutput_S[idx_j, j, h] * wS + dweight_S += doutput_S[idx_j, j, h] * S[b, j, h] + dedge_attr[idx_w, 2, h] = dweight_S + + dI[b, 0, h] = dI_reg + for j in range(3): + dA[b, j, h] = dA_reg[j] + for j in range(5): + dS[b, j, h] = dS_reg[j] + + def message_passing_edge_bwd_bwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + dI: wp.array(ndim=3, dtype=dtype_wp), + dA: wp.array(ndim=3, dtype=dtype_wp), + dS: wp.array(ndim=3, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + col_data: wp.array(ndim=1, dtype=wp.int32), + col_indices: wp.array(ndim=1, dtype=wp.int32), + col_indptr: wp.array(ndim=1, dtype=wp.int32), + d2I: wp.array(ndim=3, dtype=dtype_wp), + d2A: wp.array(ndim=3, dtype=dtype_wp), + d2S: wp.array(ndim=3, dtype=dtype_wp), + d2edge_attr: wp.array(ndim=3, dtype=dtype_wp), + ): + # Col-based iteration: b is source node, idx_j is destination node + # Computes d2I, d2A, d2S, d2edge_attr - no atomics needed + b, h = wp.tid() + + d2I_reg = I.dtype(0) + d2A_reg = vec3(I.dtype(0)) + d2S_reg = vec5(I.dtype(0)) + + for i in range(col_indptr[b], col_indptr[b + 1]): + idx_j = col_indices[i] # Destination node + idx_w = col_data[i] + + dweight_I = dedge_attr[idx_w, 0, h] + dweight_A = dedge_attr[idx_w, 1, h] + dweight_S = dedge_attr[idx_w, 2, h] + + # d2I[b] = Σ dedge_attr[edge] * doutput_I[dst] + d2I_reg += doutput_I[idx_j, 0, h] * dweight_I + + # d2edge_attr[edge] = dI[src] * doutput_I[dst] + d2edge_attr[idx_w, 0, h] = doutput_I[idx_j, 0, h] * dI[b, 0, h] + + # A + dweight_A_reg = I.dtype(0.0) + for j in range(3): + d2A_reg[j] += doutput_A[idx_j, j, h] * dweight_A + dweight_A_reg += doutput_A[idx_j, j, h] * dA[b, j, h] + d2edge_attr[idx_w, 1, h] = dweight_A_reg + + # S + dweight_S_reg = I.dtype(0.0) + for j in range(5): + d2S_reg[j] += doutput_S[idx_j, j, h] * dweight_S + dweight_S_reg += doutput_S[idx_j, j, h] * dS[b, j, h] + d2edge_attr[idx_w, 2, h] = dweight_S_reg + + d2I[b, 0, h] = d2I_reg + + for j in range(3): + d2A[b, j, h] = d2A_reg[j] + + for j in range(5): + d2S[b, j, h] = d2S_reg[j] + + def message_passing_output_bwd_bwd( + I: wp.array(ndim=3, dtype=dtype_wp), + A: wp.array(ndim=3, dtype=dtype_wp), + S: wp.array(ndim=3, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + dI: wp.array(ndim=3, dtype=dtype_wp), + dA: wp.array(ndim=3, dtype=dtype_wp), + dS: wp.array(ndim=3, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indices: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + d2output_I: wp.array(ndim=3, dtype=dtype_wp), + d2output_A: wp.array(ndim=3, dtype=dtype_wp), + d2output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + # Row-based iteration: b is destination node, idx_j is source node + # Computes d2output_I, d2output_A, d2output_S - no atomics needed + b, h = wp.tid() + + d2output_I_reg = I.dtype(0) + d2output_A_reg = vec3(I.dtype(0)) + d2output_S_reg = vec5(I.dtype(0)) + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_j = row_indices[i] # Source node + idx_w = row_data[i] + + wI = edge_attr[idx_w, 0, h] + wA = edge_attr[idx_w, 1, h] + wS = edge_attr[idx_w, 2, h] + + dweight_I = dedge_attr[idx_w, 0, h] + dweight_A = dedge_attr[idx_w, 1, h] + dweight_S = dedge_attr[idx_w, 2, h] + + # d2output_I[b] = Σ (dI[src] * edge_attr + I[src] * dedge_attr) + d2output_I_reg += dI[idx_j, 0, h] * wI + d2output_I_reg += I[idx_j, 0, h] * dweight_I + + # A + for j in range(3): + d2output_A_reg[j] += dA[idx_j, j, h] * wA + d2output_A_reg[j] += A[idx_j, j, h] * dweight_A + + # S + for j in range(5): + d2output_S_reg[j] += dS[idx_j, j, h] * wS + d2output_S_reg[j] += S[idx_j, j, h] * dweight_S + + d2output_I[b, 0, h] = d2output_I_reg + + for j in range(3): + d2output_A[b, j, h] = d2output_A_reg[j] + + for j in range(5): + d2output_S[b, j, h] = d2output_S_reg[j] + + return ( + wp.Kernel( + message_passing_fwd, + key=f"message_passing_fwd_{dtype}", + module=wp.get_module(f"message_passing_fwd_{dtype}"), + ), + wp.Kernel( + message_passing_bwd, + key=f"message_passing_bwd_{dtype}", + module=wp.get_module(f"message_passing_bwd_{dtype}"), + ), + wp.Kernel( + message_passing_edge_bwd_bwd, + key=f"message_passing_edge_bwd_bwd_{dtype}", + module=wp.get_module(f"message_passing_edge_bwd_bwd_{dtype}"), + ), + wp.Kernel( + message_passing_output_bwd_bwd, + key=f"message_passing_output_bwd_bwd_{dtype}", + module=wp.get_module(f"message_passing_output_bwd_bwd_{dtype}"), + ), + ) + + +( + message_passing_fwd_fp64, + message_passing_bwd_fp64, + message_passing_edge_bwd_bwd_fp64, + message_passing_output_bwd_bwd_fp64, +) = generate_message_passing("float64") +( + message_passing_fwd_fp32, + message_passing_bwd_fp32, + message_passing_edge_bwd_bwd_fp32, + message_passing_output_bwd_bwd_fp32, +) = generate_message_passing("float32") +( + message_passing_fwd_fp16, + message_passing_bwd_fp16, + message_passing_edge_bwd_bwd_fp16, + message_passing_output_bwd_bwd_fp16, +) = generate_message_passing("float16") + +add_module("message_passing_fwd", ["float64"], message_passing_fwd_fp64) +add_module("message_passing_bwd", ["float64"], message_passing_bwd_fp64) +add_module("message_passing_edge_bwd_bwd", ["float64"], message_passing_edge_bwd_bwd_fp64) +add_module("message_passing_output_bwd_bwd", ["float64"], message_passing_output_bwd_bwd_fp64) + +add_module("message_passing_fwd", ["float32"], message_passing_fwd_fp32) +add_module("message_passing_bwd", ["float32"], message_passing_bwd_fp32) +add_module("message_passing_edge_bwd_bwd", ["float32"], message_passing_edge_bwd_bwd_fp32) +add_module("message_passing_output_bwd_bwd", ["float32"], message_passing_output_bwd_bwd_fp32) + +add_module("message_passing_fwd", ["float16"], message_passing_fwd_fp16) +add_module("message_passing_bwd", ["float16"], message_passing_bwd_fp16) +add_module("message_passing_edge_bwd_bwd", ["float16"], message_passing_edge_bwd_bwd_fp16) +add_module("message_passing_output_bwd_bwd", ["float16"], message_passing_output_bwd_bwd_fp16) diff --git a/src/matgl/kernels/tensornet_radial_mp.py b/src/matgl/kernels/tensornet_radial_mp.py new file mode 100644 index 00000000..27d39a2b --- /dev/null +++ b/src/matgl/kernels/tensornet_radial_mp.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import warp as wp + +from .utils import add_module, get_wp_fp_dtype + + +def generate_radial_message_passing(dtype: str): + dtype_wp = get_wp_fp_dtype(dtype) + + class vec3(wp.types.vector(length=3, dtype=dtype_wp)): + pass + + class vec5(wp.types.vector(length=5, dtype=dtype_wp)): + pass + + def radial_message_passing_fwd( + edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + output_I: wp.array(ndim=3, dtype=dtype_wp), + output_A: wp.array(ndim=3, dtype=dtype_wp), + output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + output_I_reg = output_I.dtype(0) + output_A_reg = vec3(output_I.dtype(0)) + output_S_reg = vec5(output_I.dtype(0)) + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + + weight_I_reg = edge_attr[idx_w, 0, h] + weight_A_reg = edge_attr[idx_w, 1, h] + weight_S_reg = edge_attr[idx_w, 2, h] + + r_ij = vec3(output_I.dtype(0)) + r_ij[0] = edge_vec_norm[idx_w, 0] + r_ij[1] = edge_vec_norm[idx_w, 1] + r_ij[2] = edge_vec_norm[idx_w, 2] + + output_I_reg += weight_I_reg + + output_A_reg[0] += r_ij[2] * weight_A_reg + output_A_reg[1] += -r_ij[1] * weight_A_reg + output_A_reg[2] += r_ij[0] * weight_A_reg + + S_reg = vec5() + mean_r2 = (r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2]) / output_I.dtype(3.0) + S_reg[0] = r_ij[0] * r_ij[0] - mean_r2 + S_reg[1] = r_ij[0] * r_ij[1] + S_reg[2] = r_ij[0] * r_ij[2] + S_reg[3] = r_ij[1] * r_ij[1] - mean_r2 + S_reg[4] = r_ij[1] * r_ij[2] + + output_S_reg[0] += S_reg[0] * weight_S_reg + output_S_reg[1] += S_reg[1] * weight_S_reg + output_S_reg[2] += S_reg[2] * weight_S_reg + output_S_reg[3] += S_reg[3] * weight_S_reg + output_S_reg[4] += S_reg[4] * weight_S_reg + + output_I[b, 0, h] = output_I_reg + for i in range(3): + output_A[b, i, h] = output_A_reg[i] + + for i in range(5): + output_S[b, i, h] = output_S_reg[i] + + def radial_message_passing_bwd( + edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + dedge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + doutput_I_reg = doutput_I[b, 0, h] + doutput_A_reg = vec3() + doutput_A_reg[0] = doutput_A[b, 0, h] + doutput_A_reg[1] = doutput_A[b, 1, h] + doutput_A_reg[2] = doutput_A[b, 2, h] + + doutput_S_reg = vec5() + doutput_S_reg[0] = doutput_S[b, 0, h] + doutput_S_reg[1] = doutput_S[b, 1, h] + doutput_S_reg[2] = doutput_S[b, 2, h] + doutput_S_reg[3] = doutput_S[b, 3, h] + doutput_S_reg[4] = doutput_S[b, 4, h] + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + + edge_attr_A_reg = edge_attr[idx_w, 1, h] + edge_attr_S_reg = edge_attr[idx_w, 2, h] + + r_ij = vec3(doutput_I.dtype(0)) + dr_ij = vec3(doutput_I.dtype(0)) + r_ij[0] = edge_vec_norm[idx_w, 0] + r_ij[1] = edge_vec_norm[idx_w, 1] + r_ij[2] = edge_vec_norm[idx_w, 2] + + dr_ij[2] += doutput_A_reg[0] * edge_attr_A_reg + dr_ij[1] += -doutput_A_reg[1] * edge_attr_A_reg + dr_ij[0] += doutput_A_reg[2] * edge_attr_A_reg + + dedge_attr_I = doutput_I_reg + + dedge_attr_A = doutput_A_reg[0] * r_ij[2] - doutput_A_reg[1] * r_ij[1] + doutput_A_reg[2] * r_ij[0] + + S_reg = vec5() + mean_r2 = (r_ij[0] * r_ij[0] + r_ij[1] * r_ij[1] + r_ij[2] * r_ij[2]) / doutput_I.dtype(3.0) + S_reg[0] = r_ij[0] * r_ij[0] - mean_r2 + S_reg[1] = r_ij[0] * r_ij[1] + S_reg[2] = r_ij[0] * r_ij[2] + S_reg[3] = r_ij[1] * r_ij[1] - mean_r2 + S_reg[4] = r_ij[1] * r_ij[2] + + dedge_attr_S = (S_reg[0]) * doutput_S_reg[0] + dedge_attr_S += (S_reg[1]) * doutput_S_reg[1] + dedge_attr_S += (S_reg[2]) * doutput_S_reg[2] + dedge_attr_S += (S_reg[3]) * doutput_S_reg[3] + dedge_attr_S += (S_reg[4]) * doutput_S_reg[4] + + dS_reg = vec5() + dS_reg[0] = edge_attr_S_reg * doutput_S_reg[0] + dS_reg[1] = edge_attr_S_reg * doutput_S_reg[1] + dS_reg[2] = edge_attr_S_reg * doutput_S_reg[2] + dS_reg[3] = edge_attr_S_reg * doutput_S_reg[3] + dS_reg[4] = edge_attr_S_reg * doutput_S_reg[4] + + dr_ij[0] += ( + dS_reg[0] * (doutput_I.dtype(4.0) / doutput_I.dtype(3.0) * r_ij[0]) + + dS_reg[1] * r_ij[1] + + dS_reg[2] * r_ij[2] + + dS_reg[3] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[0]) + ) + dr_ij[1] += ( + dS_reg[0] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[1]) + + dS_reg[1] * r_ij[0] + + dS_reg[3] * (doutput_I.dtype(4.0) / doutput_I.dtype(3.0) * r_ij[1]) + + dS_reg[4] * r_ij[2] + ) + dr_ij[2] += ( + dS_reg[0] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[2]) + + dS_reg[2] * r_ij[0] + + dS_reg[3] * (-doutput_I.dtype(2.0) / doutput_I.dtype(3.0) * r_ij[2]) + + dS_reg[4] * r_ij[1] + ) + + wp.atomic_add(dedge_attr, idx_w, 0, h, dedge_attr_I) + wp.atomic_add(dedge_attr, idx_w, 1, h, dedge_attr_A) + wp.atomic_add(dedge_attr, idx_w, 2, h, dedge_attr_S) + + wp.atomic_add(dedge_vec_norm, idx_w, 0, dr_ij[0]) + wp.atomic_add(dedge_vec_norm, idx_w, 1, dr_ij[1]) + wp.atomic_add(dedge_vec_norm, idx_w, 2, dr_ij[2]) + + def radial_message_passing_bwd_bwd( + edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + edge_attr: wp.array(ndim=3, dtype=dtype_wp), + dedge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + dedge_attr: wp.array(ndim=3, dtype=dtype_wp), + doutput_I: wp.array(ndim=3, dtype=dtype_wp), + doutput_A: wp.array(ndim=3, dtype=dtype_wp), + doutput_S: wp.array(ndim=3, dtype=dtype_wp), + row_data: wp.array(ndim=1, dtype=wp.int32), + row_indptr: wp.array(ndim=1, dtype=wp.int32), + d2edge_vec_norm: wp.array(ndim=2, dtype=dtype_wp), + d2edge_attr: wp.array(ndim=3, dtype=dtype_wp), + d2output_I: wp.array(ndim=3, dtype=dtype_wp), + d2output_A: wp.array(ndim=3, dtype=dtype_wp), + d2output_S: wp.array(ndim=3, dtype=dtype_wp), + ): + b, h = wp.tid() + + d2output_I_reg = d2output_I.dtype(0.0) + d2output_A_reg = vec3() + d2output_S_reg = vec5() + + for i in range(row_indptr[b], row_indptr[b + 1]): + idx_w = row_data[i] + edge_attr_A_reg = edge_attr[idx_w, 1, h] + edge_attr_S_reg = edge_attr[idx_w, 2, h] + + dedge_attr_I = dedge_attr[idx_w, 0, h] + dedge_attr_A = dedge_attr[idx_w, 1, h] + dedge_attr_S = dedge_attr[idx_w, 2, h] + + r_ij = vec3(d2output_I.dtype(0)) + dr_ij = vec3(d2output_I.dtype(0)) + for j in range(3): + r_ij[j] = edge_vec_norm[idx_w, j] + dr_ij[j] = dedge_vec_norm[idx_w, j] + + d2output_I_reg += dedge_attr_I + + d2r_ij = vec3(d2output_I.dtype(0)) + + # No gradient contribution for edge_attr[*, 0, h] in forward pass + # d2edge_attr[idx_w, 0, h] = d2output_I.dtype(0.0) + + d2output_A_reg[0] += dr_ij[2] * edge_attr_A_reg + d2output_A_reg[1] += -dr_ij[1] * edge_attr_A_reg + d2output_A_reg[2] += dr_ij[0] * edge_attr_A_reg + + d2output_A_reg[0] += dedge_attr_A * r_ij[2] + d2output_A_reg[1] += -dedge_attr_A * r_ij[1] + d2output_A_reg[2] += dedge_attr_A * r_ij[0] + + dweight_A = doutput_A[b, 0, h] * dr_ij[2] - doutput_A[b, 1, h] * dr_ij[1] + doutput_A[b, 2, h] * dr_ij[0] + + d2r_ij[2] += dedge_attr_A * doutput_A[b, 0, h] + d2r_ij[1] += -dedge_attr_A * doutput_A[b, 1, h] + d2r_ij[0] += dedge_attr_A * doutput_A[b, 2, h] + + wp.atomic_add(d2edge_attr, idx_w, 1, h, dweight_A) + + c0 = doutput_S.dtype(4.0) / doutput_S.dtype(3.0) + c1 = -doutput_S.dtype(2.0) / doutput_S.dtype(3.0) + + c2 = doutput_S.dtype(2.0) / doutput_S.dtype(3.0) + c3 = -doutput_S.dtype(1.0) / doutput_S.dtype(3.0) + + d2output_S_reg[0] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * c0 * r_ij[0] + + dedge_vec_norm[idx_w, 1] * c1 * r_ij[1] + + dedge_vec_norm[idx_w, 2] * c1 * r_ij[2] + ) + d2output_S_reg[0] += dedge_attr_S * ( + c2 * r_ij[0] * r_ij[0] + c3 * r_ij[1] * r_ij[1] + c3 * r_ij[2] * r_ij[2] + ) + + d2output_S_reg[1] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * r_ij[1] + dedge_vec_norm[idx_w, 1] * r_ij[0] + ) + d2output_S_reg[1] += dedge_attr_S * (r_ij[1] * r_ij[0]) + + d2output_S_reg[2] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * r_ij[2] + dedge_vec_norm[idx_w, 2] * r_ij[0] + ) + d2output_S_reg[2] += dedge_attr_S * (r_ij[2] * r_ij[0]) + + d2output_S_reg[3] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 0] * c1 * r_ij[0] + + dedge_vec_norm[idx_w, 1] * c0 * r_ij[1] + + dedge_vec_norm[idx_w, 2] * c1 * r_ij[2] + ) + d2output_S_reg[3] += dedge_attr_S * ( + c3 * r_ij[0] * r_ij[0] + c2 * r_ij[1] * r_ij[1] + c3 * r_ij[2] * r_ij[2] + ) + + d2output_S_reg[4] += edge_attr_S_reg * ( + dedge_vec_norm[idx_w, 1] * r_ij[2] + dedge_vec_norm[idx_w, 2] * r_ij[1] + ) + d2output_S_reg[4] += dedge_attr_S * (r_ij[2] * r_ij[1]) + + d2r_ij[0] += doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c0) + d2r_ij[1] += doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c1) + d2r_ij[2] += doutput_S[b, 0, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) + + d2r_ij[0] += doutput_S[b, 0, h] * dedge_attr_S * (c0 * r_ij[0]) + d2r_ij[1] += doutput_S[b, 0, h] * dedge_attr_S * (c1 * r_ij[1]) + d2r_ij[2] += doutput_S[b, 0, h] * dedge_attr_S * (c1 * r_ij[2]) + + d2r_ij[0] += doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) + d2r_ij[1] += doutput_S[b, 1, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) + + d2r_ij[0] += doutput_S[b, 1, h] * dedge_attr_S * (r_ij[1]) + d2r_ij[1] += doutput_S[b, 1, h] * dedge_attr_S * (r_ij[0]) + + d2r_ij[0] += doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) + d2r_ij[2] += doutput_S[b, 2, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0]) + + d2r_ij[0] += doutput_S[b, 2, h] * dedge_attr_S * (r_ij[2]) + d2r_ij[2] += doutput_S[b, 2, h] * dedge_attr_S * (r_ij[0]) + + d2r_ij[0] += doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 0] * c1) + d2r_ij[1] += doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1] * c0) + d2r_ij[2] += doutput_S[b, 3, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2] * c1) + + d2r_ij[0] += doutput_S[b, 3, h] * dedge_attr_S * (c1 * r_ij[0]) + d2r_ij[1] += doutput_S[b, 3, h] * dedge_attr_S * (c0 * r_ij[1]) + d2r_ij[2] += doutput_S[b, 3, h] * dedge_attr_S * (c1 * r_ij[2]) + + d2r_ij[1] += doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 2]) + d2r_ij[2] += doutput_S[b, 4, h] * edge_attr_S_reg * (dedge_vec_norm[idx_w, 1]) + + d2r_ij[1] += doutput_S[b, 4, h] * dedge_attr_S * (r_ij[2]) + d2r_ij[2] += doutput_S[b, 4, h] * dedge_attr_S * (r_ij[1]) + + d2weight_S = doutput_S.dtype(0.0) + d2weight_S += doutput_S[b, 0, h] * ( + c0 * r_ij[0] * dedge_vec_norm[idx_w, 0] + + c1 * r_ij[1] * dedge_vec_norm[idx_w, 1] + + c1 * r_ij[2] * dedge_vec_norm[idx_w, 2] + ) + + d2weight_S += doutput_S[b, 1, h] * (r_ij[1] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 1]) + + d2weight_S += doutput_S[b, 2, h] * (r_ij[2] * dedge_vec_norm[idx_w, 0] + r_ij[0] * dedge_vec_norm[idx_w, 2]) + + d2weight_S += doutput_S[b, 3, h] * ( + c1 * r_ij[0] * dedge_vec_norm[idx_w, 0] + + c0 * r_ij[1] * dedge_vec_norm[idx_w, 1] + + c1 * r_ij[2] * dedge_vec_norm[idx_w, 2] + ) + + d2weight_S += doutput_S[b, 4, h] * (r_ij[2] * dedge_vec_norm[idx_w, 1] + r_ij[1] * dedge_vec_norm[idx_w, 2]) + + wp.atomic_add(d2edge_attr, idx_w, 2, h, d2weight_S) + + wp.atomic_add(d2edge_vec_norm, idx_w, 0, d2r_ij[0]) + wp.atomic_add(d2edge_vec_norm, idx_w, 1, d2r_ij[1]) + wp.atomic_add(d2edge_vec_norm, idx_w, 2, d2r_ij[2]) + + d2output_I[b, 0, h] = d2output_I_reg + + for i in range(3): + d2output_A[b, i, h] = d2output_A_reg[i] + + for i in range(5): + d2output_S[b, i, h] = d2output_S_reg[i] + + return ( + wp.Kernel( + radial_message_passing_fwd, + key=f"radial_message_passing_fwd_{dtype}", + module=wp.get_module(f"radial_message_passing_fwd_{dtype}"), + ), + wp.Kernel( + radial_message_passing_bwd, + key=f"radial_message_passing_bwd_{dtype}", + module=wp.get_module(f"radial_message_passing_bwd_{dtype}"), + ), + wp.Kernel( + radial_message_passing_bwd_bwd, + key=f"radial_message_passing_bwd_bwd_{dtype}", + module=wp.get_module(f"radial_message_passing_bwd_bwd_{dtype}"), + ), + ) + + +( + radial_message_passing_fwd_fp64, + radial_message_passing_bwd_fp64, + radial_message_passing_bwd_bwd_fp64, +) = generate_radial_message_passing("float64") +( + radial_message_passing_fwd_fp32, + radial_message_passing_bwd_fp32, + radial_message_passing_bwd_bwd_fp32, +) = generate_radial_message_passing("float32") +( + radial_message_passing_fwd_fp16, + radial_message_passing_bwd_fp16, + radial_message_passing_bwd_bwd_fp16, +) = generate_radial_message_passing("float16") + +add_module("radial_message_passing_fwd", ["float64"], radial_message_passing_fwd_fp64) +add_module("radial_message_passing_bwd", ["float64"], radial_message_passing_bwd_fp64) +add_module("radial_message_passing_bwd_bwd", ["float64"], radial_message_passing_bwd_bwd_fp64) + +add_module("radial_message_passing_fwd", ["float32"], radial_message_passing_fwd_fp32) +add_module("radial_message_passing_bwd", ["float32"], radial_message_passing_bwd_fp32) +add_module("radial_message_passing_bwd_bwd", ["float32"], radial_message_passing_bwd_bwd_fp32) + +add_module("radial_message_passing_fwd", ["float16"], radial_message_passing_fwd_fp16) +add_module("radial_message_passing_bwd", ["float16"], radial_message_passing_bwd_fp16) +add_module("radial_message_passing_bwd_bwd", ["float16"], radial_message_passing_bwd_bwd_fp16) diff --git a/src/matgl/kernels/utils.py b/src/matgl/kernels/utils.py new file mode 100644 index 00000000..e2ef2d30 --- /dev/null +++ b/src/matgl/kernels/utils.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp + +MODULES = {} + + +def get_module(name: str, dtype: list[str]): + """Get the module for the given name and dtype.""" + full_name = name + "_" + "_".join(get_dtype(d) for d in dtype) + if full_name not in MODULES: + print(f"Module {full_name} not found in MODULES dictionary") + print(f"Available modules: {list(MODULES.keys())}") + raise ValueError(f"Module {full_name} not found") + return MODULES[full_name] + + +def add_module(name: str, dtype: list[str], kernel: wp.Kernel): + """Add the module for the given name and dtype.""" + full_name = name + "_" + "_".join(get_dtype(d) for d in dtype) + if full_name not in MODULES: + MODULES[full_name] = kernel + return MODULES[full_name] + + +def get_dtype(dtype: str): + """Get the dtype string representation for the given dtype (WIP).""" + if dtype.endswith("16"): + return "fp16" + if dtype.endswith("32"): + return "fp32" + if dtype.endswith("64"): + return "fp64" + raise ValueError(f"Unsupported dtype: {dtype}") + + +def get_wp_fp_dtype(dtype: str): + """Get the warp dtype for the given dtype (WIP).""" + if dtype.endswith("16"): + return wp.float16 + if dtype.endswith("32"): + return wp.float32 + if dtype.endswith("64"): + return wp.float64 + raise ValueError(f"Unsupported dtype: {dtype}") + + +def list_modules(): + """List all modules in the MODULES dictionary.""" + print("Available modules:") + for name in MODULES: + print(f" - {name}") + return list(MODULES.keys()) + + +def get_stream(device: torch.device): + """Get the stream for the given device.""" + if device.type == "cuda": + return wp.stream_from_torch(torch.cuda.current_stream(device)) + return None diff --git a/src/matgl/models/_tensornet_pyg.py b/src/matgl/models/_tensornet_pyg.py index 0fab57f6..378a7210 100644 --- a/src/matgl/models/_tensornet_pyg.py +++ b/src/matgl/models/_tensornet_pyg.py @@ -28,14 +28,18 @@ WeightedAtomReadOut, WeightedReadOut, ) -from matgl.utils.cutoff import cosine_cutoff -from matgl.utils.maths import ( - decompose_tensor, - new_radial_tensor, - scatter_add, - vector_to_skewtensor, - vector_to_symtensor, +from matgl.ops import ( + fn_compose_tensor, + fn_decompose_tensor, + fn_message_passing, + fn_radial_message_passing, + fn_tensor_matmul_o3_3x3, + fn_tensor_matmul_so3_3x3, + fn_tensor_norm3, + graph_transform, ) +from matgl.utils.cutoff import cosine_cutoff +from matgl.utils.maths import scatter_add from ._core import MatGLModel @@ -45,40 +49,6 @@ logger = logging.getLogger(__file__) -def compose_tensor(I_tensor: torch.Tensor, A: torch.Tensor, S: torch.Tensor) -> torch.Tensor: - """Compose tensor from scalar (I_tensor), skew-symmetric (A), and traceless symmetric (S) components. - - Args: - I_tensor: Scalar component, shape (num_nodes, 1, 1, units) or (num_nodes, 3, 3, units) - A: Skew-symmetric component, shape (num_nodes, 3, 3, units) - S: Traceless symmetric component, shape (num_nodes, 3, 3, units) - - Returns: - Composed tensor, shape (num_nodes, 3, 3, units) - """ - # I_tensor is scalar (1x1), A is skew (3x3), S is traceless symmetric (3x3) - # For I_tensor, we need to expand it to 3x3 identity matrix - if I_tensor.shape[1] == 1 and I_tensor.shape[2] == 1: - # I_tensor has shape (num_nodes, 1, 1, units) - # Expand scalar to 3x3 identity matrix: multiply I_tensor by identity - eye = torch.eye(3, 3, device=I_tensor.device, dtype=I_tensor.dtype) # (3, 3) - # I_tensor: (num_nodes, 1, 1, units) - # We need: I_expanded[i, :, :, u] = I_tensor[i, 0, 0, u] * eye - # I_values: (num_nodes, units) - I_values = I_tensor.squeeze(1).squeeze(1) # (num_nodes, units) - # eye_expanded: (1, 3, 3, 1) for broadcasting - eye_expanded = eye.unsqueeze(0).unsqueeze(-1) # (1, 3, 3, 1) - # I_values.unsqueeze(1).unsqueeze(1): (num_nodes, 1, 1, units) - # Multiply: (num_nodes, 1, 1, units) * (1, 3, 3, 1) -> (num_nodes, 3, 3, units) - I_expanded = I_values.unsqueeze(1).unsqueeze(1) * eye_expanded # (num_nodes, 3, 3, units) - else: - I_expanded = I_tensor - - # A is already 3x3 skew-symmetric, shape (num_nodes, 3, 3, units) - # S is already 3x3 traceless symmetric, shape (num_nodes, 3, 3, units) - return I_expanded + A + S - - def compute_pair_vector_and_distance( pos: torch.Tensor, edge_index: torch.Tensor, @@ -108,151 +78,9 @@ def compute_pair_vector_and_distance( return bond_vec, bond_dist -def radial_message_passing( - edge_vec_norm: torch.Tensor, - edge_attr: torch.Tensor, - edge_index: torch.Tensor, - num_nodes: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform radial message passing to aggregate edge information to nodes. - - Args: - edge_vec_norm: Normalized edge vectors, shape (num_edges, 3) - edge_attr: Edge attributes, shape (num_edges, 3, units) - edge_index: Edge indices, shape (2, num_edges) - num_nodes: Number of nodes - - Returns: - I: Scalar components, shape (num_nodes, 1, 1, units) - A: Skew-symmetric components, shape (num_nodes, 3, 3, units) - S: Traceless symmetric components, shape (num_nodes, 3, 3, units) - """ - dst = edge_index[1] - - # Create radial tensors from edge vectors - # For scalars: use (1, 1, 1, 1) which will broadcast with f_I - eye_scalar_base = torch.ones(1, 1, 1, 1, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype) - A_skew_base = vector_to_skewtensor(edge_vec_norm).unsqueeze(-3) # (num_edges, 1, 3, 3) - S_sym_base = vector_to_symtensor(edge_vec_norm).unsqueeze(-3) # (num_edges, 1, 3, 3) - - # Split edge_attr into three components - edge_attr_I = edge_attr[:, 0, :] # (num_edges, units) - edge_attr_A = edge_attr[:, 1, :] # (num_edges, units) - edge_attr_S = edge_attr[:, 2, :] # (num_edges, units) - - # Call new_radial_tensor - # new_radial_tensor multiplies f_I[..., None, None] * scalars - # f_I: (num_edges, units) -> (num_edges, units, 1, 1) - # scalars: (1, 1, 1, 1) -> broadcasts to (num_edges, units, 1, 1) - # Result: I_ij (num_edges, units, 1, 1), A_ij (num_edges, units, 3, 3), S_ij (num_edges, units, 3, 3) - I_ij, A_ij, S_ij = new_radial_tensor( - eye_scalar_base, - A_skew_base, - S_sym_base, - edge_attr_I, - edge_attr_A, - edge_attr_S, - ) - - # new_radial_tensor returns with units in position 1, we need units in position -1 - # Transpose: (num_edges, units, 1, 1) -> (num_edges, 1, 1, units) - # Transpose: (num_edges, units, 3, 3) -> (num_edges, 3, 3, units) - I_ij = I_ij.permute(0, 2, 3, 1) # (num_edges, 1, 1, units) - A_ij = A_ij.permute(0, 2, 3, 1) # (num_edges, 3, 3, units) - S_ij = S_ij.permute(0, 2, 3, 1) # (num_edges, 3, 3, units) - - # Aggregate to nodes - I_tensor = scatter_add(I_ij, dst, dim_size=num_nodes, dim=0) - A = scatter_add(A_ij, dst, dim_size=num_nodes, dim=0) - S = scatter_add(S_ij, dst, dim_size=num_nodes, dim=0) - - return I_tensor, A, S - - -def message_passing( - I_tensor: torch.Tensor, - A: torch.Tensor, - S: torch.Tensor, - edge_attr: torch.Tensor, - edge_index: torch.Tensor, - num_nodes: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform message passing for tensor components. - - Args: - I_tensor: Scalar components, shape (num_nodes, 1, 1, units) - A: Skew-symmetric components, shape (num_nodes, 3, 3, units) - S: Traceless symmetric components, shape (num_nodes, 3, 3, units) - edge_attr: Edge attributes, shape (num_edges, 3, units) - edge_index: Edge indices, shape (2, num_edges) - num_nodes: Number of nodes - - Returns: - Im: Aggregated scalar messages, shape (num_nodes, 1, 1, units) - Am: Aggregated skew messages, shape (num_nodes, 3, 3, units) - Sm: Aggregated traceless messages, shape (num_nodes, 3, 3, units) - """ - dst = edge_index[1] - - # Get node features for destination nodes - I_j = I_tensor[dst] - A_j = A[dst] - S_j = S[dst] - - # Extract edge attribute components - # edge_attr has shape (num_edges, 3, units) where dim 1 is (I, A, S) components - edge_attr_I = edge_attr[:, 0, :] # (num_edges, units) - edge_attr_A = edge_attr[:, 1, :] # (num_edges, units) - edge_attr_S = edge_attr[:, 2, :] # (num_edges, units) - - # After linear transformations, I_tensor, A, S all have shape (num_nodes, 3, 3, units) - # So I_j, A_j, S_j have shape (num_edges, 3, 3, units) - # Expand edge attributes for broadcasting: (num_edges, units) -> (num_edges, 1, 1, units) - edge_attr_I = edge_attr_I.unsqueeze(1).unsqueeze(1) # (num_edges, 1, 1, units) - edge_attr_A = edge_attr_A.unsqueeze(1).unsqueeze(1) # (num_edges, 1, 1, units) - edge_attr_S = edge_attr_S.unsqueeze(1).unsqueeze(1) # (num_edges, 1, 1, units) - - # Apply edge attributes to node features - I_m = I_j * edge_attr_I # (num_edges, 3, 3, units) - A_m = A_j * edge_attr_A # (num_edges, 3, 3, units) - S_m = S_j * edge_attr_S # (num_edges, 3, 3, units) - - # Aggregate messages - Im = scatter_add(I_m, dst, dim_size=num_nodes, dim=0) - Am = scatter_add(A_m, dst, dim_size=num_nodes, dim=0) - Sm = scatter_add(S_m, dst, dim_size=num_nodes, dim=0) - - return Im, Am, Sm - - -def tensor_matmul_o3(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: - """O(3) equivariant tensor multiplication. - - Args: - X: First tensor, shape (num_nodes, 3, 3, units) - Y: Second tensor, shape (num_nodes, 3, 3, units) - - Returns: - Result tensor, shape (num_nodes, 3, 3, units) - """ - # O(3) equivariant: A + B where A = X @ Y, B = Y @ X - A = torch.einsum("nijk,njlk->nilk", X, Y) - B = torch.einsum("nijk,njlk->nilk", Y, X) - return A + B - - -def tensor_matmul_so3(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: - """SO(3) equivariant tensor multiplication. - - Args: - X: First tensor, shape (num_nodes, 3, 3, units) - Y: Second tensor, shape (num_nodes, 3, 3, units) - - Returns: - Result tensor, shape (num_nodes, 3, 3, units) - """ - # SO(3) equivariant: 2 * (X @ Y) - return 2 * torch.einsum("nijk,njlk->nilk", X, Y) +def tensor_norm(tensor): + """Computes Frobenius norm.""" + return (tensor * tensor).sum((-3, -2)) class TensorEmbedding(nn.Module): @@ -271,9 +99,9 @@ def __init__( self.units = units self.cutoff = cutoff - self.distance_proj1 = nn.Linear(degree_rbf, units, dtype=dtype) - self.distance_proj2 = nn.Linear(degree_rbf, units, dtype=dtype) - self.distance_proj3 = nn.Linear(degree_rbf, units, dtype=dtype) + # Create unified distance_proj from 3 temp layers (matches reference RNG pattern). + self.distance_proj = self._create_distance_proj(degree_rbf, units, dtype=dtype) + self.emb = nn.Embedding(ntypes_node, units, dtype=dtype) self.emb2 = nn.Linear(2 * units, units, dtype=dtype) self.act = activation @@ -288,10 +116,58 @@ def __init__( self.reset_parameters() + def _create_distance_proj( + self, + in_features: int, + units: int, + dtype: torch.dtype = matgl.float_th, + ) -> nn.Linear: + """Create unified distance_proj from 3 separate layers to match reference RNG pattern.""" + d_proj1 = nn.Linear(in_features, units, bias=True, dtype=dtype) + d_proj2 = nn.Linear(in_features, units, bias=True, dtype=dtype) + d_proj3 = nn.Linear(in_features, units, bias=True, dtype=dtype) + + layer = torch.nn.utils.skip_init(nn.Linear, in_features, 3 * units, bias=True, dtype=dtype) + with torch.no_grad(): + layer.weight.copy_(torch.cat([d_proj1.weight, d_proj2.weight, d_proj3.weight], dim=0)) + layer.bias.copy_(torch.cat([d_proj1.bias, d_proj2.bias, d_proj3.bias], dim=0)) + return layer + + def _reset_distance_proj(self) -> None: + """Reset distance_proj weights using 3 temp layers to match reference RNG pattern.""" + dtype = self.distance_proj.weight.dtype + d_proj1 = torch.nn.utils.skip_init( + nn.Linear, self.distance_proj.in_features, self.units, bias=True, dtype=dtype + ) + d_proj2 = torch.nn.utils.skip_init( + nn.Linear, self.distance_proj.in_features, self.units, bias=True, dtype=dtype + ) + d_proj3 = torch.nn.utils.skip_init( + nn.Linear, self.distance_proj.in_features, self.units, bias=True, dtype=dtype + ) + d_proj1.reset_parameters() + d_proj2.reset_parameters() + d_proj3.reset_parameters() + with torch.no_grad(): + self.distance_proj.weight.copy_(torch.cat([d_proj1.weight, d_proj2.weight, d_proj3.weight], dim=0)) + self.distance_proj.bias.copy_(torch.cat([d_proj1.bias, d_proj2.bias, d_proj3.bias], dim=0)) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + """Handle legacy checkpoints with separate distance_proj1/2/3 layers.""" + w_keys = [f"{prefix}distance_proj{i}.weight" for i in (1, 2, 3)] + b_keys = [f"{prefix}distance_proj{i}.bias" for i in (1, 2, 3)] + new_w = f"{prefix}distance_proj.weight" + new_b = f"{prefix}distance_proj.bias" + + if all(k in state_dict for k in w_keys + b_keys): + state_dict[new_w] = torch.cat([state_dict.pop(k) for k in w_keys], dim=0) + state_dict[new_b] = torch.cat([state_dict.pop(k) for k in b_keys], dim=0) + + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + def reset_parameters(self): - self.distance_proj1.reset_parameters() - self.distance_proj2.reset_parameters() - self.distance_proj3.reset_parameters() + """Reinitialize parameters with RNG pattern matching reference implementation.""" + self._reset_distance_proj() self.emb.reset_parameters() self.emb2.reset_parameters() for linear in self.linears_tensor: @@ -307,6 +183,8 @@ def forward( edge_weight: torch.Tensor, edge_vec: torch.Tensor, edge_attr: torch.Tensor, + col_data: torch.Tensor, + col_indptr: torch.Tensor, ) -> torch.Tensor: """Forward pass. @@ -316,94 +194,51 @@ def forward( edge_weight: Edge weights (distances), shape (num_edges,) edge_vec: Edge vectors, shape (num_edges, 3) edge_attr: Edge attributes (RBF), shape (num_edges, num_rbf) + col_data: CSR col data for destination aggregation, shape (num_edges,) + col_indptr: CSR col indptr for destination aggregation, shape (num_nodes+1,) Returns: X: Tensor representation, shape (num_nodes, 3, 3, units) """ - num_nodes = z.shape[0] - # Node embedding x = self.emb(z) # (num_nodes, units) # Edge processing C = cosine_cutoff(edge_weight, self.cutoff) - W1 = self.distance_proj1(edge_attr) * C.view(-1, 1) # (num_edges, units) - W2 = self.distance_proj2(edge_attr) * C.view(-1, 1) - W3 = self.distance_proj3(edge_attr) * C.view(-1, 1) - - edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) + edge_attr = self.distance_proj(edge_attr).view(-1, 3, self.units) # Get atomic number messages - src, dst = edge_index[0], edge_index[1] - vi = x[src] - vj = x[dst] - zij = torch.cat([vi, vj], dim=-1) + zij = x.index_select(0, edge_index.t().reshape(-1)).view(-1, self.units * 2) Zij = self.emb2(zij) # (num_edges, units) # Create edge attributes with Zij - edge_attr_processed = torch.stack([W1, W2, W3], dim=1) # (num_edges, 3, units) - edge_attr_processed = edge_attr_processed * Zij.unsqueeze(1) # (num_edges, 3, units) + edge_attr_processed = edge_attr.view(-1, 3, self.units) * C.view(-1, 1, 1) * Zij.view(-1, 1, self.units) # Radial message passing - I_tensor, A, S = radial_message_passing(edge_vec_norm, edge_attr_processed, edge_index, num_nodes) + edge_vec_norm = edge_vec / torch.norm(edge_vec, dim=1, keepdim=True).clamp(min=1e-6) + I, A, S = fn_radial_message_passing(edge_vec_norm, edge_attr_processed, col_data, col_indptr) # noqa: E741 # Compose initial tensor to get proper shape for norm computation - X = compose_tensor(I_tensor, A, S) # (num_nodes, 3, 3, units) + X = fn_compose_tensor(I, A, S) # (num_nodes, 3, 3, units) # Normalize and process - # Following original: norm = tensor_norm(scalars + skew_matrices + traceless_tensors) - # For X with shape (num_nodes, 3, 3, units), we need to sum over (-3, -2) - # which are the (3, 3) spatial dimensions - # tensor_norm sums over (-2, -1), but we need (-3, -2) for our tensor shape - # So we compute the norm manually: sum over the spatial (3, 3) dimensions - norm = (X**2).sum((-3, -2)) # (num_nodes, units) - norm = self.init_norm(norm) # (num_nodes, units) - - # Apply tensor linear transformations - # I_tensor has shape (num_nodes, 1, 1, units), A and S have (num_nodes, 3, 3, units) - # The linear layer expects (..., units) as the last dimension - # Original code: permute(0, 2, 3, 1) puts units in position -2, then linear, then permute back - # For (num_nodes, 3, 3, units): permute(0, 2, 3, 1) -> (num_nodes, 3, units, 3) - # But linear expects (..., units), so we need to reshape or use a different approach - # Actually, the linear is applied to each spatial position independently - # So we reshape to (num_nodes * 3 * 3, units), apply linear, reshape back - if I_tensor.shape[1] == 1 and I_tensor.shape[2] == 1: - # Expand I_tensor from (num_nodes, 1, 1, units) to (num_nodes, 3, 3, units) - eye = torch.eye(3, 3, device=I_tensor.device, dtype=I_tensor.dtype) # (3, 3) - I_values = I_tensor.squeeze(1).squeeze(1) # (num_nodes, units) - I_expanded = I_values.unsqueeze(1).unsqueeze(1) * eye.unsqueeze(0).unsqueeze(-1) # (num_nodes, 3, 3, units) - # Reshape to (num_nodes * 3 * 3, units), apply linear, reshape back - I_reshaped = I_expanded.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[0](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_expanded.shape) # (num_nodes, 3, 3, units) - else: - # Reshape to (num_nodes * 3 * 3, units), apply linear, reshape back - I_reshaped = I_tensor.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[0](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_tensor.shape) # (num_nodes, 3, 3, units) - # Same for A and S - A_reshaped = A.reshape(-1, self.units) # (num_nodes * 9, units) - A_reshaped = self.linears_tensor[1](A_reshaped) # (num_nodes * 9, units) - A = A_reshaped.reshape(A.shape) # (num_nodes, 3, 3, units) - - S_reshaped = S.reshape(-1, self.units) # (num_nodes * 9, units) - S_reshaped = self.linears_tensor[2](S_reshaped) # (num_nodes * 9, units) - S = S_reshaped.reshape(S.shape) # (num_nodes, 3, 3, units) + norm = tensor_norm(X) # (num_nodes, units) + norm = self.init_norm(norm) # (num_nodes, units) # Process norm through scalar layers for linear_scalar in self.linears_scalar: norm = self.act(linear_scalar(norm)) - norm = norm.reshape(norm.shape[0], self.units, 3) - norm_I, norm_A, norm_S = norm[..., 0], norm[..., 1], norm[..., 2] + norm = norm.view(-1, self.units, 3) + norm_I, norm_A, norm_S = norm.unbind(dim=-1) # Apply norm to tensors - I_tensor = I_tensor * norm_I.unsqueeze(1).unsqueeze(1) - A = A * norm_A.unsqueeze(1).unsqueeze(1) - S = S * norm_S.unsqueeze(1).unsqueeze(1) + I = self.linears_tensor[0](I) * norm_I.unsqueeze(-2) # noqa: E741 + A = self.linears_tensor[1](A) * norm_A.unsqueeze(-2) + S = self.linears_tensor[2](S) * norm_S.unsqueeze(-2) - X = compose_tensor(I_tensor, A, S) + X = fn_compose_tensor(I, A, S) return X @@ -453,6 +288,12 @@ def forward( edge_index: torch.Tensor, edge_weight: torch.Tensor, edge_attr: torch.Tensor, + row_data: torch.Tensor, + row_indices: torch.Tensor, + row_indptr: torch.Tensor, + col_data: torch.Tensor, + col_indices: torch.Tensor, + col_indptr: torch.Tensor, ) -> torch.Tensor: """Forward pass. @@ -461,97 +302,74 @@ def forward( edge_index: Edge indices, shape (2, num_edges) edge_weight: Edge weights (distances), shape (num_edges,) edge_attr: Edge attributes (RBF), shape (num_edges, num_rbf) + row_data: CSR row data indices for message passing. + row_indices: CSR row indices for message passing. + row_indptr: CSR row pointers for message passing. + col_data: CSC column data indices for message passing. + col_indices: CSC column indices for message passing. + col_indptr: CSC column pointers for message passing. Returns: X: Updated tensor representations, shape (num_nodes, 3, 3, units) """ - num_nodes = X.shape[0] - # Process edge attributes C = cosine_cutoff(edge_weight, self.cutoff) edge_attr_processed = edge_attr for linear_scalar in self.linears_scalar: edge_attr_processed = self.act(linear_scalar(edge_attr_processed)) - edge_attr_processed = (edge_attr_processed * C.view(-1, 1)).reshape( - edge_attr.shape[0], 3, self.units + edge_attr_processed = ( + (edge_attr_processed * C.view(-1, 1)).view(edge_attr.shape[0], self.units, 3).mT.contiguous() ) # (num_edges, 3, units) # Normalize input tensor # For X with shape (num_nodes, 3, 3, units), we need to sum over (-3, -2) # which are the (3, 3) spatial dimensions to get (num_nodes, units) - norm_X = (X**2).sum((-3, -2)) + 1 # (num_nodes, units) - X = X / norm_X.reshape(-1, 1, 1, X.shape[-1]) + norm_X = (X * X).sum((-3, -2)) + 1 # (num_nodes, units) + X = X / norm_X.view(-1, 1, 1, X.shape[-1]) # Decompose input tensor - # X has shape (num_nodes, 3, 3, units) - # decompose_tensor expects (..., 3, 3), so we permute to (num_nodes, units, 3, 3) - # then apply decompose_tensor which works on the last two dimensions (3, 3) - X_permuted = X.permute(0, 3, 1, 2) # (num_nodes, units, 3, 3) - # decompose_tensor works on last two dims, so this will work for each (num_nodes, units) slice - I_permuted, A_permuted, S_permuted = decompose_tensor(X_permuted) # Each: (num_nodes, units, 3, 3) - # Permute back: (num_nodes, units, 3, 3) -> (num_nodes, 3, 3, units) - I_tensor = I_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - A = A_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - S = S_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) + I, A, S = fn_decompose_tensor(X) # noqa: E741 # Apply tensor linear transformations - # Reshape to (num_nodes * 9, units), apply linear, reshape back - I_reshaped = I_tensor.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[0](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_tensor.shape) # (num_nodes, 3, 3, units) - - A_reshaped = A.reshape(-1, self.units) # (num_nodes * 9, units) - A_reshaped = self.linears_tensor[1](A_reshaped) # (num_nodes * 9, units) - A = A_reshaped.reshape(A.shape) # (num_nodes, 3, 3, units) + I = self.linears_tensor[0](I) # noqa: E741 + A = self.linears_tensor[1](A) + S = self.linears_tensor[2](S) - S_reshaped = S.reshape(-1, self.units) # (num_nodes * 9, units) - S_reshaped = self.linears_tensor[2](S_reshaped) # (num_nodes * 9, units) - S = S_reshaped.reshape(S.shape) # (num_nodes, 3, 3, units) - Y = compose_tensor(I_tensor, A, S) + # compose back + Y = fn_compose_tensor(I, A, S) # Message passing - Im, Am, Sm = message_passing(I_tensor, A, S, edge_attr_processed, edge_index, num_nodes) - msg = compose_tensor(Im, Am, Sm) + Im, Am, Sm = fn_message_passing( + I, + A, + S, + edge_attr_processed, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + msg = fn_compose_tensor(Im, Am, Sm) # Apply group action if self.equivariance_invariance_group == "O(3)": - C = tensor_matmul_o3(Y, msg) # (num_nodes, 3, 3, units) + C = fn_tensor_matmul_o3_3x3(Y, msg) else: # SO(3) - C = tensor_matmul_so3(Y, msg) # (num_nodes, 3, 3, units) - C = 2 * C - - # decompose_tensor expects (..., 3, 3), so permute to (num_nodes, units, 3, 3) - C_permuted = C.permute(0, 3, 1, 2) # (num_nodes, units, 3, 3) - I_permuted, A_permuted, S_permuted = decompose_tensor(C_permuted) # Each: (num_nodes, units, 3, 3) - # Permute back: (num_nodes, units, 3, 3) -> (num_nodes, 3, 3, units) - I_tensor = I_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - A = A_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - S = S_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) + C = 2 * fn_tensor_matmul_so3_3x3(Y, msg) + I, A, S = fn_decompose_tensor(C) # noqa: E741 # Normalize - # For compose_tensor(I_tensor, A, S) with shape (num_nodes, 3, 3, units), - # we need to sum over (-3, -2) to get (num_nodes, units) - X_composed = compose_tensor(I_tensor, A, S) # (num_nodes, 3, 3, units) - normp1 = ((X_composed**2).sum((-3, -2)) + 1).reshape(-1, 1, 1, X_composed.shape[-1]) - I_tensor, A, S = I_tensor / normp1, A / normp1, S / normp1 + normp1 = (tensor_norm(C) + 1).unsqueeze(-2) + I, A, S = I / normp1, A / normp1, S / normp1 # noqa: E741 # Final tensor transformations - # Reshape to (num_nodes * 9, units), apply linear, reshape back - I_reshaped = I_tensor.reshape(-1, self.units) # (num_nodes * 9, units) - I_reshaped = self.linears_tensor[3](I_reshaped) # (num_nodes * 9, units) - I_tensor = I_reshaped.reshape(I_tensor.shape) # (num_nodes, 3, 3, units) - - A_reshaped = A.reshape(-1, self.units) # (num_nodes * 9, units) - A_reshaped = self.linears_tensor[4](A_reshaped) # (num_nodes * 9, units) - A = A_reshaped.reshape(A.shape) # (num_nodes, 3, 3, units) - - S_reshaped = S.reshape(-1, self.units) # (num_nodes * 9, units) - S_reshaped = self.linears_tensor[5](S_reshaped) # (num_nodes * 9, units) - S = S_reshaped.reshape(S.shape) # (num_nodes, 3, 3, units) - dX = compose_tensor(I_tensor, A, S) - - # Update - X = X + dX + torch.einsum("nijk,njlk->nilk", dX, dX) + I = self.linears_tensor[3](I) # noqa: E741 + A = self.linears_tensor[4](A) + S = self.linears_tensor[5](S) + dX = fn_compose_tensor(I, A, S) + X = X + dX + fn_tensor_matmul_so3_3x3(dX, dX) return X @@ -751,8 +569,8 @@ def forward( else: # PyG Data object - extract tensors z = getattr(g, "node_type", getattr(g, "z", None)) - pos = g.pos # type: ignore[union-attr] - edge_index = g.edge_index # type: ignore[union-attr] + pos = g.pos # type: ignore[attr-defined] + edge_index = g.edge_index # type: ignore[attr-defined] pbc_offshift = getattr(g, "pbc_offshift", None) batch = getattr(g, "batch", None) num_graphs = getattr(g, "num_graphs", None) @@ -760,31 +578,36 @@ def forward( # Obtain graph, with distances and relative position vectors bond_vec, bond_dist = compute_pair_vector_and_distance(pos, edge_index, pbc_offshift) + # perpare graph indices for message passing + row_data, row_indices, row_indptr, col_data, col_indices, col_indptr = graph_transform( + edge_index.int(), + z.shape[0], # type: ignore[union-attr] + ) + # Expand distances with radial basis functions edge_attr = self.bond_expansion(bond_dist) # Embedding layer - X = self.tensor_embedding(z, edge_index, bond_dist, bond_vec, edge_attr) + X = self.tensor_embedding(z, edge_index, bond_dist, bond_vec, edge_attr, col_data, col_indptr) # Interaction layers for layer in self.layers: - X = layer(X, edge_index, bond_dist, edge_attr) - - # decompose_tensor expects (..., 3, 3), so permute to (num_nodes, units, 3, 3) - # X has shape (num_nodes, 3, 3, units) - X_permuted = X.permute(0, 3, 1, 2) # (num_nodes, units, 3, 3) - scalars_permuted, skew_metrices_permuted, traceless_tensors_permuted = decompose_tensor(X_permuted) - # Permute back: (num_nodes, units, 3, 3) -> (num_nodes, 3, 3, units) - scalars = scalars_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - skew_metrices = skew_metrices_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - traceless_tensors = traceless_tensors_permuted.permute(0, 2, 3, 1) # (num_nodes, 3, 3, units) - - # tensor_norm sums over (-2, -1), but for (num_nodes, 3, 3, units) we need to sum over (-3, -2) - # to get (num_nodes, units) - scalars_norm = (scalars**2).sum((-3, -2)) # (num_nodes, units) - skew_norm = (skew_metrices**2).sum((-3, -2)) # (num_nodes, units) - traceless_norm = (traceless_tensors**2).sum((-3, -2)) # (num_nodes, units) - x = torch.cat((scalars_norm, skew_norm, traceless_norm), dim=-1) # (num_nodes, 3 * units) + X = layer( + X, + edge_index, + bond_dist, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + # compute I, A, S norms + x = fn_tensor_norm3(X) + # normalize x = self.out_norm(x) x = self.linear(x) @@ -807,7 +630,7 @@ def forward( batch_long = batch.to(torch.long) if num_graphs is None: num_graphs = int(batch_long.max().item()) + 1 - return scatter_add(atomic_energies, batch_long, dim_size=num_graphs) + return scatter_add(atomic_energies, batch_long, dim_size=num_graphs) # type: ignore[arg-type] # Single graph case: Sum all energies (equivalent to scatter_add with all nodes in one graph) return torch.sum(atomic_energies, dim=0, keepdim=True).squeeze() diff --git a/src/matgl/ops/__init__.py b/src/matgl/ops/__init__.py new file mode 100644 index 00000000..a2d408aa --- /dev/null +++ b/src/matgl/ops/__init__.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Custom tensor operations with Warp kernel implementations.""" + +from __future__ import annotations + +import warp as wp + +from .compose_tensor import fn_compose_tensor +from .decompose_tensor import fn_decompose_tensor +from .equivariant_o3_matmul import fn_tensor_matmul_o3_3x3 +from .equivariant_so3_matmul import fn_tensor_matmul_so3_3x3 +from .graph_transform import graph_transform +from .tensor_norm3 import fn_tensor_norm3 +from .tensornet_mp import fn_message_passing +from .tensornet_radial_mp import fn_radial_message_passing + +wp.init() + +__all__ = [ + "fn_compose_tensor", + "fn_decompose_tensor", + "fn_message_passing", + "fn_radial_message_passing", + "fn_radial_message_passing", + "fn_tensor_matmul_o3_3x3", + "fn_tensor_matmul_so3_3x3", + "fn_tensor_norm3", + "graph_transform", +] diff --git a/src/matgl/ops/compose_tensor.py b/src/matgl/ops/compose_tensor.py new file mode 100644 index 00000000..78e141b2 --- /dev/null +++ b/src/matgl/ops/compose_tensor.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::compose_tensor_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty((x.shape[0], 3, 3, x.shape[-1]), dtype=x.dtype, device=x.device) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + compose_tensor_fwd = get_module("compose_tensor_fwd", [str(x.dtype)]) + wp.launch( + compose_tensor_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, z_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("tensornet::compose_tensor_fwd_primitive") +def _(x: Tensor, y: Tensor, z: Tensor) -> Tensor: + return torch.empty((z.shape[0], 3, 3, z.shape[-1]), dtype=x.dtype, device=x.device) + + +@torch.library.custom_op( + "tensornet::compose_tensor_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor, y: Tensor, z: Tensor) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.zeros_like(x) + grad_y = torch.zeros_like(y) + grad_z = torch.zeros_like(z) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_z_wp = wp.from_torch(grad_z.detach(), return_ctype=True) + + compose_tensor_bwd = get_module("compose_tensor_bwd", [str(x.dtype)]) + wp.launch( + compose_tensor_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_output_wp, grad_x_wp, grad_y_wp, grad_z_wp), + ) + + return [grad_x, grad_y, grad_z] + + +@torch.library.register_fake("tensornet::compose_tensor_bwd_primitive") +def _(grad_output: list[Tensor], x: Tensor, y: Tensor, z: Tensor) -> list[Tensor]: + return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] + + +@torch.library.custom_op( + "tensornet::compose_tensor_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, +) -> list[Tensor]: + stream = get_stream(grad_output.device) + device = wp.device_from_torch(grad_output.device) + grad_x = torch.zeros_like(grad_grad_x) + grad_y = torch.zeros_like(grad_grad_y) + grad_z = torch.zeros_like(grad_grad_z) + + grad_grad_output = torch.zeros_like(grad_output) + + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_grad_z_wp = wp.from_torch(grad_grad_z.detach(), return_ctype=True) + + compose_tensor_bwd_bwd = get_module("compose_tensor_bwd_bwd", [str(x.dtype)]) + wp.launch( + compose_tensor_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_grad_x_wp, grad_grad_y_wp, grad_grad_z_wp, grad_grad_output_wp), + ) + + return [grad_grad_output, grad_x, grad_y, grad_z] + + +@torch.library.register_fake("tensornet::compose_tensor_bwd_bwd_primitive") +def _( + grad_output: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, +) -> list[Tensor]: + return [ + torch.empty_like(grad_output), + torch.empty_like(x), + torch.empty_like(y), + torch.empty_like(z), + ] + + +def compose_tensor_setup_fwd_context(ctx, inputs, output): + (x, y, z) = inputs + ctx.save_for_backward(x, y, z) + + +def compose_tensor_setup_bwd_context(ctx, inputs, output): + (grad_output, x, y, z) = inputs + ctx.save_for_backward(grad_output, x, y, z) + + +@torch.compiler.allow_in_graph +def compose_tensor_fwd(*args): + return torch.ops.tensornet.compose_tensor_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def compose_tensor_bwd(ctx, grad_output): + x, y, z = ctx.saved_tensors + dx, dy, dz = torch.ops.tensornet.compose_tensor_bwd_primitive(grad_output, x, y, z) + return dx, dy, dz + + +@torch.compiler.allow_in_graph +def compose_tensor_bwd_bwd(ctx, *grad_outputs): + grad_grad_x = grad_outputs[0][0] + grad_grad_y = grad_outputs[0][1] + grad_grad_z = grad_outputs[0][2] + + grad_output_saved, x, y, z = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + if grad_grad_y is None: + grad_grad_y = torch.zeros_like(y) + if grad_grad_z is None: + grad_grad_z = torch.zeros_like(z) + + outputs = torch.ops.tensornet.compose_tensor_bwd_bwd_primitive( + grad_output_saved, grad_grad_x, grad_grad_y, grad_grad_z, x, y, z + ) + + return outputs[0], outputs[1], outputs[2], outputs[3] + + +torch.library.register_autograd( + "tensornet::compose_tensor_fwd_primitive", + compose_tensor_bwd, + setup_context=compose_tensor_setup_fwd_context, +) + +torch.library.register_autograd( + "tensornet::compose_tensor_bwd_primitive", + compose_tensor_bwd_bwd, + setup_context=compose_tensor_setup_bwd_context, +) + + +def fn_compose_tensor(x: Tensor, y: Tensor, z: Tensor) -> Tensor: + output = torch.ops.tensornet.compose_tensor_fwd_primitive(x, y, z) + return output diff --git a/src/matgl/ops/decompose_tensor.py b/src/matgl/ops/decompose_tensor.py new file mode 100644 index 00000000..42a92744 --- /dev/null +++ b/src/matgl/ops/decompose_tensor.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::decompose_tensor_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output_i = torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device) + output_a = torch.empty((x.shape[0], 3, x.shape[-1]), dtype=x.dtype, device=x.device) + output_s = torch.empty((x.shape[0], 5, x.shape[-1]), dtype=x.dtype, device=x.device) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + output_i_wp = wp.from_torch(output_i.detach(), return_ctype=True) + output_a_wp = wp.from_torch(output_a.detach(), return_ctype=True) + output_s_wp = wp.from_torch(output_s.detach(), return_ctype=True) + + decompose_tensor_fwd = get_module("decompose_tensor_fwd", [str(x.dtype)]) + wp.launch( + decompose_tensor_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, output_i_wp, output_a_wp, output_s_wp), + ) + + return [output_i, output_a, output_s] + + +@torch.library.register_fake("tensornet::decompose_tensor_fwd_primitive") +def _(x: Tensor) -> list[Tensor]: + return [ + torch.empty((x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device), + torch.empty((x.shape[0], 3, x.shape[-1]), dtype=x.dtype, device=x.device), + torch.empty((x.shape[0], 5, x.shape[-1]), dtype=x.dtype, device=x.device), + ] + + +@torch.library.custom_op( + "tensornet::decompose_tensor_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + + grad_output_i_wp = wp.from_torch(grad_output_i.detach(), return_ctype=True) + grad_output_a_wp = wp.from_torch(grad_output_a.detach(), return_ctype=True) + grad_output_s_wp = wp.from_torch(grad_output_s.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + + decompose_tensor_bwd = get_module("decompose_tensor_bwd", [str(x.dtype)]) + wp.launch( + decompose_tensor_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_output_i_wp, grad_output_a_wp, grad_output_s_wp, grad_x_wp), + ) + + return [grad_x] + + +@torch.library.register_fake("tensornet::decompose_tensor_bwd_primitive") +def _(grad_output_i: Tensor, grad_output_a: Tensor, grad_output_s: Tensor, x: Tensor) -> list[Tensor]: + return [torch.empty_like(x)] + + +@torch.library.custom_op( + "tensornet::decompose_tensor_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_i: Tensor, + grad_output_a: Tensor, + grad_output_s: Tensor, + grad_grad_x: Tensor, + x: Tensor, +) -> list[Tensor]: + stream = get_stream(grad_output_i.device) + device = wp.device_from_torch(grad_output_i.device) + grad_x = torch.zeros_like(grad_grad_x) + + grad_grad_output_i = torch.empty_like(grad_output_i) + grad_grad_output_a = torch.empty_like(grad_output_a) + grad_grad_output_s = torch.empty_like(grad_output_s) + + grad_grad_output_i_wp = wp.from_torch(grad_grad_output_i.detach(), return_ctype=True) + grad_grad_output_a_wp = wp.from_torch(grad_grad_output_a.detach(), return_ctype=True) + grad_grad_output_s_wp = wp.from_torch(grad_grad_output_s.detach(), return_ctype=True) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + + decompose_tensor_bwd_bwd = get_module("decompose_tensor_bwd_bwd", [str(x.dtype)]) + wp.launch( + decompose_tensor_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + grad_grad_x_wp, + grad_grad_output_i_wp, + grad_grad_output_a_wp, + grad_grad_output_s_wp, + ), + ) + + return [grad_grad_output_i, grad_grad_output_a, grad_grad_output_s, grad_x] + + +@torch.library.register_fake("tensornet::decompose_tensor_bwd_bwd_primitive") +def _( + grad_output_i: Tensor, + grad_output_a: Tensor, + grad_output_s: Tensor, + grad_grad_x: Tensor, + x: Tensor, +) -> list[Tensor]: + return [ + torch.empty_like(grad_output_i), + torch.empty_like(grad_output_a), + torch.empty_like(grad_output_s), + torch.empty_like(grad_grad_x), + ] + + +def decompose_tensor_setup_fwd_context(ctx, inputs, output): + (x,) = inputs # Unpack the single input tensor + ctx.save_for_backward(x) + + +def decompose_tensor_setup_bwd_context(ctx, inputs, output): + (grad_output_i, grad_output_a, grad_output_s, x) = inputs + ctx.save_for_backward(grad_output_i, grad_output_a, grad_output_s, x) + + +@torch.compiler.allow_in_graph +def decompose_tensor_fwd(*args): + return torch.ops.tensornet.decompose_tensor_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def decompose_tensor_bwd(ctx, *grad_outputs): + (x,) = ctx.saved_tensors + grad_output_i, grad_output_a, grad_output_s = grad_outputs[0] + dx = torch.ops.tensornet.decompose_tensor_bwd_primitive(grad_output_i, grad_output_a, grad_output_s, x) + return dx[0] + + +@torch.compiler.allow_in_graph +def decompose_tensor_bwd_bwd(ctx, *grad_outputs): + (grad_grad_x,) = grad_outputs[0] + + grad_output_i, grad_output_a, grad_output_s, x = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + + outputs = torch.ops.tensornet.decompose_tensor_bwd_bwd_primitive( + grad_output_i, grad_output_a, grad_output_s, grad_grad_x, x + ) + + return outputs[0], outputs[1], outputs[2], outputs[3] + + +torch.library.register_autograd( + "tensornet::decompose_tensor_fwd_primitive", + decompose_tensor_bwd, + setup_context=decompose_tensor_setup_fwd_context, +) + +torch.library.register_autograd( + "tensornet::decompose_tensor_bwd_primitive", + decompose_tensor_bwd_bwd, + setup_context=decompose_tensor_setup_bwd_context, +) + + +def fn_decompose_tensor(x: Tensor) -> list[Tensor]: + output = torch.ops.tensornet.decompose_tensor_fwd_primitive(x) + return output diff --git a/src/matgl/ops/equivariant_o3_matmul.py b/src/matgl/ops/equivariant_o3_matmul.py new file mode 100644 index 00000000..c042fd10 --- /dev/null +++ b/src/matgl/ops/equivariant_o3_matmul.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::tensor_matmul_o3_3x3_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor, y: Tensor) -> Tensor: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty_like(x) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + tensor_matmul_o3_3x3_fwd = get_module("tensor_matmul_o3_3x3_fwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_o3_3x3_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_fwd_primitive") +def _(x: Tensor, y: Tensor) -> Tensor: + return torch.empty_like(x) + + +@torch.library.custom_op( + "tensornet::tensor_matmul_o3_3x3_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + grad_y = torch.empty_like(y) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + tensor_matmul_o3_3x3_bwd = get_module("tensor_matmul_o3_3x3_bwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_o3_3x3_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, grad_output_wp, grad_x_wp, grad_y_wp), + ) + + return [grad_x, grad_y] + + +@torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_primitive") +def _(grad_output: list[Tensor], x: Tensor, y: Tensor) -> list[Tensor]: + return [torch.empty_like(x), torch.empty_like(y)] + + +@torch.library.custom_op( + "tensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + + stream = get_stream(grad_output.device) + device = wp.device_from_torch(grad_output.device) + grad_x = torch.empty_like(grad_output) + grad_y = torch.empty_like(grad_output) + + grad_grad_output = torch.empty_like(grad_output) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + tensor_matmul_o3_3x3_bwd_bwd = get_module("tensor_matmul_o3_3x3_bwd_bwd", [str(grad_output.dtype)]) + wp.launch( + tensor_matmul_o3_3x3_bwd_bwd, + dim=(grad_output.shape[0], grad_output.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_output_wp, + grad_x_wp, + grad_y_wp, + grad_grad_output_wp, + ), + ) + + return [grad_grad_output, grad_x, grad_y] + + +@torch.library.register_fake("tensornet::tensor_matmul_o3_3x3_bwd_bwd_primitive") +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: + return [ + torch.empty_like(grad_output), + torch.empty_like(grad_output), + torch.empty_like(grad_output), + ] + + +def tensor_matmul_o3_3x3_setup_fwd_context(ctx, inputs, output): + (x, y) = inputs + ctx.save_for_backward(x, y) + + +def tensor_matmul_o3_3x3_setup_bwd_context(ctx, inputs, output): + (grad_output, x, y) = inputs + ctx.save_for_backward(grad_output, x, y) + + +@torch.compiler.allow_in_graph +def tensor_matmul_o3_3x3_fwd(*args): + return torch.ops.tensornet.tensor_matmul_o3_3x3_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def tensor_matmul_o3_3x3_bwd(ctx, grad_output): + x, y = ctx.saved_tensors + dx, dy = torch.ops.tensornet.tensor_matmul_o3_3x3_bwd_primitive(grad_output, x, y) + return dx, dy + + +@torch.compiler.allow_in_graph +def tensor_matmul_o3_3x3_bwd_bwd(ctx, *grad_outputs): + grad_grad_x = grad_outputs[0][0] + grad_grad_y = grad_outputs[0][1] + + grad_output_saved, x, y = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + if grad_grad_y is None: + grad_grad_y = torch.zeros_like(y) + + outputs = torch.ops.tensornet.tensor_matmul_o3_3x3_bwd_bwd_primitive( + grad_output_saved, grad_grad_x, grad_grad_y, x, y + ) + return outputs[0], outputs[1], outputs[2] + + +torch.library.register_autograd( + "tensornet::tensor_matmul_o3_3x3_fwd_primitive", + tensor_matmul_o3_3x3_bwd, + setup_context=tensor_matmul_o3_3x3_setup_fwd_context, +) + +torch.library.register_autograd( + "tensornet::tensor_matmul_o3_3x3_bwd_primitive", + tensor_matmul_o3_3x3_bwd_bwd, + setup_context=tensor_matmul_o3_3x3_setup_bwd_context, +) + + +def fn_tensor_matmul_o3_3x3(x: Tensor, y: Tensor) -> Tensor: + z = torch.ops.tensornet.tensor_matmul_o3_3x3_fwd_primitive(x, y) + return z diff --git a/src/matgl/ops/equivariant_so3_matmul.py b/src/matgl/ops/equivariant_so3_matmul.py new file mode 100644 index 00000000..7ea61565 --- /dev/null +++ b/src/matgl/ops/equivariant_so3_matmul.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::tensor_matmul_so3_3x3_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor, y: Tensor) -> Tensor: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty_like(x) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + tensor_matmul_so3_3x3_fwd = get_module("tensor_matmul_so3_3x3_fwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_so3_3x3_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_fwd_primitive") +def _(x: Tensor, y: Tensor) -> Tensor: + return torch.empty_like(x) + + +@torch.library.custom_op( + "tensornet::tensor_matmul_so3_3x3_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + grad_y = torch.empty_like(y) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + tensor_matmul_so3_3x3_bwd = get_module("tensor_matmul_so3_3x3_bwd", [str(x.dtype)]) + wp.launch( + tensor_matmul_so3_3x3_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, y_wp, grad_output_wp, grad_x_wp, grad_y_wp), + ) + + return [grad_x, grad_y] + + +@torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_primitive") +def _(grad_output: list[Tensor], x: Tensor, y: Tensor) -> list[Tensor]: + return [torch.empty_like(x), torch.empty_like(y)] + + +@torch.library.custom_op( + "tensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: + if x.shape[1] != 3 or x.shape[2] != 3 or y.shape[1] != 3 or y.shape[2] != 3: + raise ValueError("x and y must be 3x3 matrices") + if x.ndim != 4 or y.ndim != 4: + raise ValueError("x and y must be 4D tensors") + stream = get_stream(grad_output.device) + device = wp.device_from_torch(grad_output.device) + grad_x = torch.empty_like(grad_output) + grad_y = torch.empty_like(grad_output) + + grad_grad_output = torch.empty_like(grad_output) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + + tensor_matmul_so3_3x3_bwd_bwd = get_module("tensor_matmul_so3_3x3_bwd_bwd", [str(grad_output.dtype)]) + wp.launch( + tensor_matmul_so3_3x3_bwd_bwd, + dim=(grad_output.shape[0], grad_output.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_output_wp, + grad_x_wp, + grad_y_wp, + grad_grad_output_wp, + ), + ) + + return [grad_grad_output, grad_x, grad_y] + + +@torch.library.register_fake("tensornet::tensor_matmul_so3_3x3_bwd_bwd_primitive") +def _(grad_output: Tensor, grad_grad_x: Tensor, grad_grad_y: Tensor, x: Tensor, y: Tensor) -> list[Tensor]: + return [ + torch.empty_like(grad_output), + torch.empty_like(grad_output), + torch.empty_like(grad_output), + ] + + +def tensor_matmul_so3_3x3_setup_fwd_context(ctx, inputs, output): + (x, y) = inputs + ctx.save_for_backward(x, y) + + +def tensor_matmul_so3_3x3_setup_bwd_context(ctx, inputs, output): + (grad_output, x, y) = inputs + ctx.save_for_backward(grad_output, x, y) + + +@torch.compiler.allow_in_graph +def tensor_matmul_so3_3x3_fwd(*args): + return torch.ops.tensornet.tensor_matmul_so3_3x3_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def tensor_matmul_so3_3x3_bwd(ctx, grad_output): + x, y = ctx.saved_tensors + dx, dy = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_primitive(grad_output, x, y) + return dx, dy + + +@torch.compiler.allow_in_graph +def tensor_matmul_so3_3x3_bwd_bwd(ctx, *grad_outputs): + grad_grad_x = grad_outputs[0][0] + grad_grad_y = grad_outputs[0][1] + + grad_output_saved, x, y = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + if grad_grad_y is None: + grad_grad_y = torch.zeros_like(y) + + outputs = torch.ops.tensornet.tensor_matmul_so3_3x3_bwd_bwd_primitive( + grad_output_saved, grad_grad_x, grad_grad_y, x, y + ) + return outputs[0], outputs[1], outputs[2] + + +torch.library.register_autograd( + "tensornet::tensor_matmul_so3_3x3_fwd_primitive", + tensor_matmul_so3_3x3_bwd, + setup_context=tensor_matmul_so3_3x3_setup_fwd_context, +) + +torch.library.register_autograd( + "tensornet::tensor_matmul_so3_3x3_bwd_primitive", + tensor_matmul_so3_3x3_bwd_bwd, + setup_context=tensor_matmul_so3_3x3_setup_bwd_context, +) + + +def fn_tensor_matmul_so3_3x3(x: Tensor, y: Tensor) -> Tensor: + z = torch.ops.tensornet.tensor_matmul_so3_3x3_fwd_primitive(x, y) + return z diff --git a/src/matgl/ops/graph_transform.py b/src/matgl/ops/graph_transform.py new file mode 100644 index 00000000..8e4c481e --- /dev/null +++ b/src/matgl/ops/graph_transform.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import convert_to_sparse, count_row_col, get_stream + + +@torch.library.custom_op( + "nvtnet::count_row_col_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(edge_index: Tensor, num_nodes: int) -> tuple[Tensor, Tensor]: + stream = get_stream(edge_index.device) + device = wp.device_from_torch(edge_index.device) + row_count = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + col_count = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + + edge_index_wp = wp.from_torch(edge_index, return_ctype=True) + row_count_wp = wp.from_torch(row_count, return_ctype=True) + col_count_wp = wp.from_torch(col_count, return_ctype=True) + + wp.launch( + count_row_col, + dim=(edge_index.shape[1]), + stream=stream, + device=device, + inputs=(edge_index_wp, row_count_wp, col_count_wp), + ) + + return row_count, col_count + + +@torch.library.register_fake("nvtnet::count_row_col_primitive") +def _(edge_index: Tensor, num_nodes: int) -> tuple[Tensor, Tensor]: + output = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + output2 = torch.zeros(num_nodes + 1, dtype=torch.int32, device=edge_index.device) + return output, output2 + + +@torch.library.custom_op( + "nvtnet::convert_to_sparse_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + edge_index: Tensor, + row_count: Tensor, + col_count: Tensor, + row_indptr: Tensor, + col_indptr: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + stream = get_stream(edge_index.device) + device = wp.device_from_torch(edge_index.device) + edge_index_wp = wp.from_torch(edge_index, return_ctype=True) + + row_count_wp = wp.from_torch(row_count, return_ctype=True) + col_count_wp = wp.from_torch(col_count, return_ctype=True) + + row_indptr_wp = wp.from_torch(row_indptr, return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr, return_ctype=True) + + row_indices = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + col_indices = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + + row_data = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + col_data = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + + row_indices_wp = wp.from_torch(row_indices, return_ctype=True) + col_indices_wp = wp.from_torch(col_indices, return_ctype=True) + + row_data_wp = wp.from_torch(row_data, return_ctype=True) + col_data_wp = wp.from_torch(col_data, return_ctype=True) + + wp.launch( + convert_to_sparse, + dim=(edge_index.shape[1]), + stream=stream, + device=device, + inputs=( + edge_index_wp, + row_count_wp, + col_count_wp, + row_indptr_wp, + col_indptr_wp, + row_indices_wp, + col_indices_wp, + row_data_wp, + col_data_wp, + ), + ) + + return row_indices, col_indices, row_data, col_data + + +@torch.library.register_fake("nvtnet::convert_to_sparse_primitive") +def _( + edge_index: Tensor, + row_count: Tensor, + col_count: Tensor, + row_indptr: Tensor, + col_indptr: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + output = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + output2 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + output3 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + output4 = torch.empty(edge_index.shape[1], dtype=torch.int32, device=edge_index.device) + return output, output2, output3, output4 + + +@torch.compiler.allow_in_graph +def graph_transform(edge_index: Tensor, num_nodes: int) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + row_count, col_count = torch.ops.nvtnet.count_row_col_primitive(edge_index, num_nodes) + row_indptr, col_indptr = ( + torch.cumsum(row_count, dim=0, dtype=torch.int32), + torch.cumsum(col_count, dim=0, dtype=torch.int32), + ) + ( + row_indices, + col_indices, + row_data, + col_data, + ) = torch.ops.nvtnet.convert_to_sparse_primitive(edge_index, row_count, col_count, row_indptr, col_indptr) + return row_data, row_indices, row_indptr, col_data, col_indices, col_indptr diff --git a/src/matgl/ops/tensor_norm3.py b/src/matgl/ops/tensor_norm3.py new file mode 100644 index 00000000..00e0f971 --- /dev/null +++ b/src/matgl/ops/tensor_norm3.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::tensor_norm3_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(x: Tensor) -> Tensor: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output = torch.empty((x.shape[0], 3 * x.shape[-1]), dtype=x.dtype, device=x.device) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + output_wp = wp.from_torch(output.detach(), return_ctype=True) + + tensor_norm3_fwd = get_module("tensor_norm3_fwd", [str(x.dtype)]) + wp.launch( + tensor_norm3_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(x_wp, output_wp), + ) + + return output + + +@torch.library.register_fake("tensornet::tensor_norm3_fwd_primitive") +def _(x: Tensor) -> Tensor: + return torch.empty((x.shape[0], 3 * x.shape[-1]), dtype=x.dtype, device=x.device) + + +@torch.library.custom_op( + "tensornet::tensor_norm3_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(grad_output: Tensor, x: Tensor) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + + tensor_norm3_bwd = get_module("tensor_norm3_bwd", [str(x.dtype)]) + wp.launch( + tensor_norm3_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=(grad_output_wp, x_wp, grad_x_wp), + ) + + return [grad_x] + + +@torch.library.register_fake("tensornet::tensor_norm3_bwd_primitive") +def _(grad_output: Tensor, x: Tensor) -> list[Tensor]: + return [torch.empty_like(x)] + + +@torch.library.custom_op( + "tensornet::tensor_norm3_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_grad_x: Tensor, + x: Tensor, + grad_output: Tensor, +) -> list[Tensor]: + stream = get_stream(grad_grad_x.device) + device = wp.device_from_torch(grad_grad_x.device) + grad_grad_output = torch.empty( + (grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), + dtype=grad_grad_x.dtype, + device=grad_grad_x.device, + ) + grad_x = torch.empty_like(x) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + x_wp = wp.from_torch(x.detach(), return_ctype=True) + grad_output_wp = wp.from_torch(grad_output.detach(), return_ctype=True) + grad_grad_output_wp = wp.from_torch(grad_grad_output.detach(), return_ctype=True) + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + + tensor_norm3_bwd_bwd = get_module("tensor_norm3_bwd_bwd", [str(grad_grad_x.dtype)]) + wp.launch( + tensor_norm3_bwd_bwd, + dim=(grad_grad_x.shape[0], grad_grad_x.shape[-1]), + stream=stream, + device=device, + inputs=( + grad_grad_x_wp, + x_wp, + grad_output_wp, + grad_grad_output_wp, + grad_x_wp, + ), + ) + + return [grad_grad_output, grad_x] + + +@torch.library.register_fake("tensornet::tensor_norm3_bwd_bwd_primitive") +def _( + grad_grad_x: Tensor, + x: Tensor, + grad_output: Tensor, +) -> list[Tensor]: + return [ + torch.empty( + (grad_grad_x.shape[0], 3 * grad_grad_x.shape[-1]), + dtype=grad_grad_x.dtype, + device=grad_grad_x.device, + ), + torch.empty_like(x), + ] + + +def tensor_norm3_fwd_setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + +def tensor_norm3_bwd_setup_context(ctx, inputs, output): + (grad_output, x) = inputs + ctx.save_for_backward(grad_output, x) + + +@torch.compiler.allow_in_graph +def tensor_norm3_fwd(*args): + """Forward: computes I, A, S norms of 3x3 tensor.""" + return torch.ops.tensornet.tensor_norm3_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def tensor_norm3_bwd(ctx, grad_output): + """Backward: returns grad for x.""" + (x,) = ctx.saved_tensors + return torch.ops.tensornet.tensor_norm3_bwd_primitive(grad_output, x)[0] + + +@torch.compiler.allow_in_graph +def tensor_norm3_bwd_bwd(ctx, *grad_outputs): + """Double backward: returns (grad for grad_output, grad for x).""" + (grad_grad_x,) = grad_outputs[0] + grad_output, x = ctx.saved_tensors + + if grad_grad_x is None: + grad_grad_x = torch.zeros_like(x) + + outputs = torch.ops.tensornet.tensor_norm3_bwd_bwd_primitive(grad_grad_x, x, grad_output) + return outputs[0], outputs[1] + + +torch.library.register_autograd( + "tensornet::tensor_norm3_fwd_primitive", + tensor_norm3_bwd, + setup_context=tensor_norm3_fwd_setup_context, +) + +torch.library.register_autograd( + "tensornet::tensor_norm3_bwd_primitive", + tensor_norm3_bwd_bwd, + setup_context=tensor_norm3_bwd_setup_context, +) + + +def fn_tensor_norm3(x: Tensor) -> Tensor: + return torch.ops.tensornet.tensor_norm3_fwd_primitive(x) diff --git a/src/matgl/ops/tensornet_mp.py b/src/matgl/ops/tensornet_mp.py new file mode 100644 index 00000000..da507102 --- /dev/null +++ b/src/matgl/ops/tensornet_mp.py @@ -0,0 +1,568 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::message_passing_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + output_x = torch.empty_like(x) + output_y = torch.empty_like(y) + output_z = torch.empty_like(z) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + output_x_wp = wp.from_torch(output_x.detach(), return_ctype=True) + output_y_wp = wp.from_torch(output_y.detach(), return_ctype=True) + output_z_wp = wp.from_torch(output_z.detach(), return_ctype=True) + + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_fwd = get_module("message_passing_fwd", [str(x.dtype)]) + wp.launch( + message_passing_fwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + row_data_wp, + row_indices_wp, + row_indptr_wp, + output_x_wp, + output_y_wp, + output_z_wp, + ), + ) + + return [output_x, output_y, output_z] + + +@torch.library.register_fake("tensornet::message_passing_fwd_primitive") +def _( + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + return [torch.empty_like(x), torch.empty_like(y), torch.empty_like(z)] + + +@torch.library.custom_op( + "tensornet::message_passing_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + grad_x = torch.empty_like(x) + grad_y = torch.empty_like(y) + grad_z = torch.empty_like(z) + + grad_edge_attr = torch.zeros_like(edge_attr) + + grad_output_x_wp = wp.from_torch(grad_output_x.detach(), return_ctype=True) + grad_output_y_wp = wp.from_torch(grad_output_y.detach(), return_ctype=True) + grad_output_z_wp = wp.from_torch(grad_output_z.detach(), return_ctype=True) + + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + col_data_wp = wp.from_torch(col_data.detach(), return_ctype=True) + col_indices_wp = wp.from_torch(col_indices.detach(), return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr.detach(), return_ctype=True) + + grad_x_wp = wp.from_torch(grad_x.detach(), return_ctype=True) + grad_y_wp = wp.from_torch(grad_y.detach(), return_ctype=True) + grad_z_wp = wp.from_torch(grad_z.detach(), return_ctype=True) + grad_edge_attr_wp = wp.from_torch(grad_edge_attr.detach(), return_ctype=True) + + message_passing_bwd = get_module("message_passing_bwd", [str(x.dtype)]) + + wp.launch( + message_passing_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + grad_output_x_wp, + grad_output_y_wp, + grad_output_z_wp, + col_data_wp, + col_indices_wp, + col_indptr_wp, + grad_x_wp, + grad_y_wp, + grad_z_wp, + grad_edge_attr_wp, + ), + ) + + return [grad_x, grad_y, grad_z, grad_edge_attr] + + +@torch.library.register_fake("tensornet::message_passing_bwd_primitive") +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + return [ + torch.empty_like(x), + torch.empty_like(y), + torch.empty_like(z), + torch.empty_like(edge_attr), + ] + + +@torch.library.custom_op( + "tensornet::message_passing_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + grad_grad_edge_attr: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + stream = get_stream(x.device) + device = wp.device_from_torch(x.device) + + # Convert inputs to warp arrays + x_wp = wp.from_torch(x.detach(), return_ctype=True) + y_wp = wp.from_torch(y.detach(), return_ctype=True) + z_wp = wp.from_torch(z.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + grad_grad_x_wp = wp.from_torch(grad_grad_x.detach(), return_ctype=True) + grad_grad_y_wp = wp.from_torch(grad_grad_y.detach(), return_ctype=True) + grad_grad_z_wp = wp.from_torch(grad_grad_z.detach(), return_ctype=True) + grad_grad_edge_attr_wp = wp.from_torch(grad_grad_edge_attr.detach(), return_ctype=True) + + grad_output_x_wp = wp.from_torch(grad_output_x.detach(), return_ctype=True) + grad_output_y_wp = wp.from_torch(grad_output_y.detach(), return_ctype=True) + grad_output_z_wp = wp.from_torch(grad_output_z.detach(), return_ctype=True) + + col_data_wp = wp.from_torch(col_data.detach(), return_ctype=True) + col_indices_wp = wp.from_torch(col_indices.detach(), return_ctype=True) + col_indptr_wp = wp.from_torch(col_indptr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indices_wp = wp.from_torch(row_indices.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + # Allocate output tensors (no zero-init needed with two-kernel approach) + dgrad_x = torch.empty_like(x) + dgrad_y = torch.empty_like(y) + dgrad_z = torch.empty_like(z) + dgrad_edge_attr = torch.empty_like(edge_attr) + dgrad_output_x = torch.empty_like(grad_output_x) + dgrad_output_y = torch.empty_like(grad_output_y) + dgrad_output_z = torch.empty_like(grad_output_z) + + dgrad_x_wp = wp.from_torch(dgrad_x.detach(), return_ctype=True) + dgrad_y_wp = wp.from_torch(dgrad_y.detach(), return_ctype=True) + dgrad_z_wp = wp.from_torch(dgrad_z.detach(), return_ctype=True) + dgrad_edge_attr_wp = wp.from_torch(dgrad_edge_attr.detach(), return_ctype=True) + dgrad_output_x_wp = wp.from_torch(dgrad_output_x.detach(), return_ctype=True) + dgrad_output_y_wp = wp.from_torch(dgrad_output_y.detach(), return_ctype=True) + dgrad_output_z_wp = wp.from_torch(dgrad_output_z.detach(), return_ctype=True) + + # Kernel 1: col-based - computes d2I, d2A, d2S, d2edge_attr + message_passing_edge_bwd_bwd = get_module("message_passing_edge_bwd_bwd", [str(x.dtype)]) + wp.launch( + message_passing_edge_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_grad_z_wp, + grad_grad_edge_attr_wp, + grad_output_x_wp, + grad_output_y_wp, + grad_output_z_wp, + col_data_wp, + col_indices_wp, + col_indptr_wp, + dgrad_x_wp, + dgrad_y_wp, + dgrad_z_wp, + dgrad_edge_attr_wp, + ), + ) + + # Kernel 2: row-based - computes d2output_I, d2output_A, d2output_S + message_passing_output_bwd_bwd = get_module("message_passing_output_bwd_bwd", [str(x.dtype)]) + wp.launch( + message_passing_output_bwd_bwd, + dim=(x.shape[0], x.shape[-1]), + stream=stream, + device=device, + inputs=( + x_wp, + y_wp, + z_wp, + edge_attr_wp, + grad_grad_x_wp, + grad_grad_y_wp, + grad_grad_z_wp, + grad_grad_edge_attr_wp, + row_data_wp, + row_indices_wp, + row_indptr_wp, + dgrad_output_x_wp, + dgrad_output_y_wp, + dgrad_output_z_wp, + ), + ) + + return [ + dgrad_output_x, + dgrad_output_y, + dgrad_output_z, + dgrad_x, + dgrad_y, + dgrad_z, + dgrad_edge_attr, + ] + + +@torch.library.register_fake("tensornet::message_passing_bwd_bwd_primitive") +def _( + grad_output_x: Tensor, + grad_output_y: Tensor, + grad_output_z: Tensor, + grad_grad_x: Tensor, + grad_grad_y: Tensor, + grad_grad_z: Tensor, + grad_grad_edge_attr: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + return [ + torch.empty_like(grad_output_x), + torch.empty_like(grad_output_y), + torch.empty_like(grad_output_z), + torch.empty_like(grad_grad_x), + torch.empty_like(grad_grad_y), + torch.empty_like(grad_grad_z), + torch.empty_like(grad_grad_edge_attr), + ] + + +def message_passing_setup_fwd_context(ctx, inputs, output): + ( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = inputs + ctx.save_for_backward( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + +def message_passing_setup_bwd_context(ctx, inputs, output): + ( + grad_output_x, + grad_output_y, + grad_output_z, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = inputs + ctx.save_for_backward( + grad_output_x, + grad_output_y, + grad_output_z, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + +@torch.compiler.allow_in_graph +def message_passing_fwd(*args): + return torch.ops.tensornet.message_passing_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def message_passing_bwd(ctx, grad_outputs): + ( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = ctx.saved_tensors + + result = torch.ops.tensornet.message_passing_bwd_primitive( + grad_outputs[0], + grad_outputs[1], + grad_outputs[2], + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + grad_x, grad_y, grad_z, grad_edge_attr = result + + return grad_x, grad_y, grad_z, grad_edge_attr, None, None, None, None, None, None + + +@torch.compiler.allow_in_graph +def message_passing_bwd_bwd(ctx, *grad_outputs): + grad_grad_x, grad_grad_y, grad_grad_z, grad_grad_edge_attr = grad_outputs[0] + + ( + grad_output_x, + grad_output_y, + grad_output_z, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) = ctx.saved_tensors + + result = torch.ops.tensornet.message_passing_bwd_bwd_primitive( + grad_output_x, + grad_output_y, + grad_output_z, + grad_grad_x, + grad_grad_y, + grad_grad_z, + grad_grad_edge_attr, + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) + + return ( + result[0], + result[1], + result[2], + result[3], + result[4], + result[5], + result[6], + None, + None, + None, + None, + None, + None, + ) + + +torch.library.register_autograd( + "tensornet::message_passing_fwd_primitive", + message_passing_bwd, + setup_context=message_passing_setup_fwd_context, +) + +torch.library.register_autograd( + "tensornet::message_passing_bwd_primitive", + message_passing_bwd_bwd, + setup_context=message_passing_setup_bwd_context, +) + + +def fn_message_passing( + x: Tensor, + y: Tensor, + z: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indices: Tensor, + row_indptr: Tensor, + col_data: Tensor, + col_indices: Tensor, + col_indptr: Tensor, +) -> list[Tensor]: + return torch.ops.tensornet.message_passing_fwd_primitive( + x, + y, + z, + edge_attr, + row_data, + row_indices, + row_indptr, + col_data, + col_indices, + col_indptr, + ) diff --git a/src/matgl/ops/tensornet_radial_mp.py b/src/matgl/ops/tensornet_radial_mp.py new file mode 100644 index 00000000..8e3fccbe --- /dev/null +++ b/src/matgl/ops/tensornet_radial_mp.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + +import torch +import warp as wp +from torch import Tensor + +from matgl.kernels import get_module, get_stream + + +@torch.library.custom_op( + "tensornet::radial_message_passing_fwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> list[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + stream = get_stream(edge_vec_norm.device) + device = wp.device_from_torch(edge_vec_norm.device) + output_I = torch.zeros( + (num_atoms, 1, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ) + output_A = torch.zeros( + (num_atoms, 3, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ) + output_S = torch.zeros( + (num_atoms, 5, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ) + + output_I_wp = wp.from_torch(output_I.detach(), return_ctype=True) + output_A_wp = wp.from_torch(output_A.detach(), return_ctype=True) + output_S_wp = wp.from_torch(output_S.detach(), return_ctype=True) + + edge_vec_norm_wp = wp.from_torch(edge_vec_norm.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_fwd = get_module("radial_message_passing_fwd", [str(edge_vec_norm.dtype)]) + wp.launch( + message_passing_fwd, + dim=(num_atoms, edge_attr.shape[-1]), + stream=stream, + device=device, + inputs=( + edge_vec_norm_wp, + edge_attr_wp, + row_data_wp, + row_indptr_wp, + output_I_wp, + output_A_wp, + output_S_wp, + ), + ) + + return [output_I, output_A, output_S] + + +@torch.library.register_fake("tensornet::radial_message_passing_fwd_primitive") +def _(edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor) -> list[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + return [ + torch.empty( + (num_atoms, 1, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ), + torch.empty( + (num_atoms, 3, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ), + torch.empty( + (num_atoms, 5, edge_attr.shape[-1]), + dtype=edge_vec_norm.dtype, + device=edge_vec_norm.device, + ), + ] + + +@torch.library.custom_op( + "tensornet::radial_message_passing_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> list[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + stream = get_stream(grad_output_I.device) + device = wp.device_from_torch(grad_output_I.device) + + grad_output_I_wp = wp.from_torch(grad_output_I.detach(), return_ctype=True) + grad_output_A_wp = wp.from_torch(grad_output_A.detach(), return_ctype=True) + grad_output_S_wp = wp.from_torch(grad_output_S.detach(), return_ctype=True) + + grad_edge_vec_norm = torch.zeros_like(edge_vec_norm) + grad_edge_vec_norm_wp = wp.from_torch(grad_edge_vec_norm.detach(), return_ctype=True) + + grad_edge_attr = torch.zeros_like(edge_attr) + grad_edge_attr_wp = wp.from_torch(grad_edge_attr.detach(), return_ctype=True) + + edge_vec_norm_wp = wp.from_torch(edge_vec_norm.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + message_passing_bwd = get_module("radial_message_passing_bwd", [str(edge_vec_norm.dtype)]) + wp.launch( + message_passing_bwd, + dim=(num_atoms, edge_attr.shape[-1]), + stream=stream, + device=device, + inputs=( + edge_vec_norm_wp, + edge_attr_wp, + row_data_wp, + row_indptr_wp, + grad_output_I_wp, + grad_output_A_wp, + grad_output_S_wp, + grad_edge_vec_norm_wp, + grad_edge_attr_wp, + ), + ) + + return [grad_edge_vec_norm, grad_edge_attr] + + +@torch.library.register_fake("tensornet::radial_message_passing_bwd_primitive") +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> list[Tensor]: + return [torch.empty_like(edge_vec_norm), torch.empty_like(edge_attr)] + + +@torch.library.custom_op( + "tensornet::radial_message_passing_bwd_bwd_primitive", + mutates_args=(), + device_types=["cpu", "cuda"], +) +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + grad_grad_edge_vec_norm: Tensor, + grad_grad_edge_attr: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> list[Tensor]: + num_atoms = row_indptr.shape[0] - 1 + stream = get_stream(grad_output_I.device) + device = wp.device_from_torch(grad_output_I.device) + + edge_vec_norm_wp = wp.from_torch(edge_vec_norm.detach(), return_ctype=True) + edge_attr_wp = wp.from_torch(edge_attr.detach(), return_ctype=True) + + row_data_wp = wp.from_torch(row_data.detach(), return_ctype=True) + row_indptr_wp = wp.from_torch(row_indptr.detach(), return_ctype=True) + + grad_grad_edge_vec_norm_wp = wp.from_torch(grad_grad_edge_vec_norm.detach(), return_ctype=True) + grad_grad_edge_attr_wp = wp.from_torch(grad_grad_edge_attr.detach(), return_ctype=True) + + grad_output_I_wp = wp.from_torch(grad_output_I.detach(), return_ctype=True) + grad_output_A_wp = wp.from_torch(grad_output_A.detach(), return_ctype=True) + grad_output_S_wp = wp.from_torch(grad_output_S.detach(), return_ctype=True) + dgrad_output_I = torch.zeros_like(grad_output_I) + dgrad_output_A = torch.zeros_like(grad_output_A) + dgrad_output_S = torch.zeros_like(grad_output_S) + dgrad_output_I_wp = wp.from_torch(dgrad_output_I.detach(), return_ctype=True) + dgrad_output_A_wp = wp.from_torch(dgrad_output_A.detach(), return_ctype=True) + dgrad_output_S_wp = wp.from_torch(dgrad_output_S.detach(), return_ctype=True) + + dgrad_grad_edge_vec_norm = torch.zeros_like(grad_grad_edge_vec_norm) + dgrad_grad_edge_vec_norm_wp = wp.from_torch(dgrad_grad_edge_vec_norm.detach(), return_ctype=True) + + dgrad_grad_edge_attr = torch.zeros_like(grad_grad_edge_attr) + dgrad_grad_edge_attr_wp = wp.from_torch(dgrad_grad_edge_attr.detach(), return_ctype=True) + + message_passing_bwd_bwd = get_module("radial_message_passing_bwd_bwd", [str(edge_vec_norm.dtype)]) + wp.launch( + message_passing_bwd_bwd, + dim=(num_atoms, edge_attr.shape[-1]), + stream=stream, + device=device, + inputs=( + edge_vec_norm_wp, + edge_attr_wp, + grad_grad_edge_vec_norm_wp, + grad_grad_edge_attr_wp, + grad_output_I_wp, + grad_output_A_wp, + grad_output_S_wp, + row_data_wp, + row_indptr_wp, + dgrad_grad_edge_vec_norm_wp, + dgrad_grad_edge_attr_wp, + dgrad_output_I_wp, + dgrad_output_A_wp, + dgrad_output_S_wp, + ), + ) + + return [ + dgrad_output_I, + dgrad_output_A, + dgrad_output_S, + dgrad_grad_edge_vec_norm, + dgrad_grad_edge_attr, + ] + + +@torch.library.register_fake("tensornet::radial_message_passing_bwd_bwd_primitive") +def _( + grad_output_I: Tensor, + grad_output_A: Tensor, + grad_output_S: Tensor, + grad_grad_edge_vec_norm: Tensor, + grad_grad_edge_attr: Tensor, + edge_vec_norm: Tensor, + edge_attr: Tensor, + row_data: Tensor, + row_indptr: Tensor, +) -> list[Tensor]: + return [ + torch.empty_like(grad_output_I), + torch.empty_like(grad_output_A), + torch.empty_like(grad_output_S), + torch.empty_like(grad_grad_edge_vec_norm), + torch.empty_like(grad_grad_edge_attr), + ] + + +def radial_message_passing_setup_fwd_context(ctx, inputs, output): + (edge_vec_norm, edge_attr, row_data, row_indptr) = inputs + ctx.save_for_backward(edge_vec_norm, edge_attr, row_data, row_indptr) + + +def radial_message_passing_setup_bwd_context(ctx, inputs, output): + ( + grad_output_I, + grad_output_A, + grad_output_S, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) = inputs + ctx.save_for_backward( + grad_output_I, + grad_output_A, + grad_output_S, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) + + +@torch.compiler.allow_in_graph +def radial_message_passing_fwd(*args): + return torch.ops.tensornet.radial_message_passing_fwd_primitive(*args) + + +@torch.compiler.allow_in_graph +def radial_message_passing_bwd(ctx, grad_outputs): + edge_vec_norm, edge_attr, row_data, row_indptr = ctx.saved_tensors + + result = torch.ops.tensornet.radial_message_passing_bwd_primitive( + grad_outputs[0], + grad_outputs[1], + grad_outputs[2], + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) + + grad_edge_vec_norm, grad_edge_attr = result + + return grad_edge_vec_norm, grad_edge_attr, None, None + + +@torch.compiler.allow_in_graph +def radial_message_passing_bwd_bwd(ctx, *grad_outputs): + grad_grad_edge_vec_norm, grad_grad_edge_attr = grad_outputs[0] + ( + grad_output_I, + grad_output_A, + grad_output_S, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) = ctx.saved_tensors + + result = torch.ops.tensornet.radial_message_passing_bwd_bwd_primitive( + grad_output_I, + grad_output_A, + grad_output_S, + grad_grad_edge_vec_norm, + grad_grad_edge_attr, + edge_vec_norm, + edge_attr, + row_data, + row_indptr, + ) + + ( + dgrad_output_I, + dgrad_output_A, + dgrad_output_S, + dgrad_grad_edge_vec_norm, + dgrad_grad_edge_attr, + ) = result + + return ( + dgrad_output_I, + dgrad_output_A, + dgrad_output_S, + dgrad_grad_edge_vec_norm, + dgrad_grad_edge_attr, + None, + None, + ) + + +torch.library.register_autograd( + "tensornet::radial_message_passing_fwd_primitive", + radial_message_passing_bwd, + setup_context=radial_message_passing_setup_fwd_context, +) + +torch.library.register_autograd( + "tensornet::radial_message_passing_bwd_primitive", + radial_message_passing_bwd_bwd, + setup_context=radial_message_passing_setup_bwd_context, +) + + +def fn_radial_message_passing( + edge_vec_norm: Tensor, edge_attr: Tensor, row_data: Tensor, row_indptr: Tensor +) -> list[Tensor]: + return torch.ops.tensornet.radial_message_passing_fwd_primitive(edge_vec_norm, edge_attr, row_data, row_indptr) diff --git a/tests/models/test_tensornet_pyg.py b/tests/models/test_tensornet_pyg.py index a91673ce..2a761422 100644 --- a/tests/models/test_tensornet_pyg.py +++ b/tests/models/test_tensornet_pyg.py @@ -106,3 +106,60 @@ def test_model_intensive_with_classification(self, graph_MoS_pyg): ) output = model(g=graph) assert torch.numel(output) == 1 + + def test_backward(self, graph_MoS_pyg): + """Test cell gradient (dE/dcell).""" + torch.manual_seed(0) + torch.use_deterministic_algorithms(True) + + EXPECTED_CELL_GRAD = torch.tensor( + [ + [-0.000967, 0.000000, 0.000000], + [0.000000, -0.000967, 0.000000], + [0.000000, 0.000000, -0.000967], + ] + ) + + structure, graph, _ = graph_MoS_pyg + cell = torch.tensor(structure.lattice.matrix, dtype=matgl.float_th).requires_grad_(True) + + graph.pbc_offshift = torch.matmul(graph.pbc_offset, cell) + graph.pos = graph.frac_coords @ cell + + model = TensorNet(is_intensive=False, activation_type="swish") + model.train() + + energy = model(g=graph) + cell_grad = torch.autograd.grad(energy, cell, create_graph=True)[0] + + assert torch.allclose(cell_grad, EXPECTED_CELL_GRAD, atol=1e-6) + + def test_double_backward(self, graph_MoS_pyg): + """Test double backward: loss = sum(cell_grad^2), compare cell.grad.""" + torch.manual_seed(0) + torch.use_deterministic_algorithms(True) + + EXPECTED_CELL_GRAD2 = torch.tensor( + [ + [-0.000010, -0.000000, -0.000000], + [-0.000000, -0.000010, -0.000000], + [-0.000000, -0.000000, -0.000010], + ] + ) + + structure, graph, _ = graph_MoS_pyg + cell = torch.tensor(structure.lattice.matrix, dtype=matgl.float_th).requires_grad_(True) + cell.retain_grad() + + graph.pbc_offshift = torch.matmul(graph.pbc_offset, cell) + graph.pos = graph.frac_coords @ cell + + model = TensorNet(is_intensive=False, activation_type="swish") + model.train() + + energy = model(g=graph) + cell_grad = torch.autograd.grad(energy, cell, create_graph=True)[0] + loss = (cell_grad * cell_grad).sum() + loss.backward() + + assert torch.allclose(cell.grad, EXPECTED_CELL_GRAD2, atol=1e-6)