Skip to content

Commit 7f88093

Browse files
committed
Fix HashSet::get_or_insert_with
1 parent 5e4a982 commit 7f88093

File tree

2 files changed

+137
-7
lines changed

2 files changed

+137
-7
lines changed

src/map.rs

+5
Original file line numberDiff line numberDiff line change
@@ -4132,6 +4132,11 @@ impl<'a, K, V, S, A: Allocator + Clone> RawVacantEntryMut<'a, K, V, S, A> {
41324132
hash_builder: self.hash_builder,
41334133
}
41344134
}
4135+
4136+
#[inline]
4137+
pub(crate) fn hasher(&self) -> &S {
4138+
self.hash_builder
4139+
}
41354140
}
41364141

41374142
impl<K, V, S, A: Allocator + Clone> Debug for RawEntryBuilderMut<'_, K, V, S, A> {

src/set.rs

+132-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ use core::iter::{Chain, FromIterator, FusedIterator};
88
use core::mem;
99
use core::ops::{BitAnd, BitOr, BitXor, Sub};
1010

11-
use super::map::{self, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, HashMap, Keys};
11+
use super::map::{
12+
self, make_hash, make_insert_hash, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner,
13+
HashMap, Keys, RawEntryMut,
14+
};
1215
use crate::raw::{Allocator, Global};
1316

1417
// Future Optimization (FIXME!)
@@ -953,6 +956,12 @@ where
953956
/// Inserts a value computed from `f` into the set if the given `value` is
954957
/// not present, then returns a reference to the value in the set.
955958
///
959+
/// # Panics
960+
///
961+
/// Panics if the value from the function and the provided lookup value
962+
/// are not equivalent or have different hashes. See [`Equivalent`]
963+
/// and [`Hash`] for more information.
964+
///
956965
/// # Examples
957966
///
958967
/// ```
@@ -967,20 +976,40 @@ where
967976
/// assert_eq!(value, pet);
968977
/// }
969978
/// assert_eq!(set.len(), 4); // a new "fish" was inserted
979+
/// assert!(set.contains("fish"));
970980
/// ```
971981
#[cfg_attr(feature = "inline-more", inline)]
972982
pub fn get_or_insert_with<Q: ?Sized, F>(&mut self, value: &Q, f: F) -> &T
973983
where
974984
Q: Hash + Equivalent<T>,
975985
F: FnOnce(&Q) -> T,
976986
{
987+
#[cold]
988+
#[inline(never)]
989+
fn assert_failed() {
990+
panic!(
991+
"the value from the function and the lookup value \
992+
must be equivalent and have the same hash"
993+
);
994+
}
995+
977996
// Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
978997
// `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
979-
self.map
980-
.raw_entry_mut()
981-
.from_key(value)
982-
.or_insert_with(|| (f(value), ()))
983-
.0
998+
let hash = make_hash::<Q, S>(&self.map.hash_builder, value);
999+
let raw_entry_builder = self.map.raw_entry_mut();
1000+
match raw_entry_builder.from_key_hashed_nocheck(hash, value) {
1001+
RawEntryMut::Occupied(entry) => entry.into_key(),
1002+
RawEntryMut::Vacant(entry) => {
1003+
let insert_value = f(value);
1004+
let insert_value_hash = make_insert_hash::<T, S>(entry.hasher(), &insert_value);
1005+
if !(hash == insert_value_hash && value.equivalent(&insert_value)) {
1006+
assert_failed();
1007+
}
1008+
entry
1009+
.insert_hashed_nocheck(insert_value_hash, insert_value, ())
1010+
.0
1011+
}
1012+
}
9841013
}
9851014

9861015
/// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2429,7 +2458,7 @@ fn assert_covariance() {
24292458
#[cfg(test)]
24302459
mod test_set {
24312460
use super::super::map::DefaultHashBuilder;
2432-
use super::HashSet;
2461+
use super::{make_hash, Equivalent, HashSet};
24332462
use std::vec::Vec;
24342463

24352464
#[test]
@@ -2886,4 +2915,100 @@ mod test_set {
28862915
set.insert(i);
28872916
}
28882917
}
2918+
2919+
#[test]
2920+
fn duplicate_insert() {
2921+
let mut set = HashSet::new();
2922+
set.insert(1);
2923+
set.get_or_insert_with(&1, |_| 1);
2924+
set.get_or_insert_with(&1, |_| 1);
2925+
assert!([1].iter().eq(set.iter()));
2926+
}
2927+
2928+
#[test]
2929+
#[should_panic]
2930+
fn some_invalid_hash() {
2931+
use core::hash::{Hash, Hasher};
2932+
struct Invalid {
2933+
count: u32,
2934+
}
2935+
2936+
struct InvalidRef {
2937+
count: u32,
2938+
}
2939+
2940+
impl PartialEq for Invalid {
2941+
fn eq(&self, other: &Self) -> bool {
2942+
self.count == other.count
2943+
}
2944+
}
2945+
impl Eq for Invalid {}
2946+
2947+
impl Equivalent<Invalid> for InvalidRef {
2948+
fn equivalent(&self, key: &Invalid) -> bool {
2949+
self.count == key.count
2950+
}
2951+
}
2952+
impl Hash for Invalid {
2953+
fn hash<H: Hasher>(&self, state: &mut H) {
2954+
self.count.hash(state);
2955+
}
2956+
}
2957+
impl Hash for InvalidRef {
2958+
fn hash<H: Hasher>(&self, state: &mut H) {
2959+
let double = self.count * 2;
2960+
double.hash(state);
2961+
}
2962+
}
2963+
let mut set: HashSet<Invalid> = HashSet::new();
2964+
let key = InvalidRef { count: 1 };
2965+
let value = Invalid { count: 1 };
2966+
if key.equivalent(&value) {
2967+
set.get_or_insert_with(&key, |_| value);
2968+
}
2969+
}
2970+
2971+
#[test]
2972+
#[should_panic]
2973+
fn some_invalid_equivalent() {
2974+
use core::hash::{Hash, Hasher};
2975+
struct Invalid {
2976+
count: u32,
2977+
other: u32,
2978+
}
2979+
2980+
struct InvalidRef {
2981+
count: u32,
2982+
other: u32,
2983+
}
2984+
2985+
impl PartialEq for Invalid {
2986+
fn eq(&self, other: &Self) -> bool {
2987+
self.count == other.count && self.other == other.other
2988+
}
2989+
}
2990+
impl Eq for Invalid {}
2991+
2992+
impl Equivalent<Invalid> for InvalidRef {
2993+
fn equivalent(&self, key: &Invalid) -> bool {
2994+
self.count == key.count && self.other == key.other
2995+
}
2996+
}
2997+
impl Hash for Invalid {
2998+
fn hash<H: Hasher>(&self, state: &mut H) {
2999+
self.count.hash(state);
3000+
}
3001+
}
3002+
impl Hash for InvalidRef {
3003+
fn hash<H: Hasher>(&self, state: &mut H) {
3004+
self.count.hash(state);
3005+
}
3006+
}
3007+
let mut set: HashSet<Invalid> = HashSet::new();
3008+
let key = InvalidRef { count: 1, other: 1 };
3009+
let value = Invalid { count: 1, other: 2 };
3010+
if make_hash(set.hasher(), &key) == make_hash(set.hasher(), &value) {
3011+
set.get_or_insert_with(&key, |_| value);
3012+
}
3013+
}
28893014
}

0 commit comments

Comments
 (0)