Skip to main content

rustc_mir_transform/
unreachable_enum_branching.rs

1//! A pass that eliminates branches on uninhabited or unreachable enum variants.
2
3use rustc_abi::Variants;
4use rustc_data_structures::fx::FxHashSet;
5use rustc_middle::bug;
6use rustc_middle::mir::{
7    BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind,
8};
9use rustc_middle::ty::layout::TyAndLayout;
10use rustc_middle::ty::{Ty, TyCtxt};
11use tracing::trace;
12
13use crate::patch::MirPatch;
14
15pub(super) struct UnreachableEnumBranching;
16
17fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> {
18    if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator {
19        p.as_local()
20    } else {
21        None
22    }
23}
24
25/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the
26/// discriminant is read from. Otherwise, returns None.
27fn get_switched_on_type<'tcx>(
28    block_data: &BasicBlockData<'tcx>,
29    tcx: TyCtxt<'tcx>,
30    body: &Body<'tcx>,
31) -> Option<Ty<'tcx>> {
32    let terminator = block_data.terminator();
33
34    // Only bother checking blocks which terminate by switching on a local.
35    let local = get_discriminant_local(&terminator.kind)?;
36
37    let stmt_before_term = block_data.statements.last()?;
38
39    if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
40        && l.as_local() == Some(local)
41    {
42        let ty = place.ty(body, tcx).ty;
43        if ty.is_enum() {
44            return Some(ty);
45        }
46    }
47
48    None
49}
50
51fn variant_discriminants<'tcx>(
52    layout: &TyAndLayout<'tcx>,
53    ty: Ty<'tcx>,
54    tcx: TyCtxt<'tcx>,
55) -> FxHashSet<u128> {
56    match &layout.variants {
57        Variants::Empty => {
58            // Uninhabited, no valid discriminant.
59            FxHashSet::default()
60        }
61        Variants::Single { index } => {
62            let mut res = FxHashSet::default();
63            res.insert(
64                ty.discriminant_for_variant(tcx, *index)
65                    .map_or(index.as_u32() as u128, |discr| discr.val),
66            );
67            res
68        }
69        Variants::Multiple { variants, .. } => variants
70            .iter_enumerated()
71            .filter_map(|(idx, layout)| {
72                (!layout.is_uninhabited())
73                    .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val)
74            })
75            .collect(),
76    }
77}
78
79impl<'tcx> crate::MirPass<'tcx> for UnreachableEnumBranching {
80    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
81        sess.mir_opt_level() > 0
82    }
83
84    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
85        trace!("UnreachableEnumBranching starting for {:?}", body.source);
86
87        let mut unreachable_targets = Vec::new();
88        let mut patch = MirPatch::new(body);
89
90        for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
91            trace!("processing block {:?}", bb);
92
93            if bb_data.is_cleanup {
94                continue;
95            }
96
97            let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue };
98
99            let layout = tcx.layout_of(body.typing_env(tcx).as_query_input(discriminant_ty));
100
101            let mut allowed_variants = if let Ok(layout) = layout {
102                // Find allowed variants based on uninhabited.
103                variant_discriminants(&layout, discriminant_ty, tcx)
104            } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
105                // If there are some generics, we can still get the allowed variants.
106                variant_range
107                    .map(|variant| {
108                        discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
109                    })
110                    .collect()
111            } else {
112                continue;
113            };
114
115            trace!("allowed_variants = {:?}", allowed_variants);
116
117            unreachable_targets.clear();
118            let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
119                bug!()
120            };
121
122            for (index, (val, _)) in targets.iter().enumerate() {
123                if !allowed_variants.remove(&val) {
124                    unreachable_targets.push(index);
125                }
126            }
127
128            // If and only if there is a variant that does not have a branch set, change the
129            // current of otherwise as the variant branch and set otherwise to unreachable. It
130            // transforms the following code
131            // ```rust
132            // match c {
133            //     Ordering::Less => 1,
134            //     Ordering::Equal => 2,
135            //     _ => 3,
136            // }
137            // ```
138            // to
139            // ```rust
140            // match c {
141            //     Ordering::Less => 1,
142            //     Ordering::Equal => 2,
143            //     Ordering::Greater => 3,
144            // }
145            // ```
146            let replace_otherwise_to_unreachable = allowed_variants.len() <= 1
147                && !body.basic_blocks[targets.otherwise()].is_empty_unreachable();
148            if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
149                continue;
150            }
151
152            let unreachable_block = patch.unreachable_no_cleanup_block();
153            let mut targets = targets.clone();
154            if replace_otherwise_to_unreachable {
155                let otherwise_is_last_variant = allowed_variants.len() == 1;
156                if otherwise_is_last_variant {
157                    // We have checked that `allowed_variants` has only one element.
158                    #[allow(rustc::potential_query_instability)]
159                    let last_variant = *allowed_variants.iter().next().unwrap();
160                    targets.add_target(last_variant, targets.otherwise());
161                }
162                unreachable_targets.push(targets.iter().count());
163            }
164            for index in unreachable_targets.iter() {
165                targets.all_targets_mut()[*index] = unreachable_block;
166            }
167            patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
168        }
169
170        patch.apply(body);
171    }
172
173    fn is_required(&self) -> bool {
174        false
175    }
176}