Skip to main content

saluki_io/net/client/http/
conn.rs

1use std::{
2    future::Future,
3    io,
4    path::PathBuf,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use http::{Extensions, Uri};
12use hyper_hickory::{TokioHickoryHttpConnector, TokioHickoryResolver};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
14use hyper_util::{
15    client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
16    rt::TokioIo,
17};
18use metrics::Counter;
19use pin_project_lite::pin_project;
20use rustls::ClientConfig;
21use saluki_error::{ErrorContext as _, GenericError};
22use tokio::net::TcpStream;
23use tower::{BoxError, Service};
24use tracing::debug;
25
26/// Imposes a limit on the age of a connection.
27///
28/// In many cases, it is undesirable to hold onto a connection indefinitely, even if it can be theoretically reused.
29/// Doing so can make it more difficult to perform maintenance on infrastructure, as the expectation of old connections
30/// being eventually closed and replaced is not upheld.
31///
32/// This extension allows tracking the age of a connection (based on when the connector creates the connection) and
33/// checking if it is expired, or past the configured limit. Callers can then decide how to handle the expiration, such
34/// as by closing the connection.
35#[derive(Clone)]
36struct ConnectionAgeLimit {
37    limit: Duration,
38    created: Instant,
39}
40
41impl ConnectionAgeLimit {
42    fn new(limit: Duration) -> Self {
43        ConnectionAgeLimit {
44            limit,
45            created: Instant::now(),
46        }
47    }
48
49    fn is_expired(&self) -> bool {
50        self.created.elapsed() >= self.limit
51    }
52}
53
54/// An inner transport that abstracts over TCP and Unix domain socket connections.
55///
56/// This allows using a single monomorphization of the HTTP/2 and TLS stacks regardless of the
57/// underlying transport, avoiding duplicate code generation for each transport type.
58enum Transport {
59    Tcp(TokioIo<TcpStream>),
60    #[cfg(unix)]
61    Unix(TokioIo<tokio::net::UnixStream>),
62}
63
64impl Connection for Transport {
65    fn connected(&self) -> Connected {
66        match self {
67            Self::Tcp(s) => s.connected(),
68            #[cfg(unix)]
69            Self::Unix(_) => Connected::new(),
70        }
71    }
72}
73
74impl hyper::rt::Read for Transport {
75    fn poll_read(
76        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
77    ) -> Poll<io::Result<()>> {
78        match Pin::get_mut(self) {
79            Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
80            #[cfg(unix)]
81            Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
82        }
83    }
84}
85
86impl hyper::rt::Write for Transport {
87    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
88        match Pin::get_mut(self) {
89            Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
90            #[cfg(unix)]
91            Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
92        }
93    }
94
95    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96        match Pin::get_mut(self) {
97            Self::Tcp(s) => Pin::new(s).poll_flush(cx),
98            #[cfg(unix)]
99            Self::Unix(s) => Pin::new(s).poll_flush(cx),
100        }
101    }
102
103    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
104        match Pin::get_mut(self) {
105            Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
106            #[cfg(unix)]
107            Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
108        }
109    }
110
111    fn is_write_vectored(&self) -> bool {
112        match self {
113            Self::Tcp(s) => s.is_write_vectored(),
114            #[cfg(unix)]
115            Self::Unix(s) => s.is_write_vectored(),
116        }
117    }
118
119    fn poll_write_vectored(
120        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
121    ) -> Poll<io::Result<usize>> {
122        match Pin::get_mut(self) {
123            Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
124            #[cfg(unix)]
125            Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
126        }
127    }
128}
129
130pin_project! {
131    /// A connection that supports both HTTP and HTTPS.
132    pub struct HttpsCapableConnection {
133        #[pin]
134        inner: MaybeHttpsStream<Transport>,
135        bytes_sent: Option<Counter>,
136        conn_age_limit: Option<Duration>,
137    }
138}
139
140impl Connection for HttpsCapableConnection {
141    fn connected(&self) -> Connected {
142        let connected = self.inner.connected();
143
144        if let Some(conn_age_limit) = self.conn_age_limit {
145            debug!("setting connection age limit to {:?}", conn_age_limit);
146            connected.extra(ConnectionAgeLimit::new(conn_age_limit))
147        } else {
148            connected
149        }
150    }
151}
152
153impl hyper::rt::Read for HttpsCapableConnection {
154    fn poll_read(
155        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
156    ) -> Poll<io::Result<()>> {
157        let this = self.project();
158        this.inner.poll_read(cx, buf)
159    }
160}
161
162impl hyper::rt::Write for HttpsCapableConnection {
163    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
164        let this = self.project();
165        match this.inner.poll_write(cx, buf) {
166            Poll::Ready(Ok(n)) => {
167                if let Some(bytes_sent) = this.bytes_sent {
168                    bytes_sent.increment(n as u64);
169                }
170                Poll::Ready(Ok(n))
171            }
172            other => other,
173        }
174    }
175
176    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
177        let this = self.project();
178        this.inner.poll_flush(cx)
179    }
180
181    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182        let this = self.project();
183        this.inner.poll_shutdown(cx)
184    }
185
186    fn is_write_vectored(&self) -> bool {
187        self.inner.is_write_vectored()
188    }
189
190    fn poll_write_vectored(
191        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
192    ) -> Poll<io::Result<usize>> {
193        let this = self.project();
194        match this.inner.poll_write_vectored(cx, bufs) {
195            Poll::Ready(Ok(n)) => {
196                if let Some(bytes_sent) = this.bytes_sent {
197                    bytes_sent.increment(n as u64);
198                }
199                Poll::Ready(Ok(n))
200            }
201            other => other,
202        }
203    }
204}
205
206/// An inner connector that routes to either TCP (via DNS) or a Unix domain socket.
207///
208/// When a Unix socket path is configured, all connections are routed through that socket regardless
209/// of the URI host. Otherwise, connections are routed via the standard DNS + TCP path.
210#[derive(Clone)]
211struct InnerConnector {
212    http: TokioHickoryHttpConnector,
213    connect_timeout: Duration,
214    #[cfg(unix)]
215    unix_socket_path: Option<Arc<std::path::Path>>,
216}
217
218impl Service<Uri> for InnerConnector {
219    type Response = Transport;
220    type Error = BoxError;
221    type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
222
223    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224        // When routing via a Unix domain socket, the TCP/DNS connector is not used, so we consider
225        // the service immediately ready.
226        #[cfg(unix)]
227        if self.unix_socket_path.is_some() {
228            return Poll::Ready(Ok(()));
229        }
230
231        self.http.poll_ready(cx).map_err(Into::into)
232    }
233
234    fn call(&mut self, dst: Uri) -> Self::Future {
235        #[cfg(unix)]
236        if let Some(path) = self.unix_socket_path.clone() {
237            let connect_timeout = self.connect_timeout;
238            return Box::pin(async move {
239                let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
240                    .await
241                    .map_err(|_| -> BoxError {
242                        Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
243                    })?
244                    .map_err(|e| -> BoxError { Box::new(e) })?;
245                Ok(Transport::Unix(TokioIo::new(stream)))
246            });
247        }
248
249        let fut = self.http.call(dst);
250        Box::pin(async move {
251            let tcp = fut.await.map_err(BoxError::from)?;
252            Ok(Transport::Tcp(tcp))
253        })
254    }
255}
256
257/// A connector that supports HTTP or HTTPS.
258#[derive(Clone)]
259pub struct HttpsCapableConnector {
260    inner: HttpsConnector<InnerConnector>,
261    bytes_sent: Option<Counter>,
262    conn_age_limit: Option<Duration>,
263}
264
265impl Service<Uri> for HttpsCapableConnector {
266    type Response = HttpsCapableConnection;
267    type Error = BoxError;
268    type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
269
270    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
271        self.inner.poll_ready(cx)
272    }
273
274    fn call(&mut self, dst: Uri) -> Self::Future {
275        let inner = self.inner.call(dst);
276        let bytes_sent = self.bytes_sent.clone();
277        let conn_age_limit = self.conn_age_limit;
278        Box::pin(async move {
279            inner.await.map(|inner| HttpsCapableConnection {
280                inner,
281                bytes_sent,
282                conn_age_limit,
283            })
284        })
285    }
286}
287
288/// A builder for `HttpsCapableConnector`.
289#[derive(Default)]
290pub struct HttpsCapableConnectorBuilder {
291    connect_timeout: Option<Duration>,
292    bytes_sent: Option<Counter>,
293    conn_age_limit: Option<Duration>,
294    #[cfg(unix)]
295    unix_socket_path: Option<PathBuf>,
296}
297
298impl HttpsCapableConnectorBuilder {
299    /// Sets the timeout when connecting to the remote host.
300    ///
301    /// Defaults to 30 seconds.
302    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
303        self.connect_timeout = Some(timeout);
304        self
305    }
306
307    /// Sets the maximum age of a connection before it is closed.
308    ///
309    /// This is distinct from the maximum idle time: if any connection's age exceeds `limit`, it will be closed rather
310    /// than being reused and added to the idle connection pool.
311    ///
312    /// Defaults to no limit.
313    pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
314    where
315        L: Into<Option<Duration>>,
316    {
317        self.conn_age_limit = limit.into();
318        self
319    }
320
321    /// Sets a counter that gets incremented with the number of bytes sent over the connection.
322    ///
323    /// This tracks bytes sent at the HTTP client level, which includes headers and body but does not include underlying
324    /// transport overhead, such as TLS handshaking, and so on.
325    ///
326    /// Defaults to unset.
327    pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
328        self.bytes_sent = Some(counter);
329        self
330    }
331
332    /// Sets a Unix domain socket path to route all connections through.
333    ///
334    /// When set, the connector will connect to this Unix socket instead of performing DNS resolution
335    /// and TCP connection. The URI host is ignored in this case — all requests are sent through the
336    /// configured socket.
337    ///
338    /// Defaults to unset (TCP connections via DNS).
339    #[cfg(unix)]
340    pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
341        self.unix_socket_path = Some(path.into());
342        self
343    }
344
345    /// Builds the `HttpsCapableConnector` from the given TLS configuration.
346    pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
347        let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
348
349        let hickory_resolver = TokioHickoryResolver::from_system_conf()
350            .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
351
352        // Create the HTTP connector, and ensure that we don't enforce _only_ HTTP, since that will break being able to
353        // wrap this in an HTTPS connector.
354        let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
355        http_connector.set_connect_timeout(Some(connect_timeout));
356        http_connector.enforce_http(false);
357
358        let inner_connector = InnerConnector {
359            http: http_connector,
360            connect_timeout,
361            #[cfg(unix)]
362            unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
363        };
364
365        // Create the HTTPS connector.
366        let https_connector = HttpsConnectorBuilder::new()
367            .with_tls_config(tls_config)
368            .https_or_http()
369            .enable_all_versions()
370            .wrap_connector(inner_connector);
371
372        Ok(HttpsCapableConnector {
373            inner: https_connector,
374            bytes_sent: self.bytes_sent,
375            conn_age_limit: self.conn_age_limit,
376        })
377    }
378}
379
380pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
381    let maybe_conn_metadata = captured_conn.connection_metadata();
382    if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
383        let mut extensions = Extensions::new();
384        conn_metadata.get_extras(&mut extensions);
385
386        // If the connection has an age limit, check to see if the connection is expired (i.e. too old) and "poison"
387        // it if so. Poisoning indicates to `hyper` that the connection should be closed/dropped instead of
388        // returning it back to the idle connection pool.
389        if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
390            if conn_age_limit.is_expired() {
391                debug!("connection is expired; poisoning it");
392                conn_metadata.poison();
393            }
394        }
395    }
396}