1mod by_move_body;
54mod drop;
55mod layout;
56
57pub(super) use by_move_body::coroutine_by_move_body_def_id;
58use drop::{
59 create_coroutine_drop_shim, create_coroutine_drop_shim_async,
60 create_coroutine_drop_shim_proxy_async, elaborate_coroutine_drops, has_async_drops,
61 insert_clean_drop,
62};
63pub(super) use layout::mir_coroutine_witnesses;
64use layout::{CoroutineSavedLocals, compute_layout, locals_live_across_suspend_points};
65use rustc_abi::{FieldIdx, VariantIdx};
66use rustc_hir::lang_items::LangItem;
67use rustc_hir::{self as hir, CoroutineDesugaring, CoroutineKind};
68use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
69use rustc_index::{Idx, IndexVec, indexvec};
70use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
71use rustc_middle::mir::*;
72use rustc_middle::ty::{
73 self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
74};
75use rustc_middle::{bug, span_bug};
76use rustc_mir_dataflow::impls::always_storage_live_locals;
77use rustc_span::def_id::DefId;
78use tracing::{debug, instrument};
79
80use crate::deref_separator::deref_finder;
81use crate::{abort_unwinding_calls, pass_manager as pm, simplify};
82
83pub(super) struct StateTransform;
84
85struct RenameLocalVisitor<'tcx> {
86 from: Local,
87 to: Local,
88 tcx: TyCtxt<'tcx>,
89}
90
91impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
92 fn tcx(&self) -> TyCtxt<'tcx> {
93 self.tcx
94 }
95
96 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
97 if *local == self.from {
98 *local = self.to;
99 } else if *local == self.to {
100 *local = self.from;
101 }
102 }
103
104 fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
105 match terminator.kind {
106 TerminatorKind::Return => {
107 }
110 _ => self.super_terminator(terminator, location),
111 }
112 }
113}
114
115struct SelfArgVisitor<'tcx> {
116 tcx: TyCtxt<'tcx>,
117 new_base: Place<'tcx>,
118}
119
120impl<'tcx> SelfArgVisitor<'tcx> {
121 fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
122 Self { tcx, new_base }
123 }
124}
125
126impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
127 fn tcx(&self) -> TyCtxt<'tcx> {
128 self.tcx
129 }
130
131 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
132 assert_ne!(*local, SELF_ARG);
133 }
134
135 fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
136 if place.local == SELF_ARG {
137 replace_base(place, self.new_base, self.tcx);
138 }
139
140 for elem in place.projection.iter() {
141 if let PlaceElem::Index(local) = elem {
142 assert_ne!(local, SELF_ARG);
143 }
144 }
145 }
146}
147
148#[tracing::instrument(level = "trace", skip(tcx))]
149fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
150 place.local = new_base.local;
151
152 let mut new_projection = new_base.projection.to_vec();
153 new_projection.append(&mut place.projection.to_vec());
154
155 place.projection = tcx.mk_place_elems(&new_projection);
156 tracing::trace!(?place);
157}
158
159const SELF_ARG: Local = Local::arg(0);
160pub(crate) const CTX_ARG: Local = Local::arg(1);
161
162struct SuspensionPoint<'tcx> {
164 state: usize,
166 resume: BasicBlock,
168 resume_arg: Place<'tcx>,
170 drop: Option<BasicBlock>,
172 storage_liveness: GrowableBitSet<Local>,
174}
175
176struct TransformVisitor<'tcx> {
177 tcx: TyCtxt<'tcx>,
178 coroutine_kind: hir::CoroutineKind,
179
180 discr_ty: Ty<'tcx>,
182
183 remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
185
186 storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
188
189 suspension_points: Vec<SuspensionPoint<'tcx>>,
191
192 always_live_locals: DenseBitSet<Local>,
194
195 new_ret_local: Local,
197
198 old_yield_ty: Ty<'tcx>,
199
200 old_ret_ty: Ty<'tcx>,
201}
202
203impl<'tcx> TransformVisitor<'tcx> {
204 fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
205 let block = body.basic_blocks.next_index();
206 let source_info = SourceInfo::outermost(body.span);
207
208 let none_value = match self.coroutine_kind {
209 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
210 span_bug!(body.span, "`Future`s are not fused inherently")
211 }
212 CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
213 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
215 let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span);
216 make_aggregate_adt(
217 option_def_id,
218 VariantIdx::ZERO,
219 self.tcx.mk_args(&[self.old_yield_ty.into()]),
220 IndexVec::new(),
221 )
222 }
223 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
225 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
226 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
227 let yield_ty = args.type_at(0);
228 Rvalue::Use(
229 Operand::Constant(Box::new(ConstOperand {
230 span: source_info.span,
231 const_: Const::Unevaluated(
232 UnevaluatedConst::new(
233 self.tcx.require_lang_item(LangItem::AsyncGenFinished, body.span),
234 self.tcx.mk_args(&[yield_ty.into()]),
235 ),
236 self.old_yield_ty,
237 ),
238 user_ty: None,
239 })),
240 WithRetag::Yes,
241 )
242 }
243 };
244
245 let statements = vec![Statement::new(
246 source_info,
247 StatementKind::Assign(Box::new((Place::return_place(), none_value))),
248 )];
249
250 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
251 statements,
252 Some(Terminator { source_info, kind: TerminatorKind::Return }),
253 false,
254 ));
255
256 block
257 }
258
259 #[tracing::instrument(level = "trace", skip(self, statements))]
265 fn make_state(
266 &self,
267 val: Operand<'tcx>,
268 source_info: SourceInfo,
269 is_return: bool,
270 statements: &mut Vec<Statement<'tcx>>,
271 ) {
272 const ZERO: VariantIdx = VariantIdx::ZERO;
273 const ONE: VariantIdx = VariantIdx::from_usize(1);
274 let rvalue = match self.coroutine_kind {
275 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
276 let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span);
277 let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
278 let (variant_idx, operands) = if is_return {
279 (ZERO, indexvec![val]) } else {
281 (ONE, IndexVec::new()) };
283 make_aggregate_adt(poll_def_id, variant_idx, args, operands)
284 }
285 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
286 let option_def_id = self.tcx.require_lang_item(LangItem::Option, source_info.span);
287 let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
288 let (variant_idx, operands) = if is_return {
289 (ZERO, IndexVec::new()) } else {
291 (ONE, indexvec![val]) };
293 make_aggregate_adt(option_def_id, variant_idx, args, operands)
294 }
295 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
296 if is_return {
297 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
298 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
299 let yield_ty = args.type_at(0);
300 Rvalue::Use(
301 Operand::Constant(Box::new(ConstOperand {
302 span: source_info.span,
303 const_: Const::Unevaluated(
304 UnevaluatedConst::new(
305 self.tcx.require_lang_item(
306 LangItem::AsyncGenFinished,
307 source_info.span,
308 ),
309 self.tcx.mk_args(&[yield_ty.into()]),
310 ),
311 self.old_yield_ty,
312 ),
313 user_ty: None,
314 })),
315 WithRetag::Yes,
316 )
317 } else {
318 Rvalue::Use(val, WithRetag::Yes)
319 }
320 }
321 CoroutineKind::Coroutine(_) => {
322 let coroutine_state_def_id =
323 self.tcx.require_lang_item(LangItem::CoroutineState, source_info.span);
324 let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
325 let variant_idx = if is_return {
326 ONE } else {
328 ZERO };
330 make_aggregate_adt(coroutine_state_def_id, variant_idx, args, indexvec![val])
331 }
332 };
333
334 statements.push(Statement::new(
336 source_info,
337 StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
338 ));
339 }
340
341 #[tracing::instrument(level = "trace", skip(self), ret)]
343 fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
344 let self_place = Place::from(SELF_ARG);
345 let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
346 let mut projection = base.projection.to_vec();
347 projection.push(ProjectionElem::Field(idx, ty));
348
349 Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
350 }
351
352 #[tracing::instrument(level = "trace", skip(self))]
354 fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
355 let self_place = Place::from(SELF_ARG);
356 Statement::new(
357 source_info,
358 StatementKind::SetDiscriminant {
359 place: Box::new(self_place),
360 variant_index: state_disc,
361 },
362 )
363 }
364
365 #[tracing::instrument(level = "trace", skip(self, body))]
367 fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
368 let temp_decl = LocalDecl::new(self.discr_ty, body.span);
369 let local_decls_len = body.local_decls.push(temp_decl);
370 let temp = Place::from(local_decls_len);
371
372 let self_place = Place::from(SELF_ARG);
373 let assign = Statement::new(
374 SourceInfo::outermost(body.span),
375 StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
376 );
377 (assign, temp)
378 }
379
380 #[tracing::instrument(level = "trace", skip(self, body))]
382 fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
383 body.local_decls.swap(old_local, new_local);
384
385 let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
386 visitor.visit_body(body);
387 for suspension in &mut self.suspension_points {
388 let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
389 let location = Location { block: START_BLOCK, statement_index: 0 };
390 visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
391 }
392 }
393}
394
395impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
396 fn tcx(&self) -> TyCtxt<'tcx> {
397 self.tcx
398 }
399
400 #[tracing::instrument(level = "trace", skip(self), ret)]
401 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
402 assert!(!self.remap.contains(*local));
403 }
404
405 #[tracing::instrument(level = "trace", skip(self), ret)]
406 fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
407 if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
409 replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
410 }
411 }
412
413 #[tracing::instrument(level = "trace", skip(self, stmt), ret)]
414 fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
415 if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
417 && self.remap.contains(l)
418 {
419 stmt.make_nop(true);
420 }
421 self.super_statement(stmt, location);
422 }
423
424 #[tracing::instrument(level = "trace", skip(self, term), ret)]
425 fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
426 if let TerminatorKind::Return = term.kind {
427 return;
430 }
431 self.super_terminator(term, location);
432 }
433
434 #[tracing::instrument(level = "trace", skip(self, data), ret)]
435 fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
436 match data.terminator().kind {
437 TerminatorKind::Return => {
438 let source_info = data.terminator().source_info;
439 self.make_state(
441 Operand::Move(Place::return_place()),
442 source_info,
443 true,
444 &mut data.statements,
445 );
446 let state = VariantIdx::new(CoroutineArgs::RETURNED);
448 data.statements.push(self.set_discr(state, source_info));
449 data.terminator_mut().kind = TerminatorKind::Return;
450 }
451 TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
452 let source_info = data.terminator().source_info;
453 self.make_state(value.clone(), source_info, false, &mut data.statements);
455 let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
457
458 if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
461 replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
462 }
463
464 let storage_liveness: GrowableBitSet<Local> =
465 self.storage_liveness[block].clone().unwrap().into();
466
467 for i in 0..self.always_live_locals.domain_size() {
468 let l = Local::new(i);
469 let needs_storage_dead = storage_liveness.contains(l)
470 && !self.remap.contains(l)
471 && !self.always_live_locals.contains(l);
472 if needs_storage_dead {
473 data.statements
474 .push(Statement::new(source_info, StatementKind::StorageDead(l)));
475 }
476 }
477
478 self.suspension_points.push(SuspensionPoint {
479 state,
480 resume,
481 resume_arg,
482 drop,
483 storage_liveness,
484 });
485
486 let state = VariantIdx::new(state);
487 data.statements.push(self.set_discr(state, source_info));
488 data.terminator_mut().kind = TerminatorKind::Return;
489 }
490 _ => {}
491 }
492
493 self.super_basic_block_data(block, data);
494 }
495}
496
497fn make_aggregate_adt<'tcx>(
498 def_id: DefId,
499 variant_idx: VariantIdx,
500 args: GenericArgsRef<'tcx>,
501 operands: IndexVec<FieldIdx, Operand<'tcx>>,
502) -> Rvalue<'tcx> {
503 Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
504}
505
506#[tracing::instrument(level = "trace", skip(tcx, body))]
507fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
508 let coroutine_ty = body.local_decls[SELF_ARG].ty;
509
510 let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
511
512 body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
514
515 SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
517}
518
519#[tracing::instrument(level = "trace", skip(tcx, body))]
520fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
521 let coroutine_ty = body.local_decls[SELF_ARG].ty;
522
523 let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
524
525 let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
526 let pin_adt_ref = tcx.adt_def(pin_did);
527 let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
528 let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
529
530 body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
532
533 let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
534
535 SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
537
538 let source_info = SourceInfo::outermost(body.span);
539 let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);
540
541 let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
542 statements.insert(
543 0,
544 Statement::new(
545 source_info,
546 StatementKind::Assign(Box::new((
547 unpinned_local.into(),
548 Rvalue::Use(Operand::Copy(pin_field), WithRetag::Yes),
549 ))),
550 ),
551 );
552}
553
554#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
572fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
573 let context_mut_ref = Ty::new_task_context(tcx);
574 let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
575 let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
576 ty::GenericArgs::empty(),
577 body.typing_env(tcx),
578 tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
579 );
580
581 let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
584 body.local_decls.swap(CTX_ARG, resume_local);
585 RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);
586
587 let source_info = SourceInfo::outermost(body.span);
591 let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
592 let nonnull_rhs =
593 Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
594 let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
595 let resume_rhs = Rvalue::Aggregate(
596 Box::new(AggregateKind::Adt(
597 resume_ty_def_id,
598 VariantIdx::ZERO,
599 ty::GenericArgs::empty(),
600 None,
601 None,
602 )),
603 indexvec![Operand::Move(nonnull_local.into())],
604 );
605 let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
606 body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
607 0..0,
608 [Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
609 );
610}
611
612fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
617 let context_mut_ref = Ty::new_task_context(tcx);
618 let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
619 let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
620 ty::GenericArgs::empty(),
621 body.typing_env(tcx),
622 tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
623 );
624
625 let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
626 for bb_data in body.basic_blocks.as_mut().iter_mut() {
627 if bb_data.is_cleanup {
628 continue;
629 }
630
631 let terminator = bb_data.terminator_mut();
632 if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
633 && let func_ty = func.ty(&body.local_decls, tcx)
634 && let ty::FnDef(def_id, _) = *func_ty.kind()
635 && def_id == get_context_def_id
636 && let [arg] = &**args
637 && let Some(place) = arg.node.place()
638 {
639 let arg =
640 Rvalue::Cast(
641 CastKind::Transmute,
642 Operand::Copy(place.project_deeper(
643 &[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
644 tcx,
645 )),
646 context_mut_ref,
647 );
648 let assign = Statement::new(
649 terminator.source_info,
650 StatementKind::Assign(Box::new((*destination, arg))),
651 );
652 terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
653 bb_data.statements.push(assign);
654 }
655 }
656}
657
658fn insert_switch<'tcx>(
663 body: &mut Body<'tcx>,
664 cases: Vec<(usize, BasicBlock)>,
665 transform: &TransformVisitor<'tcx>,
666 default_block: BasicBlock,
667) {
668 let (assign, discr) = transform.get_discr(body);
669
670 #[cfg(debug_assertions)]
672 for bb in body.basic_blocks.iter() {
673 for target in bb.terminator().successors() {
674 assert_ne!(target, START_BLOCK);
675 }
676 }
677
678 let former_entry = std::mem::replace(
680 &mut body.basic_blocks_mut()[START_BLOCK],
681 BasicBlockData::new_stmts(vec![assign], None, false),
682 );
683 let former_entry = body.basic_blocks_mut().push(former_entry);
684
685 let mut switch_targets =
687 SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
688 for bb in switch_targets.all_targets_mut() {
689 if *bb == START_BLOCK {
690 *bb = former_entry;
691 }
692 }
693
694 let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
695 body.basic_blocks_mut()[START_BLOCK].terminator =
696 Some(Terminator { source_info: SourceInfo::outermost(body.span), kind: switch });
697}
698
699fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
700 let source_info = SourceInfo::outermost(body.span);
701 body.basic_blocks_mut().push(BasicBlockData::new(Some(Terminator { source_info, kind }), false))
702}
703
704fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) -> Statement<'tcx> {
705 let poll_def_id = tcx.require_lang_item(LangItem::Poll, source_info.span);
707 let args = tcx.mk_args(&[tcx.types.unit.into()]);
708 let val = Operand::Constant(Box::new(ConstOperand {
709 span: source_info.span,
710 user_ty: None,
711 const_: Const::zero_sized(tcx.types.unit),
712 }));
713 let ready_val = Rvalue::Aggregate(
714 Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)),
715 indexvec![val],
716 );
717 Statement::new(source_info, StatementKind::Assign(Box::new((Place::return_place(), ready_val))))
718}
719
720fn insert_poll_ready_block<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> BasicBlock {
721 let source_info = SourceInfo::outermost(body.span);
722 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
723 [return_poll_ready_assign(tcx, source_info)].to_vec(),
724 Some(Terminator { source_info, kind: TerminatorKind::Return }),
725 false,
726 ))
727}
728
729fn insert_panic_block<'tcx>(
730 tcx: TyCtxt<'tcx>,
731 body: &mut Body<'tcx>,
732 message: AssertMessage<'tcx>,
733) -> BasicBlock {
734 let assert_block = body.basic_blocks.next_index();
735 let kind = TerminatorKind::Assert {
736 cond: Operand::Constant(Box::new(ConstOperand {
737 span: body.span,
738 user_ty: None,
739 const_: Const::from_bool(tcx, false),
740 })),
741 expected: true,
742 msg: Box::new(message),
743 target: assert_block,
744 unwind: UnwindAction::Continue,
745 };
746
747 insert_term_block(body, kind)
748}
749
750fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::TypingEnv<'tcx>) -> bool {
751 if body.return_ty().is_privately_uninhabited(tcx, typing_env) {
753 return false;
754 }
755
756 body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
758 }
760
761fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
762 if !tcx.sess.panic_strategy().unwinds() {
764 return false;
765 }
766
767 body.basic_blocks.iter().any(|block| block.terminator().unwind().is_some())
769}
770
771fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
773 transform: &TransformVisitor<'tcx>,
774 body: &mut Body<'tcx>,
775) {
776 let source_info = SourceInfo::outermost(body.span);
777 let poison_block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
778 vec![transform.set_discr(VariantIdx::new(CoroutineArgs::POISONED), source_info)],
779 Some(Terminator { source_info, kind: TerminatorKind::UnwindResume }),
780 true,
781 ));
782
783 for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
784 let source_info = block.terminator().source_info;
785
786 if let TerminatorKind::UnwindResume = block.terminator().kind {
787 if idx != poison_block {
790 *block.terminator_mut() =
791 Terminator { source_info, kind: TerminatorKind::Goto { target: poison_block } };
792 }
793 } else if !block.is_cleanup
794 && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
797 {
798 *unwind = UnwindAction::Cleanup(poison_block);
799 }
800 }
801}
802
803#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
804fn create_coroutine_resume_function<'tcx>(
805 tcx: TyCtxt<'tcx>,
806 transform: TransformVisitor<'tcx>,
807 body: &mut Body<'tcx>,
808 can_return: bool,
809 can_unwind: bool,
810) {
811 if can_unwind {
813 generate_poison_block_and_redirect_unwinds_there(&transform, body);
814 }
815
816 let mut cases = create_cases(body, &transform, Operation::Resume);
817
818 use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
819
820 cases.insert(0, (CoroutineArgs::UNRESUMED, START_BLOCK));
822
823 if can_unwind {
825 cases.insert(
826 1,
827 (
828 CoroutineArgs::POISONED,
829 insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind)),
830 ),
831 );
832 }
833
834 if can_return {
835 let block = match transform.coroutine_kind {
836 CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
837 | CoroutineKind::Coroutine(_) => {
838 if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) {
841 insert_poll_ready_block(tcx, body)
842 } else {
843 insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
844 }
845 }
846 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
847 | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
848 transform.insert_none_ret_block(body)
849 }
850 };
851 cases.insert(1, (CoroutineArgs::RETURNED, block));
852 }
853
854 let default_block = insert_term_block(body, TerminatorKind::Unreachable);
855 insert_switch(body, cases, &transform, default_block);
856
857 match transform.coroutine_kind {
858 CoroutineKind::Coroutine(_)
859 | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
860 {
861 make_coroutine_state_argument_pinned(tcx, body);
862 }
863 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
866 make_coroutine_state_argument_indirect(tcx, body);
867 }
868 }
869
870 simplify::remove_dead_blocks(body);
873
874 pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
875
876 deref_finder(tcx, body, false);
878
879 if transform.coroutine_kind.is_async_desugaring() {
880 transform_async_context(tcx, body);
881 }
882
883 if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
884 dumper.dump_mir(body);
885 }
886}
887
888#[derive(PartialEq, Copy, Clone, Debug)]
890enum Operation {
891 Resume,
892 Drop,
893 AsyncDrop,
894}
895
896impl Operation {
897 fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
898 match self {
899 Operation::Resume => Some(point.resume),
900 Operation::Drop | Operation::AsyncDrop => point.drop,
901 }
902 }
903
904 fn resume_place<'tcx>(self, point: &SuspensionPoint<'tcx>) -> Option<Place<'tcx>> {
905 match self {
906 Operation::Resume | Operation::AsyncDrop => Some(point.resume_arg),
907 Operation::Drop => None,
908 }
909 }
910}
911
912#[tracing::instrument(level = "trace", skip(transform, body))]
913fn create_cases<'tcx>(
914 body: &mut Body<'tcx>,
915 transform: &TransformVisitor<'tcx>,
916 operation: Operation,
917) -> Vec<(usize, BasicBlock)> {
918 let source_info = SourceInfo::outermost(body.span);
919
920 transform
921 .suspension_points
922 .iter()
923 .filter_map(|point| {
924 operation.target_block(point).map(|target| {
926 let mut statements = Vec::new();
927
928 for l in body.local_decls.indices() {
930 let needs_storage_live = point.storage_liveness.contains(l)
931 && !transform.remap.contains(l)
932 && !transform.always_live_locals.contains(l);
933 if needs_storage_live {
934 statements.push(Statement::new(source_info, StatementKind::StorageLive(l)));
935 }
936 }
937
938 if let Some(resume_arg) = operation.resume_place(point)
940 && resume_arg != CTX_ARG.into()
941 {
942 statements.push(Statement::new(
943 source_info,
944 StatementKind::Assign(Box::new((
945 resume_arg,
946 Rvalue::Use(Operand::Move(CTX_ARG.into()), WithRetag::Yes),
947 ))),
948 ));
949 }
950
951 let block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
953 statements,
954 Some(Terminator { source_info, kind: TerminatorKind::Goto { target } }),
955 false,
956 ));
957
958 (point.state, block)
959 })
960 })
961 .collect()
962}
963
964impl<'tcx> crate::MirPass<'tcx> for StateTransform {
965 #[instrument(level = "debug", skip(self, tcx, body), ret)]
966 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
967 debug!(def_id = ?body.source.def_id());
968
969 let Some(old_yield_ty) = body.yield_ty() else {
970 return;
972 };
973 tracing::trace!(def_id = ?body.source.def_id());
974
975 let old_ret_ty = body.return_ty();
976
977 assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
978
979 if let Some(dumper) = MirDumper::new(tcx, "coroutine_before", body) {
980 dumper.dump_mir(body);
981 }
982
983 let coroutine_ty = body.local_decls.raw[1].ty;
985 let coroutine_kind = body.coroutine_kind().unwrap();
986
987 let ty::Coroutine(_, args) = coroutine_ty.kind() else {
989 tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
990 };
991 let discr_ty = args.as_coroutine().discr_ty(tcx);
992
993 let new_ret_ty = match coroutine_kind {
994 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
995 let poll_did = tcx.require_lang_item(LangItem::Poll, body.span);
997 let poll_adt_ref = tcx.adt_def(poll_did);
998 let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
999 Ty::new_adt(tcx, poll_adt_ref, poll_args)
1000 }
1001 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1002 let option_did = tcx.require_lang_item(LangItem::Option, body.span);
1004 let option_adt_ref = tcx.adt_def(option_did);
1005 let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1006 Ty::new_adt(tcx, option_adt_ref, option_args)
1007 }
1008 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
1009 old_yield_ty
1011 }
1012 CoroutineKind::Coroutine(_) => {
1013 let state_did = tcx.require_lang_item(LangItem::CoroutineState, body.span);
1015 let state_adt_ref = tcx.adt_def(state_did);
1016 let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1017 Ty::new_adt(tcx, state_adt_ref, state_args)
1018 }
1019 };
1020
1021 let has_async_drops = has_async_drops(body);
1026
1027 if coroutine_kind.is_async_desugaring() {
1028 eliminate_get_context_calls(tcx, body);
1029 }
1030
1031 let always_live_locals = always_storage_live_locals(body);
1032 let movable = coroutine_kind.movability() == hir::Movability::Movable;
1033 let liveness_info =
1034 locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1035
1036 if tcx.sess.opts.unstable_opts.validate_mir {
1037 let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias {
1038 assigned_local: None,
1039 saved_locals: &liveness_info.saved_locals,
1040 storage_conflicts: &liveness_info.storage_conflicts,
1041 };
1042
1043 vis.visit_body(body);
1044 }
1045
1046 let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
1050
1051 let can_return = can_return(tcx, body, body.typing_env(tcx));
1052
1053 let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1056 tracing::trace!(?new_ret_local);
1057
1058 let mut transform = TransformVisitor {
1064 tcx,
1065 coroutine_kind,
1066 remap,
1067 storage_liveness,
1068 always_live_locals,
1069 suspension_points: Vec::new(),
1070 discr_ty,
1071 new_ret_local,
1072 old_ret_ty,
1073 old_yield_ty,
1074 };
1075 transform.visit_body(body);
1076
1077 transform.replace_local(RETURN_PLACE, new_ret_local, body);
1079
1080 let source_info = SourceInfo::outermost(body.span);
1083 let args_iter = body.args_iter();
1084 body.basic_blocks.as_mut()[START_BLOCK].statements.splice(
1085 0..0,
1086 args_iter.filter_map(|local| {
1087 let (ty, variant_index, idx) = transform.remap[local]?;
1088 let lhs = transform.make_field(variant_index, idx, ty);
1089 let rhs = Rvalue::Use(Operand::Move(local.into()), WithRetag::Yes);
1090 let assign = StatementKind::Assign(Box::new((lhs, rhs)));
1091 Some(Statement::new(source_info, assign))
1092 }),
1093 );
1094
1095 if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1097 body.arg_count = 1;
1098 }
1099
1100 for var in &mut body.var_debug_info {
1104 var.argument_index = None;
1105 }
1106
1107 body.coroutine.as_mut().unwrap().yield_ty = None;
1108 body.coroutine.as_mut().unwrap().resume_ty = None;
1109 body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
1110
1111 let drop_clean = insert_clean_drop(tcx, body, has_async_drops);
1115
1116 if let Some(dumper) = MirDumper::new(tcx, "coroutine_pre-elab", body) {
1117 dumper.dump_mir(body);
1118 }
1119
1120 elaborate_coroutine_drops(tcx, body);
1124
1125 if let Some(dumper) = MirDumper::new(tcx, "coroutine_post-transform", body) {
1126 dumper.dump_mir(body);
1127 }
1128
1129 let can_unwind = can_unwind(tcx, body);
1130
1131 if has_async_drops {
1133 let drop_shim =
1135 create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind);
1136 body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim);
1137 } else {
1138 let drop_shim =
1140 create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
1141 body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
1142
1143 let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
1145 body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
1146 }
1147
1148 create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);
1150 }
1151
1152 fn is_required(&self) -> bool {
1153 true
1154 }
1155}
1156
1157struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> {
1170 saved_locals: &'a CoroutineSavedLocals,
1171 storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
1172 assigned_local: Option<CoroutineSavedLocal>,
1173}
1174
1175impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1176 fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> {
1177 if place.is_indirect() {
1178 return None;
1179 }
1180
1181 self.saved_locals.get(place.local)
1182 }
1183
1184 fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1185 if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1186 assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1187
1188 self.assigned_local = Some(assigned_local);
1189 f(self);
1190 self.assigned_local = None;
1191 }
1192 }
1193}
1194
1195impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1196 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1197 let Some(lhs) = self.assigned_local else {
1198 assert!(!context.is_use());
1203 return;
1204 };
1205
1206 let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1207
1208 if !self.storage_conflicts.contains(lhs, rhs) {
1209 bug!(
1210 "Assignment between coroutine saved locals whose storage is not \
1211 marked as conflicting: {:?}: {:?} = {:?}",
1212 location,
1213 lhs,
1214 rhs,
1215 );
1216 }
1217 }
1218
1219 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1220 match &statement.kind {
1221 StatementKind::Assign((lhs, rhs)) => {
1222 self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1223 }
1224
1225 StatementKind::FakeRead(..)
1226 | StatementKind::SetDiscriminant { .. }
1227 | StatementKind::StorageLive(_)
1228 | StatementKind::StorageDead(_)
1229 | StatementKind::AscribeUserType(..)
1230 | StatementKind::PlaceMention(..)
1231 | StatementKind::Coverage(..)
1232 | StatementKind::Intrinsic(..)
1233 | StatementKind::ConstEvalCounter
1234 | StatementKind::BackwardIncompatibleDropHint { .. }
1235 | StatementKind::Nop => {}
1236 }
1237 }
1238
1239 fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1240 match &terminator.kind {
1243 TerminatorKind::Call {
1244 func,
1245 args,
1246 destination,
1247 target: Some(_),
1248 unwind: _,
1249 call_source: _,
1250 fn_span: _,
1251 } => {
1252 self.check_assigned_place(*destination, |this| {
1253 this.visit_operand(func, location);
1254 for arg in args {
1255 this.visit_operand(&arg.node, location);
1256 }
1257 });
1258 }
1259
1260 TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1261 self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1262 }
1263
1264 TerminatorKind::InlineAsm { .. } => {}
1266
1267 TerminatorKind::Call { .. }
1268 | TerminatorKind::Goto { .. }
1269 | TerminatorKind::SwitchInt { .. }
1270 | TerminatorKind::UnwindResume
1271 | TerminatorKind::UnwindTerminate(_)
1272 | TerminatorKind::Return
1273 | TerminatorKind::TailCall { .. }
1274 | TerminatorKind::Unreachable
1275 | TerminatorKind::Drop { .. }
1276 | TerminatorKind::Assert { .. }
1277 | TerminatorKind::CoroutineDrop
1278 | TerminatorKind::FalseEdge { .. }
1279 | TerminatorKind::FalseUnwind { .. } => {}
1280 }
1281 }
1282}