Skip to content

Commit 13b4185

Browse files
committed
FIX
1 parent 0182598 commit 13b4185

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

modules/commons/common_layers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,16 @@ def forward(self, x, key_padding_mask=None):
275275
# Project inputs to Q, K, V
276276
Q, K, V = torch.split(self.in_proj(x), self.embed_dim, dim=-1)
277277

278-
# Query-Key Normalization
279-
if self.use_qk_norm:
280-
Q = self.q_norm(Q)
281-
K = self.k_norm(K)
282-
283278
# Reshape Q, K, V for multi-head attention
284279
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
285280
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
286281
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
287282

283+
# Query-Key Normalization
284+
if self.use_qk_norm:
285+
Q = self.q_norm(Q)
286+
K = self.k_norm(K)
287+
288288
# Apply RoPE
289289
if self.rotary_embed is not None:
290290
Q = self.rotary_embed.rotate_queries_or_keys(Q)

0 commit comments

Comments
 (0)