add codec and framing to abstract encoding and decoding logic from run
diff --git a/crates/proc-macro-api/src/codec.rs b/crates/proc-macro-api/src/codec.rs
new file mode 100644
index 0000000..baccaa6
--- /dev/null
+++ b/crates/proc-macro-api/src/codec.rs
@@ -0,0 +1,12 @@
+//! Protocol codec
+
+use std::io;
+
+use serde::de::DeserializeOwned;
+
+use crate::framing::Framing;
+
+pub trait Codec: Framing {
+ fn encode<T: serde::Serialize>(msg: &T) -> io::Result<Self::Buf>;
+ fn decode<T: DeserializeOwned>(buf: &mut Self::Buf) -> io::Result<T>;
+}
diff --git a/crates/proc-macro-api/src/framing.rs b/crates/proc-macro-api/src/framing.rs
new file mode 100644
index 0000000..a1e6fc0
--- /dev/null
+++ b/crates/proc-macro-api/src/framing.rs
@@ -0,0 +1,14 @@
+//! Protocol framing
+
+use std::io::{self, BufRead, Write};
+
+pub trait Framing {
+ type Buf: Default;
+
+ fn read<'a, R: BufRead>(
+ inp: &mut R,
+ buf: &'a mut Self::Buf,
+ ) -> io::Result<Option<&'a mut Self::Buf>>;
+
+ fn write<W: Write>(out: &mut W, buf: &Self::Buf) -> io::Result<()>;
+}
diff --git a/crates/proc-macro-api/src/legacy_protocol.rs b/crates/proc-macro-api/src/legacy_protocol.rs
index 6d521d0..c2b132d 100644
--- a/crates/proc-macro-api/src/legacy_protocol.rs
+++ b/crates/proc-macro-api/src/legacy_protocol.rs
@@ -14,14 +14,15 @@
use crate::{
ProcMacro, ProcMacroKind, ServerError,
+ codec::Codec,
legacy_protocol::{
- json::{read_json, write_json},
+ json::JsonProtocol,
msg::{
ExpandMacro, ExpandMacroData, ExpnGlobals, FlatTree, Message, Request, Response,
ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
flat::serialize_span_data_index_map,
},
- postcard::{read_postcard, write_postcard},
+ postcard::PostcardProtocol,
},
process::ProcMacroServerProcess,
version,
@@ -154,42 +155,26 @@
}
if srv.use_postcard() {
- srv.send_task(send_request_postcard, req)
+ srv.send_task(send_request::<PostcardProtocol>, req)
} else {
- srv.send_task(send_request, req)
+ srv.send_task(send_request::<JsonProtocol>, req)
}
}
/// Sends a request to the server and reads the response.
-fn send_request(
+fn send_request<P: Codec>(
mut writer: &mut dyn Write,
mut reader: &mut dyn BufRead,
req: Request,
- buf: &mut String,
+ buf: &mut P::Buf,
) -> Result<Option<Response>, ServerError> {
- req.write(write_json, &mut writer).map_err(|err| ServerError {
+ req.write::<_, P>(&mut writer).map_err(|err| ServerError {
message: "failed to write request".into(),
io: Some(Arc::new(err)),
})?;
- let res = Response::read(read_json, &mut reader, buf).map_err(|err| ServerError {
+ let res = Response::read::<_, P>(&mut reader, buf).map_err(|err| ServerError {
message: "failed to read response".into(),
io: Some(Arc::new(err)),
})?;
Ok(res)
}
-
-fn send_request_postcard(
- mut writer: &mut dyn Write,
- mut reader: &mut dyn BufRead,
- req: Request,
- buf: &mut Vec<u8>,
-) -> Result<Option<Response>, ServerError> {
- req.write_postcard(write_postcard, &mut writer).map_err(|err| ServerError {
- message: "failed to write request".into(),
- io: Some(Arc::new(err)),
- })?;
- let res = Response::read_postcard(read_postcard, &mut reader, buf).map_err(|err| {
- ServerError { message: "failed to read response".into(), io: Some(Arc::new(err)) }
- })?;
- Ok(res)
-}
diff --git a/crates/proc-macro-api/src/legacy_protocol/json.rs b/crates/proc-macro-api/src/legacy_protocol/json.rs
index cf8535f..1359c05 100644
--- a/crates/proc-macro-api/src/legacy_protocol/json.rs
+++ b/crates/proc-macro-api/src/legacy_protocol/json.rs
@@ -1,36 +1,58 @@
//! Protocol functions for json.
use std::io::{self, BufRead, Write};
-/// Reads a JSON message from the input stream.
-pub fn read_json<'a>(
- inp: &mut impl BufRead,
- buf: &'a mut String,
-) -> io::Result<Option<&'a mut String>> {
- loop {
- buf.clear();
+use serde::{Serialize, de::DeserializeOwned};
- inp.read_line(buf)?;
- buf.pop(); // Remove trailing '\n'
+use crate::{codec::Codec, framing::Framing};
- if buf.is_empty() {
- return Ok(None);
+pub struct JsonProtocol;
+
+impl Framing for JsonProtocol {
+ type Buf = String;
+
+ fn read<'a, R: BufRead>(
+ inp: &mut R,
+ buf: &'a mut String,
+ ) -> io::Result<Option<&'a mut String>> {
+ loop {
+ buf.clear();
+
+ inp.read_line(buf)?;
+ buf.pop(); // Remove trailing '\n'
+
+ if buf.is_empty() {
+ return Ok(None);
+ }
+
+ // Some ill behaved macro try to use stdout for debugging
+ // We ignore it here
+ if !buf.starts_with('{') {
+ tracing::error!("proc-macro tried to print : {}", buf);
+ continue;
+ }
+
+ return Ok(Some(buf));
}
+ }
- // Some ill behaved macro try to use stdout for debugging
- // We ignore it here
- if !buf.starts_with('{') {
- tracing::error!("proc-macro tried to print : {}", buf);
- continue;
- }
-
- return Ok(Some(buf));
+ fn write<W: Write>(out: &mut W, buf: &String) -> io::Result<()> {
+ tracing::debug!("> {}", buf);
+ out.write_all(buf.as_bytes())?;
+ out.write_all(b"\n")?;
+ out.flush()
}
}
-/// Writes a JSON message to the output stream.
-pub fn write_json(out: &mut impl Write, msg: &String) -> io::Result<()> {
- tracing::debug!("> {}", msg);
- out.write_all(msg.as_bytes())?;
- out.write_all(b"\n")?;
- out.flush()
+impl Codec for JsonProtocol {
+ fn encode<T: Serialize>(msg: &T) -> io::Result<String> {
+ Ok(serde_json::to_string(msg)?)
+ }
+
+ fn decode<T: DeserializeOwned>(buf: &mut String) -> io::Result<T> {
+ let mut deserializer = serde_json::Deserializer::from_str(buf);
+ // Note that some proc-macro generate very deep syntax tree
+ // We have to disable the current limit of serde here
+ deserializer.disable_recursion_limit();
+ Ok(T::deserialize(&mut deserializer)?)
+ }
}
diff --git a/crates/proc-macro-api/src/legacy_protocol/msg.rs b/crates/proc-macro-api/src/legacy_protocol/msg.rs
index 6df1846..1c77863 100644
--- a/crates/proc-macro-api/src/legacy_protocol/msg.rs
+++ b/crates/proc-macro-api/src/legacy_protocol/msg.rs
@@ -8,10 +8,7 @@
use serde::de::DeserializeOwned;
use serde_derive::{Deserialize, Serialize};
-use crate::{
- ProcMacroKind,
- legacy_protocol::postcard::{decode_cobs, encode_cobs},
-};
+use crate::{ProcMacroKind, codec::Codec};
/// Represents requests sent from the client to the proc-macro-srv.
#[derive(Debug, Serialize, Deserialize)]
@@ -152,60 +149,21 @@
}
pub trait Message: serde::Serialize + DeserializeOwned {
- fn read<R: BufRead>(
- from_proto: ProtocolRead<R, String>,
- inp: &mut R,
- buf: &mut String,
- ) -> io::Result<Option<Self>> {
- Ok(match from_proto(inp, buf)? {
+ fn read<R: BufRead, C: Codec>(inp: &mut R, buf: &mut C::Buf) -> io::Result<Option<Self>> {
+ Ok(match C::read(inp, buf)? {
None => None,
- Some(text) => {
- let mut deserializer = serde_json::Deserializer::from_str(text);
- // Note that some proc-macro generate very deep syntax tree
- // We have to disable the current limit of serde here
- deserializer.disable_recursion_limit();
- Some(Self::deserialize(&mut deserializer)?)
- }
+ Some(buf) => C::decode(buf)?,
})
}
- fn write<W: Write>(self, to_proto: ProtocolWrite<W, String>, out: &mut W) -> io::Result<()> {
- let text = serde_json::to_string(&self)?;
- to_proto(out, &text)
- }
-
- fn read_postcard<R: BufRead>(
- from_proto: ProtocolRead<R, Vec<u8>>,
- inp: &mut R,
- buf: &mut Vec<u8>,
- ) -> io::Result<Option<Self>> {
- Ok(match from_proto(inp, buf)? {
- None => None,
- Some(buf) => Some(decode_cobs(buf)?),
- })
- }
-
- fn write_postcard<W: Write>(
- self,
- to_proto: ProtocolWrite<W, Vec<u8>>,
- out: &mut W,
- ) -> io::Result<()> {
- let buf = encode_cobs(&self)?;
- to_proto(out, &buf)
+ fn write<W: Write, C: Codec>(self, out: &mut W) -> io::Result<()> {
+ let value = C::encode(&self)?;
+ C::write(out, &value)
}
}
impl Message for Request {}
impl Message for Response {}
-/// Type alias for a function that reads protocol messages from a buffered input stream.
-#[allow(type_alias_bounds)]
-type ProtocolRead<R: BufRead, Buf> =
- for<'i, 'buf> fn(inp: &'i mut R, buf: &'buf mut Buf) -> io::Result<Option<&'buf mut Buf>>;
-/// Type alias for a function that writes protocol messages to an output stream.
-#[allow(type_alias_bounds)]
-type ProtocolWrite<W: Write, Buf> =
- for<'o, 'msg> fn(out: &'o mut W, msg: &'msg Buf) -> io::Result<()>;
-
#[cfg(test)]
mod tests {
use intern::{Symbol, sym};
diff --git a/crates/proc-macro-api/src/legacy_protocol/postcard.rs b/crates/proc-macro-api/src/legacy_protocol/postcard.rs
index 305e4de..c28a9bf 100644
--- a/crates/proc-macro-api/src/legacy_protocol/postcard.rs
+++ b/crates/proc-macro-api/src/legacy_protocol/postcard.rs
@@ -2,28 +2,39 @@
use std::io::{self, BufRead, Write};
-pub fn read_postcard<'a>(
- input: &mut impl BufRead,
- buf: &'a mut Vec<u8>,
-) -> io::Result<Option<&'a mut Vec<u8>>> {
- buf.clear();
- let n = input.read_until(0, buf)?;
- if n == 0 {
- return Ok(None);
+use serde::{Serialize, de::DeserializeOwned};
+
+use crate::{codec::Codec, framing::Framing};
+
+pub struct PostcardProtocol;
+
+impl Framing for PostcardProtocol {
+ type Buf = Vec<u8>;
+
+ fn read<'a, R: BufRead>(
+ inp: &mut R,
+ buf: &'a mut Vec<u8>,
+ ) -> io::Result<Option<&'a mut Vec<u8>>> {
+ buf.clear();
+ let n = inp.read_until(0, buf)?;
+ if n == 0 {
+ return Ok(None);
+ }
+ Ok(Some(buf))
}
- Ok(Some(buf))
+
+ fn write<W: Write>(out: &mut W, buf: &Vec<u8>) -> io::Result<()> {
+ out.write_all(buf)?;
+ out.flush()
+ }
}
-#[allow(clippy::ptr_arg)]
-pub fn write_postcard(out: &mut impl Write, msg: &Vec<u8>) -> io::Result<()> {
- out.write_all(msg)?;
- out.flush()
-}
+impl Codec for PostcardProtocol {
+ fn encode<T: Serialize>(msg: &T) -> io::Result<Vec<u8>> {
+ postcard::to_allocvec_cobs(msg).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
+ }
-pub fn encode_cobs<T: serde::Serialize>(value: &T) -> io::Result<Vec<u8>> {
- postcard::to_allocvec_cobs(value).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
-}
-
-pub fn decode_cobs<T: serde::de::DeserializeOwned>(bytes: &mut [u8]) -> io::Result<T> {
- postcard::from_bytes_cobs(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
+ fn decode<T: DeserializeOwned>(buf: &mut Self::Buf) -> io::Result<T> {
+ postcard::from_bytes_cobs(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
+ }
}
diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs
index 2cdb33f..a725b94 100644
--- a/crates/proc-macro-api/src/lib.rs
+++ b/crates/proc-macro-api/src/lib.rs
@@ -12,6 +12,8 @@
)]
#![allow(internal_features)]
+mod codec;
+mod framing;
pub mod legacy_protocol;
mod process;
@@ -19,7 +21,8 @@
use span::{ErasedFileAstId, FIXUP_ERASED_FILE_AST_ID_MARKER, Span};
use std::{fmt, io, sync::Arc, time::SystemTime};
-use crate::process::ProcMacroServerProcess;
+pub use crate::codec::Codec;
+use crate::{legacy_protocol::SpanMode, process::ProcMacroServerProcess};
/// The versions of the server protocol
pub mod version {
@@ -123,7 +126,11 @@
Item = (impl AsRef<std::ffi::OsStr>, &'a Option<impl 'a + AsRef<std::ffi::OsStr>>),
> + Clone,
) -> io::Result<ProcMacroClient> {
- let process = ProcMacroServerProcess::run(process_path, env, process::Protocol::default())?;
+ let process = ProcMacroServerProcess::run(
+ process_path,
+ env,
+ process::Protocol::Postcard { mode: SpanMode::Id },
+ )?;
Ok(ProcMacroClient { process: Arc::new(process), path: process_path.to_owned() })
}
diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs
index 7f0cd05..1365245 100644
--- a/crates/proc-macro-api/src/process.rs
+++ b/crates/proc-macro-api/src/process.rs
@@ -34,12 +34,6 @@
Postcard { mode: SpanMode },
}
-impl Default for Protocol {
- fn default() -> Self {
- Protocol::Postcard { mode: SpanMode::Id }
- }
-}
-
/// Maintains the state of the proc-macro server process.
#[derive(Debug)]
struct ProcessSrvState {
@@ -122,11 +116,10 @@
srv.version = version;
if version >= version::RUST_ANALYZER_SPAN_SUPPORT
- && let Ok(mode) = srv.enable_rust_analyzer_spans()
+ && let Ok(new_mode) = srv.enable_rust_analyzer_spans()
{
- srv.protocol = match protocol {
- Protocol::Postcard { .. } => Protocol::Postcard { mode },
- Protocol::LegacyJson { .. } => Protocol::LegacyJson { mode },
+ match &mut srv.protocol {
+ Protocol::Postcard { mode } | Protocol::LegacyJson { mode } => *mode = new_mode,
};
}
diff --git a/crates/proc-macro-srv-cli/Cargo.toml b/crates/proc-macro-srv-cli/Cargo.toml
index f6022cf..aa15389 100644
--- a/crates/proc-macro-srv-cli/Cargo.toml
+++ b/crates/proc-macro-srv-cli/Cargo.toml
@@ -18,7 +18,7 @@
clap = {version = "4.5.42", default-features = false, features = ["std"]}
[features]
-default = ["postcard"]
+default = []
sysroot-abi = ["proc-macro-srv/sysroot-abi", "proc-macro-api/sysroot-abi"]
in-rust-tree = ["proc-macro-srv/in-rust-tree", "sysroot-abi"]
diff --git a/crates/proc-macro-srv-cli/src/main_loop.rs b/crates/proc-macro-srv-cli/src/main_loop.rs
index b0e7108..029ab6e 100644
--- a/crates/proc-macro-srv-cli/src/main_loop.rs
+++ b/crates/proc-macro-srv-cli/src/main_loop.rs
@@ -2,13 +2,14 @@
use std::io;
use proc_macro_api::{
+ Codec,
legacy_protocol::{
- json::{read_json, write_json},
+ json::JsonProtocol,
msg::{
self, ExpandMacroData, ExpnGlobals, Message, SpanMode, SpanTransformer,
deserialize_span_data_index_map, serialize_span_data_index_map,
},
- postcard::{read_postcard, write_postcard},
+ postcard::PostcardProtocol,
},
version::CURRENT_API_VERSION,
};
@@ -36,12 +37,12 @@
pub(crate) fn run(format: ProtocolFormat) -> io::Result<()> {
match format {
- ProtocolFormat::Json => run_json(),
- ProtocolFormat::Postcard => run_postcard(),
+ ProtocolFormat::Json => run_::<JsonProtocol>(),
+ ProtocolFormat::Postcard => run_::<PostcardProtocol>(),
}
}
-fn run_json() -> io::Result<()> {
+fn run_<C: Codec>() -> io::Result<()> {
fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
match kind {
proc_macro_srv::ProcMacroKind::CustomDerive => {
@@ -52,9 +53,9 @@
}
}
- let mut buf = String::new();
- let mut read_request = || msg::Request::read(read_json, &mut io::stdin().lock(), &mut buf);
- let write_response = |msg: msg::Response| msg.write(write_json, &mut io::stdout().lock());
+ let mut buf = C::Buf::default();
+ let mut read_request = || msg::Request::read::<_, C>(&mut io::stdin().lock(), &mut buf);
+ let write_response = |msg: msg::Response| msg.write::<_, C>(&mut io::stdout().lock());
let env = EnvSnapshot::default();
let srv = proc_macro_srv::ProcMacroSrv::new(&env);
@@ -170,134 +171,3 @@
Ok(())
}
-
-fn run_postcard() -> io::Result<()> {
- fn macro_kind_to_api(kind: proc_macro_srv::ProcMacroKind) -> proc_macro_api::ProcMacroKind {
- match kind {
- proc_macro_srv::ProcMacroKind::CustomDerive => {
- proc_macro_api::ProcMacroKind::CustomDerive
- }
- proc_macro_srv::ProcMacroKind::Bang => proc_macro_api::ProcMacroKind::Bang,
- proc_macro_srv::ProcMacroKind::Attr => proc_macro_api::ProcMacroKind::Attr,
- }
- }
-
- let mut buf = Vec::new();
- let mut read_request =
- || msg::Request::read_postcard(read_postcard, &mut io::stdin().lock(), &mut buf);
- let write_response =
- |msg: msg::Response| msg.write_postcard(write_postcard, &mut io::stdout().lock());
-
- let env = proc_macro_srv::EnvSnapshot::default();
- let srv = proc_macro_srv::ProcMacroSrv::new(&env);
-
- let mut span_mode = msg::SpanMode::Id;
-
- while let Some(req) = read_request()? {
- let res = match req {
- msg::Request::ListMacros { dylib_path } => {
- msg::Response::ListMacros(srv.list_macros(&dylib_path).map(|macros| {
- macros.into_iter().map(|(name, kind)| (name, macro_kind_to_api(kind))).collect()
- }))
- }
- msg::Request::ExpandMacro(task) => {
- let msg::ExpandMacro {
- lib,
- env,
- current_dir,
- data:
- msg::ExpandMacroData {
- macro_body,
- macro_name,
- attributes,
- has_global_spans:
- msg::ExpnGlobals { serialize: _, def_site, call_site, mixed_site },
- span_data_table,
- },
- } = *task;
- match span_mode {
- msg::SpanMode::Id => msg::Response::ExpandMacro({
- let def_site = proc_macro_srv::SpanId(def_site as u32);
- let call_site = proc_macro_srv::SpanId(call_site as u32);
- let mixed_site = proc_macro_srv::SpanId(mixed_site as u32);
-
- let macro_body =
- macro_body.to_subtree_unresolved::<SpanTrans>(CURRENT_API_VERSION);
- let attributes = attributes
- .map(|it| it.to_subtree_unresolved::<SpanTrans>(CURRENT_API_VERSION));
-
- srv.expand(
- lib,
- &env,
- current_dir,
- ¯o_name,
- macro_body,
- attributes,
- def_site,
- call_site,
- mixed_site,
- )
- .map(|it| {
- msg::FlatTree::new_raw::<SpanTrans>(
- tt::SubtreeView::new(&it),
- CURRENT_API_VERSION,
- )
- })
- .map_err(|e| e.into_string().unwrap_or_default())
- .map_err(msg::PanicMessage)
- }),
- msg::SpanMode::RustAnalyzer => msg::Response::ExpandMacroExtended({
- let mut span_data_table =
- msg::deserialize_span_data_index_map(&span_data_table);
-
- let def_site = span_data_table[def_site];
- let call_site = span_data_table[call_site];
- let mixed_site = span_data_table[mixed_site];
-
- let macro_body =
- macro_body.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table);
- let attributes = attributes.map(|it| {
- it.to_subtree_resolved(CURRENT_API_VERSION, &span_data_table)
- });
- srv.expand(
- lib,
- &env,
- current_dir,
- ¯o_name,
- macro_body,
- attributes,
- def_site,
- call_site,
- mixed_site,
- )
- .map(|it| {
- (
- msg::FlatTree::new(
- tt::SubtreeView::new(&it),
- CURRENT_API_VERSION,
- &mut span_data_table,
- ),
- msg::serialize_span_data_index_map(&span_data_table),
- )
- })
- .map(|(tree, span_data_table)| msg::ExpandMacroExtended {
- tree,
- span_data_table,
- })
- .map_err(|e| e.into_string().unwrap_or_default())
- .map_err(msg::PanicMessage)
- }),
- }
- }
- msg::Request::ApiVersionCheck {} => msg::Response::ApiVersionCheck(CURRENT_API_VERSION),
- msg::Request::SetConfig(config) => {
- span_mode = config.span_mode;
- msg::Response::SetConfig(config)
- }
- };
-
- write_response(res)?;
- }
-
- Ok(())
-}