Skip to content

Commit f8f3264

Browse files
committed
Add ability to run Gemma 2 models without post attention norm and post feedforward norm
1 parent afd1393 commit f8f3264

File tree

5 files changed

+47
-8
lines changed

5 files changed

+47
-8
lines changed

src/transformers/models/gemma2/configuration_gemma2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ class Gemma2Config(PretrainedConfig):
8888
scaling factor when applying tanh softcapping on the logits.
8989
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
9090
scaling factor when applying tanh softcapping on the attention scores.
91+
use_post_attention_norm (`bool`, *optional*, defaults to `True`):
92+
whether to use a post attention layer normalization layer.
93+
use_post_feedforward_norm (`bool`, *optional*, defaults to `True`):
94+
whether to use a post feedforward layer normalization layer.
95+
9196
9297
```python
9398
>>> from transformers import Gemma2Model, Gemma2Config
@@ -142,6 +147,8 @@ def __init__(
142147
layer_types=None,
143148
final_logit_softcapping=30.0,
144149
attn_logit_softcapping=50.0,
150+
use_post_attention_norm=True,
151+
use_post_feedforward_norm=True,
145152
**kwargs,
146153
):
147154
super().__init__(
@@ -170,6 +177,8 @@ def __init__(
170177
self.sliding_window = sliding_window
171178
self.final_logit_softcapping = final_logit_softcapping
172179
self.attn_logit_softcapping = attn_logit_softcapping
180+
self.use_post_attention_norm = use_post_attention_norm
181+
self.use_post_feedforward_norm = use_post_feedforward_norm
173182
self.layer_types = layer_types
174183

175184
if self.layer_types is None:

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,12 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
248248
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
249249
self.mlp = Gemma2MLP(config)
250250
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
251-
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
251+
if self.config.use_post_attention_norm:
252+
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
252253

253254
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
254-
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255+
if self.config.use_post_feedforward_norm:
256+
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255257

256258
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
257259
def forward(
@@ -282,13 +284,15 @@ def forward(
282284
cache_position=cache_position,
283285
**kwargs,
284286
)
285-
hidden_states = self.post_attention_layernorm(hidden_states)
287+
if self.config.use_post_attention_norm:
288+
hidden_states = self.post_attention_layernorm(hidden_states)
286289
hidden_states = residual + hidden_states
287290

288291
residual = hidden_states
289292
hidden_states = self.pre_feedforward_layernorm(hidden_states)
290293
hidden_states = self.mlp(hidden_states)
291-
hidden_states = self.post_feedforward_layernorm(hidden_states)
294+
if self.config.use_post_feedforward_norm:
295+
hidden_states = self.post_feedforward_layernorm(hidden_states)
292296
hidden_states = residual + hidden_states
293297

294298
outputs = (hidden_states,)

src/transformers/models/gemma2/modular_gemma2.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ class Gemma2Config(PretrainedConfig):
112112
scaling factor when applying tanh softcapping on the logits.
113113
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
114114
scaling factor when applying tanh softcapping on the attention scores.
115+
use_post_attention_norm (`bool`, *optional*, defaults to `True`):
116+
whether to use a post attention layer normalization layer.
117+
use_post_feedforward_norm (`bool`, *optional*, defaults to `True`):
118+
whether to use a post feedforward layer normalization layer.
119+
115120
116121
```python
117122
>>> from transformers import Gemma2Model, Gemma2Config
@@ -166,6 +171,8 @@ def __init__(
166171
layer_types=None,
167172
final_logit_softcapping=30.0,
168173
attn_logit_softcapping=50.0,
174+
use_post_attention_norm=True,
175+
use_post_feedforward_norm=True,
169176
**kwargs,
170177
):
171178
super().__init__(
@@ -194,6 +201,8 @@ def __init__(
194201
self.sliding_window = sliding_window
195202
self.final_logit_softcapping = final_logit_softcapping
196203
self.attn_logit_softcapping = attn_logit_softcapping
204+
self.use_post_attention_norm = use_post_attention_norm
205+
self.use_post_feedforward_norm = use_post_feedforward_norm
197206
self.layer_types = layer_types
198207

199208
if self.layer_types is None:
@@ -313,10 +322,12 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
313322
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
314323
self.mlp = Gemma2MLP(config)
315324
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
316-
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
325+
if self.config.use_post_attention_norm:
326+
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
317327

318328
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
319-
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
329+
if self.config.use_post_feedforward_norm:
330+
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
320331

321332
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
322333
def forward(
@@ -347,13 +358,15 @@ def forward(
347358
cache_position=cache_position,
348359
**kwargs,
349360
)
350-
hidden_states = self.post_attention_layernorm(hidden_states)
361+
if self.config.use_post_attention_norm:
362+
hidden_states = self.post_attention_layernorm(hidden_states)
351363
hidden_states = residual + hidden_states
352364

353365
residual = hidden_states
354366
hidden_states = self.pre_feedforward_layernorm(hidden_states)
355367
hidden_states = self.mlp(hidden_states)
356-
hidden_states = self.post_feedforward_layernorm(hidden_states)
368+
if self.config.use_post_feedforward_norm:
369+
hidden_states = self.post_feedforward_layernorm(hidden_states)
357370
hidden_states = residual + hidden_states
358371

359372
outputs = (hidden_states,)

src/transformers/models/t5gemma/configuration_t5gemma.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class T5GemmaModuleConfig(PretrainedConfig):
9090
scaling factor when applying tanh softcapping on the logits.
9191
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
9292
scaling factor when applying tanh softcapping on the attention scores.
93+
use_post_attention_norm (`bool`, *optional*, defaults to `True`):
94+
whether to use a post attention layer normalization layer.
95+
use_post_feedforward_norm (`bool`, *optional*, defaults to `True`):
96+
whether to use a post feedforward layer normalization layer.
97+
9398
9499
```python
95100
>>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig
@@ -144,6 +149,8 @@ def __init__(
144149
layer_types=None,
145150
final_logit_softcapping=30.0,
146151
attn_logit_softcapping=50.0,
152+
use_post_attention_norm=True,
153+
use_post_feedforward_norm=True,
147154
**kwargs,
148155
):
149156
super().__init__(
@@ -172,6 +179,8 @@ def __init__(
172179
self.sliding_window = sliding_window
173180
self.final_logit_softcapping = final_logit_softcapping
174181
self.attn_logit_softcapping = attn_logit_softcapping
182+
self.use_post_attention_norm = use_post_attention_norm
183+
self.use_post_feedforward_norm = use_post_feedforward_norm
175184
self.layer_types = layer_types
176185

177186
if self.layer_types is None:

src/transformers/models/t5gemma/modular_t5gemma.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,17 @@ def __init__(
166166
encoder.is_decoder = False
167167
encoder.dropout_rate = dropout_rate
168168
encoder.attention_dropout = attention_dropout
169+
encoder.use_post_attention_norm = True
170+
encoder.use_post_feedforward_norm = True
169171
self.encoder = encoder
170172

171173
decoder.is_decoder = True
172174
decoder.use_cache = True
173175
decoder.dropout_rate = dropout_rate
174176
decoder.attention_dropout = attention_dropout
175177
decoder.cross_attention_hidden_size = encoder.hidden_size
178+
decoder.use_post_attention_norm = True
179+
decoder.use_post_feedforward_norm = True
176180
self.decoder = decoder
177181

178182
for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]:

0 commit comments

Comments
 (0)