Skip to main content

rustc_mir_transform/
match_branches.rs

1use rustc_abi::Integer;
2use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
3use rustc_middle::mir::*;
4use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
5use rustc_middle::ty::util::Discr;
6use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
7
8use super::simplify::simplify_cfg;
9use crate::patch::MirPatch;
10use crate::unreachable_prop::remove_successors_from_switch;
11
12/// Unifies all targets into one basic block if each statement can have the same statement.
13pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17        // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable.
18        sess.mir_opt_level() >= 2
19    }
20
21    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22        let typing_env = body.typing_env(tcx);
23        let mut changed = false;
24        for bb in body.basic_blocks.indices() {
25            if !candidate_match(body, bb) {
26                continue;
27            };
28            changed |= simplify_match(tcx, typing_env, body, bb)
29        }
30
31        if changed {
32            simplify_cfg(tcx, body);
33        }
34    }
35
36    fn is_required(&self) -> bool {
37        false
38    }
39}
40
41struct SimplifyMatch<'tcx, 'a> {
42    tcx: TyCtxt<'tcx>,
43    typing_env: ty::TypingEnv<'tcx>,
44    patch: MirPatch<'tcx>,
45    body: &'a Body<'tcx>,
46    switch_bb: BasicBlock,
47    discr: &'a Operand<'tcx>,
48    discr_local: Option<Local>,
49    discr_ty: Ty<'tcx>,
50}
51
52impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> {
53    fn discr_local(&mut self) -> Local {
54        *self.discr_local.get_or_insert_with(|| {
55            // Introduce a temporary for the discriminant value.
56            let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info;
57            self.patch.new_temp(self.discr_ty, source_info.span)
58        })
59    }
60
61    /// Unifies the assignments if all rvalues are constants and equal.
62    fn unify_if_equal_const(
63        &self,
64        dest: Place<'tcx>,
65        consts: &[(u128, &ConstOperand<'tcx>)],
66        otherwise: Option<&ConstOperand<'tcx>>,
67    ) -> Option<StatementKind<'tcx>> {
68        let (_, first_const, mut others) = split_first_case(consts, otherwise);
69        let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?;
70        if others.all(|const_| {
71            const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int)
72        }) {
73            Some(StatementKind::Assign(Box::new((
74                dest,
75                // We didn't remember the `WithRetag` of the original assignments, so in case
76                // one of them had "no", we also have to use "no" here.
77                Rvalue::Use(Operand::Constant(Box::new(first_const.clone())), WithRetag::No),
78            ))))
79        } else {
80            None
81        }
82    }
83
84    /// If a source block is found that switches between two blocks that are exactly
85    /// the same modulo const bool assignments (e.g., one assigns true another false
86    /// to the same place), unify a target block statements into the source block,
87    /// using Eq / Ne comparison with switch value where const bools value differ.
88    ///
89    /// For example:
90    ///
91    /// ```ignore (MIR)
92    /// bb0: {
93    ///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
94    /// }
95    ///
96    /// bb1: {
97    ///     _2 = const true;
98    ///     goto -> bb3;
99    /// }
100    ///
101    /// bb2: {
102    ///     _2 = const false;
103    ///     goto -> bb3;
104    /// }
105    /// ```
106    ///
107    /// into:
108    ///
109    /// ```ignore (MIR)
110    /// bb0: {
111    ///    _2 = Eq(move _3, const 42_isize);
112    ///    goto -> bb3;
113    /// }
114    /// ```
115    fn unify_by_eq_op(
116        &mut self,
117        dest: Place<'tcx>,
118        consts: &[(u128, &ConstOperand<'tcx>)],
119        otherwise: Option<&ConstOperand<'tcx>>,
120    ) -> Option<StatementKind<'tcx>> {
121        // FIXME: extend to any case.
122        let (first_case, first_const, mut others) = split_first_case(consts, otherwise);
123        if !first_const.ty().is_bool() {
124            return None;
125        }
126        let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?;
127        if others.all(|const_| {
128            const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool)
129        }) {
130            // Make value conditional on switch condition.
131            let size =
132                self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size;
133            let const_cmp = Operand::const_from_scalar(
134                self.tcx,
135                self.discr_ty,
136                rustc_const_eval::interpret::Scalar::from_uint(first_case, size),
137                rustc_span::DUMMY_SP,
138            );
139            let op = if first_bool { BinOp::Eq } else { BinOp::Ne };
140            let rval = Rvalue::BinaryOp(
141                op,
142                Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)),
143            );
144            Some(StatementKind::Assign(Box::new((dest, rval))))
145        } else {
146            None
147        }
148    }
149
150    /// Unifies the assignments if all rvalues can be cast from the discriminant value by IntToInt.
151    ///
152    /// For example:
153    ///
154    /// ```ignore (MIR)
155    /// bb0: {
156    ///     switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
157    /// }
158    ///
159    /// bb1: {
160    ///     unreachable;
161    /// }
162    ///
163    /// bb2: {
164    ///     _0 = const 1_i16;
165    ///     goto -> bb5;
166    /// }
167    ///
168    /// bb3: {
169    ///     _0 = const 2_i16;
170    ///     goto -> bb5;
171    /// }
172    ///
173    /// bb4: {
174    ///     _0 = const 3_i16;
175    ///     goto -> bb5;
176    /// }
177    /// ```
178    ///
179    /// into:
180    ///
181    /// ```ignore (MIR)
182    /// bb0: {
183    ///    _0 = _1 as i16 (IntToInt);
184    ///    goto -> bb5;
185    /// }
186    /// ```
187    fn unify_by_int_to_int(
188        &mut self,
189        dest: Place<'tcx>,
190        consts: &[(u128, &ConstOperand<'tcx>)],
191    ) -> Option<StatementKind<'tcx>> {
192        let (_, first_const) = consts[0];
193        if !first_const.ty().is_integral() {
194            return None;
195        }
196        let discr_layout =
197            self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap();
198        if consts.iter().all(|&(case, const_)| {
199            let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env)
200            else {
201                return false;
202            };
203            can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int)
204        }) {
205            let operand = Operand::Copy(Place::from(self.discr_local()));
206            let rval = if first_const.ty() == self.discr_ty {
207                Rvalue::Use(operand, WithRetag::No)
208            } else {
209                Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty())
210            };
211            Some(StatementKind::Assign(Box::new((dest, rval))))
212        } else {
213            None
214        }
215    }
216
217    /// This is primarily used to unify these copy statements that simplified the canonical enum clone method by GVN.
218    /// The GVN simplified
219    /// ```ignore (syntax-highlighting-only)
220    /// match a {
221    ///     Foo::A(x) => Foo::A(*x),
222    ///     Foo::B => Foo::B
223    /// }
224    /// ```
225    /// to
226    /// ```ignore (syntax-highlighting-only)
227    /// match a {
228    ///     Foo::A(_x) => a, // copy a
229    ///     Foo::B => Foo::B
230    /// }
231    /// ```
232    /// This will simplify into a copy statement.
233    fn unify_by_copy(
234        &self,
235        dest: Place<'tcx>,
236        rvals: &[(u128, &Rvalue<'tcx>)],
237    ) -> Option<StatementKind<'tcx>> {
238        let bbs = &self.body.basic_blocks;
239        // Check if the copy source matches the following pattern.
240        // _2 = discriminant(*_1); // "*_1" is the expected the copy source.
241        // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
242        let &Statement {
243            kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))),
244            ..
245        } = bbs[self.switch_bb].statements.last()?
246        else {
247            return None;
248        };
249        if self.discr.place() != Some(discr_place) {
250            return None;
251        }
252        let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx);
253        if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
254            return None;
255        }
256        let dest_ty = dest.ty(self.body.local_decls(), self.tcx);
257        if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
258            return None;
259        }
260        let ty::Adt(def, _) = dest_ty.ty.kind() else {
261            return None;
262        };
263
264        for &(case, rvalue) in rvals.iter() {
265            match rvalue {
266                // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
267                Rvalue::Use(Operand::Constant(box constant), _)
268                    if let Const::Val(const_, ty) = constant.const_ =>
269                {
270                    let (ecx, op) = mk_eval_cx_for_const_val(
271                        self.tcx.at(constant.span),
272                        self.typing_env,
273                        const_,
274                        ty,
275                    )?;
276                    let variant = ecx.read_discriminant(&op).discard_err()?;
277                    if !def.variants()[variant].fields.is_empty() {
278                        return None;
279                    }
280                    let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?;
281                    if val != case {
282                        return None;
283                    }
284                }
285                Rvalue::Use(Operand::Copy(src_place), _) if *src_place == copy_src_place => {}
286                // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
287                Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
288                    if fields.is_empty()
289                        && let Some(Discr { val, .. }) =
290                            src_ty.ty.discriminant_for_variant(self.tcx, *variant_index)
291                        && val == case => {}
292                _ => return None,
293            }
294        }
295        // We didn't remember the `WithRetag` of the original assignments, so in case
296        // one of them had "no", we also have to use "no" here.
297        Some(StatementKind::Assign(Box::new((
298            dest,
299            Rvalue::Use(Operand::Copy(copy_src_place), WithRetag::No),
300        ))))
301    }
302
303    /// Returns a new statement if we can use the statement replace all statements.
304    fn try_unify_stmts(
305        &mut self,
306        index: usize,
307        stmts: &[(u128, &StatementKind<'tcx>)],
308        otherwise: Option<&StatementKind<'tcx>>,
309    ) -> Option<StatementKind<'tcx>> {
310        if let Some(new_stmt) = identical_stmts(stmts, otherwise) {
311            return Some(new_stmt);
312        }
313
314        let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?;
315        if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) {
316            if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) {
317                return Some(new_stmt);
318            }
319            if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) {
320                return Some(new_stmt);
321            }
322            // Requires the otherwise is unreachable.
323            if otherwise.is_none()
324                && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts)
325            {
326                return Some(new_stmt);
327            }
328        }
329
330        // We only know the first statement is safe to introduce new dereferences.
331        if index == 0
332            // We cannot create overlapping assignments.
333            && dest.is_stable_offset()
334            // Requires the otherwise is unreachable.
335            && otherwise.is_none()
336            && let Some(new_stmt) = self.unify_by_copy(dest, &rvals)
337        {
338            return Some(new_stmt);
339        }
340        None
341    }
342}
343
344/// Returns the first case target if all targets have an equal number of statements and identical destination.
345fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool {
346    use itertools::Itertools;
347    let targets = match &body.basic_blocks[switch_bb].terminator().kind {
348        TerminatorKind::SwitchInt {
349            discr: Operand::Copy(_) | Operand::Move(_), targets, ..
350        } => targets,
351        // Only optimize switch int statements
352        _ => return false,
353    };
354    // We require that the possible target blocks don't contain this block.
355    if targets.all_targets().contains(&switch_bb) {
356        return false;
357    }
358    // We require that the possible target blocks all be distinct.
359    if !targets.is_distinct() {
360        return false;
361    }
362    // Check that destinations are identical, and if not, then don't optimize this block
363    targets
364        .all_targets()
365        .iter()
366        .map(|&bb| &body.basic_blocks[bb])
367        .filter(|bb| !bb.is_empty_unreachable())
368        .map(|bb| (bb.statements.len(), &bb.terminator().kind))
369        .all_equal()
370}
371
372fn simplify_match<'tcx>(
373    tcx: TyCtxt<'tcx>,
374    typing_env: ty::TypingEnv<'tcx>,
375    body: &mut Body<'tcx>,
376    switch_bb: BasicBlock,
377) -> bool {
378    let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind {
379        TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets),
380        _ => unreachable!(),
381    };
382    let mut simplify_match = SimplifyMatch {
383        tcx,
384        typing_env,
385        patch: MirPatch::new(body),
386        body,
387        switch_bb,
388        discr,
389        discr_local: None,
390        discr_ty: discr.ty(body.local_decls(), tcx),
391    };
392    let reachable_cases: Vec<_> =
393        targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect();
394    let mut new_stmts = Vec::new();
395    let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() {
396        None
397    } else {
398        Some(targets.otherwise())
399    };
400    // We can patch the terminator to goto because there is a single target.
401    match (reachable_cases.len(), otherwise.is_none()) {
402        (1, true) | (0, false) => {
403            let mut patch = simplify_match.patch;
404            remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| {
405                body.basic_blocks[bb].is_empty_unreachable()
406            });
407            patch.apply(body);
408            return true;
409        }
410        _ => {}
411    }
412    let Some(&(_, first_case_bb)) = reachable_cases.first() else {
413        return false;
414    };
415    let stmt_len = body.basic_blocks[first_case_bb].statements.len();
416    let mut cases = Vec::with_capacity(stmt_len);
417    // Check at each position in the basic blocks whether these statements can be unified.
418    for index in 0..stmt_len {
419        cases.clear();
420        let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind);
421        for &(case, bb) in &reachable_cases {
422            cases.push((case, &body.basic_blocks[bb].statements[index].kind));
423        }
424        let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else {
425            return false;
426        };
427        new_stmts.push(new_stmt);
428    }
429    // Take ownership of items now that we know we can optimize.
430    let discr = discr.clone();
431
432    let statement_index = body.basic_blocks[switch_bb].statements.len();
433    let parent_end = Location { block: switch_bb, statement_index };
434    let mut patch = simplify_match.patch;
435    if let Some(discr_local) = simplify_match.discr_local {
436        patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
437        patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr, WithRetag::No));
438    }
439    for new_stmt in new_stmts {
440        patch.add_statement(parent_end, new_stmt);
441    }
442    if let Some(discr_local) = simplify_match.discr_local {
443        patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
444    }
445    patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone());
446    patch.apply(body);
447    true
448}
449
450/// Check if the cast constant using `IntToInt` is equal to the target constant.
451fn can_cast(
452    tcx: TyCtxt<'_>,
453    src_val: impl Into<u128>,
454    src_layout: TyAndLayout<'_>,
455    cast_ty: Ty<'_>,
456    target_scalar: ScalarInt,
457) -> bool {
458    let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
459    let v = match src_layout.ty.kind() {
460        ty::Uint(_) => from_scalar.to_uint(src_layout.size),
461        ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
462        // We can also transform the values of other integer representations (such as char),
463        // although this may not be practical in real-world scenarios.
464        _ => return false,
465    };
466    let size = match *cast_ty.kind() {
467        ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
468        ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
469        _ => return false,
470    };
471    let v = size.truncate(v);
472    let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
473    cast_scalar == target_scalar
474}
475
476fn candidate_assign<'tcx, 'a>(
477    stmts: &'a [(u128, &'a StatementKind<'tcx>)],
478    otherwise: Option<&'a StatementKind<'tcx>>,
479) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> {
480    let (_, first_stmt) = stmts[0];
481    let (dest, _) = first_stmt.as_assign()?;
482    let otherwise = if let Some(otherwise) = otherwise {
483        let Some((otherwise_dest, rval)) = otherwise.as_assign() else {
484            return None;
485        };
486        if otherwise_dest != dest {
487            return None;
488        }
489        Some(rval)
490    } else {
491        None
492    };
493    let rvals = stmts
494        .into_iter()
495        .map(|&(case, stmt)| {
496            let (other_dest, rval) = stmt.as_assign()?;
497            if other_dest != dest {
498                return None;
499            }
500            Some((case, rval))
501        })
502        .try_collect()?;
503    Some((*dest, rvals, otherwise))
504}
505
506// Returns all ConstOperands if all Rvalues are ConstOperands.
507fn candidate_const<'tcx, 'a>(
508    rvals: &'a [(u128, &'a Rvalue<'tcx>)],
509    otherwise: Option<&'a Rvalue<'tcx>>,
510) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> {
511    // We ignore the retag mode here, which means the `Use` we insert later must be without retag.
512    let otherwise = if let Some(otherwise) = otherwise {
513        let Rvalue::Use(Operand::Constant(box const_), _) = otherwise else {
514            return None;
515        };
516        Some(const_)
517    } else {
518        None
519    };
520    let consts = rvals
521        .into_iter()
522        .map(|&(case, rval)| {
523            let Rvalue::Use(Operand::Constant(box const_), _) = rval else { return None };
524            Some((case, const_))
525        })
526        .try_collect()?;
527    Some((consts, otherwise))
528}
529
530// Returns the first case and others (including otherwise if present).
531fn split_first_case<'a, T>(
532    stmts: &'a [(u128, &'a T)],
533    otherwise: Option<&'a T>,
534) -> (u128, &'a T, impl Iterator<Item = &'a T>) {
535    let (first_case, first) = stmts[0];
536    (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise))
537}
538
539// If all statements are identical, we can optimize.
540fn identical_stmts<'tcx>(
541    stmts: &[(u128, &StatementKind<'tcx>)],
542    otherwise: Option<&StatementKind<'tcx>>,
543) -> Option<StatementKind<'tcx>> {
544    use itertools::Itertools;
545    let (_, first_stmt, others) = split_first_case(stmts, otherwise);
546    if std::iter::once(first_stmt).chain(others).all_equal() {
547        return Some(first_stmt.clone());
548    }
549    None
550}