Skip to main content

rustc_mir_transform/
large_enums.rs

1use rustc_abi::{HasDataLayout, Size, TagEncoding, Variants};
2use rustc_const_eval::interpret::{Scalar, alloc_range};
3use rustc_data_structures::fx::FxHashMap;
4use rustc_middle::mir::interpret::AllocId;
5use rustc_middle::mir::*;
6use rustc_middle::ty::util::IntTypeExt;
7use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
8use rustc_session::Session;
9
10use crate::patch::MirPatch;
11
12/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
13/// enough discrepancy between them.
14///
15/// i.e. If there are two variants:
16/// ```
17/// enum Example {
18///   Small,
19///   Large([u32; 1024]),
20/// }
21/// ```
22/// Instead of emitting moves of the large variant, perform a memcpy instead.
23/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
24///
25/// In summary, what this does is at runtime determine which enum variant is active,
26/// and instead of copying all the bytes of the largest possible variant,
27/// copy only the bytes for the currently active variant. The number of bytes to copy is determined
28/// by a lookup table: a discriminant-indexed array indicating the size of each variant.
29pub(super) struct EnumSizeOpt {
30    pub(crate) discrepancy: u64,
31}
32
33impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
34    fn is_enabled(&self, sess: &Session) -> bool {
35        // There are some differences in behavior on wasm and ARM that are not properly
36        // understood, so we conservatively treat this optimization as unsound:
37        // https://github.com/rust-lang/rust/issues/154413
38        sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() >= 3
39    }
40
41    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
42        // NOTE: This pass may produce different MIR based on the alignment of the target
43        // platform, but it will still be valid.
44
45        let mut alloc_cache = FxHashMap::default();
46        let typing_env = body.typing_env(tcx);
47
48        let mut patch = MirPatch::new(body);
49
50        for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
51            for (statement_index, st) in data.statements.iter_mut().enumerate() {
52                let StatementKind::Assign(box (
53                    lhs,
54                    Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs), _),
55                )) = &st.kind
56                else {
57                    continue;
58                };
59
60                let location = Location { block, statement_index };
61
62                let ty = lhs.ty(&body.local_decls, tcx).ty;
63
64                let Some((adt_def, num_variants, alloc_id)) =
65                    self.candidate(tcx, typing_env, ty, &mut alloc_cache)
66                else {
67                    continue;
68                };
69
70                let span = st.source_info.span;
71
72                let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
73                let size_array_local = patch.new_temp(tmp_ty, span);
74
75                let store_live = StatementKind::StorageLive(size_array_local);
76
77                let place = Place::from(size_array_local);
78                let constant_vals = ConstOperand {
79                    span,
80                    user_ty: None,
81                    const_: Const::Val(
82                        ConstValue::Indirect { alloc_id, offset: Size::ZERO },
83                        tmp_ty,
84                    ),
85                };
86                let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)), WithRetag::No);
87                let const_assign = StatementKind::Assign(Box::new((place, rval)));
88
89                let discr_place =
90                    Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
91                let store_discr =
92                    StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));
93
94                let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
95                let cast_discr = StatementKind::Assign(Box::new((
96                    discr_cast_place,
97                    Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
98                )));
99
100                let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
101                let store_size = StatementKind::Assign(Box::new((
102                    size_place,
103                    Rvalue::Use(
104                        Operand::Copy(Place {
105                            local: size_array_local,
106                            projection: tcx
107                                .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
108                        }),
109                        WithRetag::No,
110                    ),
111                )));
112
113                let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
114                let dst_ptr =
115                    StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));
116
117                let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
118                let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
119                let dst_cast = StatementKind::Assign(Box::new((
120                    dst_cast_place,
121                    Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
122                )));
123
124                let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
125                let src_ptr =
126                    StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));
127
128                let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
129                let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
130                let src_cast = StatementKind::Assign(Box::new((
131                    src_cast_place,
132                    Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
133                )));
134
135                let copy_bytes = StatementKind::Intrinsic(Box::new(
136                    NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
137                        src: Operand::Copy(src_cast_place),
138                        dst: Operand::Copy(dst_cast_place),
139                        count: Operand::Copy(size_place),
140                    }),
141                ));
142
143                let store_dead = StatementKind::StorageDead(size_array_local);
144
145                let stmts = [
146                    store_live,
147                    const_assign,
148                    store_discr,
149                    cast_discr,
150                    store_size,
151                    dst_ptr,
152                    dst_cast,
153                    src_ptr,
154                    src_cast,
155                    copy_bytes,
156                    store_dead,
157                ];
158                for stmt in stmts {
159                    patch.add_statement(location, stmt);
160                }
161
162                st.make_nop(true);
163            }
164        }
165
166        patch.apply(body);
167    }
168
169    fn is_required(&self) -> bool {
170        false
171    }
172}
173
174impl EnumSizeOpt {
175    fn candidate<'tcx>(
176        &self,
177        tcx: TyCtxt<'tcx>,
178        typing_env: ty::TypingEnv<'tcx>,
179        ty: Ty<'tcx>,
180        alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
181    ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
182        let adt_def = match ty.kind() {
183            ty::Adt(adt_def, _args) if adt_def.is_enum() => adt_def,
184            _ => return None,
185        };
186        let layout = tcx.layout_of(typing_env.as_query_input(ty)).ok()?;
187        let variants = match &layout.variants {
188            Variants::Single { .. } | Variants::Empty => return None,
189            Variants::Multiple { tag_encoding: TagEncoding::Niche { .. }, .. } => return None,
190
191            Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
192            Variants::Multiple { variants, .. } => variants,
193        };
194        let min = variants.iter().map(|v| v.size).min().unwrap();
195        let max = variants.iter().map(|v| v.size).max().unwrap();
196        if max.bytes() - min.bytes() < self.discrepancy {
197            return None;
198        }
199
200        let num_discrs = adt_def.discriminants(tcx).count();
201        if variants.iter_enumerated().any(|(var_idx, _)| {
202            let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
203            (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
204        }) {
205            return None;
206        }
207        if let Some(alloc_id) = alloc_cache.get(&ty) {
208            return Some((*adt_def, num_discrs, *alloc_id));
209        }
210
211        // Construct an in-memory array mapping discriminant idx to variant size.
212        let data_layout = tcx.data_layout();
213        let ptr_size = data_layout.pointer_size();
214        let mut alloc = interpret::Allocation::from_bytes(
215            vec![0; ptr_size.bytes_usize() * num_discrs],
216            tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
217            Mutability::Mut,
218            (),
219        );
220        for (var_idx, layout) in variants.iter_enumerated() {
221            let curr_idx = ptr_size * adt_def.discriminant_for_variant(tcx, var_idx).val as u64;
222            let val = Scalar::from_target_usize(layout.size.bytes(), &tcx);
223            alloc.write_scalar(&tcx, alloc_range(curr_idx, val.size()), val).unwrap();
224        }
225        alloc.mutability = Mutability::Not;
226        let alloc = tcx.reserve_and_set_memory_alloc(tcx.mk_const_alloc(alloc));
227
228        Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
229    }
230}