Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporate selectors into Shuttle #94

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ rand_pcg = "0.3.1"
scoped-tls = "1.0.0"
smallvec = "1.6.1"
tracing = { version = "0.1.21", default-features = false, features = ["std"] }
crossbeam-channel = "0.5.6"

[dev-dependencies]
criterion = { version = "0.4.0", features = ["html_reports"] }
Expand Down
34 changes: 33 additions & 1 deletion src/sync/mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::runtime::execution::ExecutionState;
use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::{TaskId, DEFAULT_INLINE_TASKS};
use crate::runtime::thread;
use crate::sync::mpsc::selector::Selectable;
use smallvec::SmallVec;
use std::cell::RefCell;
use std::fmt::Debug;
Expand All @@ -14,6 +15,9 @@ use std::sync::Arc;
use std::time::Duration;
use tracing::trace;

mod selector;
pub use self::selector::Select;

// TODO
// * Add support for try_recv() and try_send()
// * Add support for iter() for receivers
Expand All @@ -32,6 +36,11 @@ pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
(sender, receiver)
}

/// Create an unbounded channel -- alias for channel<T>() to conform to crossbeam standards
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
channel()
}

/// Create a bounded channel
pub fn sync_channel<T>(bound: usize) -> (SyncSender<T>, Receiver<T>) {
let channel = Arc::new(Channel::new(Some(bound)));
Expand Down Expand Up @@ -355,7 +364,7 @@ unsafe impl<T: Send> Sync for Channel<T> {}

/// The receiving half of Rust's [`channel`] (or [`sync_channel`]) type.
/// This half can only be owned by one thread.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Receiver<T> {
inner: Arc<Channel<T>>,
}
Expand All @@ -375,6 +384,22 @@ impl<T> Receiver<T> {
}
}

impl<T> Selectable for Receiver<T> {
// Determines whether the channel has anything to be received.
// TODO (Finn) is this sufficient?
fn try_select(&self) -> bool {
self.inner.state.borrow().messages.len() > 0
}

fn add_waiting_receiver(&self, task: TaskId) {
self.inner.state.borrow_mut().waiting_receivers.push(task);
}

fn delete_waiting_receiver(&self, task: TaskId) {
self.inner.state.borrow_mut().waiting_receivers.retain(|v| *v != task)
}
}

impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if ExecutionState::should_stop() {
Expand Down Expand Up @@ -405,6 +430,13 @@ impl<T> Sender<T> {
pub fn send(&self, t: T) -> Result<(), SendError<T>> {
self.inner.send(t)
}

/// Attempts to send a value on this channel, returning it back if it could
/// not be sent.
/// TODO: incorporate timeout
pub fn send_timeout(&self, t: T, _: Duration) -> Result<(), SendError<T>> {
self.send(t)
}
}

impl<T> Clone for Sender<T> {
Expand Down
117 changes: 117 additions & 0 deletions src/sync/mpsc/selector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//! Selector implementation for multi-producer, single-consumer channels.

use crate::{sync::mpsc::Receiver, runtime::{execution::ExecutionState, thread}};
use crossbeam_channel::{TrySelectError, SelectTimeoutError, RecvError};
use core::fmt::Debug;
use std::time::Duration;
use crate::runtime::task::TaskId;

/// Represents the return value of a selector; contains an index representing which of the selectables was ready.
#[derive(Debug)]
pub struct SelectedOperation {
/// the index representing which selectable became ready
pub index: usize,
}

impl SelectedOperation {
pub fn index(&self) -> usize {
self.index
}

/// Performs a receive on an arbitrary receiver which had been given to the selector that returned this SelectedOperation.
/// TODO: in crossbeam, this method panics if the receiver does not match the one added to the selector -- is this necessary?
pub fn recv<T>(&self, r: &Receiver<T>) -> Result<T, RecvError> {
r.recv().map_err(|_| RecvError)
}
}

/// Any object which is selectable -- typically used for a receiver.
pub trait Selectable {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this in the public API?

/// Attempts to select from the selectable, returning true if anything is present and false otherwise.
fn try_select(&self) -> bool;
/// Adds a queued receiver to the selectable (used when the selector containing the selectable is about to block).
fn add_waiting_receiver(&self, task: TaskId);
/// Removes all instances of a queued receiver from the selectable (used after the selector has been unblocked).
fn delete_waiting_receiver(&self, task: TaskId);
}

impl<'a> Debug for dyn Selectable + 'a {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Selectable")
}
}

fn try_select(handles: &mut [(&dyn Selectable, usize)]) -> Result<SelectedOperation, TrySelectError> {
for handle in handles {
if handle.0.try_select() {
return Ok(SelectedOperation{index: handle.1})
}
}
Err(TrySelectError{})
}

fn select(handles: &mut [(&dyn Selectable, usize)]) -> SelectedOperation {
SelectedOperation {
index: {
if let Ok(SelectedOperation{index: idx}) = try_select(handles) {
idx
} else {
let id = ExecutionState::me();

loop {
for handle in &mut *handles {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add more than once

handle.0.add_waiting_receiver(id);
}

ExecutionState::with(|state| {
state.get_mut(id).block()
});
thread::switch();

if let Ok(SelectedOperation{index: idx}) = try_select(handles) {
for handle in &mut *handles {
handle.0.delete_waiting_receiver(id);
}
break idx;
}
}
}
},
}
}

/// A selector.
#[derive(Debug)]
pub struct Select<'a> {
handles: Vec<(&'a dyn Selectable, usize)>,
}

impl<'a> Select<'a> {
/// Creates a new instance of the selector with no selectables.
pub fn new() -> Self {
Self { handles: Vec::new() }
}

/// Adds a new receiving selectable which the selector will wait on.
pub fn recv<T>(&mut self, r: &'a Receiver<T>) -> usize {
self.handles.push((r, self.handles.len()));
self.handles.len() - 1
}

/// Attempts to receive from one of the added selectables, returning the index of the given channel if possible.
pub fn try_select(&mut self) -> Result<SelectedOperation, TrySelectError> {
try_select(&mut self.handles)
}

/// Blocks until a value can be retrieved from one of the given selectables.
pub fn select(&mut self) -> SelectedOperation {
select(&mut self.handles)
}

/// Blocks until a value can be retrieved from one of the given selectables, returning an error if no value is received
/// before the timeout.
/// TODO: actually enforce timeout
pub fn select_timeout(&mut self, _: Duration) -> Result<SelectedOperation, SelectTimeoutError> {
Ok(self.select())
}
}
1 change: 1 addition & 0 deletions tests/basic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod pct;
mod portfolio;
mod replay;
mod rwlock;
mod selector;
mod shrink;
mod thread;
mod timeout;
113 changes: 113 additions & 0 deletions tests/basic/selector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::time::Duration;

use shuttle::sync::mpsc::{Select, channel};
use shuttle::{check_dfs};
use test_log::test;
use shuttle::thread;

#[test]
fn selector_one_channel() {
check_dfs(
move || {
let (s, r) = channel();

let mut selector = Select::new();
selector.recv(&r);

s.send(5).unwrap();

let op = selector.select();
assert_eq!(op.index, 0);

let val = r.recv().unwrap();
assert_eq!(val, 5);
},
None,
);
}

#[test]
fn selector_multi_channel() {
check_dfs(
move || {
let (_, r1) = channel::<i32>();
let (s2, r2) = channel();

let mut selector = Select::new();
selector.recv(&r1);
selector.recv(&r2);

s2.send(81).unwrap();

let op = selector.select();
assert_eq!(op.index, 1);

let val = r2.recv().unwrap();
assert_eq!(val, 81);
},
None,
);
}

#[test]
fn try_select_empty_selector() {
check_dfs(
move || {
assert!(Select::new().try_select().is_err())
},
None,
);
}

#[test]
fn select_unused_channel_functional() {
check_dfs(
move || {
let (s1, r1) = channel();
let (s2, r2) = channel();

let mut selector = Select::new();
selector.recv(&r1);
selector.recv(&r2);

s2.send(81).unwrap();

let op = selector.select();
assert_eq!(op.index, 1);

let val = r2.recv().unwrap();
assert_eq!(val, 81);

s1.send(198).unwrap();
let val = r1.recv().unwrap();
assert_eq!(val, 198);
},
None,
);
}

#[test]
fn select_multi_threaded() {
check_dfs(
move || {
let (_, r1) = channel::<i32>();
let (s2, r2) = channel();

let mut selector = Select::new();
selector.recv(&r1);
selector.recv(&r2);

thread::spawn(move || {
thread::sleep(Duration::from_millis(15));
s2.send(5).unwrap();
});

let op = selector.select();
assert_eq!(op.index, 1);

let val = r2.recv().unwrap();
assert_eq!(val, 5);
},
None,
);
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more tests: 3 threads