Skip to main content

rustc_index_macros/
newtype.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::parse::*;
4use syn::*;
5
6// We parse the input and emit the output in a single step.
7// This field stores the final macro output
8struct Newtype(TokenStream);
9
10impl Parse for Newtype {
11    fn parse(input: ParseStream<'_>) -> Result<Self> {
12        let mut attrs = input.call(Attribute::parse_outer)?;
13        let vis: Visibility = input.parse()?;
14        input.parse::<Token![struct]>()?;
15        let name: Ident = input.parse()?;
16
17        let body;
18        braced!(body in input);
19
20        // Any additional `#[derive]` macro paths to apply
21        let mut debug_format: Option<Lit> = None;
22        let mut max = None;
23        let mut consts = Vec::new();
24        let mut encodable = false;
25        let mut ord = false;
26        let mut stable_hash = false;
27        let mut stable_hash_generic = false;
28        let mut stable_hash_no_context = false;
29        let mut gate_rustc_only = quote! {};
30        let mut gate_rustc_only_cfg = quote! { all() };
31
32        attrs.retain(|attr| match attr.path().get_ident() {
33            Some(ident) => match &*ident.to_string() {
34                "gate_rustc_only" => {
35                    gate_rustc_only = quote! { #[cfg(feature = "nightly")] };
36                    gate_rustc_only_cfg = quote! { feature = "nightly" };
37                    false
38                }
39                "encodable" => {
40                    encodable = true;
41                    false
42                }
43                "orderable" => {
44                    ord = true;
45                    false
46                }
47                "stable_hash" => {
48                    stable_hash = true;
49                    false
50                }
51                "stable_hash_generic" => {
52                    stable_hash_generic = true;
53                    false
54                }
55                "stable_hash_no_context" => {
56                    stable_hash_no_context = true;
57                    false
58                }
59                "max" => {
60                    let Meta::NameValue(MetaNameValue { value: Expr::Lit(lit), .. }) = &attr.meta
61                    else {
62                        panic!("#[max = NUMBER] attribute requires max value");
63                    };
64
65                    if let Some(old) = max.replace(lit.lit.clone()) {
66                        panic!("Specified multiple max: {old:?}");
67                    }
68
69                    false
70                }
71                "debug_format" => {
72                    let Meta::NameValue(MetaNameValue { value: Expr::Lit(lit), .. }) = &attr.meta
73                    else {
74                        panic!("#[debug_format = FMT] attribute requires a format");
75                    };
76
77                    if let Some(old) = debug_format.replace(lit.lit.clone()) {
78                        panic!("Specified multiple debug format options: {old:?}");
79                    }
80
81                    false
82                }
83                _ => true,
84            },
85            _ => true,
86        });
87
88        loop {
89            // We've parsed everything that the user provided, so we're done
90            if body.is_empty() {
91                break;
92            }
93
94            // Otherwise, we are parsing a user-defined constant
95            let const_attrs = body.call(Attribute::parse_outer)?;
96            body.parse::<Token![const]>()?;
97            let const_name: Ident = body.parse()?;
98            body.parse::<Token![=]>()?;
99            let const_val: Expr = body.parse()?;
100            body.parse::<Token![;]>()?;
101            consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
102        }
103
104        let debug_format =
105            debug_format.unwrap_or_else(|| Lit::Str(LitStr::new("{}", Span::call_site())));
106
107        // shave off 256 indices at the end to allow space for packing these indices into enums
108        let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
109
110        let encodable_impls = if encodable {
111            quote! {
112                #gate_rustc_only
113                impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
114                    fn decode(d: &mut D) -> Self {
115                        Self::from_u32(d.read_u32())
116                    }
117                }
118                #gate_rustc_only
119                impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
120                    fn encode(&self, e: &mut E) {
121                        e.emit_u32(self.as_u32());
122                    }
123                }
124            }
125        } else {
126            quote! {}
127        };
128        let step = if ord {
129            quote! {
130                #gate_rustc_only
131                impl ::std::iter::Step for #name {
132                    #[inline]
133                    fn steps_between(start: &Self, end: &Self) -> (usize, Option<usize>) {
134                        <usize as ::std::iter::Step>::steps_between(
135                            &Self::index(*start),
136                            &Self::index(*end),
137                        )
138                    }
139
140                    #[inline]
141                    fn forward_checked(start: Self, u: usize) -> Option<Self> {
142                        Self::index(start).checked_add(u).map(Self::from_usize)
143                    }
144
145                    #[inline]
146                    fn backward_checked(start: Self, u: usize) -> Option<Self> {
147                        Self::index(start).checked_sub(u).map(Self::from_usize)
148                    }
149                }
150                impl ::std::cmp::Ord for #name {
151                    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
152                        self.as_u32().cmp(&other.as_u32())
153                    }
154                }
155                impl ::std::cmp::PartialOrd for #name {
156                    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
157                        self.as_u32().partial_cmp(&other.as_u32())
158                    }
159                }
160            }
161        } else {
162            quote! {}
163        };
164
165        let hash_stable = if stable_hash {
166            quote! {
167                #gate_rustc_only
168                impl<'__ctx> ::rustc_data_structures::stable_hasher::HashStable<::rustc_middle::ich::StableHashingContext<'__ctx>> for #name {
169                    fn hash_stable(&self, hcx: &mut ::rustc_middle::ich::StableHashingContext<'__ctx>, hasher: &mut ::rustc_data_structures::stable_hasher::StableHasher) {
170                        self.as_u32().hash_stable(hcx, hasher)
171                    }
172                }
173            }
174        } else if stable_hash_generic || stable_hash_no_context {
175            quote! {
176                #gate_rustc_only
177                impl<Hcx> ::rustc_data_structures::stable_hasher::HashStable<Hcx> for #name {
178                    fn hash_stable(&self, hcx: &mut Hcx, hasher: &mut ::rustc_data_structures::stable_hasher::StableHasher) {
179                        self.as_u32().hash_stable(hcx, hasher)
180                    }
181                }
182            }
183        } else {
184            quote! {}
185        };
186
187        let debug_impl = quote! {
188            impl ::std::fmt::Debug for #name {
189                fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
190                    write!(fmt, #debug_format, self.as_u32())
191                }
192            }
193        };
194
195        Ok(Self(quote! {
196            #(#attrs)*
197            #[derive(Clone, Copy)]
198            #[cfg_attr(#gate_rustc_only_cfg, rustc_pass_by_value)]
199            #vis struct #name {
200                #[cfg(not(#gate_rustc_only_cfg))]
201                private_use_as_methods_instead: u32,
202                #[cfg(#gate_rustc_only_cfg)]
203                private_use_as_methods_instead: pattern_type!(u32 is 0..=#max),
204            }
205
206            #(#consts)*
207
208            impl #name {
209                /// Maximum value the index can take, as a `u32`.
210                #vis const MAX_AS_U32: u32  = #max;
211
212                /// Maximum value the index can take.
213                #vis const MAX: Self = Self::from_u32(#max);
214
215                /// Zero value of the index.
216                #vis const ZERO: Self = Self::from_u32(0);
217
218                /// Creates a new index from a given `usize`.
219                ///
220                /// # Panics
221                ///
222                /// Will panic if `value` exceeds `MAX`.
223                #[inline]
224                #vis const fn from_usize(value: usize) -> Self {
225                    assert!(value <= (#max as usize));
226                    // SAFETY: We just checked that `value <= max`.
227                    unsafe {
228                        Self::from_u32_unchecked(value as u32)
229                    }
230                }
231
232                /// Creates a new index from a given `u32`.
233                ///
234                /// # Panics
235                ///
236                /// Will panic if `value` exceeds `MAX`.
237                #[inline]
238                #vis const fn from_u32(value: u32) -> Self {
239                    assert!(value <= #max);
240                    // SAFETY: We just checked that `value <= max`.
241                    unsafe {
242                        Self::from_u32_unchecked(value)
243                    }
244                }
245
246                /// Creates a new index from a given `u16`.
247                ///
248                /// # Panics
249                ///
250                /// Will panic if `value` exceeds `MAX`.
251                #[inline]
252                #vis const fn from_u16(value: u16) -> Self {
253                    let value = value as u32;
254                    assert!(value <= #max);
255                    // SAFETY: We just checked that `value <= max`.
256                    unsafe {
257                        Self::from_u32_unchecked(value)
258                    }
259                }
260
261                /// Creates a new index from a given `u32`.
262                ///
263                /// # Safety
264                ///
265                /// The provided value must be less than or equal to the maximum value for the newtype.
266                /// Providing a value outside this range is undefined due to layout restrictions.
267                ///
268                /// Prefer using `from_u32`.
269                #[inline]
270                #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
271                    Self { private_use_as_methods_instead: unsafe { std::mem::transmute(value) } }
272                }
273
274                /// Extracts the value of this index as a `usize`.
275                #[inline]
276                #vis const fn index(self) -> usize {
277                    self.as_usize()
278                }
279
280                /// Extracts the value of this index as a `u32`.
281                #[inline]
282                #vis const fn as_u32(self) -> u32 {
283                    unsafe { std::mem::transmute(self.private_use_as_methods_instead) }
284                }
285
286                /// Extracts the value of this index as a `usize`.
287                #[inline]
288                #vis const fn as_usize(self) -> usize {
289                    self.as_u32() as usize
290                }
291            }
292
293            impl std::ops::Add<usize> for #name {
294                type Output = Self;
295
296                #[inline]
297                fn add(self, other: usize) -> Self {
298                    Self::from_usize(self.index() + other)
299                }
300            }
301
302            impl std::ops::AddAssign<usize> for #name {
303                #[inline]
304                fn add_assign(&mut self, other: usize) {
305                    *self = *self + other;
306                }
307            }
308
309            impl rustc_index::Idx for #name {
310                #[inline]
311                fn new(value: usize) -> Self {
312                    Self::from_usize(value)
313                }
314
315                #[inline]
316                fn index(self) -> usize {
317                    self.as_usize()
318                }
319            }
320
321            #step
322
323            #hash_stable
324
325            impl From<#name> for u32 {
326                #[inline]
327                fn from(v: #name) -> u32 {
328                    v.as_u32()
329                }
330            }
331
332            impl From<#name> for usize {
333                #[inline]
334                fn from(v: #name) -> usize {
335                    v.as_usize()
336                }
337            }
338
339            impl From<usize> for #name {
340                #[inline]
341                fn from(value: usize) -> Self {
342                    Self::from_usize(value)
343                }
344            }
345
346            impl From<u32> for #name {
347                #[inline]
348                fn from(value: u32) -> Self {
349                    Self::from_u32(value)
350                }
351            }
352
353            impl ::std::cmp::Eq for #name {}
354
355            impl ::std::cmp::PartialEq for #name {
356                fn eq(&self, other: &Self) -> bool {
357                    self.as_u32().eq(&other.as_u32())
358                }
359            }
360
361            #gate_rustc_only
362            impl ::std::marker::StructuralPartialEq for #name {}
363
364            impl ::std::hash::Hash for #name {
365                fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
366                    self.as_u32().hash(state)
367                }
368            }
369
370            #encodable_impls
371            #debug_impl
372        }))
373    }
374}
375
376pub(crate) fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
377    let input = parse_macro_input!(input as Newtype);
378    input.0.into()
379}