Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for typed task Labels. #138

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions src/current.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@
//! building tools that need to exploit Shuttle's total ordering of concurrent operations; for
//! example, a tool that wants to check linearizability might want access to a global timestamp for
//! events, which the [`context_switches`] function provides.
jorajeev marked this conversation as resolved.
Show resolved Hide resolved
//!
//! This module also provides functions to manage the assocation of `labels` to threads and async tasks.
//! Labels are typed values that can be associated with a task. They are useful for debugging: for
//! instance, the `TaskName` label can be set to assign names to tasks to make debug output easier to read.
//! Labels can also be used to build customized schedulers: for instance, they can be used to assign
//! numeric weights to tasks, which can be used to implement a priority-preemptive scheduler.

use crate::runtime::execution::{ExecutionState, TASK_ID_TO_TAGS};
#[allow(deprecated)]
use crate::runtime::execution::TASK_ID_TO_TAGS;
use crate::runtime::execution::{ExecutionState, LABELS};
use crate::runtime::task::clock::VectorClock;
pub use crate::runtime::task::{Tag, Taggable, TaskId};
pub use crate::runtime::task::labels::Labels;
pub use crate::runtime::task::{ChildLabelFn, TaskId, TaskName};
#[allow(deprecated)]
pub use crate::runtime::task::{Tag, Taggable};
use std::fmt::Debug;
use std::sync::Arc;

/// The number of context switches that happened so far in the current Shuttle execution.
Expand All @@ -34,23 +46,71 @@ pub fn clock_for(task_id: TaskId) -> VectorClock {
ExecutionState::with(|state| state.get_clock(task_id).clone())
}

/// Apply the given function to the Labels for the specified task
pub fn with_labels_for_task<F, T>(task_id: TaskId, f: F) -> T
where
F: FnOnce(&mut Labels) -> T,
{
LABELS.with(|cell| {
let mut map = cell.borrow_mut();
let m = map.entry(task_id).or_default();
f(m)
})
}

/// Get a label of the given type for the specified task, if any
pub fn get_label_for_task<T: Clone + Debug + 'static>(task_id: TaskId) -> Option<T> {
with_labels_for_task(task_id, |labels| labels.get().cloned())
}

/// Add the given label to the specified task, returning the old label for the type, if any
pub fn set_label_for_task<T: Clone + Debug + 'static>(task_id: TaskId, value: T) -> Option<T> {
with_labels_for_task(task_id, |labels| labels.insert(value))
}

/// Remove a label of the given type for the specified task, returning the old label for the type, if any
pub fn remove_label_for_task<T: Clone + Debug + 'static>(task_id: TaskId) -> Option<T> {
with_labels_for_task(task_id, |labels| labels.remove())
}

/// Get the debug name for a task
pub fn get_name_for_task(task_id: TaskId) -> Option<TaskName> {
get_label_for_task::<TaskName>(task_id)
}

/// Set the debug name for a task, returning the old name, if any
pub fn set_name_for_task(task_id: TaskId, task_name: impl Into<TaskName>) -> Option<TaskName> {
set_label_for_task::<TaskName>(task_id, task_name.into())
}

/// Gets the `TaskId` of the current task, or `None` if there is no current task.
pub fn get_current_task() -> Option<TaskId> {
ExecutionState::with(|s| Some(s.try_current()?.id()))
}

/// Get the `TaskId` of the current task. Panics if there is no current task.
pub fn me() -> TaskId {
get_current_task().unwrap()
}

/// Sets the `tag` field of the current task.
/// Returns the `tag` which was there previously.
#[deprecated]
#[allow(deprecated)]
jorajeev marked this conversation as resolved.
Show resolved Hide resolved
pub fn set_tag_for_current_task(tag: Arc<dyn Tag>) -> Option<Arc<dyn Tag>> {
ExecutionState::set_tag_for_current_task(tag)
}

/// Gets the `tag` field of the current task.
#[deprecated]
#[allow(deprecated)]
pub fn get_tag_for_current_task() -> Option<Arc<dyn Tag>> {
ExecutionState::get_tag_for_current_task()
}

/// Gets the `TaskId` of the current task, or `None` if there is no current task.
pub fn get_current_task() -> Option<TaskId> {
ExecutionState::with(|s| Some(s.try_current()?.id()))
}

