Skip to content

Commit 034dea1

Browse files
author
Jesujoba Alabi
committed
fixed neox attention_adapters
1 parent f920882 commit 034dea1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/transformers/models/gpt_neox/modeling_gpt_neox.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,11 @@ def forward(
355355
# pseudocode:
356356
# x = x + attn(ln1(x))
357357
# x = x + mlp(ln2(x))
358-
attn_output = attn_output + hidden_states
359-
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
358+
hidden_states = self.attention_adapters(attn_output, hidden_states, None) #attn_output = attn_output + hidden_states
359+
residual = hidden_states
360+
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
360361
# residual connection
361-
hidden_states = self.output_adapters(mlp_output, attn_output, None)
362+
hidden_states = self.output_adapters(mlp_output, residual, None)
362363
#hidden_states = mlp_output + attn_output
363364

364365
if use_cache:

0 commit comments

Comments
 (0)