cargo/util/
local_poll_adapter.rs1use futures::{FutureExt, future::LocalBoxFuture, stream::FuturesUnordered};
2use std::{collections::HashMap, hash::Hash, ops::Deref, task::Poll};
3
4pub struct LocalPollAdapter<'a, S, K, R> {
13 pool: FuturesUnordered<LocalBoxFuture<'a, (K, R)>>,
14 cache: HashMap<K, Poll<R>>,
15 self_parameter: S,
16}
17
18impl<'a, S, K, V, E> LocalPollAdapter<'a, S, K, Result<V, E>>
19where
20 S: Clone + Deref + 'a,
21 K: Clone + Hash + Eq + 'a,
22 V: Clone,
23{
24 pub fn new(self_parameter: S) -> Self {
25 Self {
26 pool: FuturesUnordered::new(),
27 cache: HashMap::new(),
28 self_parameter,
29 }
30 }
31
32 pub fn poll<F>(&mut self, f: F, key: K) -> Poll<Result<V, E>>
40 where
41 F: AsyncFn(&S::Target, &K) -> Result<V, E> + 'a,
42 {
43 match self.cache.get(&key) {
44 Some(Poll::Ready(Ok(v))) => return Poll::Ready(Ok(v.clone())),
46 Some(Poll::Ready(Err(_))) => return self.cache.remove(&key).unwrap(),
49 Some(Poll::Pending) => return Poll::Pending,
51 None => {}
53 }
54
55 let mut future = {
58 let key = key.clone();
59 let self_parameter = self.self_parameter.clone();
60 async move {
61 let v = f(self_parameter.deref(), &key).await;
62 (key, v)
63 }
64 .boxed_local()
65 };
66
67 if let Some((k, v)) = (&mut future).now_or_never() {
70 if let Ok(success) = &v {
71 self.cache.insert(k, Poll::Ready(Ok(success.clone())));
73 }
74 return Poll::Ready(v);
75 }
76
77 self.cache.insert(key.clone(), Poll::Pending);
79
80 self.pool.push(future);
82 Poll::Pending
83 }
84
85 pub fn pending_count(&self) -> usize {
87 self.pool.len()
88 }
89
90 pub fn wait(&mut self) -> bool {
92 let is_empty = self.pool.is_empty();
93 for (k, v) in crate::util::block_on_stream(&mut self.pool) {
94 *self
95 .cache
96 .get_mut(&k)
97 .expect("all pending work is in the cache") = Poll::Ready(v);
98 }
99 is_empty
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::LocalPollAdapter;
106 use std::{rc::Rc, task::Poll};
107
108 fn yield_once() -> impl std::future::Future<Output = ()> {
110 let mut yielded = false;
111
112 std::future::poll_fn(move |cx| {
113 if yielded {
114 Poll::Ready(())
115 } else {
116 yielded = true;
117 cx.waker().wake_by_ref();
118 Poll::Pending
119 }
120 })
121 }
122
123 struct Thing {}
124
125 impl Thing {
126 async fn widen(&self, i: &i32) -> Result<i64, ()> {
127 if *i > 10 {
128 yield_once().await;
130 }
131 if *i % 2 != 0 {
132 return Err(());
134 }
135 Ok(*i as i64)
136 }
137 }
138
139 struct PolledThing<'a> {
141 poller: LocalPollAdapter<'a, Rc<Thing>, i32, Result<i64, ()>>,
142 }
143
144 impl<'a> PolledThing<'a> {
145 fn new() -> Self {
146 Self {
147 poller: LocalPollAdapter::new(Rc::new(Thing {})),
148 }
149 }
150
151 fn widen(&mut self, i: &i32) -> Poll<Result<i64, ()>> {
153 self.poller.poll(Thing::widen, i.clone())
154 }
155
156 fn wait(&mut self) -> bool {
157 self.poller.wait()
158 }
159 }
160
161 #[test]
162 fn immediate_success() {
163 let mut p = PolledThing::new();
164 assert_eq!(p.widen(&2), Poll::Ready(Ok(2)));
165 assert!(p.wait());
166 }
167
168 #[test]
169 fn immediate_error() {
170 let mut p = PolledThing::new();
171 assert_eq!(p.widen(&1), Poll::Ready(Err(())));
172 assert!(p.wait());
173 }
174
175 #[test]
176 fn deferred_error() {
177 let mut p = PolledThing::new();
178 assert_eq!(p.widen(&1001), Poll::Pending);
179 assert!(!p.wait());
180 assert_eq!(p.widen(&1001), Poll::Ready(Err(())));
181 assert!(p.wait());
182 assert_eq!(p.widen(&1001), Poll::Pending);
184 assert!(!p.wait());
185 assert_eq!(p.widen(&1001), Poll::Ready(Err(())));
186 assert!(p.wait());
187 }
188
189 #[test]
190 fn deferred_success() {
191 let mut p = PolledThing::new();
192 assert_eq!(p.widen(&50), Poll::Pending);
193 assert!(!p.wait());
194 assert_eq!(p.widen(&50), Poll::Ready(Ok(50)));
195 assert!(p.wait());
196 assert_eq!(p.widen(&50), Poll::Ready(Ok(50)));
198 assert!(p.wait());
199 }
200}