Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Fused attention unit tests fixes and refinements (#1352)
* Add util functions to attn_mask_type Signed-off-by: Reese Wang <[email protected]> * Add util functions to qkv_layout Signed-off-by: Reese Wang <[email protected]> * Fix THD cross reference code Signed-off-by: Reese Wang <[email protected]> * Remove explicit segment_pad, encoding it to segment_ids Signed-off-by: Reese Wang <[email protected]> * Add jax.jit, replace _token with segment_ids, rename bias shape enum Signed-off-by: Reese Wang <[email protected]> * Add comment for make_mask Signed-off-by: Reese Wang <[email protected]> * Clean code Signed-off-by: Reese Wang <[email protected]> * Add doc strings for the added functions Signed-off-by: Reese Wang <[email protected]> * Remove cache for fa deterministic which causes UT failed Signed-off-by: Reese Wang <[email protected]> * Rename fixture to avoid conflict Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Reese Wang <[email protected]>