@@ -112,6 +112,11 @@ class Gemma2Config(PretrainedConfig):
112
112
scaling factor when applying tanh softcapping on the logits.
113
113
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
114
114
scaling factor when applying tanh softcapping on the attention scores.
115
+ use_post_attention_norm (`bool`, *optional*, defaults to `True`):
116
+ whether to use a post attention layer normalization layer.
117
+ use_post_feedforward_norm (`bool`, *optional*, defaults to `True`):
118
+ whether to use a post feedforward layer normalization layer.
119
+
115
120
116
121
```python
117
122
>>> from transformers import Gemma2Model, Gemma2Config
@@ -166,6 +171,8 @@ def __init__(
166
171
layer_types = None ,
167
172
final_logit_softcapping = 30.0 ,
168
173
attn_logit_softcapping = 50.0 ,
174
+ use_post_attention_norm = True ,
175
+ use_post_feedforward_norm = True ,
169
176
** kwargs ,
170
177
):
171
178
super ().__init__ (
@@ -194,6 +201,8 @@ def __init__(
194
201
self .sliding_window = sliding_window
195
202
self .final_logit_softcapping = final_logit_softcapping
196
203
self .attn_logit_softcapping = attn_logit_softcapping
204
+ self .use_post_attention_norm = use_post_attention_norm
205
+ self .use_post_feedforward_norm = use_post_feedforward_norm
197
206
self .layer_types = layer_types
198
207
199
208
if self .layer_types is None :
@@ -313,10 +322,12 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
313
322
self .self_attn = Gemma2Attention (config = config , layer_idx = layer_idx )
314
323
self .mlp = Gemma2MLP (config )
315
324
self .input_layernorm = Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
316
- self .post_attention_layernorm = Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
325
+ if self .config .use_post_attention_norm :
326
+ self .post_attention_layernorm = Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
317
327
318
328
self .pre_feedforward_layernorm = Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
319
- self .post_feedforward_layernorm = Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
329
+ if self .config .use_post_feedforward_norm :
330
+ self .post_feedforward_layernorm = Gemma2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
320
331
321
332
@deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
322
333
def forward (
@@ -347,13 +358,15 @@ def forward(
347
358
cache_position = cache_position ,
348
359
** kwargs ,
349
360
)
350
- hidden_states = self .post_attention_layernorm (hidden_states )
361
+ if self .config .use_post_attention_norm :
362
+ hidden_states = self .post_attention_layernorm (hidden_states )
351
363
hidden_states = residual + hidden_states
352
364
353
365
residual = hidden_states
354
366
hidden_states = self .pre_feedforward_layernorm (hidden_states )
355
367
hidden_states = self .mlp (hidden_states )
356
- hidden_states = self .post_feedforward_layernorm (hidden_states )
368
+ if self .config .use_post_feedforward_norm :
369
+ hidden_states = self .post_feedforward_layernorm (hidden_states )
357
370
hidden_states = residual + hidden_states
358
371
359
372
outputs = (hidden_states ,)
0 commit comments