Skip to content

Commit

Permalink
[JAX] Flax module init with a given dtype (#1472)
Browse files Browse the repository at this point in the history
* flax module to init params with given dtype

Signed-off-by: Phuong Nguyen <[email protected]>

* all tests passed

Signed-off-by: Phuong Nguyen <[email protected]>

* remove unneccessary reshape for kernel

Signed-off-by: Phuong Nguyen <[email protected]>

* remove casting output of dot

Signed-off-by: Phuong Nguyen <[email protected]>

* clean up

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Feb 11, 2025
1 parent 544dd14 commit b87e539
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 45 deletions.
9 changes: 7 additions & 2 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,13 @@ def abstract(
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
assert (
q_dtype == k_dtype == v_dtype == bias_dtype
), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
)

batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def type_safe_dot_general(
"""

if fp8_meta_pkg is None:
kernel = jnp.asarray(kernel, x.dtype)
assert x.dtype == kernel.dtype, f"lhs dtype = {x.dtype}, rhs dtype = {kernel.dtype}"
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))

amax_list = fp8_meta_pkg.amax_list
Expand Down
88 changes: 50 additions & 38 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,13 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
):
scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, jnp.float32, axes=scale_axes
)
scale = jnp.asarray(scale, dtype)
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = scale.astype(dtype)

layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, jnp.float32, axes=bias_axes
)
bias = jnp.asarray(bias, dtype)
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = bias.astype(dtype)
else:
assert layernorm_type == "rmsnorm"
bias = None
Expand Down Expand Up @@ -280,7 +276,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods

def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma
self.scale_init,
self.zero_centered_gamma,
)
super().__post_init__()

Expand All @@ -299,6 +296,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
outputs : jax.numpy.ndarray
Output tensors.
"""
x = x.astype(self.dtype)

features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
Expand Down Expand Up @@ -424,7 +422,9 @@ class DenseGeneral(TransformerEngineBase):

def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()

@nn.compact
Expand Down Expand Up @@ -452,14 +452,13 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)

kernel = jnp.reshape(kernel, kernel_shape)
kernel = kernel.astype(self.dtype)

if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, jnp.float32, axes=self.bias_axes
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
else:
Expand Down Expand Up @@ -490,7 +489,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
Expand All @@ -502,7 +501,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
Expand Down Expand Up @@ -633,9 +632,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):

def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma
self.scale_init,
self.zero_centered_gamma,
)
super().__post_init__()

Expand Down Expand Up @@ -665,6 +667,7 @@ def __call__(self, inputs: Array) -> Array:
and not self.return_layernorm_output
and self.enable_layernorm
)
inputs = inputs.astype(self.dtype)

if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
Expand Down Expand Up @@ -709,10 +712,9 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)

kernel = jnp.reshape(kernel, kernel_shape)
kernel = kernel.astype(self.dtype)

contract_ind = tuple(range(0, len(axis)))

Expand Down Expand Up @@ -755,7 +757,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
Expand All @@ -767,7 +769,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
Expand All @@ -779,7 +781,7 @@ def __call__(self, inputs: Array) -> Array:
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, jnp.float32, axes=self.bias_axes
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)

Expand Down Expand Up @@ -935,9 +937,12 @@ class LayerNormMLP(TransformerEngineBase):

def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma
self.scale_init,
self.zero_centered_gamma,
)
super().__post_init__()

Expand Down Expand Up @@ -970,6 +975,8 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
and self.enable_layernorm
)

inputs = inputs.astype(self.dtype)

gated_act_pool = [
("gelu", "linear"),
("silu", "linear"),
Expand Down Expand Up @@ -1033,7 +1040,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
for _ in range(num_kernels):
key, init_key = jax_random.split(key)
kernels.append(self.kernel_init(init_key, *init_args))
return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)
return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)

wi_fp8_meta_pkg = None
wo_fp8_meta_pkg = None
Expand All @@ -1054,10 +1061,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
self.dtype,
axes=self.kernel_axes_1,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
kernel_1 = kernel_1.astype(self.dtype)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
Expand All @@ -1066,10 +1074,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_kernel",
self.kernel_init,
kernel_2_param_shape,
jnp.float32,
self.dtype,
axes=self.kernel_axes_2,
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
kernel_2 = kernel_2.astype(self.dtype)
contract_ind = tuple(range(0, len(axis)))

ffn1_ckpt_name = "ffn1"
Expand All @@ -1081,13 +1090,13 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1
"wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1
)
bias_1 = bias_1.astype(self.dtype)

bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2
"wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2
)
bias_2 = bias_2.astype(self.dtype)
else:
Expand Down Expand Up @@ -1156,7 +1165,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
jnp.float32,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
Expand All @@ -1172,7 +1181,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wi_lora_b_kernel",
nn.initializers.zeros,
wi_lora_b_kernel_shape,
jnp.float32,
self.dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
Expand All @@ -1189,10 +1198,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias_1 = None
if self.use_bias:
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1
"wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1
)
bias_1 = bias_1.astype(self.dtype)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype)
x += jnp.reshape(bias_1, bias_1_shape)

x = checkpoint_name(x, ffn1_ckpt_name)
Expand All @@ -1207,6 +1216,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
z = functools.reduce(operator.mul, activations)
# Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = z.astype(self.dtype)
# import pdb; pdb.set_trace()

z = nn.Dropout(
rate=self.intermediate_dropout_rate,
Expand All @@ -1215,6 +1226,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
)(z, deterministic=deterministic)

z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
z = z.astype(self.dtype)

# DenseGeneral 2
out = type_safe_dot_general(
Expand All @@ -1228,7 +1240,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_lora_a_kernel",
self.kernel_init,
wo_lora_a_kernel_shape,
jnp.float32,
self.dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
Expand All @@ -1239,7 +1251,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_lora_b_kernel",
nn.initializers.zeros,
wo_lora_b_kernel_shape,
jnp.float32,
self.dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
Expand All @@ -1256,7 +1268,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias_2 = None
if self.use_bias:
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2
"wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2
)
bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
Expand Down
16 changes: 12 additions & 4 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,9 @@ def __post_init__(self):
)

if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
Expand Down Expand Up @@ -1198,6 +1200,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq):
inputs_kv = ln_out

key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
key = key.astype(self.dtype)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj")
Expand Down Expand Up @@ -1437,7 +1440,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
"rel_embedding",
self.embedding_init,
(self.num_attention_heads, self.num_buckets),
jnp.float32,
self.dtype,
axes=self.embedding_axes,
)

Expand Down Expand Up @@ -1673,10 +1676,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods

def __post_init__(self):
if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal"
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
Expand Down Expand Up @@ -1726,6 +1731,9 @@ def __call__(
outputs: jax.numpy.ndarray
Output tensors.
"""

inputs = inputs.astype(self.dtype)

assert (
self.layer_type in TransformerLayerType
), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
Expand Down

0 comments on commit b87e539

Please sign in to comment.