Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
are not enabled by default.
- `use_rowwise_bias_correction` is used in Adam to enable rowwise
bias correction computation

embedding_table_index_type (torch.dtype = torch.int64): The data type of
the embedding table index tensor. Options are `torch.int32` and
`torch.int64`

embedding_table_offset_type (torch.dtype = torch.int64): The data type of
the embedding table offset tensor. Options are `torch.int32` and
`torch.int64`
"""

embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]
Expand Down Expand Up @@ -654,6 +662,8 @@ def __init__( # noqa C901
uvm_host_mapped: bool = False,
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
embedding_table_index_type: torch.dtype = torch.int64,
embedding_table_offset_type: torch.dtype = torch.int64,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -1343,6 +1353,17 @@ def __init__( # noqa C901
FeatureGateName.BOUNDS_CHECK_INDICES_V2
)

if embedding_table_index_type not in [torch.int32, torch.int64]:
raise ValueError(
f"embedding_table_index_type must be torch.int32 or torch.int64, but got {embedding_table_index_type}"
)
self.embedding_table_index_type: torch.dtype = embedding_table_index_type
if embedding_table_offset_type not in [torch.int32, torch.int64]:
raise ValueError(
f"embedding_table_offset_type must be torch.int32 or torch.int64, but got {embedding_table_offset_type}"
)
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type

@torch.jit.ignore
def log(self, msg: str) -> None:
"""
Expand Down Expand Up @@ -3409,6 +3430,15 @@ def prepare_inputs(
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
if self.embedding_table_index_type == torch.int32:
self.log(
"Casting indices to int32 based on embedding_table_index_type input."
)
indices = indices.to(torch.int32)
if self.embedding_table_index_type != self.embedding_table_offset_type:
self.log(
f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type."
)
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
Expand Down
Loading