diff --git a/Cargo.toml b/Cargo.toml index 45e433fff..484314994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ rust-version = "1.65.0" foldhash = { version = "0.2.0", default-features = false, optional = true } # For external trait impls +paralight = { version = "0.0.6", optional = true } rayon = { version = "1.9.0", optional = true } serde_core = { version = "1.0.221", default-features = false, optional = true } @@ -85,5 +86,5 @@ default-hasher = ["dep:foldhash"] inline-more = [] [package.metadata.docs.rs] -features = ["nightly", "rayon", "serde", "raw-entry"] +features = ["nightly", "paralight", "rayon", "serde", "raw-entry"] rustdoc-args = ["--generate-link-to-definition"] diff --git a/src/external_trait_impls/mod.rs b/src/external_trait_impls/mod.rs index ef497836c..cfea0ffef 100644 --- a/src/external_trait_impls/mod.rs +++ b/src/external_trait_impls/mod.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "paralight")] +mod paralight; #[cfg(feature = "rayon")] pub(crate) mod rayon; #[cfg(feature = "serde")] diff --git a/src/external_trait_impls/paralight.rs b/src/external_trait_impls/paralight.rs new file mode 100644 index 000000000..ce01ca507 --- /dev/null +++ b/src/external_trait_impls/paralight.rs @@ -0,0 +1,498 @@ +use crate::raw::{Allocator, RawTable}; +use crate::{HashMap, HashSet}; +use paralight::iter::{ + IntoParallelRefMutSource, IntoParallelRefSource, IntoParallelSource, ParallelSource, + SourceCleanup, SourceDescriptor, +}; + +// HashSet.par_iter() +impl<'data, T: Sync + 'data, S: 'data, A: Allocator + Sync + 'data> IntoParallelRefSource<'data> + for HashSet +{ + type Item = Option<&'data T>; + type Source = HashSetRefParallelSource<'data, T, S, A>; + + fn par_iter(&'data self) -> Self::Source { + HashSetRefParallelSource { hash_set: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashSetRefParallelSource<'data, T, S, A: Allocator> { + hash_set: &'data HashSet, +} + +impl<'data, T: Sync, S, A: Allocator + Sync> ParallelSource + for HashSetRefParallelSource<'data, T, S, A> +{ + type Item = Option<&'data T>; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashSetRefSourceDescriptor { + table: &self.hash_set.map.table, + } + } +} + +struct HashSetRefSourceDescriptor<'data, T: Sync, A: Allocator> { + table: &'data RawTable<(T, ()), A>, +} + +impl SourceCleanup for HashSetRefSourceDescriptor<'_, T, A> { + const NEEDS_CLEANUP: bool = false; + + unsafe fn cleanup_item_range(&self, _range: core::ops::Range) { + // Nothing to cleanup + } +} + +impl<'data, T: Sync, A: Allocator> SourceDescriptor for HashSetRefSourceDescriptor<'data, T, A> { + type Item = Option<&'data T>; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn fetch_item(&self, index: usize) -> Self::Item { + debug_assert!(index < self.len()); + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + let (t, ()) = unsafe { bucket.as_ref() }; + Some(t) + } else { + None + } + } +} + +// HashMap.par_iter() +impl<'data, K: Sync + 'data, V: Sync + 'data, S: 'data, A: Allocator + Sync + 'data> + IntoParallelRefSource<'data> for HashMap +{ + type Item = Option<&'data (K, V)>; + type Source = HashMapRefParallelSource<'data, K, V, S, A>; + + fn par_iter(&'data self) -> Self::Source { + HashMapRefParallelSource { hash_map: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashMapRefParallelSource<'data, K, V, S, A: Allocator> { + hash_map: &'data HashMap, +} + +impl<'data, K: Sync, V: Sync, S, A: Allocator + Sync> ParallelSource + for HashMapRefParallelSource<'data, K, V, S, A> +{ + type Item = Option<&'data (K, V)>; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashMapRefSourceDescriptor { + table: &self.hash_map.table, + } + } +} + +struct HashMapRefSourceDescriptor<'data, K: Sync, V: Sync, A: Allocator> { + table: &'data RawTable<(K, V), A>, +} + +impl SourceCleanup for HashMapRefSourceDescriptor<'_, K, V, A> { + const NEEDS_CLEANUP: bool = false; + + unsafe fn cleanup_item_range(&self, _range: core::ops::Range) { + // Nothing to cleanup + } +} + +impl<'data, K: Sync, V: Sync, A: Allocator> SourceDescriptor + for HashMapRefSourceDescriptor<'data, K, V, A> +{ + type Item = Option<&'data (K, V)>; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn fetch_item(&self, index: usize) -> Self::Item { + debug_assert!(index < self.len()); + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + unsafe { Some(bucket.as_ref()) } + } else { + None + } + } +} + +// HashMap.par_iter_mut() +// TODO: Remove Sync requirement on V. +impl<'data, K: Sync + 'data, V: Send + Sync + 'data, S: 'data, A: Allocator + Sync + 'data> + IntoParallelRefMutSource<'data> for HashMap +{ + type Item = Option<(&'data K, &'data mut V)>; + type Source = HashMapRefMutParallelSource<'data, K, V, S, A>; + + fn par_iter_mut(&'data mut self) -> Self::Source { + HashMapRefMutParallelSource { hash_map: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashMapRefMutParallelSource<'data, K, V, S, A: Allocator> { + hash_map: &'data mut HashMap, +} + +impl<'data, K: Sync, V: Send + Sync, S, A: Allocator + Sync> ParallelSource + for HashMapRefMutParallelSource<'data, K, V, S, A> +{ + type Item = Option<(&'data K, &'data mut V)>; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashMapRefMutSourceDescriptor { + table: &self.hash_map.table, + } + } +} + +struct HashMapRefMutSourceDescriptor<'data, K: Sync, V: Send + Sync, A: Allocator> { + table: &'data RawTable<(K, V), A>, +} + +impl SourceCleanup + for HashMapRefMutSourceDescriptor<'_, K, V, A> +{ + const NEEDS_CLEANUP: bool = false; + + unsafe fn cleanup_item_range(&self, _range: core::ops::Range) { + // Nothing to cleanup + } +} + +impl<'data, K: Sync, V: Send + Sync, A: Allocator> SourceDescriptor + for HashMapRefMutSourceDescriptor<'data, K, V, A> +{ + type Item = Option<(&'data K, &'data mut V)>; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn fetch_item(&self, index: usize) -> Self::Item { + debug_assert!(index < self.len()); + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + let (key, value) = unsafe { bucket.as_mut() }; + Some((key, value)) + } else { + None + } + } +} + +// HashSet.into_par_iter() +// TODO: Remove Sync requirement on T. +impl IntoParallelSource for HashSet { + type Item = Option; + type Source = HashSetParallelSource; + + fn into_par_iter(self) -> Self::Source { + HashSetParallelSource { hash_set: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashSetParallelSource { + hash_set: HashSet, +} + +impl ParallelSource for HashSetParallelSource { + type Item = Option; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashSetSourceDescriptor { + table: self.hash_set.map.table, + } + } +} + +struct HashSetSourceDescriptor { + table: RawTable<(T, ()), A>, +} + +impl SourceCleanup for HashSetSourceDescriptor { + const NEEDS_CLEANUP: bool = core::mem::needs_drop::(); + + unsafe fn cleanup_item_range(&self, range: core::ops::Range) { + if Self::NEEDS_CLEANUP { + debug_assert!(range.start <= range.end); + debug_assert!(range.start <= self.len()); + debug_assert!(range.end <= self.len()); + for index in range { + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + let (t, ()) = unsafe { bucket.read() }; + drop(t); + } + } + } + } +} + +impl SourceDescriptor for HashSetSourceDescriptor { + type Item = Option; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn fetch_item(&self, index: usize) -> Self::Item { + debug_assert!(index < self.len()); + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + let (t, ()) = unsafe { bucket.read() }; + Some(t) + } else { + None + } + } +} + +impl Drop for HashSetSourceDescriptor { + fn drop(&mut self) { + // Paralight already dropped each missing bucket via calls to cleanup_item_range(), so we + // can simply mark all buckets as cleared and let the RawTable destructor do the rest. + // TODO: Optimize this to simply deallocate without touching the control bytes. + self.table.clear_no_drop(); + } +} + +// HashMap.into_par_iter() +// TODO: Remove Sync requirement on K and V. +impl IntoParallelSource + for HashMap +{ + type Item = Option<(K, V)>; + type Source = HashMapParallelSource; + + fn into_par_iter(self) -> Self::Source { + HashMapParallelSource { hash_map: self } + } +} + +#[must_use = "iterator adaptors are lazy"] +pub struct HashMapParallelSource { + hash_map: HashMap, +} + +impl ParallelSource + for HashMapParallelSource +{ + type Item = Option<(K, V)>; + + fn descriptor(self) -> impl SourceDescriptor + Sync { + HashMapSourceDescriptor { + table: self.hash_map.table, + } + } +} + +struct HashMapSourceDescriptor { + table: RawTable<(K, V), A>, +} + +impl SourceCleanup + for HashMapSourceDescriptor +{ + const NEEDS_CLEANUP: bool = core::mem::needs_drop::<(K, V)>(); + + unsafe fn cleanup_item_range(&self, range: core::ops::Range) { + if Self::NEEDS_CLEANUP { + debug_assert!(range.start <= range.end); + debug_assert!(range.start <= self.len()); + debug_assert!(range.end <= self.len()); + for index in range { + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + let key_value = unsafe { bucket.read() }; + drop(key_value); + } + } + } + } +} + +impl SourceDescriptor + for HashMapSourceDescriptor +{ + type Item = Option<(K, V)>; + + fn len(&self) -> usize { + self.table.buckets() + } + + unsafe fn fetch_item(&self, index: usize) -> Self::Item { + debug_assert!(index < self.len()); + // SAFETY: TODO + let full = unsafe { self.table.is_bucket_full(index) }; + if full { + // SAFETY: TODO + let bucket = unsafe { self.table.bucket(index) }; + // SAFETY: TODO + unsafe { Some(bucket.read()) } + } else { + None + } + } +} + +impl Drop for HashMapSourceDescriptor { + fn drop(&mut self) { + // Paralight already dropped each missing bucket via calls to cleanup_item_range(), so we + // can simply mark all buckets as cleared and let the RawTable destructor do the rest. + // TODO: Optimize this to simply deallocate without touching the control bytes. + self.table.clear_no_drop(); + } +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::boxed::Box; + use core::ops::Deref; + use paralight::iter::{ParallelIteratorExt, ParallelSourceExt}; + use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder}; + + #[test] + fn test_set_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut set = HashSet::new(); + for i in 1..=42 { + set.insert(Box::new(i)); + } + + let sum = set + .par_iter() + .with_thread_pool(&mut thread_pool) + .filter_map(|x| x.map(|y| y.deref())) + .sum::(); + assert_eq!(sum, 21 * 43); + } + + #[test] + fn test_set_into_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut set = HashSet::new(); + for i in 1..=42 { + set.insert(Box::new(i)); + } + + let sum = set + .into_par_iter() + .with_thread_pool(&mut thread_pool) + .filter_map(|x| x.map(|y| *y)) + .sum::(); + assert_eq!(sum, 21 * 43); + } + + #[test] + fn test_map_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i * i)); + } + + map.par_iter() + .with_thread_pool(&mut thread_pool) + .filter_map(|x| x) + .for_each(|(k, v)| assert_eq!(**k * **k, **v)); + } + + #[test] + fn test_map_par_iter_mut() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i)); + } + + map.par_iter_mut() + .with_thread_pool(&mut thread_pool) + .filter_map(|x| x) + .for_each(|(k, v)| **v *= **k); + + for (k, v) in map.iter() { + assert_eq!(**k * **k, **v); + } + } + + #[test] + fn test_map_into_par_iter() { + let mut thread_pool = ThreadPoolBuilder { + num_threads: ThreadCount::AvailableParallelism, + range_strategy: RangeStrategy::WorkStealing, + cpu_pinning: CpuPinningPolicy::No, + } + .build(); + + let mut map = HashMap::new(); + for i in 1..=42 { + map.insert(Box::new(i), Box::new(i * i)); + } + + map.into_par_iter() + .with_thread_pool(&mut thread_pool) + .filter_map(|x| x) + .for_each(|(k, v)| assert_eq!(*k * *k, *v)); + } +}