-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Support segment_ids/pos as FA inputs (#1406)
* POC for segment_ids/segment_pos Signed-off-by: Reese Wang <[email protected]> * Change segment_pos position Signed-off-by: Reese Wang <[email protected]> * Use RemainingArgs to solve number of parameters mismatches Signed-off-by: Reese Wang <[email protected]> * Test mask_descriptor for accomendating different mask representations Signed-off-by: Reese Wang <[email protected]> * Fix bugs Signed-off-by: Reese Wang <[email protected]> * Use descriptor in bwd Signed-off-by: Reese Wang <[email protected]> * Primitives only accepts pure jnp array Signed-off-by: Reese Wang <[email protected]> * segment_ids/pos support POC Signed-off-by: Reese Wang <[email protected]> * Move seqlens/offsets generation to mask descriptor Signed-off-by: Reese Wang <[email protected]> * Rename MaskDescriptor to SequenceDescriptor Signed-off-by: Reese Wang <[email protected]> * Generalize get_seqlens_and_offsets Signed-off-by: Reese Wang <[email protected]> * Utilize sequence desc on FA bwd Signed-off-by: Reese Wang <[email protected]> * Migrate to new API Signed-off-by: Reese Wang <[email protected]> * Add docstrings Signed-off-by: Reese Wang <[email protected]> * Remove small inputs and test different input format Signed-off-by: Reese Wang <[email protected]> * Fix lint Signed-off-by: Reese Wang <[email protected]> * Fix seed shardings Signed-off-by: Reese Wang <[email protected]> * Optimize sequence converting overhead Signed-off-by: Reese Wang <[email protected]> * Optimize seq_offsets calculation Signed-off-by: Reese Wang <[email protected]> * Fix up Signed-off-by: Reese Wang <[email protected]> * fix lint Signed-off-by: Reese Wang <[email protected]> * Fix conflicts Signed-off-by: Reese Wang <[email protected]> * Remove reduntant line Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Reese Wang <[email protected]>
- Loading branch information
Showing
5 changed files
with
786 additions
and
297 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.