1use rustc_abi::Integer;
2use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
3use rustc_middle::mir::*;
4use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
5use rustc_middle::ty::util::Discr;
6use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
7
8use super::simplify::simplify_cfg;
9use crate::patch::MirPatch;
10use crate::unreachable_prop::remove_successors_from_switch;
11
12pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17 sess.mir_opt_level() >= 2
19 }
20
21 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22 let typing_env = body.typing_env(tcx);
23 let mut changed = false;
24 for bb in body.basic_blocks.indices() {
25 if !candidate_match(body, bb) {
26 continue;
27 };
28 changed |= simplify_match(tcx, typing_env, body, bb)
29 }
30
31 if changed {
32 simplify_cfg(tcx, body);
33 }
34 }
35
36 fn is_required(&self) -> bool {
37 false
38 }
39}
40
41struct SimplifyMatch<'tcx, 'a> {
42 tcx: TyCtxt<'tcx>,
43 typing_env: ty::TypingEnv<'tcx>,
44 patch: MirPatch<'tcx>,
45 body: &'a Body<'tcx>,
46 switch_bb: BasicBlock,
47 discr: &'a Operand<'tcx>,
48 discr_local: Option<Local>,
49 discr_ty: Ty<'tcx>,
50}
51
52impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> {
53 fn discr_local(&mut self) -> Local {
54 *self.discr_local.get_or_insert_with(|| {
55 let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info;
57 self.patch.new_temp(self.discr_ty, source_info.span)
58 })
59 }
60
61 fn unify_if_equal_const(
63 &self,
64 dest: Place<'tcx>,
65 consts: &[(u128, &ConstOperand<'tcx>)],
66 otherwise: Option<&ConstOperand<'tcx>>,
67 ) -> Option<StatementKind<'tcx>> {
68 let (_, first_const, mut others) = split_first_case(consts, otherwise);
69 let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?;
70 if others.all(|const_| {
71 const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int)
72 }) {
73 Some(StatementKind::Assign(Box::new((
74 dest,
75 Rvalue::Use(Operand::Constant(Box::new(first_const.clone())), WithRetag::No),
78 ))))
79 } else {
80 None
81 }
82 }
83
84 fn unify_by_eq_op(
116 &mut self,
117 dest: Place<'tcx>,
118 consts: &[(u128, &ConstOperand<'tcx>)],
119 otherwise: Option<&ConstOperand<'tcx>>,
120 ) -> Option<StatementKind<'tcx>> {
121 let (first_case, first_const, mut others) = split_first_case(consts, otherwise);
123 if !first_const.ty().is_bool() {
124 return None;
125 }
126 let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?;
127 if others.all(|const_| {
128 const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool)
129 }) {
130 let size =
132 self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size;
133 let const_cmp = Operand::const_from_scalar(
134 self.tcx,
135 self.discr_ty,
136 rustc_const_eval::interpret::Scalar::from_uint(first_case, size),
137 rustc_span::DUMMY_SP,
138 );
139 let op = if first_bool { BinOp::Eq } else { BinOp::Ne };
140 let rval = Rvalue::BinaryOp(
141 op,
142 Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)),
143 );
144 Some(StatementKind::Assign(Box::new((dest, rval))))
145 } else {
146 None
147 }
148 }
149
150 fn unify_by_int_to_int(
188 &mut self,
189 dest: Place<'tcx>,
190 consts: &[(u128, &ConstOperand<'tcx>)],
191 ) -> Option<StatementKind<'tcx>> {
192 let (_, first_const) = consts[0];
193 if !first_const.ty().is_integral() {
194 return None;
195 }
196 let discr_layout =
197 self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap();
198 if consts.iter().all(|&(case, const_)| {
199 let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env)
200 else {
201 return false;
202 };
203 can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int)
204 }) {
205 let operand = Operand::Copy(Place::from(self.discr_local()));
206 let rval = if first_const.ty() == self.discr_ty {
207 Rvalue::Use(operand, WithRetag::No)
208 } else {
209 Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty())
210 };
211 Some(StatementKind::Assign(Box::new((dest, rval))))
212 } else {
213 None
214 }
215 }
216
217 fn unify_by_copy(
234 &self,
235 dest: Place<'tcx>,
236 rvals: &[(u128, &Rvalue<'tcx>)],
237 ) -> Option<StatementKind<'tcx>> {
238 let bbs = &self.body.basic_blocks;
239 let &Statement {
243 kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))),
244 ..
245 } = bbs[self.switch_bb].statements.last()?
246 else {
247 return None;
248 };
249 if self.discr.place() != Some(discr_place) {
250 return None;
251 }
252 let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx);
253 if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
254 return None;
255 }
256 let dest_ty = dest.ty(self.body.local_decls(), self.tcx);
257 if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
258 return None;
259 }
260 let ty::Adt(def, _) = dest_ty.ty.kind() else {
261 return None;
262 };
263
264 for &(case, rvalue) in rvals.iter() {
265 match rvalue {
266 Rvalue::Use(Operand::Constant(box constant), _)
268 if let Const::Val(const_, ty) = constant.const_ =>
269 {
270 let (ecx, op) = mk_eval_cx_for_const_val(
271 self.tcx.at(constant.span),
272 self.typing_env,
273 const_,
274 ty,
275 )?;
276 let variant = ecx.read_discriminant(&op).discard_err()?;
277 if !def.variants()[variant].fields.is_empty() {
278 return None;
279 }
280 let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?;
281 if val != case {
282 return None;
283 }
284 }
285 Rvalue::Use(Operand::Copy(src_place), _) if *src_place == copy_src_place => {}
286 Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
288 if fields.is_empty()
289 && let Some(Discr { val, .. }) =
290 src_ty.ty.discriminant_for_variant(self.tcx, *variant_index)
291 && val == case => {}
292 _ => return None,
293 }
294 }
295 Some(StatementKind::Assign(Box::new((
298 dest,
299 Rvalue::Use(Operand::Copy(copy_src_place), WithRetag::No),
300 ))))
301 }
302
303 fn try_unify_stmts(
305 &mut self,
306 index: usize,
307 stmts: &[(u128, &StatementKind<'tcx>)],
308 otherwise: Option<&StatementKind<'tcx>>,
309 ) -> Option<StatementKind<'tcx>> {
310 if let Some(new_stmt) = identical_stmts(stmts, otherwise) {
311 return Some(new_stmt);
312 }
313
314 let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?;
315 if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) {
316 if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) {
317 return Some(new_stmt);
318 }
319 if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) {
320 return Some(new_stmt);
321 }
322 if otherwise.is_none()
324 && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts)
325 {
326 return Some(new_stmt);
327 }
328 }
329
330 if index == 0
332 && dest.is_stable_offset()
334 && otherwise.is_none()
336 && let Some(new_stmt) = self.unify_by_copy(dest, &rvals)
337 {
338 return Some(new_stmt);
339 }
340 None
341 }
342}
343
344fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool {
346 use itertools::Itertools;
347 let targets = match &body.basic_blocks[switch_bb].terminator().kind {
348 TerminatorKind::SwitchInt {
349 discr: Operand::Copy(_) | Operand::Move(_), targets, ..
350 } => targets,
351 _ => return false,
353 };
354 if targets.all_targets().contains(&switch_bb) {
356 return false;
357 }
358 if !targets.is_distinct() {
360 return false;
361 }
362 targets
364 .all_targets()
365 .iter()
366 .map(|&bb| &body.basic_blocks[bb])
367 .filter(|bb| !bb.is_empty_unreachable())
368 .map(|bb| (bb.statements.len(), &bb.terminator().kind))
369 .all_equal()
370}
371
372fn simplify_match<'tcx>(
373 tcx: TyCtxt<'tcx>,
374 typing_env: ty::TypingEnv<'tcx>,
375 body: &mut Body<'tcx>,
376 switch_bb: BasicBlock,
377) -> bool {
378 let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind {
379 TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets),
380 _ => unreachable!(),
381 };
382 let mut simplify_match = SimplifyMatch {
383 tcx,
384 typing_env,
385 patch: MirPatch::new(body),
386 body,
387 switch_bb,
388 discr,
389 discr_local: None,
390 discr_ty: discr.ty(body.local_decls(), tcx),
391 };
392 let reachable_cases: Vec<_> =
393 targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect();
394 let mut new_stmts = Vec::new();
395 let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() {
396 None
397 } else {
398 Some(targets.otherwise())
399 };
400 match (reachable_cases.len(), otherwise.is_none()) {
402 (1, true) | (0, false) => {
403 let mut patch = simplify_match.patch;
404 remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| {
405 body.basic_blocks[bb].is_empty_unreachable()
406 });
407 patch.apply(body);
408 return true;
409 }
410 _ => {}
411 }
412 let Some(&(_, first_case_bb)) = reachable_cases.first() else {
413 return false;
414 };
415 let stmt_len = body.basic_blocks[first_case_bb].statements.len();
416 let mut cases = Vec::with_capacity(stmt_len);
417 for index in 0..stmt_len {
419 cases.clear();
420 let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind);
421 for &(case, bb) in &reachable_cases {
422 cases.push((case, &body.basic_blocks[bb].statements[index].kind));
423 }
424 let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else {
425 return false;
426 };
427 new_stmts.push(new_stmt);
428 }
429 let discr = discr.clone();
431
432 let statement_index = body.basic_blocks[switch_bb].statements.len();
433 let parent_end = Location { block: switch_bb, statement_index };
434 let mut patch = simplify_match.patch;
435 if let Some(discr_local) = simplify_match.discr_local {
436 patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
437 patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr, WithRetag::No));
438 }
439 for new_stmt in new_stmts {
440 patch.add_statement(parent_end, new_stmt);
441 }
442 if let Some(discr_local) = simplify_match.discr_local {
443 patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
444 }
445 patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone());
446 patch.apply(body);
447 true
448}
449
450fn can_cast(
452 tcx: TyCtxt<'_>,
453 src_val: impl Into<u128>,
454 src_layout: TyAndLayout<'_>,
455 cast_ty: Ty<'_>,
456 target_scalar: ScalarInt,
457) -> bool {
458 let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
459 let v = match src_layout.ty.kind() {
460 ty::Uint(_) => from_scalar.to_uint(src_layout.size),
461 ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
462 _ => return false,
465 };
466 let size = match *cast_ty.kind() {
467 ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
468 ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
469 _ => return false,
470 };
471 let v = size.truncate(v);
472 let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
473 cast_scalar == target_scalar
474}
475
476fn candidate_assign<'tcx, 'a>(
477 stmts: &'a [(u128, &'a StatementKind<'tcx>)],
478 otherwise: Option<&'a StatementKind<'tcx>>,
479) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> {
480 let (_, first_stmt) = stmts[0];
481 let (dest, _) = first_stmt.as_assign()?;
482 let otherwise = if let Some(otherwise) = otherwise {
483 let Some((otherwise_dest, rval)) = otherwise.as_assign() else {
484 return None;
485 };
486 if otherwise_dest != dest {
487 return None;
488 }
489 Some(rval)
490 } else {
491 None
492 };
493 let rvals = stmts
494 .into_iter()
495 .map(|&(case, stmt)| {
496 let (other_dest, rval) = stmt.as_assign()?;
497 if other_dest != dest {
498 return None;
499 }
500 Some((case, rval))
501 })
502 .try_collect()?;
503 Some((*dest, rvals, otherwise))
504}
505
506fn candidate_const<'tcx, 'a>(
508 rvals: &'a [(u128, &'a Rvalue<'tcx>)],
509 otherwise: Option<&'a Rvalue<'tcx>>,
510) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> {
511 let otherwise = if let Some(otherwise) = otherwise {
513 let Rvalue::Use(Operand::Constant(box const_), _) = otherwise else {
514 return None;
515 };
516 Some(const_)
517 } else {
518 None
519 };
520 let consts = rvals
521 .into_iter()
522 .map(|&(case, rval)| {
523 let Rvalue::Use(Operand::Constant(box const_), _) = rval else { return None };
524 Some((case, const_))
525 })
526 .try_collect()?;
527 Some((consts, otherwise))
528}
529
530fn split_first_case<'a, T>(
532 stmts: &'a [(u128, &'a T)],
533 otherwise: Option<&'a T>,
534) -> (u128, &'a T, impl Iterator<Item = &'a T>) {
535 let (first_case, first) = stmts[0];
536 (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise))
537}
538
539fn identical_stmts<'tcx>(
541 stmts: &[(u128, &StatementKind<'tcx>)],
542 otherwise: Option<&StatementKind<'tcx>>,
543) -> Option<StatementKind<'tcx>> {
544 use itertools::Itertools;
545 let (_, first_stmt, others) = split_first_case(stmts, otherwise);
546 if std::iter::once(first_stmt).chain(others).all_equal() {
547 return Some(first_stmt.clone());
548 }
549 None
550}