Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions openfold3/core/model/latent/template_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openfold3.core.model.feature_embedders.template_embedders import (
TemplatePairEmbedderAllAtom,
)
from openfold3.core.model.latent.base_blocks import PairBlock
from openfold3.core.model.primitives import LayerNorm, Linear
from openfold3.core.utils.checkpointing import checkpoint_blocks, checkpoint_section
from openfold3.core.utils.chunk_utils import (
Expand All @@ -41,8 +42,6 @@
)
from openfold3.core.utils.tensor_utils import add

from .base_blocks import PairBlock


# TODO: Make arguments match PairBlock
class TemplatePairBlock(PairBlock):
Expand Down
8 changes: 5 additions & 3 deletions openfold3/core/model/layers/diffusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from ml_collections import ConfigDict

import openfold3.core.config.default_linear_init_config as lin_init
from openfold3.core.model.layers.attention_pair_bias import (
AttentionPairBias,
CrossAttentionPairBias,
)
from openfold3.core.model.layers.transition import ConditionedTransitionBlock
from openfold3.core.model.primitives import LayerNorm
from openfold3.core.utils.checkpointing import checkpoint_blocks

from .attention_pair_bias import AttentionPairBias, CrossAttentionPairBias
from .transition import ConditionedTransitionBlock


class DiffusionTransformerBlock(nn.Module):
"""Diffusion transformer block.
Expand Down
16 changes: 10 additions & 6 deletions openfold3/core/model/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .activations import SwiGLU
from .attention import (
from openfold3.core.model.primitives.activations import SwiGLU
from openfold3.core.model.primitives.attention import (
DEFAULT_LMA_KV_CHUNK_SIZE,
DEFAULT_LMA_Q_CHUNK_SIZE,
Attention,
GlobalAttention,
)
from .dropout import Dropout, DropoutColumnwise, DropoutRowwise
from .initialization import (
from openfold3.core.model.primitives.dropout import (
Dropout,
DropoutColumnwise,
DropoutRowwise,
)
from openfold3.core.model.primitives.initialization import (
final_init_,
gating_init_,
glorot_uniform_init_,
Expand All @@ -29,8 +33,8 @@
lecun_normal_init_,
trunc_normal_init_,
)
from .linear import Linear
from .normalization import AdaLN, LayerNorm
from openfold3.core.model.primitives.linear import Linear
from openfold3.core.model.primitives.normalization import AdaLN, LayerNorm

__all__ = [
"SwiGLU",
Expand Down
3 changes: 1 addition & 2 deletions openfold3/core/model/primitives/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from torch import nn

import openfold3.core.config.default_linear_init_config as lin_init

from .linear import Linear
from openfold3.core.model.primitives.linear import Linear

triton_is_installed = importlib.util.find_spec("triton") is not None
if triton_is_installed:
Expand Down
3 changes: 1 addition & 2 deletions openfold3/core/model/primitives/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@

import openfold3.core.config.default_linear_init_config as lin_init
from openfold3.core.kernels.cueq_utils import is_cuequivariance_available
from openfold3.core.model.primitives.linear import Linear
from openfold3.core.utils.checkpointing import get_checkpoint_fn
from openfold3.core.utils.tensor_utils import flatten_final_dims

from .linear import Linear

warnings.filterwarnings("once")

deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
Expand Down
2 changes: 1 addition & 1 deletion openfold3/core/model/primitives/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import torch.nn as nn

from .initialization import (
from openfold3.core.model.primitives.initialization import (
final_init_,
gating_init_,
glorot_uniform_init_,
Expand Down
3 changes: 1 addition & 2 deletions openfold3/core/model/primitives/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from ml_collections import ConfigDict

import openfold3.core.config.default_linear_init_config as lin_init

from .linear import Linear
from openfold3.core.model.primitives.linear import Linear

deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
if deepspeed_is_installed:
Expand Down
2 changes: 1 addition & 1 deletion openfold3/core/utils/geometry/rigid_matrix_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def from_tensor_4x4(cls, array):
return cls.from_array(array)

@classmethod
def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
def from_array4x4(cls, array: torch.Tensor) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0],
Expand Down
6 changes: 3 additions & 3 deletions openfold3/core/utils/geometry/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
def assert_rotation_matrix_equal(
matrix1: rotation_matrix.Rot3Array, matrix2: rotation_matrix.Rot3Array
):
for field in dataclasses.fields(rotation_matrix.Rot3Array):
field = field.name
assert torch.equal(getattr(matrix1, field), getattr(matrix2, field))
for f in dataclasses.fields(rotation_matrix.Rot3Array):
name = f.name
assert torch.equal(getattr(matrix1, name), getattr(matrix2, name))


def assert_rotation_matrix_close(
Expand Down
18 changes: 10 additions & 8 deletions openfold3/core/utils/geometry/vector.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be slightly more clear if we annotated the return types that are scalars (formerly Float) to with something like

def dot(self, other: Vec3Array) -> torch.Tensor: # return: scalar

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this is something I still need to add

Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ def cross(self, other: Vec3Array) -> Vec3Array:
new_z = self.x * other.y - self.y * other.x
return Vec3Array(new_x, new_y, new_z)

def dot(self, other: Vec3Array) -> Float:
def dot(self, other: Vec3Array) -> torch.Tensor:
"""Compute dot product between 'self' and 'other'."""
return self.x * other.x + self.y * other.y + self.z * other.z

def norm(self, epsilon: float = 1e-6) -> Float:
def norm(self, epsilon: float = 1e-6) -> torch.Tensor:
"""Compute Norm of Vec3Array, clipped to epsilon."""
# To avoid NaN on the backward pass, we must use maximum before the sqrt
norm2 = self.dot(self)
Expand Down Expand Up @@ -180,7 +180,7 @@ def cat(cls, vecs: list[Vec3Array], dim: int) -> Vec3Array:

def square_euclidean_distance(
vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6
) -> Float:
) -> torch.Tensor:
"""Computes square of euclidean distance between 'vec1' and 'vec2'.

