Skip to content

Commit

Permalink
[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward (#1319)
Browse files Browse the repository at this point in the history
* FFI for all softmax functions

Signed-off-by: Hua Huang <[email protected]>

* FFI for FusedAttnBackward and Dequantize

FusedAttnBackward passed all testes in test_fused_attn.py.
Dequantize is not used currently; finish it for completeness.

Signed-off-by: Hua Huang <[email protected]>

* Fix FusedAttnBackward FFI pybind & simplify

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert changes to tests/jax/test_fused_attn.py

Signed-off-by: Hua Huang <[email protected]>

---------

Signed-off-by: Hua Huang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent 68adf45 commit 237b493
Show file tree
Hide file tree
Showing 7 changed files with 561 additions and 302 deletions.
133 changes: 84 additions & 49 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,6 @@ def lowering(
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)

wkspace_aval = ctx.avals_out[-1]

if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
Expand Down Expand Up @@ -433,6 +431,8 @@ def lowering(
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

wkspace_aval = ctx.avals_out[-1]

opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
Expand Down Expand Up @@ -725,28 +725,6 @@ def lowering(
"""
Fused attention bwd lowering rules
"""
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]

args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
Expand All @@ -761,33 +739,90 @@ def lowering(
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)

wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

wkspace_aval = ctx.avals_out[-1]

opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)

out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)

return out

Expand Down
182 changes: 97 additions & 85 deletions transformer_engine/jax/cpp_extensions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from jax import core, dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi

from transformer_engine import transformer_engine_jax

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType


Expand Down Expand Up @@ -133,32 +134,36 @@ def forward_lowering(name, ctx, logits, *, scale_factor):
"""
softmax_forward lowering rules
"""
(i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
pad_batch = batch
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]

out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits]
operand_shapes = [i_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor,
)
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
else:
(i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
pad_batch = batch
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]

out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits]
operand_shapes = [i_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor,
)

out = custom_caller(name, args, opaque, False)
out = custom_caller(name, args, opaque, False)

return out

Expand Down Expand Up @@ -240,37 +245,41 @@ def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
"""
softmax_backward lowering rules
"""
dz_aval, _ = ctx.avals_in

dz_type = ir.RankedTensorType(dz.type)
dz_shape = dz_type.shape

# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, dz_shape[:-3])
pad_batch = batch # unused
heads = dz_shape[-3]
q_seqlen = dz_shape[-2]
k_seqlen = dz_shape[-1]

softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape

out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(dz_aval.dtype),
scale_factor,
)
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
else:
dz_aval, _ = ctx.avals_in

dz_type = ir.RankedTensorType(dz.type)
dz_shape = dz_type.shape

# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, dz_shape[:-3])
pad_batch = batch # unused
heads = dz_shape[-3]
q_seqlen = dz_shape[-2]
k_seqlen = dz_shape[-1]

softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape

out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(dz_aval.dtype),
scale_factor,
)

out = custom_caller(name, args, opaque, False)
out = custom_caller(name, args, opaque, False)

return out

Expand Down Expand Up @@ -577,36 +586,39 @@ def lowering(ctx, logits, mask, *, scale_factor):
"""
te_scaled_masked_softmax_forward lowering rules
"""
if is_ffi_enabled():
ffi_name = "te_scaled_masked_softmax_forward_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]

mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
pad_batch = reduce(operator.mul, mask_shape[:-3])

out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(logits_aval.dtype),
scale_factor,
)

logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]

mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
pad_batch = reduce(operator.mul, mask_shape[:-3])

out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(logits_aval.dtype),
scale_factor,
)

out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

return out

Expand Down
Loading

0 comments on commit 237b493

Please sign in to comment.