1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_data_structures::thin_vec::ThinVec;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_middle::bug;
6use rustc_middle::mir::visit::Visitor;
7use rustc_middle::mir::*;
8use rustc_middle::ty::layout::PrimitiveExt;
9use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
10use rustc_session::Session;
11use tracing::debug;
12
13pub(super) struct CheckEnums;
17
18impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
19 fn is_enabled(&self, sess: &Session) -> bool {
20 sess.ub_checks()
21 }
22
23 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24 if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
27 return;
28 }
29
30 let typing_env = body.typing_env(tcx);
31 let basic_blocks = body.basic_blocks.as_mut();
32 let local_decls = &mut body.local_decls;
33
34 for block in basic_blocks.indices().rev() {
39 for statement_index in (0..basic_blocks[block].statements.len()).rev() {
40 let location = Location { block, statement_index };
41 let statement = &basic_blocks[block].statements[statement_index];
42 let source_info = statement.source_info;
43
44 let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
45 finder.visit_statement(statement, location);
46
47 for check in finder.into_found_enums() {
48 debug!("Inserting enum check");
49 let new_block = split_block(basic_blocks, location);
50
51 match check {
52 EnumCheckType::Direct { op_size, .. }
53 | EnumCheckType::WithNiche { op_size, .. }
54 if op_size.bytes() == 0 =>
55 {
56 tcx.dcx().span_delayed_bug(
59 source_info.span,
60 "cannot build enum discriminant from zero-sized type",
61 );
62 basic_blocks[block].terminator = Some(Terminator {
63 source_info,
64 kind: TerminatorKind::Goto { target: new_block },
65 attributes: ThinVec::new(),
66 });
67 }
68 EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
69 insert_direct_enum_check(
70 tcx,
71 local_decls,
72 basic_blocks,
73 block,
74 source_op,
75 discr,
76 op_size,
77 valid_discrs,
78 source_info,
79 new_block,
80 )
81 }
82 EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
83 tcx,
84 local_decls,
85 &mut basic_blocks[block],
86 source_info,
87 new_block,
88 ),
89 EnumCheckType::WithNiche {
90 source_op,
91 discr,
92 op_size,
93 offset,
94 valid_range,
95 } => insert_niche_check(
96 tcx,
97 local_decls,
98 &mut basic_blocks[block],
99 source_op,
100 valid_range,
101 discr,
102 op_size,
103 offset,
104 source_info,
105 new_block,
106 ),
107 }
108 }
109 }
110 }
111 }
112
113 fn is_required(&self) -> bool {
114 true
115 }
116}
117
118enum EnumCheckType<'tcx> {
120 Uninhabited,
122 Direct {
125 source_op: Operand<'tcx>,
126 discr: TyAndSize<'tcx>,
127 op_size: Size,
128 valid_discrs: Vec<u128>,
129 },
130 WithNiche {
132 source_op: Operand<'tcx>,
133 discr: TyAndSize<'tcx>,
134 op_size: Size,
135 offset: Size,
136 valid_range: WrappingRange,
137 },
138}
139
140#[derive(Debug, Copy, Clone)]
141struct TyAndSize<'tcx> {
142 pub ty: Ty<'tcx>,
143 pub size: Size,
144}
145
146struct EnumFinder<'a, 'tcx> {
149 tcx: TyCtxt<'tcx>,
150 local_decls: &'a mut LocalDecls<'tcx>,
151 typing_env: TypingEnv<'tcx>,
152 enums: Vec<EnumCheckType<'tcx>>,
153}
154
155impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
156 fn new(
157 tcx: TyCtxt<'tcx>,
158 local_decls: &'a mut LocalDecls<'tcx>,
159 typing_env: TypingEnv<'tcx>,
160 ) -> Self {
161 EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
162 }
163
164 fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
166 self.enums
167 }
168}
169
170impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
171 fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
172 if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
173 let ty::Adt(adt_def, _) = ty.kind() else {
174 return;
175 };
176 if !adt_def.is_enum() {
177 return;
178 }
179
180 let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
181 return;
182 };
183 let Ok(op_layout) = self
184 .tcx
185 .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
186 else {
187 return;
188 };
189
190 match enum_layout.variants {
191 Variants::Empty if op_layout.is_uninhabited() => return,
192 Variants::Empty => {
195 self.enums.push(EnumCheckType::Uninhabited);
198 }
199 Variants::Single { .. } => {}
201 Variants::Multiple {
203 tag_encoding: TagEncoding::Direct,
204 tag: Scalar::Initialized { value, .. },
205 ..
206 } => {
207 let valid_discrs =
208 adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
209
210 let discr =
211 TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
212 self.enums.push(EnumCheckType::Direct {
213 source_op: op.to_copy(),
214 discr,
215 op_size: op_layout.size,
216 valid_discrs,
217 });
218 }
219 Variants::Multiple {
221 tag_encoding: TagEncoding::Niche { .. },
222 tag: Scalar::Initialized { value, valid_range, .. },
223 tag_field,
224 ..
225 } => {
226 let discr =
227 TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
228 self.enums.push(EnumCheckType::WithNiche {
229 source_op: op.to_copy(),
230 discr,
231 op_size: op_layout.size,
232 offset: enum_layout.fields.offset(tag_field.as_usize()),
233 valid_range,
234 });
235 }
236 _ => return,
237 }
238
239 self.super_rvalue(rvalue, location);
240 }
241 }
242}
243
244fn split_block(
245 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
246 location: Location,
247) -> BasicBlock {
248 let block_data = &mut basic_blocks[location.block];
249
250 let new_block = BasicBlockData::new_stmts(
252 block_data.statements.split_off(location.statement_index),
253 block_data.terminator.take(),
254 block_data.is_cleanup,
255 );
256
257 basic_blocks.push(new_block)
258}
259
260fn insert_discr_cast_to_u128<'tcx>(
262 tcx: TyCtxt<'tcx>,
263 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
264 block_data: &mut BasicBlockData<'tcx>,
265 source_op: Operand<'tcx>,
266 discr: TyAndSize<'tcx>,
267 op_size: Size,
268 offset: Option<Size>,
269 source_info: SourceInfo,
270) -> Place<'tcx> {
271 let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
272 match size.bytes() {
273 1 => tcx.types.u8,
274 2 => tcx.types.u16,
275 4 => tcx.types.u32,
276 8 => tcx.types.u64,
277 16 => tcx.types.u128,
278 invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
279 }
280 };
281
282 let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
283 let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
286 let array_len = op_size.bytes();
287 let mu_array_ty = Ty::new_array(tcx, mu, array_len);
288 let mu_array =
289 local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
290 let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
291 block_data
292 .statements
293 .push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
294
295 let offset = offset.unwrap_or(Size::ZERO);
298 let smaller_mu_array = mu_array.project_deeper(
299 &[ProjectionElem::Subslice {
300 from: offset.bytes(),
301 to: offset.bytes() + discr.size.bytes(),
302 from_end: false,
303 }],
304 tcx,
305 );
306
307 (CastKind::Transmute, Operand::Copy(smaller_mu_array))
308 } else {
309 let operand_int_ty = get_ty_for_size(tcx, op_size);
310
311 let op_as_int =
312 local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
313 let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
314 block_data.statements.push(Statement::new(
315 source_info,
316 StatementKind::Assign(Box::new((op_as_int, rvalue))),
317 ));
318
319 (CastKind::IntToInt, Operand::Copy(op_as_int))
320 };
321
322 let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
324 let discr_in_discr_ty =
325 local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
326 block_data.statements.push(Statement::new(
327 source_info,
328 StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
329 ));
330
331 let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
333 let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
334 let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
335 block_data
336 .statements
337 .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
338
339 discr
340}
341
342fn insert_direct_enum_check<'tcx>(
343 tcx: TyCtxt<'tcx>,
344 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
345 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
346 current_block: BasicBlock,
347 source_op: Operand<'tcx>,
348 discr: TyAndSize<'tcx>,
349 op_size: Size,
350 discriminants: Vec<u128>,
351 source_info: SourceInfo,
352 new_block: BasicBlock,
353) {
354 let invalid_discr_block_data = BasicBlockData::new(None, false);
356 let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
357 let block_data = &mut basic_blocks[current_block];
358 let discr_place = insert_discr_cast_to_u128(
359 tcx,
360 local_decls,
361 block_data,
362 source_op,
363 discr,
364 op_size,
365 None,
366 source_info,
367 );
368
369 let mask = discr.size.unsigned_int_max();
371 let discr_masked =
372 local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
373 let rvalue = Rvalue::BinaryOp(
374 BinOp::BitAnd,
375 Box::new((
376 Operand::Copy(discr_place),
377 Operand::Constant(Box::new(ConstOperand {
378 span: source_info.span,
379 user_ty: None,
380 const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
381 })),
382 )),
383 );
384 block_data
385 .statements
386 .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
387
388 block_data.terminator = Some(Terminator {
390 source_info,
391 kind: TerminatorKind::SwitchInt {
392 discr: Operand::Copy(discr_masked),
393 targets: SwitchTargets::new(
394 discriminants
395 .into_iter()
396 .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
397 invalid_discr_block,
398 ),
399 },
400 attributes: ThinVec::new(),
401 });
402
403 basic_blocks[invalid_discr_block].terminator = Some(Terminator {
405 source_info,
406 kind: TerminatorKind::Assert {
407 cond: Operand::Constant(Box::new(ConstOperand {
408 span: source_info.span,
409 user_ty: None,
410 const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
411 })),
412 expected: true,
413 target: new_block,
414 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
415 unwind: UnwindAction::Unreachable,
419 },
420 attributes: ThinVec::new(),
421 });
422}
423
424fn insert_uninhabited_enum_check<'tcx>(
425 tcx: TyCtxt<'tcx>,
426 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
427 block_data: &mut BasicBlockData<'tcx>,
428 source_info: SourceInfo,
429 new_block: BasicBlock,
430) {
431 let is_ok: Place<'_> =
432 local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
433 block_data.statements.push(Statement::new(
434 source_info,
435 StatementKind::Assign(Box::new((
436 is_ok,
437 Rvalue::Use(
438 Operand::Constant(Box::new(ConstOperand {
439 span: source_info.span,
440 user_ty: None,
441 const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
442 })),
443 WithRetag::Yes, ),
445 ))),
446 ));
447
448 block_data.terminator = Some(Terminator {
449 source_info,
450 kind: TerminatorKind::Assert {
451 cond: Operand::Copy(is_ok),
452 expected: true,
453 target: new_block,
454 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
455 ConstOperand {
456 span: source_info.span,
457 user_ty: None,
458 const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
459 },
460 )))),
461 unwind: UnwindAction::Unreachable,
465 },
466 attributes: ThinVec::new(),
467 });
468}
469
470fn insert_niche_check<'tcx>(
471 tcx: TyCtxt<'tcx>,
472 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
473 block_data: &mut BasicBlockData<'tcx>,
474 source_op: Operand<'tcx>,
475 valid_range: WrappingRange,
476 discr: TyAndSize<'tcx>,
477 op_size: Size,
478 offset: Size,
479 source_info: SourceInfo,
480 new_block: BasicBlock,
481) {
482 let discr = insert_discr_cast_to_u128(
483 tcx,
484 local_decls,
485 block_data,
486 source_op,
487 discr,
488 op_size,
489 Some(offset),
490 source_info,
491 );
492
493 let start_const = Operand::Constant(Box::new(ConstOperand {
495 span: source_info.span,
496 user_ty: None,
497 const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
498 }));
499 let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
500 span: source_info.span,
501 user_ty: None,
502 const_: Const::Val(
503 ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
504 tcx.types.u128,
505 ),
506 }));
507
508 let discr_diff: Place<'_> =
509 local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
510 block_data.statements.push(Statement::new(
511 source_info,
512 StatementKind::Assign(Box::new((
513 discr_diff,
514 Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
515 ))),
516 ));
517
518 let is_ok: Place<'_> =
519 local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
520 block_data.statements.push(Statement::new(
521 source_info,
522 StatementKind::Assign(Box::new((
523 is_ok,
524 Rvalue::BinaryOp(
525 BinOp::Le,
527 Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
528 ),
529 ))),
530 ));
531
532 block_data.terminator = Some(Terminator {
533 source_info,
534 kind: TerminatorKind::Assert {
535 cond: Operand::Copy(is_ok),
536 expected: true,
537 target: new_block,
538 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
539 unwind: UnwindAction::Unreachable,
543 },
544 attributes: ThinVec::new(),
545 });
546}