You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
for transformer_engine version 1.14.0+87fbe812f
in transformer_engine.jax.flax.module
in Softmax class
in call line 191-198:
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
else:
outputs = jax_nn.softmax(logits * self.scale_factor) # <- self.softmax_type ignored
after else, self.softmax_type is ignored and no causal attention is performed when e.g. self.softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
The text was updated successfully, but these errors were encountered:
I just ran into this as well. At least for unfused attention + no softmax kernel available (that's the only configuration i've tried), it doesn't apply any kind of causal mask.
for transformer_engine version 1.14.0+87fbe812f
in transformer_engine.jax.flax.module
in Softmax class
in call line 191-198:
after else, self.softmax_type is ignored and no causal attention is performed when e.g. self.softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
The text was updated successfully, but these errors were encountered: