saluki_io/net/client/http/
conn.rs

1use std::{
2    future::Future,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6    time::{Duration, Instant},
7};
8
9use http::{Extensions, Uri};
10use hyper_hickory::{TokioHickoryHttpConnector, TokioHickoryResolver};
11use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
12use hyper_util::{
13    client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
14    rt::TokioIo,
15};
16use metrics::Counter;
17use pin_project_lite::pin_project;
18use rustls::ClientConfig;
19use saluki_error::{ErrorContext as _, GenericError};
20use tokio::net::TcpStream;
21use tower::{BoxError, Service};
22use tracing::debug;
23
24/// Imposes a limit on the age of a connection.
25///
26/// In many cases, it is undesirable to hold onto a connection indefinitely, even if it can be theoretically reused.
27/// Doing so can make it more difficult to perform maintenance on infrastructure, as the expectation of old connections
28/// being eventually closed and replaced is not upheld.
29///
30/// This extension allows tracking the age of a connection (based on when the connector creates the connection) and
31/// checking if it is expired, or past the configured limit. Callers can then decide how to handle the expiration, such
32/// as by closing the connection.
33#[derive(Clone)]
34struct ConnectionAgeLimit {
35    limit: Duration,
36    created: Instant,
37}
38
39impl ConnectionAgeLimit {
40    fn new(limit: Duration) -> Self {
41        ConnectionAgeLimit {
42            limit,
43            created: Instant::now(),
44        }
45    }
46
47    fn is_expired(&self) -> bool {
48        self.created.elapsed() >= self.limit
49    }
50}
51
52pin_project! {
53    /// A connection that supports both HTTP and HTTPS.
54    pub struct HttpsCapableConnection {
55        #[pin]
56        inner: MaybeHttpsStream<TokioIo<TcpStream>>,
57        bytes_sent: Option<Counter>,
58        conn_age_limit: Option<Duration>,
59    }
60}
61
62impl Connection for HttpsCapableConnection {
63    fn connected(&self) -> Connected {
64        let connected = self.inner.connected();
65
66        if let Some(conn_age_limit) = self.conn_age_limit {
67            debug!("setting connection age limit to {:?}", conn_age_limit);
68            connected.extra(ConnectionAgeLimit::new(conn_age_limit))
69        } else {
70            connected
71        }
72    }
73}
74
75impl hyper::rt::Read for HttpsCapableConnection {
76    fn poll_read(
77        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
78    ) -> Poll<io::Result<()>> {
79        let this = self.project();
80        this.inner.poll_read(cx, buf)
81    }
82}
83
84impl hyper::rt::Write for HttpsCapableConnection {
85    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
86        let this = self.project();
87        match this.inner.poll_write(cx, buf) {
88            Poll::Ready(Ok(n)) => {
89                if let Some(bytes_sent) = this.bytes_sent {
90                    bytes_sent.increment(n as u64);
91                }
92                Poll::Ready(Ok(n))
93            }
94            other => other,
95        }
96    }
97
98    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99        let this = self.project();
100        this.inner.poll_flush(cx)
101    }
102
103    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
104        let this = self.project();
105        this.inner.poll_shutdown(cx)
106    }
107
108    fn is_write_vectored(&self) -> bool {
109        self.inner.is_write_vectored()
110    }
111
112    fn poll_write_vectored(
113        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
114    ) -> Poll<io::Result<usize>> {
115        let this = self.project();
116        match this.inner.poll_write_vectored(cx, bufs) {
117            Poll::Ready(Ok(n)) => {
118                if let Some(bytes_sent) = this.bytes_sent {
119                    bytes_sent.increment(n as u64);
120                }
121                Poll::Ready(Ok(n))
122            }
123            other => other,
124        }
125    }
126}
127
128/// A connector that supports HTTP or HTTPS.
129#[derive(Clone)]
130pub struct HttpsCapableConnector {
131    inner: HttpsConnector<TokioHickoryHttpConnector>,
132    bytes_sent: Option<Counter>,
133    conn_age_limit: Option<Duration>,
134}
135
136impl Service<Uri> for HttpsCapableConnector {
137    type Response = HttpsCapableConnection;
138    type Error = BoxError;
139    type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
140
141    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142        self.inner.poll_ready(cx)
143    }
144
145    fn call(&mut self, dst: Uri) -> Self::Future {
146        let inner = self.inner.call(dst);
147        let bytes_sent = self.bytes_sent.clone();
148        let conn_age_limit = self.conn_age_limit;
149        Box::pin(async move {
150            inner.await.map(|inner| HttpsCapableConnection {
151                inner,
152                bytes_sent,
153                conn_age_limit,
154            })
155        })
156    }
157}
158
159/// A builder for `HttpsCapableConnector`.
160#[derive(Default)]
161pub struct HttpsCapableConnectorBuilder {
162    connect_timeout: Option<Duration>,
163    bytes_sent: Option<Counter>,
164    conn_age_limit: Option<Duration>,
165}
166
167impl HttpsCapableConnectorBuilder {
168    /// Sets the timeout when connecting to the remote host.
169    ///
170    /// Defaults to 30 seconds.
171    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
172        self.connect_timeout = Some(timeout);
173        self
174    }
175
176    /// Sets the maximum age of a connection before it is closed.
177    ///
178    /// This is distinct from the maximum idle time: if any connection's age exceeds `limit`, it will be closed rather
179    /// than being reused and added to the idle connection pool.
180    ///
181    /// Defaults to no limit.
182    pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
183    where
184        L: Into<Option<Duration>>,
185    {
186        self.conn_age_limit = limit.into();
187        self
188    }
189
190    /// Sets a counter that gets incremented with the number of bytes sent over the connection.
191    ///
192    /// This tracks bytes sent at the HTTP client level, which includes headers and body but does not include underlying
193    /// transport overhead, such as TLS handshaking, and so on.
194    ///
195    /// Defaults to unset.
196    pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
197        self.bytes_sent = Some(counter);
198        self
199    }
200
201    /// Builds the `HttpsCapableConnector` from the given TLS configuration.
202    pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
203        let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
204
205        let hickory_resolver = TokioHickoryResolver::from_system_conf()
206            .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
207
208        // Create the HTTP connector, and ensure that we don't enforce _only_ HTTP, since that will break being able to
209        // wrap this in an HTTPS connector.
210        let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
211        http_connector.set_connect_timeout(Some(connect_timeout));
212        http_connector.enforce_http(false);
213
214        // Create the HTTPS connector.
215        let https_connector = HttpsConnectorBuilder::new()
216            .with_tls_config(tls_config)
217            .https_or_http()
218            .enable_all_versions()
219            .wrap_connector(http_connector);
220
221        Ok(HttpsCapableConnector {
222            inner: https_connector,
223            bytes_sent: self.bytes_sent,
224            conn_age_limit: self.conn_age_limit,
225        })
226    }
227}
228
229pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
230    let maybe_conn_metadata = captured_conn.connection_metadata();
231    if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
232        let mut extensions = Extensions::new();
233        conn_metadata.get_extras(&mut extensions);
234
235        // If the connection has an age limit, check to see if the connection is expired (i.e. too old) and "poison"
236        // it if so. Poisoning indicates to `hyper` that the connection should be closed/dropped instead of
237        // returning it back to the idle connection pool.
238        if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
239            if conn_age_limit.is_expired() {
240                debug!("connection is expired; poisoning it");
241                conn_metadata.poison();
242            }
243        }
244    }
245}