Skip to main content

cargo/util/network/
http_async.rs

1//! Async wrapper around cURL for making managing HTTP requests.
2//!
3//! Requests are executed in parallel using cURL [`Multi`] on
4//! a worker thread that is owned by the Client.
5
6use std::collections::HashMap;
7use std::io::Cursor;
8use std::io::Read;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::sync::atomic::Ordering;
12use std::sync::mpsc;
13use std::sync::mpsc::Receiver;
14use std::sync::mpsc::Sender;
15use std::thread::JoinHandle;
16use std::time::Duration;
17use std::time::Instant;
18
19use curl::easy::Easy2;
20use curl::easy::Handler;
21use curl::easy::InfoType;
22use curl::easy::WriteError;
23use curl::multi::Easy2Handle;
24use curl::multi::Multi;
25use futures::channel::oneshot;
26use portable_atomic::AtomicI64;
27use portable_atomic::AtomicU64;
28use tracing::{debug, error, trace, warn};
29
30use crate::util::network::http::HandleConfiguration;
31use crate::util::network::http::HttpTimeout;
32
33type Response = http::Response<Vec<u8>>;
34type Request = http::Request<Vec<u8>>;
35type HttpResult<T> = std::result::Result<T, Error>;
36
37#[derive(Debug, Clone, thiserror::Error)]
38#[non_exhaustive]
39pub enum Error {
40    #[error(transparent)]
41    Multi(#[from] curl::MultiError),
42
43    #[error(transparent)]
44    Easy(#[from] curl::Error),
45
46    #[error(
47        "transfer too slow: failed to transfer more than {low_speed_limit} bytes in {}s (transferred {transferred} bytes)",
48        timeout_dur.as_secs()
49    )]
50    TooSlow {
51        low_speed_limit: u32,
52        timeout_dur: Duration,
53        transferred: u64,
54    },
55
56    #[error("failed to convert header value of `{name}` to string: {bytes:?}")]
57    BadHeader { name: String, bytes: Vec<u8> },
58}
59
60struct Message {
61    easy: Easy2<Collector>,
62    sender: oneshot::Sender<HttpResult<Response>>,
63}
64
65#[derive(Default)]
66struct Stats {
67    dl_remaining: AtomicI64,
68    dl_transferred: AtomicU64,
69}
70
71/// HTTP Client. Creating a new client spawns a cURL `Multi` and
72/// thread that is used for all HTTP requests by this client.
73pub struct Client {
74    channel: Option<Sender<Message>>,
75    thread_handle: Option<JoinHandle<()>>,
76    handle_config: HandleConfiguration,
77    stats: Arc<Stats>,
78}
79
80impl Client {
81    /// Spawns a new worker thread where HTTP request execute.
82    pub fn new(handle_config: HandleConfiguration) -> Client {
83        let (tx, rx) = mpsc::channel();
84        let stats = Arc::new(Stats::default());
85        let timeout = handle_config.timeout.clone();
86        let worker_stats = stats.clone();
87        let handle = std::thread::spawn(move || {
88            WorkerServer::run(rx, handle_config.multiplexing, timeout, worker_stats)
89        });
90        Client {
91            channel: Some(tx),
92            thread_handle: Some(handle),
93            handle_config,
94            stats,
95        }
96    }
97
98    /// Perform a blocking HTTP request using this client.
99    /// Does not start an async executor.
100    pub fn request_blocking(&self, request: Request) -> HttpResult<Response> {
101        let mut handle = self.request_helper(request)?;
102        // Configure the handle timeout since we're blocking here and not using the
103        // client-level timeout.
104        self.handle_config.timeout.configure2(&mut handle)?;
105        handle.perform()?;
106        Ok(WorkerServer::process_response(handle))
107    }
108
109    /// Perform an HTTP request using this client.
110    pub async fn request(&self, request: Request) -> HttpResult<Response> {
111        let handle = self.request_helper(request)?;
112        let (sender, receiver) = oneshot::channel();
113        let req = Message {
114            easy: handle,
115            sender,
116        };
117        self.channel.as_ref().unwrap().send(req).unwrap();
118        receiver.await.unwrap()
119    }
120
121    fn request_helper(&self, request: Request) -> HttpResult<Easy2<Collector>> {
122        let url = request.uri().to_string();
123        debug!(target: "network::fetch", url);
124        let mut collector = Collector::new(self.stats.clone());
125        let (parts, body) = request.into_parts();
126        let body_len = body.len();
127        collector.request_body = Cursor::new(body);
128        collector.debug = self.handle_config.verbose;
129        let mut handle = curl::easy::Easy2::new(collector);
130        self.handle_config.configure2(&mut handle)?;
131
132        handle.url(&url)?;
133        handle.follow_location(true)?;
134        handle.progress(true)?;
135
136        match parts.method {
137            http::Method::HEAD => handle.nobody(true)?,
138            http::Method::GET => handle.get(true)?,
139            http::Method::POST => {
140                handle.post_field_size(body_len as u64)?;
141                handle.post(true)?;
142            }
143            http::Method::PUT => {
144                handle.in_filesize(body_len as u64)?;
145                handle.put(true)?;
146            }
147            method => {
148                if body_len > 0 {
149                    handle.upload(true)?;
150                    handle.in_filesize(body_len as u64)?;
151                }
152                handle.custom_request(method.as_str())?;
153            }
154        }
155
156        let mut headers = curl::easy::List::new();
157        for (name, value) in parts.headers {
158            if let Some(name) = name {
159                let value: &str = value.to_str().map_err(|_| Error::BadHeader {
160                    name: name.to_string(),
161                    bytes: value.as_bytes().to_owned(),
162                })?;
163                headers.append(&format!("{}: {}", name, value))?;
164            }
165        }
166        handle.http_headers(headers)?;
167
168        Ok(handle)
169    }
170
171    /// Returns the number pending bytes across all active transfers.
172    pub fn bytes_pending(&self) -> u64 {
173        self.stats
174            .dl_remaining
175            .load(Ordering::Acquire)
176            .try_into()
177            .unwrap()
178    }
179}
180
181impl Drop for Client {
182    fn drop(&mut self) {
183        // Close the channel
184        drop(self.channel.take().unwrap());
185        // Join the thread
186        let _ = self.thread_handle.take().unwrap().join();
187    }
188}
189
190impl std::fmt::Debug for Client {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        f.debug_struct("http_async::Client").finish()
193    }
194}
195
196/// Manages the cURL `Multi`. Processes incoming work sent over the
197/// channel, and returns responses.
198struct WorkerServer {
199    /// Channel to receive new work
200    incoming_work: Receiver<Message>,
201    /// curl multi interface
202    multi: Multi,
203    /// Map of token to curl handle and response channel
204    handles: HashMap<
205        usize,
206        (
207            Easy2Handle<Collector>,
208            oneshot::Sender<HttpResult<Response>>,
209        ),
210    >,
211    /// Next token to use
212    token: usize,
213    /// Global timeout configuration
214    timeout: HttpTimeout,
215    /// Global transfer statistics
216    stats: Arc<Stats>,
217    /// Instant when the current low speed window started
218    low_speed_window_start: Instant,
219    /// Amount of total bytes transferred when the current low speed window started
220    low_speed_window_initial: u64,
221}
222
223impl WorkerServer {
224    fn run(
225        incoming_work: Receiver<Message>,
226        multiplex: bool,
227        timeout: HttpTimeout,
228        stats: Arc<Stats>,
229    ) {
230        let mut multi = Multi::new();
231        // let's not flood the server with connections
232        if let Err(e) = multi.set_max_host_connections(2) {
233            error!("failed to set max host connections in curl: {e}");
234        }
235        if let Err(e) = multi.pipelining(false, multiplex) {
236            error!("failed to enable multiplexing/pipelining in curl: {e}");
237        }
238
239        let mut worker = Self {
240            incoming_work,
241            multi,
242            handles: HashMap::new(),
243            token: 0,
244            timeout,
245            stats,
246            low_speed_window_start: Instant::now(),
247            low_speed_window_initial: 0,
248        };
249        worker.worker_loop();
250    }
251
252    fn fail_and_drain(&mut self, e: &Error) {
253        warn!(
254            target: "network",
255            "failing all outstanding HTTP requests: {e}"
256        );
257        for (_token, (_handle, sender)) in self.handles.drain() {
258            let _ = sender.send(Err(e.clone()));
259        }
260    }
261
262    fn process_response(mut easy: Easy2<Collector>) -> Response {
263        let mut response =
264            std::mem::replace(&mut easy.get_mut().response, Response::new(Vec::new()));
265        if let Ok(status) = easy.response_code()
266            && status != 0
267            && let Ok(status) = http::StatusCode::from_u16(status as u16)
268        {
269            *response.status_mut() = status;
270        }
271        // Would be nice to set HTTP version via `response.version_mut()`, but `curl` doesn't have it exposed.
272        let extensions = Extensions {
273            client_ip: easy.primary_ip().ok().flatten().map(str::to_string),
274            effective_url: easy.effective_url().ok().flatten().map(str::to_string),
275        };
276        response.extensions_mut().insert(extensions);
277        response
278    }
279
280    /// Marks the start of a new timeout window.
281    fn reset_low_speed_timeout(&mut self) {
282        self.low_speed_window_start = Instant::now();
283        self.low_speed_window_initial = self.stats.dl_transferred.load(Ordering::Acquire);
284    }
285
286    /// Return an error if we're at the end of a timeout window, we haven't
287    /// made enough progress.
288    fn check_low_speed_timeout(&mut self) -> Option<Error> {
289        // Make sure we've waited for the timeout duration
290        if Instant::now().duration_since(self.low_speed_window_start) < self.timeout.dur {
291            return None;
292        }
293
294        // Calculate how much we've transferred since the last check.
295        let current = self.stats.dl_transferred.load(Ordering::Acquire);
296        let transferred = current.saturating_sub(self.low_speed_window_initial);
297        self.reset_low_speed_timeout();
298        if transferred < self.timeout.low_speed_limit.into() {
299            Some(Error::TooSlow {
300                low_speed_limit: self.timeout.low_speed_limit,
301                timeout_dur: self.timeout.dur,
302                transferred,
303            })
304        } else {
305            None
306        }
307    }
308
309    fn worker_loop(&mut self) {
310        const INITIAL_DELAY: Duration = Duration::from_millis(1);
311        let mut wait_backoff = INITIAL_DELAY;
312        loop {
313            // Start any pending work.
314            while let Ok(msg) = self.incoming_work.try_recv() {
315                self.enqueue_request(msg);
316                wait_backoff = INITIAL_DELAY;
317            }
318
319            match self.multi.perform() {
320                Err(e) if e.is_call_perform() => {
321                    // cURL states if you receive `is_call_perform`, this means that you should call `perform` again.
322                }
323                Err(e) => {
324                    self.fail_and_drain(&Error::Multi(e));
325                }
326                Ok(running) => {
327                    self.multi.messages(|msg| {
328                        let t = msg.token().expect("all handles have tokens");
329                        trace!(token = t, "finish");
330                        let Some((handle, sender)) = self.handles.remove(&t) else {
331                            error!("missing entry {t} in handle table");
332                            return;
333                        };
334                        let result = msg.result_for2(&handle).expect("handle must have a result");
335                        let easy = self.multi.remove2(handle).expect("handle must be in multi");
336                        let response = Self::process_response(easy);
337                        let _ = sender.send(result.map(|()| response).map_err(Into::into));
338                    });
339
340                    if running > 0 {
341                        // Check for low speed timeout.
342                        if let Some(timeout_error) = self.check_low_speed_timeout() {
343                            self.fail_and_drain(&timeout_error);
344                            continue;
345                        }
346
347                        let max_timeout = Duration::from_millis(1000);
348                        let mut timeout = self
349                            .multi
350                            .get_timeout()
351                            .ok()
352                            .flatten()
353                            .unwrap_or(max_timeout)
354                            .min(max_timeout);
355                        if timeout.is_zero() {
356                            // curl said not to wait.
357                            continue;
358                        }
359                        // Ideally we would use `Multi::poll` + a `MultiWaker` instead of `Multi::wait`
360                        // to wake the thread when new work is queued. But it requires curl 7.68+,
361                        // which is not available everywhere we support.
362                        //
363                        // Instead, we use an exponential backoff approach so that as long as requests
364                        // are being queued, we poll quickly to allow the requests to be added sooner.
365                        // Without this, we end up sitting in `Multi::wait` too long while new work is
366                        // added to the channel.
367                        //
368                        // `get_timeout` says we should wait *at most* the timeout amount, so reducing
369                        // the wait time is fine.
370                        if wait_backoff < timeout {
371                            wait_backoff *= 2;
372                            timeout = wait_backoff
373                        }
374                        trace!(
375                            pending = self.handles.len(),
376                            timeout = timeout.as_millis(),
377                            "curl wait"
378                        );
379                        if let Err(e) = self.multi.wait(&mut [], timeout) {
380                            self.fail_and_drain(&Error::Multi(e));
381                        }
382                    } else {
383                        // Block, waiting for more work
384                        trace!("all work completed");
385                        match self.incoming_work.recv() {
386                            Ok(msg) => {
387                                trace!("resuming work");
388                                self.reset_low_speed_timeout();
389                                self.enqueue_request(msg);
390                                wait_backoff = INITIAL_DELAY;
391                            }
392                            Err(_) => {
393                                // The sending channel is closed. Shut down the worker.
394                                break;
395                            }
396                        }
397                    }
398                }
399            }
400        }
401    }
402
403    /// Adds the request to the `Multi`, or send an error back through the channel.
404    fn enqueue_request(&mut self, message: Message) {
405        match self.multi.add2(message.easy) {
406            Ok(mut handle) => {
407                self.token = self.token.wrapping_add(1);
408                handle.set_token(self.token).ok();
409                self.handles.insert(self.token, (handle, message.sender));
410            }
411            Err(e) => {
412                let _ = message.sender.send(Err(e.into()));
413            }
414        }
415    }
416}
417
418/// Interface that cURL (`Easy2`) uses to make progress.
419struct Collector {
420    /// The response being built
421    response: Response,
422    /// The body to transmit
423    request_body: Cursor<Vec<u8>>,
424    /// Whether we're in debug mode
425    debug: bool,
426    /// Global transfer statistics.
427    global_stats: Arc<Stats>,
428    /// How much has this particular transfer added to global `dl_remaining` stats.
429    dl_remaining_delta: i64,
430}
431
432impl Collector {
433    fn new(stats: Arc<Stats>) -> Self {
434        Collector {
435            response: Response::new(Vec::new()),
436            request_body: Cursor::new(Vec::new()),
437            debug: false,
438            global_stats: stats,
439            dl_remaining_delta: 0,
440        }
441    }
442}
443
444impl Handler for Collector {
445    fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
446        self.response.body_mut().extend_from_slice(data);
447        self.global_stats
448            .dl_transferred
449            .fetch_add(data.len() as u64, Ordering::Release);
450        Ok(data.len())
451    }
452
453    fn header(&mut self, data: &[u8]) -> bool {
454        if let Some((name, value)) = handle_http_header(data)
455            && let Ok(name) = http::HeaderName::from_str(name)
456            && let Ok(value) = http::HeaderValue::from_str(value)
457        {
458            self.response.headers_mut().append(name, value);
459        }
460        true
461    }
462
463    fn read(&mut self, data: &mut [u8]) -> Result<usize, curl::easy::ReadError> {
464        Ok(self.request_body.read(data).unwrap())
465    }
466
467    fn debug(&mut self, kind: InfoType, data: &[u8]) {
468        if self.debug {
469            super::http::debug(kind, data);
470        }
471    }
472
473    fn progress(&mut self, dltotal: f64, dlnow: f64, _ultotal: f64, _ulnow: f64) -> bool {
474        if dlnow > dltotal {
475            return true;
476        }
477        let dl_total = dltotal as i64;
478        let dl_current = dlnow as i64;
479
480        let remaining = dl_total - dl_current;
481
482        self.global_stats
483            .dl_remaining
484            .fetch_add(remaining - self.dl_remaining_delta, Ordering::Release);
485        self.dl_remaining_delta = remaining;
486        true
487    }
488}
489
490impl Drop for Collector {
491    fn drop(&mut self) {
492        // Zero out this transfer's contribution to the global dl_remaining.
493        self.global_stats
494            .dl_remaining
495            .fetch_add(-self.dl_remaining_delta, Ordering::Release);
496    }
497}
498
499/// Additional fields on an [`http::Response`].
500#[derive(Clone)]
501struct Extensions {
502    client_ip: Option<String>,
503    effective_url: Option<String>,
504}
505
506pub trait ResponsePartsExtensions {
507    fn client_ip(&self) -> Option<&str>;
508    fn effective_url(&self) -> Option<&str>;
509}
510
511impl ResponsePartsExtensions for http::response::Parts {
512    fn client_ip(&self) -> Option<&str> {
513        self.extensions
514            .get::<Extensions>()
515            .and_then(|extensions| extensions.client_ip.as_deref())
516    }
517
518    fn effective_url(&self) -> Option<&str> {
519        self.extensions
520            .get::<Extensions>()
521            .and_then(|extensions| extensions.effective_url.as_deref())
522    }
523}
524
525impl ResponsePartsExtensions for Response {
526    fn client_ip(&self) -> Option<&str> {
527        self.extensions()
528            .get::<Extensions>()
529            .and_then(|extensions| extensions.client_ip.as_deref())
530    }
531
532    fn effective_url(&self) -> Option<&str> {
533        self.extensions()
534            .get::<Extensions>()
535            .and_then(|extensions| extensions.effective_url.as_deref())
536    }
537}
538
539/// Splits HTTP `HEADER: VALUE` to a tuple.
540fn handle_http_header(buf: &[u8]) -> Option<(&str, &str)> {
541    if buf.is_empty() {
542        return None;
543    }
544    let buf = std::str::from_utf8(buf).ok()?.trim_end();
545    // Don't let server sneak extra lines anywhere.
546    if buf.contains('\n') {
547        return None;
548    }
549    let (tag, value) = buf.split_once(':')?;
550    let value = value.trim();
551    Some((tag, value))
552}