diff --git a/openfold3/core/model/latent/template_module.py b/openfold3/core/model/latent/template_module.py index cbd00ef82..6917265da 100644 --- a/openfold3/core/model/latent/template_module.py +++ b/openfold3/core/model/latent/template_module.py @@ -32,6 +32,7 @@ from openfold3.core.model.feature_embedders.template_embedders import ( TemplatePairEmbedderAllAtom, ) +from openfold3.core.model.latent.base_blocks import PairBlock from openfold3.core.model.primitives import LayerNorm, Linear from openfold3.core.utils.checkpointing import checkpoint_blocks, checkpoint_section from openfold3.core.utils.chunk_utils import ( @@ -41,8 +42,6 @@ ) from openfold3.core.utils.tensor_utils import add -from .base_blocks import PairBlock - # TODO: Make arguments match PairBlock class TemplatePairBlock(PairBlock): diff --git a/openfold3/core/model/layers/diffusion_transformer.py b/openfold3/core/model/layers/diffusion_transformer.py index 6577df4e1..59d400d02 100644 --- a/openfold3/core/model/layers/diffusion_transformer.py +++ b/openfold3/core/model/layers/diffusion_transformer.py @@ -22,12 +22,14 @@ from ml_collections import ConfigDict import openfold3.core.config.default_linear_init_config as lin_init +from openfold3.core.model.layers.attention_pair_bias import ( + AttentionPairBias, + CrossAttentionPairBias, +) +from openfold3.core.model.layers.transition import ConditionedTransitionBlock from openfold3.core.model.primitives import LayerNorm from openfold3.core.utils.checkpointing import checkpoint_blocks -from .attention_pair_bias import AttentionPairBias, CrossAttentionPairBias -from .transition import ConditionedTransitionBlock - class DiffusionTransformerBlock(nn.Module): """Diffusion transformer block. diff --git a/openfold3/core/model/primitives/__init__.py b/openfold3/core/model/primitives/__init__.py index dd715f242..87d2d9231 100644 --- a/openfold3/core/model/primitives/__init__.py +++ b/openfold3/core/model/primitives/__init__.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .activations import SwiGLU -from .attention import ( +from openfold3.core.model.primitives.activations import SwiGLU +from openfold3.core.model.primitives.attention import ( DEFAULT_LMA_KV_CHUNK_SIZE, DEFAULT_LMA_Q_CHUNK_SIZE, Attention, GlobalAttention, ) -from .dropout import Dropout, DropoutColumnwise, DropoutRowwise -from .initialization import ( +from openfold3.core.model.primitives.dropout import ( + Dropout, + DropoutColumnwise, + DropoutRowwise, +) +from openfold3.core.model.primitives.initialization import ( final_init_, gating_init_, glorot_uniform_init_, @@ -29,8 +33,8 @@ lecun_normal_init_, trunc_normal_init_, ) -from .linear import Linear -from .normalization import AdaLN, LayerNorm +from openfold3.core.model.primitives.linear import Linear +from openfold3.core.model.primitives.normalization import AdaLN, LayerNorm __all__ = [ "SwiGLU", diff --git a/openfold3/core/model/primitives/activations.py b/openfold3/core/model/primitives/activations.py index 308cc4c21..9ff0f5c80 100644 --- a/openfold3/core/model/primitives/activations.py +++ b/openfold3/core/model/primitives/activations.py @@ -22,8 +22,7 @@ from torch import nn import openfold3.core.config.default_linear_init_config as lin_init - -from .linear import Linear +from openfold3.core.model.primitives.linear import Linear triton_is_installed = importlib.util.find_spec("triton") is not None if triton_is_installed: diff --git a/openfold3/core/model/primitives/attention.py b/openfold3/core/model/primitives/attention.py index e6c74bd7f..ca8979358 100644 --- a/openfold3/core/model/primitives/attention.py +++ b/openfold3/core/model/primitives/attention.py @@ -30,11 +30,10 @@ import openfold3.core.config.default_linear_init_config as lin_init from openfold3.core.kernels.cueq_utils import is_cuequivariance_available +from openfold3.core.model.primitives.linear import Linear from openfold3.core.utils.checkpointing import get_checkpoint_fn from openfold3.core.utils.tensor_utils import flatten_final_dims -from .linear import Linear - warnings.filterwarnings("once") deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None diff --git a/openfold3/core/model/primitives/linear.py b/openfold3/core/model/primitives/linear.py index b52c8c478..66ecc273e 100644 --- a/openfold3/core/model/primitives/linear.py +++ b/openfold3/core/model/primitives/linear.py @@ -21,7 +21,7 @@ import torch import torch.nn as nn -from .initialization import ( +from openfold3.core.model.primitives.initialization import ( final_init_, gating_init_, glorot_uniform_init_, diff --git a/openfold3/core/model/primitives/normalization.py b/openfold3/core/model/primitives/normalization.py index 6680e08c2..935705f08 100644 --- a/openfold3/core/model/primitives/normalization.py +++ b/openfold3/core/model/primitives/normalization.py @@ -22,8 +22,7 @@ from ml_collections import ConfigDict import openfold3.core.config.default_linear_init_config as lin_init - -from .linear import Linear +from openfold3.core.model.primitives.linear import Linear deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None if deepspeed_is_installed: diff --git a/openfold3/core/utils/geometry/rigid_matrix_vector.py b/openfold3/core/utils/geometry/rigid_matrix_vector.py index f7f6afc79..999308537 100644 --- a/openfold3/core/utils/geometry/rigid_matrix_vector.py +++ b/openfold3/core/utils/geometry/rigid_matrix_vector.py @@ -163,7 +163,7 @@ def from_tensor_4x4(cls, array): return cls.from_array(array) @classmethod - def from_array4x4(cls, array: torch.tensor) -> Rigid3Array: + def from_array4x4(cls, array: torch.Tensor) -> Rigid3Array: """Construct Rigid3Array from homogeneous 4x4 array.""" rotation = rotation_matrix.Rot3Array( array[..., 0, 0], diff --git a/openfold3/core/utils/geometry/test_utils.py b/openfold3/core/utils/geometry/test_utils.py index 3948cb9d4..1b5f3fbac 100644 --- a/openfold3/core/utils/geometry/test_utils.py +++ b/openfold3/core/utils/geometry/test_utils.py @@ -25,9 +25,9 @@ def assert_rotation_matrix_equal( matrix1: rotation_matrix.Rot3Array, matrix2: rotation_matrix.Rot3Array ): - for field in dataclasses.fields(rotation_matrix.Rot3Array): - field = field.name - assert torch.equal(getattr(matrix1, field), getattr(matrix2, field)) + for f in dataclasses.fields(rotation_matrix.Rot3Array): + name = f.name + assert torch.equal(getattr(matrix1, name), getattr(matrix2, name)) def assert_rotation_matrix_close( diff --git a/openfold3/core/utils/geometry/vector.py b/openfold3/core/utils/geometry/vector.py index 332c83ab2..d7d731359 100644 --- a/openfold3/core/utils/geometry/vector.py +++ b/openfold3/core/utils/geometry/vector.py @@ -106,11 +106,11 @@ def cross(self, other: Vec3Array) -> Vec3Array: new_z = self.x * other.y - self.y * other.x return Vec3Array(new_x, new_y, new_z) - def dot(self, other: Vec3Array) -> Float: + def dot(self, other: Vec3Array) -> torch.Tensor: """Compute dot product between 'self' and 'other'.""" return self.x * other.x + self.y * other.y + self.z * other.z - def norm(self, epsilon: float = 1e-6) -> Float: + def norm(self, epsilon: float = 1e-6) -> torch.Tensor: """Compute Norm of Vec3Array, clipped to epsilon.""" # To avoid NaN on the backward pass, we must use maximum before the sqrt norm2 = self.dot(self) @@ -180,7 +180,7 @@ def cat(cls, vecs: list[Vec3Array], dim: int) -> Vec3Array: def square_euclidean_distance( vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6 -) -> Float: +) -> torch.Tensor: """Computes square of euclidean distance between 'vec1' and 'vec2'. Args: @@ -200,15 +200,15 @@ def square_euclidean_distance( return distance -def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: +def dot(vector1: Vec3Array, vector2: Vec3Array) -> torch.Tensor: return vector1.dot(vector2) -def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Vec3Array: return vector1.cross(vector2) -def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> torch.Tensor: return vector.norm(epsilon) @@ -218,7 +218,7 @@ def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: def euclidean_distance( vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6 -) -> Float: +) -> torch.Tensor: """Computes euclidean distance between 'vec1' and 'vec2'. Args: @@ -236,7 +236,9 @@ def euclidean_distance( return distance -def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array) -> Float: +def dihedral_angle( + a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array +) -> torch.Tensor: """Computes torsion angle for a quadruple of points. For points (a, b, c, d), this is the angle between the planes defined by diff --git a/openfold3/tests/utils/__init__.py b/openfold3/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfold3/tests/utils/geometry/__init__.py b/openfold3/tests/utils/geometry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfold3/tests/utils/geometry/helpers.py b/openfold3/tests/utils/geometry/helpers.py new file mode 100644 index 000000000..a3f59be90 --- /dev/null +++ b/openfold3/tests/utils/geometry/helpers.py @@ -0,0 +1,75 @@ +"""Shared helpers for geometry tests.""" + +from __future__ import annotations + +import math + +import torch + +from openfold3.core.utils.geometry.rigid_matrix_vector import Rigid3Array +from openfold3.core.utils.geometry.rotation_matrix import Rot3Array +from openfold3.core.utils.geometry.vector import Vec3Array + +_Translation = tuple[float, float, float] + + +def v(x: float, y: float, z: float) -> Vec3Array: + """Build a scalar Vec3Array from three floats.""" + return Vec3Array( + torch.tensor(x, dtype=torch.float32), + torch.tensor(y, dtype=torch.float32), + torch.tensor(z, dtype=torch.float32), + ) + + +def vb(coords: list[list[float]]) -> Vec3Array: + """Build a batched Vec3Array from a list of [x, y, z] triples.""" + t = torch.tensor(coords, dtype=torch.float32) + return Vec3Array(t[:, 0], t[:, 1], t[:, 2]) + + +def rot_x(theta: float) -> Rot3Array: + """Rotation about the X axis by *theta* radians.""" + c, s = math.cos(theta), math.sin(theta) + return Rot3Array.from_array( + torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, c, -s], + [0.0, s, c], + ] + ) + ) + + +def rot_y(theta: float) -> Rot3Array: + """Rotation about the Y axis by *theta* radians.""" + c, s = math.cos(theta), math.sin(theta) + return Rot3Array.from_array( + torch.tensor( + [ + [c, 0.0, s], + [0.0, 1.0, 0.0], + [-s, 0.0, c], + ] + ) + ) + + +def rot_z(theta: float) -> Rot3Array: + """Rotation about the Z axis by *theta* radians.""" + c, s = math.cos(theta), math.sin(theta) + return Rot3Array.from_array( + torch.tensor( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ] + ) + ) + + +def rigid(rot: Rot3Array, translation: _Translation) -> Rigid3Array: + """Build a Rigid3Array from a rotation and a translation triple.""" + return Rigid3Array(rot, v(*translation)) diff --git a/openfold3/tests/utils/geometry/test_rigid_matrix_vector.py b/openfold3/tests/utils/geometry/test_rigid_matrix_vector.py new file mode 100644 index 000000000..578056dfe --- /dev/null +++ b/openfold3/tests/utils/geometry/test_rigid_matrix_vector.py @@ -0,0 +1,396 @@ +"""Tests for Rigid3Array SE(3) transformation class.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from openfold3.core.utils.geometry.rigid_matrix_vector import Rigid3Array +from openfold3.core.utils.geometry.rotation_matrix import Rot3Array +from openfold3.core.utils.geometry.vector import Vec3Array +from openfold3.tests.utils.geometry.helpers import rigid as _rigid +from openfold3.tests.utils.geometry.helpers import rot_x as _rot_x +from openfold3.tests.utils.geometry.helpers import rot_z as _rot_z +from openfold3.tests.utils.geometry.helpers import v as _v + +# =================================================================== +# Construction & conversion +# =================================================================== + + +class TestConstruction: + def test_identity(self): + rig = Rigid3Array.identity((), device="cpu") + assert torch.allclose(rig.rotation.to_tensor(), torch.eye(3)) + assert torch.allclose(rig.translation.to_tensor(), torch.zeros(3)) + + def test_identity_batched(self): + rig = Rigid3Array.identity((4,), device="cpu") + assert rig.shape == (4,) + + def test_from_array_round_trip(self): + mat = torch.eye(4) + mat[:3, 3] = torch.tensor([1.0, 2.0, 3.0]) + rig = Rigid3Array.from_array(mat) + recovered = rig.to_tensor() + assert torch.allclose(recovered, mat) + + def test_from_array4x4(self): + mat = torch.eye(4) + mat[:3, 3] = torch.tensor([10.0, 20.0, 30.0]) + rig = Rigid3Array.from_array4x4(mat) + assert torch.allclose( + rig.translation.to_tensor(), torch.tensor([10.0, 20.0, 30.0]) + ) + + def test_to_tensor_has_homogeneous_row(self): + rig = _rigid(_rot_z(0.5), (1.0, 2.0, 3.0)) + mat = rig.to_tensor() + assert mat.shape == (4, 4) + assert torch.allclose(mat[3, :], torch.tensor([0.0, 0.0, 0.0, 1.0])) + + +# =================================================================== +# Properties +# =================================================================== + + +class TestProperties: + def test_shape(self): + rig = Rigid3Array.identity((2, 3), device="cpu") + assert rig.shape == (2, 3) + + def test_dtype(self): + rig = Rigid3Array.identity((), device="cpu") + assert rig.dtype == torch.float32 + + def test_device(self): + rig = Rigid3Array.identity((), device="cpu") + assert rig.device == torch.device("cpu") + + +# =================================================================== +# apply_to_point +# =================================================================== + +_APPLY_CASES = [ + pytest.param( + _rigid(Rot3Array.identity((), "cpu"), (1, 0, 0)), + _v(0, 0, 0), + _v(1, 0, 0), + id="pure-translation-x", + ), + pytest.param( + _rigid(Rot3Array.identity((), "cpu"), (0, 5, 0)), + _v(1, 2, 3), + _v(1, 7, 3), + id="translate-y-by-5", + ), + pytest.param( + _rigid(_rot_z(math.pi / 2), (0, 0, 0)), + _v(1, 0, 0), + _v(0, 1, 0), + id="pure-90z-rotation", + ), + pytest.param( + _rigid(_rot_z(math.pi / 2), (10, 0, 0)), + _v(1, 0, 0), + _v(10, 1, 0), + id="rotate-90z-then-translate", + ), +] + + +class TestApplyToPoint: + @pytest.mark.parametrize("rig,point,expected", _APPLY_CASES) + def test_known_transforms(self, rig, point, expected): + result = rig.apply_to_point(point) + assert torch.allclose(result.to_tensor(), expected.to_tensor(), atol=1e-5) + + def test_apply_tensor_interface(self): + """apply() accepts a raw [..., 3] tensor and returns one.""" + rig = _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)) + point = torch.tensor([0.0, 0.0, 0.0]) + result = rig.apply(point) + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + def test_identity_is_noop(self): + rig = Rigid3Array.identity((), device="cpu") + p = _v(5, 10, 15) + result = rig.apply_to_point(p) + assert torch.allclose(result.to_tensor(), p.to_tensor()) + + +# =================================================================== +# apply_inverse_to_point +# =================================================================== + + +class TestApplyInverse: + @pytest.mark.parametrize( + "rig,point", + [ + pytest.param( + _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)), + _v(5, 6, 7), + id="pure-translation", + ), + pytest.param( + _rigid(_rot_z(math.pi / 4), (0, 0, 0)), + _v(1, 0, 0), + id="pure-rotation-45z", + ), + pytest.param( + _rigid(_rot_x(math.pi / 3), (10, 20, 30)), + _v(1, 2, 3), + id="rotation-and-translation", + ), + ], + ) + def test_apply_then_inverse_recovers_point(self, rig, point): + transformed = rig.apply_to_point(point) + recovered = rig.apply_inverse_to_point(transformed) + assert torch.allclose(recovered.to_tensor(), point.to_tensor(), atol=1e-5) + + def test_invert_apply_tensor_interface(self): + """invert_apply() accepts a raw [..., 3] tensor and returns one.""" + rig = _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)) + point = torch.tensor([1.0, 2.0, 3.0]) + result = rig.invert_apply(point) + assert torch.allclose(result, torch.tensor([0.0, 0.0, 0.0]), atol=1e-5) + + +# =================================================================== +# Inverse +# =================================================================== + + +class TestInverse: + @pytest.mark.parametrize( + "rig", + [ + pytest.param( + Rigid3Array.identity((), device="cpu"), + id="identity", + ), + pytest.param( + _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)), + id="pure-translation", + ), + pytest.param( + _rigid(_rot_z(math.pi / 2), (0, 0, 0)), + id="pure-rotation", + ), + pytest.param( + _rigid(_rot_z(0.7), (5, -3, 8)), + id="general-transform", + ), + ], + ) + def test_T_Tinv_eq_identity(self, rig): + composed = rig @ rig.inverse() + # Rotation should be identity + assert torch.allclose(composed.rotation.to_tensor(), torch.eye(3), atol=1e-5) + # Translation should be zero + assert torch.allclose( + composed.translation.to_tensor(), torch.zeros(3), atol=1e-5 + ) + + def test_inverse_of_pure_translation(self): + rig = _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)) + inv = rig.inverse() + assert torch.allclose( + inv.translation.to_tensor(), torch.tensor([-1.0, -2.0, -3.0]) + ) + + +# =================================================================== +# Composition (matmul) +# =================================================================== + +_COMPOSE_CASES = [ + pytest.param( + _rigid(Rot3Array.identity((), "cpu"), (1, 0, 0)), + _rigid(Rot3Array.identity((), "cpu"), (0, 2, 0)), + _v(0, 0, 0), + _v(1, 2, 0), + id="two-translations-add", + ), + pytest.param( + _rigid(_rot_z(math.pi / 2), (0, 0, 0)), + _rigid(Rot3Array.identity((), "cpu"), (1, 0, 0)), + _v(0, 0, 0), + # First translate (1,0,0), then rotate 90-Z -> (0,1,0) + _v(0, 1, 0), + id="rotate-after-translate", + ), + pytest.param( + _rigid(Rot3Array.identity((), "cpu"), (1, 0, 0)), + _rigid(_rot_z(math.pi / 2), (0, 0, 0)), + _v(1, 0, 0), + # First rotate (1,0,0) -> (0,1,0), then translate by (1,0,0) -> (1,1,0) + _v(1, 1, 0), + id="translate-after-rotate", + ), +] + + +class TestComposition: + @pytest.mark.parametrize("t1,t2,point,expected", _COMPOSE_CASES) + def test_compose_apply(self, t1, t2, point, expected): + composed = t1 @ t2 + result = composed.apply_to_point(point) + assert torch.allclose(result.to_tensor(), expected.to_tensor(), atol=1e-5) + + def test_compose_method_matches_matmul(self): + a = _rigid(_rot_z(0.5), (1, 2, 3)) + b = _rigid(_rot_x(0.3), (4, 5, 6)) + via_matmul = a @ b + via_method = a.compose(b) + assert torch.allclose(via_matmul.to_tensor(), via_method.to_tensor(), atol=1e-6) + + def test_identity_is_neutral(self): + rig = _rigid(_rot_z(1.0), (5, 10, 15)) + eye = Rigid3Array.identity((), device="cpu") + assert torch.allclose((eye @ rig).to_tensor(), rig.to_tensor(), atol=1e-6) + assert torch.allclose((rig @ eye).to_tensor(), rig.to_tensor(), atol=1e-6) + + def test_compose_rotation(self): + rig = _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)) + rot = _rot_z(math.pi / 2) + result = rig.compose_rotation(rot) + # Translation is unchanged (cloned) + assert torch.allclose( + result.translation.to_tensor(), torch.tensor([1.0, 2.0, 3.0]) + ) + # Rotation is now the composed rotation + assert torch.allclose(result.rotation.to_tensor(), rot.to_tensor(), atol=1e-6) + + +# =================================================================== +# Scalar multiply & scale_translation +# =================================================================== + + +class TestScaling: + def test_scalar_mul(self): + rig = _rigid(Rot3Array.identity((), "cpu"), (1, 2, 3)) + scaled = rig * torch.tensor(2.0) + # Both rotation entries and translation scale + assert torch.allclose( + scaled.translation.to_tensor(), torch.tensor([2.0, 4.0, 6.0]) + ) + + def test_scale_translation(self): + rig = _rigid(_rot_z(math.pi / 4), (2, 4, 6)) + scaled = rig.scale_translation(0.5) + # Rotation should be unchanged + assert torch.allclose(scaled.rotation.to_tensor(), rig.rotation.to_tensor()) + # Translation should be halved + assert torch.allclose( + scaled.translation.to_tensor(), torch.tensor([1.0, 2.0, 3.0]) + ) + + +# =================================================================== +# Indexing, reshape, unsqueeze, cat +# =================================================================== + + +class TestIndexing: + def test_getitem(self): + a = _rigid(Rot3Array.identity((), "cpu"), (1, 0, 0)) + b = _rigid(Rot3Array.identity((), "cpu"), (0, 2, 0)) + batch = Rigid3Array.cat( + [ + a.map_tensor_fn(lambda t: t.unsqueeze(0)), + b.map_tensor_fn(lambda t: t.unsqueeze(0)), + ], + dim=0, + ) + assert batch.shape == (2,) + first = batch[0] + assert torch.allclose( + first.translation.to_tensor(), torch.tensor([1.0, 0.0, 0.0]) + ) + + def test_unsqueeze(self): + rig = Rigid3Array.identity((3,), device="cpu") + u = rig.unsqueeze(0) + assert u.shape == (1, 3) + + def test_reshape(self): + rig = Rigid3Array.identity((2, 3), device="cpu") + reshaped = rig.reshape((6,)) + assert reshaped.shape == (6,) + + def test_cat(self): + a = Rigid3Array.identity((2,), device="cpu") + b = Rigid3Array.identity((3,), device="cpu") + c = Rigid3Array.cat([a, b], dim=0) + assert c.shape == (5,) + + +# =================================================================== +# Gradient control +# =================================================================== + + +class TestStopRotGradient: + def test_rotation_detached_translation_kept(self): + rot_param = torch.tensor(1.0, requires_grad=True) + rot = Rot3Array( + rot_param, + rot_param, + rot_param, + rot_param, + rot_param, + rot_param, + rot_param, + rot_param, + rot_param, + ) + trans_param = torch.tensor(2.0, requires_grad=True) + trans = Vec3Array(trans_param, trans_param, trans_param) + + rig = Rigid3Array(rot, trans) + stopped = rig.stop_rot_gradient() + assert not stopped.rotation.xx.requires_grad + assert stopped.translation.x.requires_grad + + +# =================================================================== +# Batched apply +# =================================================================== + + +class TestBatchedApply: + def test_batch_of_transforms_on_batch_of_points(self): + """Two different transforms applied element-wise to two points.""" + # Transform 0: pure translation by (10, 0, 0) + # Transform 1: 90-deg Z rotation, no translation + r0 = Rot3Array.identity((), "cpu") + r1 = _rot_z(math.pi / 2) + rot = Rot3Array.from_array(torch.stack([r0.to_tensor(), r1.to_tensor()])) + trans = Vec3Array( + torch.tensor([10.0, 0.0]), + torch.tensor([0.0, 0.0]), + torch.tensor([0.0, 0.0]), + ) + rig = Rigid3Array(rot, trans) + + points = Vec3Array( + torch.tensor([1.0, 1.0]), + torch.tensor([0.0, 0.0]), + torch.tensor([0.0, 0.0]), + ) + + result = rig.apply_to_point(points) + # Point 0: identity rot + translate (10,0,0) -> (11, 0, 0) + assert torch.allclose(result.x[0], torch.tensor(11.0), atol=1e-5) + assert torch.allclose(result.y[0], torch.tensor(0.0), atol=1e-5) + # Point 1: 90-Z rotation on (1,0,0) -> (0,1,0), no translation + assert torch.allclose(result.x[1], torch.tensor(0.0), atol=1e-5) + assert torch.allclose(result.y[1], torch.tensor(1.0), atol=1e-5) diff --git a/openfold3/tests/utils/geometry/test_rotation_matrix.py b/openfold3/tests/utils/geometry/test_rotation_matrix.py new file mode 100644 index 000000000..11706caa3 --- /dev/null +++ b/openfold3/tests/utils/geometry/test_rotation_matrix.py @@ -0,0 +1,350 @@ +"""Tests for Rot3Array rotation matrix class.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from openfold3.core.utils.geometry.rotation_matrix import Rot3Array +from openfold3.core.utils.geometry.vector import Vec3Array +from openfold3.tests.utils.geometry.helpers import rot_x as _rot_x +from openfold3.tests.utils.geometry.helpers import rot_y as _rot_y +from openfold3.tests.utils.geometry.helpers import rot_z as _rot_z +from openfold3.tests.utils.geometry.helpers import v as _v + +# =================================================================== +# Construction & conversion +# =================================================================== + + +class TestConstruction: + def test_identity_is_eye(self): + eye = Rot3Array.identity((1,), device="cpu") + expected = torch.eye(3).unsqueeze(0) + assert torch.allclose(eye.to_tensor(), expected) + + def test_from_array_round_trip(self): + mat = torch.tensor( + [ + [0.0, -1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + rot = Rot3Array.from_array(mat) + assert torch.allclose(rot.to_tensor(), mat) + + def test_from_array_batched(self): + batch = torch.stack([torch.eye(3), torch.eye(3)]) + rot = Rot3Array.from_array(batch) + assert rot.xx.shape == (2,) + + def test_from_quaternion_identity(self): + # Quaternion (1, 0, 0, 0) -> identity rotation + rot = Rot3Array.from_quaternion( + w=torch.tensor(1.0), + x=torch.tensor(0.0), + y=torch.tensor(0.0), + z=torch.tensor(0.0), + ) + expected = torch.eye(3) + assert torch.allclose(rot.to_tensor(), expected, atol=1e-6) + + @pytest.mark.parametrize( + "w,x,y,z", + [ + pytest.param(0.0, 1.0, 0.0, 0.0, id="180-about-x"), + pytest.param(0.0, 0.0, 1.0, 0.0, id="180-about-y"), + pytest.param(0.0, 0.0, 0.0, 1.0, id="180-about-z"), + ], + ) + def test_from_quaternion_180(self, w, x, y, z): + rot = Rot3Array.from_quaternion( + w=torch.tensor(w), + x=torch.tensor(x), + y=torch.tensor(y), + z=torch.tensor(z), + ) + # R @ R should give identity for 180-degree rotations + composed = rot @ rot + assert torch.allclose(composed.to_tensor(), torch.eye(3), atol=1e-6) + + def test_from_quaternion_90_about_z(self): + # 90 deg about Z: w=cos(45)=sqrt(2)/2, z=sin(45)=sqrt(2)/2 + s = math.sqrt(2) / 2 + rot = Rot3Array.from_quaternion( + w=torch.tensor(s), + x=torch.tensor(0.0), + y=torch.tensor(0.0), + z=torch.tensor(s), + ) + expected = _rot_z(math.pi / 2) + assert torch.allclose(rot.to_tensor(), expected.to_tensor(), atol=1e-6) + + +# =================================================================== +# from_two_vectors +# =================================================================== + +_TWO_VEC_CASES = [ + pytest.param( + _v(1, 0, 0), + _v(0, 1, 0), + id="standard-xy", + ), + pytest.param( + _v(0, 0, 1), + _v(0, 1, 0), + id="z-and-y", + ), + pytest.param( + _v(3, 0, 0), + _v(1, 2, 0), + id="scaled-e0-tilted-e1", + ), +] + + +class TestFromTwoVectors: + @pytest.mark.parametrize("e0,e1", _TWO_VEC_CASES) + def test_result_is_orthogonal(self, e0, e1): + """R^T R should equal identity.""" + rot = Rot3Array.from_two_vectors(e0, e1) + rtr = (rot.inverse() @ rot).to_tensor() + assert torch.allclose(rtr, torch.eye(3), atol=1e-5) + + @pytest.mark.parametrize("e0,e1", _TWO_VEC_CASES) + def test_det_is_one(self, e0, e1): + """Proper rotation has determinant +1.""" + rot = Rot3Array.from_two_vectors(e0, e1) + det = torch.det(rot.to_tensor()) + assert torch.allclose(det, torch.tensor(1.0), atol=1e-5) + + def test_e0_maps_to_x_axis(self): + e0 = _v(0, 0, 5) + e1 = _v(0, 3, 0) + rot = Rot3Array.from_two_vectors(e0, e1) + # The constructed frame's first column is the normalized e0 direction, + # so applying the *inverse* frame to e0 should give the +x axis. + result = rot.apply_inverse_to_point(e0.normalized(epsilon=0)) + assert torch.allclose(result.to_tensor(), _v(1, 0, 0).to_tensor(), atol=1e-5) + + +# =================================================================== +# Inverse +# =================================================================== + + +class TestInverse: + @pytest.mark.parametrize( + "rot", + [ + pytest.param(Rot3Array.identity((), device="cpu"), id="identity"), + pytest.param(_rot_z(math.pi / 4), id="45-deg-z"), + pytest.param(_rot_x(math.pi / 3), id="60-deg-x"), + ], + ) + def test_inverse_is_transpose(self, rot): + inv = rot.inverse() + assert torch.allclose(inv.to_tensor(), rot.to_tensor().T, atol=1e-6) + + @pytest.mark.parametrize( + "rot", + [ + pytest.param(Rot3Array.identity((), device="cpu"), id="identity"), + pytest.param(_rot_z(math.pi / 2), id="90-deg-z"), + pytest.param(_rot_y(1.0), id="1-rad-y"), + ], + ) + def test_R_Rinv_eq_identity(self, rot): + composed = rot @ rot.inverse() + assert torch.allclose(composed.to_tensor(), torch.eye(3), atol=1e-6) + + +# =================================================================== +# apply_to_point +# =================================================================== + +_APPLY_CASES = [ + pytest.param( + _rot_z(math.pi / 2), + _v(1, 0, 0), + _v(0, 1, 0), + id="90z-rotates-x-to-y", + ), + pytest.param( + _rot_x(math.pi / 2), + _v(0, 1, 0), + _v(0, 0, 1), + id="90x-rotates-y-to-z", + ), + pytest.param( + _rot_y(math.pi / 2), + _v(0, 0, 1), + _v(1, 0, 0), + id="90y-rotates-z-to-x", + ), + pytest.param( + _rot_z(math.pi), + _v(1, 0, 0), + _v(-1, 0, 0), + id="180z-flips-x", + ), +] + + +class TestApplyToPoint: + @pytest.mark.parametrize("rot,point,expected", _APPLY_CASES) + def test_known_rotations(self, rot, point, expected): + result = rot.apply_to_point(point) + assert torch.allclose(result.to_tensor(), expected.to_tensor(), atol=1e-5) + + def test_identity_is_noop(self): + p = _v(1, 2, 3) + eye = Rot3Array.identity((), device="cpu") + result = eye.apply_to_point(p) + assert torch.allclose(result.to_tensor(), p.to_tensor()) + + def test_apply_inverse_undoes_apply(self): + rot = _rot_z(math.pi / 6) # 30 degrees about Z + p = _v(1, 2, 3) + transformed = rot.apply_to_point(p) + recovered = rot.apply_inverse_to_point(transformed) + assert torch.allclose(recovered.to_tensor(), p.to_tensor(), atol=1e-5) + + +# =================================================================== +# Composition (matmul) +# =================================================================== + + +class TestComposition: + def test_two_90z_eq_180z(self): + r90 = _rot_z(math.pi / 2) + r180 = r90 @ r90 + expected = _rot_z(math.pi) + assert torch.allclose(r180.to_tensor(), expected.to_tensor(), atol=1e-5) + + def test_xyz_composition(self): + """Composing rotations about different axes.""" + rx = _rot_x(math.pi / 2) + ry = _rot_y(math.pi / 2) + composed = rx @ ry + # Verify by applying to a test point + p = _v(1, 0, 0) + result_composed = composed.apply_to_point(p) + result_sequential = rx.apply_to_point(ry.apply_to_point(p)) + assert torch.allclose( + result_composed.to_tensor(), result_sequential.to_tensor(), atol=1e-5 + ) + + def test_identity_is_neutral(self): + r = _rot_z(0.7) + eye = Rot3Array.identity((), device="cpu") + assert torch.allclose((eye @ r).to_tensor(), r.to_tensor(), atol=1e-6) + assert torch.allclose((r @ eye).to_tensor(), r.to_tensor(), atol=1e-6) + + +# =================================================================== +# Scalar multiply +# =================================================================== + + +class TestScalarMultiply: + def test_mul_by_one_is_noop(self): + r = _rot_z(math.pi / 4) + result = r * torch.tensor(1.0) + assert torch.allclose(result.to_tensor(), r.to_tensor()) + + def test_mul_by_zero(self): + r = _rot_z(math.pi / 4) + result = r * torch.tensor(0.0) + assert torch.allclose(result.to_tensor(), torch.zeros(3, 3)) + + def test_mul_by_pi(self): + r = _rot_z(math.pi / 4) + result = r * torch.tensor(math.pi) + s = math.pi / math.sqrt(2) + expected_tensor = torch.tensor( + [ + [s, -s, 0.0], + [s, s, 0.0], + [0.0, 0.0, math.pi], + ] + ) + assert torch.allclose(result.to_tensor(), expected_tensor, atol=1e-6) + + +# =================================================================== +# Batched operations +# =================================================================== + + +class TestBatched: + def test_identity_batch(self): + rot = Rot3Array.identity((3,), device="cpu") + assert rot.xx.shape == (3,) + assert torch.allclose(rot.to_tensor(), torch.eye(3).expand(3, 3, 3)) + + def test_indexing(self): + r0 = _rot_z(0.0) + r1 = _rot_z(math.pi / 2) + batch = Rot3Array.cat( + [ + r0.map_tensor_fn(lambda t: t.unsqueeze(0)), + r1.map_tensor_fn(lambda t: t.unsqueeze(0)), + ], + dim=0, + ) + assert batch.xx.shape == (2,) + first = batch[0] + assert torch.allclose(first.to_tensor(), r0.to_tensor(), atol=1e-6) + + def test_apply_to_point_batched(self): + """Batch of two rotations applied to a batch of two points.""" + batch = Rot3Array.from_array( + torch.stack( + [ + _rot_z(0.0).to_tensor(), + _rot_z(math.pi / 2).to_tensor(), + ] + ) + ) + points = Vec3Array( + torch.tensor([1.0, 1.0]), + torch.tensor([0.0, 0.0]), + torch.tensor([0.0, 0.0]), + ) + result = batch.apply_to_point(points) + # First rotation is identity -> (1,0,0) unchanged + assert torch.allclose(result.x[0], torch.tensor(1.0), atol=1e-5) + assert torch.allclose(result.y[0], torch.tensor(0.0), atol=1e-5) + # Second rotation is 90-deg Z -> (1,0,0) becomes (0,1,0) + assert torch.allclose(result.x[1], torch.tensor(0.0), atol=1e-5) + assert torch.allclose(result.y[1], torch.tensor(1.0), atol=1e-5) + + def test_reshape(self): + rot = Rot3Array.identity((2, 3), device="cpu") + reshaped = rot.reshape((6,)) + assert reshaped.xx.shape == (6,) + + def test_cat(self): + a = Rot3Array.identity((2,), device="cpu") + b = Rot3Array.identity((3,), device="cpu") + c = Rot3Array.cat([a, b], dim=0) + assert c.xx.shape == (5,) + + +# =================================================================== +# Gradient control +# =================================================================== + + +class TestStopGradient: + def test_stop_gradient_detaches(self): + t = torch.tensor(1.0, requires_grad=True) + rot = Rot3Array(t, t, t, t, t, t, t, t, t) + stopped = rot.stop_gradient() + assert not stopped.xx.requires_grad diff --git a/openfold3/tests/utils/geometry/test_vector.py b/openfold3/tests/utils/geometry/test_vector.py new file mode 100644 index 000000000..09743542b --- /dev/null +++ b/openfold3/tests/utils/geometry/test_vector.py @@ -0,0 +1,372 @@ +"""Tests for Vec3Array and geometry free functions.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from openfold3.core.utils.geometry.vector import ( + Vec3Array, + cross, + dihedral_angle, + dot, + euclidean_distance, + norm, + normalized, + square_euclidean_distance, +) +from openfold3.tests.utils.geometry.helpers import v as _v +from openfold3.tests.utils.geometry.helpers import vb as _vb + +# =================================================================== +# Construction & round-trip +# =================================================================== + + +class TestConstruction: + def test_from_array_round_trip(self): + tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + v = Vec3Array.from_array(tensor) + assert torch.equal(v.x, torch.tensor([1.0, 4.0])) + assert torch.equal(v.y, torch.tensor([2.0, 5.0])) + assert torch.equal(v.z, torch.tensor([3.0, 6.0])) + assert torch.equal(v.to_tensor(), tensor) + + def test_zeros(self): + v = Vec3Array.zeros((2, 3)) + assert v.shape == (2, 3) + assert torch.equal(v.x, torch.zeros(2, 3)) + + def test_shape_property(self): + v = _vb([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert v.shape == (3,) + + def test_cat(self): + a = _vb([[1, 0, 0], [0, 1, 0]]) + b = _vb([[0, 0, 1]]) + c = Vec3Array.cat([a, b], dim=0) + assert c.shape == (3,) + assert torch.allclose(c.z, torch.tensor([0.0, 0.0, 1.0])) + + +# =================================================================== +# Arithmetic operators +# =================================================================== + +_ARITH_CASES = [ + pytest.param( + _v(1, 2, 3), + _v(4, 5, 6), + _v(5, 7, 9), + _v(-3, -3, -3), + id="simple-integers", + ), + pytest.param( + _v(0, 0, 0), + _v(1, 2, 3), + _v(1, 2, 3), + _v(-1, -2, -3), + id="zero-vector", + ), +] + + +class TestArithmetic: + @pytest.mark.parametrize("a,b,expected_sum,expected_diff", _ARITH_CASES) + def test_add(self, a, b, expected_sum, expected_diff): + result = a + b + assert torch.allclose(result.to_tensor(), expected_sum.to_tensor()) + + @pytest.mark.parametrize("a,b,expected_sum,expected_diff", _ARITH_CASES) + def test_sub(self, a, b, expected_sum, expected_diff): + result = a - b + assert torch.allclose(result.to_tensor(), expected_diff.to_tensor()) + + def test_scalar_mul(self): + v = _v(1, 2, 3) + result = v * 2.0 + assert torch.allclose(result.to_tensor(), _v(2, 4, 6).to_tensor()) + + def test_rmul(self): + v = _v(1, 2, 3) + result = 3.0 * v + assert torch.allclose(result.to_tensor(), _v(3, 6, 9).to_tensor()) + + def test_truediv(self): + v = _v(2, 4, 6) + result = v / 2.0 + assert torch.allclose(result.to_tensor(), _v(1, 2, 3).to_tensor()) + + def test_neg(self): + v = _v(1, -2, 3) + result = -v + assert torch.allclose(result.to_tensor(), _v(-1, 2, -3).to_tensor()) + + def test_pos(self): + v = _v(1, -2, 3) + result = +v + assert torch.allclose(result.to_tensor(), v.to_tensor()) + + +# =================================================================== +# Indexing, iteration, reshape +# =================================================================== + + +class TestIndexing: + def test_getitem_single(self): + v = _vb([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + second = v[1] + assert torch.allclose(second.to_tensor(), torch.tensor([0.0, 1.0, 0.0])) + + def test_getitem_slice(self): + v = _vb([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + sliced = v[:2] + assert sliced.shape == (2,) + + def test_iter_yields_xyz_tensors(self): + v = _v(1, 2, 3) + x, y, z = v + assert torch.equal(x, torch.tensor(1.0)) + assert torch.equal(y, torch.tensor(2.0)) + assert torch.equal(z, torch.tensor(3.0)) + + def test_reshape(self): + v = _vb([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 1]]) + reshaped = v.reshape((2, 2)) + assert reshaped.shape == (2, 2) + + def test_unsqueeze(self): + v = _vb([[1, 0, 0], [0, 1, 0]]) + u = v.unsqueeze(0) + assert u.shape == (1, 2) + + def test_sum(self): + v = _vb([[1, 2, 3], [4, 5, 6]]) + s = v.sum(dim=0) + assert torch.allclose(s.to_tensor(), torch.tensor([5.0, 7.0, 9.0])) + + def test_clone(self): + v = _v(1, 2, 3) + c = v.clone() + assert torch.equal(v.to_tensor(), c.to_tensor()) + # clone should produce independent storage + assert v.x.data_ptr() != c.x.data_ptr() + + +# =================================================================== +# Dot product +# =================================================================== + +_DOT_CASES = [ + pytest.param(_v(1, 0, 0), _v(1, 0, 0), 1.0, id="parallel-unit"), + pytest.param(_v(1, 0, 0), _v(0, 1, 0), 0.0, id="perpendicular"), + pytest.param(_v(1, 0, 0), _v(-1, 0, 0), -1.0, id="antiparallel"), + pytest.param(_v(1, 2, 3), _v(4, 5, 6), 32.0, id="general"), +] + + +class TestDot: + @pytest.mark.parametrize("a,b,expected", _DOT_CASES) + def test_method(self, a, b, expected): + assert torch.allclose(a.dot(b), torch.tensor(expected)) + + @pytest.mark.parametrize("a,b,expected", _DOT_CASES) + def test_free_function(self, a, b, expected): + assert torch.allclose(dot(a, b), torch.tensor(expected)) + + +# =================================================================== +# Cross product +# =================================================================== + +_CROSS_CASES = [ + pytest.param(_v(1, 0, 0), _v(0, 1, 0), _v(0, 0, 1), id="x-cross-y-eq-z"), + pytest.param(_v(0, 1, 0), _v(0, 0, 1), _v(1, 0, 0), id="y-cross-z-eq-x"), + pytest.param(_v(0, 0, 1), _v(1, 0, 0), _v(0, 1, 0), id="z-cross-x-eq-y"), + pytest.param(_v(1, 0, 0), _v(1, 0, 0), _v(0, 0, 0), id="parallel-gives-zero"), + pytest.param(_v(2, 0, 0), _v(0, 3, 0), _v(0, 0, 6), id="scaled-axes"), +] + + +class TestCross: + @pytest.mark.parametrize("a,b,expected", _CROSS_CASES) + def test_method(self, a, b, expected): + result = a.cross(b) + assert torch.allclose(result.to_tensor(), expected.to_tensor()) + + @pytest.mark.parametrize("a,b,expected", _CROSS_CASES) + def test_free_function(self, a, b, expected): + result = cross(a, b) + assert torch.allclose(result.to_tensor(), expected.to_tensor()) + + def test_anticommutativity(self): + a, b = _v(1, 2, 3), _v(4, 5, 6) + assert torch.allclose(a.cross(b).to_tensor(), (-b.cross(a)).to_tensor()) + + +# =================================================================== +# Norm / normalized +# =================================================================== + +_NORM_CASES = [ + pytest.param(_v(1, 0, 0), 1.0, id="unit-x"), + pytest.param(_v(0, 3, 0), 3.0, id="along-y"), + pytest.param(_v(3, 4, 0), 5.0, id="3-4-5-triangle"), + pytest.param(_v(1, 1, 1), math.sqrt(3), id="diagonal"), +] + + +class TestNorm: + @pytest.mark.parametrize("v,expected", _NORM_CASES) + def test_norm_method(self, v, expected): + assert torch.allclose(v.norm(epsilon=0), torch.tensor(expected)) + + @pytest.mark.parametrize("v,expected", _NORM_CASES) + def test_norm_free(self, v, expected): + assert torch.allclose(norm(v, epsilon=0), torch.tensor(expected)) + + def test_norm2(self): + v = _v(3, 4, 0) + assert torch.allclose(v.norm2(), torch.tensor(25.0)) + + def test_norm_epsilon_clamp(self): + """Near-zero vector is clamped so norm >= epsilon.""" + v = _v(0, 0, 0) + eps = 1e-6 + assert v.norm(epsilon=eps) >= eps + + +class TestNormalized: + @pytest.mark.parametrize( + "v", + [ + pytest.param(_v(5, 0, 0), id="along-x"), + pytest.param(_v(0, 0, -7), id="along-neg-z"), + pytest.param(_v(1, 1, 1), id="diagonal"), + ], + ) + def test_unit_length(self, v): + u = v.normalized(epsilon=0) + assert torch.allclose(u.norm(epsilon=0), torch.tensor(1.0), atol=1e-6) + + def test_direction_preserved(self): + v = _v(3, 0, 0) + u = v.normalized(epsilon=0) + assert torch.allclose(u.to_tensor(), _v(1, 0, 0).to_tensor(), atol=1e-6) + + def test_free_function(self): + v = _v(0, 4, 0) + u = normalized(v, epsilon=0) + assert torch.allclose(u.to_tensor(), _v(0, 1, 0).to_tensor(), atol=1e-6) + + +# =================================================================== +# Distance functions +# =================================================================== + +_DIST_CASES = [ + pytest.param(_v(0, 0, 0), _v(1, 0, 0), 1.0, id="unit-apart-x"), + pytest.param(_v(0, 0, 0), _v(3, 4, 0), 5.0, id="3-4-5"), + pytest.param(_v(1, 1, 1), _v(1, 1, 1), 0.0, id="same-point"), +] + + +class TestDistance: + @pytest.mark.parametrize("a,b,expected", _DIST_CASES) + def test_euclidean(self, a, b, expected): + # epsilon=0 for exact match on these clean cases + result = euclidean_distance(a, b, epsilon=0) + assert torch.allclose(result, torch.tensor(expected), atol=1e-6) + + @pytest.mark.parametrize("a,b,expected", _DIST_CASES) + def test_square_euclidean(self, a, b, expected): + result = square_euclidean_distance(a, b, epsilon=0) + assert torch.allclose(result, torch.tensor(expected**2), atol=1e-6) + + def test_symmetry(self): + a, b = _v(1, 2, 3), _v(4, 5, 6) + assert torch.allclose( + euclidean_distance(a, b, epsilon=0), + euclidean_distance(b, a, epsilon=0), + ) + + +# =================================================================== +# Dihedral angle +# =================================================================== + +_DIHEDRAL_CASES = [ + pytest.param( + _v(1, 1, 0), + _v(0, 1, 0), + _v(0, 0, 0), + _v(1, 0, 0), + 0.0, + id="coplanar-cis", + ), + pytest.param( + _v(1, 1, 0), + _v(0, 1, 0), + _v(0, 0, 0), + _v(-1, 0, 0), + math.pi, + id="coplanar-trans", + ), + pytest.param( + _v(1, 1, 0), + _v(0, 1, 0), + _v(0, 0, 0), + _v(0, 0, 1), + math.pi / 2, + id="perpendicular-pos", + ), + pytest.param( + _v(1, 1, 0), + _v(0, 1, 0), + _v(0, 0, 0), + _v(0, 0, -1), + -math.pi / 2, + id="perpendicular-neg", + ), +] + + +class TestDihedralAngle: + @pytest.mark.parametrize("a,b,c,d,expected_rad", _DIHEDRAL_CASES) + def test_known_angles(self, a, b, c, d, expected_rad): + result = dihedral_angle(a, b, c, d) + assert torch.allclose(result, torch.tensor(expected_rad), atol=1e-5) + + +# =================================================================== +# Batched operations +# =================================================================== + + +class TestBatched: + def test_dot_batched(self): + a = _vb([[1, 0, 0], [0, 1, 0]]) + b = _vb([[1, 0, 0], [0, 0, 1]]) + result = a.dot(b) + assert torch.allclose(result, torch.tensor([1.0, 0.0])) + + def test_cross_batched(self): + a = _vb([[1, 0, 0], [0, 1, 0]]) + b = _vb([[0, 1, 0], [0, 0, 1]]) + result = a.cross(b) + expected = _vb([[0, 0, 1], [1, 0, 0]]) + assert torch.allclose(result.to_tensor(), expected.to_tensor()) + + def test_norm_batched(self): + v = _vb([[3, 4, 0], [0, 0, 5]]) + result = v.norm(epsilon=0) + assert torch.allclose(result, torch.tensor([5.0, 5.0])) + + def test_map_tensor_fn(self): + v = _vb([[1, 2, 3], [4, 5, 6]]) + doubled = v.map_tensor_fn(lambda t: t * 2) + expected = _vb([[2, 4, 6], [8, 10, 12]]) + assert torch.allclose(doubled.to_tensor(), expected.to_tensor()) diff --git a/openfold3/tests/test_utils.py b/openfold3/tests/utils/test_utils.py similarity index 100% rename from openfold3/tests/test_utils.py rename to openfold3/tests/utils/test_utils.py diff --git a/pyproject.toml b/pyproject.toml index 40d04e8b7..424bee07b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,6 +153,8 @@ select = [ "SIM", # isort "I", + # flake8-tidy-imports (TID252 bans relative imports) + "TID", ] ignore = [ "E741", @@ -161,6 +163,9 @@ ignore = [ ] extend-safe-fixes = ["UP006"] +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + [tool.ruff.lint.per-file-ignores] "**/tests/**" = ["E501"]