Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
NamedSharding,
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl, AmaxScope
from .quantization import quantize, AmaxScope
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp_tpsp,
Expand Down Expand Up @@ -945,7 +945,7 @@ def layernorm_fwd(
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
quantizer: Optional[Quantizer],
quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
Expand Down Expand Up @@ -975,7 +975,16 @@ def layernorm_fwd(
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, mu, rsigma)

# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
Expand Down Expand Up @@ -1029,7 +1038,7 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out, _ = quantize(
out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence
)
return out, mu, rsigma
Expand All @@ -1050,11 +1059,9 @@ def layernorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out = quantize(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
Expand Down Expand Up @@ -1219,7 +1226,16 @@ def rmsnorm_fwd(
Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon)
if quantizer is not None:
output = quantize(
output,
quantizer,
flatten_axis=-1,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return (output, rsigma)

# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
Expand Down Expand Up @@ -1274,7 +1290,7 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out = quantize(
out.data,
quantizer,
amax_scope=amax_scope,
Expand All @@ -1297,11 +1313,9 @@ def rmsnorm_fwd(
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out = quantize(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
Expand Down