Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions crates/proc-macro-api/src/bidirectional_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::{
io::{self, BufRead, Write},
panic::{AssertUnwindSafe, catch_unwind},
sync::Arc,
};

Expand Down Expand Up @@ -56,9 +57,17 @@ pub fn run_conversation<C: Codec>(
return Ok(BidirectionalMessage::Response(response));
}
BidirectionalMessage::SubRequest(sr) => {
let resp = callback(sr)?;
let reply = BidirectionalMessage::SubResponse(resp);
let encoded = C::encode(&reply).map_err(wrap_encode)?;
let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would slightly prefer to have the callback type + UnwindSafe than this AssertUnwindSafe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky because the callback captures &dyn ExpandDatabase (later at implementation), which is not UnwindSafe. Requiring UnwindSafe on the callback would therefore force unwind-safety guarantees on ExpandDatabase.

Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp),
Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
reason: err.to_string(),
}),
Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
reason: "callback panicked or was cancelled".into(),
}),
};

let encoded = C::encode(&resp).map_err(wrap_encode)?;
C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?;
}
_ => {
Expand Down
3 changes: 3 additions & 0 deletions crates/proc-macro-api/src/bidirectional_protocol/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ pub enum SubResponse {
line: u32,
column: u32,
},
Cancel {
reason: String,
},
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
65 changes: 42 additions & 23 deletions crates/proc-macro-srv-cli/src/main_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ use proc_macro_api::{
version::CURRENT_API_VERSION,
};
use std::io;
use std::panic::panic_any;

use legacy::Message;

use proc_macro_srv::{EnvSnapshot, SpanId};
use proc_macro_srv::{EnvSnapshot, ProcMacroCancelMarker, ProcMacroClientError, SpanId};

use crate::ProtocolFormat;
struct SpanTrans;
Expand Down Expand Up @@ -170,27 +171,49 @@ impl<'a, C: Codec> ProcMacroClientHandle<'a, C> {
fn roundtrip(
&mut self,
req: bidirectional::SubRequest,
) -> Option<bidirectional::BidirectionalMessage> {
) -> Result<bidirectional::SubResponse, ProcMacroClientError> {
let msg = bidirectional::BidirectionalMessage::SubRequest(req);

if msg.write::<_, C>(&mut self.stdout.lock()).is_err() {
return None;
msg.write::<_, C>(&mut self.stdout.lock()).map_err(ProcMacroClientError::Io)?;

let msg =
bidirectional::BidirectionalMessage::read::<_, C>(&mut self.stdin.lock(), self.buf)
.map_err(ProcMacroClientError::Io)?
.ok_or(ProcMacroClientError::Eof)?;

match msg {
bidirectional::BidirectionalMessage::SubResponse(resp) => match resp {
bidirectional::SubResponse::Cancel { reason } => {
Err(ProcMacroClientError::Cancelled { reason })
}
other => Ok(other),
},
other => {
Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}")))
}
}
}
}

match bidirectional::BidirectionalMessage::read::<_, C>(&mut self.stdin.lock(), self.buf) {
Ok(Some(msg)) => Some(msg),
_ => None,
fn handle_failure(failure: Result<bidirectional::SubResponse, ProcMacroClientError>) -> ! {
match failure {
Err(ProcMacroClientError::Cancelled { reason }) => {
panic_any(ProcMacroCancelMarker { reason });
}
Err(err) => {
panic!("proc-macro IPC failed: {err:?}");
}
Ok(other) => {
panic!("unexpected SubResponse {other:?}");
}
}
}

impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_, C> {
fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String {
match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::FilePathResult { name },
)) => name,
_ => String::new(),
Ok(bidirectional::SubResponse::FilePathResult { name }) => name,
other => handle_failure(other),
}
}

Expand All @@ -204,20 +227,16 @@ impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl
start: range.start().into(),
end: range.end().into(),
}) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::SourceTextResult { text },
)) => text,
_ => None,
Ok(bidirectional::SubResponse::SourceTextResult { text }) => text,
other => handle_failure(other),
}
}

fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option<String> {
match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() })
{
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::LocalFilePathResult { name },
)) => name,
_ => None,
Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name,
other => handle_failure(other),
}
}

Expand All @@ -228,10 +247,10 @@ impl<C: Codec> proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl
ast_id: anchor.ast_id.into_raw(),
offset: range.start().into(),
}) {
Some(bidirectional::BidirectionalMessage::SubResponse(
bidirectional::SubResponse::LineColumnResult { line, column },
)) => Some((line, column)),
_ => None,
Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => {
Some((line, column))
}
other => handle_failure(other),
}
}
}
Expand Down
45 changes: 40 additions & 5 deletions crates/proc-macro-srv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ impl<'env> ProcMacroSrv<'env> {
}
}

#[derive(Debug)]
pub enum ProcMacroClientError {
Cancelled { reason: String },
Io(std::io::Error),
Protocol(String),
Eof,
}

#[derive(Debug)]
pub struct ProcMacroCancelMarker {
pub reason: String,
}

pub type ProcMacroClientHandle<'a> = &'a mut (dyn ProcMacroClientInterface + Sync + Send);

pub trait ProcMacroClientInterface {
Expand All @@ -104,6 +117,20 @@ pub trait ProcMacroClientInterface {

const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024;

pub enum ExpandError {
Panic(PanicMessage),
Cancelled { reason: Option<String> },
}

impl ExpandError {
pub fn into_string(self) -> Option<String> {
match self {
ExpandError::Panic(panic_message) => panic_message.into_string(),
ExpandError::Cancelled { reason } => reason,
}
}
}

impl ProcMacroSrv<'_> {
pub fn expand<S: ProcMacroSrvSpan>(
&self,
Expand All @@ -117,10 +144,12 @@ impl ProcMacroSrv<'_> {
call_site: S,
mixed_site: S,
callback: Option<ProcMacroClientHandle<'_>>,
) -> Result<token_stream::TokenStream<S>, PanicMessage> {
) -> Result<token_stream::TokenStream<S>, ExpandError> {
let snapped_env = self.env;
let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage {
message: Some(format!("failed to load macro: {err}")),
let expander = self.expander(lib.as_ref()).map_err(|err| {
ExpandError::Panic(PanicMessage {
message: Some(format!("failed to load macro: {err}")),
})
})?;

let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref));
Expand All @@ -138,8 +167,14 @@ impl ProcMacroSrv<'_> {
)
});
match thread.unwrap().join() {
Ok(res) => res,
Err(e) => std::panic::resume_unwind(e),
Ok(res) => res.map_err(ExpandError::Panic),

Err(payload) => {
if let Some(cancel) = payload.downcast_ref::<ProcMacroCancelMarker>() {
return Err(ExpandError::Cancelled { reason: Some(cancel.reason.clone()) });
}
std::panic::resume_unwind(payload)
}
}
});
prev_env.rollback();
Expand Down