/// Gets the `tag` field of the specified task.
#[deprecated]
#[allow(deprecated)]
pub fn get_tag_for_task(task_id: TaskId) -> Option<Arc<dyn Tag>> {
TASK_ID_TO_TAGS.with(|cell| {
let map = cell.borrow();
Expand All @@ -59,6 +119,8 @@ pub fn get_tag_for_task(task_id: TaskId) -> Option<Arc<dyn Tag>> {
}

/// Sets the `tag` field of the specified task.
#[deprecated]
#[allow(deprecated)]
pub fn set_tag_for_task(task: TaskId, tag: Arc<dyn Tag>) -> Option<Arc<dyn Tag>> {
ExecutionState::set_tag_for_task(task, tag)
}
50 changes: 49 additions & 1 deletion src/runtime/execution.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::runtime::failure::{init_panic_hook, persist_failure, persist_task_failure};
use crate::runtime::storage::{StorageKey, StorageMap};
use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::{Task, TaskId, DEFAULT_INLINE_TASKS};
use crate::runtime::task::labels::Labels;
use crate::runtime::task::{ChildLabelFn, Task, TaskId, TaskName, DEFAULT_INLINE_TASKS};
use crate::runtime::thread::continuation::PooledContinuation;
use crate::scheduler::{Schedule, Scheduler};
use crate::thread::thread_fn;
Expand All @@ -18,6 +19,7 @@ use std::rc::Rc;
use std::sync::Arc;
use tracing::{trace, Span};

#[allow(deprecated)]
use super::task::Tag;

// We use this scoped TLS to smuggle the ExecutionState, which is not 'static, across tasks that
Expand All @@ -28,9 +30,14 @@ scoped_thread_local! {

thread_local! {
#[allow(clippy::complexity)]
#[allow(deprecated)]
pub(crate) static TASK_ID_TO_TAGS: RefCell<HashMap<TaskId, Arc<dyn Tag>>> = RefCell::new(HashMap::new());
}

thread_local! {
pub(crate) static LABELS: RefCell<HashMap<TaskId, Labels>> = RefCell::new(HashMap::new());
}

/// An `Execution` encapsulates a single run of a function under test against a chosen scheduler.
/// Its only useful method is `Execution::run`, which executes the function to completion.
///
Expand Down Expand Up @@ -327,6 +334,33 @@ impl ExecutionState {
Self::with(|s| s.current().id())
}

fn set_labels_for_new_task(state: &ExecutionState, task_id: TaskId, name: Option<String>) {
LABELS.with(|cell| {
let mut map = cell.borrow_mut();

// If parent has labels, inherit them
if let Some(parent_task_id) = state.try_current().map(|t| t.id()) {
let parent_map = map.get(&parent_task_id);
if let Some(parent_map) = parent_map {
let mut child_map = parent_map.clone();

// If the parent has a `ChildLabelFn` set, use that to update the child's Labels
if let Some(gen) = parent_map.get::<ChildLabelFn>() {
(gen.0)(task_id, &mut child_map);
}

map.insert(task_id, child_map);
}
}

// Add any name assigned to the task to its set of Labels
if let Some(name) = name {
let m = map.entry(task_id).or_default();
m.insert(TaskName::from(name));
}
});
}

/// Spawn a new task for a future. This doesn't create a yield point; the caller should do that
/// if it wants to give the new task a chance to run immediately.
pub(crate) fn spawn_future<F>(future: F, stack_size: usize, name: Option<String>) -> TaskId
Expand All @@ -339,6 +373,9 @@ impl ExecutionState {

let task_id = TaskId(state.tasks.len());
let tag = state.get_tag_or_default_for_current_task();

Self::set_labels_for_new_task(state, task_id, name.clone());

let clock = state.increment_clock_mut(); // Increment the parent's clock
clock.extend(task_id); // and extend it with an entry for the new task

Expand Down Expand Up @@ -372,6 +409,9 @@ impl ExecutionState {
let parent_span_id = state.top_level_span.id();
let task_id = TaskId(state.tasks.len());
let tag = state.get_tag_or_default_for_current_task();

Self::set_labels_for_new_task(state, task_id, name.clone());

let clock = if let Some(ref mut clock) = initial_clock {
clock
} else {
Expand Down Expand Up @@ -424,6 +464,7 @@ impl ExecutionState {
while Self::with(|state| state.storage.pop()).is_some() {}

TASK_ID_TO_TAGS.with(|cell| cell.borrow_mut().clear());
LABELS.with(|cell| cell.borrow_mut().clear());

#[cfg(debug_assertions)]
Self::with(|state| state.has_cleaned_up = true);
Expand Down Expand Up @@ -651,6 +692,9 @@ impl ExecutionState {
// 2) It creates a visual separation of scheduling decisions and `Task`-induced tracing.
// Note that there is a case to be made for not `in_scope`-ing it, as that makes seeing the context
// of the context switch clearer.
//
// Note also that changing this trace! statement requires changing the test `basic::labels::test_tracing_with_label_fn`
// which relies on this trace reporting the `runnable` tasks.
self.top_level_span
.in_scope(|| trace!(?runnable, next_task=?self.next_task));

Expand All @@ -669,18 +713,22 @@ impl ExecutionState {

// Sets the `tag` field of the current task.
// Returns the `tag` which was there previously.
#[allow(deprecated)]
pub(crate) fn set_tag_for_current_task(tag: Arc<dyn Tag>) -> Option<Arc<dyn Tag>> {
ExecutionState::with(|s| s.current_mut().set_tag(tag))
}

#[allow(deprecated)]
fn get_tag_or_default_for_current_task(&self) -> Option<Arc<dyn Tag>> {
self.try_current().and_then(|current| current.get_tag())
}

#[allow(deprecated)]
pub(crate) fn get_tag_for_current_task() -> Option<Arc<dyn Tag>> {
ExecutionState::with(|s| s.get_tag_or_default_for_current_task())
}

#[allow(deprecated)]
pub(crate) fn set_tag_for_task(task: TaskId, tag: Arc<dyn Tag>) -> Option<Arc<dyn Tag>> {
ExecutionState::with(|s| s.get_mut(task).set_tag(tag))
}
Expand Down
Loading
Loading