Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Oct 1, 2025

Description

PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX.

Fixes #2070

BEFORE
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    1.97s | avg:   0.16s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  612.61s | avg:   3.83s
test_context_parallel_allgather_attn_shardy                  |  20x |   90.95s | avg:   4.55s
test_context_parallel_ring_attn                              | 640x | 1042.37s | avg:   1.63s
test_context_parallel_ring_attn_shardy                       |  20x |   37.74s | avg:   1.89s
test_cross_attn                                              |   6x |   31.82s | avg:   5.30s
test_distributed_gemm                                        |   6x |    6.10s | avg:   1.02s
test_layernorm                                               | 144x |   81.39s | avg:   0.57s
test_layernorm_mlp_grad                                      | 240x |  301.51s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  293.58s | avg:   1.22s
test_layernorm_mlp_layer                                     |  48x |   21.58s | avg:   0.45s
test_layernorm_mlp_layer_fp8                                 | 192x |   81.58s | avg:   0.42s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   91.23s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   25.98s | avg:   0.54s
test_rmsnorm                                                 |  72x |   29.43s | avg:   0.41s
test_self_attn                                               |  18x |   89.75s | avg:   4.99s
test_self_attn_shardy                                        |   6x |   17.32s | avg:   2.89s
test_softmax                                                 | 288x |  185.44s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   13.07s | avg:   0.54s
test_te_distributed_dense_grad                               |   6x |    5.12s | avg:   0.85s
================================================================================
TOTAL RUNTIME                                                |      | 3060.56s |
================================================================================

AFTER
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    2.20s | avg:   0.18s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  587.44s | avg:   3.67s
test_context_parallel_allgather_attn_shardy                  |  20x |   87.95s | avg:   4.40s
test_context_parallel_ring_attn                              | 640x | 1037.16s | avg:   1.62s
test_context_parallel_ring_attn_shardy                       |  20x |   41.83s | avg:   2.09s
test_cross_attn                                              |  18x |   89.76s | avg:   4.99s
test_distributed_gemm                                        |   6x |    5.74s | avg:   0.96s
test_layernorm                                               | 144x |   83.85s | avg:   0.58s
test_layernorm_mlp_grad                                      | 240x |  301.73s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  309.08s | avg:   1.29s
test_layernorm_mlp_layer                                     |  48x |   24.98s | avg:   0.52s
test_layernorm_mlp_layer_fp8                                 | 192x |   89.17s | avg:   0.46s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   92.58s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   26.29s | avg:   0.55s
test_rmsnorm                                                 |  72x |   29.52s | avg:   0.41s
test_self_attn                                               |  54x |  259.63s | avg:   4.81s
test_self_attn_shardy                                        |  18x |   43.51s | avg:   2.42s
test_softmax                                                 | 288x |  183.87s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   12.72s | avg:   0.53s
test_te_distributed_dense_grad                               |   6x |    4.74s | avg:   0.79s
================================================================================
TOTAL RUNTIME                                                |      | 3313.74s |
================================================================================

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 4 commits October 1, 2025 14:45
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pre-commit-ci bot and others added 5 commits October 2, 2025 15:54
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 6, 2025

/te-ci jax

pggPL and others added 2 commits October 7, 2025 14:30
@phu0ngng phu0ngng self-requested a review October 8, 2025 18:33
pggPL and others added 7 commits October 14, 2025 14:41
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 14, 2025

/te-ci jax L1

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 15, 2025

/te-ci jax L1

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 15, 2025

/te-ci jax L1

)
else:
assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will probably need the shape of (1,) for shardy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes the tests for shardy, so I didn't change that. I didn't change anything in shared_shardy_rule - the sequence_offsets is not returned as an output, so maybe it is not needed? Tbh I do not fully get why shardy rule in attention is build as it is.

pggPL and others added 6 commits October 16, 2025 19:48
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset)


def jax_general_softmax(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This softmax is copied from JAX. In JAX there is also custom jvp for this softmax, it can be generalized to support softmax offset. I don't know how efficient we want unfused attention to be - I decided to not include custom jvp.

@pggPL pggPL marked this pull request as ready for review October 17, 2025 11:04
@pggPL pggPL requested a review from KshitijLakhani October 17, 2025 11:05
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 17, 2025

/te-ci jax L1

pggPL and others added 4 commits October 21, 2025 18:25
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 21, 2025

/te-ci jax L1

pggPL and others added 2 commits October 22, 2025 00:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add attention sink to flash attention

2 participants