Skip to content

Commit 876b0c2

Browse files
committed
implement join handle for tasks and implement waker
1 parent b21dd64 commit 876b0c2

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ repository = "https://github.com/hyperware-ai/process_lib"
99
license = "Apache-2.0"
1010

1111
[features]
12-
hyperapp = ["dep:futures-util", "dep:uuid", "logging"]
12+
hyperapp = ["dep:futures-util", "dep:futures-channel", "dep:uuid", "logging"]
1313
logging = ["dep:color-eyre", "dep:tracing", "dep:tracing-error", "dep:tracing-subscriber"]
1414
hyperwallet = ["dep:hex", "dep:sha3"]
1515
simulation-mode = []
@@ -42,6 +42,7 @@ url = "2.4.1"
4242
wit-bindgen = "0.42.1"
4343

4444
futures-util = { version = "0.3", optional = true }
45+
futures-channel = { version = "0.3", optional = true }
4546
uuid = { version = "1.0", features = ["v4"], optional = true }
4647

4748
color-eyre = { version = "0.6", features = ["capture-spantrace"], optional = true }

src/hyperapp.rs

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ use std::cell::RefCell;
22
use std::collections::HashMap;
33
use std::future::Future;
44
use std::pin::Pin;
5+
use std::sync::{
6+
atomic::{AtomicBool, Ordering},
7+
Arc,
8+
};
59
use std::task::{Context, Poll};
610

711
use crate::{
@@ -10,15 +14,15 @@ use crate::{
1014
logging::{error, info},
1115
set_state, timer, Address, BuildError, LazyLoadBlob, Message, Request, SendError,
1216
};
13-
use futures_util::task::noop_waker_ref;
17+
use futures_util::task::{waker_ref, ArcWake};
18+
use futures_channel::{mpsc, oneshot};
1419
use serde::{Deserialize, Serialize};
1520
use thiserror::Error;
1621
use uuid::Uuid;
1722

1823
thread_local! {
1924
static SPAWN_QUEUE: RefCell<Vec<Pin<Box<dyn Future<Output = ()>>>>> = RefCell::new(Vec::new());
2025

21-
2226
pub static APP_CONTEXT: RefCell<AppContext> = RefCell::new(AppContext {
2327
hidden_state: None,
2428
executor: Executor::new(),
@@ -146,10 +150,53 @@ pub struct Executor {
146150
tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
147151
}
148152

149-
pub fn spawn(fut: impl Future<Output = ()> + 'static) {
153+
struct ExecutorWakeFlag {
154+
triggered: AtomicBool,
155+
}
156+
157+
impl ExecutorWakeFlag {
158+
fn new() -> Self {
159+
Self {
160+
triggered: AtomicBool::new(false),
161+
}
162+
}
163+
164+
fn take(&self) -> bool {
165+
self.triggered.swap(false, Ordering::SeqCst)
166+
}
167+
}
168+
169+
impl ArcWake for ExecutorWakeFlag {
170+
fn wake_by_ref(arc_self: &Arc<Self>) {
171+
arc_self.triggered.store(true, Ordering::SeqCst);
172+
}
173+
}
174+
175+
pub struct JoinHandle<T> {
176+
receiver: oneshot::Receiver<T>,
177+
}
178+
179+
impl<T> Future for JoinHandle<T> {
180+
type Output = Result<T, oneshot::Canceled>;
181+
182+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
183+
let receiver = &mut self.get_mut().receiver;
184+
Pin::new(receiver).poll(cx)
185+
}
186+
}
187+
188+
pub fn spawn<T>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T>
189+
where
190+
T: 'static,
191+
{
192+
let (sender, receiver) = oneshot::channel();
150193
SPAWN_QUEUE.with(|queue| {
151-
queue.borrow_mut().push(Box::pin(fut));
152-
})
194+
queue.borrow_mut().push(Box::pin(async move {
195+
let result = fut.await;
196+
let _ = sender.send(result);
197+
}));
198+
});
199+
JoinHandle { receiver }
153200
}
154201

155202
impl Executor {
@@ -158,19 +205,24 @@ impl Executor {
158205
}
159206

160207
pub fn poll_all_tasks(&mut self) {
208+
let wake_flag = Arc::new(ExecutorWakeFlag::new());
161209
loop {
162210
// Drain any newly spawned tasks into our task list
163211
SPAWN_QUEUE.with(|queue| {
164212
self.tasks.append(&mut queue.borrow_mut());
165213
});
166214

167-
// Poll all tasks, collecting completed ones
215+
// Poll all tasks, collecting completed ones.
216+
// Put waker into context so tasks can wake the executor if needed.
168217
let mut completed = Vec::new();
169-
let mut ctx = Context::from_waker(noop_waker_ref());
218+
{
219+
let waker = waker_ref(&wake_flag);
220+
let mut ctx = Context::from_waker(&waker);
170221

171-
for i in 0..self.tasks.len() {
172-
if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) {
173-
completed.push(i);
222+
for i in 0..self.tasks.len() {
223+
if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) {
224+
completed.push(i);
225+
}
174226
}
175227
}
176228

@@ -181,9 +233,10 @@ impl Executor {
181233

182234
// Check if there are new tasks spawned during polling
183235
let has_new_tasks = SPAWN_QUEUE.with(|queue| !queue.borrow().is_empty());
236+
// Check if any task woke the executor that needs to be re-polled
237+
let was_woken = wake_flag.take();
184238

185-
// Continue if new tasks were spawned, otherwise we're done
186-
if !has_new_tasks {
239+
if !has_new_tasks && !was_woken {
187240
break;
188241
}
189242
}

0 commit comments

Comments
 (0)