Skip to main content

saluki_io/net/client/http/
conn.rs

1use std::{
2    future::Future,
3    io,
4    path::PathBuf,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use http::{Extensions, Uri};
12use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
13use hyper_util::{
14    client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
15    rt::TokioIo,
16};
17use metrics::Counter;
18use pin_project_lite::pin_project;
19use rustls::ClientConfig;
20use saluki_error::{ErrorContext as _, GenericError};
21use tokio::net::TcpStream;
22use tower::{BoxError, Service};
23use tracing::debug;
24
25use crate::net::dns::{HickoryHttpConnector, HickoryResolver};
26
27/// Imposes a limit on the age of a connection.
28///
29/// In many cases, it is undesirable to hold onto a connection indefinitely, even if it can be theoretically reused.
30/// Doing so can make it more difficult to perform maintenance on infrastructure, as the expectation of old connections
31/// being eventually closed and replaced is not upheld.
32///
33/// This extension allows tracking the age of a connection (based on when the connector creates the connection) and
34/// checking if it is expired, or past the configured limit. Callers can then decide how to handle the expiration, such
35/// as by closing the connection.
36#[derive(Clone)]
37struct ConnectionAgeLimit {
38    limit: Duration,
39    created: Instant,
40}
41
42impl ConnectionAgeLimit {
43    fn new(limit: Duration) -> Self {
44        ConnectionAgeLimit {
45            limit,
46            created: Instant::now(),
47        }
48    }
49
50    fn is_expired(&self) -> bool {
51        self.created.elapsed() >= self.limit
52    }
53}
54
55/// An inner transport that abstracts over TCP and Unix domain socket connections.
56///
57/// This allows using a single monomorphization of the HTTP/2 and TLS stacks regardless of the
58/// underlying transport, avoiding duplicate code generation for each transport type.
59enum Transport {
60    Tcp(TokioIo<TcpStream>),
61    #[cfg(unix)]
62    Unix(TokioIo<tokio::net::UnixStream>),
63}
64
65impl Connection for Transport {
66    fn connected(&self) -> Connected {
67        match self {
68            Self::Tcp(s) => s.connected(),
69            #[cfg(unix)]
70            Self::Unix(_) => Connected::new(),
71        }
72    }
73}
74
75impl hyper::rt::Read for Transport {
76    fn poll_read(
77        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
78    ) -> Poll<io::Result<()>> {
79        match Pin::get_mut(self) {
80            Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
81            #[cfg(unix)]
82            Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
83        }
84    }
85}
86
87impl hyper::rt::Write for Transport {
88    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
89        match Pin::get_mut(self) {
90            Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
91            #[cfg(unix)]
92            Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
93        }
94    }
95
96    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97        match Pin::get_mut(self) {
98            Self::Tcp(s) => Pin::new(s).poll_flush(cx),
99            #[cfg(unix)]
100            Self::Unix(s) => Pin::new(s).poll_flush(cx),
101        }
102    }
103
104    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105        match Pin::get_mut(self) {
106            Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
107            #[cfg(unix)]
108            Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
109        }
110    }
111
112    fn is_write_vectored(&self) -> bool {
113        match self {
114            Self::Tcp(s) => s.is_write_vectored(),
115            #[cfg(unix)]
116            Self::Unix(s) => s.is_write_vectored(),
117        }
118    }
119
120    fn poll_write_vectored(
121        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
122    ) -> Poll<io::Result<usize>> {
123        match Pin::get_mut(self) {
124            Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
125            #[cfg(unix)]
126            Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
127        }
128    }
129}
130
131pin_project! {
132    /// A connection that supports both HTTP and HTTPS.
133    pub struct HttpsCapableConnection {
134        #[pin]
135        inner: MaybeHttpsStream<Transport>,
136        bytes_sent: Option<Counter>,
137        conn_age_limit: Option<Duration>,
138    }
139}
140
141impl Connection for HttpsCapableConnection {
142    fn connected(&self) -> Connected {
143        let connected = self.inner.connected();
144
145        if let Some(conn_age_limit) = self.conn_age_limit {
146            debug!("setting connection age limit to {:?}", conn_age_limit);
147            connected.extra(ConnectionAgeLimit::new(conn_age_limit))
148        } else {
149            connected
150        }
151    }
152}
153
154impl hyper::rt::Read for HttpsCapableConnection {
155    fn poll_read(
156        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
157    ) -> Poll<io::Result<()>> {
158        let this = self.project();
159        this.inner.poll_read(cx, buf)
160    }
161}
162
163impl hyper::rt::Write for HttpsCapableConnection {
164    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
165        let this = self.project();
166        match this.inner.poll_write(cx, buf) {
167            Poll::Ready(Ok(n)) => {
168                if let Some(bytes_sent) = this.bytes_sent {
169                    bytes_sent.increment(n as u64);
170                }
171                Poll::Ready(Ok(n))
172            }
173            other => other,
174        }
175    }
176
177    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178        let this = self.project();
179        this.inner.poll_flush(cx)
180    }
181
182    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
183        let this = self.project();
184        this.inner.poll_shutdown(cx)
185    }
186
187    fn is_write_vectored(&self) -> bool {
188        self.inner.is_write_vectored()
189    }
190
191    fn poll_write_vectored(
192        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
193    ) -> Poll<io::Result<usize>> {
194        let this = self.project();
195        match this.inner.poll_write_vectored(cx, bufs) {
196            Poll::Ready(Ok(n)) => {
197                if let Some(bytes_sent) = this.bytes_sent {
198                    bytes_sent.increment(n as u64);
199                }
200                Poll::Ready(Ok(n))
201            }
202            other => other,
203        }
204    }
205}
206
207/// An inner connector that routes to either TCP (via DNS) or a Unix domain socket.
208///
209/// When a Unix socket path is configured, all connections are routed through that socket regardless
210/// of the URI host. Otherwise, connections are routed via the standard DNS + TCP path.
211#[derive(Clone)]
212struct InnerConnector {
213    http: HickoryHttpConnector,
214    connect_timeout: Duration,
215    #[cfg(unix)]
216    unix_socket_path: Option<Arc<std::path::Path>>,
217}
218
219impl Service<Uri> for InnerConnector {
220    type Response = Transport;
221    type Error = BoxError;
222    type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
223
224    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225        // When routing via a Unix domain socket, the TCP/DNS connector is not used, so we consider
226        // the service immediately ready.
227        #[cfg(unix)]
228        if self.unix_socket_path.is_some() {
229            return Poll::Ready(Ok(()));
230        }
231
232        self.http.poll_ready(cx).map_err(Into::into)
233    }
234
235    fn call(&mut self, dst: Uri) -> Self::Future {
236        #[cfg(unix)]
237        if let Some(path) = self.unix_socket_path.clone() {
238            let connect_timeout = self.connect_timeout;
239            return Box::pin(async move {
240                let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
241                    .await
242                    .map_err(|_| -> BoxError {
243                        Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
244                    })?
245                    .map_err(|e| -> BoxError { Box::new(e) })?;
246                Ok(Transport::Unix(TokioIo::new(stream)))
247            });
248        }
249
250        let fut = self.http.call(dst);
251        Box::pin(async move {
252            let tcp = fut.await.map_err(BoxError::from)?;
253            Ok(Transport::Tcp(tcp))
254        })
255    }
256}
257
258/// A connector that supports HTTP or HTTPS.
259#[derive(Clone)]
260pub struct HttpsCapableConnector {
261    inner: HttpsConnector<InnerConnector>,
262    bytes_sent: Option<Counter>,
263    conn_age_limit: Option<Duration>,
264}
265
266impl Service<Uri> for HttpsCapableConnector {
267    type Response = HttpsCapableConnection;
268    type Error = BoxError;
269    type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
270
271    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272        self.inner.poll_ready(cx)
273    }
274
275    fn call(&mut self, dst: Uri) -> Self::Future {
276        let inner = self.inner.call(dst);
277        let bytes_sent = self.bytes_sent.clone();
278        let conn_age_limit = self.conn_age_limit;
279        Box::pin(async move {
280            inner.await.map(|inner| HttpsCapableConnection {
281                inner,
282                bytes_sent,
283                conn_age_limit,
284            })
285        })
286    }
287}
288
289/// A builder for `HttpsCapableConnector`.
290#[derive(Default)]
291pub struct HttpsCapableConnectorBuilder {
292    connect_timeout: Option<Duration>,
293    bytes_sent: Option<Counter>,
294    conn_age_limit: Option<Duration>,
295    #[cfg(unix)]
296    unix_socket_path: Option<PathBuf>,
297}
298
299impl HttpsCapableConnectorBuilder {
300    /// Sets the timeout when connecting to the remote host.
301    ///
302    /// Defaults to 30 seconds.
303    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
304        self.connect_timeout = Some(timeout);
305        self
306    }
307
308    /// Sets the maximum age of a connection before it is closed.
309    ///
310    /// This is distinct from the maximum idle time: if any connection's age exceeds `limit`, it will be closed rather
311    /// than being reused and added to the idle connection pool.
312    ///
313    /// Defaults to no limit.
314    pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
315    where
316        L: Into<Option<Duration>>,
317    {
318        self.conn_age_limit = limit.into();
319        self
320    }
321
322    /// Sets a counter that gets incremented with the number of bytes sent over the connection.
323    ///
324    /// This tracks bytes sent at the HTTP client level, which includes headers and body but does not include underlying
325    /// transport overhead, such as TLS handshaking, and so on.
326    ///
327    /// Defaults to unset.
328    pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
329        self.bytes_sent = Some(counter);
330        self
331    }
332
333    /// Sets a Unix domain socket path to route all connections through.
334    ///
335    /// When set, the connector will connect to this Unix socket instead of performing DNS resolution
336    /// and TCP connection. The URI host is ignored in this case — all requests are sent through the
337    /// configured socket.
338    ///
339    /// Defaults to unset (TCP connections via DNS).
340    #[cfg(unix)]
341    pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
342        self.unix_socket_path = Some(path.into());
343        self
344    }
345
346    /// Builds the `HttpsCapableConnector` from the given TLS configuration.
347    pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
348        let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
349
350        let hickory_resolver = HickoryResolver::from_system_conf()
351            .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
352
353        // Create the HTTP connector, and ensure that we don't enforce _only_ HTTP, since that will break being able to
354        // wrap this in an HTTPS connector.
355        let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
356        http_connector.set_connect_timeout(Some(connect_timeout));
357        http_connector.enforce_http(false);
358
359        let inner_connector = InnerConnector {
360            http: http_connector,
361            connect_timeout,
362            #[cfg(unix)]
363            unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
364        };
365
366        // Create the HTTPS connector.
367        let https_connector = HttpsConnectorBuilder::new()
368            .with_tls_config(tls_config)
369            .https_or_http()
370            .enable_all_versions()
371            .wrap_connector(inner_connector);
372
373        Ok(HttpsCapableConnector {
374            inner: https_connector,
375            bytes_sent: self.bytes_sent,
376            conn_age_limit: self.conn_age_limit,
377        })
378    }
379}
380
381pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
382    let maybe_conn_metadata = captured_conn.connection_metadata();
383    if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
384        let mut extensions = Extensions::new();
385        conn_metadata.get_extras(&mut extensions);
386
387        // If the connection has an age limit, check to see if the connection is expired (i.e. too old) and "poison"
388        // it if so. Poisoning indicates to `hyper` that the connection should be closed/dropped instead of
389        // returning it back to the idle connection pool.
390        if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
391            if conn_age_limit.is_expired() {
392                debug!("connection is expired; poisoning it");
393                conn_metadata.poison();
394            }
395        }
396    }
397}