Skip to content

Commit 6a920cd

Browse files
reddragonThe gemma Authors
authored andcommitted
Scale and precision related changes for RMSNorm and Einsum.
PiperOrigin-RevId: 775113121
1 parent db6e197 commit 6a920cd

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

gemma/gm/nn/_layers.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class Einsum(nn.Module):
2626
weight_name: str = 'w'
2727
initializer: nn.initializers.Initializer = nn.initializers.normal()
2828
dtype: jnp.dtype | None = None
29+
w_scale: float | None = None
2930

3031
@nn.compact
3132
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
@@ -35,24 +36,42 @@ def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
3536
self.shape,
3637
self.dtype if self.dtype is not None else None,
3738
)
39+
if self.w_scale:
40+
w *= self.w_scale
3841
return jnp.einsum(eqn, x, w)
3942

4043

44+
def reduce_precision(x: jax.Array) -> jax.Array:
45+
"""Helper function to reduce the precision of a tensor."""
46+
finfo = jnp.finfo(x.dtype) # jnp important!
47+
return jax.lax.reduce_precision(x, finfo.nexp, finfo.nmant)
48+
49+
4150
class RMSNorm(nn.Module):
4251
"""RMSNorm layer."""
4352

53+
with_scale: bool = True
54+
scale_init: nn.initializers.Initializer = nn.initializers.zeros_init()
55+
scale_plus_one: bool = True
56+
guard_against_excess_precision: bool = False
57+
4458
@nn.compact
4559
def __call__(self, x):
46-
scale = self.param('scale', nn.initializers.zeros_init(), (x.shape[-1]))
60+
if self.guard_against_excess_precision:
61+
x = reduce_precision(x)
62+
4763
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
4864

4965
# Jax.lax.rsqrt is used because it returns different floats than
5066
# jnp.reciprocal(jnp.sqrt(var + 1e-06))
5167
normed_inputs = x * jax.lax.rsqrt(var + 1e-06)
5268

53-
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
54-
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
55-
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
56-
scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1))
57-
normed_inputs = normed_inputs * (1 + scale)
69+
if self.with_scale:
70+
scale = self.param('scale', self.scale_init, (x.shape[-1]))
71+
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale
72+
# is a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
73+
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
74+
scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1))
75+
normed_inputs = normed_inputs * (
76+
1. + scale if self.scale_plus_one else scale)
5877
return normed_inputs

gemma/gm/nn/_layers_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414

1515
"""Tests for transformer layers."""
1616

17+
from flax import linen as nn
1718
from gemma import gm
1819
import jax
1920
import jax.numpy as jnp
2021
import numpy as np
2122
import pytest
2223

24+
_ZERO_INIT = nn.initializers.zeros_init()
25+
_ONES_INIT = nn.initializers.ones_init()
26+
2327

2428
@pytest.mark.parametrize(
2529
'inputs_shape, params_shape, eqn, expected_shape',
@@ -48,3 +52,23 @@ def test_rmsnorm(x, expected):
4852
params = rmsnorm.init(jax.random.PRNGKey(0), x)
4953
output = rmsnorm.apply(params, x)
5054
np.testing.assert_array_equal(output, jnp.array([expected]))
55+
56+
57+
@pytest.mark.parametrize(
58+
'x, expected,with_scale,scale_init',
59+
[
60+
# This is the default case.
61+
([0.1, 0.2], [0.6324429, 1.2648858], True, _ZERO_INIT),
62+
# In this case, the output is simply scaled by (1 + scale).
63+
([0.1, 0.2], [1.2648858, 2.5297716], True, _ONES_INIT),
64+
# When with_scale is False, the output is not scaled.
65+
([0.1, 0.2], [0.6324429, 1.2648858], False, _ZERO_INIT),
66+
([0.1, 0.2], [0.6324429, 1.2648858], False, _ONES_INIT),
67+
],
68+
)
69+
def test_rmsnorm_with_scale(x, expected, with_scale, scale_init):
70+
x = jnp.array([x])
71+
rmsnorm = gm.nn.RMSNorm(with_scale=with_scale, scale_init=scale_init)
72+
params = rmsnorm.init(jax.random.PRNGKey(0), x)
73+
output = rmsnorm.apply(params, x)
74+
np.testing.assert_array_equal(output, jnp.array([expected]))

0 commit comments

Comments
 (0)