Skip to main content

cargo/util/
local_poll_adapter.rs

1use futures::{FutureExt, future::LocalBoxFuture, stream::FuturesUnordered};
2use std::{collections::HashMap, hash::Hash, ops::Deref, task::Poll};
3
4/// A local (!Send) adapter for caching and executing an async method
5/// from a non-async context.
6///
7/// The `self_parameter`, `key`, and successful (Ok) results must all be cheap to `clone`.
8///
9/// Ensures at most one in-flight computation per key. Results are:
10/// - cached on success
11/// - not retained on error
12pub struct LocalPollAdapter<'a, S, K, R> {
13    pool: FuturesUnordered<LocalBoxFuture<'a, (K, R)>>,
14    cache: HashMap<K, Poll<R>>,
15    self_parameter: S,
16}
17
18impl<'a, S, K, V, E> LocalPollAdapter<'a, S, K, Result<V, E>>
19where
20    S: Clone + Deref + 'a,
21    K: Clone + Hash + Eq + 'a,
22    V: Clone,
23{
24    pub fn new(self_parameter: S) -> Self {
25        Self {
26            pool: FuturesUnordered::new(),
27            cache: HashMap::new(),
28            self_parameter,
29        }
30    }
31
32    /// Polls the result for `key`, spawning work if needed.
33    ///
34    /// If this function returns [`Poll::Pending`], call [`LocalPollAdapter::wait`]
35    /// to execute the work, then call this function again with the same key
36    /// to pick up the result.
37    ///
38    /// Futures that complete immediately are not queued.
39    pub fn poll<F>(&mut self, f: F, key: K) -> Poll<Result<V, E>>
40    where
41        F: AsyncFn(&S::Target, &K) -> Result<V, E> + 'a,
42    {
43        match self.cache.get(&key) {
44            // We have a cached success value, clone it and return.
45            Some(Poll::Ready(Ok(v))) => return Poll::Ready(Ok(v.clone())),
46            // We have a cached error value, remove it and return.
47            // Errors are not Clone, so they are only stored once.
48            Some(Poll::Ready(Err(_))) => return self.cache.remove(&key).unwrap(),
49            // This key is already pending.
50            Some(Poll::Pending) => return Poll::Pending,
51            // Looks like we have work to do!
52            None => {}
53        }
54
55        // Created a pinned future that executes the function,
56        // returning the key and the result.
57        let mut future = {
58            let key = key.clone();
59            let self_parameter = self.self_parameter.clone();
60            async move {
61                let v = f(self_parameter.deref(), &key).await;
62                (key, v)
63            }
64            .boxed_local()
65        };
66
67        // Attempt to run the future immediately. If it has no `await` yields,
68        // it will return here.
69        if let Some((k, v)) = (&mut future).now_or_never() {
70            if let Ok(success) = &v {
71                // Only cache successful results.
72                self.cache.insert(k, Poll::Ready(Ok(success.clone())));
73            }
74            return Poll::Ready(v);
75        }
76
77        // Insert Pending into the cache so we avoid queuing the same future twice.
78        self.cache.insert(key.clone(), Poll::Pending);
79
80        // Add the future to the pending queue.
81        self.pool.push(future);
82        Poll::Pending
83    }
84
85    /// Returns the number of pending futures.
86    pub fn pending_count(&self) -> usize {
87        self.pool.len()
88    }
89
90    /// Run all pending futures. Returns true if there was no work to do.
91    pub fn wait(&mut self) -> bool {
92        let is_empty = self.pool.is_empty();
93        for (k, v) in crate::util::block_on_stream(&mut self.pool) {
94            *self
95                .cache
96                .get_mut(&k)
97                .expect("all pending work is in the cache") = Poll::Ready(v);
98        }
99        is_empty
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::LocalPollAdapter;
106    use std::{rc::Rc, task::Poll};
107
108    /// Future that yields once.
109    fn yield_once() -> impl std::future::Future<Output = ()> {
110        let mut yielded = false;
111
112        std::future::poll_fn(move |cx| {
113            if yielded {
114                Poll::Ready(())
115            } else {
116                yielded = true;
117                cx.waker().wake_by_ref();
118                Poll::Pending
119            }
120        })
121    }
122
123    struct Thing {}
124
125    impl Thing {
126        async fn widen(&self, i: &i32) -> Result<i64, ()> {
127            if *i > 10 {
128                // Big numbers take longer to process (need to test futures that yield).
129                yield_once().await;
130            }
131            if *i % 2 != 0 {
132                // Odd numbers are not supported (need to test errors).
133                return Err(());
134            }
135            Ok(*i as i64)
136        }
137    }
138
139    /// Poll wrapper around `Thing`
140    struct PolledThing<'a> {
141        poller: LocalPollAdapter<'a, Rc<Thing>, i32, Result<i64, ()>>,
142    }
143
144    impl<'a> PolledThing<'a> {
145        fn new() -> Self {
146            Self {
147                poller: LocalPollAdapter::new(Rc::new(Thing {})),
148            }
149        }
150
151        // Non-async version of the widen method.
152        fn widen(&mut self, i: &i32) -> Poll<Result<i64, ()>> {
153            self.poller.poll(Thing::widen, i.clone())
154        }
155
156        fn wait(&mut self) -> bool {
157            self.poller.wait()
158        }
159    }
160
161    #[test]
162    fn immediate_success() {
163        let mut p = PolledThing::new();
164        assert_eq!(p.widen(&2), Poll::Ready(Ok(2)));
165        assert!(p.wait());
166    }
167
168    #[test]
169    fn immediate_error() {
170        let mut p = PolledThing::new();
171        assert_eq!(p.widen(&1), Poll::Ready(Err(())));
172        assert!(p.wait());
173    }
174
175    #[test]
176    fn deferred_error() {
177        let mut p = PolledThing::new();
178        assert_eq!(p.widen(&1001), Poll::Pending);
179        assert!(!p.wait());
180        assert_eq!(p.widen(&1001), Poll::Ready(Err(())));
181        assert!(p.wait());
182        // Errors are not cached
183        assert_eq!(p.widen(&1001), Poll::Pending);
184        assert!(!p.wait());
185        assert_eq!(p.widen(&1001), Poll::Ready(Err(())));
186        assert!(p.wait());
187    }
188
189    #[test]
190    fn deferred_success() {
191        let mut p = PolledThing::new();
192        assert_eq!(p.widen(&50), Poll::Pending);
193        assert!(!p.wait());
194        assert_eq!(p.widen(&50), Poll::Ready(Ok(50)));
195        assert!(p.wait());
196        // Success is cached.
197        assert_eq!(p.widen(&50), Poll::Ready(Ok(50)));
198        assert!(p.wait());
199    }
200}