Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Support segment_ids/pos as FA inputs #1406

Merged
merged 24 commits into from
Jan 24, 2025

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Jan 13, 2025

Description

This PR adds segment_ids/pos limited support and deprecated fused_attn_thd API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Add a new SequenceDescriptor class for different sequence descriptions scenario.
    • from_seqlens for non-THD
    • from_seqlens_and_offsets for THD
    • from_segment_ids_and_pos for THD + ring attn (haven't implemented)
  • Change the old fused_attn mask parameter to SequenceDescriptor. Passing mask in the position argument will work for a while but generating deprecation warning.
  • Deprecate fused_attn_thd API as the refactored fused_attn can also support THD format.
  • Remove small inputs in test_fused_attn.py as the long sequence inputs should cover.
  • Add different sequence inputs tests in test_fused_attn.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
@zlsh80826 zlsh80826 force-pushed the rewang/test-segment-ids branch from 08a7582 to e62c049 Compare January 14, 2025 09:57
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

tests/jax/test_fused_attn.py Show resolved Hide resolved
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Mask, id="Mask"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help to cut down on test cases by creating a standalone unit test for the SequenceDesc to cover and check all of the cases? Then in this unit test we can use either Seqlens or SegmentIDs depending on THD or BSHD?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to keep those different SeqDescFormat in the test_fused_attn for a while. I don't have a better way to verify the converted seqlens/offsets are correct, the best way to verify that is checking the FA results. The current time for CI is still under 1hr so I think it is ok for now. But yeap, if we found some of SeqDescFormat are rarely used in the future, we can remove them or just testing on the specific input shape instead of all input shapes.

transformer_engine/jax/attention.py Show resolved Hide resolved
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Signed-off-by: Reese Wang <[email protected]>
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Signed-off-by: Reese Wang <[email protected]>
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826 zlsh80826 merged commit c2c3d54 into NVIDIA:main Jan 24, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants