saluki_io/net/client/http/
conn.rs

1use std::{
2    future::Future,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6    time::{Duration, Instant},
7};
8
9use http::{Extensions, Uri};
10use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
11use hyper_util::{
12    client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
13    rt::TokioIo,
14};
15use metrics::Counter;
16use pin_project_lite::pin_project;
17use rustls::ClientConfig;
18use tokio::net::TcpStream;
19use tower::{BoxError, Service};
20use tracing::debug;
21
22/// Imposes a limit on the age of a connection.
23///
24/// In many cases, it is undesirable to hold onto a connection indefinitely, even if it can be theoretically reused.
25/// Doing so can make it more difficult to perform maintenance on infrastructure, as the expectation of old connections
26/// being eventually closed and replaced is not upheld.
27///
28/// This extension allows tracking the age of a connection (based on when the connector creates the connection) and
29/// checking if it is expired, or past the configured limit. Callers can then decide how to handle the expiration, such
30/// as by closing the connection.
31#[derive(Clone)]
32struct ConnectionAgeLimit {
33    limit: Duration,
34    created: Instant,
35}
36
37impl ConnectionAgeLimit {
38    fn new(limit: Duration) -> Self {
39        ConnectionAgeLimit {
40            limit,
41            created: Instant::now(),
42        }
43    }
44
45    fn is_expired(&self) -> bool {
46        self.created.elapsed() >= self.limit
47    }
48}
49
50pin_project! {
51    /// A connection that supports both HTTP and HTTPS.
52    pub struct HttpsCapableConnection {
53        #[pin]
54        inner: MaybeHttpsStream<TokioIo<TcpStream>>,
55        bytes_sent: Option<Counter>,
56        conn_age_limit: Option<Duration>,
57    }
58}
59
60impl Connection for HttpsCapableConnection {
61    fn connected(&self) -> Connected {
62        let connected = self.inner.connected();
63
64        if let Some(conn_age_limit) = self.conn_age_limit {
65            debug!("setting connection age limit to {:?}", conn_age_limit);
66            connected.extra(ConnectionAgeLimit::new(conn_age_limit))
67        } else {
68            connected
69        }
70    }
71}
72
73impl hyper::rt::Read for HttpsCapableConnection {
74    fn poll_read(
75        self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
76    ) -> Poll<io::Result<()>> {
77        let this = self.project();
78        this.inner.poll_read(cx, buf)
79    }
80}
81
82impl hyper::rt::Write for HttpsCapableConnection {
83    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
84        let this = self.project();
85        match this.inner.poll_write(cx, buf) {
86            Poll::Ready(Ok(n)) => {
87                if let Some(bytes_sent) = this.bytes_sent {
88                    bytes_sent.increment(n as u64);
89                }
90                Poll::Ready(Ok(n))
91            }
92            other => other,
93        }
94    }
95
96    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97        let this = self.project();
98        this.inner.poll_flush(cx)
99    }
100
101    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
102        let this = self.project();
103        this.inner.poll_shutdown(cx)
104    }
105
106    fn is_write_vectored(&self) -> bool {
107        self.inner.is_write_vectored()
108    }
109
110    fn poll_write_vectored(
111        self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
112    ) -> Poll<io::Result<usize>> {
113        let this = self.project();
114        match this.inner.poll_write_vectored(cx, bufs) {
115            Poll::Ready(Ok(n)) => {
116                if let Some(bytes_sent) = this.bytes_sent {
117                    bytes_sent.increment(n as u64);
118                }
119                Poll::Ready(Ok(n))
120            }
121            other => other,
122        }
123    }
124}
125
126/// A connector that supports HTTP or HTTPS.
127#[derive(Clone)]
128pub struct HttpsCapableConnector {
129    inner: HttpsConnector<HttpConnector>,
130    bytes_sent: Option<Counter>,
131    conn_age_limit: Option<Duration>,
132}
133
134impl Service<Uri> for HttpsCapableConnector {
135    type Response = HttpsCapableConnection;
136    type Error = BoxError;
137    type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
138
139    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140        self.inner.poll_ready(cx)
141    }
142
143    fn call(&mut self, dst: Uri) -> Self::Future {
144        let inner = self.inner.call(dst);
145        let bytes_sent = self.bytes_sent.clone();
146        let conn_age_limit = self.conn_age_limit;
147        Box::pin(async move {
148            inner.await.map(|inner| HttpsCapableConnection {
149                inner,
150                bytes_sent,
151                conn_age_limit,
152            })
153        })
154    }
155}
156
157/// A builder for `HttpsCapableConnector`.
158#[derive(Default)]
159pub struct HttpsCapableConnectorBuilder {
160    connect_timeout: Option<Duration>,
161    bytes_sent: Option<Counter>,
162    conn_age_limit: Option<Duration>,
163}
164
165impl HttpsCapableConnectorBuilder {
166    /// Sets the timeout when connecting to the remote host.
167    ///
168    /// Defaults to 30 seconds.
169    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
170        self.connect_timeout = Some(timeout);
171        self
172    }
173
174    /// Sets the maximum age of a connection before it is closed.
175    ///
176    /// This is distinct from the maximum idle time: if any connection's age exceeds `limit`, it will be closed rather
177    /// than being reused and added to the idle connection pool.
178    ///
179    /// Defaults to no limit.
180    pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
181    where
182        L: Into<Option<Duration>>,
183    {
184        self.conn_age_limit = limit.into();
185        self
186    }
187
188    /// Sets a counter that gets incremented with the number of bytes sent over the connection.
189    ///
190    /// This tracks bytes sent at the HTTP client level, which includes headers and body but does not include underlying
191    /// transport overhead, such as TLS handshaking, and so on.
192    ///
193    /// Defaults to unset.
194    pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
195        self.bytes_sent = Some(counter);
196        self
197    }
198
199    /// Builds the `HttpsCapableConnector` from the given TLS configuration.
200    pub fn build(self, tls_config: ClientConfig) -> HttpsCapableConnector {
201        let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
202
203        // Create the HTTP connector, and ensure that we don't enforce _only_ HTTP, since that will break being able to
204        // wrap this in an HTTPS connector.
205        let mut http_connector = HttpConnector::new();
206        http_connector.set_connect_timeout(Some(connect_timeout));
207        http_connector.enforce_http(false);
208
209        // Create the HTTPS connector.
210        let https_connector = HttpsConnectorBuilder::new()
211            .with_tls_config(tls_config)
212            .https_or_http()
213            .enable_all_versions()
214            .wrap_connector(http_connector);
215
216        HttpsCapableConnector {
217            inner: https_connector,
218            bytes_sent: self.bytes_sent,
219            conn_age_limit: self.conn_age_limit,
220        }
221    }
222}
223
224pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
225    let maybe_conn_metadata = captured_conn.connection_metadata();
226    if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
227        let mut extensions = Extensions::new();
228        conn_metadata.get_extras(&mut extensions);
229
230        // If the connection has an age limit, check to see if the connection is expired (i.e. too old) and "poison"
231        // it if so. Poisoning indicates to `hyper` that the connection should be closed/dropped instead of
232        // returning it back to the idle connection pool.
233        if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
234            if conn_age_limit.is_expired() {
235                debug!("connection is expired; poisoning it");
236                conn_metadata.poison();
237            }
238        }
239    }
240}