Skip to main content

proc_macro/bridge/
rpc.rs

1//! Serialization for client-server communication.
2
3use std::any::Any;
4use std::io::Write;
5use std::num::NonZero;
6
7use super::buffer::Buffer;
8
9pub(super) trait Encode<S>: Sized {
10    fn encode(self, w: &mut Buffer, s: &mut S);
11}
12
13pub(super) trait Decode<'a, 's, S>: Sized {
14    fn decode(r: &mut &'a [u8], s: &'s mut S) -> Self;
15}
16
17macro_rules! rpc_encode_decode {
18    (le $ty:ty) => {
19        impl<S> Encode<S> for $ty {
20            #[inline]
21            fn encode(self, w: &mut Buffer, _: &mut S) {
22                w.extend_from_array(&self.to_le_bytes());
23            }
24        }
25
26        impl<S> Decode<'_, '_, S> for $ty {
27            #[inline]
28            fn decode(r: &mut &[u8], _: &mut S) -> Self {
29                const N: usize = size_of::<$ty>();
30
31                let mut bytes = [0; N];
32                bytes.copy_from_slice(&r[..N]);
33                *r = &r[N..];
34
35                Self::from_le_bytes(bytes)
36            }
37        }
38    };
39    (struct $name:ident $(<$($T:ident),+>)? { $($field:ident),* $(,)? }) => {
40        impl<S, $($($T: Encode<S>),+)?> Encode<S> for $name $(<$($T),+>)? {
41            fn encode(self, w: &mut Buffer, s: &mut S) {
42                $(self.$field.encode(w, s);)*
43            }
44        }
45
46        impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
47            for $name $(<$($T),+>)?
48        {
49            #[inline]
50            fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
51                $name {
52                    $($field: Decode::decode(r, s)),*
53                }
54            }
55        }
56    };
57    (enum $name:ident $(<$($T:ident),+>)? { $($variant:ident $(($field:ident))*),* $(,)? }) => {
58        #[allow(non_upper_case_globals, non_camel_case_types)]
59        const _: () = {
60            #[repr(u8)] enum Tag { $($variant),* }
61
62            $(const $variant: u8 = Tag::$variant as u8;)*
63
64            impl<S, $($($T: Encode<S>),+)?> Encode<S> for $name $(<$($T),+>)? {
65                #[inline]
66                fn encode(self, w: &mut Buffer, s: &mut S) {
67                    match self {
68                        $($name::$variant $(($field))* => {
69                            $variant.encode(w, s);
70                            $($field.encode(w, s);)*
71                        })*
72                    }
73                }
74            }
75
76            impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
77                for $name $(<$($T),+>)?
78            {
79                #[inline]
80                fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
81                    match u8::decode(r, s) {
82                        $($variant => {
83                            $(let $field = Decode::decode(r, s);)*
84                            $name::$variant $(($field))*
85                        })*
86                        _ => unreachable!(),
87                    }
88                }
89            }
90        };
91    }
92}
93
94impl<S> Encode<S> for () {
95    #[inline]
96    fn encode(self, _: &mut Buffer, _: &mut S) {}
97}
98
99impl<S> Decode<'_, '_, S> for () {
100    #[inline]
101    fn decode(_: &mut &[u8], _: &mut S) -> Self {}
102}
103
104impl<S> Encode<S> for u8 {
105    #[inline]
106    fn encode(self, w: &mut Buffer, _: &mut S) {
107        w.push(self);
108    }
109}
110
111impl<S> Decode<'_, '_, S> for u8 {
112    #[inline]
113    fn decode(r: &mut &[u8], _: &mut S) -> Self {
114        let x = r[0];
115        *r = &r[1..];
116        x
117    }
118}
119
120rpc_encode_decode!(le u32);
121#[cfg(target_pointer_width = "64")]
122rpc_encode_decode!(le usize);
123
124#[cfg(not(target_pointer_width = "64"))]
125const MAX_USIZE_SIZE: usize = 8;
126
127#[cfg(not(target_pointer_width = "64"))]
128impl<S> Encode<S> for usize {
129    #[inline]
130    fn encode(self, w: &mut Buffer, _: &mut S) {
131        const N: usize = size_of::<usize>();
132
133        // We can pad with zeros without changing the value because of
134        // little endian encoding.
135        let mut bytes = [0; MAX_USIZE_SIZE];
136        bytes[..N].copy_from_slice(&self.to_le_bytes());
137
138        w.extend_from_array(&bytes);
139    }
140}
141
142#[cfg(not(target_pointer_width = "64"))]
143impl<S> Decode<'_, '_, S> for usize {
144    #[inline]
145    fn decode(r: &mut &[u8], _: &mut S) -> Self {
146        const N: usize = size_of::<usize>();
147        const {
148            assert!(N <= MAX_USIZE_SIZE);
149        }
150
151        let mut bytes = [0; N];
152        bytes.copy_from_slice(&r[..N]);
153        *r = &r[MAX_USIZE_SIZE..];
154
155        Self::from_le_bytes(bytes)
156    }
157}
158
159impl<S> Encode<S> for bool {
160    #[inline]
161    fn encode(self, w: &mut Buffer, s: &mut S) {
162        (self as u8).encode(w, s);
163    }
164}
165
166impl<S> Decode<'_, '_, S> for bool {
167    #[inline]
168    fn decode(r: &mut &[u8], s: &mut S) -> Self {
169        match u8::decode(r, s) {
170            0 => false,
171            1 => true,
172            _ => unreachable!(),
173        }
174    }
175}
176
177impl<S> Encode<S> for NonZero<u32> {
178    #[inline]
179    fn encode(self, w: &mut Buffer, s: &mut S) {
180        self.get().encode(w, s);
181    }
182}
183
184impl<S> Decode<'_, '_, S> for NonZero<u32> {
185    #[inline]
186    fn decode(r: &mut &[u8], s: &mut S) -> Self {
187        Self::new(u32::decode(r, s)).unwrap()
188    }
189}
190
191impl<S, A: Encode<S>, B: Encode<S>> Encode<S> for (A, B) {
192    #[inline]
193    fn encode(self, w: &mut Buffer, s: &mut S) {
194        self.0.encode(w, s);
195        self.1.encode(w, s);
196    }
197}
198
199impl<'a, S, A: for<'s> Decode<'a, 's, S>, B: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S>
200    for (A, B)
201{
202    #[inline]
203    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
204        (Decode::decode(r, s), Decode::decode(r, s))
205    }
206}
207
208impl<S> Encode<S> for &str {
209    #[inline]
210    fn encode(self, w: &mut Buffer, s: &mut S) {
211        let bytes = self.as_bytes();
212        bytes.len().encode(w, s);
213        w.write_all(bytes).unwrap();
214    }
215}
216
217impl<'a, S> Decode<'a, '_, S> for &'a str {
218    #[inline]
219    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
220        let len = usize::decode(r, s);
221        let xs = &r[..len];
222        *r = &r[len..];
223        str::from_utf8(xs).unwrap()
224    }
225}
226
227impl<S> Encode<S> for String {
228    #[inline]
229    fn encode(self, w: &mut Buffer, s: &mut S) {
230        self[..].encode(w, s);
231    }
232}
233
234impl<S> Decode<'_, '_, S> for String {
235    #[inline]
236    fn decode(r: &mut &[u8], s: &mut S) -> Self {
237        <&str>::decode(r, s).to_string()
238    }
239}
240
241impl<S, T: Encode<S>> Encode<S> for Vec<T> {
242    #[inline]
243    fn encode(self, w: &mut Buffer, s: &mut S) {
244        self.len().encode(w, s);
245        for x in self {
246            x.encode(w, s);
247        }
248    }
249}
250
251impl<'a, S, T: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S> for Vec<T> {
252    #[inline]
253    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
254        let len = usize::decode(r, s);
255        let mut vec = Vec::with_capacity(len);
256        for _ in 0..len {
257            vec.push(T::decode(r, s));
258        }
259        vec
260    }
261}
262
263/// Simplified version of panic payloads, ignoring
264/// types other than `&'static str` and `String`.
265pub enum PanicMessage {
266    StaticStr(&'static str),
267    String(String),
268    Unknown,
269}
270
271impl From<Box<dyn Any + Send>> for PanicMessage {
272    fn from(payload: Box<dyn Any + Send + 'static>) -> Self {
273        if let Some(s) = payload.downcast_ref::<&'static str>() {
274            return PanicMessage::StaticStr(s);
275        }
276        if let Ok(s) = payload.downcast::<String>() {
277            return PanicMessage::String(*s);
278        }
279        PanicMessage::Unknown
280    }
281}
282
283impl From<PanicMessage> for Box<dyn Any + Send> {
284    fn from(val: PanicMessage) -> Self {
285        match val {
286            PanicMessage::StaticStr(s) => Box::new(s),
287            PanicMessage::String(s) => Box::new(s),
288            PanicMessage::Unknown => {
289                struct UnknownPanicMessage;
290                Box::new(UnknownPanicMessage)
291            }
292        }
293    }
294}
295
296impl PanicMessage {
297    pub fn as_str(&self) -> Option<&str> {
298        match self {
299            PanicMessage::StaticStr(s) => Some(s),
300            PanicMessage::String(s) => Some(s),
301            PanicMessage::Unknown => None,
302        }
303    }
304
305    pub fn into_string(self) -> Option<String> {
306        match self {
307            PanicMessage::StaticStr(s) => Some(s.into()),
308            PanicMessage::String(s) => Some(s),
309            PanicMessage::Unknown => None,
310        }
311    }
312}
313
314impl<S> Encode<S> for PanicMessage {
315    #[inline]
316    fn encode(self, w: &mut Buffer, s: &mut S) {
317        self.as_str().encode(w, s);
318    }
319}
320
321impl<S> Decode<'_, '_, S> for PanicMessage {
322    #[inline]
323    fn decode(r: &mut &[u8], s: &mut S) -> Self {
324        match Option::<String>::decode(r, s) {
325            Some(s) => PanicMessage::String(s),
326            None => PanicMessage::Unknown,
327        }
328    }
329}