-
Notifications
You must be signed in to change notification settings - Fork 531
[JAX] Add support for sink attention in JAX #2225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax |
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci jax L1 |
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
|
/te-ci jax L1 |
| ) | ||
| else: | ||
| assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX | ||
| softmax_offset = jnp.zeros(0, dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will probably need the shape of (1,) for shardy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It passes the tests for shardy, so I didn't change that. I didn't change anything in shared_shardy_rule - the sequence_offsets is not returned as an output, so maybe it is not needed? Tbh I do not fully get why shardy rule in attention is build as it is.
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
| return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset) | ||
|
|
||
|
|
||
| def jax_general_softmax( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This softmax is copied from JAX. In JAX there is also custom jvp for this softmax, it can be generalized to support softmax offset. I don't know how efficient we want unfused attention to be - I decided to not include custom jvp.
|
/te-ci jax L1 |
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Description
PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX.
Fixes #2070
Type of change
Checklist: