From 50febbe2424f39c29b9b30691472c25ac6cd9867 Mon Sep 17 00:00:00 2001 From: Basil Wong Date: Wed, 29 Jan 2025 08:55:26 -0800 Subject: [PATCH] Updating split_table_batched_embeddings_ops_training.py (#3613) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/690 After this diff stack: EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params. These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py Reviewed By: q10 Differential Revision: D68005929 --- ...t_table_batched_embeddings_ops_training.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 65a21c99c5..70693def2c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -586,6 +586,14 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): are not enabled by default. - `use_rowwise_bias_correction` is used in Adam to enable rowwise bias correction computation + + embedding_table_index_type (torch.dtype = torch.int64): The data type of + the embedding table index tensor. Options are `torch.int32` and + `torch.int64` + + embedding_table_offset_type (torch.dtype = torch.int64): The data type of + the embedding table offset tensor. Options are `torch.int32` and + `torch.int64` """ embedding_specs: List[Tuple[int, int, EmbeddingLocation, ComputeDevice]] @@ -654,6 +662,8 @@ def __init__( # noqa C901 uvm_host_mapped: bool = False, extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None, tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None, + embedding_table_index_type: torch.dtype = torch.int64, + embedding_table_offset_type: torch.dtype = torch.int64, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -1343,6 +1353,17 @@ def __init__( # noqa C901 FeatureGateName.BOUNDS_CHECK_INDICES_V2 ) + if embedding_table_index_type not in [torch.int32, torch.int64]: + raise ValueError( + f"embedding_table_index_type must be torch.int32 or torch.int64, but got {embedding_table_index_type}" + ) + self.embedding_table_index_type: torch.dtype = embedding_table_index_type + if embedding_table_offset_type not in [torch.int32, torch.int64]: + raise ValueError( + f"embedding_table_offset_type must be torch.int32 or torch.int64, but got {embedding_table_offset_type}" + ) + self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type + @torch.jit.ignore def log(self, msg: str) -> None: """ @@ -3409,6 +3430,15 @@ def prepare_inputs( # NOTE: Force offsets to have the same dtype as indices since the # kernels assume same dtype. We might need to revisit the assumption # of same dtypes in the future. + if self.embedding_table_index_type == torch.int32: + self.log( + "Casting indices to int32 based on embedding_table_index_type input." + ) + indices = indices.to(torch.int32) + if self.embedding_table_index_type != self.embedding_table_offset_type: + self.log( + f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type." + ) offsets = offsets.to(dtype=indices.dtype) # Force casting per_sample_weights to float