diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 90ab5fb7fe..d09ce7ef74 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -27,7 +27,7 @@ NamedSharding, get_cudnn_version, ) -from .quantization import _quantize_dbias_impl, AmaxScope +from .quantization import quantize, AmaxScope from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp, @@ -945,7 +945,7 @@ def layernorm_fwd( beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float, - quantizer: Optional[Quantizer], + quantizer: Optional[Quantizer] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, transpose_batch_sequence: bool = False, output_amax_when_no_scaling: bool = False, @@ -975,7 +975,16 @@ def layernorm_fwd( - Reciprocal of the standard deviation of the input tensor. Shape: (..., 1) """ if not NormFwdPrimitive.enabled(): - return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon) + if quantizer is not None: + output = quantize( + output, + quantizer, + flatten_axis=-1, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + return (output, mu, rsigma) # TE/common does not support normalization with colwise only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: @@ -1029,7 +1038,7 @@ def layernorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=False, ) - out, _ = _quantize_dbias_impl( + out, _ = quantize( out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence ) return out, mu, rsigma @@ -1050,11 +1059,9 @@ def layernorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=True, ) - out, _ = _quantize_dbias_impl( + out = quantize( out, - is_dbias=False, quantizer=quantizer, - dq_dtype=x.dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, ) @@ -1219,7 +1226,16 @@ def rmsnorm_fwd( Shape: (..., 1) """ if not NormFwdPrimitive.enabled(): - return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon) + if quantizer is not None: + output = quantize( + output, + quantizer, + flatten_axis=-1, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + return (output, rsigma) # TE/common does not support normalization with colwise only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: @@ -1274,7 +1290,7 @@ def rmsnorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=False, ) - out, _ = _quantize_dbias_impl( + out = quantize( out.data, quantizer, amax_scope=amax_scope, @@ -1297,11 +1313,9 @@ def rmsnorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=True, ) - out, _ = _quantize_dbias_impl( + out = quantize( out, - is_dbias=False, quantizer=quantizer, - dq_dtype=x.dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, )