Skip to main content

rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! This implementation is heavily inspired by the work outlined in [libfirm].
11//!
12//! The general algorithm proceeds in two phases: (1) walk the CFG backwards to construct a
13//! graph of threading conditions, and (2) propagate fulfilled conditions forward by duplicating
14//! blocks.
15//!
16//! # 1. Condition graph construction
17//!
18//! In this file, we denote as `place ?= value` the existence of a replacement condition
19//! on `place` with given `value`, irrespective of the polarity and target of that
20//! replacement condition.
21//!
22//! Inside a block, we associate with each condition `c` a set of targets:
23//! - `Goto(target)` if fulfilling `c` changes the terminator into a `Goto { target }`;
24//! - `Chain(target, c2)` if fulfilling `c` means that `c2` is fulfilled inside `target`.
25//!
26//! Before walking a block `bb`, we construct the exit set of condition from its successors.
27//! For each condition `c` in a successor `s`, we record that fulfilling `c` in `bb` will fulfill
28//! `c` in `s`, as a `Chain(s, c)` condition.
29//!
30//! When encountering a `switchInt(place) -> [value: bb...]` terminator, we also record a
31//! `place == value` condition for each `value`, and associate a `Goto(target)` condition.
32//!
33//! Then, we walk the statements backwards, transforming the set of conditions along the way,
34//! resulting in a set of conditions at the block entry.
35//!
36//! We try to avoid creating irreducible control-flow by not threading through a loop header.
37//!
38//! Applying the optimisation can create a lot of new MIR, so we bound the instruction
39//! cost by `MAX_COST`.
40//!
41//! # 2. Block duplication
42//!
43//! We now have the set of fulfilled conditions inside each block and their targets.
44//!
45//! For each block `bb` in reverse postorder, we apply in turn the target associated with each
46//! fulfilled condition:
47//! - for `Goto(target)`, change the terminator of `bb` into a `Goto { target }`;
48//! - for `Chain(target, cond)`, duplicate `target` into a new block which fulfills the same
49//! conditions and also fulfills `cond`. This is made efficient by maintaining a map of duplicates,
50//! `duplicate[(target, cond)]` to avoid cloning blocks multiple times.
51//!
52//! [libfirm]: <https://pp.ipd.kit.edu/uploads/publikationen/priesner17masterarbeit.pdf>
53
54use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{
66    Map, PlaceCollectionMode, PlaceIndex, TrackElem, ValueIndex,
67};
68use rustc_span::DUMMY_SP;
69use tracing::{debug, instrument, trace};
70
71use crate::cost_checker::CostChecker;
72
73pub(super) struct JumpThreading;
74
75const MAX_COST: u8 = 100;
76
77impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
78    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
79        if sess.target.is_like_gpu {
80            // Jump threading can duplicate calls in control-flow.
81            // This leads to incorrect code when done for so called "convergent" operations on GPU
82            // targets, similar to how inline assembly cannot be duplicated on all targets.
83            // Conservatively prevent this by disabling the pass.
84            // See also issue #137086.
85            return false;
86        }
87        sess.mir_opt_level() >= 2
88    }
89
90    #[instrument(skip_all level = "debug")]
91    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
92        let def_id = body.source.def_id();
93        debug!(?def_id);
94
95        // Optimizing coroutines creates query cycles.
96        if tcx.is_coroutine(def_id) {
97            trace!("Skipped for coroutine {:?}", def_id);
98            return;
99        }
100
101        let typing_env = body.typing_env(tcx);
102        let mut finder = TOFinder {
103            tcx,
104            typing_env,
105            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
106            body,
107            map: Map::new(tcx, body, PlaceCollectionMode::OnDemand),
108            maybe_loop_headers: maybe_loop_headers(body),
109            entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
110        };
111
112        for (bb, bbdata) in traversal::postorder(body) {
113            if bbdata.is_cleanup {
114                continue;
115            }
116
117            let mut state = finder.populate_from_outgoing_edges(bb);
118            trace!("output_states[{bb:?}] = {state:?}");
119
120            finder.process_terminator(bb, &mut state);
121            trace!("pre_terminator_states[{bb:?}] = {state:?}");
122
123            for stmt in bbdata.statements.iter().rev() {
124                if state.is_empty() {
125                    break;
126                }
127
128                finder.process_statement(stmt, &mut state);
129
130                // When a statement mutates a place, assignments to that place that happen
131                // above the mutation cannot fulfill a condition.
132                //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
133                //   _1 = 6
134                if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
135                    finder.flood_state(lhs, tail, &mut state);
136                }
137            }
138
139            trace!("entry_states[{bb:?}] = {state:?}");
140            finder.entry_states[bb] = state;
141        }
142
143        let mut entry_states = finder.entry_states;
144        simplify_conditions(body, &mut entry_states);
145        remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
146
147        if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
148            opportunities.apply();
149        }
150    }
151
152    fn is_required(&self) -> bool {
153        false
154    }
155}
156
157struct TOFinder<'a, 'tcx> {
158    tcx: TyCtxt<'tcx>,
159    typing_env: ty::TypingEnv<'tcx>,
160    ecx: InterpCx<'tcx, DummyMachine>,
161    body: &'a Body<'tcx>,
162    map: Map<'tcx>,
163    maybe_loop_headers: DenseBitSet<BasicBlock>,
164    /// This stores the state of each visited block on entry,
165    /// and the current state of the block being visited.
166    // Invariant: for each `bb`, each condition in `entry_states[bb]` has a `chain` that
167    // starts with `bb`.
168    entry_states: IndexVec<BasicBlock, ConditionSet>,
169}
170
171rustc_index::newtype_index! {
172    #[orderable]
173    #[debug_format = "_c{}"]
174    struct ConditionIndex {}
175}
176
177/// Represent the following statement. If we can prove that the current local is equal/not-equal
178/// to `value`, jump to `target`.
179#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
180struct Condition {
181    place: ValueIndex,
182    value: ScalarInt,
183    polarity: Polarity,
184}
185
186#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
187enum Polarity {
188    Ne,
189    Eq,
190}
191
192impl Condition {
193    fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
194        self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
195    }
196}
197
198/// Represent the effect of fulfilling a condition.
199#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
200enum EdgeEffect {
201    /// If the condition is fulfilled, replace the current block's terminator by a single goto.
202    Goto { target: BasicBlock },
203    /// If the condition is fulfilled, fulfill the condition `succ_condition` in `succ_block`.
204    Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
205}
206
207impl EdgeEffect {
208    fn block(self) -> BasicBlock {
209        match self {
210            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
211        }
212    }
213
214    fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
215        match self {
216            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
217                if *bb == target {
218                    *bb = new_target
219                }
220            }
221        }
222    }
223}
224
225#[derive(Clone, Debug, Default)]
226struct ConditionSet {
227    active: Vec<(ConditionIndex, Condition)>,
228    fulfilled: Vec<ConditionIndex>,
229    targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
230}
231
232impl ConditionSet {
233    fn is_empty(&self) -> bool {
234        self.active.is_empty()
235    }
236
237    #[tracing::instrument(level = "trace", skip(self))]
238    fn push_condition(&mut self, c: Condition, target: BasicBlock) {
239        let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
240        self.active.push((index, c));
241    }
242
243    /// Register fulfilled condition and remove it from the set.
244    fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
245        self.active.retain(|&(index, condition)| {
246            let targets = &self.targets[index];
247            if f(condition, targets) {
248                trace!(?index, ?condition, "fulfill");
249                self.fulfilled.push(index);
250                false
251            } else {
252                true
253            }
254        })
255    }
256
257    /// Register fulfilled condition and remove them from the set.
258    fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
259        self.fulfill_if(|c, _| c.matches(place, value))
260    }
261
262    fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
263        self.active.retain(|&(_, c)| f(c))
264    }
265
266    fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
267        self.active.retain_mut(|(_, c)| {
268            if let Some(new) = f(*c) {
269                *c = new;
270                true
271            } else {
272                false
273            }
274        })
275    }
276
277    fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
278        for (_, c) in &mut self.active {
279            f(c)
280        }
281    }
282}
283
284impl<'a, 'tcx> TOFinder<'a, 'tcx> {
285    fn place(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<PlaceIndex> {
286        self.map.register_place(self.tcx, self.body, place, tail)
287    }
288
289    fn value(&mut self, place: PlaceIndex) -> Option<ValueIndex> {
290        self.map.register_value(self.tcx, self.typing_env, place)
291    }
292
293    fn place_value(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<ValueIndex> {
294        let place = self.place(place, tail)?;
295        self.value(place)
296    }
297
298    /// Construct the condition set for `bb` from the terminator, without executing its effect.
299    #[instrument(level = "trace", skip(self))]
300    fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
301        let bbdata = &self.body[bb];
302
303        // This should be the first time we populate `entry_states[bb]`.
304        debug_assert!(self.entry_states[bb].is_empty());
305
306        let state_len =
307            bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
308        let mut state = ConditionSet {
309            active: Vec::with_capacity(state_len),
310            targets: IndexVec::with_capacity(state_len),
311            fulfilled: Vec::new(),
312        };
313
314        // Use an index-set to deduplicate conditions coming from different successor blocks.
315        let mut known_conditions =
316            FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
317        let mut insert = |condition, succ_block, succ_condition| {
318            let (index, new) = known_conditions.insert_full(condition);
319            let index = ConditionIndex::from_usize(index);
320            if new {
321                state.active.push((index, condition));
322                let _index = state.targets.push(Vec::new());
323                debug_assert_eq!(_index, index);
324            }
325            let target = EdgeEffect::Chain { succ_block, succ_condition };
326            debug_assert!(
327                !state.targets[index].contains(&target),
328                "duplicate targets for index={index:?} as {target:?} targets={:#?}",
329                &state.targets[index],
330            );
331            state.targets[index].push(target);
332        };
333
334        // A given block may have several times the same successor.
335        let mut seen = FxHashSet::default();
336        for succ in bbdata.terminator().successors() {
337            if !seen.insert(succ) {
338                continue;
339            }
340
341            // Do not thread through loop headers.
342            if self.maybe_loop_headers.contains(succ) {
343                continue;
344            }
345
346            for &(succ_index, cond) in self.entry_states[succ].active.iter() {
347                insert(cond, succ, succ_index);
348            }
349        }
350
351        let num_conditions = known_conditions.len();
352        debug_assert_eq!(num_conditions, state.active.len());
353        debug_assert_eq!(num_conditions, state.targets.len());
354        state.fulfilled.reserve(num_conditions);
355
356        state
357    }
358
359    /// Remove all conditions in the state that alias given place.
360    fn flood_state(
361        &self,
362        place: Place<'tcx>,
363        extra_elem: Option<TrackElem>,
364        state: &mut ConditionSet,
365    ) {
366        if state.is_empty() {
367            return;
368        }
369        let mut places_to_exclude = FxHashSet::default();
370        self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
371            places_to_exclude.insert(vi);
372        });
373        trace!(?places_to_exclude, "flood_state");
374        if places_to_exclude.is_empty() {
375            return;
376        }
377        state.retain(|c| !places_to_exclude.contains(&c.place));
378    }
379
380    /// Extract the mutated place from a statement.
381    ///
382    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
383    ///     (_1 as Ok).0 = _5;
384    ///     (_1 as Err).0 = _6;
385    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
386    /// the value may have been mangled by the second assignment.
387    ///
388    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
389    /// stop at flooding the discriminant, and preserve the variant fields.
390    ///     (_1 as Some).0 = _6;
391    ///     SetDiscriminant(_1, 1);
392    ///     switchInt((_1 as Some).0)
393    #[instrument(level = "trace", skip(self), ret)]
394    fn mutated_statement(
395        &self,
396        stmt: &Statement<'tcx>,
397    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
398        match stmt.kind {
399            StatementKind::Assign(box (place, _)) => Some((place, None)),
400            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
401                Some((place, Some(TrackElem::Discriminant)))
402            }
403            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
404                Some((Place::from(local), None))
405            }
406            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
407            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
408            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
409            | StatementKind::AscribeUserType(..)
410            | StatementKind::Coverage(..)
411            | StatementKind::FakeRead(..)
412            | StatementKind::ConstEvalCounter
413            | StatementKind::PlaceMention(..)
414            | StatementKind::BackwardIncompatibleDropHint { .. }
415            | StatementKind::Nop => None,
416        }
417    }
418
419    #[instrument(level = "trace", skip(self, state))]
420    fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
421        if let Some(lhs) = self.value(lhs)
422            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
423        {
424            state.fulfill_matches(lhs, int)
425        }
426    }
427
428    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
429    #[instrument(level = "trace", skip(self, state))]
430    fn process_constant(
431        &mut self,
432        lhs: PlaceIndex,
433        constant: OpTy<'tcx>,
434        state: &mut ConditionSet,
435    ) {
436        self.map.for_each_projection_value(
437            lhs,
438            constant,
439            &mut |elem, op| match elem {
440                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
441                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
442                TrackElem::Discriminant => {
443                    let variant = self.ecx.read_discriminant(op).discard_err()?;
444                    let discr_value =
445                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
446                    Some(discr_value.into())
447                }
448                TrackElem::DerefLen => {
449                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
450                    let len_usize = op.len(&self.ecx).discard_err()?;
451                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
452                    Some(ImmTy::from_uint(len_usize, layout).into())
453                }
454            },
455            &mut |place, op| {
456                if let Some(place) = self.map.value(place)
457                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
458                    && let Some(imm) = imm.right()
459                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
460                {
461                    state.fulfill_matches(place, int)
462                }
463            },
464        );
465    }
466
467    #[instrument(level = "trace", skip(self, state))]
468    fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
469        let mut renames = FxHashMap::default();
470        self.map.register_copy_tree(
471            lhs, // tree to copy
472            rhs, // tree to build
473            &mut |lhs, rhs| {
474                renames.insert(lhs, rhs);
475            },
476        );
477        state.for_each_mut(|c| {
478            if let Some(rhs) = renames.get(&c.place) {
479                c.place = *rhs
480            }
481        });
482    }
483
484    #[instrument(level = "trace", skip(self, state))]
485    fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
486        match rhs {
487            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
488            Operand::Constant(constant) => {
489                let Some(constant) =
490                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
491                else {
492                    return;
493                };
494                self.process_constant(lhs, constant, state);
495            }
496            // Transfer the conditions on the copied rhs.
497            Operand::Move(rhs) | Operand::Copy(rhs) => {
498                let Some(rhs) = self.place(*rhs, None) else { return };
499                self.process_copy(lhs, rhs, state)
500            }
501            Operand::RuntimeChecks(_) => {}
502        }
503    }
504
505    #[instrument(level = "trace", skip(self, state))]
506    fn process_assign(
507        &mut self,
508        lhs_place: &Place<'tcx>,
509        rvalue: &Rvalue<'tcx>,
510        state: &mut ConditionSet,
511    ) {
512        let Some(lhs) = self.place(*lhs_place, None) else { return };
513        match rvalue {
514            Rvalue::Use(operand, _) => self.process_operand(lhs, operand, state),
515            // Transfer the conditions on the copy rhs.
516            Rvalue::Discriminant(rhs) => {
517                let Some(rhs) = self.place(*rhs, Some(TrackElem::Discriminant)) else { return };
518                self.process_copy(lhs, rhs, state)
519            }
520            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
521            Rvalue::Aggregate(box kind, operands) => {
522                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
523                let lhs = match kind {
524                    // Do not support unions.
525                    AggregateKind::Adt(.., Some(_)) => return,
526                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
527                        let discr_ty = agg_ty.discriminant_ty(self.tcx);
528                        let discr_target =
529                            self.map.register_place_index(discr_ty, lhs, TrackElem::Discriminant);
530                        if let Some(discr_value) =
531                            self.ecx.discriminant_for_variant(agg_ty, *variant_index).discard_err()
532                        {
533                            self.process_immediate(discr_target, discr_value, state);
534                        }
535                        self.map.register_place_index(
536                            agg_ty,
537                            lhs,
538                            TrackElem::Variant(*variant_index),
539                        )
540                    }
541                    _ => lhs,
542                };
543                for (field_index, operand) in operands.iter_enumerated() {
544                    let operand_ty = operand.ty(self.body, self.tcx);
545                    let field = self.map.register_place_index(
546                        operand_ty,
547                        lhs,
548                        TrackElem::Field(field_index),
549                    );
550                    self.process_operand(field, operand, state);
551                }
552            }
553            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
554            Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
555                let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
556                let Some(lhs) = self.value(lhs) else { return };
557                let Some(operand) = self.place_value(*operand, None) else { return };
558                state.retain_mut(|mut c| {
559                    if c.place == lhs {
560                        let value = self
561                            .ecx
562                            .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
563                            .discard_err()?
564                            .to_scalar_int()
565                            .discard_err()?;
566                        c.place = operand;
567                        c.value = value;
568                    }
569                    Some(c)
570                });
571            }
572            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
573            // Create a condition on `rhs ?= B`.
574            Rvalue::BinaryOp(
575                op,
576                box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
577                | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
578            ) => {
579                let equals = match op {
580                    BinOp::Eq => ScalarInt::TRUE,
581                    BinOp::Ne => ScalarInt::FALSE,
582                    _ => return,
583                };
584                if value.const_.ty().is_floating_point() {
585                    // Floating point equality does not follow bit-patterns.
586                    // -0.0 and NaN both have special rules for equality,
587                    // and therefore we cannot use integer comparisons for them.
588                    // Avoid handling them, though this could be extended in the future.
589                    return;
590                }
591                let Some(lhs) = self.value(lhs) else { return };
592                let Some(operand) = self.place_value(*operand, None) else { return };
593                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
594                else {
595                    return;
596                };
597                state.for_each_mut(|c| {
598                    if c.place == lhs {
599                        let polarity =
600                            if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
601                        c.place = operand;
602                        c.value = value;
603                        c.polarity = polarity;
604                    }
605                });
606            }
607
608            _ => {}
609        }
610    }
611
612    #[instrument(level = "trace", skip(self, state))]
613    fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
614        // Below, `lhs` is the return value of `mutated_statement`,
615        // the place to which `conditions` apply.
616
617        match &stmt.kind {
618            // If we expect `discriminant(place) ?= A`,
619            // we have an opportunity if `variant_index ?= A`.
620            StatementKind::SetDiscriminant { box place, variant_index } => {
621                let Some(discr_target) = self.place(*place, Some(TrackElem::Discriminant)) else {
622                    return;
623                };
624                let enum_ty = place.ty(self.body, self.tcx).ty;
625                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
626                // Even if the discriminant write does nothing due to niches, it is UB to set the
627                // discriminant when the data does not encode the desired discriminant.
628                let Some(discr) =
629                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
630                else {
631                    return;
632                };
633                self.process_immediate(discr_target, discr, state)
634            }
635            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
636            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
637                Operand::Copy(place) | Operand::Move(place),
638            )) => {
639                let Some(place) = self.place_value(*place, None) else { return };
640                state.fulfill_matches(place, ScalarInt::TRUE);
641            }
642            StatementKind::Assign(box (lhs_place, rhs)) => {
643                self.process_assign(lhs_place, rhs, state)
644            }
645            _ => {}
646        }
647    }
648
649    /// Execute the terminator for block `bb` into state `entry_states[bb]`.
650    #[instrument(level = "trace", skip(self, state))]
651    fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
652        let term = self.body.basic_blocks[bb].terminator();
653        let place_to_flood = match term.kind {
654            // Disallowed during optimizations.
655            TerminatorKind::FalseEdge { .. }
656            | TerminatorKind::FalseUnwind { .. }
657            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
658            // Cannot reason about inline asm.
659            TerminatorKind::InlineAsm { .. } => {
660                state.active.clear();
661                return;
662            }
663            // `SwitchInt` is handled specially.
664            TerminatorKind::SwitchInt { ref discr, ref targets } => {
665                return self.process_switch_int(discr, targets, state);
666            }
667            // These do not modify memory.
668            TerminatorKind::UnwindResume
669            | TerminatorKind::UnwindTerminate(_)
670            | TerminatorKind::Return
671            | TerminatorKind::Unreachable
672            | TerminatorKind::CoroutineDrop
673            // Assertions can be no-op at codegen time, so treat them as such.
674            | TerminatorKind::Assert { .. }
675            | TerminatorKind::Goto { .. } => None,
676            // Flood the overwritten place, and progress through.
677            TerminatorKind::Drop { place: destination, .. }
678            | TerminatorKind::Call { destination, .. } => Some(destination),
679            TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
680        };
681
682        // This terminator modifies `place_to_flood`, cleanup the associated conditions.
683        if let Some(place_to_flood) = place_to_flood {
684            self.flood_state(place_to_flood, None, state);
685        }
686    }
687
688    #[instrument(level = "trace", skip(self))]
689    fn process_switch_int(
690        &mut self,
691        discr: &Operand<'tcx>,
692        targets: &SwitchTargets,
693        state: &mut ConditionSet,
694    ) {
695        let Some(discr) = discr.place() else { return };
696        let Some(discr_idx) = self.place_value(discr, None) else { return };
697
698        let discr_ty = discr.ty(self.body, self.tcx).ty;
699        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
700
701        // Attempt to fulfill a condition using an outgoing branch's condition.
702        // Only support the case where there are no duplicated outgoing edges.
703        if targets.is_distinct() {
704            for &(index, c) in state.active.iter() {
705                if c.place != discr_idx {
706                    continue;
707                }
708
709                // Set of blocks `t` such that the edge `bb -> t` fulfills `c`.
710                let mut edges_fulfilling_condition = FxHashSet::default();
711
712                // On edge `bb -> tgt`, we know that `discr_idx == branch`.
713                for (branch, tgt) in targets.iter() {
714                    if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
715                        && c.matches(discr_idx, branch)
716                    {
717                        edges_fulfilling_condition.insert(tgt);
718                    }
719                }
720
721                // On edge `bb -> otherwise`, we only know that `discr` is different from all the
722                // constants in the switch. That's much weaker information than the equality we
723                // had in the previous arm. All we can conclude is that the replacement condition
724                // `discr != value` can be threaded, and nothing else.
725                if c.polarity == Polarity::Ne
726                    && let Ok(value) = c.value.try_to_bits(discr_layout.size)
727                    && targets.all_values().contains(&value.into())
728                {
729                    edges_fulfilling_condition.insert(targets.otherwise());
730                }
731
732                // Register that jumping to a `t` fulfills condition `c`.
733                // This does *not* mean that `c` is fulfilled in this block: inserting `index` in
734                // `fulfilled` is wrong if we have targets that jump to other blocks.
735                let condition_targets = &state.targets[index];
736
737                let new_edges: Vec<_> = condition_targets
738                    .iter()
739                    .copied()
740                    .filter(|&target| match target {
741                        EdgeEffect::Goto { .. } => false,
742                        EdgeEffect::Chain { succ_block, .. } => {
743                            edges_fulfilling_condition.contains(&succ_block)
744                        }
745                    })
746                    .collect();
747
748                if new_edges.len() == condition_targets.len() {
749                    // If `new_edges == condition_targets`, do not bother creating a new
750                    // `ConditionIndex`, we can use the existing one.
751                    state.fulfilled.push(index);
752                } else {
753                    // Fulfilling `index` may thread conditions that we do not want,
754                    // so create a brand new index to immediately mark fulfilled.
755                    let index = state.targets.push(new_edges);
756                    state.fulfilled.push(index);
757                }
758            }
759        }
760
761        // Introduce additional conditions of the form `discr ?= value` for each value in targets.
762        let mut mk_condition = |value, polarity, target| {
763            let c = Condition { place: discr_idx, value, polarity };
764            state.push_condition(c, target);
765        };
766        if let Some((value, then_, else_)) = targets.as_static_if() {
767            // We have an `if`, generate both `discr == value` and `discr != value`.
768            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
769            mk_condition(value, Polarity::Eq, then_);
770            mk_condition(value, Polarity::Ne, else_);
771        } else {
772            // We have a general switch and we cannot express `discr != value0 && discr != value1`,
773            // so we only generate equality predicates.
774            for (value, target) in targets.iter() {
775                if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
776                    mk_condition(value, Polarity::Eq, target);
777                }
778            }
779        }
780    }
781}
782
783/// Propagate fulfilled conditions forward in the CFG to reduce the amount of duplication.
784#[instrument(level = "debug", skip(body, entry_states))]
785fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
786    let basic_blocks = &body.basic_blocks;
787    let reverse_postorder = basic_blocks.reverse_postorder();
788
789    // Start by computing the number of *incoming edges* for each block.
790    // We do not use the cached `basic_blocks.predecessors` as we only want reachable predecessors.
791    let mut predecessors = IndexVec::from_elem(0, &entry_states);
792    predecessors[START_BLOCK] = 1; // Account for the implicit entry edge.
793    for &bb in reverse_postorder {
794        let term = basic_blocks[bb].terminator();
795        for s in term.successors() {
796            predecessors[s] += 1;
797        }
798    }
799
800    // Compute the number of edges into each block that carry each condition.
801    let mut fulfill_in_pred_count = IndexVec::from_fn_n(
802        |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
803        entry_states.len(),
804    );
805
806    // By traversing in RPO, we increase the likelihood to visit predecessors before successors.
807    for &bb in reverse_postorder {
808        let preds = predecessors[bb];
809        trace!(?bb, ?preds);
810
811        // We have removed all the input edges towards this block. Just skip visiting it.
812        if preds == 0 {
813            continue;
814        }
815
816        let state = &mut entry_states[bb];
817        trace!(?state);
818
819        // Conditions that are fulfilled in all the predecessors, are fulfilled in `bb`.
820        trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
821        for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
822            if cond_preds == preds {
823                trace!(?condition);
824                state.fulfilled.push(condition);
825            }
826        }
827
828        // We want to count how many times each condition is fulfilled,
829        // so ensure we are not counting the same edge twice.
830        let mut targets: Vec<_> = state
831            .fulfilled
832            .iter()
833            .flat_map(|&index| state.targets[index].iter().copied())
834            .collect();
835        targets.sort();
836        targets.dedup();
837        trace!(?targets);
838
839        // We may modify the set of successors by applying edges, so track them here.
840        let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
841
842        targets.reverse();
843        while let Some(target) = targets.pop() {
844            match target {
845                EdgeEffect::Goto { target } => {
846                    // We update the count of predecessors. If target or any successor has not been
847                    // processed yet, this increases the likelihood we find something relevant.
848                    predecessors[target] += 1;
849                    for &s in successors.iter() {
850                        predecessors[s] -= 1;
851                    }
852                    // Only process edges that still exist.
853                    targets.retain(|t| t.block() == target);
854                    successors.clear();
855                    successors.push(target);
856                }
857                EdgeEffect::Chain { succ_block, succ_condition } => {
858                    // `predecessors` is the number of incoming *edges* in each block.
859                    // Count the number of edges that apply `succ_condition` into `succ_block`.
860                    let count = successors.iter().filter(|&&s| s == succ_block).count();
861                    fulfill_in_pred_count[succ_block][succ_condition] += count;
862                }
863            }
864        }
865    }
866}
867
868#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
869fn remove_costly_conditions<'tcx>(
870    tcx: TyCtxt<'tcx>,
871    typing_env: ty::TypingEnv<'tcx>,
872    body: &Body<'tcx>,
873    entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
874) {
875    let basic_blocks = &body.basic_blocks;
876
877    let mut costs = IndexVec::from_elem(None, basic_blocks);
878    let mut cost = |bb: BasicBlock| -> u8 {
879        let c = *costs[bb].get_or_insert_with(|| {
880            let bbdata = &basic_blocks[bb];
881            let mut cost = CostChecker::new(tcx, typing_env, None, body);
882            cost.visit_basic_block_data(bb, bbdata);
883            cost.cost().try_into().unwrap_or(MAX_COST)
884        });
885        trace!("cost[{bb:?}] = {c}");
886        c
887    };
888
889    // Initialize costs with `MAX_COST`: if we have a cycle, the cyclic `bb` has infinite costs.
890    let mut condition_cost = IndexVec::from_fn_n(
891        |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
892        entry_states.len(),
893    );
894
895    let reverse_postorder = basic_blocks.reverse_postorder();
896
897    for &bb in reverse_postorder.iter().rev() {
898        let state = &entry_states[bb];
899        trace!(?bb, ?state);
900
901        let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
902
903        for (condition, targets) in state.targets.iter_enumerated() {
904            for &target in targets {
905                match target {
906                    // A `Goto` has cost 0.
907                    EdgeEffect::Goto { .. } => {}
908                    // Chaining into an already-fulfilled condition is nop.
909                    EdgeEffect::Chain { succ_block, succ_condition }
910                        if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
911                    // When chaining, use `cost[succ_block][succ_condition] + cost(succ_block)`.
912                    EdgeEffect::Chain { succ_block, succ_condition } => {
913                        // Cost associated with duplicating `succ_block`.
914                        let duplication_cost = cost(succ_block);
915                        // Cost associated with the rest of the chain.
916                        let target_cost =
917                            *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
918                        let cost = current_costs[condition]
919                            .saturating_add(duplication_cost)
920                            .saturating_add(target_cost);
921                        trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
922                        current_costs[condition] = cost;
923                    }
924                }
925            }
926        }
927
928        trace!("condition_cost[{bb:?}] = {:?}", current_costs);
929        condition_cost[bb] = current_costs;
930    }
931
932    trace!(?condition_cost);
933
934    for &bb in reverse_postorder {
935        for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
936            if condition_cost[bb][index] >= MAX_COST {
937                trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
938                targets.clear()
939            }
940        }
941    }
942}
943
944struct OpportunitySet<'a, 'tcx> {
945    basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
946    entry_states: IndexVec<BasicBlock, ConditionSet>,
947    /// Cache duplicated block. When cloning a basic block `bb` to fulfill a condition `c`,
948    /// record the target of this `bb with c` edge.
949    duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
950}
951
952impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
953    fn new(
954        body: &'a mut Body<'tcx>,
955        mut entry_states: IndexVec<BasicBlock, ConditionSet>,
956    ) -> Option<OpportunitySet<'a, 'tcx>> {
957        trace!(def_id = ?body.source.def_id(), "apply");
958
959        if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
960            return None;
961        }
962
963        // Free some memory, because we will need to clone condition sets.
964        for state in entry_states.iter_mut() {
965            state.active = Default::default();
966        }
967        let duplicates = Default::default();
968        let basic_blocks = body.basic_blocks.as_mut();
969        Some(OpportunitySet { basic_blocks, entry_states, duplicates })
970    }
971
972    /// Apply the opportunities on the graph.
973    #[instrument(level = "debug", skip(self))]
974    fn apply(mut self) {
975        let mut worklist = Vec::with_capacity(self.basic_blocks.len());
976        worklist.push(START_BLOCK);
977
978        // Use a `GrowableBitSet` and not a `DenseBitSet` as we are adding blocks.
979        let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
980
981        while let Some(bb) = worklist.pop() {
982            if !visited.insert(bb) {
983                continue;
984            }
985
986            self.apply_once(bb);
987
988            // `apply_once` may have modified the terminator of `bb`.
989            // Only visit actual successors.
990            worklist.extend(self.basic_blocks[bb].terminator().successors());
991        }
992    }
993
994    /// Apply the opportunities on `bb`.
995    #[instrument(level = "debug", skip(self))]
996    fn apply_once(&mut self, bb: BasicBlock) {
997        let state = &mut self.entry_states[bb];
998        trace!(?state);
999
1000        // We are modifying the `bb` in-place. Once a `EdgeEffect` has been applied,
1001        // it does not need to be applied again.
1002        let mut targets: Vec<_> = state
1003            .fulfilled
1004            .iter()
1005            .flat_map(|&index| std::mem::take(&mut state.targets[index]))
1006            .collect();
1007        targets.sort();
1008        targets.dedup();
1009        trace!(?targets);
1010
1011        // Use a while-pop to allow modifying `targets` from inside the loop.
1012        targets.reverse();
1013        while let Some(target) = targets.pop() {
1014            debug!(?target);
1015            trace!(term = ?self.basic_blocks[bb].terminator().kind);
1016
1017            // By construction, `target.block()` is a successor of `bb`.
1018            // When applying targets, we may change the set of successors.
1019            // The match below updates the set of targets for consistency.
1020            debug_assert!(
1021                self.basic_blocks[bb].terminator().successors().contains(&target.block()),
1022                "missing {target:?} in successors for {bb:?}, term={:?}",
1023                self.basic_blocks[bb].terminator(),
1024            );
1025
1026            match target {
1027                EdgeEffect::Goto { target } => {
1028                    self.apply_goto(bb, target);
1029
1030                    // We now have `target` as single successor. Drop all other target blocks.
1031                    targets.retain(|t| t.block() == target);
1032                    // Also do this on targets that may be applied by a duplicate of `bb`.
1033                    for ts in self.entry_states[bb].targets.iter_mut() {
1034                        ts.retain(|t| t.block() == target);
1035                    }
1036                }
1037                EdgeEffect::Chain { succ_block, succ_condition } => {
1038                    let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1039
1040                    // We have a new name for `target`, ensure it is correctly applied.
1041                    if let Some(new_succ_block) = new_succ_block {
1042                        for t in targets.iter_mut() {
1043                            t.replace_block(succ_block, new_succ_block)
1044                        }
1045                        // Also do this on targets that may be applied by a duplicate of `bb`.
1046                        for t in
1047                            self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1048                        {
1049                            t.replace_block(succ_block, new_succ_block)
1050                        }
1051                    }
1052                }
1053            }
1054
1055            trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1056        }
1057    }
1058
1059    #[instrument(level = "debug", skip(self))]
1060    fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1061        self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1062    }
1063
1064    #[instrument(level = "debug", skip(self), ret)]
1065    fn apply_chain(
1066        &mut self,
1067        bb: BasicBlock,
1068        target: BasicBlock,
1069        condition: ConditionIndex,
1070    ) -> Option<BasicBlock> {
1071        if self.entry_states[target].fulfilled.contains(&condition) {
1072            // `target` already fulfills `condition`, so we do not need to thread anything.
1073            trace!("fulfilled");
1074            return None;
1075        }
1076
1077        // We may be tempted to modify `target` in-place to avoid a clone. This is wrong.
1078        // We may still have edges from other blocks to `target` that have not been created yet.
1079        // For instance because we may be threading an edge coming from `bb`,
1080        // or `target` may be a block duplicate for which we may still create predecessors.
1081
1082        let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1083            // If we already have a duplicate of `target` which fulfills `condition`, reuse it.
1084            // Otherwise, we clone a new bb to such ends.
1085            let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1086            trace!(?target, ?new_target, ?condition, "clone");
1087
1088            // By definition, `new_target` fulfills the same condition as `target`, with
1089            // `condition` added.
1090            let mut condition_set = self.entry_states[target].clone();
1091            condition_set.fulfilled.push(condition);
1092            let _new_target = self.entry_states.push(condition_set);
1093            debug_assert_eq!(new_target, _new_target);
1094
1095            new_target
1096        });
1097        trace!(?target, ?new_target, ?condition, "reuse");
1098
1099        // Replace `target` by `new_target` where it appears.
1100        // This changes exactly `direct_count` edges.
1101        self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1102            if *s == target {
1103                *s = new_target;
1104            }
1105        });
1106
1107        Some(new_target)
1108    }
1109}
1110
1111/// Compute the set of loop headers in the given body. A loop header is usually defined as a block
1112/// which dominates one of its predecessors. This definition is only correct for reducible CFGs.
1113/// However, computing dominators is expensive, so we approximate according to the post-order
1114/// traversal order. A loop header for us is a block which is visited after its predecessor in
1115/// post-order. This is ok as we mostly need a heuristic.
1116fn maybe_loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> {
1117    let mut maybe_loop_headers = DenseBitSet::new_empty(body.basic_blocks.len());
1118    let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
1119    for (bb, bbdata) in traversal::postorder(body) {
1120        // Post-order means we visit successors before the block for acyclic CFGs.
1121        // If the successor is not visited yet, consider it a loop header.
1122        for succ in bbdata.terminator().successors() {
1123            if !visited.contains(succ) {
1124                maybe_loop_headers.insert(succ);
1125            }
1126        }
1127
1128        // Only mark `bb` as visited after we checked the successors, in case we have a self-loop.
1129        //     bb1: goto -> bb1;
1130        let _new = visited.insert(bb);
1131        debug_assert!(_new);
1132    }
1133
1134    maybe_loop_headers
1135}