diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index 241974979..5492ed710 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -370,6 +370,7 @@ def __init__( c_hidden: int, no_heads: int, gating: bool = True, + inf:float = 1e9, ): """ Args: @@ -394,6 +395,7 @@ def __init__( self.c_hidden = c_hidden self.no_heads = no_heads self.gating = gating + self.inf = inf # DISCREPANCY: c_hidden is not the per-head channel dimension, as # stated in the supplement, but the overall channel dimension. @@ -539,12 +541,23 @@ def forward( if biases is None: biases = [] - # DeepSpeed attention kernel applies scaling internally - q, k, v = self._prep_qkv(q_x, kv_x, - apply_scale=not use_deepspeed_evo_attention or use_cuequivariance_attention) - if is_fp16_enabled(): use_memory_efficient_kernel = False + + if use_cuequivariance_attention: + # cuEquivariance -> Torch fallback for small sequence length and some shapes + if cueq_would_fall_back(q_x.shape[-2], q_x.shape[-1] // self.no_heads, q_x.dtype): + # convert the mask from boolean to float pre-mul + biases[0] = (self.inf * (biases[0] - 1)) + use_cuequivariance_attention = False + + # The EvoformerAttention kernel can only be used for sequence lengths > 16 + if use_deepspeed_evo_attention and q_x.shape[-2] <= 16: + use_deepspeed_evo_attention = False + + # DeepSpeed attention kernel applies scaling internally + q, k, v = self._prep_qkv(q_x, kv_x, + apply_scale = not (use_deepspeed_evo_attention or use_cuequivariance_attention)) # cuequivariance kernel takes precedence over use_deepspeed_evo_attention if use_cuequivariance_attention: