Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions src/agent-client-protocol/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use std::{
any::Any,
borrow::Cow,
collections::HashMap,
future::Future,
marker::PhantomData,
pin::Pin,
rc::Rc,
sync::{
Arc, Mutex,
atomic::{AtomicI64, Ordering},
},
task::{Context, Poll},
};

use agent_client_protocol_schema::{
Expand All @@ -22,7 +26,6 @@ use futures::{
},
future::LocalBoxFuture,
io::BufReader,
select_biased,
};
use serde::{Deserialize, de::DeserializeOwned};
use serde_json::value::RawValue;
Expand All @@ -43,6 +46,43 @@ struct PendingResponse {
respond: oneshot::Sender<Result<Box<dyn Any + Send>>>,
}

pub(crate) struct PendingRequest<Out> {
id: RequestId,
pending_responses: Arc<Mutex<HashMap<RequestId, PendingResponse>>>,
rx: oneshot::Receiver<Result<Box<dyn Any + Send>>>,
_marker: PhantomData<Out>,
}

impl<Out> Future for PendingRequest<Out>
where
Out: Send + 'static,
{
type Output = Result<Out>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match Pin::new(&mut this.rx).poll(cx) {
Poll::Ready(result) => {
let result = result
.map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??;
let result = result
.downcast::<Out>()
.map_err(|_| Error::internal_error().data("failed to deserialize response"))?;
Poll::Ready(Ok(*result))
}
Poll::Pending => Poll::Pending,
}
}
}

impl<Out> Unpin for PendingRequest<Out> {}

impl<Out> Drop for PendingRequest<Out> {
fn drop(&mut self) {
drop(self.pending_responses.lock().unwrap().remove(&self.id));
}
}

impl<Local, Remote> RpcConnection<Local, Remote>
where
Local: Side + 'static,
Expand Down Expand Up @@ -113,7 +153,7 @@ where
&self,
method: impl Into<Arc<str>>,
params: Option<Remote::InRequest>,
) -> Result<impl Future<Output = Result<Out>>> {
) -> Result<PendingRequest<Out>> {
let (tx, rx) = oneshot::channel();
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let id = RequestId::Number(id);
Expand Down Expand Up @@ -143,14 +183,11 @@ where
Error::internal_error().data("connection closed before request could be sent")
);
}
Ok(async move {
let result = rx
.await
.map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??
.downcast::<Out>()
.map_err(|_| Error::internal_error().data("failed to deserialize response"))?;

Ok(*result)
Ok(PendingRequest {
id,
pending_responses: self.pending_responses.clone(),
rx,
_marker: PhantomData,
})
}

Expand All @@ -167,7 +204,7 @@ where
let mut outgoing_line = Vec::new();
let mut incoming_line = String::new();
loop {
select_biased! {
futures::select! {
message = outgoing_rx.next() => {
if let Some(message) = message {
outgoing_line.clear();
Expand Down Expand Up @@ -236,7 +273,9 @@ where
pending_response.respond.send(result).ok();
}
} else {
log::error!("received response for unknown request id: {id:?}");
log::debug!(
"received response for unknown request id: {id:?} (possibly cancelled)"
);
}
} else if let Some(method) = message.method {
// Notification
Expand Down Expand Up @@ -315,6 +354,14 @@ where
}
}

#[cfg(test)]
impl<Local: Side, Remote: Side> RpcConnection<Local, Remote> {
// Test-only visibility into pending request tracking for drop cleanup assertions.
pub(crate) fn pending_response_count(&self) -> usize {
self.pending_responses.lock().unwrap().len()
}
}

#[derive(Debug, Deserialize)]
pub struct RawIncomingMessage<'a> {
id: Option<RequestId>,
Expand Down
51 changes: 51 additions & 0 deletions src/agent-client-protocol/src/rpc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,54 @@ async fn test_set_session_config_option() {
})
.await;
}

#[tokio::test]
async fn test_pending_response_cleanup_on_drop() {
struct NoopHandler;

impl MessageHandler<ClientSide> for NoopHandler {
fn handle_request(
&self,
_request: AgentRequest,
) -> impl std::future::Future<Output = Result<ClientResponse>> {
async { Err(Error::internal_error()) }
}

fn handle_notification(
&self,
_notification: AgentNotification,
) -> impl std::future::Future<Output = Result<()>> {
async { Ok(()) }
}
}

let local_set = tokio::task::LocalSet::new();
local_set
.run_until(async {
let (_client_to_agent_rx, client_to_agent_tx) = piper::pipe(1024);
let (agent_to_client_rx, _agent_to_client_tx) = piper::pipe(1024);

let (conn, _io_task) = RpcConnection::<ClientSide, AgentSide>::new(
NoopHandler,
client_to_agent_tx,
agent_to_client_rx,
|fut| {
tokio::task::spawn_local(fut);
},
);

let pending = conn
.request::<InitializeResponse>(
AGENT_METHOD_NAMES.initialize,
Some(ClientRequest::InitializeRequest(InitializeRequest::new(
ProtocolVersion::LATEST,
))),
)
.expect("request should enqueue pending response");

assert_eq!(conn.pending_response_count(), 1);
drop(pending);
assert_eq!(conn.pending_response_count(), 0);
})
.await;
}