Skip to content

Commit 22339b2

Browse files
committed
fix: fix AtomicWaker implementation
1 parent 8c83d1c commit 22339b2

File tree

7 files changed

+169
-89
lines changed

7 files changed

+169
-89
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@ jobs:
1111
runs-on: ubuntu-latest
1212
steps:
1313
- uses: actions/checkout@v3
14-
- uses: actions-rs/toolchain@v1
15-
with:
16-
toolchain: nightly
17-
override: true
18-
components: rustfmt, clippy, miri
19-
- run: cargo +nightly fmt --check
20-
- run: cargo clippy --no-default-features -- -D warnings
21-
- run: cargo clippy --all-features --tests -- -D warnings
22-
- run: cargo test --all-features
23-
- run: cargo +nightly miri test --all-features
14+
- name: rustfmt
15+
run: cargo fmt -- --config "unstable_features=true,imports_granularity=Crate,group_imports=StdExternalCrate,format_code_in_doc_comments=true"
16+
- name: clippy no-default-feature
17+
run: cargo clippy --no-default-features -- -D warnings
18+
- name: clippy all-features
19+
run: cargo clippy --all-features --tests -- -D warnings
20+
- name: test
21+
run: cargo test --all-features
22+
- name: miri
23+
run: cargo +nightly miri test --all-features --all-seeds
24+
- name: loom
25+
run: cargo test --release --lib
26+
env:
27+
RUSTFLAGS: "--cfg loom"

Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ tokio-test = "0.4"
3636
tokio = { version = "1", features = ["macros", "rt-multi-thread", "test-util", "time"] }
3737

3838
[target.'cfg(loom)'.dev-dependencies]
39-
loom = { version = "0.6.0", features = ["futures"] }
39+
loom = { version = "0.7", features = ["futures"] }
4040

