Skip to main content

proc_macro/bridge/
server.rs

1//! Server-side traits.
2
3use std::cell::Cell;
4use std::hash::Hash;
5use std::ops::{Bound, Range};
6use std::sync::atomic::AtomicU32;
7use std::sync::mpsc;
8use std::{panic, thread};
9
10use crate::bridge::{
11    ApiTags, BridgeConfig, Buffer, Decode, Diagnostic, Encode, ExpnGlobals, Literal, Mark, Marked,
12    PanicMessage, TokenTree, client, handle,
13};
14
15pub(super) struct HandleStore<S: Server> {
16    token_stream: handle::OwnedStore<MarkedTokenStream<S>>,
17    span: handle::InternedStore<MarkedSpan<S>>,
18}
19
20impl<S: Server> HandleStore<S> {
21    fn new() -> Self {
22        static TOKEN_STREAM: AtomicU32 = AtomicU32::new(1);
23        static SPAN: AtomicU32 = AtomicU32::new(1);
24
25        HandleStore {
26            token_stream: handle::OwnedStore::new(&TOKEN_STREAM),
27            span: handle::InternedStore::new(&SPAN),
28        }
29    }
30}
31
32pub(super) type MarkedTokenStream<S> = Marked<<S as Server>::TokenStream, client::TokenStream>;
33pub(super) type MarkedSpan<S> = Marked<<S as Server>::Span, client::Span>;
34pub(super) type MarkedSymbol<S> = Marked<<S as Server>::Symbol, client::Symbol>;
35
36impl<S: Server> Encode<HandleStore<S>> for MarkedTokenStream<S> {
37    fn encode(self, w: &mut Buffer, s: &mut HandleStore<S>) {
38        s.token_stream.alloc(self).encode(w, s);
39    }
40}
41
42impl<S: Server> Decode<'_, '_, HandleStore<S>> for MarkedTokenStream<S> {
43    fn decode(r: &mut &[u8], s: &mut HandleStore<S>) -> Self {
44        s.token_stream.take(handle::Handle::decode(r, &mut ()))
45    }
46}
47
48impl<'s, S: Server> Decode<'_, 's, HandleStore<S>> for &'s MarkedTokenStream<S> {
49    fn decode(r: &mut &[u8], s: &'s mut HandleStore<S>) -> Self {
50        &s.token_stream[handle::Handle::decode(r, &mut ())]
51    }
52}
53
54impl<S: Server> Encode<HandleStore<S>> for MarkedSpan<S> {
55    fn encode(self, w: &mut Buffer, s: &mut HandleStore<S>) {
56        s.span.alloc(self).encode(w, s);
57    }
58}
59
60impl<S: Server> Decode<'_, '_, HandleStore<S>> for MarkedSpan<S> {
61    fn decode(r: &mut &[u8], s: &mut HandleStore<S>) -> Self {
62        s.span.copy(handle::Handle::decode(r, &mut ()))
63    }
64}
65
66macro_rules! define_server {
67    (
68        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
69    ) => {
70        pub trait Server {
71            type TokenStream: 'static + Clone + Default;
72            type Span: 'static + Copy + Eq + Hash;
73            type Symbol: 'static;
74
75            fn globals(&mut self) -> ExpnGlobals<Self::Span>;
76
77            /// Intern a symbol received from RPC
78            fn intern_symbol(ident: &str) -> Self::Symbol;
79
80            /// Recover the string value of a symbol, and invoke a callback with it.
81            fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str));
82
83            $(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)?;)*
84        }
85    }
86}
87with_api!(define_server, Self::TokenStream, Self::Span, Self::Symbol);
88
89// FIXME(eddyb) `pub` only for `ExecutionStrategy` below.
90pub struct Dispatcher<S: Server> {
91    handle_store: HandleStore<S>,
92    server: S,
93}
94
95macro_rules! define_dispatcher {
96    (
97        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
98    ) => {
99        impl<S: Server> Dispatcher<S> {
100            fn dispatch(&mut self, mut buf: Buffer) -> Buffer {
101                let Dispatcher { handle_store, server } = self;
102
103                let mut reader = &buf[..];
104                match ApiTags::decode(&mut reader, &mut ()) {
105                    $(ApiTags::$method => {
106                        let mut call_method = || {
107                            $(let $arg = <$arg_ty>::decode(&mut reader, handle_store).unmark();)*
108                            let r = server.$method($($arg),*);
109                            $(let r: $ret_ty = Mark::mark(r);)?
110                            r
111                        };
112                        // HACK(eddyb) don't use `panic::catch_unwind` in a panic.
113                        // If client and server happen to use the same `std`,
114                        // `catch_unwind` asserts that the panic counter was 0,
115                        // even when the closure passed to it didn't panic.
116                        let r = if thread::panicking() {
117                            Ok(call_method())
118                        } else {
119                            panic::catch_unwind(panic::AssertUnwindSafe(call_method))
120                                .map_err(PanicMessage::from)
121                        };
122
123                        buf.clear();
124                        r.encode(&mut buf, handle_store);
125                    })*
126                }
127                buf
128            }
129        }
130    }
131}
132with_api!(define_dispatcher, MarkedTokenStream<S>, MarkedSpan<S>, MarkedSymbol<S>);
133
134// This trait is currently only implemented and used once, inside of this crate.
135// We keep it public to allow implementing more complex execution strategies in
136// the future, such as wasm proc-macros.
137pub trait ExecutionStrategy {
138    fn run_bridge_and_client(
139        &self,
140        dispatcher: &mut Dispatcher<impl Server>,
141        input: Buffer,
142        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
143        force_show_panics: bool,
144    ) -> Buffer;
145}
146
147thread_local! {
148    /// While running a proc-macro with the same-thread executor, this flag will
149    /// be set, forcing nested proc-macro invocations (e.g. due to
150    /// `TokenStream::expand_expr`) to be run using a cross-thread executor.
151    ///
152    /// This is required as the thread-local state in the proc_macro client does
153    /// not handle being re-entered, and will invalidate all `Symbol`s when
154    /// entering a nested macro.
155    static ALREADY_RUNNING_SAME_THREAD: Cell<bool> = const { Cell::new(false) };
156}
157
158/// Keep `ALREADY_RUNNING_SAME_THREAD` (see also its documentation)
159/// set to `true`, preventing same-thread reentrance.
160struct RunningSameThreadGuard(());
161
162impl RunningSameThreadGuard {
163    fn new() -> Self {
164        let already_running = ALREADY_RUNNING_SAME_THREAD.replace(true);
165        assert!(
166            !already_running,
167            "same-thread nesting (\"reentrance\") of proc macro executions is not supported"
168        );
169        RunningSameThreadGuard(())
170    }
171}
172
173impl Drop for RunningSameThreadGuard {
174    fn drop(&mut self) {
175        ALREADY_RUNNING_SAME_THREAD.set(false);
176    }
177}
178
179pub struct MaybeCrossThread {
180    pub cross_thread: bool,
181}
182
183pub const SAME_THREAD: MaybeCrossThread = MaybeCrossThread { cross_thread: false };
184pub const CROSS_THREAD: MaybeCrossThread = MaybeCrossThread { cross_thread: true };
185
186impl ExecutionStrategy for MaybeCrossThread {
187    fn run_bridge_and_client(
188        &self,
189        dispatcher: &mut Dispatcher<impl Server>,
190        input: Buffer,
191        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
192        force_show_panics: bool,
193    ) -> Buffer {
194        if self.cross_thread || ALREADY_RUNNING_SAME_THREAD.get() {
195            let (mut server, mut client) = MessagePipe::new();
196
197            let join_handle = thread::spawn(move || {
198                let mut dispatch = |b: Buffer| -> Buffer {
199                    client.send(b);
200                    client.recv().expect("server died while client waiting for reply")
201                };
202
203                run_client(BridgeConfig {
204                    input,
205                    dispatch: (&mut dispatch).into(),
206                    force_show_panics,
207                })
208            });
209
210            while let Some(b) = server.recv() {
211                server.send(dispatcher.dispatch(b));
212            }
213
214            join_handle.join().unwrap()
215        } else {
216            let _guard = RunningSameThreadGuard::new();
217
218            let mut dispatch = |buf| dispatcher.dispatch(buf);
219
220            run_client(BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics })
221        }
222    }
223}
224
225/// A message pipe used for communicating between server and client threads.
226struct MessagePipe<T> {
227    tx: mpsc::SyncSender<T>,
228    rx: mpsc::Receiver<T>,
229}
230
231impl<T> MessagePipe<T> {
232    /// Creates a new pair of endpoints for the message pipe.
233    fn new() -> (Self, Self) {
234        let (tx1, rx1) = mpsc::sync_channel(1);
235        let (tx2, rx2) = mpsc::sync_channel(1);
236        (MessagePipe { tx: tx1, rx: rx2 }, MessagePipe { tx: tx2, rx: rx1 })
237    }
238
239    /// Send a message to the other endpoint of this pipe.
240    fn send(&mut self, value: T) {
241        self.tx.send(value).unwrap();
242    }
243
244    /// Receive a message from the other endpoint of this pipe.
245    ///
246    /// Returns `None` if the other end of the pipe has been destroyed, and no
247    /// message was received.
248    fn recv(&mut self) -> Option<T> {
249        self.rx.recv().ok()
250    }
251}
252
253fn run_server<
254    S: Server,
255    I: Encode<HandleStore<S>>,
256    O: for<'a, 's> Decode<'a, 's, HandleStore<S>>,
257>(
258    strategy: &impl ExecutionStrategy,
259    server: S,
260    input: I,
261    run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
262    force_show_panics: bool,
263) -> Result<O, PanicMessage> {
264    let mut dispatcher = Dispatcher { handle_store: HandleStore::new(), server };
265
266    let globals = dispatcher.server.globals();
267
268    let mut buf = Buffer::new();
269    (<ExpnGlobals<MarkedSpan<S>> as Mark>::mark(globals), input)
270        .encode(&mut buf, &mut dispatcher.handle_store);
271
272    buf = strategy.run_bridge_and_client(&mut dispatcher, buf, run_client, force_show_panics);
273
274    Result::decode(&mut &buf[..], &mut dispatcher.handle_store)
275}
276
277impl client::Client {
278    pub fn run1<S>(
279        &self,
280        strategy: &impl ExecutionStrategy,
281        server: S,
282        input: S::TokenStream,
283        force_show_panics: bool,
284    ) -> Result<S::TokenStream, PanicMessage>
285    where
286        S: Server,
287    {
288        let client::Client { run } = *self;
289        run_server(strategy, server, <MarkedTokenStream<S>>::mark(input), run, force_show_panics)
290            .map(|s| <Option<MarkedTokenStream<S>>>::unmark(s).unwrap_or_default())
291    }
292
293    pub fn run2<S>(
294        &self,
295        strategy: &impl ExecutionStrategy,
296        server: S,
297        input: S::TokenStream,
298        input2: S::TokenStream,
299        force_show_panics: bool,
300    ) -> Result<S::TokenStream, PanicMessage>
301    where
302        S: Server,
303    {
304        let client::Client { run } = *self;
305        run_server(
306            strategy,
307            server,
308            (<MarkedTokenStream<S>>::mark(input), <MarkedTokenStream<S>>::mark(input2)),
309            run,
310            force_show_panics,
311        )
312        .map(|s| <Option<MarkedTokenStream<S>>>::unmark(s).unwrap_or_default())
313    }
314}