1use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{
66 Map, PlaceCollectionMode, PlaceIndex, TrackElem, ValueIndex,
67};
68use rustc_span::DUMMY_SP;
69use tracing::{debug, instrument, trace};
70
71use crate::cost_checker::CostChecker;
72
73pub(super) struct JumpThreading;
74
75const MAX_COST: u8 = 100;
76
77impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
78 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
79 if sess.target.is_like_gpu {
80 return false;
86 }
87 sess.mir_opt_level() >= 2
88 }
89
90 #[instrument(skip_all level = "debug")]
91 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
92 let def_id = body.source.def_id();
93 debug!(?def_id);
94
95 if tcx.is_coroutine(def_id) {
97 trace!("Skipped for coroutine {:?}", def_id);
98 return;
99 }
100
101 let typing_env = body.typing_env(tcx);
102 let mut finder = TOFinder {
103 tcx,
104 typing_env,
105 ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
106 body,
107 map: Map::new(tcx, body, PlaceCollectionMode::OnDemand),
108 maybe_loop_headers: maybe_loop_headers(body),
109 entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
110 };
111
112 for (bb, bbdata) in traversal::postorder(body) {
113 if bbdata.is_cleanup {
114 continue;
115 }
116
117 let mut state = finder.populate_from_outgoing_edges(bb);
118 trace!("output_states[{bb:?}] = {state:?}");
119
120 finder.process_terminator(bb, &mut state);
121 trace!("pre_terminator_states[{bb:?}] = {state:?}");
122
123 for stmt in bbdata.statements.iter().rev() {
124 if state.is_empty() {
125 break;
126 }
127
128 finder.process_statement(stmt, &mut state);
129
130 if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
135 finder.flood_state(lhs, tail, &mut state);
136 }
137 }
138
139 trace!("entry_states[{bb:?}] = {state:?}");
140 finder.entry_states[bb] = state;
141 }
142
143 let mut entry_states = finder.entry_states;
144 simplify_conditions(body, &mut entry_states);
145 remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
146
147 if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
148 opportunities.apply();
149 }
150 }
151
152 fn is_required(&self) -> bool {
153 false
154 }
155}
156
157struct TOFinder<'a, 'tcx> {
158 tcx: TyCtxt<'tcx>,
159 typing_env: ty::TypingEnv<'tcx>,
160 ecx: InterpCx<'tcx, DummyMachine>,
161 body: &'a Body<'tcx>,
162 map: Map<'tcx>,
163 maybe_loop_headers: DenseBitSet<BasicBlock>,
164 entry_states: IndexVec<BasicBlock, ConditionSet>,
169}
170
171rustc_index::newtype_index! {
172 #[orderable]
173 #[debug_format = "_c{}"]
174 struct ConditionIndex {}
175}
176
177#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
180struct Condition {
181 place: ValueIndex,
182 value: ScalarInt,
183 polarity: Polarity,
184}
185
186#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
187enum Polarity {
188 Ne,
189 Eq,
190}
191
192impl Condition {
193 fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
194 self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
195 }
196}
197
198#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
200enum EdgeEffect {
201 Goto { target: BasicBlock },
203 Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
205}
206
207impl EdgeEffect {
208 fn block(self) -> BasicBlock {
209 match self {
210 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
211 }
212 }
213
214 fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
215 match self {
216 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
217 if *bb == target {
218 *bb = new_target
219 }
220 }
221 }
222 }
223}
224
225#[derive(Clone, Debug, Default)]
226struct ConditionSet {
227 active: Vec<(ConditionIndex, Condition)>,
228 fulfilled: Vec<ConditionIndex>,
229 targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
230}
231
232impl ConditionSet {
233 fn is_empty(&self) -> bool {
234 self.active.is_empty()
235 }
236
237 #[tracing::instrument(level = "trace", skip(self))]
238 fn push_condition(&mut self, c: Condition, target: BasicBlock) {
239 let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
240 self.active.push((index, c));
241 }
242
243 fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
245 self.active.retain(|&(index, condition)| {
246 let targets = &self.targets[index];
247 if f(condition, targets) {
248 trace!(?index, ?condition, "fulfill");
249 self.fulfilled.push(index);
250 false
251 } else {
252 true
253 }
254 })
255 }
256
257 fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
259 self.fulfill_if(|c, _| c.matches(place, value))
260 }
261
262 fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
263 self.active.retain(|&(_, c)| f(c))
264 }
265
266 fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
267 self.active.retain_mut(|(_, c)| {
268 if let Some(new) = f(*c) {
269 *c = new;
270 true
271 } else {
272 false
273 }
274 })
275 }
276
277 fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
278 for (_, c) in &mut self.active {
279 f(c)
280 }
281 }
282}
283
284impl<'a, 'tcx> TOFinder<'a, 'tcx> {
285 fn place(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<PlaceIndex> {
286 self.map.register_place(self.tcx, self.body, place, tail)
287 }
288
289 fn value(&mut self, place: PlaceIndex) -> Option<ValueIndex> {
290 self.map.register_value(self.tcx, self.typing_env, place)
291 }
292
293 fn place_value(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<ValueIndex> {
294 let place = self.place(place, tail)?;
295 self.value(place)
296 }
297
298 #[instrument(level = "trace", skip(self))]
300 fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
301 let bbdata = &self.body[bb];
302
303 debug_assert!(self.entry_states[bb].is_empty());
305
306 let state_len =
307 bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
308 let mut state = ConditionSet {
309 active: Vec::with_capacity(state_len),
310 targets: IndexVec::with_capacity(state_len),
311 fulfilled: Vec::new(),
312 };
313
314 let mut known_conditions =
316 FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
317 let mut insert = |condition, succ_block, succ_condition| {
318 let (index, new) = known_conditions.insert_full(condition);
319 let index = ConditionIndex::from_usize(index);
320 if new {
321 state.active.push((index, condition));
322 let _index = state.targets.push(Vec::new());
323 debug_assert_eq!(_index, index);
324 }
325 let target = EdgeEffect::Chain { succ_block, succ_condition };
326 debug_assert!(
327 !state.targets[index].contains(&target),
328 "duplicate targets for index={index:?} as {target:?} targets={:#?}",
329 &state.targets[index],
330 );
331 state.targets[index].push(target);
332 };
333
334 let mut seen = FxHashSet::default();
336 for succ in bbdata.terminator().successors() {
337 if !seen.insert(succ) {
338 continue;
339 }
340
341 if self.maybe_loop_headers.contains(succ) {
343 continue;
344 }
345
346 for &(succ_index, cond) in self.entry_states[succ].active.iter() {
347 insert(cond, succ, succ_index);
348 }
349 }
350
351 let num_conditions = known_conditions.len();
352 debug_assert_eq!(num_conditions, state.active.len());
353 debug_assert_eq!(num_conditions, state.targets.len());
354 state.fulfilled.reserve(num_conditions);
355
356 state
357 }
358
359 fn flood_state(
361 &self,
362 place: Place<'tcx>,
363 extra_elem: Option<TrackElem>,
364 state: &mut ConditionSet,
365 ) {
366 if state.is_empty() {
367 return;
368 }
369 let mut places_to_exclude = FxHashSet::default();
370 self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
371 places_to_exclude.insert(vi);
372 });
373 trace!(?places_to_exclude, "flood_state");
374 if places_to_exclude.is_empty() {
375 return;
376 }
377 state.retain(|c| !places_to_exclude.contains(&c.place));
378 }
379
380 #[instrument(level = "trace", skip(self), ret)]
394 fn mutated_statement(
395 &self,
396 stmt: &Statement<'tcx>,
397 ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
398 match stmt.kind {
399 StatementKind::Assign(box (place, _)) => Some((place, None)),
400 StatementKind::SetDiscriminant { box place, variant_index: _ } => {
401 Some((place, Some(TrackElem::Discriminant)))
402 }
403 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
404 Some((Place::from(local), None))
405 }
406 | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
407 | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
409 | StatementKind::AscribeUserType(..)
410 | StatementKind::Coverage(..)
411 | StatementKind::FakeRead(..)
412 | StatementKind::ConstEvalCounter
413 | StatementKind::PlaceMention(..)
414 | StatementKind::BackwardIncompatibleDropHint { .. }
415 | StatementKind::Nop => None,
416 }
417 }
418
419 #[instrument(level = "trace", skip(self, state))]
420 fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
421 if let Some(lhs) = self.value(lhs)
422 && let Immediate::Scalar(Scalar::Int(int)) = *rhs
423 {
424 state.fulfill_matches(lhs, int)
425 }
426 }
427
428 #[instrument(level = "trace", skip(self, state))]
430 fn process_constant(
431 &mut self,
432 lhs: PlaceIndex,
433 constant: OpTy<'tcx>,
434 state: &mut ConditionSet,
435 ) {
436 self.map.for_each_projection_value(
437 lhs,
438 constant,
439 &mut |elem, op| match elem {
440 TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
441 TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
442 TrackElem::Discriminant => {
443 let variant = self.ecx.read_discriminant(op).discard_err()?;
444 let discr_value =
445 self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
446 Some(discr_value.into())
447 }
448 TrackElem::DerefLen => {
449 let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
450 let len_usize = op.len(&self.ecx).discard_err()?;
451 let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
452 Some(ImmTy::from_uint(len_usize, layout).into())
453 }
454 },
455 &mut |place, op| {
456 if let Some(place) = self.map.value(place)
457 && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
458 && let Some(imm) = imm.right()
459 && let Immediate::Scalar(Scalar::Int(int)) = *imm
460 {
461 state.fulfill_matches(place, int)
462 }
463 },
464 );
465 }
466
467 #[instrument(level = "trace", skip(self, state))]
468 fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
469 let mut renames = FxHashMap::default();
470 self.map.register_copy_tree(
471 lhs, rhs, &mut |lhs, rhs| {
474 renames.insert(lhs, rhs);
475 },
476 );
477 state.for_each_mut(|c| {
478 if let Some(rhs) = renames.get(&c.place) {
479 c.place = *rhs
480 }
481 });
482 }
483
484 #[instrument(level = "trace", skip(self, state))]
485 fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
486 match rhs {
487 Operand::Constant(constant) => {
489 let Some(constant) =
490 self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
491 else {
492 return;
493 };
494 self.process_constant(lhs, constant, state);
495 }
496 Operand::Move(rhs) | Operand::Copy(rhs) => {
498 let Some(rhs) = self.place(*rhs, None) else { return };
499 self.process_copy(lhs, rhs, state)
500 }
501 Operand::RuntimeChecks(_) => {}
502 }
503 }
504
505 #[instrument(level = "trace", skip(self, state))]
506 fn process_assign(
507 &mut self,
508 lhs_place: &Place<'tcx>,
509 rvalue: &Rvalue<'tcx>,
510 state: &mut ConditionSet,
511 ) {
512 let Some(lhs) = self.place(*lhs_place, None) else { return };
513 match rvalue {
514 Rvalue::Use(operand, _) => self.process_operand(lhs, operand, state),
515 Rvalue::Discriminant(rhs) => {
517 let Some(rhs) = self.place(*rhs, Some(TrackElem::Discriminant)) else { return };
518 self.process_copy(lhs, rhs, state)
519 }
520 Rvalue::Aggregate(box kind, operands) => {
522 let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
523 let lhs = match kind {
524 AggregateKind::Adt(.., Some(_)) => return,
526 AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
527 let discr_ty = agg_ty.discriminant_ty(self.tcx);
528 let discr_target =
529 self.map.register_place_index(discr_ty, lhs, TrackElem::Discriminant);
530 if let Some(discr_value) =
531 self.ecx.discriminant_for_variant(agg_ty, *variant_index).discard_err()
532 {
533 self.process_immediate(discr_target, discr_value, state);
534 }
535 self.map.register_place_index(
536 agg_ty,
537 lhs,
538 TrackElem::Variant(*variant_index),
539 )
540 }
541 _ => lhs,
542 };
543 for (field_index, operand) in operands.iter_enumerated() {
544 let operand_ty = operand.ty(self.body, self.tcx);
545 let field = self.map.register_place_index(
546 operand_ty,
547 lhs,
548 TrackElem::Field(field_index),
549 );
550 self.process_operand(field, operand, state);
551 }
552 }
553 Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
555 let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
556 let Some(lhs) = self.value(lhs) else { return };
557 let Some(operand) = self.place_value(*operand, None) else { return };
558 state.retain_mut(|mut c| {
559 if c.place == lhs {
560 let value = self
561 .ecx
562 .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
563 .discard_err()?
564 .to_scalar_int()
565 .discard_err()?;
566 c.place = operand;
567 c.value = value;
568 }
569 Some(c)
570 });
571 }
572 Rvalue::BinaryOp(
575 op,
576 box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
577 | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
578 ) => {
579 let equals = match op {
580 BinOp::Eq => ScalarInt::TRUE,
581 BinOp::Ne => ScalarInt::FALSE,
582 _ => return,
583 };
584 if value.const_.ty().is_floating_point() {
585 return;
590 }
591 let Some(lhs) = self.value(lhs) else { return };
592 let Some(operand) = self.place_value(*operand, None) else { return };
593 let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
594 else {
595 return;
596 };
597 state.for_each_mut(|c| {
598 if c.place == lhs {
599 let polarity =
600 if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
601 c.place = operand;
602 c.value = value;
603 c.polarity = polarity;
604 }
605 });
606 }
607
608 _ => {}
609 }
610 }
611
612 #[instrument(level = "trace", skip(self, state))]
613 fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
614 match &stmt.kind {
618 StatementKind::SetDiscriminant { box place, variant_index } => {
621 let Some(discr_target) = self.place(*place, Some(TrackElem::Discriminant)) else {
622 return;
623 };
624 let enum_ty = place.ty(self.body, self.tcx).ty;
625 let Some(discr) =
629 self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
630 else {
631 return;
632 };
633 self.process_immediate(discr_target, discr, state)
634 }
635 StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
637 Operand::Copy(place) | Operand::Move(place),
638 )) => {
639 let Some(place) = self.place_value(*place, None) else { return };
640 state.fulfill_matches(place, ScalarInt::TRUE);
641 }
642 StatementKind::Assign(box (lhs_place, rhs)) => {
643 self.process_assign(lhs_place, rhs, state)
644 }
645 _ => {}
646 }
647 }
648
649 #[instrument(level = "trace", skip(self, state))]
651 fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
652 let term = self.body.basic_blocks[bb].terminator();
653 let place_to_flood = match term.kind {
654 TerminatorKind::FalseEdge { .. }
656 | TerminatorKind::FalseUnwind { .. }
657 | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
658 TerminatorKind::InlineAsm { .. } => {
660 state.active.clear();
661 return;
662 }
663 TerminatorKind::SwitchInt { ref discr, ref targets } => {
665 return self.process_switch_int(discr, targets, state);
666 }
667 TerminatorKind::UnwindResume
669 | TerminatorKind::UnwindTerminate(_)
670 | TerminatorKind::Return
671 | TerminatorKind::Unreachable
672 | TerminatorKind::CoroutineDrop
673 | TerminatorKind::Assert { .. }
675 | TerminatorKind::Goto { .. } => None,
676 TerminatorKind::Drop { place: destination, .. }
678 | TerminatorKind::Call { destination, .. } => Some(destination),
679 TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
680 };
681
682 if let Some(place_to_flood) = place_to_flood {
684 self.flood_state(place_to_flood, None, state);
685 }
686 }
687
688 #[instrument(level = "trace", skip(self))]
689 fn process_switch_int(
690 &mut self,
691 discr: &Operand<'tcx>,
692 targets: &SwitchTargets,
693 state: &mut ConditionSet,
694 ) {
695 let Some(discr) = discr.place() else { return };
696 let Some(discr_idx) = self.place_value(discr, None) else { return };
697
698 let discr_ty = discr.ty(self.body, self.tcx).ty;
699 let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
700
701 if targets.is_distinct() {
704 for &(index, c) in state.active.iter() {
705 if c.place != discr_idx {
706 continue;
707 }
708
709 let mut edges_fulfilling_condition = FxHashSet::default();
711
712 for (branch, tgt) in targets.iter() {
714 if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
715 && c.matches(discr_idx, branch)
716 {
717 edges_fulfilling_condition.insert(tgt);
718 }
719 }
720
721 if c.polarity == Polarity::Ne
726 && let Ok(value) = c.value.try_to_bits(discr_layout.size)
727 && targets.all_values().contains(&value.into())
728 {
729 edges_fulfilling_condition.insert(targets.otherwise());
730 }
731
732 let condition_targets = &state.targets[index];
736
737 let new_edges: Vec<_> = condition_targets
738 .iter()
739 .copied()
740 .filter(|&target| match target {
741 EdgeEffect::Goto { .. } => false,
742 EdgeEffect::Chain { succ_block, .. } => {
743 edges_fulfilling_condition.contains(&succ_block)
744 }
745 })
746 .collect();
747
748 if new_edges.len() == condition_targets.len() {
749 state.fulfilled.push(index);
752 } else {
753 let index = state.targets.push(new_edges);
756 state.fulfilled.push(index);
757 }
758 }
759 }
760
761 let mut mk_condition = |value, polarity, target| {
763 let c = Condition { place: discr_idx, value, polarity };
764 state.push_condition(c, target);
765 };
766 if let Some((value, then_, else_)) = targets.as_static_if() {
767 let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
769 mk_condition(value, Polarity::Eq, then_);
770 mk_condition(value, Polarity::Ne, else_);
771 } else {
772 for (value, target) in targets.iter() {
775 if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
776 mk_condition(value, Polarity::Eq, target);
777 }
778 }
779 }
780 }
781}
782
783#[instrument(level = "debug", skip(body, entry_states))]
785fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
786 let basic_blocks = &body.basic_blocks;
787 let reverse_postorder = basic_blocks.reverse_postorder();
788
789 let mut predecessors = IndexVec::from_elem(0, &entry_states);
792 predecessors[START_BLOCK] = 1; for &bb in reverse_postorder {
794 let term = basic_blocks[bb].terminator();
795 for s in term.successors() {
796 predecessors[s] += 1;
797 }
798 }
799
800 let mut fulfill_in_pred_count = IndexVec::from_fn_n(
802 |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
803 entry_states.len(),
804 );
805
806 for &bb in reverse_postorder {
808 let preds = predecessors[bb];
809 trace!(?bb, ?preds);
810
811 if preds == 0 {
813 continue;
814 }
815
816 let state = &mut entry_states[bb];
817 trace!(?state);
818
819 trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
821 for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
822 if cond_preds == preds {
823 trace!(?condition);
824 state.fulfilled.push(condition);
825 }
826 }
827
828 let mut targets: Vec<_> = state
831 .fulfilled
832 .iter()
833 .flat_map(|&index| state.targets[index].iter().copied())
834 .collect();
835 targets.sort();
836 targets.dedup();
837 trace!(?targets);
838
839 let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
841
842 targets.reverse();
843 while let Some(target) = targets.pop() {
844 match target {
845 EdgeEffect::Goto { target } => {
846 predecessors[target] += 1;
849 for &s in successors.iter() {
850 predecessors[s] -= 1;
851 }
852 targets.retain(|t| t.block() == target);
854 successors.clear();
855 successors.push(target);
856 }
857 EdgeEffect::Chain { succ_block, succ_condition } => {
858 let count = successors.iter().filter(|&&s| s == succ_block).count();
861 fulfill_in_pred_count[succ_block][succ_condition] += count;
862 }
863 }
864 }
865 }
866}
867
868#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
869fn remove_costly_conditions<'tcx>(
870 tcx: TyCtxt<'tcx>,
871 typing_env: ty::TypingEnv<'tcx>,
872 body: &Body<'tcx>,
873 entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
874) {
875 let basic_blocks = &body.basic_blocks;
876
877 let mut costs = IndexVec::from_elem(None, basic_blocks);
878 let mut cost = |bb: BasicBlock| -> u8 {
879 let c = *costs[bb].get_or_insert_with(|| {
880 let bbdata = &basic_blocks[bb];
881 let mut cost = CostChecker::new(tcx, typing_env, None, body);
882 cost.visit_basic_block_data(bb, bbdata);
883 cost.cost().try_into().unwrap_or(MAX_COST)
884 });
885 trace!("cost[{bb:?}] = {c}");
886 c
887 };
888
889 let mut condition_cost = IndexVec::from_fn_n(
891 |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
892 entry_states.len(),
893 );
894
895 let reverse_postorder = basic_blocks.reverse_postorder();
896
897 for &bb in reverse_postorder.iter().rev() {
898 let state = &entry_states[bb];
899 trace!(?bb, ?state);
900
901 let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
902
903 for (condition, targets) in state.targets.iter_enumerated() {
904 for &target in targets {
905 match target {
906 EdgeEffect::Goto { .. } => {}
908 EdgeEffect::Chain { succ_block, succ_condition }
910 if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
911 EdgeEffect::Chain { succ_block, succ_condition } => {
913 let duplication_cost = cost(succ_block);
915 let target_cost =
917 *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
918 let cost = current_costs[condition]
919 .saturating_add(duplication_cost)
920 .saturating_add(target_cost);
921 trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
922 current_costs[condition] = cost;
923 }
924 }
925 }
926 }
927
928 trace!("condition_cost[{bb:?}] = {:?}", current_costs);
929 condition_cost[bb] = current_costs;
930 }
931
932 trace!(?condition_cost);
933
934 for &bb in reverse_postorder {
935 for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
936 if condition_cost[bb][index] >= MAX_COST {
937 trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
938 targets.clear()
939 }
940 }
941 }
942}
943
944struct OpportunitySet<'a, 'tcx> {
945 basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
946 entry_states: IndexVec<BasicBlock, ConditionSet>,
947 duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
950}
951
952impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
953 fn new(
954 body: &'a mut Body<'tcx>,
955 mut entry_states: IndexVec<BasicBlock, ConditionSet>,
956 ) -> Option<OpportunitySet<'a, 'tcx>> {
957 trace!(def_id = ?body.source.def_id(), "apply");
958
959 if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
960 return None;
961 }
962
963 for state in entry_states.iter_mut() {
965 state.active = Default::default();
966 }
967 let duplicates = Default::default();
968 let basic_blocks = body.basic_blocks.as_mut();
969 Some(OpportunitySet { basic_blocks, entry_states, duplicates })
970 }
971
972 #[instrument(level = "debug", skip(self))]
974 fn apply(mut self) {
975 let mut worklist = Vec::with_capacity(self.basic_blocks.len());
976 worklist.push(START_BLOCK);
977
978 let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
980
981 while let Some(bb) = worklist.pop() {
982 if !visited.insert(bb) {
983 continue;
984 }
985
986 self.apply_once(bb);
987
988 worklist.extend(self.basic_blocks[bb].terminator().successors());
991 }
992 }
993
994 #[instrument(level = "debug", skip(self))]
996 fn apply_once(&mut self, bb: BasicBlock) {
997 let state = &mut self.entry_states[bb];
998 trace!(?state);
999
1000 let mut targets: Vec<_> = state
1003 .fulfilled
1004 .iter()
1005 .flat_map(|&index| std::mem::take(&mut state.targets[index]))
1006 .collect();
1007 targets.sort();
1008 targets.dedup();
1009 trace!(?targets);
1010
1011 targets.reverse();
1013 while let Some(target) = targets.pop() {
1014 debug!(?target);
1015 trace!(term = ?self.basic_blocks[bb].terminator().kind);
1016
1017 debug_assert!(
1021 self.basic_blocks[bb].terminator().successors().contains(&target.block()),
1022 "missing {target:?} in successors for {bb:?}, term={:?}",
1023 self.basic_blocks[bb].terminator(),
1024 );
1025
1026 match target {
1027 EdgeEffect::Goto { target } => {
1028 self.apply_goto(bb, target);
1029
1030 targets.retain(|t| t.block() == target);
1032 for ts in self.entry_states[bb].targets.iter_mut() {
1034 ts.retain(|t| t.block() == target);
1035 }
1036 }
1037 EdgeEffect::Chain { succ_block, succ_condition } => {
1038 let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1039
1040 if let Some(new_succ_block) = new_succ_block {
1042 for t in targets.iter_mut() {
1043 t.replace_block(succ_block, new_succ_block)
1044 }
1045 for t in
1047 self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1048 {
1049 t.replace_block(succ_block, new_succ_block)
1050 }
1051 }
1052 }
1053 }
1054
1055 trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1056 }
1057 }
1058
1059 #[instrument(level = "debug", skip(self))]
1060 fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1061 self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1062 }
1063
1064 #[instrument(level = "debug", skip(self), ret)]
1065 fn apply_chain(
1066 &mut self,
1067 bb: BasicBlock,
1068 target: BasicBlock,
1069 condition: ConditionIndex,
1070 ) -> Option<BasicBlock> {
1071 if self.entry_states[target].fulfilled.contains(&condition) {
1072 trace!("fulfilled");
1074 return None;
1075 }
1076
1077 let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1083 let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1086 trace!(?target, ?new_target, ?condition, "clone");
1087
1088 let mut condition_set = self.entry_states[target].clone();
1091 condition_set.fulfilled.push(condition);
1092 let _new_target = self.entry_states.push(condition_set);
1093 debug_assert_eq!(new_target, _new_target);
1094
1095 new_target
1096 });
1097 trace!(?target, ?new_target, ?condition, "reuse");
1098
1099 self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1102 if *s == target {
1103 *s = new_target;
1104 }
1105 });
1106
1107 Some(new_target)
1108 }
1109}
1110
1111fn maybe_loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> {
1117 let mut maybe_loop_headers = DenseBitSet::new_empty(body.basic_blocks.len());
1118 let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
1119 for (bb, bbdata) in traversal::postorder(body) {
1120 for succ in bbdata.terminator().successors() {
1123 if !visited.contains(succ) {
1124 maybe_loop_headers.insert(succ);
1125 }
1126 }
1127
1128 let _new = visited.insert(bb);
1131 debug_assert!(_new);
1132 }
1133
1134 maybe_loop_headers
1135}