Skip to content

Commit d3bb304

Browse files
authored
Merge pull request #171 from hyperware-ai/j/spawn-join
Feature: join handle + waker
2 parents 519ed27 + fadce45 commit d3bb304

File tree

3 files changed

+94
-17
lines changed

3 files changed

+94
-17
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ repository = "https://github.com/hyperware-ai/process_lib"
99
license = "Apache-2.0"
1010

1111
[features]
12-
hyperapp = ["dep:futures-util", "dep:uuid", "logging"]
12+
hyperapp = ["dep:futures-util", "dep:futures-channel", "dep:uuid", "logging"]
1313
logging = ["dep:color-eyre", "dep:tracing", "dep:tracing-error", "dep:tracing-subscriber"]
1414
hyperwallet = ["dep:hex", "dep:sha3"]
1515
simulation-mode = []
@@ -42,6 +42,7 @@ url = "2.4.1"
4242
wit-bindgen = "0.42.1"
4343

4444
futures-util = { version = "0.3", optional = true }
45+
futures-channel = { version = "0.3", optional = true }
4546
uuid = { version = "1.0", features = ["v4"], optional = true }
4647

4748
color-eyre = { version = "0.6", features = ["capture-spantrace"], optional = true }

src/hyperapp.rs

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
use std::cell::RefCell;
2-
use std::collections::HashMap;
2+
use std::collections::{HashMap, HashSet};
33
use std::future::Future;
44
use std::pin::Pin;
5+
use std::sync::{
6+
atomic::{AtomicBool, Ordering},
7+
Arc,
8+
};
59
use std::task::{Context, Poll};
610

