Skip to content

Commit

Permalink
[TE/JAX] Disable FusedAttn with FFI by default (#1298)
Browse files Browse the repository at this point in the history
* disable fused attn with ffi

---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Oct 31, 2024
1 parent 9dddb36 commit 23caab3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def lowering(

wkspace_aval = ctx.avals_out[-1]

if is_ffi_enabled():
if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI", "0")):
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
Expand Down

0 comments on commit 23caab3

Please sign in to comment.