Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __init__(
c_hidden: int,
no_heads: int,
gating: bool = True,
inf:float = 1e9,
):
"""
Args:
Expand All @@ -394,6 +395,7 @@ def __init__(
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
self.inf = inf

# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
Expand Down Expand Up @@ -539,12 +541,23 @@ def forward(
if biases is None:
biases = []

# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale=not use_deepspeed_evo_attention or use_cuequivariance_attention)

if is_fp16_enabled():
use_memory_efficient_kernel = False

if use_cuequivariance_attention:
# cuEquivariance -> Torch fallback for small sequence length and some shapes
if cueq_would_fall_back(q_x.shape[-2], q_x.shape[-1] // self.no_heads, q_x.dtype):
# convert the mask from boolean to float pre-mul
biases[0] = (self.inf * (biases[0] - 1))
use_cuequivariance_attention = False

# The EvoformerAttention kernel can only be used for sequence lengths > 16
if use_deepspeed_evo_attention and q_x.shape[-2] <= 16:
use_deepspeed_evo_attention = False

# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale = not (use_deepspeed_evo_attention or use_cuequivariance_attention))

# cuequivariance kernel takes precedence over use_deepspeed_evo_attention
if use_cuequivariance_attention:
Expand Down
Loading