saluki_io/net/client/http/
conn.rs1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6 time::{Duration, Instant},
7};
8
9use http::{Extensions, Uri};
10use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
11use hyper_util::{
12 client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
13 rt::TokioIo,
14};
15use metrics::Counter;
16use pin_project_lite::pin_project;
17use rustls::ClientConfig;
18use tokio::net::TcpStream;
19use tower::{BoxError, Service};
20use tracing::debug;
21
22#[derive(Clone)]
32struct ConnectionAgeLimit {
33 limit: Duration,
34 created: Instant,
35}
36
37impl ConnectionAgeLimit {
38 fn new(limit: Duration) -> Self {
39 ConnectionAgeLimit {
40 limit,
41 created: Instant::now(),
42 }
43 }
44
45 fn is_expired(&self) -> bool {
46 self.created.elapsed() >= self.limit
47 }
48}
49
50pin_project! {
51 pub struct HttpsCapableConnection {
53 #[pin]
54 inner: MaybeHttpsStream<TokioIo<TcpStream>>,
55 bytes_sent: Option<Counter>,
56 conn_age_limit: Option<Duration>,
57 }
58}
59
60impl Connection for HttpsCapableConnection {
61 fn connected(&self) -> Connected {
62 let connected = self.inner.connected();
63
64 if let Some(conn_age_limit) = self.conn_age_limit {
65 debug!("setting connection age limit to {:?}", conn_age_limit);
66 connected.extra(ConnectionAgeLimit::new(conn_age_limit))
67 } else {
68 connected
69 }
70 }
71}
72
73impl hyper::rt::Read for HttpsCapableConnection {
74 fn poll_read(
75 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
76 ) -> Poll<io::Result<()>> {
77 let this = self.project();
78 this.inner.poll_read(cx, buf)
79 }
80}
81
82impl hyper::rt::Write for HttpsCapableConnection {
83 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
84 let this = self.project();
85 match this.inner.poll_write(cx, buf) {
86 Poll::Ready(Ok(n)) => {
87 if let Some(bytes_sent) = this.bytes_sent {
88 bytes_sent.increment(n as u64);
89 }
90 Poll::Ready(Ok(n))
91 }
92 other => other,
93 }
94 }
95
96 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97 let this = self.project();
98 this.inner.poll_flush(cx)
99 }
100
101 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
102 let this = self.project();
103 this.inner.poll_shutdown(cx)
104 }
105
106 fn is_write_vectored(&self) -> bool {
107 self.inner.is_write_vectored()
108 }
109
110 fn poll_write_vectored(
111 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
112 ) -> Poll<io::Result<usize>> {
113 let this = self.project();
114 match this.inner.poll_write_vectored(cx, bufs) {
115 Poll::Ready(Ok(n)) => {
116 if let Some(bytes_sent) = this.bytes_sent {
117 bytes_sent.increment(n as u64);
118 }
119 Poll::Ready(Ok(n))
120 }
121 other => other,
122 }
123 }
124}
125
126#[derive(Clone)]
128pub struct HttpsCapableConnector {
129 inner: HttpsConnector<HttpConnector>,
130 bytes_sent: Option<Counter>,
131 conn_age_limit: Option<Duration>,
132}
133
134impl Service<Uri> for HttpsCapableConnector {
135 type Response = HttpsCapableConnection;
136 type Error = BoxError;
137 type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
138
139 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 self.inner.poll_ready(cx)
141 }
142
143 fn call(&mut self, dst: Uri) -> Self::Future {
144 let inner = self.inner.call(dst);
145 let bytes_sent = self.bytes_sent.clone();
146 let conn_age_limit = self.conn_age_limit;
147 Box::pin(async move {
148 inner.await.map(|inner| HttpsCapableConnection {
149 inner,
150 bytes_sent,
151 conn_age_limit,
152 })
153 })
154 }
155}
156
157#[derive(Default)]
159pub struct HttpsCapableConnectorBuilder {
160 connect_timeout: Option<Duration>,
161 bytes_sent: Option<Counter>,
162 conn_age_limit: Option<Duration>,
163}
164
165impl HttpsCapableConnectorBuilder {
166 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
170 self.connect_timeout = Some(timeout);
171 self
172 }
173
174 pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
181 where
182 L: Into<Option<Duration>>,
183 {
184 self.conn_age_limit = limit.into();
185 self
186 }
187
188 pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
195 self.bytes_sent = Some(counter);
196 self
197 }
198
199 pub fn build(self, tls_config: ClientConfig) -> HttpsCapableConnector {
201 let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
202
203 let mut http_connector = HttpConnector::new();
206 http_connector.set_connect_timeout(Some(connect_timeout));
207 http_connector.enforce_http(false);
208
209 let https_connector = HttpsConnectorBuilder::new()
211 .with_tls_config(tls_config)
212 .https_or_http()
213 .enable_all_versions()
214 .wrap_connector(http_connector);
215
216 HttpsCapableConnector {
217 inner: https_connector,
218 bytes_sent: self.bytes_sent,
219 conn_age_limit: self.conn_age_limit,
220 }
221 }
222}
223
224pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
225 let maybe_conn_metadata = captured_conn.connection_metadata();
226 if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
227 let mut extensions = Extensions::new();
228 conn_metadata.get_extras(&mut extensions);
229
230 if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
234 if conn_age_limit.is_expired() {
235 debug!("connection is expired; poisoning it");
236 conn_metadata.poison();
237 }
238 }
239 }
240}