diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 65a21c99c5..70693def2c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -586,6 +586,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): are not enabled by default. - `use_rowwise_bias_correction` is used in Adam to enable rowwise bias correction computation + + embedding_table_index_type (torch.dtype = torch.int64): The data type of + the embedding table index tensor. Options are `torch.int32` and + `torch.int64` + + embedding_table_offset_type (torch.dtype = torch.int64): The data type of + the embedding table offset tensor. Options are `torch.int32` and + `torch.int64` """ embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]] @@ -654,6 +662,8 @@ def __init__( # noqa C901 uvm_host_mapped: bool = False, extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None, tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None, + embedding_table_index_type: torch.dtype = torch.int64, + embedding_table_offset_type: torch.dtype = torch.int64, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -1343,6 +1353,17 @@ def __init__( # noqa C901 FeatureGateName.BOUNDS_CHECK_INDICES_V2 ) + if embedding_table_index_type not in [torch.int32, torch.int64]: + raise ValueError( + f"embedding_table_index_type must be torch.int32 or torch.int64, but got {embedding_table_index_type}" + ) + self.embedding_table_index_type: torch.dtype = embedding_table_index_type + if embedding_table_offset_type not in [torch.int32, torch.int64]: + raise ValueError( + f"embedding_table_offset_type must be torch.int32 or torch.int64, but got {embedding_table_offset_type}" + ) + self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type + @torch.jit.ignore def log(self, msg: str) -> None: """ @@ -3409,6 +3430,15 @@ def prepare_inputs( # NOTE: Force offsets to have the same dtype as indices since the # kernels assume same dtype. We might need to revisit the assumption # of same dtypes in the future. + if self.embedding_table_index_type == torch.int32: + self.log( + "Casting indices to int32 based on embedding_table_index_type input." + ) + indices = indices.to(torch.int32) + if self.embedding_table_index_type != self.embedding_table_offset_type: + self.log( + f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type." + ) offsets = offsets.to(dtype=indices.dtype) # Force casting per_sample_weights to float