Skip to main content

rustc_index_macros/
newtype.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::parse::*;
4use syn::*;
5
6// We parse the input and emit the output in a single step.
7// This field stores the final macro output
8struct Newtype(TokenStream);
9
10impl Parse for Newtype {
11    fn parse(input: ParseStream<'_>) -> Result<Self> {
12        let mut attrs = input.call(Attribute::parse_outer)?;
13        let vis: Visibility = input.parse()?;
14        input.parse::<Token![struct]>()?;
15        let name: Ident = input.parse()?;
16
17        let body;
18        braced!(body in input);
19
20        // Any additional `#[derive]` macro paths to apply
21        let mut debug_format: Option<Lit> = None;
22        let mut max = None;
23        let mut consts = Vec::new();
24        let mut encodable = false;
25        let mut ord = false;
26        let mut stable_hash = false;
27        let mut gate_rustc_only = quote! {};
28        let mut gate_rustc_only_cfg = quote! { all() };
29
30        attrs.retain(|attr| match attr.path().get_ident() {
31            Some(ident) => match &*ident.to_string() {
32                "gate_rustc_only" => {
33                    gate_rustc_only = quote! { #[cfg(feature = "nightly")] };
34                    gate_rustc_only_cfg = quote! { feature = "nightly" };
35                    false
36                }
37                "encodable" => {
38                    encodable = true;
39                    false
40                }
41                "orderable" => {
42                    ord = true;
43                    false
44                }
45                "stable_hash" => {
46                    stable_hash = true;
47                    false
48                }
49                "max" => {
50                    let Meta::NameValue(MetaNameValue { value: Expr::Lit(lit), .. }) = &attr.meta
51                    else {
52                        panic!("#[max = NUMBER] attribute requires max value");
53                    };
54
55                    if let Some(old) = max.replace(lit.lit.clone()) {
56                        panic!("Specified multiple max: {old:?}");
57                    }
58
59                    false
60                }
61                "debug_format" => {
62                    let Meta::NameValue(MetaNameValue { value: Expr::Lit(lit), .. }) = &attr.meta
63                    else {
64                        panic!("#[debug_format = FMT] attribute requires a format");
65                    };
66
67                    if let Some(old) = debug_format.replace(lit.lit.clone()) {
68                        panic!("Specified multiple debug format options: {old:?}");
69                    }
70
71                    false
72                }
73                _ => true,
74            },
75            _ => true,
76        });
77
78        loop {
79            // We've parsed everything that the user provided, so we're done
80            if body.is_empty() {
81                break;
82            }
83
84            // Otherwise, we are parsing a user-defined constant
85            let const_attrs = body.call(Attribute::parse_outer)?;
86            body.parse::<Token![const]>()?;
87            let const_name: Ident = body.parse()?;
88            body.parse::<Token![=]>()?;
89            let const_val: Expr = body.parse()?;
90            body.parse::<Token![;]>()?;
91            consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
92        }
93
94        let debug_format =
95            debug_format.unwrap_or_else(|| Lit::Str(LitStr::new("{}", Span::call_site())));
96
97        // shave off 256 indices at the end to allow space for packing these indices into enums
98        let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
99
100        let encodable_impls = if encodable {
101            quote! {
102                #gate_rustc_only
103                impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
104                    fn decode(d: &mut D) -> Self {
105                        Self::from_u32(d.read_u32())
106                    }
107                }
108                #gate_rustc_only
109                impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
110                    fn encode(&self, e: &mut E) {
111                        e.emit_u32(self.as_u32());
112                    }
113                }
114            }
115        } else {
116            quote! {}
117        };
118        let step = if ord {
119            quote! {
120                #gate_rustc_only
121                impl ::std::iter::Step for #name {
122                    #[inline]
123                    fn steps_between(start: &Self, end: &Self) -> (usize, Option<usize>) {
124                        <usize as ::std::iter::Step>::steps_between(
125                            &Self::index(*start),
126                            &Self::index(*end),
127                        )
128                    }
129
130                    #[inline]
131                    fn forward_checked(start: Self, u: usize) -> Option<Self> {
132                        Self::index(start).checked_add(u).map(Self::from_usize)
133                    }
134
135                    #[inline]
136                    fn backward_checked(start: Self, u: usize) -> Option<Self> {
137                        Self::index(start).checked_sub(u).map(Self::from_usize)
138                    }
139                }
140                impl ::std::cmp::Ord for #name {
141                    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
142                        self.as_u32().cmp(&other.as_u32())
143                    }
144                }
145                impl ::std::cmp::PartialOrd for #name {
146                    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
147                        self.as_u32().partial_cmp(&other.as_u32())
148                    }
149                }
150            }
151        } else {
152            quote! {}
153        };
154
155        let hash_stable = if stable_hash {
156            quote! {
157                #gate_rustc_only
158                impl ::rustc_data_structures::stable_hasher::HashStable for #name {
159                    fn hash_stable<
160                        __Hcx: ::rustc_data_structures::stable_hasher::HashStableContext
161                    >(
162                        &self,
163                        hcx: &mut __Hcx,
164                        hasher: &mut ::rustc_data_structures::stable_hasher::StableHasher
165                    ) {
166                        self.as_u32().hash_stable(hcx, hasher)
167                    }
168                }
169            }
170        } else {
171            quote! {}
172        };
173
174        let debug_impl = quote! {
175            impl ::std::fmt::Debug for #name {
176                fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
177                    write!(fmt, #debug_format, self.as_u32())
178                }
179            }
180        };
181
182        Ok(Self(quote! {
183            #(#attrs)*
184            #[derive(Clone, Copy)]
185            #[cfg_attr(#gate_rustc_only_cfg, rustc_pass_by_value)]
186            #vis struct #name {
187                #[cfg(not(#gate_rustc_only_cfg))]
188                private_use_as_methods_instead: u32,
189                #[cfg(#gate_rustc_only_cfg)]
190                private_use_as_methods_instead: pattern_type!(u32 is 0..=#max),
191            }
192
193            #(#consts)*
194
195            impl #name {
196                /// Maximum value the index can take, as a `u32`.
197                #vis const MAX_AS_U32: u32  = #max;
198
199                /// Maximum value the index can take.
200                #vis const MAX: Self = Self::from_u32(#max);
201
202                /// Zero value of the index.
203                #vis const ZERO: Self = Self::from_u32(0);
204
205                /// Creates a new index from a given `usize`.
206                ///
207                /// # Panics
208                ///
209                /// Will panic if `value` exceeds `MAX`.
210                #[inline]
211                #vis const fn from_usize(value: usize) -> Self {
212                    assert!(value <= (#max as usize));
213                    // SAFETY: We just checked that `value <= max`.
214                    unsafe {
215                        Self::from_u32_unchecked(value as u32)
216                    }
217                }
218
219                /// Creates a new index from a given `u32`.
220                ///
221                /// # Panics
222                ///
223                /// Will panic if `value` exceeds `MAX`.
224                #[inline]
225                #vis const fn from_u32(value: u32) -> Self {
226                    assert!(value <= #max);
227                    // SAFETY: We just checked that `value <= max`.
228                    unsafe {
229                        Self::from_u32_unchecked(value)
230                    }
231                }
232
233                /// Creates a new index from a given `u16`.
234                ///
235                /// # Panics
236                ///
237                /// Will panic if `value` exceeds `MAX`.
238                #[inline]
239                #vis const fn from_u16(value: u16) -> Self {
240                    let value = value as u32;
241                    assert!(value <= #max);
242                    // SAFETY: We just checked that `value <= max`.
243                    unsafe {
244                        Self::from_u32_unchecked(value)
245                    }
246                }
247
248                /// Creates a new index from a given `u32`.
249                ///
250                /// # Safety
251                ///
252                /// The provided value must be less than or equal to the maximum value for the newtype.
253                /// Providing a value outside this range is undefined due to layout restrictions.
254                ///
255                /// Prefer using `from_u32`.
256                #[inline]
257                #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
258                    Self { private_use_as_methods_instead: unsafe { std::mem::transmute(value) } }
259                }
260
261                /// Extracts the value of this index as a `usize`.
262                #[inline]
263                #vis const fn index(self) -> usize {
264                    self.as_usize()
265                }
266
267                /// Extracts the value of this index as a `u32`.
268                #[inline]
269                #vis const fn as_u32(self) -> u32 {
270                    unsafe { std::mem::transmute(self.private_use_as_methods_instead) }
271                }
272
273                /// Extracts the value of this index as a `usize`.
274                #[inline]
275                #vis const fn as_usize(self) -> usize {
276                    self.as_u32() as usize
277                }
278            }
279
280            impl std::ops::Add<usize> for #name {
281                type Output = Self;
282
283                #[inline]
284                fn add(self, other: usize) -> Self {
285                    Self::from_usize(self.index() + other)
286                }
287            }
288
289            impl std::ops::AddAssign<usize> for #name {
290                #[inline]
291                fn add_assign(&mut self, other: usize) {
292                    *self = *self + other;
293                }
294            }
295
296            impl rustc_index::Idx for #name {
297                #[inline]
298                fn new(value: usize) -> Self {
299                    Self::from_usize(value)
300                }
301
302                #[inline]
303                fn index(self) -> usize {
304                    self.as_usize()
305                }
306            }
307
308            #step
309
310            #hash_stable
311
312            impl From<#name> for u32 {
313                #[inline]
314                fn from(v: #name) -> u32 {
315                    v.as_u32()
316                }
317            }
318
319            impl From<#name> for usize {
320                #[inline]
321                fn from(v: #name) -> usize {
322                    v.as_usize()
323                }
324            }
325
326            impl From<usize> for #name {
327                #[inline]
328                fn from(value: usize) -> Self {
329                    Self::from_usize(value)
330                }
331            }
332
333            impl From<u32> for #name {
334                #[inline]
335                fn from(value: u32) -> Self {
336                    Self::from_u32(value)
337                }
338            }
339
340            impl ::std::cmp::Eq for #name {}
341
342            impl ::std::cmp::PartialEq for #name {
343                fn eq(&self, other: &Self) -> bool {
344                    self.as_u32().eq(&other.as_u32())
345                }
346            }
347
348            #gate_rustc_only
349            impl ::std::marker::StructuralPartialEq for #name {}
350
351            impl ::std::hash::Hash for #name {
352                fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
353                    self.as_u32().hash(state)
354                }
355            }
356
357            #encodable_impls
358            #debug_impl
359        }))
360    }
361}
362
363pub(crate) fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
364    let input = parse_macro_input!(input as Newtype);
365    input.0.into()
366}