Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a80273f

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Added checks for making sure the key and value depths are divisible by the number of attention heads when doing multihead attention.
PiperOrigin-RevId: 162025615
1 parent 1b1d7ed commit a80273f

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tensor2tensor/models/common_attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,18 @@ def multihead_attention(query_antecedent,
374374
375375
Returns:
376376
A Tensor.
377+
378+
Raises:
379+
ValueError: if the key depth or value depth are not divisible by the
380+
number of attention heads.
377381
"""
382+
if total_key_depth % num_heads != 0:
383+
raise ValueError("Key depth (%d) must be divisible by the number of "
384+
"attention heads (%d)." % (total_key_depth, num_heads))
385+
if total_value_depth % num_heads != 0:
386+
raise ValueError("Value depth (%d) must be divisible by the number of "
387+
"attention heads (%d)." % (total_value_depth, num_heads))
388+
378389
with tf.variable_scope(
379390
name,
380391
default_name="multihead_attention",

0 commit comments

Comments
 (0)