Skip to main content

saluki_io/net/client/http/
conn.rs

1use std::{
2    future::Future,
3    io,
4    path::PathBuf,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use hickory_resolver::net::NetError;
12use http::{Extensions, Uri};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
14use hyper_util::{
15    client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
16    rt::TokioIo,
17};
18use metrics::Counter;
19use pin_project_lite::pin_project;
20use rustls::ClientConfig;
21use saluki_error::{ErrorContext as _, GenericError};
22use tokio::net::TcpStream;
23use tower::{BoxError, Service};
24use tracing::debug;
25
26use super::telemetry::HttpTransactionErrorTelemetry;
27use crate::net::dns::{HickoryHttpConnector, HickoryResolver};
28
29/// Imposes a limit on the age of a connection.
30///
31/// In many cases, it's undesirable to hold onto a connection indefinitely, even if it can be theoretically reused.
32/// Doing so can make it more difficult to perform maintenance on infrastructure, as the expectation of old connections
33/// being eventually closed and replaced isn't upheld.
34///
35/// This extension allows tracking the age of a connection (based on when the connector creates the connection) and
36/// checking if it's expired, or past the configured limit. Callers can then decide how to handle the expiration, such
37/// as by closing the connection.
38#[derive(Clone)]
39struct ConnectionAgeLimit {
40    limit: Duration,
41    created: Instant,
42}
43
44impl ConnectionAgeLimit {
45    fn new(limit: Duration) -> Self {
46        ConnectionAgeLimit {
47            limit,
48            created: Instant::now(),
49        }
50    }
51
52    fn is_expired(&self) -> bool {
53        self.created.elapsed() >= self.limit
54    }
55}
56
57/// An inner transport that abstracts over TCP and Unix domain socket connections.
58///
59/// This allows using a single monomorphization of the HTTP/2 and TLS stacks regardless of the
60/// underlying transport, avoiding duplicate code generation for each transport type.
61enum Transport {
62    Tcp(TokioIo<TcpStream>),
63    #[cfg(unix)]
64    Unix(TokioIo<tokio::net::UnixStream>),
65}
66
67impl Connection for Transport {
68    fn connected(&self) -> Connected {
69        match self {
70            Self::Tcp(s) => s.connected(),
71            #[cfg(unix)]
72            Self::Unix(_) => Connected::new(),
73        }
74    }
75}
76
77impl hyper::rt::Read for Transport {
78    fn poll_read(
79        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
80    ) -> Poll<io::Result<()>> {
81        match Pin::get_mut(self) {
82            Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
83            #[cfg(unix)]
84            Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
85        }
86    }
87}
88
89impl hyper::rt::Write for Transport {
90    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
91        match Pin::get_mut(self) {
92            Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
93            #[cfg(unix)]
94            Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
95        }
96    }
97
98    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99        match Pin::get_mut(self) {
100            Self::Tcp(s) => Pin::new(s).poll_flush(cx),
101            #[cfg(unix)]
102            Self::Unix(s) => Pin::new(s).poll_flush(cx),
103        }
104    }
105
106    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107        match Pin::get_mut(self) {
108            Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
109            #[cfg(unix)]
110            Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
111        }
112    }
113
114    fn is_write_vectored(&self) -> bool {
115        match self {
116            Self::Tcp(s) => s.is_write_vectored(),
117            #[cfg(unix)]
118            Self::Unix(s) => s.is_write_vectored(),
119        }
120    }
121
122    fn poll_write_vectored(
123        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
124    ) -> Poll<io::Result<usize>> {
125        match Pin::get_mut(self) {
126            Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
127            #[cfg(unix)]
128            Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
129        }
130    }
131}
132
133pin_project! {
134    /// A connection that supports both HTTP and HTTPS.
135    pub struct HttpsCapableConnection {
136        #[pin]
137        inner: MaybeHttpsStream<Transport>,
138        bytes_sent: Option<Counter>,
139        error_telemetry: Option<HttpTransactionErrorTelemetry>,
140        conn_age_limit: Option<Duration>,
141    }
142}
143
144impl Connection for HttpsCapableConnection {
145    fn connected(&self) -> Connected {
146        let connected = self.inner.connected();
147
148        if let Some(conn_age_limit) = self.conn_age_limit {
149            debug!("setting connection age limit to {:?}", conn_age_limit);
150            connected.extra(ConnectionAgeLimit::new(conn_age_limit))
151        } else {
152            connected
153        }
154    }
155}
156
157impl hyper::rt::Read for HttpsCapableConnection {
158    fn poll_read(
159        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
160    ) -> Poll<io::Result<()>> {
161        let this = self.project();
162        this.inner.poll_read(cx, buf)
163    }
164}
165
166impl hyper::rt::Write for HttpsCapableConnection {
167    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
168        let this = self.project();
169        match this.inner.poll_write(cx, buf) {
170            Poll::Ready(Ok(n)) => {
171                if let Some(bytes_sent) = this.bytes_sent {
172                    bytes_sent.increment(n as u64);
173                }
174                Poll::Ready(Ok(n))
175            }
176            Poll::Ready(Err(error)) => {
177                if let Some(error_telemetry) = this.error_telemetry.as_ref() {
178                    error_telemetry.increment_wrote_request_error();
179                }
180                Poll::Ready(Err(error))
181            }
182            other => other,
183        }
184    }
185
186    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
187        let this = self.project();
188        match this.inner.poll_flush(cx) {
189            Poll::Ready(Err(error)) => {
190                if let Some(error_telemetry) = this.error_telemetry.as_ref() {
191                    error_telemetry.increment_wrote_request_error();
192                }
193                Poll::Ready(Err(error))
194            }
195            other => other,
196        }
197    }
198
199    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200        let this = self.project();
201        this.inner.poll_shutdown(cx)
202    }
203
204    fn is_write_vectored(&self) -> bool {
205        self.inner.is_write_vectored()
206    }
207
208    fn poll_write_vectored(
209        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
210    ) -> Poll<io::Result<usize>> {
211        let this = self.project();
212        match this.inner.poll_write_vectored(cx, bufs) {
213            Poll::Ready(Ok(n)) => {
214                if let Some(bytes_sent) = this.bytes_sent {
215                    bytes_sent.increment(n as u64);
216                }
217                Poll::Ready(Ok(n))
218            }
219            Poll::Ready(Err(error)) => {
220                if let Some(error_telemetry) = this.error_telemetry.as_ref() {
221                    error_telemetry.increment_wrote_request_error();
222                }
223                Poll::Ready(Err(error))
224            }
225            other => other,
226        }
227    }
228}
229
230/// An inner connector that routes to either TCP (via DNS) or a Unix domain socket.
231///
232/// When a Unix socket path is configured, all connections are routed through that socket regardless
233/// of the URI host. Otherwise, connections are routed via the standard DNS + TCP path.
234#[derive(Clone)]
235struct InnerConnector {
236    http: HickoryHttpConnector,
237    connect_timeout: Duration,
238    error_telemetry: Option<HttpTransactionErrorTelemetry>,
239    #[cfg(unix)]
240    unix_socket_path: Option<Arc<std::path::Path>>,
241}
242
243impl Service<Uri> for InnerConnector {
244    type Response = Transport;
245    type Error = BoxError;
246    type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
247
248    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249        // When routing via a Unix domain socket, the TCP/DNS connector is not used, so we consider
250        // the service immediately ready.
251        #[cfg(unix)]
252        if self.unix_socket_path.is_some() {
253            return Poll::Ready(Ok(()));
254        }
255
256        self.http.poll_ready(cx).map_err(Into::into)
257    }
258
259    fn call(&mut self, dst: Uri) -> Self::Future {
260        #[cfg(unix)]
261        if let Some(path) = self.unix_socket_path.clone() {
262            let connect_timeout = self.connect_timeout;
263            let error_telemetry = self.error_telemetry.clone();
264            return Box::pin(async move {
265                let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
266                    .await
267                    .map_err(|_| -> BoxError {
268                        if let Some(error_telemetry) = &error_telemetry {
269                            error_telemetry.increment_connection_error();
270                        }
271                        Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
272                    })?
273                    .map_err(|e| -> BoxError {
274                        if let Some(error_telemetry) = &error_telemetry {
275                            error_telemetry.increment_connection_error();
276                        }
277                        Box::new(e)
278                    })?;
279                Ok(Transport::Unix(TokioIo::new(stream)))
280            });
281        }
282
283        let fut = self.http.call(dst);
284        let error_telemetry = self.error_telemetry.clone();
285        Box::pin(async move {
286            let tcp = fut.await.map_err(|error| {
287                if !is_dns_error(&error) {
288                    if let Some(error_telemetry) = &error_telemetry {
289                        error_telemetry.increment_connection_error();
290                    }
291                }
292                BoxError::from(error)
293            })?;
294            Ok(Transport::Tcp(tcp))
295        })
296    }
297}
298
299/// HTTP protocol selection for client connections.
300#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
301pub enum HttpProtocol {
302    /// Automatically negotiate HTTP/2 with HTTP/1.1 fallback.
303    #[default]
304    Auto,
305
306    /// Use HTTP/1.1 only.
307    Http1,
308}
309
310/// A connector that supports HTTP or HTTPS.
311#[derive(Clone)]
312pub struct HttpsCapableConnector {
313    inner: HttpsConnector<InnerConnector>,
314    bytes_sent: Option<Counter>,
315    error_telemetry: Option<HttpTransactionErrorTelemetry>,
316    conn_age_limit: Option<Duration>,
317}
318
319impl Service<Uri> for HttpsCapableConnector {
320    type Response = HttpsCapableConnection;
321    type Error = BoxError;
322    type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
323
324    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
325        self.inner.poll_ready(cx)
326    }
327
328    fn call(&mut self, dst: Uri) -> Self::Future {
329        let inner = self.inner.call(dst);
330        let bytes_sent = self.bytes_sent.clone();
331        let error_telemetry = self.error_telemetry.clone();
332        let conn_age_limit = self.conn_age_limit;
333        Box::pin(async move {
334            match inner.await {
335                Ok(inner) => Ok(HttpsCapableConnection {
336                    inner,
337                    bytes_sent,
338                    error_telemetry,
339                    conn_age_limit,
340                }),
341                Err(error) => {
342                    if is_tls_error(error.as_ref()) {
343                        if let Some(error_telemetry) = &error_telemetry {
344                            error_telemetry.increment_tls_error();
345                        }
346                    }
347                    Err(error)
348                }
349            }
350        })
351    }
352}
353
354/// A builder for `HttpsCapableConnector`.
355#[derive(Default)]
356pub struct HttpsCapableConnectorBuilder {
357    connect_timeout: Option<Duration>,
358    bytes_sent: Option<Counter>,
359    error_telemetry: Option<HttpTransactionErrorTelemetry>,
360    conn_age_limit: Option<Duration>,
361    http_protocol: HttpProtocol,
362    #[cfg(unix)]
363    unix_socket_path: Option<PathBuf>,
364}
365
366impl HttpsCapableConnectorBuilder {
367    /// Sets the timeout when connecting to the remote host.
368    ///
369    /// Defaults to 30 seconds.
370    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
371        self.connect_timeout = Some(timeout);
372        self
373    }
374
375    /// Sets the HTTP protocol selection for client connections.
376    ///
377    /// Defaults to [`HttpProtocol::Auto`].
378    pub fn with_http_protocol(mut self, protocol: HttpProtocol) -> Self {
379        self.http_protocol = protocol;
380        self
381    }
382
383    /// Sets the maximum age of a connection before it's closed.
384    ///
385    /// This is distinct from the maximum idle time: if any connection's age exceeds `limit`, it will be closed rather
386    /// than being reused and added to the idle connection pool.
387    ///
388    /// Defaults to no limit.
389    pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
390    where
391        L: Into<Option<Duration>>,
392    {
393        self.conn_age_limit = limit.into();
394        self
395    }
396
397    /// Sets a counter that gets incremented with the number of bytes sent over the connection.
398    ///
399    /// This tracks bytes sent at the HTTP client level, which includes headers and body but doesn't include underlying
400    /// transport overhead, such as TLS handshaking, and so on.
401    ///
402    /// Defaults to unset.
403    pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
404        self.bytes_sent = Some(counter);
405        self
406    }
407
408    /// Sets the telemetry counters used to track HTTP request lifecycle failures.
409    pub(super) fn with_error_telemetry(mut self, error_telemetry: HttpTransactionErrorTelemetry) -> Self {
410        self.error_telemetry = Some(error_telemetry);
411        self
412    }
413
414    /// Sets a Unix domain socket path to route all connections through.
415    ///
416    /// When set, the connector will connect to this Unix socket instead of performing DNS resolution
417    /// and TCP connection. The URI host is ignored in this case—all requests are sent through the
418    /// configured socket.
419    ///
420    /// Defaults to unset (TCP connections via DNS).
421    #[cfg(unix)]
422    pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
423        self.unix_socket_path = Some(path.into());
424        self
425    }
426
427    /// Builds the `HttpsCapableConnector` from the given TLS configuration.
428    pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
429        let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
430
431        let mut hickory_resolver = HickoryResolver::from_system_conf()
432            .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
433        if let Some(error_telemetry) = &self.error_telemetry {
434            hickory_resolver = hickory_resolver.with_lookup_errors_counter(error_telemetry.dns_errors());
435        }
436
437        // Create the HTTP connector, and ensure that we don't enforce _only_ HTTP, since that will break being able to
438        // wrap this in an HTTPS connector.
439        let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
440        http_connector.set_connect_timeout(Some(connect_timeout));
441        http_connector.enforce_http(false);
442
443        let inner_connector = InnerConnector {
444            http: http_connector,
445            connect_timeout,
446            error_telemetry: self.error_telemetry.clone(),
447            #[cfg(unix)]
448            unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
449        };
450
451        // Create the HTTPS connector.
452        let https_connector_builder = HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http();
453        let https_connector = match self.http_protocol {
454            HttpProtocol::Auto => https_connector_builder
455                .enable_all_versions()
456                .wrap_connector(inner_connector),
457            HttpProtocol::Http1 => https_connector_builder.enable_http1().wrap_connector(inner_connector),
458        };
459
460        Ok(HttpsCapableConnector {
461            inner: https_connector,
462            bytes_sent: self.bytes_sent,
463            error_telemetry: self.error_telemetry,
464            conn_age_limit: self.conn_age_limit,
465        })
466    }
467}
468
469#[cfg(test)]
470fn configure_tls_alpn_for_http_protocol(mut tls_config: ClientConfig, protocol: HttpProtocol) -> ClientConfig {
471    match protocol {
472        HttpProtocol::Auto => {
473            tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
474        }
475        HttpProtocol::Http1 => {
476            tls_config.alpn_protocols.clear();
477        }
478    }
479
480    tls_config
481}
482
483fn is_tls_error(error: &(dyn std::error::Error + 'static)) -> bool {
484    let mut current = Some(error);
485    while let Some(error) = current {
486        if error.downcast_ref::<rustls::Error>().is_some() {
487            return true;
488        }
489        current = error.source();
490    }
491    false
492}
493
494fn is_dns_error(error: &(dyn std::error::Error + 'static)) -> bool {
495    let mut current = Some(error);
496    while let Some(error) = current {
497        if error.downcast_ref::<NetError>().is_some() {
498            return true;
499        }
500        current = error.source();
501    }
502    false
503}
504
505pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
506    let maybe_conn_metadata = captured_conn.connection_metadata();
507    if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
508        let mut extensions = Extensions::new();
509        conn_metadata.get_extras(&mut extensions);
510
511        // If the connection has an age limit, check to see if the connection is expired (i.e. too old) and "poison"
512        // it if so. Poisoning indicates to `hyper` that the connection should be closed/dropped instead of
513        // returning it back to the idle connection pool.
514        if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
515            if conn_age_limit.is_expired() {
516                debug!("connection is expired; poisoning it");
517                conn_metadata.poison();
518            }
519        }
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::{configure_tls_alpn_for_http_protocol, HttpProtocol};
526
527    fn empty_tls_config() -> rustls::ClientConfig {
528        rustls::ClientConfig::builder_with_provider(rustls::crypto::aws_lc_rs::default_provider().into())
529            .with_safe_default_protocol_versions()
530            .expect("AWS-LC default protocol versions should be valid")
531            .with_root_certificates(rustls::RootCertStore::empty())
532            .with_no_client_auth()
533    }
534
535    #[test]
536    fn auto_protocol_advertises_h2_and_http1_alpn() {
537        let tls_config = configure_tls_alpn_for_http_protocol(empty_tls_config(), HttpProtocol::Auto);
538
539        assert_eq!(tls_config.alpn_protocols, vec![b"h2".to_vec(), b"http/1.1".to_vec()]);
540    }
541
542    #[test]
543    fn http1_protocol_leaves_alpn_empty() {
544        let tls_config = configure_tls_alpn_for_http_protocol(empty_tls_config(), HttpProtocol::Http1);
545
546        assert!(tls_config.alpn_protocols.is_empty());
547    }
548}