Skip to content

Commit 3d6b521

Browse files
The reshape call has a numerical bug where the target reshape size is always higher than array size if num_devices is not a nonnegative power of two. To fix this, this commit adds padding
1 parent 0368421 commit 3d6b521

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

tests/ann_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,23 @@ def test_pmap(self, qy_shape, db_shape, dtype, k, recall):
125125
db_size = db.shape[0]
126126
gt_scores = lax.dot_general(qy, db, (([1], [1]), ([], [])))
127127
_, gt_args = lax.top_k(-gt_scores, k) # negate the score to get min-k
128-
db_per_device = db_size//num_devices
128+
db_per_device = math.ceil(db_size/num_devices)
129+
130+
# The reshape call has a numerical bug where the target reshape size is
131+
# always higher than array size if num_devices is not a nonnegative
132+
# power of two. To fix this, we can do padding.
133+
db_dim = db.shape[1]
134+
target_size = db_per_device * num_devices
135+
if db_size < target_size:
136+
pad_len = target_size - db_size
137+
pad_values = np.ones((pad_len, db_dim), dtype=db.dtype) * np.inf
138+
139+
# Pad with inf because we are running min-k and we do not want to
140+
# affect the result. Use concatenate so we have more control over
141+
# padding values and for readability. This also will avoid surprises
142+
# with pad modes.
143+
db = np.concatenate([db, pad_values], axis=0)
144+
129145
sharded_db = db.reshape(num_devices, db_per_device, 128)
130146
db_offsets = np.arange(num_devices, dtype=np.int32) * db_per_device
131147
def parallel_topk(qy, db, db_offset):

0 commit comments

Comments
 (0)