Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Causal mask ignored in DotProductAttention #1524

Open
anthony-Neo opened this issue Feb 28, 2025 · 1 comment
Open

Causal mask ignored in DotProductAttention #1524

anthony-Neo opened this issue Feb 28, 2025 · 1 comment
Assignees
Labels
good first issue Good for newcomers

Comments

@anthony-Neo
Copy link

anthony-Neo commented Feb 28, 2025

for transformer_engine version 1.14.0+87fbe812f
in transformer_engine.jax.flax.module
in Softmax class

in call line 191-198:

# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(
    SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
):
    outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
else:
    outputs = jax_nn.softmax(logits * self.scale_factor)  # <- self.softmax_type ignored

after else, self.softmax_type is ignored and no causal attention is performed when e.g. self.softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED

@philip-essential
Copy link

I just ran into this as well. At least for unfused attention + no softmax kernel available (that's the only configuration i've tried), it doesn't apply any kind of causal mask.

@phu0ngng phu0ngng added the good first issue Good for newcomers label Mar 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

4 participants