Skip to main content

rustc_mir_transform/coroutine/
mod.rs

1//! This is the implementation of the pass which transforms coroutines into state machines.
2//!
3//! MIR generation for coroutines creates a function which has a self argument which
4//! passes by value. This argument is effectively a coroutine type which only contains upvars and
5//! is only used for this argument inside the MIR for the coroutine.
6//! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that
7//! MIR before this pass and creates drop flags for MIR locals.
8//! It will also drop the coroutine argument (which only consists of upvars) if any of the upvars
9//! are moved out of. This pass elaborates the drops of upvars / coroutine argument in the case
10//! that none of the upvars were moved out of. This is because we cannot have any drops of this
11//! coroutine in the MIR, since it is used to create the drop glue for the coroutine. We'd get
12//! infinite recursion otherwise.
13//!
14//! This pass creates the implementation for either the `Coroutine::resume` or `Future::poll`
15//! function and the drop shim for the coroutine based on the MIR input.
16//! It converts the coroutine argument from Self to &mut Self adding derefs in the MIR as needed.
17//! It computes the final layout of the coroutine struct which looks like this:
18//!     First upvars are stored
19//!     It is followed by the coroutine state field.
20//!     Then finally the MIR locals which are live across a suspension point are stored.
21//!     ```ignore (illustrative)
22//!     struct Coroutine {
23//!         upvars...,
24//!         state: u32,
25//!         mir_locals...,
26//!     }
27//!     ```
28//! This pass computes the meaning of the state field and the MIR locals which are live
29//! across a suspension point. There are however three hardcoded coroutine states:
30//!     0 - Coroutine have not been resumed yet
31//!     1 - Coroutine has returned / is completed
32//!     2 - Coroutine has been poisoned
33//!
34//! It also rewrites `return x` and `yield y` as setting a new coroutine state and returning
35//! `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
36//! or `Poll::Ready(x)` and `Poll::Pending` respectively.
37//! MIR locals which are live across a suspension point are moved to the coroutine struct
38//! with references to them being updated with references to the coroutine struct.
39//!
40//! The pass creates two functions which have a switch on the coroutine state giving
41//! the action to take.
42//!
43//! One of them is the implementation of `Coroutine::resume` / `Future::poll`.
44//! For coroutines with state 0 (unresumed) it starts the execution of the coroutine.
45//! For coroutines with state 1 (returned) and state 2 (poisoned) it panics.
46//! Otherwise it continues the execution from the last suspension point.
47//!
48//! The other function is the drop glue for the coroutine.
49//! For coroutines with state 0 (unresumed) it drops the upvars of the coroutine.
50//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
51//! Otherwise it drops all the values in scope at the last suspension point.
52
53mod by_move_body;
54mod drop;
55mod layout;
56
57pub(super) use by_move_body::coroutine_by_move_body_def_id;
58use drop::{
59    create_coroutine_drop_shim, create_coroutine_drop_shim_async,
60    create_coroutine_drop_shim_proxy_async, elaborate_coroutine_drops, has_async_drops,
61    insert_clean_drop,
62};
63pub(super) use layout::mir_coroutine_witnesses;
64use layout::{CoroutineSavedLocals, compute_layout, locals_live_across_suspend_points};
65use rustc_abi::{FieldIdx, VariantIdx};
66use rustc_data_structures::thin_vec::ThinVec;
67use rustc_hir::lang_items::LangItem;
68use rustc_hir::{self as hir, CoroutineDesugaring, CoroutineKind};
69use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
70use rustc_index::{Idx, IndexVec, indexvec};
71use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
72use rustc_middle::mir::*;
73use rustc_middle::ty::{
74    self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
75};
76use rustc_middle::{bug, span_bug};
77use rustc_mir_dataflow::impls::always_storage_live_locals;
78use rustc_span::def_id::DefId;
79use tracing::{debug, instrument};
80
81use crate::deref_separator::deref_finder;
82use crate::{abort_unwinding_calls, pass_manager as pm, simplify};
83
84pub(super) struct StateTransform;
85
86struct RenameLocalVisitor<'tcx> {
87    from: Local,
88    to: Local,
89    tcx: TyCtxt<'tcx>,
90}
91
92impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
93    fn tcx(&self) -> TyCtxt<'tcx> {
94        self.tcx
95    }
96
97    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
98        if *local == self.from {
99            *local = self.to;
100        } else if *local == self.to {
101            *local = self.from;
102        }
103    }
104
105    fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
106        match terminator.kind {
107            TerminatorKind::Return => {
108                // Do not replace the implicit `_0` access here, as that's not possible. The
109                // transform already handles `return` correctly.
110            }
111            _ => self.super_terminator(terminator, location),
112        }
113    }
114}
115
116struct SelfArgVisitor<'tcx> {
117    tcx: TyCtxt<'tcx>,
118    new_base: Place<'tcx>,
119}
120
121impl<'tcx> SelfArgVisitor<'tcx> {
122    fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
123        Self { tcx, new_base }
124    }
125}
126
127impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
128    fn tcx(&self) -> TyCtxt<'tcx> {
129        self.tcx
130    }
131
132    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
133        assert_ne!(*local, SELF_ARG);
134    }
135
136    fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
137        if place.local == SELF_ARG {
138            replace_base(place, self.new_base, self.tcx);
139        }
140
141        for elem in place.projection.iter() {
142            if let PlaceElem::Index(local) = elem {
143                assert_ne!(local, SELF_ARG);
144            }
145        }
146    }
147}
148
149#[tracing::instrument(level = "trace", skip(tcx))]
150fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
151    place.local = new_base.local;
152
153    let mut new_projection = new_base.projection.to_vec();
154    new_projection.append(&mut place.projection.to_vec());
155
156    place.projection = tcx.mk_place_elems(&new_projection);
157    tracing::trace!(?place);
158}
159
160const SELF_ARG: Local = Local::arg(0);
161pub(crate) const CTX_ARG: Local = Local::arg(1);
162
163/// A `yield` point in the coroutine.
164struct SuspensionPoint<'tcx> {
165    /// State discriminant used when suspending or resuming at this point.
166    state: usize,
167    /// The block to jump to after resumption.
168    resume: BasicBlock,
169    /// Where to move the resume argument after resumption.
170    resume_arg: Place<'tcx>,
171    /// Which block to jump to if the coroutine is dropped in this state.
172    drop: Option<BasicBlock>,
173    /// Set of locals that have live storage while at this suspension point.
174    storage_liveness: GrowableBitSet<Local>,
175}
176
177struct TransformVisitor<'tcx> {
178    tcx: TyCtxt<'tcx>,
179    coroutine_kind: hir::CoroutineKind,
180
181    // The type of the discriminant in the coroutine struct
182    discr_ty: Ty<'tcx>,
183
184    // Mapping from Local to (type of local, coroutine struct index)
185    remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
186
187    // A map from a suspension point in a block to the locals which have live storage at that point
188    storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
189
190    // A list of suspension points, generated during the transform
191    suspension_points: Vec<SuspensionPoint<'tcx>>,
192
193    // The set of locals that have no `StorageLive`/`StorageDead` annotations.
194    always_live_locals: DenseBitSet<Local>,
195
196    // New local we just create to hold the `CoroutineState` value.
197    new_ret_local: Local,
198
199    old_yield_ty: Ty<'tcx>,
200
201    old_ret_ty: Ty<'tcx>,
202}
203
204impl<'tcx> TransformVisitor<'tcx> {
205    fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
206        let block = body.basic_blocks.next_index();
207        let source_info = SourceInfo::outermost(body.span);
208
209        let none_value = match self.coroutine_kind {
210            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
211                span_bug!(body.span, "`Future`s are not fused inherently")
212            }
213            CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
214            // `gen` continues return `None`
215            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
216                let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span);
217                make_aggregate_adt(
218                    option_def_id,
219                    VariantIdx::ZERO,
220                    self.tcx.mk_args(&[self.old_yield_ty.into()]),
221                    IndexVec::new(),
222                )
223            }
224            // `async gen` continues to return `Poll::Ready(None)`
225            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
226                let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
227                let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
228                let yield_ty = args.type_at(0);
229                Rvalue::Use(
230                    Operand::Constant(Box::new(ConstOperand {
231                        span: source_info.span,
232                        const_: Const::Unevaluated(
233                            UnevaluatedConst::new(
234                                self.tcx.require_lang_item(LangItem::AsyncGenFinished, body.span),
235                                self.tcx.mk_args(&[yield_ty.into()]),
236                            ),
237                            self.old_yield_ty,
238                        ),
239                        user_ty: None,
240                    })),
241                    WithRetag::Yes,
242                )
243            }
244        };
245
246        let statements = vec![Statement::new(
247            source_info,
248            StatementKind::Assign(Box::new((Place::return_place(), none_value))),
249        )];
250
251        body.basic_blocks_mut().push(BasicBlockData::new_stmts(
252            statements,
253            Some(Terminator {
254                source_info,
255                kind: TerminatorKind::Return,
256                attributes: ThinVec::new(),
257            }),
258            false,
259        ));
260
261        block
262    }
263
264    // Make a `CoroutineState` or `Poll` variant assignment.
265    //
266    // `core::ops::CoroutineState` only has single element tuple variants,
267    // so we can just write to the downcasted first field and then set the
268    // discriminant to the appropriate variant.
269    #[tracing::instrument(level = "trace", skip(self, statements))]
270    fn make_state(
271        &self,
272        val: Operand<'tcx>,
273        source_info: SourceInfo,
274        is_return: bool,
275        statements: &mut Vec<Statement<'tcx>>,
276    ) {
277        const ZERO: VariantIdx = VariantIdx::ZERO;
278        const ONE: VariantIdx = VariantIdx::from_usize(1);
279        let rvalue = match self.coroutine_kind {
280            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
281                let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span);
282                let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
283                let (variant_idx, operands) = if is_return {
284                    (ZERO, indexvec![val]) // Poll::Ready(val)
285                } else {
286                    (ONE, IndexVec::new()) // Poll::Pending
287                };
288                make_aggregate_adt(poll_def_id, variant_idx, args, operands)
289            }
290            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
291                let option_def_id = self.tcx.require_lang_item(LangItem::Option, source_info.span);
292                let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
293                let (variant_idx, operands) = if is_return {
294                    (ZERO, IndexVec::new()) // None
295                } else {
296                    (ONE, indexvec![val]) // Some(val)
297                };
298                make_aggregate_adt(option_def_id, variant_idx, args, operands)
299            }
300            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
301                if is_return {
302                    let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
303                    let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
304                    let yield_ty = args.type_at(0);
305                    Rvalue::Use(
306                        Operand::Constant(Box::new(ConstOperand {
307                            span: source_info.span,
308                            const_: Const::Unevaluated(
309                                UnevaluatedConst::new(
310                                    self.tcx.require_lang_item(
311                                        LangItem::AsyncGenFinished,
312                                        source_info.span,
313                                    ),
314                                    self.tcx.mk_args(&[yield_ty.into()]),
315                                ),
316                                self.old_yield_ty,
317                            ),
318                            user_ty: None,
319                        })),
320                        WithRetag::Yes,
321                    )
322                } else {
323                    Rvalue::Use(val, WithRetag::Yes)
324                }
325            }
326            CoroutineKind::Coroutine(_) => {
327                let coroutine_state_def_id =
328                    self.tcx.require_lang_item(LangItem::CoroutineState, source_info.span);
329                let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
330                let variant_idx = if is_return {
331                    ONE // CoroutineState::Complete(val)
332                } else {
333                    ZERO // CoroutineState::Yielded(val)
334                };
335                make_aggregate_adt(coroutine_state_def_id, variant_idx, args, indexvec![val])
336            }
337        };
338
339        // Assign to `new_ret_local`, which will be replaced by `RETURN_PLACE` later.
340        statements.push(Statement::new(
341            source_info,
342            StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
343        ));
344    }
345
346    // Create a Place referencing a coroutine struct field
347    #[tracing::instrument(level = "trace", skip(self), ret)]
348    fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
349        let self_place = Place::from(SELF_ARG);
350        let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
351        let mut projection = base.projection.to_vec();
352        projection.push(ProjectionElem::Field(idx, ty));
353
354        Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
355    }
356
357    // Create a statement which changes the discriminant
358    #[tracing::instrument(level = "trace", skip(self))]
359    fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
360        let self_place = Place::from(SELF_ARG);
361        Statement::new(
362            source_info,
363            StatementKind::SetDiscriminant {
364                place: Box::new(self_place),
365                variant_index: state_disc,
366            },
367        )
368    }
369
370    // Create a statement which reads the discriminant into a temporary
371    #[tracing::instrument(level = "trace", skip(self, body))]
372    fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
373        let temp_decl = LocalDecl::new(self.discr_ty, body.span);
374        let local_decls_len = body.local_decls.push(temp_decl);
375        let temp = Place::from(local_decls_len);
376
377        let self_place = Place::from(SELF_ARG);
378        let assign = Statement::new(
379            SourceInfo::outermost(body.span),
380            StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
381        );
382        (assign, temp)
383    }
384
385    /// Swaps all references of `old_local` and `new_local`.
386    #[tracing::instrument(level = "trace", skip(self, body))]
387    fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
388        body.local_decls.swap(old_local, new_local);
389
390        let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
391        visitor.visit_body(body);
392        for suspension in &mut self.suspension_points {
393            let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
394            let location = Location { block: START_BLOCK, statement_index: 0 };
395            visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
396        }
397    }
398}
399
400impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
401    fn tcx(&self) -> TyCtxt<'tcx> {
402        self.tcx
403    }
404
405    #[tracing::instrument(level = "trace", skip(self), ret)]
406    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
407        assert!(!self.remap.contains(*local));
408    }
409
410    #[tracing::instrument(level = "trace", skip(self), ret)]
411    fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
412        // Replace an Local in the remap with a coroutine struct access
413        if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
414            replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
415        }
416    }
417
418    #[tracing::instrument(level = "trace", skip(self, stmt), ret)]
419    fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
420        // Remove StorageLive and StorageDead statements for remapped locals
421        if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
422            && self.remap.contains(l)
423        {
424            stmt.make_nop(true);
425        }
426        self.super_statement(stmt, location);
427    }
428
429    #[tracing::instrument(level = "trace", skip(self, term), ret)]
430    fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
431        if let TerminatorKind::Return = term.kind {
432            // `visit_basic_block_data` introduces `Return` terminators which read `RETURN_PLACE`.
433            // But this `RETURN_PLACE` is already remapped, so we should not touch it again.
434            return;
435        }
436        self.super_terminator(term, location);
437    }
438
439    #[tracing::instrument(level = "trace", skip(self, data), ret)]
440    fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
441        match data.terminator().kind {
442            TerminatorKind::Return => {
443                let source_info = data.terminator().source_info;
444                // We must assign the value first in case it gets declared dead below
445                self.make_state(
446                    Operand::Move(Place::return_place()),
447                    source_info,
448                    true,
449                    &mut data.statements,
450                );
451                // Return state.
452                let state = VariantIdx::new(CoroutineArgs::RETURNED);
453                data.statements.push(self.set_discr(state, source_info));
454                data.terminator_mut().kind = TerminatorKind::Return;
455            }
456            TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
457                let source_info = data.terminator().source_info;
458                // We must assign the value first in case it gets declared dead below
459                self.make_state(value.clone(), source_info, false, &mut data.statements);
460                // Yield state.
461                let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
462
463                // The resume arg target location might itself be remapped if its base local is
464                // live across a yield.
465                if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
466                    replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
467                }
468
469                let storage_liveness: GrowableBitSet<Local> =
470                    self.storage_liveness[block].clone().unwrap().into();
471
472                for i in 0..self.always_live_locals.domain_size() {
473                    let l = Local::new(i);
474                    let needs_storage_dead = storage_liveness.contains(l)
475                        && !self.remap.contains(l)
476                        && !self.always_live_locals.contains(l);
477                    if needs_storage_dead {
478                        data.statements
479                            .push(Statement::new(source_info, StatementKind::StorageDead(l)));
480                    }
481                }
482
483                self.suspension_points.push(SuspensionPoint {
484                    state,
485                    resume,
486                    resume_arg,
487                    drop,
488                    storage_liveness,
489                });
490
491                let state = VariantIdx::new(state);
492                data.statements.push(self.set_discr(state, source_info));
493                data.terminator_mut().kind = TerminatorKind::Return;
494            }
495            _ => {}
496        }
497
498        self.super_basic_block_data(block, data);
499    }
500}
501
502fn make_aggregate_adt<'tcx>(
503    def_id: DefId,
504    variant_idx: VariantIdx,
505    args: GenericArgsRef<'tcx>,
506    operands: IndexVec<FieldIdx, Operand<'tcx>>,
507) -> Rvalue<'tcx> {
508    Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
509}
510
511#[tracing::instrument(level = "trace", skip(tcx, body))]
512fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
513    let coroutine_ty = body.local_decls[SELF_ARG].ty;
514
515    let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
516
517    // Replace the by value coroutine argument
518    body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
519
520    // Add a deref to accesses of the coroutine state
521    SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
522}
523
524#[tracing::instrument(level = "trace", skip(tcx, body))]
525fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
526    let coroutine_ty = body.local_decls[SELF_ARG].ty;
527
528    let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
529
530    let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
531    let pin_adt_ref = tcx.adt_def(pin_did);
532    let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
533    let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
534
535    // Replace the by ref coroutine argument
536    body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
537
538    let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
539
540    // Add the Pin field access to accesses of the coroutine state
541    SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
542
543    let source_info = SourceInfo::outermost(body.span);
544    let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);
545
546    let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
547    statements.insert(
548        0,
549        Statement::new(
550            source_info,
551            StatementKind::Assign(Box::new((
552                unpinned_local.into(),
553                Rvalue::Use(Operand::Copy(pin_field), WithRetag::Yes),
554            ))),
555        ),
556    );
557}
558
559/// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
560/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
561/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
562///
563/// The actual should be `&mut Context<'_>`. This performs the substitution:
564/// - create a new local `_r` of type `ResumeTy`;
565/// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
566/// - let all the code use `_r` instead of `_2`.
567///
568/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
569/// but rather directly use `&mut Context<'_>`, however that would currently
570/// lead to higher-kinded lifetime errors.
571/// See <https://github.com/rust-lang/rust/issues/105501>.
572///
573/// The async lowering step and the type / lifetime inference / checking are
574/// still using the `ResumeTy` indirection for the time being, and that indirection
575/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
576#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
577fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
578    let context_mut_ref = Ty::new_task_context(tcx);
579    let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
580    let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
581        ty::GenericArgs::empty(),
582        body.typing_env(tcx),
583        tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
584    );
585
586    // Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
587    // and set `CTX_ARG: &mut Context<'_>`.
588    let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
589    body.local_decls.swap(CTX_ARG, resume_local);
590    RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);
591
592    // Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
593    // Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
594    // at the function entry to make the bridge.
595    let source_info = SourceInfo::outermost(body.span);
596    let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
597    let nonnull_rhs =
598        Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
599    let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
600    let resume_rhs = Rvalue::Aggregate(
601        Box::new(AggregateKind::Adt(
602            resume_ty_def_id,
603            VariantIdx::ZERO,
604            ty::GenericArgs::empty(),
605            None,
606            None,
607        )),
608        indexvec![Operand::Move(nonnull_local.into())],
609    );
610    let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
611    body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
612        0..0,
613        [Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
614    );
615}
616
617/// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
618/// Both types are just a single pointer, but liveness analysis does not know that and
619/// supposes that the operand and the destination are live at the same time.
620/// Forcibly inline those calls to avoid this.
621fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
622    let context_mut_ref = Ty::new_task_context(tcx);
623    let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
624    let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
625        ty::GenericArgs::empty(),
626        body.typing_env(tcx),
627        tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
628    );
629
630    let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
631    for bb_data in body.basic_blocks.as_mut().iter_mut() {
632        if bb_data.is_cleanup {
633            continue;
634        }
635
636        let terminator = bb_data.terminator_mut();
637        if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
638            && let func_ty = func.ty(&body.local_decls, tcx)
639            && let ty::FnDef(def_id, _) = *func_ty.kind()
640            && def_id == get_context_def_id
641            && let [arg] = &**args
642            && let Some(place) = arg.node.place()
643        {
644            let arg =
645                Rvalue::Cast(
646                    CastKind::Transmute,
647                    Operand::Copy(place.project_deeper(
648                        &[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
649                        tcx,
650                    )),
651                    context_mut_ref,
652                );
653            let assign = Statement::new(
654                terminator.source_info,
655                StatementKind::Assign(Box::new((*destination, arg))),
656            );
657            terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
658            bb_data.statements.push(assign);
659        }
660    }
661}
662
663/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and
664/// dispatches to blocks according to `cases`.
665///
666/// After this function, the former entry point of the function will be the last block.
667fn insert_switch<'tcx>(
668    body: &mut Body<'tcx>,
669    cases: Vec<(usize, BasicBlock)>,
670    transform: &TransformVisitor<'tcx>,
671    default_block: BasicBlock,
672) {
673    let (assign, discr) = transform.get_discr(body);
674
675    // MIR validation ensures that no block targets `ENTRY_BLOCK`.
676    #[cfg(debug_assertions)]
677    for bb in body.basic_blocks.iter() {
678        for target in bb.terminator().successors() {
679            assert_ne!(target, START_BLOCK);
680        }
681    }
682
683    // Add the switch as entry block, and put the former entry block at the end.
684    let former_entry = std::mem::replace(
685        &mut body.basic_blocks_mut()[START_BLOCK],
686        BasicBlockData::new_stmts(vec![assign], None, false),
687    );
688    let former_entry = body.basic_blocks_mut().push(former_entry);
689
690    // We may point to `START_BLOCK` in our `cases`, replace it with `former_entry`.
691    let mut switch_targets =
692        SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
693    for bb in switch_targets.all_targets_mut() {
694        if *bb == START_BLOCK {
695            *bb = former_entry;
696        }
697    }
698
699    let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
700    body.basic_blocks_mut()[START_BLOCK].terminator = Some(Terminator {
701        source_info: SourceInfo::outermost(body.span),
702        kind: switch,
703        attributes: ThinVec::new(),
704    });
705}
706
707fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
708    let source_info = SourceInfo::outermost(body.span);
709    body.basic_blocks_mut().push(BasicBlockData::new(
710        Some(Terminator { source_info, kind, attributes: ThinVec::new() }),
711        false,
712    ))
713}
714
715fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) -> Statement<'tcx> {
716    // Poll::Ready(())
717    let poll_def_id = tcx.require_lang_item(LangItem::Poll, source_info.span);
718    let args = tcx.mk_args(&[tcx.types.unit.into()]);
719    let val = Operand::Constant(Box::new(ConstOperand {
720        span: source_info.span,
721        user_ty: None,
722        const_: Const::zero_sized(tcx.types.unit),
723    }));
724    let ready_val = Rvalue::Aggregate(
725        Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)),
726        indexvec![val],
727    );
728    Statement::new(source_info, StatementKind::Assign(Box::new((Place::return_place(), ready_val))))
729}
730
731fn insert_poll_ready_block<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> BasicBlock {
732    let source_info = SourceInfo::outermost(body.span);
733    body.basic_blocks_mut().push(BasicBlockData::new_stmts(
734        [return_poll_ready_assign(tcx, source_info)].to_vec(),
735        Some(Terminator { source_info, kind: TerminatorKind::Return, attributes: ThinVec::new() }),
736        false,
737    ))
738}
739
740fn insert_panic_block<'tcx>(
741    tcx: TyCtxt<'tcx>,
742    body: &mut Body<'tcx>,
743    message: AssertMessage<'tcx>,
744) -> BasicBlock {
745    let assert_block = body.basic_blocks.next_index();
746    let kind = TerminatorKind::Assert {
747        cond: Operand::Constant(Box::new(ConstOperand {
748            span: body.span,
749            user_ty: None,
750            const_: Const::from_bool(tcx, false),
751        })),
752        expected: true,
753        msg: Box::new(message),
754        target: assert_block,
755        unwind: UnwindAction::Continue,
756    };
757
758    insert_term_block(body, kind)
759}
760
761fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::TypingEnv<'tcx>) -> bool {
762    // Returning from a function with an uninhabited return type is undefined behavior.
763    if body.return_ty().is_privately_uninhabited(tcx, typing_env) {
764        return false;
765    }
766
767    // If there's a return terminator the function may return.
768    body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
769    // Otherwise the function can't return.
770}
771
772fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
773    // Nothing can unwind when landing pads are off.
774    if !tcx.sess.panic_strategy().unwinds() {
775        return false;
776    }
777
778    // If we don't find an unwinding terminator, the function cannot unwind.
779    body.basic_blocks.iter().any(|block| block.terminator().unwind().is_some())
780}
781
782// Poison the coroutine when it unwinds
783fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
784    transform: &TransformVisitor<'tcx>,
785    body: &mut Body<'tcx>,
786) {
787    let source_info = SourceInfo::outermost(body.span);
788    let poison_block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
789        vec![transform.set_discr(VariantIdx::new(CoroutineArgs::POISONED), source_info)],
790        Some(Terminator {
791            source_info,
792            kind: TerminatorKind::UnwindResume,
793
794            attributes: ThinVec::new(),
795        }),
796        true,
797    ));
798
799    for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
800        let source_info = block.terminator().source_info;
801
802        if let TerminatorKind::UnwindResume = block.terminator().kind {
803            // An existing `Resume` terminator is redirected to jump to our dedicated
804            // "poisoning block" above.
805            if idx != poison_block {
806                *block.terminator_mut() = Terminator {
807                    source_info,
808                    kind: TerminatorKind::Goto { target: poison_block },
809
810                    attributes: ThinVec::new(),
811                };
812            }
813        } else if !block.is_cleanup
814            // Any terminators that *can* unwind but don't have an unwind target set are also
815            // pointed at our poisoning block (unless they're part of the cleanup path).
816            && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
817        {
818            *unwind = UnwindAction::Cleanup(poison_block);
819        }
820    }
821}
822
823#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
824fn create_coroutine_resume_function<'tcx>(
825    tcx: TyCtxt<'tcx>,
826    transform: TransformVisitor<'tcx>,
827    body: &mut Body<'tcx>,
828    can_return: bool,
829    can_unwind: bool,
830) {
831    // Poison the coroutine when it unwinds
832    if can_unwind {
833        generate_poison_block_and_redirect_unwinds_there(&transform, body);
834    }
835
836    let mut cases = create_cases(body, &transform, Operation::Resume);
837
838    use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
839
840    // Jump to the entry point on the unresumed
841    cases.insert(0, (CoroutineArgs::UNRESUMED, START_BLOCK));
842
843    // Panic when resumed on the returned or poisoned state
844    if can_unwind {
845        cases.insert(
846            1,
847            (
848                CoroutineArgs::POISONED,
849                insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind)),
850            ),
851        );
852    }
853
854    if can_return {
855        let block = match transform.coroutine_kind {
856            CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
857            | CoroutineKind::Coroutine(_) => {
858                // For `async_drop_in_place<T>::{closure}` we just keep return Poll::Ready,
859                // because async drop of such coroutine keeps polling original coroutine
860                if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) {
861                    insert_poll_ready_block(tcx, body)
862                } else {
863                    insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
864                }
865            }
866            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
867            | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
868                transform.insert_none_ret_block(body)
869            }
870        };
871        cases.insert(1, (CoroutineArgs::RETURNED, block));
872    }
873
874    let default_block = insert_term_block(body, TerminatorKind::Unreachable);
875    insert_switch(body, cases, &transform, default_block);
876
877    match transform.coroutine_kind {
878        CoroutineKind::Coroutine(_)
879        | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
880        {
881            make_coroutine_state_argument_pinned(tcx, body);
882        }
883        // Iterator::next doesn't accept a pinned argument,
884        // unlike for all other coroutine kinds.
885        CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
886            make_coroutine_state_argument_indirect(tcx, body);
887        }
888    }
889
890    // Make sure we remove dead blocks to remove
891    // unrelated code from the drop part of the function
892    simplify::remove_dead_blocks(body);
893
894    pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
895
896    // Run derefer to fix Derefs that are not in the first place
897    deref_finder(tcx, body, false);
898
899    if transform.coroutine_kind.is_async_desugaring() {
900        transform_async_context(tcx, body);
901    }
902
903    if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
904        dumper.dump_mir(body);
905    }
906}
907
908/// An operation that can be performed on a coroutine.
909#[derive(PartialEq, Copy, Clone, Debug)]
910enum Operation {
911    Resume,
912    Drop,
913    AsyncDrop,
914}
915
916impl Operation {
917    fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
918        match self {
919            Operation::Resume => Some(point.resume),
920            Operation::Drop | Operation::AsyncDrop => point.drop,
921        }
922    }
923
924    fn resume_place<'tcx>(self, point: &SuspensionPoint<'tcx>) -> Option<Place<'tcx>> {
925        match self {
926            Operation::Resume | Operation::AsyncDrop => Some(point.resume_arg),
927            Operation::Drop => None,
928        }
929    }
930}
931
932#[tracing::instrument(level = "trace", skip(transform, body))]
933fn create_cases<'tcx>(
934    body: &mut Body<'tcx>,
935    transform: &TransformVisitor<'tcx>,
936    operation: Operation,
937) -> Vec<(usize, BasicBlock)> {
938    let source_info = SourceInfo::outermost(body.span);
939
940    transform
941        .suspension_points
942        .iter()
943        .filter_map(|point| {
944            // Find the target for this suspension point, if applicable
945            operation.target_block(point).map(|target| {
946                let mut statements = Vec::new();
947
948                // Create StorageLive instructions for locals with live storage
949                for l in body.local_decls.indices() {
950                    let needs_storage_live = point.storage_liveness.contains(l)
951                        && !transform.remap.contains(l)
952                        && !transform.always_live_locals.contains(l);
953                    if needs_storage_live {
954                        statements.push(Statement::new(source_info, StatementKind::StorageLive(l)));
955                    }
956                }
957
958                // Move the resume argument to the destination place of the `Yield` terminator
959                if let Some(resume_arg) = operation.resume_place(point)
960                    && resume_arg != CTX_ARG.into()
961                {
962                    statements.push(Statement::new(
963                        source_info,
964                        StatementKind::Assign(Box::new((
965                            resume_arg,
966                            Rvalue::Use(Operand::Move(CTX_ARG.into()), WithRetag::Yes),
967                        ))),
968                    ));
969                }
970
971                // Then jump to the real target
972                let block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
973                    statements,
974                    Some(Terminator {
975                        source_info,
976                        kind: TerminatorKind::Goto { target },
977
978                        attributes: ThinVec::new(),
979                    }),
980                    false,
981                ));
982
983                (point.state, block)
984            })
985        })
986        .collect()
987}
988
989impl<'tcx> crate::MirPass<'tcx> for StateTransform {
990    #[instrument(level = "debug", skip(self, tcx, body), ret)]
991    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
992        debug!(def_id = ?body.source.def_id());
993
994        let Some(old_yield_ty) = body.yield_ty() else {
995            // This only applies to coroutines
996            return;
997        };
998        tracing::trace!(def_id = ?body.source.def_id());
999
1000        let old_ret_ty = body.return_ty();
1001
1002        assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
1003
1004        if let Some(dumper) = MirDumper::new(tcx, "coroutine_before", body) {
1005            dumper.dump_mir(body);
1006        }
1007
1008        // The first argument is the coroutine type passed by value
1009        let coroutine_ty = body.local_decls.raw[1].ty;
1010        let coroutine_kind = body.coroutine_kind().unwrap();
1011
1012        // Get the discriminant type and args which typeck computed
1013        let ty::Coroutine(_, args) = coroutine_ty.kind() else {
1014            tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
1015        };
1016        let discr_ty = args.as_coroutine().discr_ty(tcx);
1017
1018        let new_ret_ty = match coroutine_kind {
1019            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
1020                // Compute Poll<return_ty>
1021                let poll_did = tcx.require_lang_item(LangItem::Poll, body.span);
1022                let poll_adt_ref = tcx.adt_def(poll_did);
1023                let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
1024                Ty::new_adt(tcx, poll_adt_ref, poll_args)
1025            }
1026            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1027                // Compute Option<yield_ty>
1028                let option_did = tcx.require_lang_item(LangItem::Option, body.span);
1029                let option_adt_ref = tcx.adt_def(option_did);
1030                let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1031                Ty::new_adt(tcx, option_adt_ref, option_args)
1032            }
1033            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
1034                // The yield ty is already `Poll<Option<yield_ty>>`
1035                old_yield_ty
1036            }
1037            CoroutineKind::Coroutine(_) => {
1038                // Compute CoroutineState<yield_ty, return_ty>
1039                let state_did = tcx.require_lang_item(LangItem::CoroutineState, body.span);
1040                let state_adt_ref = tcx.adt_def(state_did);
1041                let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1042                Ty::new_adt(tcx, state_adt_ref, state_args)
1043            }
1044        };
1045
1046        // We need to insert clean drop for unresumed state and perform drop elaboration
1047        // (finally in open_drop_for_tuple) before async drop expansion.
1048        // Async drops, produced by this drop elaboration, will be expanded,
1049        // and corresponding futures kept in layout.
1050        let has_async_drops = has_async_drops(body);
1051
1052        if coroutine_kind.is_async_desugaring() {
1053            eliminate_get_context_calls(tcx, body);
1054        }
1055
1056        let always_live_locals = always_storage_live_locals(body);
1057        let movable = coroutine_kind.movability() == hir::Movability::Movable;
1058        let liveness_info =
1059            locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1060
1061        if tcx.sess.opts.unstable_opts.validate_mir {
1062            let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias {
1063                assigned_local: None,
1064                saved_locals: &liveness_info.saved_locals,
1065                storage_conflicts: &liveness_info.storage_conflicts,
1066            };
1067
1068            vis.visit_body(body);
1069        }
1070
1071        // Extract locals which are live across suspension point into `layout`
1072        // `remap` gives a mapping from local indices onto coroutine struct indices
1073        // `storage_liveness` tells us which locals have live storage at suspension points
1074        let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
1075
1076        let can_return = can_return(tcx, body, body.typing_env(tcx));
1077
1078        // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1079        // RETURN_PLACE then is a fresh unused local with type ret_ty.
1080        let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1081        tracing::trace!(?new_ret_local);
1082
1083        // Run the transformation which converts Places from Local to coroutine struct
1084        // accesses for locals in `remap`.
1085        // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
1086        // either `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
1087        // or `Poll::Ready(x)` and `Poll::Pending` respectively depending on the coroutine kind.
1088        let mut transform = TransformVisitor {
1089            tcx,
1090            coroutine_kind,
1091            remap,
1092            storage_liveness,
1093            always_live_locals,
1094            suspension_points: Vec::new(),
1095            discr_ty,
1096            new_ret_local,
1097            old_ret_ty,
1098            old_yield_ty,
1099        };
1100        transform.visit_body(body);
1101
1102        // Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1103        transform.replace_local(RETURN_PLACE, new_ret_local, body);
1104
1105        // MIR parameters are not explicitly assigned-to when entering the MIR body.
1106        // If we want to save their values inside the coroutine state, we need to do so explicitly.
1107        let source_info = SourceInfo::outermost(body.span);
1108        let args_iter = body.args_iter();
1109        body.basic_blocks.as_mut()[START_BLOCK].statements.splice(
1110            0..0,
1111            args_iter.filter_map(|local| {
1112                let (ty, variant_index, idx) = transform.remap[local]?;
1113                let lhs = transform.make_field(variant_index, idx, ty);
1114                let rhs = Rvalue::Use(Operand::Move(local.into()), WithRetag::Yes);
1115                let assign = StatementKind::Assign(Box::new((lhs, rhs)));
1116                Some(Statement::new(source_info, assign))
1117            }),
1118        );
1119
1120        // Remove the context argument within generator bodies.
1121        if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1122            body.arg_count = 1;
1123        }
1124
1125        // The original arguments to the function are no longer arguments, mark them as such.
1126        // Otherwise they'll conflict with our new arguments, which although they don't have
1127        // argument_index set, will get emitted as unnamed arguments.
1128        for var in &mut body.var_debug_info {
1129            var.argument_index = None;
1130        }
1131
1132        body.coroutine.as_mut().unwrap().yield_ty = None;
1133        body.coroutine.as_mut().unwrap().resume_ty = None;
1134        body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
1135
1136        // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
1137        // the unresumed state.
1138        // This is expanded to a drop ladder in `elaborate_coroutine_drops`.
1139        let drop_clean = insert_clean_drop(tcx, body, has_async_drops);
1140
1141        if let Some(dumper) = MirDumper::new(tcx, "coroutine_pre-elab", body) {
1142            dumper.dump_mir(body);
1143        }
1144
1145        // Expand `drop(coroutine_struct)` to a drop ladder which destroys upvars.
1146        // If any upvars are moved out of, drop elaboration will handle upvar destruction.
1147        // However we need to also elaborate the code generated by `insert_clean_drop`.
1148        elaborate_coroutine_drops(tcx, body);
1149
1150        if let Some(dumper) = MirDumper::new(tcx, "coroutine_post-transform", body) {
1151            dumper.dump_mir(body);
1152        }
1153
1154        let can_unwind = can_unwind(tcx, body);
1155
1156        // Create a copy of our MIR and use it to create the drop shim for the coroutine
1157        if has_async_drops {
1158            // If coroutine has async drops, generating async drop shim
1159            let drop_shim =
1160                create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind);
1161            body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim);
1162        } else {
1163            // If coroutine has no async drops, generating sync drop shim
1164            let drop_shim =
1165                create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
1166            body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
1167
1168            // For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1169            let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
1170            body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
1171        }
1172
1173        // Create the Coroutine::resume / Future::poll function
1174        create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);
1175    }
1176
1177    fn is_required(&self) -> bool {
1178        true
1179    }
1180}
1181
1182/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
1183/// in the coroutine state machine but whose storage is not marked as conflicting
1184///
1185/// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after.
1186///
1187/// This condition would arise when the assignment is the last use of `_5` but the initial
1188/// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as
1189/// conflicting. Non-conflicting coroutine saved locals may be stored at the same location within
1190/// the coroutine state machine, which would result in ill-formed MIR: the left-hand and right-hand
1191/// sides of an assignment may not alias. This caused a miscompilation in [#73137].
1192///
1193/// [#73137]: https://github.com/rust-lang/rust/issues/73137
1194struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> {
1195    saved_locals: &'a CoroutineSavedLocals,
1196    storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
1197    assigned_local: Option<CoroutineSavedLocal>,
1198}
1199
1200impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1201    fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> {
1202        if place.is_indirect() {
1203            return None;
1204        }
1205
1206        self.saved_locals.get(place.local)
1207    }
1208
1209    fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1210        if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1211            assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1212
1213            self.assigned_local = Some(assigned_local);
1214            f(self);
1215            self.assigned_local = None;
1216        }
1217    }
1218}
1219
1220impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1221    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1222        let Some(lhs) = self.assigned_local else {
1223            // This visitor only invokes `visit_place` for the right-hand side of an assignment
1224            // and only after setting `self.assigned_local`. However, the default impl of
1225            // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places
1226            // with debuginfo. Ignore them here.
1227            assert!(!context.is_use());
1228            return;
1229        };
1230
1231        let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1232
1233        if !self.storage_conflicts.contains(lhs, rhs) {
1234            bug!(
1235                "Assignment between coroutine saved locals whose storage is not \
1236                    marked as conflicting: {:?}: {:?} = {:?}",
1237                location,
1238                lhs,
1239                rhs,
1240            );
1241        }
1242    }
1243
1244    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1245        match &statement.kind {
1246            StatementKind::Assign((lhs, rhs)) => {
1247                self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1248            }
1249
1250            StatementKind::FakeRead(..)
1251            | StatementKind::SetDiscriminant { .. }
1252            | StatementKind::StorageLive(_)
1253            | StatementKind::StorageDead(_)
1254            | StatementKind::AscribeUserType(..)
1255            | StatementKind::PlaceMention(..)
1256            | StatementKind::Coverage(..)
1257            | StatementKind::Intrinsic(..)
1258            | StatementKind::ConstEvalCounter
1259            | StatementKind::BackwardIncompatibleDropHint { .. }
1260            | StatementKind::Nop => {}
1261        }
1262    }
1263
1264    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1265        // Checking for aliasing in terminators is probably overkill, but until we have actual
1266        // semantics, we should be conservative here.
1267        match &terminator.kind {
1268            TerminatorKind::Call {
1269                func,
1270                args,
1271                destination,
1272                target: Some(_),
1273                unwind: _,
1274                call_source: _,
1275                fn_span: _,
1276            } => {
1277                self.check_assigned_place(*destination, |this| {
1278                    this.visit_operand(func, location);
1279                    for arg in args {
1280                        this.visit_operand(&arg.node, location);
1281                    }
1282                });
1283            }
1284
1285            TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1286                self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1287            }
1288
1289            // FIXME: Does `asm!` have any aliasing requirements?
1290            TerminatorKind::InlineAsm { .. } => {}
1291
1292            TerminatorKind::Call { .. }
1293            | TerminatorKind::Goto { .. }
1294            | TerminatorKind::SwitchInt { .. }
1295            | TerminatorKind::UnwindResume
1296            | TerminatorKind::UnwindTerminate(_)
1297            | TerminatorKind::Return
1298            | TerminatorKind::TailCall { .. }
1299            | TerminatorKind::Unreachable
1300            | TerminatorKind::Drop { .. }
1301            | TerminatorKind::Assert { .. }
1302            | TerminatorKind::CoroutineDrop
1303            | TerminatorKind::FalseEdge { .. }
1304            | TerminatorKind::FalseUnwind { .. } => {}
1305        }
1306    }
1307}