Skip to main content

rustc_attr_parsing/attributes/
autodiff.rs

1use std::str::FromStr;
2
3use rustc_ast::LitKind;
4use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
5use rustc_feature::{AttributeTemplate, template};
6use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
7use rustc_hir::{MethodKind, Target};
8use rustc_span::{Symbol, sym};
9use thin_vec::ThinVec;
10
11use crate::attributes::prelude::Allow;
12use crate::attributes::{OnDuplicate, SingleAttributeParser};
13use crate::context::{AcceptContext, Stage};
14use crate::parser::{ArgParser, MetaItemOrLitParser};
15use crate::target_checking::AllowedTargets;
16
17pub(crate) struct RustcAutodiffParser;
18
19impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
20    const PATH: &[Symbol] = &[sym::rustc_autodiff];
21    const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
22    const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
23        Allow(Target::Fn),
24        Allow(Target::Method(MethodKind::Inherent)),
25        Allow(Target::Method(MethodKind::Trait { body: true })),
26        Allow(Target::Method(MethodKind::Trait { body: false })),
27        Allow(Target::Method(MethodKind::TraitImpl)),
28    ]);
29    const TEMPLATE: AttributeTemplate = ::rustc_feature::AttributeTemplate {
    word: false,
    list: Some(&["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"]),
    one_of: &[],
    name_value_str: None,
    docs: Some("https://doc.rust-lang.org/std/autodiff/index.html"),
}template!(
30        List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
31        "https://doc.rust-lang.org/std/autodiff/index.html"
32    );
33
34    fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
35        let list = match args {
36            ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
37            ArgParser::List(list) => list,
38            ArgParser::NameValue(_) => {
39                let attr_span = cx.attr_span;
40                cx.adcx().expected_list_or_no_args(attr_span);
41                return None;
42            }
43        };
44
45        let mut items = list.mixed().peekable();
46
47        // Parse name
48        let Some(mode) = items.next() else {
49            cx.adcx().expected_at_least_one_argument(list.span);
50            return None;
51        };
52        let Some(mode) = mode.meta_item() else {
53            cx.adcx().expected_identifier(mode.span());
54            return None;
55        };
56        let Ok(()) = mode.args().no_args() else {
57            cx.adcx().expected_identifier(mode.span());
58            return None;
59        };
60        let Some(mode) = mode.path().word() else {
61            cx.adcx().expected_identifier(mode.span());
62            return None;
63        };
64        let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
65            cx.adcx().expected_specific_argument(mode.span, DiffMode::all_modes());
66            return None;
67        };
68
69        // Parse width
70        let width = if let Some(width) = items.peek()
71            && let MetaItemOrLitParser::Lit(width) = width
72            && let LitKind::Int(width, _) = width.kind
73            && let Ok(width) = width.0.try_into()
74        {
75            _ = items.next();
76            width
77        } else {
78            1
79        };
80
81        // Parse activities
82        let mut activities = ThinVec::new();
83        for activity in items {
84            let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
85                cx.adcx()
86                    .expected_specific_argument(activity.span(), DiffActivity::all_activities());
87                return None;
88            };
89            let Ok(()) = activity.args().no_args() else {
90                cx.adcx()
91                    .expected_specific_argument(activity.span(), DiffActivity::all_activities());
92                return None;
93            };
94            let Some(activity) = activity.path().word() else {
95                cx.adcx()
96                    .expected_specific_argument(activity.span(), DiffActivity::all_activities());
97                return None;
98            };
99            let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
100                cx.adcx().expected_specific_argument(activity.span, DiffActivity::all_activities());
101                return None;
102            };
103
104            activities.push(activity);
105        }
106        let Some(ret_activity) = activities.pop() else {
107            cx.adcx().expected_specific_argument(
108                list.span.with_lo(list.span.hi()),
109                DiffActivity::all_activities(),
110            );
111            return None;
112        };
113
114        Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
115            mode,
116            width,
117            input_activity: activities,
118            ret_activity,
119        }))))
120    }
121}