|
1 |
| -use std::{mem::MaybeUninit, task}; |
| 1 | +use std::task; |
2 | 2 |
|
3 | 3 | use crate::{
|
4 | 4 | loom::{
|
5 |
| - cell::Cell, |
6 |
| - sync::atomic::{AtomicU8, Ordering}, |
7 |
| - thread, |
| 5 | + hint, |
| 6 | + sync::atomic::{AtomicUsize, Ordering}, |
| 7 | + thread, LoomUnsafeCell, |
8 | 8 | },
|
9 | 9 | synchronized::waker::Waker,
|
10 | 10 | };
|
11 | 11 |
|
12 |
| -const EMPTY: u8 = 0b000; |
13 |
| -const REGISTERING: u8 = 0b001; |
14 |
| -const REGISTERED: u8 = 0b010; |
15 |
| -const WAKING_FLAG: u8 = 0b100; |
16 |
| - |
| 12 | +// I have to reimplement AtomicWaker because it doesn't use standard `Waker`, |
| 13 | +// see https://internals.rust-lang.org/t/thread-park-waker-a-waker-calling-thread-unpark/19114 |
17 | 14 | pub(super) struct AtomicWaker {
|
18 |
| - waker: Cell<MaybeUninit<Waker>>, |
19 |
| - state: AtomicU8, |
| 15 | + state: AtomicUsize, |
| 16 | + waker: LoomUnsafeCell<Option<Waker>>, |
20 | 17 | }
|
21 | 18 |
|
22 |
| -impl Default for AtomicWaker { |
23 |
| - fn default() -> Self { |
| 19 | +const WAITING: usize = 0; |
| 20 | +const REGISTERING: usize = 0b01; |
| 21 | +const WAKING: usize = 0b10; |
| 22 | + |
| 23 | +impl AtomicWaker { |
| 24 | + pub fn new() -> Self { |
24 | 25 | Self {
|
25 |
| - waker: Cell::new(MaybeUninit::uninit()), |
26 |
| - state: AtomicU8::new(0), |
| 26 | + state: AtomicUsize::new(WAITING), |
| 27 | + waker: LoomUnsafeCell::new(None), |
27 | 28 | }
|
28 | 29 | }
|
29 |
| -} |
30 |
| - |
31 |
| -// SAFETY: `AtomicWaker` waker `Cell` access is synchronized using state |
32 |
| -// (see `AtomicWaker::register`/`AtomicWaker::wake`) |
33 |
| -unsafe impl Send for AtomicWaker {} |
34 |
| - |
35 |
| -// SAFETY: `AtomicWaker` waker `Cell` access is synchronized using state |
36 |
| -// (see `AtomicWaker::register`/`AtomicWaker::wake`) |
37 |
| -unsafe impl Sync for AtomicWaker {} |
38 | 30 |
|
39 |
| -impl AtomicWaker { |
40 | 31 | #[inline]
|
41 | 32 | pub(super) fn register(&self, cx: Option<&task::Context>) {
|
42 |
| - let mut state = self.state.load(Ordering::Relaxed); |
43 |
| - loop { |
44 |
| - if state != EMPTY && state != REGISTERED { |
| 33 | + match self |
| 34 | + .state |
| 35 | + .compare_exchange(WAITING, REGISTERING, Ordering::Acquire, Ordering::Acquire) |
| 36 | + .unwrap_or_else(|x| x) |
| 37 | + { |
| 38 | + WAITING => { |
| 39 | + // SAFETY: see `futures::task::AtomicWaker` implementation |
| 40 | + unsafe { |
| 41 | + self.waker.with_mut(|w| match &mut *w { |
| 42 | + Some(old_waker) if old_waker.will_wake(cx) => (), |
| 43 | + _ => *w = Some(Waker::new(cx)), |
| 44 | + }); |
| 45 | + let res = self.state.compare_exchange( |
| 46 | + REGISTERING, |
| 47 | + WAITING, |
| 48 | + Ordering::AcqRel, |
| 49 | + Ordering::Acquire, |
| 50 | + ); |
| 51 | + |
| 52 | + match res { |
| 53 | + Ok(_) => {} |
| 54 | + Err(actual) => { |
| 55 | + debug_assert_eq!(actual, REGISTERING | WAKING); |
| 56 | + let waker = self.waker.with_mut(|w| (*w).take()).unwrap(); |
| 57 | + self.state.swap(WAITING, Ordering::AcqRel); |
| 58 | + waker.wake(); |
| 59 | + } |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + WAKING => { |
45 | 64 | match cx {
|
46 | 65 | Some(cx) => cx.waker().wake_by_ref(),
|
47 | 66 | None => thread::current().unpark(),
|
48 | 67 | }
|
49 |
| - return; |
| 68 | + // SAFETY: see `AtomicWaker` implementation from tokio (it's needed by loom) |
| 69 | + hint::spin_loop(); |
50 | 70 | }
|
51 |
| - match self.state.compare_exchange_weak( |
52 |
| - state, |
53 |
| - REGISTERING, |
54 |
| - Ordering::AcqRel, |
55 |
| - Ordering::Relaxed, |
56 |
| - ) { |
57 |
| - Ok(_) => break, |
58 |
| - Err(s) => state = s, |
| 71 | + state => { |
| 72 | + debug_assert!(state == REGISTERING || state == REGISTERING | WAKING); |
59 | 73 | }
|
60 | 74 | }
|
61 |
| - let prev = self.waker.replace(MaybeUninit::new(Waker::new(cx))); |
62 |
| - if state == REGISTERED { |
63 |
| - // SAFETY: state was previously `REGISTERED`, so waker was initialized. |
64 |
| - unsafe { prev.assume_init() }; |
65 |
| - } |
66 |
| - if let Err(state) = self.state.compare_exchange( |
67 |
| - REGISTERING, |
68 |
| - REGISTERED, |
69 |
| - Ordering::AcqRel, |
70 |
| - Ordering::Acquire, |
71 |
| - ) { |
72 |
| - debug_assert_eq!(state, REGISTERING | WAKING_FLAG); |
73 |
| - // SAFETY: waker has been initialized a few lines above, and cannot be read |
74 |
| - // by `Self::wake` because state has not been set to `REGISTERED`. |
75 |
| - unsafe { self.waker.replace(MaybeUninit::uninit()).assume_init() }.wake(); |
76 |
| - self.state.store(EMPTY, Ordering::Release); |
77 |
| - } |
78 | 75 | }
|
79 | 76 |
|
80 | 77 | #[inline]
|
81 | 78 | pub(super) fn wake(&self) {
|
82 |
| - let mut state = self.state.load(Ordering::Relaxed); |
83 |
| - loop { |
84 |
| - if state != REGISTERING && state != REGISTERED { |
85 |
| - return; |
| 79 | + match self.state.fetch_or(WAKING, Ordering::AcqRel) { |
| 80 | + WAITING => { |
| 81 | + // SAFETY: see `futures::task::AtomicWaker` implementation |
| 82 | + let waker = unsafe { self.waker.with_mut(|w| (*w).take()) }; |
| 83 | + self.state.fetch_and(!WAKING, Ordering::Release); |
| 84 | + if let Some(waker) = waker { |
| 85 | + waker.wake() |
| 86 | + } |
86 | 87 | }
|
87 |
| - match self.state.compare_exchange_weak( |
88 |
| - state, |
89 |
| - state | WAKING_FLAG, |
90 |
| - Ordering::AcqRel, |
91 |
| - Ordering::Relaxed, |
92 |
| - ) { |
93 |
| - Ok(REGISTERED) => break, |
94 |
| - Ok(_) => return, |
95 |
| - Err(s) => state = s, |
| 88 | + state => { |
| 89 | + debug_assert!( |
| 90 | + state == REGISTERING || state == REGISTERING | WAKING || state == WAKING |
| 91 | + ); |
96 | 92 | }
|
97 | 93 | }
|
98 |
| - // SAFETY: state was `REGISTERED`, so waker has been registered and can no more |
99 |
| - // be read/modified by `Self::register` |
100 |
| - unsafe { self.waker.replace(MaybeUninit::uninit()).assume_init() }.wake(); |
101 |
| - self.state.store(EMPTY, Ordering::Release); |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +impl Default for AtomicWaker { |
| 98 | + fn default() -> Self { |
| 99 | + Self::new() |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +// SAFETY: see `futures::task::AtomicWaker` implementation |
| 104 | +unsafe impl Send for AtomicWaker {} |
| 105 | + |
| 106 | +// SAFETY: see `futures::task::AtomicWaker` implementation |
| 107 | +unsafe impl Sync for AtomicWaker {} |
| 108 | + |
| 109 | +#[cfg(all(test, loom))] |
| 110 | +mod tests { |
| 111 | + use std::{ |
| 112 | + future::poll_fn, |
| 113 | + sync::Arc, |
| 114 | + task::Poll::{Pending, Ready}, |
| 115 | + }; |
| 116 | + |
| 117 | + use loom::{ |
| 118 | + future::block_on, |
| 119 | + sync::atomic::{AtomicUsize, Ordering}, |
| 120 | + thread, |
| 121 | + }; |
| 122 | + |
| 123 | + use super::AtomicWaker; |
| 124 | + |
| 125 | + struct Chan { |
| 126 | + num: AtomicUsize, |
| 127 | + task: AtomicWaker, |
| 128 | + } |
| 129 | + #[test] |
| 130 | + fn basic_notification() { |
| 131 | + const NUM_NOTIFY: usize = 2; |
| 132 | + |
| 133 | + loom::model(|| { |
| 134 | + let chan = Arc::new(Chan { |
| 135 | + num: AtomicUsize::new(0), |
| 136 | + task: AtomicWaker::default(), |
| 137 | + }); |
| 138 | + |
| 139 | + for _ in 0..NUM_NOTIFY { |
| 140 | + let chan = chan.clone(); |
| 141 | + |
| 142 | + thread::spawn(move || { |
| 143 | + chan.num.fetch_add(1, Ordering::Relaxed); |
| 144 | + chan.task.wake(); |
| 145 | + }); |
| 146 | + } |
| 147 | + |
| 148 | + block_on(poll_fn(move |cx| { |
| 149 | + chan.task.register(Some(cx)); |
| 150 | + |
| 151 | + if NUM_NOTIFY == chan.num.load(Ordering::Relaxed) { |
| 152 | + return Ready(()); |
| 153 | + } |
| 154 | + |
| 155 | + Pending |
| 156 | + })); |
| 157 | + }); |
102 | 158 | }
|
103 | 159 | }
|
0 commit comments