Args:
Expand All @@ -200,15 +200,15 @@ def square_euclidean_distance(
return distance


def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
def dot(vector1: Vec3Array, vector2: Vec3Array) -> torch.Tensor:
return vector1.dot(vector2)


def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Vec3Array:
return vector1.cross(vector2)


def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> torch.Tensor:
return vector.norm(epsilon)


Expand All @@ -218,7 +218,7 @@ def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:

def euclidean_distance(
vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6
) -> Float:
) -> torch.Tensor:
"""Computes euclidean distance between 'vec1' and 'vec2'.

Args:
Expand All @@ -236,7 +236,9 @@ def euclidean_distance(
return distance


def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array) -> Float:
def dihedral_angle(
a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array
) -> torch.Tensor:
"""Computes torsion angle for a quadruple of points.

For points (a, b, c, d), this is the angle between the planes defined by
Expand Down
Empty file.
Empty file.
75 changes: 75 additions & 0 deletions openfold3/tests/utils/geometry/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Shared helpers for geometry tests."""

from __future__ import annotations

import math

import torch

from openfold3.core.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold3.core.utils.geometry.rotation_matrix import Rot3Array
from openfold3.core.utils.geometry.vector import Vec3Array

_Translation = tuple[float, float, float]


def v(x: float, y: float, z: float) -> Vec3Array:
"""Build a scalar Vec3Array from three floats."""
return Vec3Array(
torch.tensor(x, dtype=torch.float32),
torch.tensor(y, dtype=torch.float32),
torch.tensor(z, dtype=torch.float32),
)


def vb(coords: list[list[float]]) -> Vec3Array:
"""Build a batched Vec3Array from a list of [x, y, z] triples."""
t = torch.tensor(coords, dtype=torch.float32)
return Vec3Array(t[:, 0], t[:, 1], t[:, 2])


def rot_x(theta: float) -> Rot3Array:
"""Rotation about the X axis by *theta* radians."""
c, s = math.cos(theta), math.sin(theta)
return Rot3Array.from_array(
torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, c, -s],
[0.0, s, c],
]
)
)


def rot_y(theta: float) -> Rot3Array:
"""Rotation about the Y axis by *theta* radians."""
c, s = math.cos(theta), math.sin(theta)
return Rot3Array.from_array(
torch.tensor(
[
[c, 0.0, s],
[0.0, 1.0, 0.0],
[-s, 0.0, c],
]
)
)


def rot_z(theta: float) -> Rot3Array:
"""Rotation about the Z axis by *theta* radians."""
c, s = math.cos(theta), math.sin(theta)
return Rot3Array.from_array(
torch.tensor(
[
[c, -s, 0.0],
[s, c, 0.0],
[0.0, 0.0, 1.0],
]
)
)


def rigid(rot: Rot3Array, translation: _Translation) -> Rigid3Array:
"""Build a Rigid3Array from a rotation and a translation triple."""
return Rigid3Array(rot, v(*translation))
Loading
Loading