Skip to content

Commit

Permalink
Implement ParJoin for MaybeJoin if the inner type is ParJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
kyren committed Nov 15, 2019
1 parent b1aae2d commit 82ae48c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ where
}
}

// SAFETY: This is safe as long as `T` implements `ParJoin` safely. `MaybeJoin` relies on `T as
// Join` for all storage access and safely wraps the inner `Join` API, so it should also be able to
// implement `ParJoin`.
#[cfg(feature = "parallel")]
unsafe impl<T> ParJoin for MaybeJoin<T>
where
T: ParJoin,
{}

/// `JoinIter` is an `Iterator` over a group of `Storages`.
#[must_use]
pub struct JoinIter<J: Join> {
Expand Down
71 changes: 71 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,77 @@ fn par_join_two_components() {
);
}

#[test]
fn par_join_with_maybe() {
use std::sync::{
atomic::{AtomicBool, Ordering},
Mutex,
};
let mut world = create_world();
world
.create_entity()
.with(CompInt(1))
.with(CompBool(false))
.build();
world
.create_entity()
.with(CompInt(2))
.with(CompBool(true))
.build();
world.create_entity().with(CompInt(3)).build();
let first = AtomicBool::new(false);
let second = AtomicBool::new(false);
let third = AtomicBool::new(false);
let error = Mutex::new(None);
struct Iter<'a>(
&'a AtomicBool,
&'a AtomicBool,
&'a AtomicBool,
&'a Mutex<Option<(i8, Option<bool>)>>,
);
impl<'a, 'b> System<'a> for Iter<'b> {
type SystemData = (ReadStorage<'a, CompInt>, ReadStorage<'a, CompBool>);

fn run(&mut self, (int, boolean): Self::SystemData) {
use rayon::iter::ParallelIterator;
let Iter(first, second, third, error) = *self;
(&int, boolean.maybe()).par_join().for_each(|(int, boolean)| {
let boolean = boolean.map(|c| c.0);
if !first.load(Ordering::SeqCst) && int.0 == 1 && boolean == Some(false) {
first.store(true, Ordering::SeqCst);
} else if !second.load(Ordering::SeqCst) && int.0 == 2 && boolean == Some(true) {
second.store(true, Ordering::SeqCst);
} else if !third.load(Ordering::SeqCst) && int.0 == 3 && boolean == None {
third.store(true, Ordering::SeqCst);
} else {
*error.lock().unwrap() = Some((int.0, boolean));
}
});
}
}
let mut dispatcher = DispatcherBuilder::new()
.with(Iter(&first, &second, &third, &error), "iter", &[])
.build();
dispatcher.dispatch(&mut world);
assert_eq!(
*error.lock().unwrap(),
None,
"Entity shouldn't be in the join",
);
assert!(
first.load(Ordering::SeqCst),
"There should be entity with CompInt(1) and CompBool(false)"
);
assert!(
second.load(Ordering::SeqCst),
"There should be entity with CompInt(2) and CompBool(true)"
);
assert!(
third.load(Ordering::SeqCst),
"There should be entity with CompInt(3) and no CompBool"
);
}

#[test]
fn par_join_many_entities_and_systems() {
use rayon::iter::ParallelIterator;
Expand Down

0 comments on commit 82ae48c

Please sign in to comment.