Skip to content

Commit d1d26ac

Browse files
Joey Yangmeta-codesync[bot]
authored andcommitted
Map hash_zch_identities to corresponding unique indices in TBE (#5077)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2082 Pull Request resolved: #5077 This change selects the `hash_zch_identities` that corresponds with unique indices during TBE prefetch. This is specifically required for MPZCH tables, which need both the slot index and the corresponding identities for correct lookup behavior. Without the identities, the inference side cannot correctly verify if it's using the correct slot, leading to potential lookup errors. Reviewed By: chouxi Differential Revision: D85999577 fbshipit-source-id: 3c8a4add1dd112e9a746b334e7046bb442ea977b
1 parent ecf2ac9 commit d1d26ac

File tree

2 files changed

+219
-28
lines changed

2 files changed

+219
-28
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ class RESParams:
200200
) # table sizes for the global rows the TBE holds
201201

202202

203+
@dataclass(frozen=True)
204+
class PrefetchedInfo:
205+
linear_unique_indices: torch.Tensor
206+
linear_unique_indices_length: torch.Tensor
207+
hash_zch_identities: Optional[torch.Tensor]
208+
209+
203210
def construct_split_state(
204211
embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]],
205212
rowwise: bool,
@@ -2100,6 +2107,12 @@ def forward( # noqa: C901
21002107
requires this information for allocating the weight gradient
21012108
tensor in the backward pass.
21022109
2110+
hash_zch_identities (Optional[Tensor]): The original raw IDs before
2111+
remapping to ZCH (Zero-Collision Hash) table slots. This tensor is
2112+
populated when using Multi-Probe Zero Collision Hash (MPZCH) modules
2113+
and is required for Raw Embedding Streaming (RES) to maintain
2114+
consistency between training and inference.
2115+
21032116
Returns:
21042117
A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` =
21052118
batch size and `total_D` = the sum of all embedding dimensions in the
@@ -2217,7 +2230,6 @@ def forward( # noqa: C901
22172230
# In forward, we don't enable multi-pass prefetch as we want the process
22182231
# to be as fast as possible and memory usage doesn't matter (will be recycled
22192232
# by dense fwd/bwd)
2220-
# TODO: Properly pass in the hash_zch_identities
22212233
self._prefetch(
22222234
indices,
22232235
offsets,
@@ -4140,6 +4152,60 @@ def raw_embedding_stream(self) -> None:
41404152
False, # blocking_tensor_copy
41414153
)
41424154

4155+
@staticmethod
4156+
@torch.jit.ignore
4157+
def _get_prefetched_info(
4158+
linear_cache_indices_merged: torch.Tensor,
4159+
total_cache_hash_size: int,
4160+
hash_zch_identities: Optional[torch.Tensor],
4161+
) -> PrefetchedInfo:
4162+
compute_inverse_indices = hash_zch_identities is not None
4163+
(
4164+
linear_unique_indices,
4165+
linear_unique_indices_length,
4166+
linear_unique_indices_cnt,
4167+
linear_unique_inverse_indices,
4168+
) = torch.ops.fbgemm.get_unique_indices_with_inverse(
4169+
linear_cache_indices_merged,
4170+
total_cache_hash_size,
4171+
compute_count=compute_inverse_indices,
4172+
compute_inverse_indices=compute_inverse_indices,
4173+
)
4174+
# linear_unique_indices is the result after deduplication and sorting
4175+
linear_unique_indices = linear_unique_indices.narrow(
4176+
0, 0, linear_unique_indices_length[0]
4177+
)
4178+
4179+
if hash_zch_identities is None:
4180+
return PrefetchedInfo(
4181+
linear_unique_indices,
4182+
linear_unique_indices_length,
4183+
None,
4184+
)
4185+
4186+
# Compute cumulative sum as indices for selecting unique elements to
4187+
# map hash_zch_identities to linear_unique_indices
4188+
count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum(
4189+
linear_unique_indices_cnt
4190+
)
4191+
count_cum_sum = count_cum_sum.narrow(0, 0, linear_unique_indices_length[0])
4192+
4193+
# Select indices corresponding to first occurrence of each unique element
4194+
linear_unique_inverse_indices = linear_unique_inverse_indices.index_select(
4195+
dim=0, index=count_cum_sum
4196+
)
4197+
4198+
# Map hash_zch_identities to unique indices
4199+
hash_zch_identities_cpu = hash_zch_identities.index_select(
4200+
dim=0, index=linear_unique_inverse_indices
4201+
).to(device=torch.device("cpu"))
4202+
4203+
return PrefetchedInfo(
4204+
linear_unique_indices,
4205+
linear_unique_indices_length,
4206+
hash_zch_identities_cpu,
4207+
)
4208+
41434209
@torch.jit.ignore
41444210
def _store_prefetched_tensors(
41454211
self,
@@ -4150,35 +4216,26 @@ def _store_prefetched_tensors(
41504216
NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional.
41514217
This function stores the prefetched tensors for the raw embedding streaming.
41524218
"""
4153-
if self.enable_raw_embedding_streaming:
4154-
with record_function(
4155-
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4156-
):
4219+
if not self.enable_raw_embedding_streaming:
4220+
return
4221+
4222+
with record_function(
4223+
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
4224+
):
4225+
# Process hash_zch_identities using helper function
4226+
prefetched_info = self._get_prefetched_info(
4227+
linear_cache_indices_merged,
4228+
self.total_cache_hash_size,
4229+
hash_zch_identities,
4230+
)
4231+
4232+
self.prefetched_info.append(
41574233
(
4158-
linear_unique_indices,
4159-
linear_unique_indices_length,
4160-
_,
4161-
) = torch.ops.fbgemm.get_unique_indices(
4162-
linear_cache_indices_merged,
4163-
self.total_cache_hash_size,
4164-
compute_count=False,
4165-
)
4166-
linear_unique_indices = linear_unique_indices.narrow(
4167-
0, 0, linear_unique_indices_length[0]
4168-
)
4169-
self.prefetched_info.append(
4170-
(
4171-
linear_unique_indices,
4172-
linear_unique_indices_length,
4173-
(
4174-
hash_zch_identities.index_select(
4175-
dim=0, index=linear_unique_indices
4176-
).to(device=torch.device("cpu"))
4177-
if hash_zch_identities is not None
4178-
else None
4179-
),
4180-
)
4234+
prefetched_info.linear_unique_indices,
4235+
prefetched_info.linear_unique_indices_length,
4236+
prefetched_info.hash_zch_identities,
41814237
)
4238+
)
41824239

41834240
@torch.jit.ignore
41844241
def __report_input_params_factory(
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
12+
import torch
13+
14+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
15+
SplitTableBatchedEmbeddingBagsCodegen,
16+
)
17+
18+
from ..common import open_source
19+
20+
if open_source:
21+
# pyre-ignore[21]
22+
from test_utils import gpu_unavailable
23+
else:
24+
from fbgemm_gpu.test.test_utils import gpu_unavailable
25+
26+
27+
class StorePrefetchedTensorsTest(unittest.TestCase):
28+
@unittest.skipIf(*gpu_unavailable)
29+
def test_get_prefetched_info(self) -> None:
30+
hash_zch_identities = torch.tensor(
31+
[
32+
[3350213393928437575], # for index 54
33+
[6548733451892409412], # for index 27
34+
[4126118985661274454], # for index 43
35+
[2565973416302224539], # for index 90
36+
],
37+
device=torch.cuda.current_device(),
38+
dtype=torch.int64,
39+
)
40+
total_cache_hash_size = 100
41+
linear_cache_indices_merged = torch.tensor(
42+
[54, 27, 43, 90],
43+
device=torch.cuda.current_device(),
44+
dtype=torch.int64,
45+
)
46+
47+
prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info(
48+
linear_cache_indices_merged,
49+
total_cache_hash_size,
50+
hash_zch_identities,
51+
)
52+
53+
self.assertEqual(
54+
[27, 43, 54, 90],
55+
prefetched_info.linear_unique_indices.tolist(),
56+
)
57+
self.assertEqual(
58+
prefetched_info.linear_unique_indices_length[0].item(),
59+
4,
60+
)
61+
assert prefetched_info.hash_zch_identities is not None
62+
self.assertEqual(
63+
prefetched_info.hash_zch_identities.shape[0],
64+
4,
65+
)
66+
self.assertEqual(
67+
[
68+
[6548733451892409412],
69+
[4126118985661274454],
70+
[3350213393928437575],
71+
[2565973416302224539],
72+
],
73+
prefetched_info.hash_zch_identities.tolist(),
74+
)
75+
76+
@unittest.skipIf(*gpu_unavailable)
77+
def test_get_prefetched_info_with_duplicate_hash_zch_identities(self) -> None:
78+
"""
79+
Test that duplicate cache indices are correctly deduplicated.
80+
When the same cache index appears multiple times with the same identity,
81+
only the first occurrence should be kept in the output.
82+
"""
83+
hash_zch_identities = torch.tensor(
84+
[
85+
[3350213393928437575], # for index 54 (first occurrence)
86+
[6548733451892409412], # for index 27
87+
[3350213393928437575], # for index 54 (duplicate - same identity)
88+
[4126118985661274454], # for index 43
89+
[6548733451892409412], # for index 27 (duplicate - same identity)
90+
[3350213393928437575], # for index 54 (duplicate - same identity)
91+
[2565973416302224539], # for index 90
92+
],
93+
device=torch.cuda.current_device(),
94+
dtype=torch.int64,
95+
)
96+
total_cache_hash_size = 100
97+
linear_cache_indices_merged = torch.tensor(
98+
[54, 27, 54, 43, 27, 54, 90], # Duplicates: 54 appears 3x, 27 appears 2x
99+
device=torch.cuda.current_device(),
100+
dtype=torch.int64,
101+
)
102+
103+
prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info(
104+
linear_cache_indices_merged,
105+
total_cache_hash_size,
106+
hash_zch_identities,
107+
)
108+
109+
self.assertEqual(
110+
[27, 43, 54, 90],
111+
prefetched_info.linear_unique_indices.tolist(),
112+
)
113+
self.assertEqual(
114+
prefetched_info.linear_unique_indices_length[0].item(),
115+
4,
116+
)
117+
assert prefetched_info.hash_zch_identities is not None
118+
self.assertEqual(
119+
prefetched_info.hash_zch_identities.shape[0],
120+
4,
121+
)
122+
self.assertEqual(
123+
[
124+
[6548733451892409412], # for index 27
125+
[4126118985661274454], # for index 43
126+
[3350213393928437575], # for index 54
127+
[2565973416302224539], # for index 90
128+
],
129+
prefetched_info.hash_zch_identities.tolist(),
130+
)
131+
132+
133+
if __name__ == "__main__":
134+
unittest.main()

0 commit comments

Comments
 (0)