Skip to content

Commit

Permalink
[JAX] Flax params initialization with weight_dtype (#1481)
Browse files Browse the repository at this point in the history
* initialization with weight_dtype

Signed-off-by: Phuong Nguyen <[email protected]>
---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Feb 14, 2025
1 parent f0d22ca commit 24e4f95
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 43 deletions.
104 changes: 71 additions & 33 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union

import jax.numpy as jnp
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
Expand Down Expand Up @@ -57,14 +57,18 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga


def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype
):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, weight_dtype, axes=scale_axes
)
scale = scale.astype(dtype)

layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, weight_dtype, axes=bias_axes
)
bias = bias.astype(dtype)
else:
assert layernorm_type == "rmsnorm"
Expand Down Expand Up @@ -256,8 +260,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
Expand All @@ -272,6 +278,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -307,6 +314,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
self.bias_init,
self.bias_axes,
self.dtype,
self.weight_dtype,
)
return layernorm(
x,
Expand Down Expand Up @@ -399,8 +407,10 @@ class DenseGeneral(TransformerEngineBase):
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
Expand All @@ -418,12 +428,13 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False

def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
)
super().__post_init__()

Expand Down Expand Up @@ -452,13 +463,13 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)

if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
else:
Expand Down Expand Up @@ -489,7 +500,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
self.dtype,
self.weight_dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
Expand All @@ -501,7 +512,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
Expand Down Expand Up @@ -594,8 +605,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
Expand Down Expand Up @@ -625,6 +638,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
Expand All @@ -633,7 +647,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0,
"fan_in",
"truncated_normal",
dtype=self.weight_dtype,
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init,
Expand Down Expand Up @@ -683,6 +700,7 @@ def __call__(self, inputs: Array) -> Array:
self.ln_bias_init,
self.ln_bias_axes,
self.dtype,
self.weight_dtype,
)

if not fuse_layernorm:
Expand Down Expand Up @@ -712,7 +730,7 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)

Expand Down Expand Up @@ -757,7 +775,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
self.dtype,
self.weight_dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
Expand All @@ -769,7 +787,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
Expand All @@ -781,7 +799,7 @@ def __call__(self, inputs: Array) -> Array:
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)

Expand Down Expand Up @@ -896,8 +914,10 @@ class LayerNormMLP(TransformerEngineBase):
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
Expand Down Expand Up @@ -930,6 +950,7 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
Expand All @@ -938,7 +959,7 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init,
Expand Down Expand Up @@ -1015,6 +1036,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
self.ln_bias_init,
self.ln_bias_axes,
self.dtype,
self.weight_dtype,
)

if not fuse_layernorm:
Expand Down Expand Up @@ -1061,7 +1083,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
num_activations,
-2,
kernel_1_each_shape,
self.dtype,
self.weight_dtype,
axes=self.kernel_axes_1,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
Expand All @@ -1074,7 +1096,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_kernel",
self.kernel_init,
kernel_2_param_shape,
self.dtype,
self.weight_dtype,
axes=self.kernel_axes_2,
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
Expand All @@ -1090,13 +1112,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1
"wi_bias",
self.bias_init,
bias_1_shape,
self.weight_dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.astype(self.dtype)

bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2
"wo_bias",
self.bias_init,
bias_2_shape,
self.weight_dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
else:
Expand Down Expand Up @@ -1165,7 +1195,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
self.dtype,
self.weight_dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
Expand All @@ -1181,7 +1211,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wi_lora_b_kernel",
nn.initializers.zeros,
wi_lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
Expand All @@ -1198,7 +1228,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias_1 = None
if self.use_bias:
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1
"wi_bias",
self.bias_init,
intermediate_dim,
self.weight_dtype,
axes=self.bias_axes_1,
)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype)
Expand Down Expand Up @@ -1240,7 +1274,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_lora_a_kernel",
self.kernel_init,
wo_lora_a_kernel_shape,
self.dtype,
self.weight_dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
Expand All @@ -1251,7 +1285,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_lora_b_kernel",
nn.initializers.zeros,
wo_lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
Expand All @@ -1268,7 +1302,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias_2 = None
if self.use_bias:
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2
"wo_bias",
self.bias_init,
(hidden_size,),
self.weight_dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
Expand Down
Loading

0 comments on commit 24e4f95

Please sign in to comment.