diff --git a/crates/proc-macro-api/src/bidirectional_protocol.rs b/crates/proc-macro-api/src/bidirectional_protocol.rs index e44723a6a389..062de20c26e9 100644 --- a/crates/proc-macro-api/src/bidirectional_protocol.rs +++ b/crates/proc-macro-api/src/bidirectional_protocol.rs @@ -2,6 +2,7 @@ use std::{ io::{self, BufRead, Write}, + panic::{AssertUnwindSafe, catch_unwind}, sync::Arc, }; @@ -56,9 +57,19 @@ pub fn run_conversation( return Ok(BidirectionalMessage::Response(response)); } BidirectionalMessage::SubRequest(sr) => { - let resp = callback(sr)?; - let reply = BidirectionalMessage::SubResponse(resp); - let encoded = C::encode(&reply).map_err(wrap_encode)?; + // TODO: Avoid `AssertUnwindSafe` by making the callback `UnwindSafe` once `ExpandDatabase` + // becomes unwind-safe (currently blocked by `parking_lot::RwLock` in the VFS). + let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) { + Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp), + Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel { + reason: err.to_string(), + }), + Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel { + reason: "callback panicked or was cancelled".into(), + }), + }; + + let encoded = C::encode(&resp).map_err(wrap_encode)?; C::write(writer, &encoded).map_err(wrap_io("failed to write sub-response"))?; } _ => { diff --git a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs index 57e7b1ee8f68..f93d3cbfbc0b 100644 --- a/crates/proc-macro-api/src/bidirectional_protocol/msg.rs +++ b/crates/proc-macro-api/src/bidirectional_protocol/msg.rs @@ -38,6 +38,9 @@ pub enum SubResponse { ByteRangeResult { range: Range, }, + Cancel { + reason: String, + }, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs index 3beaeb0697ec..d7116b70c267 100644 --- a/crates/proc-macro-srv-cli/src/main_loop.rs +++ b/crates/proc-macro-srv-cli/src/main_loop.rs @@ -6,6 +6,7 @@ use proc_macro_api::{ transport::codec::{json::JsonProtocol, postcard::PostcardProtocol}, version::CURRENT_API_VERSION, }; +use std::panic::panic_any; use std::{ io::{self, BufRead, Write}, ops::Range, @@ -13,7 +14,7 @@ use std::{ use legacy::Message; -use proc_macro_srv::{EnvSnapshot, SpanId}; +use proc_macro_srv::{EnvSnapshot, ProcMacroClientError, ProcMacroPanicMarker, SpanId}; use crate::ProtocolFormat; struct SpanTrans; @@ -179,16 +180,43 @@ impl<'a, C: Codec> ProcMacroClientHandle<'a, C> { fn roundtrip( &mut self, req: bidirectional::SubRequest, - ) -> Option { + ) -> Result { let msg = bidirectional::BidirectionalMessage::SubRequest(req); - if msg.write::(&mut *self.stdout).is_err() { - return None; + msg.write::(&mut *self.stdout).map_err(ProcMacroClientError::Io)?; + + let msg = bidirectional::BidirectionalMessage::read::(&mut *self.stdin, self.buf) + .map_err(ProcMacroClientError::Io)? + .ok_or(ProcMacroClientError::Eof)?; + + match msg { + bidirectional::BidirectionalMessage::SubResponse(resp) => match resp { + bidirectional::SubResponse::Cancel { reason } => { + Err(ProcMacroClientError::Cancelled { reason }) + } + other => Ok(other), + }, + other => { + Err(ProcMacroClientError::Protocol(format!("expected SubResponse, got {other:?}"))) + } } + } +} - match bidirectional::BidirectionalMessage::read::(&mut *self.stdin, self.buf) { - Ok(Some(msg)) => Some(msg), - _ => None, +fn handle_failure(failure: Result) -> ! { + match failure { + Err(ProcMacroClientError::Cancelled { reason }) => { + panic_any(ProcMacroPanicMarker::Cancelled { reason }); + } + Err(err) => { + panic_any(ProcMacroPanicMarker::Internal { + reason: format!("proc-macro IPC error: {err:?}"), + }); + } + Ok(other) => { + panic_any(ProcMacroPanicMarker::Internal { + reason: format!("unexpected SubResponse {other:?}"), + }); } } } @@ -196,10 +224,8 @@ impl<'a, C: Codec> ProcMacroClientHandle<'a, C> { impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandle<'_, C> { fn file(&mut self, file_id: proc_macro_srv::span::FileId) -> String { match self.roundtrip(bidirectional::SubRequest::FilePath { file_id: file_id.index() }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::FilePathResult { name }, - )) => name, - _ => String::new(), + Ok(bidirectional::SubResponse::FilePathResult { name }) => name, + other => handle_failure(other), } } @@ -213,20 +239,16 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl start: range.start().into(), end: range.end().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::SourceTextResult { text }, - )) => text, - _ => None, + Ok(bidirectional::SubResponse::SourceTextResult { text }) => text, + other => handle_failure(other), } } fn local_file(&mut self, file_id: proc_macro_srv::span::FileId) -> Option { match self.roundtrip(bidirectional::SubRequest::LocalFilePath { file_id: file_id.index() }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::LocalFilePathResult { name }, - )) => name, - _ => None, + Ok(bidirectional::SubResponse::LocalFilePathResult { name }) => name, + other => handle_failure(other), } } @@ -237,10 +259,10 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl ast_id: anchor.ast_id.into_raw(), offset: range.start().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::LineColumnResult { line, column }, - )) => Some((line, column)), - _ => None, + Ok(bidirectional::SubResponse::LineColumnResult { line, column }) => { + Some((line, column)) + } + other => handle_failure(other), } } @@ -254,10 +276,8 @@ impl proc_macro_srv::ProcMacroClientInterface for ProcMacroClientHandl start: range.start().into(), end: range.end().into(), }) { - Some(bidirectional::BidirectionalMessage::SubResponse( - bidirectional::SubResponse::ByteRangeResult { range }, - )) => range, - _ => Range { start: range.start().into(), end: range.end().into() }, + Ok(bidirectional::SubResponse::ByteRangeResult { range }) => range, + other => handle_failure(other), } } } diff --git a/crates/proc-macro-srv/src/lib.rs b/crates/proc-macro-srv/src/lib.rs index ac9f89352c2b..fc905475ed12 100644 --- a/crates/proc-macro-srv/src/lib.rs +++ b/crates/proc-macro-srv/src/lib.rs @@ -93,6 +93,20 @@ impl<'env> ProcMacroSrv<'env> { } } +#[derive(Debug)] +pub enum ProcMacroClientError { + Cancelled { reason: String }, + Io(std::io::Error), + Protocol(String), + Eof, +} + +#[derive(Debug)] +pub enum ProcMacroPanicMarker { + Cancelled { reason: String }, + Internal { reason: String }, +} + pub type ProcMacroClientHandle<'a> = &'a mut (dyn ProcMacroClientInterface + Sync + Send); pub trait ProcMacroClientInterface { @@ -107,6 +121,22 @@ pub trait ProcMacroClientInterface { const EXPANDER_STACK_SIZE: usize = 8 * 1024 * 1024; +pub enum ExpandError { + Panic(PanicMessage), + Cancelled { reason: Option }, + Internal { reason: Option }, +} + +impl ExpandError { + pub fn into_string(self) -> Option { + match self { + ExpandError::Panic(panic_message) => panic_message.into_string(), + ExpandError::Cancelled { reason } => reason, + ExpandError::Internal { reason } => reason, + } + } +} + impl ProcMacroSrv<'_> { pub fn expand( &self, @@ -120,10 +150,10 @@ impl ProcMacroSrv<'_> { call_site: S, mixed_site: S, callback: Option>, - ) -> Result, PanicMessage> { + ) -> Result, ExpandError> { let snapped_env = self.env; - let expander = self.expander(lib.as_ref()).map_err(|err| PanicMessage { - message: Some(format!("failed to load macro: {err}")), + let expander = self.expander(lib.as_ref()).map_err(|err| ExpandError::Internal { + reason: Some(format!("failed to load macro: {err}")), })?; let prev_env = EnvChange::apply(snapped_env, env, current_dir.as_ref().map(<_>::as_ref)); @@ -141,8 +171,22 @@ impl ProcMacroSrv<'_> { ) }); match thread.unwrap().join() { - Ok(res) => res, - Err(e) => std::panic::resume_unwind(e), + Ok(res) => res.map_err(ExpandError::Panic), + + Err(payload) => { + if let Some(marker) = payload.downcast_ref::() { + return match marker { + ProcMacroPanicMarker::Cancelled { reason } => { + Err(ExpandError::Cancelled { reason: Some(reason.clone()) }) + } + ProcMacroPanicMarker::Internal { reason } => { + Err(ExpandError::Internal { reason: Some(reason.clone()) }) + } + }; + } + + std::panic::resume_unwind(payload) + } } }); prev_env.rollback();