@@ -26,6 +26,7 @@ class Einsum(nn.Module):
26
26
weight_name : str = 'w'
27
27
initializer : nn .initializers .Initializer = nn .initializers .normal ()
28
28
dtype : jnp .dtype | None = None
29
+ w_scale : float | None = None
29
30
30
31
@nn .compact
31
32
def __call__ (self , eqn : str , x : jax .Array ) -> jax .Array :
@@ -35,24 +36,42 @@ def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
35
36
self .shape ,
36
37
self .dtype if self .dtype is not None else None ,
37
38
)
39
+ if self .w_scale :
40
+ w *= self .w_scale
38
41
return jnp .einsum (eqn , x , w )
39
42
40
43
44
+ def reduce_precision (x : jax .Array ) -> jax .Array :
45
+ """Helper function to reduce the precision of a tensor."""
46
+ finfo = jnp .finfo (x .dtype ) # jnp important!
47
+ return jax .lax .reduce_precision (x , finfo .nexp , finfo .nmant )
48
+
49
+
41
50
class RMSNorm (nn .Module ):
42
51
"""RMSNorm layer."""
43
52
53
+ with_scale : bool = True
54
+ scale_init : nn .initializers .Initializer = nn .initializers .zeros_init ()
55
+ scale_plus_one : bool = True
56
+ guard_against_excess_precision : bool = False
57
+
44
58
@nn .compact
45
59
def __call__ (self , x ):
46
- scale = self .param ('scale' , nn .initializers .zeros_init (), (x .shape [- 1 ]))
60
+ if self .guard_against_excess_precision :
61
+ x = reduce_precision (x )
62
+
47
63
var = jnp .mean (jnp .square (x ), axis = - 1 , keepdims = True )
48
64
49
65
# Jax.lax.rsqrt is used because it returns different floats than
50
66
# jnp.reciprocal(jnp.sqrt(var + 1e-06))
51
67
normed_inputs = x * jax .lax .rsqrt (var + 1e-06 )
52
68
53
- # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
54
- # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
55
- # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
56
- scale = jnp .expand_dims (scale , axis = range (len (x .shape ) - 1 ))
57
- normed_inputs = normed_inputs * (1 + scale )
69
+ if self .with_scale :
70
+ scale = self .param ('scale' , self .scale_init , (x .shape [- 1 ]))
71
+ # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale
72
+ # is a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
73
+ # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
74
+ scale = jnp .expand_dims (scale , axis = range (len (x .shape ) - 1 ))
75
+ normed_inputs = normed_inputs * (
76
+ 1. + scale if self .scale_plus_one else scale )
58
77
return normed_inputs
0 commit comments