rustc_mir_transform/
unreachable_enum_branching.rs1use rustc_abi::Variants;
4use rustc_data_structures::fx::FxHashSet;
5use rustc_middle::bug;
6use rustc_middle::mir::{
7 BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind,
8};
9use rustc_middle::ty::layout::TyAndLayout;
10use rustc_middle::ty::{Ty, TyCtxt};
11use tracing::trace;
12
13use crate::patch::MirPatch;
14
15pub(super) struct UnreachableEnumBranching;
16
17fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> {
18 if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator {
19 p.as_local()
20 } else {
21 None
22 }
23}
24
25fn get_switched_on_type<'tcx>(
28 block_data: &BasicBlockData<'tcx>,
29 tcx: TyCtxt<'tcx>,
30 body: &Body<'tcx>,
31) -> Option<Ty<'tcx>> {
32 let terminator = block_data.terminator();
33
34 let local = get_discriminant_local(&terminator.kind)?;
36
37 let stmt_before_term = block_data.statements.last()?;
38
39 if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
40 && l.as_local() == Some(local)
41 {
42 let ty = place.ty(body, tcx).ty;
43 if ty.is_enum() {
44 return Some(ty);
45 }
46 }
47
48 None
49}
50
51fn variant_discriminants<'tcx>(
52 layout: &TyAndLayout<'tcx>,
53 ty: Ty<'tcx>,
54 tcx: TyCtxt<'tcx>,
55) -> FxHashSet<u128> {
56 match &layout.variants {
57 Variants::Empty => {
58 FxHashSet::default()
60 }
61 Variants::Single { index } => {
62 let mut res = FxHashSet::default();
63 res.insert(
64 ty.discriminant_for_variant(tcx, *index)
65 .map_or(index.as_u32() as u128, |discr| discr.val),
66 );
67 res
68 }
69 Variants::Multiple { variants, .. } => variants
70 .iter_enumerated()
71 .filter_map(|(idx, layout)| {
72 (!layout.is_uninhabited())
73 .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val)
74 })
75 .collect(),
76 }
77}
78
79impl<'tcx> crate::MirPass<'tcx> for UnreachableEnumBranching {
80 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
81 sess.mir_opt_level() > 0
82 }
83
84 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
85 trace!("UnreachableEnumBranching starting for {:?}", body.source);
86
87 let mut unreachable_targets = Vec::new();
88 let mut patch = MirPatch::new(body);
89
90 for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
91 trace!("processing block {:?}", bb);
92
93 if bb_data.is_cleanup {
94 continue;
95 }
96
97 let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue };
98
99 let layout = tcx.layout_of(body.typing_env(tcx).as_query_input(discriminant_ty));
100
101 let mut allowed_variants = if let Ok(layout) = layout {
102 variant_discriminants(&layout, discriminant_ty, tcx)
104 } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
105 variant_range
107 .map(|variant| {
108 discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
109 })
110 .collect()
111 } else {
112 continue;
113 };
114
115 trace!("allowed_variants = {:?}", allowed_variants);
116
117 unreachable_targets.clear();
118 let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
119 bug!()
120 };
121
122 for (index, (val, _)) in targets.iter().enumerate() {
123 if !allowed_variants.remove(&val) {
124 unreachable_targets.push(index);
125 }
126 }
127
128 let replace_otherwise_to_unreachable = allowed_variants.len() <= 1
147 && !body.basic_blocks[targets.otherwise()].is_empty_unreachable();
148 if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
149 continue;
150 }
151
152 let unreachable_block = patch.unreachable_no_cleanup_block();
153 let mut targets = targets.clone();
154 if replace_otherwise_to_unreachable {
155 let otherwise_is_last_variant = allowed_variants.len() == 1;
156 if otherwise_is_last_variant {
157 #[allow(rustc::potential_query_instability)]
159 let last_variant = *allowed_variants.iter().next().unwrap();
160 targets.add_target(last_variant, targets.otherwise());
161 }
162 unreachable_targets.push(targets.iter().count());
163 }
164 for index in unreachable_targets.iter() {
165 targets.all_targets_mut()[*index] = unreachable_block;
166 }
167 patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
168 }
169
170 patch.apply(body);
171 }
172
173 fn is_required(&self) -> bool {
174 false
175 }
176}