Skip to main content

rustc_expand/
proc_macro.rs

1use rustc_ast as ast;
2use rustc_ast::tokenstream::TokenStream;
3use rustc_data_structures::profiling::TimingGuard;
4use rustc_errors::ErrorGuaranteed;
5use rustc_middle::ty::{self, TyCtxt};
6use rustc_parse::parser::{AllowConstBlockItems, ForceCollect, Parser};
7use rustc_proc_macro as pm;
8use rustc_session::Session;
9use rustc_session::config::ProcMacroExecutionStrategy;
10use rustc_span::profiling::SpannedEventArgRecorder;
11use rustc_span::{LocalExpnId, Span};
12
13use crate::base::{self, *};
14use crate::{diagnostics, proc_macro_server};
15
16fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
17    pm::bridge::server::MaybeCrossThread {
18        cross_thread: sess.opts.unstable_opts.proc_macro_execution_strategy
19            == ProcMacroExecutionStrategy::CrossThread,
20    }
21}
22
23fn record_expand_proc_macro<'a>(
24    ecx: &ExtCtxt<'a>,
25    name: &'static str,
26    span: Span,
27) -> TimingGuard<'a> {
28    ecx.sess.prof.generic_activity_with_arg_recorder(name, |recorder| {
29        recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
30    })
31}
32
33pub struct BangProcMacro {
34    pub client: pm::bridge::client::Client,
35}
36
37impl base::BangProcMacro for BangProcMacro {
38    fn expand(
39        &self,
40        ecx: &mut ExtCtxt<'_>,
41        span: Span,
42        input: TokenStream,
43    ) -> Result<TokenStream, ErrorGuaranteed> {
44        let _timer = record_expand_proc_macro(ecx, "expand_proc_macro", span);
45
46        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
47        let strategy = exec_strategy(ecx.sess);
48        let server = proc_macro_server::Rustc::new(ecx);
49        self.client.run1(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
50            ecx.dcx().emit_err(diagnostics::ProcMacroPanicked {
51                span,
52                message: e
53                    .into_string()
54                    .map(|message| diagnostics::ProcMacroPanickedHelp { message }),
55            })
56        })
57    }
58}
59
60pub struct AttrProcMacro {
61    pub client: pm::bridge::client::Client,
62}
63
64impl base::AttrProcMacro for AttrProcMacro {
65    fn expand(
66        &self,
67        ecx: &mut ExtCtxt<'_>,
68        span: Span,
69        annotation: TokenStream,
70        annotated: TokenStream,
71    ) -> Result<TokenStream, ErrorGuaranteed> {
72        let _timer = record_expand_proc_macro(ecx, "expand_proc_macro", span);
73
74        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
75        let strategy = exec_strategy(ecx.sess);
76        let server = proc_macro_server::Rustc::new(ecx);
77        self.client.run2(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
78            |e| {
79                ecx.dcx().emit_err(diagnostics::CustomAttributePanicked {
80                    span,
81                    message: e
82                        .into_string()
83                        .map(|message| diagnostics::CustomAttributePanickedHelp { message }),
84                })
85            },
86        )
87    }
88}
89
90pub struct DeriveProcMacro {
91    pub client: DeriveClient,
92}
93
94impl MultiItemModifier for DeriveProcMacro {
95    fn expand(
96        &self,
97        ecx: &mut ExtCtxt<'_>,
98        span: Span,
99        _meta_item: &ast::MetaItem,
100        item: Annotatable,
101        _is_derive_const: bool,
102    ) -> ExpandResult<Vec<Annotatable>, Annotatable> {
103        let _timer = record_expand_proc_macro(ecx, "expand_derive_proc_macro_outer", span);
104
105        // We need special handling for statement items
106        // (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
107        let is_stmt = #[allow(non_exhaustive_omitted_patterns)] match item {
    Annotatable::Stmt(..) => true,
    _ => false,
}matches!(item, Annotatable::Stmt(..));
108
109        let input = item.to_tokens();
110
111        let invoc_id = ecx.current_expansion.id;
112
113        let res = if ecx.sess.opts.incremental.is_some()
114            && ecx.sess.opts.unstable_opts.cache_proc_macros
115        {
116            ty::tls::with(|tcx| {
117                let input = &*tcx.arena.alloc(input);
118                let key: (LocalExpnId, &TokenStream) = (invoc_id, input);
119
120                QueryDeriveExpandCtx::enter(ecx, self.client, move || {
121                    tcx.derive_macro_expansion(key).cloned()
122                })
123            })
124        } else {
125            expand_derive_macro(invoc_id, input, ecx, self.client)
126        };
127
128        let Ok(output) = res else {
129            // error will already have been emitted
130            return ExpandResult::Ready(::alloc::vec::Vec::new()vec![]);
131        };
132
133        let error_count_before = ecx.dcx().err_count();
134        let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
135        let mut items = ::alloc::vec::Vec::new()vec![];
136
137        loop {
138            match parser.parse_item(
139                ForceCollect::No,
140                if is_stmt { AllowConstBlockItems::No } else { AllowConstBlockItems::Yes },
141            ) {
142                Ok(None) => break,
143                Ok(Some(item)) => {
144                    if is_stmt {
145                        items.push(Annotatable::Stmt(Box::new(ecx.stmt_item(span, item))));
146                    } else {
147                        items.push(Annotatable::Item(item));
148                    }
149                }
150                Err(err) => {
151                    err.emit();
152                    break;
153                }
154            }
155        }
156
157        // fail if there have been errors emitted
158        if ecx.dcx().err_count() > error_count_before {
159            ecx.dcx().emit_err(diagnostics::ProcMacroDeriveTokens { span });
160        }
161
162        ExpandResult::Ready(items)
163    }
164}
165
166/// Provide a query for computing the output of a derive macro.
167pub(super) fn provide_derive_macro_expansion<'tcx>(
168    tcx: TyCtxt<'tcx>,
169    key: (LocalExpnId, &'tcx TokenStream),
170) -> Result<&'tcx TokenStream, ()> {
171    let (invoc_id, input) = key;
172
173    // Make sure that we invalidate the query when the crate defining the proc macro changes
174    let _ = tcx.crate_hash(invoc_id.expn_data().macro_def_id.unwrap().krate);
175
176    QueryDeriveExpandCtx::with(|ecx, client| {
177        expand_derive_macro(invoc_id, input.clone(), ecx, client).map(|ts| &*tcx.arena.alloc(ts))
178    })
179}
180
181type DeriveClient = pm::bridge::client::Client;
182
183fn expand_derive_macro(
184    invoc_id: LocalExpnId,
185    input: TokenStream,
186    ecx: &mut ExtCtxt<'_>,
187    client: DeriveClient,
188) -> Result<TokenStream, ()> {
189    let _timer =
190        ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
191            let invoc_expn_data = invoc_id.expn_data();
192            let span = invoc_expn_data.call_site;
193            let event_arg = invoc_expn_data.kind.descr();
194            recorder.record_arg_with_span(ecx.sess.source_map(), event_arg, span);
195        });
196
197    let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
198    let strategy = exec_strategy(ecx.sess);
199    let server = proc_macro_server::Rustc::new(ecx);
200
201    match client.run1(&strategy, server, input, proc_macro_backtrace) {
202        Ok(stream) => Ok(stream),
203        Err(e) => {
204            let invoc_expn_data = invoc_id.expn_data();
205            let span = invoc_expn_data.call_site;
206            ecx.dcx().emit_err({
207                diagnostics::ProcMacroDerivePanicked {
208                    span,
209                    message: e
210                        .into_string()
211                        .map(|message| diagnostics::ProcMacroDerivePanickedHelp { message }),
212                }
213            });
214            Err(())
215        }
216    }
217}
218
219/// Stores the context necessary to expand a derive proc macro via a query.
220struct QueryDeriveExpandCtx {
221    /// Type-erased version of `&mut ExtCtxt`
222    expansion_ctx: *mut (),
223    client: DeriveClient,
224}
225
226impl QueryDeriveExpandCtx {
227    /// Store the extension context and the client into the thread local value.
228    /// It will be accessible via the `with` method while `f` is active.
229    fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
230    where
231        F: FnOnce() -> R,
232    {
233        // We need erasure to get rid of the lifetime
234        let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
235        DERIVE_EXPAND_CTX.set(&ctx, f)
236    }
237
238    /// Accesses the thread local value of the derive expansion context.
239    /// Must be called while the `enter` function is active.
240    fn with<F, R>(f: F) -> R
241    where
242        F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
243    {
244        DERIVE_EXPAND_CTX.with(|ctx| {
245            let ectx = {
246                let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
247                // SAFETY: We can only get the value from `with` while the `enter` function
248                // is active (on the callstack), and that function's signature ensures that the
249                // lifetime is valid.
250                // If `with` is called at some other time, it will panic due to usage of
251                // `scoped_tls::with`.
252                unsafe { casted.as_mut().unwrap() }
253            };
254
255            f(ectx, ctx.client)
256        })
257    }
258}
259
260// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
261// context and derive Client. We do that using a thread-local.
262static DERIVE_EXPAND_CTX: ::scoped_tls::ScopedKey<QueryDeriveExpandCtx> =
    ::scoped_tls::ScopedKey {
        inner: {
            const FOO: ::std::thread::LocalKey<::std::cell::Cell<*const ()>> =
                {
                    const __RUST_STD_INTERNAL_INIT: ::std::cell::Cell<*const ()>
                        =
                        { ::std::cell::Cell::new(::std::ptr::null()) };
                    unsafe {
                        ::std::thread::LocalKey::new(const {
                                    if ::std::mem::needs_drop::<::std::cell::Cell<*const ()>>()
                                        {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL:
                                                    ::std::thread::local_impl::EagerStorage<::std::cell::Cell<*const ()>>
                                                    =
                                                    ::std::thread::local_impl::EagerStorage::new(__RUST_STD_INTERNAL_INIT);
                                                __RUST_STD_INTERNAL_VAL.get()
                                            }
                                    } else {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL: ::std::cell::Cell<*const ()>
                                                    =
                                                    __RUST_STD_INTERNAL_INIT;
                                                &__RUST_STD_INTERNAL_VAL
                                            }
                                    }
                                })
                    }
                };
            &FOO
        },
        _marker: ::std::marker::PhantomData,
    };scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);