711
use crate::{
@@ -10,21 +14,22 @@ use crate::{
1014
logging::{error, info},
1115
set_state, timer, Address, BuildError, LazyLoadBlob, Message, Request, SendError,
1216
};
13-
use futures_util::task::noop_waker_ref;
17+
use futures_channel::{mpsc, oneshot};
18+
use futures_util::task::{waker_ref, ArcWake};
1419
use serde::{Deserialize, Serialize};
1520
use thiserror::Error;
1621
use uuid::Uuid;
1722

1823
thread_local! {
1924
static SPAWN_QUEUE: RefCell<Vec<Pin<Box<dyn Future<Output = ()>>>>> = RefCell::new(Vec::new());
2025

21-
2226
pub static APP_CONTEXT: RefCell<AppContext> = RefCell::new(AppContext {
2327
hidden_state: None,
2428
executor: Executor::new(),
2529
});
2630

2731
pub static RESPONSE_REGISTRY: RefCell<HashMap<String, Vec<u8>>> = RefCell::new(HashMap::new());
32+
pub static CANCELLED_RESPONSES: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
2833

2934
pub static APP_HELPERS: RefCell<AppHelpers> = RefCell::new(AppHelpers {
3035
current_server: None,
@@ -146,10 +151,53 @@ pub struct Executor {
146151
tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
147152
}
148153

149-
pub fn spawn(fut: impl Future<Output = ()> + 'static) {
154+
struct ExecutorWakeFlag {
155+
triggered: AtomicBool,
156+
}
157+
158+
impl ExecutorWakeFlag {
159+
fn new() -> Self {
160+
Self {
161+
triggered: AtomicBool::new(false),
162+
}
163+
}
164+
165+
fn take(&self) -> bool {
166+
self.triggered.swap(false, Ordering::SeqCst)
167+
}
168+
}
169+
170+
impl ArcWake for ExecutorWakeFlag {
171+
fn wake_by_ref(arc_self: &Arc<Self>) {
172+
arc_self.triggered.store(true, Ordering::SeqCst);
173+
}
174+
}
175+
176+
pub struct JoinHandle<T> {
177+
receiver: oneshot::Receiver<T>,
178+
}
179+
180+
impl<T> Future for JoinHandle<T> {
181+
type Output = Result<T, oneshot::Canceled>;
182+
183+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
184+
let receiver = &mut self.get_mut().receiver;
185+
Pin::new(receiver).poll(cx)
186+
}
187+
}
188+
189+
pub fn spawn<T>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T>
190+
where
191+
T: 'static,
192+
{
193+
let (sender, receiver) = oneshot::channel();
150194
SPAWN_QUEUE.with(|queue| {
151-
queue.borrow_mut().push(Box::pin(fut));
152-
})
195+
queue.borrow_mut().push(Box::pin(async move {
196+
let result = fut.await;
197+
let _ = sender.send(result);
198+
}));
199+
});
200+
JoinHandle { receiver }
153201
}
154202

155203
impl Executor {
@@ -158,19 +206,24 @@ impl Executor {
158206
}
159207

160208
pub fn poll_all_tasks(&mut self) {
209+
let wake_flag = Arc::new(ExecutorWakeFlag::new());
161210
loop {
162211
// Drain any newly spawned tasks into our task list
163212
SPAWN_QUEUE.with(|queue| {
164213
self.tasks.append(&mut queue.borrow_mut());
165214
});
166215

167-
// Poll all tasks, collecting completed ones
216+
// Poll all tasks, collecting completed ones.
217+
// Put waker into context so tasks can wake the executor if needed.
168218
let mut completed = Vec::new();
169-
let mut ctx = Context::from_waker(noop_waker_ref());
219+
{
220+
let waker = waker_ref(&wake_flag);
221+
let mut ctx = Context::from_waker(&waker);
170222

171-
for i in 0..self.tasks.len() {
172-
if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) {
173-
completed.push(i);
223+
for i in 0..self.tasks.len() {
224+
if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) {
225+
completed.push(i);
226+
}
174227
}
175228
}
176229

@@ -181,9 +234,10 @@ impl Executor {
181234

182235
// Check if there are new tasks spawned during polling
183236
let has_new_tasks = SPAWN_QUEUE.with(|queue| !queue.borrow().is_empty());
237+
// Check if any task woke the executor that needs to be re-polled
238+
let was_woken = wake_flag.take();
184239

185-
// Continue if new tasks were spawned, otherwise we're done
186-
if !has_new_tasks {
240+
if !has_new_tasks && !was_woken {
187241
break;
188242
}
189243
}
@@ -193,6 +247,7 @@ struct ResponseFuture {
193247
correlation_id: String,
194248
// Capture HTTP context at creation time
195249
http_context: Option<HttpRequestContext>,
250+
resolved: bool,
196251
}
197252

198253
impl ResponseFuture {
@@ -204,6 +259,7 @@ impl ResponseFuture {
204259
Self {
205260
correlation_id,
206261
http_context,
262+
resolved: false,
207263
}
208264
}
209265
}
@@ -212,16 +268,18 @@ impl Future for ResponseFuture {
212268
type Output = Vec<u8>;
213269

214270
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
215-
let correlation_id = &self.correlation_id;
271+
let this = self.get_mut();
216272

217273
let maybe_bytes = RESPONSE_REGISTRY.with(|registry| {
218274
let mut registry_mut = registry.borrow_mut();
219-
registry_mut.remove(correlation_id)
275+
registry_mut.remove(&this.correlation_id)
220276
});
221277

222278
if let Some(bytes) = maybe_bytes {
279+
this.resolved = true;
280+
223281
// Restore this future's captured context
224-
if let Some(ref context) = self.http_context {
282+
if let Some(ref context) = this.http_context {
225283
APP_HELPERS.with(|helpers| {
226284
helpers.borrow_mut().current_http_context = Some(context.clone());
227285
});
@@ -234,6 +292,23 @@ impl Future for ResponseFuture {
234292
}
235293
}
236294

295+
impl Drop for ResponseFuture {
296+
fn drop(&mut self) {
297+
// We want to avoid cleaning up after successful responses
298+
if self.resolved {
299+
return;
300+
}
301+
302+
RESPONSE_REGISTRY.with(|registry| {
303+
registry.borrow_mut().remove(&self.correlation_id);
304+
});
305+
306+
CANCELLED_RESPONSES.with(|set| {
307+
set.borrow_mut().insert(self.correlation_id.clone());
308+
});
309+
}
310+
}
311+
237312
#[derive(Debug, Clone, Serialize, Deserialize, Error)]
238313
pub enum AppSendError {
239314
#[error("SendError: {0}")]

0 commit comments

Comments
 (0)