41-
# See tokio Cargo.toml
41+
[lints.rust]
42+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(loom)'] }
43+
44+
# see https://users.rust-lang.org/t/how-to-document-optional-features-in-api-docs/64577/3
4245
[package.metadata.docs.rs]
4346
all-features = true
4447
rustdoc-args = [

src/loom.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#[cfg(not(all(loom, test)))]
22
mod without_loom {
3+
#[cfg(not(feature = "std"))]
4+
pub(crate) use core::sync;
5+
pub(crate) use core::{cell, hint};
36
#[cfg(feature = "std")]
4-
pub(crate) use std::thread;
5-
pub(crate) use std::{cell, sync};
7+
pub(crate) use std::{sync, thread};
68

79
pub(crate) const SPIN_LIMIT: usize = 64;
810
pub(crate) const BACKOFF_LIMIT: usize = 6;
@@ -11,6 +13,10 @@ mod without_loom {
1113
pub(crate) struct LoomUnsafeCell<T>(cell::UnsafeCell<T>);
1214

1315
impl<T> LoomUnsafeCell<T> {
16+
pub(crate) const fn new(data: T) -> Self {
17+
Self(cell::UnsafeCell::new(data))
18+
}
19+
1420
pub(crate) fn with<R>(&self, f: impl FnOnce(*const T) -> R) -> R {
1521
f(self.0.get())
1622
}
@@ -28,7 +34,7 @@ pub(crate) use without_loom::*;
2834
mod with_loom {
2935
#[cfg(feature = "std")]
3036
pub(crate) use loom::thread;
31-
pub(crate) use loom::{cell, sync};
37+
pub(crate) use loom::{cell, hint, sync};
3238

3339
pub(crate) const SPIN_LIMIT: usize = 1;
3440
pub(crate) const BACKOFF_LIMIT: usize = 1;

src/queue.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::{
66
buffer::{Buffer, BufferSlice, Drain, InsertIntoBuffer, Resize},
77
error::{TryDequeueError, TryEnqueueError},
88
loom::{
9+
hint,
910
sync::atomic::{AtomicUsize, Ordering},
1011
LoomUnsafeCell, BACKOFF_LIMIT, SPIN_LIMIT,
1112
},

src/synchronized.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99
use crate::{
1010
buffer::{Buffer, BufferSlice, Drain, InsertIntoBuffer},
1111
error::{DequeueError, EnqueueError, TryDequeueError, TryEnqueueError},
12-
loom::{thread, SPIN_LIMIT},
12+
loom::{hint, thread, SPIN_LIMIT},
1313
notify::Notify,
1414
synchronized::{atomic_waker::AtomicWaker, waker_list::WakerList},
1515
Queue,
@@ -429,7 +429,7 @@ where
429429
}
430430
res => return Ok(res),
431431
};
432-
std::hint::spin_loop();
432+
hint::spin_loop();
433433
}
434434
queue.notify().enqueuers.register(cx);
435435
match queue.try_enqueue(value) {
@@ -450,7 +450,7 @@ where
450450
Err(TryDequeueError::Empty | TryDequeueError::Pending) => {}
451451
res => return Some(res),
452452
}
453-
std::hint::spin_loop();
453+
hint::spin_loop();
454454
}
455455
queue.notify().dequeuer.register(cx);
456456
match queue.try_dequeue() {

src/synchronized/atomic_waker.rs

Lines changed: 127 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,159 @@
1-
use std::{mem::MaybeUninit, task};
1+
use std::task;
22

33
use crate::{
44
loom::{
5-
cell::Cell,
6-
sync::atomic::{AtomicU8, Ordering},
7-
thread,
5+
hint,
6+
sync::atomic::{AtomicUsize, Ordering},
7+
thread, LoomUnsafeCell,
88
},
99
synchronized::waker::Waker,
1010
};
1111

12-
const EMPTY: u8 = 0b000;
13-
const REGISTERING: u8 = 0b001;
14-
const REGISTERED: u8 = 0b010;
15-
const WAKING_FLAG: u8 = 0b100;
16-
12+
// I have to reimplement AtomicWaker because it doesn't use standard `Waker`,
13+
// see https://internals.rust-lang.org/t/thread-park-waker-a-waker-calling-thread-unpark/19114
1714
pub(super) struct AtomicWaker {
18-
waker: Cell<MaybeUninit<Waker>>,
19-
state: AtomicU8,
15+
state: AtomicUsize,
16+
waker: LoomUnsafeCell<Option<Waker>>,
2017
}
2118

22-
impl Default for AtomicWaker {
23-
fn default() -> Self {
19+
const WAITING: usize = 0;
20+
const REGISTERING: usize = 0b01;
21+
const WAKING: usize = 0b10;
22+
23+
impl AtomicWaker {
24+
pub fn new() -> Self {
2425
Self {
25-
waker: Cell::new(MaybeUninit::uninit()),
26-
state: AtomicU8::new(0),
26+
state: AtomicUsize::new(WAITING),
27+
waker: LoomUnsafeCell::new(None),
2728
}
2829
}
29-
}
30-
31-
// SAFETY: `AtomicWaker` waker `Cell` access is synchronized using state
32-
// (see `AtomicWaker::register`/`AtomicWaker::wake`)
33-
unsafe impl Send for AtomicWaker {}
34-
35-
// SAFETY: `AtomicWaker` waker `Cell` access is synchronized using state
36-
// (see `AtomicWaker::register`/`AtomicWaker::wake`)
37-
unsafe impl Sync for AtomicWaker {}
3830

39-
impl AtomicWaker {
4031
#[inline]
4132
pub(super) fn register(&self, cx: Option<&task::Context>) {
42-
let mut state = self.state.load(Ordering::Relaxed);
43-
loop {
44-
if state != EMPTY && state != REGISTERED {
33+
match self
34+
.state
35+
.compare_exchange(WAITING, REGISTERING, Ordering::Acquire, Ordering::Acquire)
36+
.unwrap_or_else(|x| x)
37+
{
38+
WAITING => {
39+
// SAFETY: see `futures::task::AtomicWaker` implementation
40+
unsafe {
41+
self.waker.with_mut(|w| match &mut *w {
42+
Some(old_waker) if old_waker.will_wake(cx) => (),
43+
_ => *w = Some(Waker::new(cx)),
44+
});
45+
let res = self.state.compare_exchange(
46+
REGISTERING,
47+
WAITING,
48+
Ordering::AcqRel,
49+
Ordering::Acquire,
50+
);
51+
52+
match res {
53+
Ok(_) => {}
54+
Err(actual) => {
55+
debug_assert_eq!(actual, REGISTERING | WAKING);
56+
let waker = self.waker.with_mut(|w| (*w).take()).unwrap();
57+
self.state.swap(WAITING, Ordering::AcqRel);
58+
waker.wake();
59+
}
60+
}
61+
}
62+
}
63+
WAKING => {
4564
match cx {
4665
Some(cx) => cx.waker().wake_by_ref(),
4766
None => thread::current().unpark(),
4867
}
49-
return;
68+
// SAFETY: see `AtomicWaker` implementation from tokio (it's needed by loom)
69+
hint::spin_loop();
5070
}
51-
match self.state.compare_exchange_weak(
52-
state,
53-
REGISTERING,
54-
Ordering::AcqRel,
55-
Ordering::Relaxed,
56-
) {
57-
Ok(_) => break,
58-
Err(s) => state = s,
71+
state => {
72+
debug_assert!(state == REGISTERING || state == REGISTERING | WAKING);
5973
}
6074
}
61-
let prev = self.waker.replace(MaybeUninit::new(Waker::new(cx)));
62-
if state == REGISTERED {
63-
// SAFETY: state was previously `REGISTERED`, so waker was initialized.
64-
unsafe { prev.assume_init() };
65-
}
66-
if let Err(state) = self.state.compare_exchange(
67-
REGISTERING,
68-
REGISTERED,
69-
Ordering::AcqRel,
70-
Ordering::Acquire,
71-
) {
72-
debug_assert_eq!(state, REGISTERING | WAKING_FLAG);
73-
// SAFETY: waker has been initialized a few lines above, and cannot be read
74-
// by `Self::wake` because state has not been set to `REGISTERED`.
75-
unsafe { self.waker.replace(MaybeUninit::uninit()).assume_init() }.wake();
76-
self.state.store(EMPTY, Ordering::Release);
77-
}
7875
}
7976

8077
#[inline]
8178
pub(super) fn wake(&self) {
82-
let mut state = self.state.load(Ordering::Relaxed);
83-
loop {
84-
if state != REGISTERING && state != REGISTERED {
85-
return;
79+
match self.state.fetch_or(WAKING, Ordering::AcqRel) {
80+
WAITING => {
81+
// SAFETY: see `futures::task::AtomicWaker` implementation
82+
let waker = unsafe { self.waker.with_mut(|w| (*w).take()) };
83+
self.state.fetch_and(!WAKING, Ordering::Release);
84+
if let Some(waker) = waker {
85+
waker.wake()
86+
}
8687
}
87-
match self.state.compare_exchange_weak(
88-
state,
89-
state | WAKING_FLAG,
90-
Ordering::AcqRel,
91-
Ordering::Relaxed,
92-
) {
93-
Ok(REGISTERED) => break,
94-
Ok(_) => return,
95-
Err(s) => state = s,
88+
state => {
89+
debug_assert!(
90+
state == REGISTERING || state == REGISTERING | WAKING || state == WAKING
91+
);
9692
}
9793
}
98-
// SAFETY: state was `REGISTERED`, so waker has been registered and can no more
99-
// be read/modified by `Self::register`
100-
unsafe { self.waker.replace(MaybeUninit::uninit()).assume_init() }.wake();
101-
self.state.store(EMPTY, Ordering::Release);
94+
}
95+
}
96+
97+
impl Default for AtomicWaker {
98+
fn default() -> Self {
99+
Self::new()
100+
}
101+
}
102+
103+
// SAFETY: see `futures::task::AtomicWaker` implementation
104+
unsafe impl Send for AtomicWaker {}
105+
106+
// SAFETY: see `futures::task::AtomicWaker` implementation
107+
unsafe impl Sync for AtomicWaker {}
108+
109+
#[cfg(all(test, loom))]
110+
mod tests {
111+
use std::{
112+
future::poll_fn,
113+
sync::Arc,
114+
task::Poll::{Pending, Ready},
115+
};
116+
117+
use loom::{
118+
future::block_on,
119+
sync::atomic::{AtomicUsize, Ordering},
120+
thread,
121+
};
122+
123+
use super::AtomicWaker;
124+
125+
struct Chan {
126+
num: AtomicUsize,
127+
task: AtomicWaker,
128+
}
129+
#[test]
130+
fn basic_notification() {
131+
const NUM_NOTIFY: usize = 2;
132+
133+
loom::model(|| {
134+
let chan = Arc::new(Chan {
135+
num: AtomicUsize::new(0),
136+
task: AtomicWaker::default(),
137+
});
138+
139+
for _ in 0..NUM_NOTIFY {
140+
let chan = chan.clone();
141+
142+
thread::spawn(move || {
143+
chan.num.fetch_add(1, Ordering::Relaxed);
144+
chan.task.wake();
145+
});
146+
}
147+
148+
block_on(poll_fn(move |cx| {
149+
chan.task.register(Some(cx));
150+
151+
if NUM_NOTIFY == chan.num.load(Ordering::Relaxed) {
152+
return Ready(());
153+
}
154+
155+
Pending
156+
}));
157+
});
102158
}
103159
}

src/synchronized/waker.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,23 @@ pub(super) enum Waker {
99
}
1010

1111
impl Waker {
12+
#[inline]
1213
pub(super) fn new(cx: Option<&task::Context>) -> Self {
1314
match cx {
1415
Some(cx) => Self::Async(cx.waker().clone()),
1516
None => Self::Sync(thread::current()),
1617
}
1718
}
1819

20+
#[inline]
21+
pub(super) fn will_wake(&self, cx: Option<&task::Context>) -> bool {
22+
match (self, cx) {
23+
(Self::Async(w), Some(cx)) => w.will_wake(cx.waker()),
24+
_ => false,
25+
}
26+
}
27+
28+
#[inline]
1929
pub(super) fn wake(self) {
2030
match self {
2131
Self::Async(waker) => waker.wake(),

0 commit comments

Comments
 (0)