Skip to content

Commit c00f142

Browse files
authored
fix: remove thread locals to fix unsoundness outlined in #1
2 parents 7abf348 + fd26250 commit c00f142

File tree

6 files changed

+171
-165
lines changed

6 files changed

+171
-165
lines changed

.github/workflows/test.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,3 @@ jobs:
2828
- uses: Swatinem/rust-cache@v1
2929
- name: Run tests
3030
run: cargo test
31-
test-no-std:
32-
name: Run tests (unstable-thread-local)
33-
runs-on: ubuntu-24.04
34-
steps:
35-
- uses: actions/checkout@v4
36-
- name: Install nightly Rust toolchain
37-
uses: dtolnay/rust-toolchain@nightly
38-
with:
39-
toolchain: nightly-2024-11-11
40-
- uses: Swatinem/rust-cache@v1
41-
- name: Run tests
42-
run: cargo test --features unstable-thread-local

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,3 @@ tokio = { version = "1", features = [ "full", "macros" ] }
1717
futures = "0.3"
1818

1919
[features]
20-
unstable-thread-local = []

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ use futures_util::{pin_mut, stream::StreamExt};
88

99
#[tokio::main]
1010
async fn main() {
11-
let stream = async_stream(|r#yield| async move {
11+
let stream = async_stream(|yielder| async move {
1212
for i in 0..3 {
13-
r#yield(i).await;
13+
yielder.r#yield(i).await;
1414
}
1515
});
1616
pin_mut!(stream);
@@ -21,4 +21,4 @@ async fn main() {
2121
```
2222

2323
## `#![no_std]` support
24-
`async-stream-lite` supports `#![no_std]` on nightly Rust (due to the usage of [the unstable `#[thread_local]` attribute](https://doc.rust-lang.org/beta/unstable-book/language-features/thread-local.html)). To enable `#![no_std]` support, enable the `unstable-thread-local` feature.
24+
`async-stream-lite` supports `#![no_std]`, but requires `alloc`.

src/lib.rs

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#![allow(clippy::tabs_in_doc_comments)]
2-
#![cfg_attr(feature = "unstable-thread-local", feature(thread_local))]
3-
#![cfg_attr(all(not(test), feature = "unstable-thread-local"), no_std)]
2+
#![cfg_attr(not(test), no_std)]
43

4+
extern crate alloc;
55
extern crate core;
66

7+
use alloc::sync::{Arc, Weak};
78
use core::{
89
cell::Cell,
910
future::Future,
1011
marker::PhantomData,
1112
pin::Pin,
12-
ptr,
13+
sync::atomic::{AtomicBool, Ordering},
1314
task::{Context, Poll}
1415
};
1516

@@ -19,91 +20,96 @@ use futures_core::stream::{FusedStream, Stream};
1920
mod tests;
2021
mod r#try;
2122

22-
#[cfg(not(feature = "unstable-thread-local"))]
23-
thread_local! {
24-
static STORE: Cell<*mut ()> = const { Cell::new(ptr::null_mut()) };
23+
pub(crate) struct SharedStore<T> {
24+
entered: AtomicBool,
25+
cell: Cell<Option<T>>
2526
}
26-
#[cfg(feature = "unstable-thread-local")]
27-
#[thread_local]
28-
static STORE: Cell<*mut ()> = Cell::new(ptr::null_mut());
2927

30-
pub(crate) fn r#yield<T>(value: T) -> YieldFut<T> {
31-
YieldFut { value: Some(value) }
28+
impl<T> Default for SharedStore<T> {
29+
fn default() -> Self {
30+
Self {
31+
entered: AtomicBool::new(false),
32+
cell: Cell::new(None)
33+
}
34+
}
35+
}
36+
37+
impl<T> SharedStore<T> {
38+
pub fn has_value(&self) -> bool {
39+
unsafe { &*self.cell.as_ptr() }.is_some()
40+
}
41+
}
42+
43+
unsafe impl<T> Sync for SharedStore<T> {}
44+
45+
pub struct Yielder<T> {
46+
pub(crate) store: Weak<SharedStore<T>>
47+
}
48+
49+
impl<T> Yielder<T> {
50+
pub fn r#yield(&self, value: T) -> YieldFut<'_, T> {
51+
#[cold]
52+
fn invalid_usage() -> ! {
53+
panic!("attempted to use async_stream_lite yielder outside of stream context or across threads")
54+
}
55+
56+
let Some(store) = self.store.upgrade() else {
57+
invalid_usage();
58+
};
59+
if !store.entered.load(Ordering::Relaxed) {
60+
invalid_usage();
61+
}
62+
63+
store.cell.replace(Some(value));
64+
65+
YieldFut { store, _p: PhantomData }
66+
}
3267
}
3368

3469
/// Future returned by an [`AsyncStream`]'s yield function.
3570
///
3671
/// This future must be `.await`ed inside the generator in order for the item to be yielded by the stream.
3772
#[must_use = "stream will not yield this item unless the future returned by yield is awaited"]
38-
pub struct YieldFut<T> {
39-
value: Option<T>
73+
pub struct YieldFut<'y, T> {
74+
store: Arc<SharedStore<T>>,
75+
_p: PhantomData<&'y ()>
4076
}
4177

42-
impl<T> Unpin for YieldFut<T> {}
78+
impl<T> Unpin for YieldFut<'_, T> {}
4379

44-
impl<T> Future for YieldFut<T> {
80+
impl<T> Future for YieldFut<'_, T> {
4581
type Output = ();
4682

47-
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
48-
if self.value.is_none() {
83+
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
84+
if !self.store.has_value() {
4985
return Poll::Ready(());
5086
}
5187

52-
fn op<T>(cell: &Cell<*mut ()>, value: &mut Option<T>) {
53-
let ptr = cell.get().cast::<Option<T>>();
54-
let option_ref = unsafe { ptr.as_mut() }.expect("attempted to use async_stream yielder outside of stream context or across threads");
55-
if option_ref.is_none() {
56-
*option_ref = value.take();
57-
}
58-
}
59-
60-
#[cfg(not(feature = "unstable-thread-local"))]
61-
return STORE.with(|cell| {
62-
op(cell, &mut self.value);
63-
Poll::Pending
64-
});
65-
#[cfg(feature = "unstable-thread-local")]
66-
{
67-
op(&STORE, &mut self.value);
68-
Poll::Pending
69-
}
88+
Poll::Pending
7089
}
7190
}
7291

73-
struct Enter<'a, T> {
74-
_p: PhantomData<&'a T>,
75-
prev: *mut ()
92+
struct Enter<'s, T> {
93+
store: &'s SharedStore<T>
7694
}
7795

78-
fn enter<T>(dst: &mut Option<T>) -> Enter<'_, T> {
79-
fn op<T>(cell: &Cell<*mut ()>, dst: &mut Option<T>) -> *mut () {
80-
let prev = cell.get();
81-
cell.set((dst as *mut Option<T>).cast::<()>());
82-
prev
83-
}
84-
#[cfg(not(feature = "unstable-thread-local"))]
85-
let prev = STORE.with(|cell| op(cell, dst));
86-
#[cfg(feature = "unstable-thread-local")]
87-
let prev = op(&STORE, dst);
88-
Enter { _p: PhantomData, prev }
96+
fn enter<T>(store: &SharedStore<T>) -> Enter<'_, T> {
97+
store.entered.store(true, Ordering::Relaxed);
98+
Enter { store }
8999
}
90100

91101
impl<T> Drop for Enter<'_, T> {
92102
fn drop(&mut self) {
93-
#[cfg(not(feature = "unstable-thread-local"))]
94-
STORE.with(|cell| cell.set(self.prev));
95-
#[cfg(feature = "unstable-thread-local")]
96-
STORE.set(self.prev);
103+
self.store.entered.store(false, Ordering::Relaxed);
97104
}
98105
}
99106

100107
pin_project_lite::pin_project! {
101108
/// A [`Stream`] created from an asynchronous generator-like function.
102109
///
103110
/// To create an [`AsyncStream`], use the [`async_stream`] function.
104-
#[derive(Debug)]
105111
pub struct AsyncStream<T, U> {
106-
_p: PhantomData<T>,
112+
store: Arc<SharedStore<T>>,
107113
done: bool,
108114
#[pin]
109115
generator: U
@@ -131,16 +137,15 @@ where
131137
return Poll::Ready(None);
132138
}
133139

134-
let mut dst = None;
135140
let res = {
136-
let _enter = enter(&mut dst);
141+
let _enter = enter(&me.store);
137142
me.generator.poll(cx)
138143
};
139144

140145
*me.done = res.is_ready();
141146

142-
if dst.is_some() {
143-
return Poll::Ready(dst.take());
147+
if me.store.has_value() {
148+
return Poll::Ready(me.store.cell.take());
144149
}
145150

146151
if *me.done { Poll::Ready(None) } else { Poll::Pending }
@@ -153,16 +158,16 @@ where
153158

154159
/// Create an asynchronous [`Stream`] from an asynchronous generator function.
155160
///
156-
/// The provided function will be given a "yielder" function, which, when called, causes the stream to yield an item:
161+
/// The provided function will be given a [`Yielder`], which, when called, causes the stream to yield an item:
157162
/// ```
158163
/// use async_stream_lite::async_stream;
159164
/// use futures::{pin_mut, stream::StreamExt};
160165
///
161166
/// #[tokio::main]
162167
/// async fn main() {
163-
/// let stream = async_stream(|r#yield| async move {
168+
/// let stream = async_stream(|yielder| async move {
164169
/// for i in 0..3 {
165-
/// r#yield(i).await;
170+
/// yielder.r#yield(i).await;
166171
/// }
167172
/// });
168173
/// pin_mut!(stream);
@@ -181,9 +186,9 @@ where
181186
/// };
182187
///
183188
/// fn zero_to_three() -> impl Stream<Item = u32> {
184-
/// async_stream(|r#yield| async move {
189+
/// async_stream(|yielder| async move {
185190
/// for i in 0..3 {
186-
/// r#yield(i).await;
191+
/// yielder.r#yield(i).await;
187192
/// }
188193
/// })
189194
/// }
@@ -207,9 +212,9 @@ where
207212
/// };
208213
///
209214
/// fn zero_to_three() -> BoxStream<'static, u32> {
210-
/// Box::pin(async_stream(|r#yield| async move {
215+
/// Box::pin(async_stream(|yielder| async move {
211216
/// for i in 0..3 {
212-
/// r#yield(i).await;
217+
/// yielder.r#yield(i).await;
213218
/// }
214219
/// }))
215220
/// }
@@ -232,18 +237,18 @@ where
232237
/// };
233238
///
234239
/// fn zero_to_three() -> impl Stream<Item = u32> {
235-
/// async_stream(|r#yield| async move {
240+
/// async_stream(|yielder| async move {
236241
/// for i in 0..3 {
237-
/// r#yield(i).await;
242+
/// yielder.r#yield(i).await;
238243
/// }
239244
/// })
240245
/// }
241246
///
242247
/// fn double<S: Stream<Item = u32>>(input: S) -> impl Stream<Item = u32> {
243-
/// async_stream(|r#yield| async move {
248+
/// async_stream(|yielder| async move {
244249
/// pin_mut!(input);
245250
/// while let Some(value) = input.next().await {
246-
/// r#yield(value * 2).await;
251+
/// yielder.r#yield(value * 2).await;
247252
/// }
248253
/// })
249254
/// }
@@ -261,15 +266,12 @@ where
261266
/// See also [`try_async_stream`], a variant of [`async_stream`] which supports try notation (`?`).
262267
pub fn async_stream<T, F, U>(generator: F) -> AsyncStream<T, U>
263268
where
264-
F: FnOnce(fn(value: T) -> YieldFut<T>) -> U,
269+
F: FnOnce(Yielder<T>) -> U,
265270
U: Future<Output = ()>
266271
{
267-
let generator = generator(r#yield::<T>);
268-
AsyncStream {
269-
_p: PhantomData,
270-
done: false,
271-
generator
272-
}
272+
let store = Arc::new(SharedStore::default());
273+
let generator = generator(Yielder { store: Arc::downgrade(&store) });
274+
AsyncStream { store, done: false, generator }
273275
}
274276

275277
pub use self::r#try::{TryAsyncStream, try_async_stream};

0 commit comments

Comments
 (0)