@@ -688,241 +688,3 @@ def test_dynamically_switch_inference_training_mode(self) -> None:
688688 self .assertTrue (m ._is_inference )
689689 self .assertTrue (m ._eviction_policy_name is None )
690690 self .assertTrue (m ._eviction_module is None )
691-
692- # Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
693- @unittest .skipIf (
694- torch .cuda .device_count () < 1 ,
695- "Not enough GPUs, this test requires at least two GPUs" ,
696- )
697- def test_zch_hash_disable_fallback (self ) -> None :
698- m = HashZchManagedCollisionModule (
699- zch_size = 30 ,
700- device = torch .device ("cuda" ),
701- total_num_buckets = 2 ,
702- eviction_policy_name = HashZchEvictionPolicyName .SINGLE_TTL_EVICTION ,
703- eviction_config = HashZchEvictionConfig (
704- features = [],
705- single_ttl = 10 ,
706- ),
707- max_probe = 4 ,
708- disable_fallback = True ,
709- start_bucket = 1 ,
710- output_segments = [0 , 10 , 20 ],
711- )
712- jt = JaggedTensor (
713- values = torch .arange (0 , 4 , dtype = torch .int64 , device = "cuda" ),
714- lengths = torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
715- )
716- # Run once to insert ids
717- output0 = m .remap ({"test" : jt })
718- self .assertTrue (
719- torch .equal (
720- output0 ["test" ].values (),
721- torch .tensor ([8 , 15 , 11 ], dtype = torch .int64 , device = "cuda:0" ),
722- )
723- )
724- self .assertTrue (
725- torch .equal (
726- output0 ["test" ].lengths (),
727- torch .tensor ([1 , 1 , 0 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
728- )
729- )
730- m .reset_inference_mode ()
731- jt = JaggedTensor (
732- values = torch .tensor ([9 , 0 , 1 , 4 , 6 , 8 ], dtype = torch .int64 , device = "cuda" ),
733- lengths = torch .tensor ([1 , 1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
734- )
735- # Run again in inference mode and only values 0 and 1 exist.
736- output1 = m .remap ({"test" : jt })
737- self .assertTrue (
738- torch .equal (
739- output1 ["test" ].values (),
740- torch .tensor ([8 , 15 ], dtype = torch .int64 , device = "cuda:0" ),
741- )
742- )
743- self .assertTrue (
744- torch .equal (
745- output1 ["test" ].lengths (),
746- torch .tensor ([0 , 1 , 1 , 0 , 0 , 0 ], dtype = torch .int64 , device = "cuda:0" ),
747- )
748- )
749-
750- m = HashZchManagedCollisionModule (
751- zch_size = 10 ,
752- device = torch .device ("cuda" ),
753- total_num_buckets = 2 ,
754- eviction_policy_name = HashZchEvictionPolicyName .SINGLE_TTL_EVICTION ,
755- eviction_config = HashZchEvictionConfig (
756- features = [],
757- single_ttl = 10 ,
758- ),
759- max_probe = 4 ,
760- start_bucket = 0 ,
761- output_segments = None ,
762- disable_fallback = True ,
763- )
764- jt = JaggedTensor (
765- values = torch .arange (0 , 4 , dtype = torch .int64 , device = "cuda" ),
766- lengths = torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
767- )
768- # Run once to insert ids
769- output0 = m .remap ({"test" : jt })
770- self .assertTrue (
771- torch .equal (
772- output0 ["test" ].values (),
773- torch .tensor ([3 , 5 , 4 , 6 ], dtype = torch .int64 , device = "cuda:0" ),
774- )
775- )
776- self .assertTrue (
777- torch .equal (
778- output0 ["test" ].lengths (),
779- torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
780- )
781- )
782- m .reset_inference_mode ()
783- jt = JaggedTensor (
784- values = torch .tensor ([9 , 0 , 1 , 4 , 6 , 8 ], dtype = torch .int64 , device = "cuda" ),
785- lengths = torch .tensor ([1 , 1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda" ),
786- )
787- # Run again in inference mode and only values 0 and 1 exist.
788- output1 = m .remap ({"test" : jt })
789- self .assertTrue (
790- torch .equal (
791- output1 ["test" ].values (),
792- torch .tensor ([3 , 5 ], dtype = torch .int64 , device = "cuda:0" ),
793- )
794- )
795- self .assertTrue (
796- torch .equal (
797- output1 ["test" ].lengths (),
798- torch .tensor ([0 , 1 , 1 , 0 , 0 , 0 ], dtype = torch .int64 , device = "cuda:0" ),
799- )
800- )
801-
802- # Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
803- @unittest .skipIf (
804- torch .cuda .device_count () < 1 ,
805- "Not enough GPUs, this test requires at least two GPUs" ,
806- )
807- def test_zch_hash_zero_rows (self ) -> None :
808- # When disabling fallback, for missed ids we should return zero rows in output embeddings.
809- mc_emb_configs = [
810- EmbeddingBagConfig (
811- num_embeddings = 10 ,
812- embedding_dim = 3 ,
813- name = "table_0" ,
814- data_type = DataType .FP32 ,
815- feature_names = ["table_0" ],
816- pooling = PoolingType .SUM ,
817- weight_init_max = None ,
818- weight_init_min = None ,
819- init_fn = None ,
820- use_virtual_table = False ,
821- virtual_table_eviction_policy = None ,
822- total_num_buckets = 1 ,
823- )
824- ]
825- mc_modules : Dict [str , ManagedCollisionModule ] = {
826- "table_0" : HashZchManagedCollisionModule (
827- zch_size = 10 ,
828- device = torch .device ("cuda" ),
829- max_probe = 512 ,
830- tb_logging_frequency = 100 ,
831- name = "table_0" ,
832- total_num_buckets = 1 ,
833- eviction_config = None ,
834- eviction_policy_name = None ,
835- opt_in_prob = - 1 ,
836- percent_reserved_slots = 0 ,
837- disable_fallback = True ,
838- )
839- }
840- mcebc = ManagedCollisionEmbeddingBagCollection (
841- EmbeddingBagCollection (
842- device = torch .device ("cuda" ),
843- tables = mc_emb_configs ,
844- is_weighted = False ,
845- ),
846- ManagedCollisionCollection (
847- managed_collision_modules = mc_modules ,
848- embedding_configs = mc_emb_configs ,
849- ),
850- return_remapped_features = True ,
851- )
852- lengths = torch .tensor (
853- [1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = torch .device ("cuda" )
854- )
855- values = torch .tensor (
856- [3 , 4 , 5 , 6 , 8 ],
857- dtype = torch .int64 ,
858- device = torch .device ("cuda" ),
859- )
860- features = KeyedJaggedTensor (
861- keys = ["table_0" ],
862- values = values ,
863- lengths = lengths ,
864- )
865- # Run once to insert ids
866- res = mcebc .forward (features )
867- # Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
868- mask = torch .abs (res [0 ]["table_0" ]) == 0
869- # For each row, check if all elements are True (i.e., close to zero)
870- row_mask = mask .all (dim = 1 )
871- # Get indices of zero rows
872- self .assertEqual (torch .nonzero (row_mask , as_tuple = False ).squeeze ().numel (), 0 )
873- self .assertIsNotNone (res [1 ])
874- self .assertTrue (
875- torch .equal (
876- # Pyre-ignore [16]: Optional type has no attribute `__getitem__`.
877- res [1 ]["table_0" ].values (),
878- torch .tensor ([1 , 2 , 8 , 9 , 3 ], dtype = torch .int64 , device = "cuda:0" ),
879- )
880- )
881- self .assertTrue (
882- torch .equal (
883- res [1 ]["table_0" ].lengths (),
884- torch .tensor ([1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
885- )
886- )
887- # Pyre-ignore [29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function
888- mcebc ._managed_collision_collection ._managed_collision_modules [
889- "table_0"
890- ].reset_inference_mode ()
891- lengths = torch .tensor (
892- [1 , 1 , 1 , 1 , 1 , 1 ], dtype = torch .int64 , device = torch .device ("cuda" )
893- )
894- values = torch .tensor (
895- [0 , 4 , 5 , 1 , 2 , 8 ],
896- dtype = torch .int64 ,
897- device = torch .device ("cuda" ),
898- )
899- features = KeyedJaggedTensor (
900- keys = ["table_0" ],
901- values = values ,
902- lengths = lengths ,
903- )
904- # Run once to insert ids.
905- res = mcebc .forward (features )
906- self .assertTrue (
907- torch .equal (
908- res [1 ]["table_0" ].values (),
909- torch .tensor ([2 , 8 , 3 ], dtype = torch .int64 , device = "cuda:0" ),
910- )
911- )
912- self .assertTrue (
913- torch .equal (
914- res [1 ]["table_0" ].lengths (),
915- torch .tensor ([0 , 1 , 1 , 0 , 0 , 1 ], dtype = torch .int64 , device = "cuda:0" ),
916- )
917- )
918- # Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
919- mask = torch .abs (res [0 ]["table_0" ]) == 0
920- # For each row, check if all elements are True (i.e., close to zero)
921- row_mask = mask .all (dim = 1 )
922- # Get indices of zero rows
923- self .assertTrue (
924- torch .equal (
925- torch .tensor ([0 , 3 , 4 ], device = "cuda:0" ),
926- torch .nonzero (row_mask , as_tuple = False ).squeeze (),
927- )
928- )
0 commit comments