diff --git a/src/join/mod.rs b/src/join/mod.rs index 73d1cf3e9..fb26342ca 100644 --- a/src/join/mod.rs +++ b/src/join/mod.rs @@ -291,6 +291,15 @@ where } } +// SAFETY: This is safe as long as `T` implements `ParJoin` safely. `MaybeJoin` relies on `T as +// Join` for all storage access and safely wraps the inner `Join` API, so it should also be able to +// implement `ParJoin`. +#[cfg(feature = "parallel")] +unsafe impl ParJoin for MaybeJoin +where + T: ParJoin, +{} + /// `JoinIter` is an `Iterator` over a group of `Storages`. #[must_use] pub struct JoinIter { diff --git a/tests/tests.rs b/tests/tests.rs index 62cc0c2a3..b22be8c3e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -397,6 +397,77 @@ fn par_join_two_components() { ); } +#[test] +fn par_join_with_maybe() { + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Mutex, + }; + let mut world = create_world(); + world + .create_entity() + .with(CompInt(1)) + .with(CompBool(false)) + .build(); + world + .create_entity() + .with(CompInt(2)) + .with(CompBool(true)) + .build(); + world.create_entity().with(CompInt(3)).build(); + let first = AtomicBool::new(false); + let second = AtomicBool::new(false); + let third = AtomicBool::new(false); + let error = Mutex::new(None); + struct Iter<'a>( + &'a AtomicBool, + &'a AtomicBool, + &'a AtomicBool, + &'a Mutex)>>, + ); + impl<'a, 'b> System<'a> for Iter<'b> { + type SystemData = (ReadStorage<'a, CompInt>, ReadStorage<'a, CompBool>); + + fn run(&mut self, (int, boolean): Self::SystemData) { + use rayon::iter::ParallelIterator; + let Iter(first, second, third, error) = *self; + (&int, boolean.maybe()).par_join().for_each(|(int, boolean)| { + let boolean = boolean.map(|c| c.0); + if !first.load(Ordering::SeqCst) && int.0 == 1 && boolean == Some(false) { + first.store(true, Ordering::SeqCst); + } else if !second.load(Ordering::SeqCst) && int.0 == 2 && boolean == Some(true) { + second.store(true, Ordering::SeqCst); + } else if !third.load(Ordering::SeqCst) && int.0 == 3 && boolean == None { + third.store(true, Ordering::SeqCst); + } else { + *error.lock().unwrap() = Some((int.0, boolean)); + } + }); + } + } + let mut dispatcher = DispatcherBuilder::new() + .with(Iter(&first, &second, &third, &error), "iter", &[]) + .build(); + dispatcher.dispatch(&mut world); + assert_eq!( + *error.lock().unwrap(), + None, + "Entity shouldn't be in the join", + ); + assert!( + first.load(Ordering::SeqCst), + "There should be entity with CompInt(1) and CompBool(false)" + ); + assert!( + second.load(Ordering::SeqCst), + "There should be entity with CompInt(2) and CompBool(true)" + ); + assert!( + third.load(Ordering::SeqCst), + "There should be entity with CompInt(3) and no CompBool" + ); +} + #[test] fn par_join_many_entities_and_systems() { use rayon::iter::ParallelIterator;