1mod by_move_body;
54mod drop;
55mod layout;
56
57pub(super) use by_move_body::coroutine_by_move_body_def_id;
58use drop::{
59 create_coroutine_drop_shim, create_coroutine_drop_shim_async,
60 create_coroutine_drop_shim_proxy_async, elaborate_coroutine_drops, has_async_drops,
61 insert_clean_drop,
62};
63pub(super) use layout::mir_coroutine_witnesses;
64use layout::{CoroutineSavedLocals, compute_layout, locals_live_across_suspend_points};
65use rustc_abi::{FieldIdx, VariantIdx};
66use rustc_data_structures::thin_vec::ThinVec;
67use rustc_hir::lang_items::LangItem;
68use rustc_hir::{self as hir, CoroutineDesugaring, CoroutineKind};
69use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
70use rustc_index::{Idx, IndexVec, indexvec};
71use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
72use rustc_middle::mir::*;
73use rustc_middle::ty::{
74 self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
75};
76use rustc_middle::{bug, span_bug};
77use rustc_mir_dataflow::impls::always_storage_live_locals;
78use rustc_span::def_id::DefId;
79use tracing::{debug, instrument};
80
81use crate::deref_separator::deref_finder;
82use crate::{abort_unwinding_calls, pass_manager as pm, simplify};
83
84pub(super) struct StateTransform;
85
86struct RenameLocalVisitor<'tcx> {
87 from: Local,
88 to: Local,
89 tcx: TyCtxt<'tcx>,
90}
91
92impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
93 fn tcx(&self) -> TyCtxt<'tcx> {
94 self.tcx
95 }
96
97 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
98 if *local == self.from {
99 *local = self.to;
100 } else if *local == self.to {
101 *local = self.from;
102 }
103 }
104
105 fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
106 match terminator.kind {
107 TerminatorKind::Return => {
108 }
111 _ => self.super_terminator(terminator, location),
112 }
113 }
114}
115
116struct SelfArgVisitor<'tcx> {
117 tcx: TyCtxt<'tcx>,
118 new_base: Place<'tcx>,
119}
120
121impl<'tcx> SelfArgVisitor<'tcx> {
122 fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
123 Self { tcx, new_base }
124 }
125}
126
127impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
128 fn tcx(&self) -> TyCtxt<'tcx> {
129 self.tcx
130 }
131
132 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
133 assert_ne!(*local, SELF_ARG);
134 }
135
136 fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
137 if place.local == SELF_ARG {
138 replace_base(place, self.new_base, self.tcx);
139 }
140
141 for elem in place.projection.iter() {
142 if let PlaceElem::Index(local) = elem {
143 assert_ne!(local, SELF_ARG);
144 }
145 }
146 }
147}
148
149#[tracing::instrument(level = "trace", skip(tcx))]
150fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
151 place.local = new_base.local;
152
153 let mut new_projection = new_base.projection.to_vec();
154 new_projection.append(&mut place.projection.to_vec());
155
156 place.projection = tcx.mk_place_elems(&new_projection);
157 tracing::trace!(?place);
158}
159
160const SELF_ARG: Local = Local::arg(0);
161pub(crate) const CTX_ARG: Local = Local::arg(1);
162
163struct SuspensionPoint<'tcx> {
165 state: usize,
167 resume: BasicBlock,
169 resume_arg: Place<'tcx>,
171 drop: Option<BasicBlock>,
173 storage_liveness: GrowableBitSet<Local>,
175}
176
177struct TransformVisitor<'tcx> {
178 tcx: TyCtxt<'tcx>,
179 coroutine_kind: hir::CoroutineKind,
180
181 discr_ty: Ty<'tcx>,
183
184 remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
186
187 storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
189
190 suspension_points: Vec<SuspensionPoint<'tcx>>,
192
193 always_live_locals: DenseBitSet<Local>,
195
196 new_ret_local: Local,
198
199 old_yield_ty: Ty<'tcx>,
200
201 old_ret_ty: Ty<'tcx>,
202}
203
204impl<'tcx> TransformVisitor<'tcx> {
205 fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
206 let block = body.basic_blocks.next_index();
207 let source_info = SourceInfo::outermost(body.span);
208
209 let none_value = match self.coroutine_kind {
210 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
211 span_bug!(body.span, "`Future`s are not fused inherently")
212 }
213 CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
214 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
216 let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span);
217 make_aggregate_adt(
218 option_def_id,
219 VariantIdx::ZERO,
220 self.tcx.mk_args(&[self.old_yield_ty.into()]),
221 IndexVec::new(),
222 )
223 }
224 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
226 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
227 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
228 let yield_ty = args.type_at(0);
229 Rvalue::Use(
230 Operand::Constant(Box::new(ConstOperand {
231 span: source_info.span,
232 const_: Const::Unevaluated(
233 UnevaluatedConst::new(
234 self.tcx.require_lang_item(LangItem::AsyncGenFinished, body.span),
235 self.tcx.mk_args(&[yield_ty.into()]),
236 ),
237 self.old_yield_ty,
238 ),
239 user_ty: None,
240 })),
241 WithRetag::Yes,
242 )
243 }
244 };
245
246 let statements = vec![Statement::new(
247 source_info,
248 StatementKind::Assign(Box::new((Place::return_place(), none_value))),
249 )];
250
251 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
252 statements,
253 Some(Terminator {
254 source_info,
255 kind: TerminatorKind::Return,
256 attributes: ThinVec::new(),
257 }),
258 false,
259 ));
260
261 block
262 }
263
264 #[tracing::instrument(level = "trace", skip(self, statements))]
270 fn make_state(
271 &self,
272 val: Operand<'tcx>,
273 source_info: SourceInfo,
274 is_return: bool,
275 statements: &mut Vec<Statement<'tcx>>,
276 ) {
277 const ZERO: VariantIdx = VariantIdx::ZERO;
278 const ONE: VariantIdx = VariantIdx::from_usize(1);
279 let rvalue = match self.coroutine_kind {
280 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
281 let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span);
282 let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
283 let (variant_idx, operands) = if is_return {
284 (ZERO, indexvec![val]) } else {
286 (ONE, IndexVec::new()) };
288 make_aggregate_adt(poll_def_id, variant_idx, args, operands)
289 }
290 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
291 let option_def_id = self.tcx.require_lang_item(LangItem::Option, source_info.span);
292 let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
293 let (variant_idx, operands) = if is_return {
294 (ZERO, IndexVec::new()) } else {
296 (ONE, indexvec![val]) };
298 make_aggregate_adt(option_def_id, variant_idx, args, operands)
299 }
300 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
301 if is_return {
302 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
303 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
304 let yield_ty = args.type_at(0);
305 Rvalue::Use(
306 Operand::Constant(Box::new(ConstOperand {
307 span: source_info.span,
308 const_: Const::Unevaluated(
309 UnevaluatedConst::new(
310 self.tcx.require_lang_item(
311 LangItem::AsyncGenFinished,
312 source_info.span,
313 ),
314 self.tcx.mk_args(&[yield_ty.into()]),
315 ),
316 self.old_yield_ty,
317 ),
318 user_ty: None,
319 })),
320 WithRetag::Yes,
321 )
322 } else {
323 Rvalue::Use(val, WithRetag::Yes)
324 }
325 }
326 CoroutineKind::Coroutine(_) => {
327 let coroutine_state_def_id =
328 self.tcx.require_lang_item(LangItem::CoroutineState, source_info.span);
329 let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
330 let variant_idx = if is_return {
331 ONE } else {
333 ZERO };
335 make_aggregate_adt(coroutine_state_def_id, variant_idx, args, indexvec![val])
336 }
337 };
338
339 statements.push(Statement::new(
341 source_info,
342 StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
343 ));
344 }
345
346 #[tracing::instrument(level = "trace", skip(self), ret)]
348 fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
349 let self_place = Place::from(SELF_ARG);
350 let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
351 let mut projection = base.projection.to_vec();
352 projection.push(ProjectionElem::Field(idx, ty));
353
354 Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
355 }
356
357 #[tracing::instrument(level = "trace", skip(self))]
359 fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
360 let self_place = Place::from(SELF_ARG);
361 Statement::new(
362 source_info,
363 StatementKind::SetDiscriminant {
364 place: Box::new(self_place),
365 variant_index: state_disc,
366 },
367 )
368 }
369
370 #[tracing::instrument(level = "trace", skip(self, body))]
372 fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
373 let temp_decl = LocalDecl::new(self.discr_ty, body.span);
374 let local_decls_len = body.local_decls.push(temp_decl);
375 let temp = Place::from(local_decls_len);
376
377 let self_place = Place::from(SELF_ARG);
378 let assign = Statement::new(
379 SourceInfo::outermost(body.span),
380 StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
381 );
382 (assign, temp)
383 }
384
385 #[tracing::instrument(level = "trace", skip(self, body))]
387 fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
388 body.local_decls.swap(old_local, new_local);
389
390 let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
391 visitor.visit_body(body);
392 for suspension in &mut self.suspension_points {
393 let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
394 let location = Location { block: START_BLOCK, statement_index: 0 };
395 visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
396 }
397 }
398}
399
400impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
401 fn tcx(&self) -> TyCtxt<'tcx> {
402 self.tcx
403 }
404
405 #[tracing::instrument(level = "trace", skip(self), ret)]
406 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
407 assert!(!self.remap.contains(*local));
408 }
409
410 #[tracing::instrument(level = "trace", skip(self), ret)]
411 fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
412 if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
414 replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
415 }
416 }
417
418 #[tracing::instrument(level = "trace", skip(self, stmt), ret)]
419 fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
420 if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
422 && self.remap.contains(l)
423 {
424 stmt.make_nop(true);
425 }
426 self.super_statement(stmt, location);
427 }
428
429 #[tracing::instrument(level = "trace", skip(self, term), ret)]
430 fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
431 if let TerminatorKind::Return = term.kind {
432 return;
435 }
436 self.super_terminator(term, location);
437 }
438
439 #[tracing::instrument(level = "trace", skip(self, data), ret)]
440 fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
441 match data.terminator().kind {
442 TerminatorKind::Return => {
443 let source_info = data.terminator().source_info;
444 self.make_state(
446 Operand::Move(Place::return_place()),
447 source_info,
448 true,
449 &mut data.statements,
450 );
451 let state = VariantIdx::new(CoroutineArgs::RETURNED);
453 data.statements.push(self.set_discr(state, source_info));
454 data.terminator_mut().kind = TerminatorKind::Return;
455 }
456 TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
457 let source_info = data.terminator().source_info;
458 self.make_state(value.clone(), source_info, false, &mut data.statements);
460 let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
462
463 if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
466 replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
467 }
468
469 let storage_liveness: GrowableBitSet<Local> =
470 self.storage_liveness[block].clone().unwrap().into();
471
472 for i in 0..self.always_live_locals.domain_size() {
473 let l = Local::new(i);
474 let needs_storage_dead = storage_liveness.contains(l)
475 && !self.remap.contains(l)
476 && !self.always_live_locals.contains(l);
477 if needs_storage_dead {
478 data.statements
479 .push(Statement::new(source_info, StatementKind::StorageDead(l)));
480 }
481 }
482
483 self.suspension_points.push(SuspensionPoint {
484 state,
485 resume,
486 resume_arg,
487 drop,
488 storage_liveness,
489 });
490
491 let state = VariantIdx::new(state);
492 data.statements.push(self.set_discr(state, source_info));
493 data.terminator_mut().kind = TerminatorKind::Return;
494 }
495 _ => {}
496 }
497
498 self.super_basic_block_data(block, data);
499 }
500}
501
502fn make_aggregate_adt<'tcx>(
503 def_id: DefId,
504 variant_idx: VariantIdx,
505 args: GenericArgsRef<'tcx>,
506 operands: IndexVec<FieldIdx, Operand<'tcx>>,
507) -> Rvalue<'tcx> {
508 Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
509}
510
511#[tracing::instrument(level = "trace", skip(tcx, body))]
512fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
513 let coroutine_ty = body.local_decls[SELF_ARG].ty;
514
515 let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
516
517 body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
519
520 SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
522}
523
524#[tracing::instrument(level = "trace", skip(tcx, body))]
525fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
526 let coroutine_ty = body.local_decls[SELF_ARG].ty;
527
528 let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
529
530 let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
531 let pin_adt_ref = tcx.adt_def(pin_did);
532 let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
533 let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
534
535 body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
537
538 let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
539
540 SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
542
543 let source_info = SourceInfo::outermost(body.span);
544 let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);
545
546 let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
547 statements.insert(
548 0,
549 Statement::new(
550 source_info,
551 StatementKind::Assign(Box::new((
552 unpinned_local.into(),
553 Rvalue::Use(Operand::Copy(pin_field), WithRetag::Yes),
554 ))),
555 ),
556 );
557}
558
559#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
577fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
578 let context_mut_ref = Ty::new_task_context(tcx);
579 let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
580 let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
581 ty::GenericArgs::empty(),
582 body.typing_env(tcx),
583 tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
584 );
585
586 let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
589 body.local_decls.swap(CTX_ARG, resume_local);
590 RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);
591
592 let source_info = SourceInfo::outermost(body.span);
596 let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
597 let nonnull_rhs =
598 Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
599 let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
600 let resume_rhs = Rvalue::Aggregate(
601 Box::new(AggregateKind::Adt(
602 resume_ty_def_id,
603 VariantIdx::ZERO,
604 ty::GenericArgs::empty(),
605 None,
606 None,
607 )),
608 indexvec![Operand::Move(nonnull_local.into())],
609 );
610 let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
611 body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
612 0..0,
613 [Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
614 );
615}
616
617fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
622 let context_mut_ref = Ty::new_task_context(tcx);
623 let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
624 let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
625 ty::GenericArgs::empty(),
626 body.typing_env(tcx),
627 tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
628 );
629
630 let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
631 for bb_data in body.basic_blocks.as_mut().iter_mut() {
632 if bb_data.is_cleanup {
633 continue;
634 }
635
636 let terminator = bb_data.terminator_mut();
637 if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
638 && let func_ty = func.ty(&body.local_decls, tcx)
639 && let ty::FnDef(def_id, _) = *func_ty.kind()
640 && def_id == get_context_def_id
641 && let [arg] = &**args
642 && let Some(place) = arg.node.place()
643 {
644 let arg =
645 Rvalue::Cast(
646 CastKind::Transmute,
647 Operand::Copy(place.project_deeper(
648 &[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
649 tcx,
650 )),
651 context_mut_ref,
652 );
653 let assign = Statement::new(
654 terminator.source_info,
655 StatementKind::Assign(Box::new((*destination, arg))),
656 );
657 terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
658 bb_data.statements.push(assign);
659 }
660 }
661}
662
663fn insert_switch<'tcx>(
668 body: &mut Body<'tcx>,
669 cases: Vec<(usize, BasicBlock)>,
670 transform: &TransformVisitor<'tcx>,
671 default_block: BasicBlock,
672) {
673 let (assign, discr) = transform.get_discr(body);
674
675 #[cfg(debug_assertions)]
677 for bb in body.basic_blocks.iter() {
678 for target in bb.terminator().successors() {
679 assert_ne!(target, START_BLOCK);
680 }
681 }
682
683 let former_entry = std::mem::replace(
685 &mut body.basic_blocks_mut()[START_BLOCK],
686 BasicBlockData::new_stmts(vec![assign], None, false),
687 );
688 let former_entry = body.basic_blocks_mut().push(former_entry);
689
690 let mut switch_targets =
692 SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
693 for bb in switch_targets.all_targets_mut() {
694 if *bb == START_BLOCK {
695 *bb = former_entry;
696 }
697 }
698
699 let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
700 body.basic_blocks_mut()[START_BLOCK].terminator = Some(Terminator {
701 source_info: SourceInfo::outermost(body.span),
702 kind: switch,
703 attributes: ThinVec::new(),
704 });
705}
706
707fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
708 let source_info = SourceInfo::outermost(body.span);
709 body.basic_blocks_mut().push(BasicBlockData::new(
710 Some(Terminator { source_info, kind, attributes: ThinVec::new() }),
711 false,
712 ))
713}
714
715fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) -> Statement<'tcx> {
716 let poll_def_id = tcx.require_lang_item(LangItem::Poll, source_info.span);
718 let args = tcx.mk_args(&[tcx.types.unit.into()]);
719 let val = Operand::Constant(Box::new(ConstOperand {
720 span: source_info.span,
721 user_ty: None,
722 const_: Const::zero_sized(tcx.types.unit),
723 }));
724 let ready_val = Rvalue::Aggregate(
725 Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)),
726 indexvec![val],
727 );
728 Statement::new(source_info, StatementKind::Assign(Box::new((Place::return_place(), ready_val))))
729}
730
731fn insert_poll_ready_block<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> BasicBlock {
732 let source_info = SourceInfo::outermost(body.span);
733 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
734 [return_poll_ready_assign(tcx, source_info)].to_vec(),
735 Some(Terminator { source_info, kind: TerminatorKind::Return, attributes: ThinVec::new() }),
736 false,
737 ))
738}
739
740fn insert_panic_block<'tcx>(
741 tcx: TyCtxt<'tcx>,
742 body: &mut Body<'tcx>,
743 message: AssertMessage<'tcx>,
744) -> BasicBlock {
745 let assert_block = body.basic_blocks.next_index();
746 let kind = TerminatorKind::Assert {
747 cond: Operand::Constant(Box::new(ConstOperand {
748 span: body.span,
749 user_ty: None,
750 const_: Const::from_bool(tcx, false),
751 })),
752 expected: true,
753 msg: Box::new(message),
754 target: assert_block,
755 unwind: UnwindAction::Continue,
756 };
757
758 insert_term_block(body, kind)
759}
760
761fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::TypingEnv<'tcx>) -> bool {
762 if body.return_ty().is_privately_uninhabited(tcx, typing_env) {
764 return false;
765 }
766
767 body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
769 }
771
772fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
773 if !tcx.sess.panic_strategy().unwinds() {
775 return false;
776 }
777
778 body.basic_blocks.iter().any(|block| block.terminator().unwind().is_some())
780}
781
782fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
784 transform: &TransformVisitor<'tcx>,
785 body: &mut Body<'tcx>,
786) {
787 let source_info = SourceInfo::outermost(body.span);
788 let poison_block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
789 vec![transform.set_discr(VariantIdx::new(CoroutineArgs::POISONED), source_info)],
790 Some(Terminator {
791 source_info,
792 kind: TerminatorKind::UnwindResume,
793
794 attributes: ThinVec::new(),
795 }),
796 true,
797 ));
798
799 for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
800 let source_info = block.terminator().source_info;
801
802 if let TerminatorKind::UnwindResume = block.terminator().kind {
803 if idx != poison_block {
806 *block.terminator_mut() = Terminator {
807 source_info,
808 kind: TerminatorKind::Goto { target: poison_block },
809
810 attributes: ThinVec::new(),
811 };
812 }
813 } else if !block.is_cleanup
814 && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
817 {
818 *unwind = UnwindAction::Cleanup(poison_block);
819 }
820 }
821}
822
823#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
824fn create_coroutine_resume_function<'tcx>(
825 tcx: TyCtxt<'tcx>,
826 transform: TransformVisitor<'tcx>,
827 body: &mut Body<'tcx>,
828 can_return: bool,
829 can_unwind: bool,
830) {
831 if can_unwind {
833 generate_poison_block_and_redirect_unwinds_there(&transform, body);
834 }
835
836 let mut cases = create_cases(body, &transform, Operation::Resume);
837
838 use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
839
840 cases.insert(0, (CoroutineArgs::UNRESUMED, START_BLOCK));
842
843 if can_unwind {
845 cases.insert(
846 1,
847 (
848 CoroutineArgs::POISONED,
849 insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind)),
850 ),
851 );
852 }
853
854 if can_return {
855 let block = match transform.coroutine_kind {
856 CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
857 | CoroutineKind::Coroutine(_) => {
858 if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) {
861 insert_poll_ready_block(tcx, body)
862 } else {
863 insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
864 }
865 }
866 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
867 | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
868 transform.insert_none_ret_block(body)
869 }
870 };
871 cases.insert(1, (CoroutineArgs::RETURNED, block));
872 }
873
874 let default_block = insert_term_block(body, TerminatorKind::Unreachable);
875 insert_switch(body, cases, &transform, default_block);
876
877 match transform.coroutine_kind {
878 CoroutineKind::Coroutine(_)
879 | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
880 {
881 make_coroutine_state_argument_pinned(tcx, body);
882 }
883 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
886 make_coroutine_state_argument_indirect(tcx, body);
887 }
888 }
889
890 simplify::remove_dead_blocks(body);
893
894 pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
895
896 deref_finder(tcx, body, false);
898
899 if transform.coroutine_kind.is_async_desugaring() {
900 transform_async_context(tcx, body);
901 }
902
903 if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
904 dumper.dump_mir(body);
905 }
906}
907
908#[derive(PartialEq, Copy, Clone, Debug)]
910enum Operation {
911 Resume,
912 Drop,
913 AsyncDrop,
914}
915
916impl Operation {
917 fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
918 match self {
919 Operation::Resume => Some(point.resume),
920 Operation::Drop | Operation::AsyncDrop => point.drop,
921 }
922 }
923
924 fn resume_place<'tcx>(self, point: &SuspensionPoint<'tcx>) -> Option<Place<'tcx>> {
925 match self {
926 Operation::Resume | Operation::AsyncDrop => Some(point.resume_arg),
927 Operation::Drop => None,
928 }
929 }
930}
931
932#[tracing::instrument(level = "trace", skip(transform, body))]
933fn create_cases<'tcx>(
934 body: &mut Body<'tcx>,
935 transform: &TransformVisitor<'tcx>,
936 operation: Operation,
937) -> Vec<(usize, BasicBlock)> {
938 let source_info = SourceInfo::outermost(body.span);
939
940 transform
941 .suspension_points
942 .iter()
943 .filter_map(|point| {
944 operation.target_block(point).map(|target| {
946 let mut statements = Vec::new();
947
948 for l in body.local_decls.indices() {
950 let needs_storage_live = point.storage_liveness.contains(l)
951 && !transform.remap.contains(l)
952 && !transform.always_live_locals.contains(l);
953 if needs_storage_live {
954 statements.push(Statement::new(source_info, StatementKind::StorageLive(l)));
955 }
956 }
957
958 if let Some(resume_arg) = operation.resume_place(point)
960 && resume_arg != CTX_ARG.into()
961 {
962 statements.push(Statement::new(
963 source_info,
964 StatementKind::Assign(Box::new((
965 resume_arg,
966 Rvalue::Use(Operand::Move(CTX_ARG.into()), WithRetag::Yes),
967 ))),
968 ));
969 }
970
971 let block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
973 statements,
974 Some(Terminator {
975 source_info,
976 kind: TerminatorKind::Goto { target },
977
978 attributes: ThinVec::new(),
979 }),
980 false,
981 ));
982
983 (point.state, block)
984 })
985 })
986 .collect()
987}
988
989impl<'tcx> crate::MirPass<'tcx> for StateTransform {
990 #[instrument(level = "debug", skip(self, tcx, body), ret)]
991 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
992 debug!(def_id = ?body.source.def_id());
993
994 let Some(old_yield_ty) = body.yield_ty() else {
995 return;
997 };
998 tracing::trace!(def_id = ?body.source.def_id());
999
1000 let old_ret_ty = body.return_ty();
1001
1002 assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
1003
1004 if let Some(dumper) = MirDumper::new(tcx, "coroutine_before", body) {
1005 dumper.dump_mir(body);
1006 }
1007
1008 let coroutine_ty = body.local_decls.raw[1].ty;
1010 let coroutine_kind = body.coroutine_kind().unwrap();
1011
1012 let ty::Coroutine(_, args) = coroutine_ty.kind() else {
1014 tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
1015 };
1016 let discr_ty = args.as_coroutine().discr_ty(tcx);
1017
1018 let new_ret_ty = match coroutine_kind {
1019 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
1020 let poll_did = tcx.require_lang_item(LangItem::Poll, body.span);
1022 let poll_adt_ref = tcx.adt_def(poll_did);
1023 let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
1024 Ty::new_adt(tcx, poll_adt_ref, poll_args)
1025 }
1026 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1027 let option_did = tcx.require_lang_item(LangItem::Option, body.span);
1029 let option_adt_ref = tcx.adt_def(option_did);
1030 let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1031 Ty::new_adt(tcx, option_adt_ref, option_args)
1032 }
1033 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
1034 old_yield_ty
1036 }
1037 CoroutineKind::Coroutine(_) => {
1038 let state_did = tcx.require_lang_item(LangItem::CoroutineState, body.span);
1040 let state_adt_ref = tcx.adt_def(state_did);
1041 let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1042 Ty::new_adt(tcx, state_adt_ref, state_args)
1043 }
1044 };
1045
1046 let has_async_drops = has_async_drops(body);
1051
1052 if coroutine_kind.is_async_desugaring() {
1053 eliminate_get_context_calls(tcx, body);
1054 }
1055
1056 let always_live_locals = always_storage_live_locals(body);
1057 let movable = coroutine_kind.movability() == hir::Movability::Movable;
1058 let liveness_info =
1059 locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1060
1061 if tcx.sess.opts.unstable_opts.validate_mir {
1062 let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias {
1063 assigned_local: None,
1064 saved_locals: &liveness_info.saved_locals,
1065 storage_conflicts: &liveness_info.storage_conflicts,
1066 };
1067
1068 vis.visit_body(body);
1069 }
1070
1071 let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
1075
1076 let can_return = can_return(tcx, body, body.typing_env(tcx));
1077
1078 let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1081 tracing::trace!(?new_ret_local);
1082
1083 let mut transform = TransformVisitor {
1089 tcx,
1090 coroutine_kind,
1091 remap,
1092 storage_liveness,
1093 always_live_locals,
1094 suspension_points: Vec::new(),
1095 discr_ty,
1096 new_ret_local,
1097 old_ret_ty,
1098 old_yield_ty,
1099 };
1100 transform.visit_body(body);
1101
1102 transform.replace_local(RETURN_PLACE, new_ret_local, body);
1104
1105 let source_info = SourceInfo::outermost(body.span);
1108 let args_iter = body.args_iter();
1109 body.basic_blocks.as_mut()[START_BLOCK].statements.splice(
1110 0..0,
1111 args_iter.filter_map(|local| {
1112 let (ty, variant_index, idx) = transform.remap[local]?;
1113 let lhs = transform.make_field(variant_index, idx, ty);
1114 let rhs = Rvalue::Use(Operand::Move(local.into()), WithRetag::Yes);
1115 let assign = StatementKind::Assign(Box::new((lhs, rhs)));
1116 Some(Statement::new(source_info, assign))
1117 }),
1118 );
1119
1120 if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1122 body.arg_count = 1;
1123 }
1124
1125 for var in &mut body.var_debug_info {
1129 var.argument_index = None;
1130 }
1131
1132 body.coroutine.as_mut().unwrap().yield_ty = None;
1133 body.coroutine.as_mut().unwrap().resume_ty = None;
1134 body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
1135
1136 let drop_clean = insert_clean_drop(tcx, body, has_async_drops);
1140
1141 if let Some(dumper) = MirDumper::new(tcx, "coroutine_pre-elab", body) {
1142 dumper.dump_mir(body);
1143 }
1144
1145 elaborate_coroutine_drops(tcx, body);
1149
1150 if let Some(dumper) = MirDumper::new(tcx, "coroutine_post-transform", body) {
1151 dumper.dump_mir(body);
1152 }
1153
1154 let can_unwind = can_unwind(tcx, body);
1155
1156 if has_async_drops {
1158 let drop_shim =
1160 create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind);
1161 body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim);
1162 } else {
1163 let drop_shim =
1165 create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
1166 body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
1167
1168 let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
1170 body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
1171 }
1172
1173 create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);
1175 }
1176
1177 fn is_required(&self) -> bool {
1178 true
1179 }
1180}
1181
1182struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> {
1195 saved_locals: &'a CoroutineSavedLocals,
1196 storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
1197 assigned_local: Option<CoroutineSavedLocal>,
1198}
1199
1200impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1201 fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> {
1202 if place.is_indirect() {
1203 return None;
1204 }
1205
1206 self.saved_locals.get(place.local)
1207 }
1208
1209 fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1210 if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1211 assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1212
1213 self.assigned_local = Some(assigned_local);
1214 f(self);
1215 self.assigned_local = None;
1216 }
1217 }
1218}
1219
1220impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1221 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1222 let Some(lhs) = self.assigned_local else {
1223 assert!(!context.is_use());
1228 return;
1229 };
1230
1231 let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1232
1233 if !self.storage_conflicts.contains(lhs, rhs) {
1234 bug!(
1235 "Assignment between coroutine saved locals whose storage is not \
1236 marked as conflicting: {:?}: {:?} = {:?}",
1237 location,
1238 lhs,
1239 rhs,
1240 );
1241 }
1242 }
1243
1244 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1245 match &statement.kind {
1246 StatementKind::Assign((lhs, rhs)) => {
1247 self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1248 }
1249
1250 StatementKind::FakeRead(..)
1251 | StatementKind::SetDiscriminant { .. }
1252 | StatementKind::StorageLive(_)
1253 | StatementKind::StorageDead(_)
1254 | StatementKind::AscribeUserType(..)
1255 | StatementKind::PlaceMention(..)
1256 | StatementKind::Coverage(..)
1257 | StatementKind::Intrinsic(..)
1258 | StatementKind::ConstEvalCounter
1259 | StatementKind::BackwardIncompatibleDropHint { .. }
1260 | StatementKind::Nop => {}
1261 }
1262 }
1263
1264 fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1265 match &terminator.kind {
1268 TerminatorKind::Call {
1269 func,
1270 args,
1271 destination,
1272 target: Some(_),
1273 unwind: _,
1274 call_source: _,
1275 fn_span: _,
1276 } => {
1277 self.check_assigned_place(*destination, |this| {
1278 this.visit_operand(func, location);
1279 for arg in args {
1280 this.visit_operand(&arg.node, location);
1281 }
1282 });
1283 }
1284
1285 TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1286 self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1287 }
1288
1289 TerminatorKind::InlineAsm { .. } => {}
1291
1292 TerminatorKind::Call { .. }
1293 | TerminatorKind::Goto { .. }
1294 | TerminatorKind::SwitchInt { .. }
1295 | TerminatorKind::UnwindResume
1296 | TerminatorKind::UnwindTerminate(_)
1297 | TerminatorKind::Return
1298 | TerminatorKind::TailCall { .. }
1299 | TerminatorKind::Unreachable
1300 | TerminatorKind::Drop { .. }
1301 | TerminatorKind::Assert { .. }
1302 | TerminatorKind::CoroutineDrop
1303 | TerminatorKind::FalseEdge { .. }
1304 | TerminatorKind::FalseUnwind { .. } => {}
1305 }
1306 }
1307}