-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Support segment_ids/pos as FA inputs #1406
Conversation
/te-ci jax L1 |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
08a7582
to
e62c049
Compare
/te-ci jax L1 |
@pytest.mark.parametrize( | ||
"seq_desc_format", | ||
[ | ||
pytest.param(SeqDescFormat.Mask, id="Mask"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it help to cut down on test cases by creating a standalone unit test for the SequenceDesc
to cover and check all of the cases? Then in this unit test we can use either Seqlens
or SegmentIDs
depending on THD or BSHD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still need to keep those different SeqDescFormat
in the test_fused_attn
for a while. I don't have a better way to verify the converted seqlens/offsets are correct, the best way to verify that is checking the FA results. The current time for CI is still under 1hr so I think it is ok for now. But yeap, if we found some of SeqDescFormat are rarely used in the future, we can remove them or just testing on the specific input shape instead of all input shapes.
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
/te-ci jax L1 |
…g/test-segment-ids
Signed-off-by: Reese Wang <[email protected]>
/te-ci jax L1 |
Signed-off-by: Reese Wang <[email protected]>
/te-ci jax L1 |
Description
This PR adds
segment_ids/pos
limited support and deprecatedfused_attn_thd
API.Type of change
Changes
SequenceDescriptor
class for different sequence descriptions scenario.from_seqlens
for non-THDfrom_seqlens_and_offsets
for THDfrom_segment_ids_and_pos
for THD + ring attn (haven't implemented)fused_attn
mask
parameter toSequenceDescriptor
. Passingmask
in the position argument will work for a while but generating deprecation warning.fused_attn_thd
API as the refactoredfused_attn
can also support THD format.test_fused_attn.py
as the long sequence inputs should cover.test_fused_attn.py
Checklist: