@@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator};
8
8
use core:: mem;
9
9
use core:: ops:: { BitAnd , BitOr , BitXor , Sub } ;
10
10
11
- use super :: map:: { self , ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner , HashMap , Keys } ;
11
+ use super :: map:: {
12
+ self , make_hash, make_insert_hash, ConsumeAllOnDrop , DefaultHashBuilder , DrainFilterInner ,
13
+ HashMap , Keys , RawEntryMut ,
14
+ } ;
12
15
use crate :: raw:: { Allocator , Global } ;
13
16
14
17
// Future Optimization (FIXME!)
@@ -953,6 +956,12 @@ where
953
956
/// Inserts a value computed from `f` into the set if the given `value` is
954
957
/// not present, then returns a reference to the value in the set.
955
958
///
959
+ /// # Panics
960
+ ///
961
+ /// Panics if the value from the function and the provided lookup value
962
+ /// are not equivalent or have different hashes. See [`Equivalent`]
963
+ /// and [`Hash`] for more information.
964
+ ///
956
965
/// # Examples
957
966
///
958
967
/// ```
@@ -967,20 +976,40 @@ where
967
976
/// assert_eq!(value, pet);
968
977
/// }
969
978
/// assert_eq!(set.len(), 4); // a new "fish" was inserted
979
+ /// assert!(set.contains("fish"));
970
980
/// ```
971
981
#[ cfg_attr( feature = "inline-more" , inline) ]
972
982
pub fn get_or_insert_with < Q : ?Sized , F > ( & mut self , value : & Q , f : F ) -> & T
973
983
where
974
984
Q : Hash + Equivalent < T > ,
975
985
F : FnOnce ( & Q ) -> T ,
976
986
{
987
+ #[ cold]
988
+ #[ inline( never) ]
989
+ fn assert_failed ( ) {
990
+ panic ! (
991
+ "the value from the function and the lookup value \
992
+ must be equivalent and have the same hash"
993
+ ) ;
994
+ }
995
+
977
996
// Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
978
997
// `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
979
- self . map
980
- . raw_entry_mut ( )
981
- . from_key ( value)
982
- . or_insert_with ( || ( f ( value) , ( ) ) )
983
- . 0
998
+ let hash = make_hash :: < Q , S > ( & self . map . hash_builder , value) ;
999
+ let raw_entry_builder = self . map . raw_entry_mut ( ) ;
1000
+ match raw_entry_builder. from_key_hashed_nocheck ( hash, value) {
1001
+ RawEntryMut :: Occupied ( entry) => entry. into_key ( ) ,
1002
+ RawEntryMut :: Vacant ( entry) => {
1003
+ let insert_value = f ( value) ;
1004
+ let insert_value_hash = make_insert_hash :: < T , S > ( entry. hasher ( ) , & insert_value) ;
1005
+ if !( hash == insert_value_hash && value. equivalent ( & insert_value) ) {
1006
+ assert_failed ( ) ;
1007
+ }
1008
+ entry
1009
+ . insert_hashed_nocheck ( insert_value_hash, insert_value, ( ) )
1010
+ . 0
1011
+ }
1012
+ }
984
1013
}
985
1014
986
1015
/// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2429,7 +2458,7 @@ fn assert_covariance() {
2429
2458
#[ cfg( test) ]
2430
2459
mod test_set {
2431
2460
use super :: super :: map:: DefaultHashBuilder ;
2432
- use super :: HashSet ;
2461
+ use super :: { make_hash , Equivalent , HashSet } ;
2433
2462
use std:: vec:: Vec ;
2434
2463
2435
2464
#[ test]
@@ -2886,4 +2915,100 @@ mod test_set {
2886
2915
set. insert ( i) ;
2887
2916
}
2888
2917
}
2918
+
2919
+ #[ test]
2920
+ fn duplicate_insert ( ) {
2921
+ let mut set = HashSet :: new ( ) ;
2922
+ set. insert ( 1 ) ;
2923
+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2924
+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2925
+ assert ! ( [ 1 ] . iter( ) . eq( set. iter( ) ) ) ;
2926
+ }
2927
+
2928
+ #[ test]
2929
+ #[ should_panic]
2930
+ fn some_invalid_hash ( ) {
2931
+ use core:: hash:: { Hash , Hasher } ;
2932
+ struct Invalid {
2933
+ count : u32 ,
2934
+ }
2935
+
2936
+ struct InvalidRef {
2937
+ count : u32 ,
2938
+ }
2939
+
2940
+ impl PartialEq for Invalid {
2941
+ fn eq ( & self , other : & Self ) -> bool {
2942
+ self . count == other. count
2943
+ }
2944
+ }
2945
+ impl Eq for Invalid { }
2946
+
2947
+ impl Equivalent < Invalid > for InvalidRef {
2948
+ fn equivalent ( & self , key : & Invalid ) -> bool {
2949
+ self . count == key. count
2950
+ }
2951
+ }
2952
+ impl Hash for Invalid {
2953
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2954
+ self . count . hash ( state) ;
2955
+ }
2956
+ }
2957
+ impl Hash for InvalidRef {
2958
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2959
+ let double = self . count * 2 ;
2960
+ double. hash ( state) ;
2961
+ }
2962
+ }
2963
+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
2964
+ let key = InvalidRef { count : 1 } ;
2965
+ let value = Invalid { count : 1 } ;
2966
+ if key. equivalent ( & value) {
2967
+ set. get_or_insert_with ( & key, |_| value) ;
2968
+ }
2969
+ }
2970
+
2971
+ #[ test]
2972
+ #[ should_panic]
2973
+ fn some_invalid_equivalent ( ) {
2974
+ use core:: hash:: { Hash , Hasher } ;
2975
+ struct Invalid {
2976
+ count : u32 ,
2977
+ other : u32 ,
2978
+ }
2979
+
2980
+ struct InvalidRef {
2981
+ count : u32 ,
2982
+ other : u32 ,
2983
+ }
2984
+
2985
+ impl PartialEq for Invalid {
2986
+ fn eq ( & self , other : & Self ) -> bool {
2987
+ self . count == other. count && self . other == other. other
2988
+ }
2989
+ }
2990
+ impl Eq for Invalid { }
2991
+
2992
+ impl Equivalent < Invalid > for InvalidRef {
2993
+ fn equivalent ( & self , key : & Invalid ) -> bool {
2994
+ self . count == key. count && self . other == key. other
2995
+ }
2996
+ }
2997
+ impl Hash for Invalid {
2998
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
2999
+ self . count . hash ( state) ;
3000
+ }
3001
+ }
3002
+ impl Hash for InvalidRef {
3003
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
3004
+ self . count . hash ( state) ;
3005
+ }
3006
+ }
3007
+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
3008
+ let key = InvalidRef { count : 1 , other : 1 } ;
3009
+ let value = Invalid { count : 1 , other : 2 } ;
3010
+ if make_hash ( set. hasher ( ) , & key) == make_hash ( set. hasher ( ) , & value) {
3011
+ set. get_or_insert_with ( & key, |_| value) ;
3012
+ }
3013
+ }
2889
3014
}
0 commit comments