Skip to main content

rustc_mir_transform/
instsimplify.rs

1//! Performs various peephole optimizations.
2
3use rustc_abi::ExternAbi;
4use rustc_hir::{LangItem, find_attr};
5use rustc_middle::bug;
6use rustc_middle::mir::visit::MutVisitor;
7use rustc_middle::mir::*;
8use rustc_middle::ty::layout::ValidityRequirement;
9use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, layout};
10use rustc_span::{Symbol, sym};
11
12use crate::simplify::simplify_duplicate_switch_targets;
13
14pub(super) enum InstSimplify {
15    BeforeInline,
16    AfterSimplifyCfg,
17}
18
19impl<'tcx> crate::MirPass<'tcx> for InstSimplify {
20    fn name(&self) -> &'static str {
21        match self {
22            InstSimplify::BeforeInline => "InstSimplify-before-inline",
23            InstSimplify::AfterSimplifyCfg => "InstSimplify-after-simplifycfg",
24        }
25    }
26
27    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
28        sess.mir_opt_level() > 0
29    }
30
31    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
32        let preserve_ub_checks = find_attr!(tcx.hir_krate_attrs(), RustcPreserveUbChecks);
33        if !preserve_ub_checks {
34            SimplifyUbCheck { tcx }.visit_body(body);
35        }
36        let ctx = InstSimplifyContext {
37            tcx,
38            local_decls: &body.local_decls,
39            typing_env: body.typing_env(tcx),
40        };
41        for block in body.basic_blocks.as_mut() {
42            for statement in block.statements.iter_mut() {
43                let StatementKind::Assign(box (.., rvalue)) = &mut statement.kind else {
44                    continue;
45                };
46
47                ctx.simplify_bool_cmp(rvalue);
48                ctx.simplify_ref_deref(rvalue);
49                ctx.simplify_ptr_aggregate(rvalue);
50                ctx.simplify_cast(rvalue);
51                ctx.simplify_repeated_aggregate(rvalue);
52                ctx.simplify_repeat_once(rvalue);
53            }
54
55            let terminator = block.terminator.as_mut().unwrap();
56            ctx.simplify_primitive_clone(terminator, &mut block.statements);
57            ctx.simplify_size_or_align_of_val(terminator, &mut block.statements);
58            ctx.simplify_intrinsic_assert(terminator);
59            ctx.simplify_nounwind_call(terminator);
60            simplify_duplicate_switch_targets(terminator);
61        }
62    }
63
64    fn is_required(&self) -> bool {
65        false
66    }
67}
68
69struct InstSimplifyContext<'a, 'tcx> {
70    tcx: TyCtxt<'tcx>,
71    local_decls: &'a LocalDecls<'tcx>,
72    typing_env: ty::TypingEnv<'tcx>,
73}
74
75impl<'tcx> InstSimplifyContext<'_, 'tcx> {
76    /// Transform aggregates like [0, 0, 0, 0, 0] into [0; 5].
77    /// GVN can also do this optimization, but GVN is only run at mir-opt-level 2 so having this in
78    /// InstSimplify helps unoptimized builds.
79    fn simplify_repeated_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
80        let Rvalue::Aggregate(box AggregateKind::Array(_), fields) = &*rvalue else {
81            return;
82        };
83        if fields.len() < 5 {
84            return;
85        }
86        let (first, rest) = fields[..].split_first().unwrap();
87        let Operand::Constant(first) = first else {
88            return;
89        };
90        let Ok(first_val) = first.const_.eval(self.tcx, self.typing_env, first.span) else {
91            return;
92        };
93        if rest.iter().all(|field| {
94            let Operand::Constant(field) = field else {
95                return false;
96            };
97            let field = field.const_.eval(self.tcx, self.typing_env, field.span);
98            field == Ok(first_val)
99        }) {
100            let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap());
101            *rvalue = Rvalue::Repeat(Operand::Constant(first.clone()), len);
102        }
103    }
104
105    /// Transform boolean comparisons into logical operations.
106    fn simplify_bool_cmp(&self, rvalue: &mut Rvalue<'tcx>) {
107        let Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) = &*rvalue else { return };
108        *rvalue = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
109            // Transform "Eq(a, true)" ==> "a"
110            (BinOp::Eq, _, Some(true)) => Rvalue::Use(a.clone(), WithRetag::Yes),
111
112            // Transform "Ne(a, false)" ==> "a"
113            (BinOp::Ne, _, Some(false)) => Rvalue::Use(a.clone(), WithRetag::Yes),
114
115            // Transform "Eq(true, b)" ==> "b"
116            (BinOp::Eq, Some(true), _) => Rvalue::Use(b.clone(), WithRetag::Yes),
117
118            // Transform "Ne(false, b)" ==> "b"
119            (BinOp::Ne, Some(false), _) => Rvalue::Use(b.clone(), WithRetag::Yes),
120
121            // Transform "Eq(false, b)" ==> "Not(b)"
122            (BinOp::Eq, Some(false), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),
123
124            // Transform "Ne(true, b)" ==> "Not(b)"
125            (BinOp::Ne, Some(true), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),
126
127            // Transform "Eq(a, false)" ==> "Not(a)"
128            (BinOp::Eq, _, Some(false)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),
129
130            // Transform "Ne(a, true)" ==> "Not(a)"
131            (BinOp::Ne, _, Some(true)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),
132
133            _ => return,
134        };
135    }
136
137    fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> {
138        let a = a.constant()?;
139        if a.const_.ty().is_bool() { a.const_.try_to_bool() } else { None }
140    }
141
142    /// Transform `&(*a)` ==> `a`.
143    fn simplify_ref_deref(&self, rvalue: &mut Rvalue<'tcx>) {
144        if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue
145            && let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection()
146            && rvalue.ty(self.local_decls, self.tcx) == base.ty(self.local_decls, self.tcx).ty
147        {
148            *rvalue = Rvalue::Use(
149                Operand::Copy(Place {
150                    local: base.local,
151                    projection: self.tcx.mk_place_elems(base.projection),
152                }),
153                // This might have been a two-phase borrow, which we should not upgrade
154                // to a full `&mut` reborrow.
155                // FIXME: Once Stacked Borrows is fully removed, we can use `Yes` here as
156                // Tree Borrows treats two-phase and full borrows the same.
157                if matches!(
158                    rvalue,
159                    Rvalue::Ref(_, BorrowKind::Mut { kind: MutBorrowKind::TwoPhaseBorrow }, _)
160                ) {
161                    WithRetag::No
162                } else {
163                    WithRetag::Yes
164                },
165            );
166        }
167    }
168
169    /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`.
170    fn simplify_ptr_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
171        if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue
172            && let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx)
173            && meta_ty.is_unit()
174        {
175            // The mutable borrows we're holding prevent printing `rvalue` here
176            let mut fields = std::mem::take(fields);
177            let _meta = fields.pop().unwrap();
178            let data = fields.pop().unwrap();
179            let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
180            *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
181        }
182    }
183
184    fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) {
185        let Rvalue::Cast(kind, operand, cast_ty) = rvalue else { return };
186
187        let operand_ty = operand.ty(self.local_decls, self.tcx);
188        if operand_ty == *cast_ty {
189            *rvalue = Rvalue::Use(operand.clone(), WithRetag::Yes);
190        } else if *kind == CastKind::Transmute
191            // Transmuting an integer to another integer is just a signedness cast
192            && let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
193                (operand_ty.kind(), cast_ty.kind())
194            && int.bit_width() == uint.bit_width()
195        {
196            // The width check isn't strictly necessary, as different widths
197            // are UB and thus we'd be allowed to turn it into a cast anyway.
198            // But let's keep the UB around for codegen to exploit later.
199            // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
200            // then the width check is necessary for big-endian correctness.)
201            *kind = CastKind::IntToInt;
202        }
203    }
204
205    /// Simplify `[x; 1]` to just `[x]`.
206    fn simplify_repeat_once(&self, rvalue: &mut Rvalue<'tcx>) {
207        if let Rvalue::Repeat(operand, count) = rvalue
208            && let Some(1) = count.try_to_target_usize(self.tcx)
209        {
210            *rvalue = Rvalue::Aggregate(
211                Box::new(AggregateKind::Array(operand.ty(self.local_decls, self.tcx))),
212                [operand.clone()].into(),
213            );
214        }
215    }
216
217    fn simplify_primitive_clone(
218        &self,
219        terminator: &mut Terminator<'tcx>,
220        statements: &mut Vec<Statement<'tcx>>,
221    ) {
222        let TerminatorKind::Call {
223            func, args, destination, target: Some(destination_block), ..
224        } = &terminator.kind
225        else {
226            return;
227        };
228
229        // It's definitely not a clone if there are multiple arguments
230        let [arg] = &args[..] else { return };
231
232        // Only bother looking more if it's easy to know what we're calling
233        let Some((fn_def_id, ..)) = func.const_fn_def() else { return };
234
235        // These types are easily available from locals, so check that before
236        // doing DefId lookups to figure out what we're actually calling.
237        let arg_ty = arg.node.ty(self.local_decls, self.tcx);
238
239        let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() else { return };
240
241        if !self.tcx.is_lang_item(fn_def_id, LangItem::CloneFn)
242            || !inner_ty.is_trivially_pure_clone_copy()
243        {
244            return;
245        }
246
247        let Some(arg_place) = arg.node.place() else { return };
248
249        statements.push(Statement::new(
250            terminator.source_info,
251            StatementKind::Assign(Box::new((
252                *destination,
253                Rvalue::Use(
254                    Operand::Copy(arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx)),
255                    WithRetag::Yes,
256                ),
257            ))),
258        ));
259        terminator.kind = TerminatorKind::Goto { target: *destination_block };
260    }
261
262    /// Simplify `size_of_val` and `align_of_val` if we don't actually need
263    /// to look at the value in order to calculate the result:
264    /// - For `Sized` types we can always do this for both,
265    /// - For `align_of_val::<[T]>` we can return `align_of::<T>()`, since it
266    ///   doesn't depend on the slice's length and the elements are sized.
267    ///
268    /// This is here so it can run after inlining, where it's more useful.
269    /// (LowerIntrinsics is done in cleanup, before the optimization passes.)
270    ///
271    /// Note that we intentionally just produce the lang item constants so this
272    /// works on generic types and avoids any risk of layout calculation cycles.
273    fn simplify_size_or_align_of_val(
274        &self,
275        terminator: &mut Terminator<'tcx>,
276        statements: &mut Vec<Statement<'tcx>>,
277    ) {
278        let source_info = terminator.source_info;
279        if let TerminatorKind::Call {
280            func, args, destination, target: Some(destination_block), ..
281        } = &terminator.kind
282            && args.len() == 1
283            && let Some((fn_def_id, generics)) = func.const_fn_def()
284        {
285            let lang_item = if self.tcx.is_intrinsic(fn_def_id, sym::size_of_val) {
286                LangItem::SizeOf
287            } else if self.tcx.is_intrinsic(fn_def_id, sym::align_of_val) {
288                LangItem::AlignOf
289            } else {
290                return;
291            };
292            let generic_ty = generics.type_at(0);
293            let ty = if generic_ty.is_sized(self.tcx, self.typing_env) {
294                generic_ty
295            } else if let LangItem::AlignOf = lang_item
296                && let ty::Slice(elem_ty) = *generic_ty.kind()
297            {
298                elem_ty
299            } else {
300                return;
301            };
302
303            let const_def_id = self.tcx.require_lang_item(lang_item, source_info.span);
304            let const_op = Operand::unevaluated_constant(
305                self.tcx,
306                const_def_id,
307                &[ty.into()],
308                source_info.span,
309            );
310            statements.push(Statement::new(
311                source_info,
312                StatementKind::Assign(Box::new((
313                    *destination,
314                    Rvalue::Use(const_op, WithRetag::Yes),
315                ))),
316            ));
317            terminator.kind = TerminatorKind::Goto { target: *destination_block };
318        }
319    }
320
321    fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) {
322        let TerminatorKind::Call { ref func, ref mut unwind, .. } = terminator.kind else {
323            return;
324        };
325
326        let Some((def_id, _)) = func.const_fn_def() else {
327            return;
328        };
329
330        let body_ty = self.tcx.type_of(def_id).skip_binder();
331        let body_abi = match body_ty.kind() {
332            ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(),
333            ty::Closure(..) => ExternAbi::RustCall,
334            ty::Coroutine(..) => ExternAbi::Rust,
335            _ => bug!("unexpected body ty: {body_ty:?}"),
336        };
337
338        if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) {
339            *unwind = UnwindAction::Unreachable;
340        }
341    }
342
343    fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) {
344        let TerminatorKind::Call { ref func, target: ref mut target @ Some(target_block), .. } =
345            terminator.kind
346        else {
347            return;
348        };
349        let func_ty = func.ty(self.local_decls, self.tcx);
350        let Some((intrinsic_name, args)) = resolve_rust_intrinsic(self.tcx, func_ty) else {
351            return;
352        };
353        // The intrinsics we are interested in have one generic parameter
354        let [arg, ..] = args[..] else { return };
355
356        let known_is_valid =
357            intrinsic_assert_panics(self.tcx, self.typing_env, arg, intrinsic_name);
358        match known_is_valid {
359            // We don't know the layout or it's not validity assertion at all, don't touch it
360            None => {}
361            Some(true) => {
362                // If we know the assert panics, indicate to later opts that the call diverges
363                *target = None;
364            }
365            Some(false) => {
366                // If we know the assert does not panic, turn the call into a Goto
367                terminator.kind = TerminatorKind::Goto { target: target_block };
368            }
369        }
370    }
371}
372
373fn intrinsic_assert_panics<'tcx>(
374    tcx: TyCtxt<'tcx>,
375    typing_env: ty::TypingEnv<'tcx>,
376    arg: ty::GenericArg<'tcx>,
377    intrinsic_name: Symbol,
378) -> Option<bool> {
379    let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?;
380    let ty = arg.expect_ty();
381    Some(!tcx.check_validity_requirement((requirement, typing_env.as_query_input(ty))).ok()?)
382}
383
384fn resolve_rust_intrinsic<'tcx>(
385    tcx: TyCtxt<'tcx>,
386    func_ty: Ty<'tcx>,
387) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
388    let ty::FnDef(def_id, args) = *func_ty.kind() else { return None };
389    let intrinsic = tcx.intrinsic(def_id)?;
390    Some((intrinsic.name, args))
391}
392
393struct SimplifyUbCheck<'tcx> {
394    tcx: TyCtxt<'tcx>,
395}
396
397impl<'tcx> MutVisitor<'tcx> for SimplifyUbCheck<'tcx> {
398    fn tcx(&self) -> TyCtxt<'tcx> {
399        self.tcx
400    }
401
402    fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _: Location) {
403        if let Operand::RuntimeChecks(RuntimeChecks::UbChecks) = operand {
404            *operand = Operand::Constant(Box::new(ConstOperand {
405                span: rustc_span::DUMMY_SP,
406                user_ty: None,
407                const_: Const::Val(
408                    ConstValue::from_bool(self.tcx.sess.ub_checks()),
409                    self.tcx.types.bool,
410                ),
411            }));
412        }
413    }
414}