saluki_io/net/client/http/
conn.rs1use std::{
2 future::Future,
3 io,
4 path::PathBuf,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use http::{Extensions, Uri};
12use hyper_hickory::{TokioHickoryHttpConnector, TokioHickoryResolver};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
14use hyper_util::{
15 client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
16 rt::TokioIo,
17};
18use metrics::Counter;
19use pin_project_lite::pin_project;
20use rustls::ClientConfig;
21use saluki_error::{ErrorContext as _, GenericError};
22use tokio::net::TcpStream;
23use tower::{BoxError, Service};
24use tracing::debug;
25
26#[derive(Clone)]
36struct ConnectionAgeLimit {
37 limit: Duration,
38 created: Instant,
39}
40
41impl ConnectionAgeLimit {
42 fn new(limit: Duration) -> Self {
43 ConnectionAgeLimit {
44 limit,
45 created: Instant::now(),
46 }
47 }
48
49 fn is_expired(&self) -> bool {
50 self.created.elapsed() >= self.limit
51 }
52}
53
54enum Transport {
59 Tcp(TokioIo<TcpStream>),
60 #[cfg(unix)]
61 Unix(TokioIo<tokio::net::UnixStream>),
62}
63
64impl Connection for Transport {
65 fn connected(&self) -> Connected {
66 match self {
67 Self::Tcp(s) => s.connected(),
68 #[cfg(unix)]
69 Self::Unix(_) => Connected::new(),
70 }
71 }
72}
73
74impl hyper::rt::Read for Transport {
75 fn poll_read(
76 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
77 ) -> Poll<io::Result<()>> {
78 match Pin::get_mut(self) {
79 Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
80 #[cfg(unix)]
81 Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
82 }
83 }
84}
85
86impl hyper::rt::Write for Transport {
87 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
88 match Pin::get_mut(self) {
89 Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
90 #[cfg(unix)]
91 Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
92 }
93 }
94
95 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96 match Pin::get_mut(self) {
97 Self::Tcp(s) => Pin::new(s).poll_flush(cx),
98 #[cfg(unix)]
99 Self::Unix(s) => Pin::new(s).poll_flush(cx),
100 }
101 }
102
103 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
104 match Pin::get_mut(self) {
105 Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
106 #[cfg(unix)]
107 Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
108 }
109 }
110
111 fn is_write_vectored(&self) -> bool {
112 match self {
113 Self::Tcp(s) => s.is_write_vectored(),
114 #[cfg(unix)]
115 Self::Unix(s) => s.is_write_vectored(),
116 }
117 }
118
119 fn poll_write_vectored(
120 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
121 ) -> Poll<io::Result<usize>> {
122 match Pin::get_mut(self) {
123 Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
124 #[cfg(unix)]
125 Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
126 }
127 }
128}
129
130pin_project! {
131 pub struct HttpsCapableConnection {
133 #[pin]
134 inner: MaybeHttpsStream<Transport>,
135 bytes_sent: Option<Counter>,
136 conn_age_limit: Option<Duration>,
137 }
138}
139
140impl Connection for HttpsCapableConnection {
141 fn connected(&self) -> Connected {
142 let connected = self.inner.connected();
143
144 if let Some(conn_age_limit) = self.conn_age_limit {
145 debug!("setting connection age limit to {:?}", conn_age_limit);
146 connected.extra(ConnectionAgeLimit::new(conn_age_limit))
147 } else {
148 connected
149 }
150 }
151}
152
153impl hyper::rt::Read for HttpsCapableConnection {
154 fn poll_read(
155 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
156 ) -> Poll<io::Result<()>> {
157 let this = self.project();
158 this.inner.poll_read(cx, buf)
159 }
160}
161
162impl hyper::rt::Write for HttpsCapableConnection {
163 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
164 let this = self.project();
165 match this.inner.poll_write(cx, buf) {
166 Poll::Ready(Ok(n)) => {
167 if let Some(bytes_sent) = this.bytes_sent {
168 bytes_sent.increment(n as u64);
169 }
170 Poll::Ready(Ok(n))
171 }
172 other => other,
173 }
174 }
175
176 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
177 let this = self.project();
178 this.inner.poll_flush(cx)
179 }
180
181 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182 let this = self.project();
183 this.inner.poll_shutdown(cx)
184 }
185
186 fn is_write_vectored(&self) -> bool {
187 self.inner.is_write_vectored()
188 }
189
190 fn poll_write_vectored(
191 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
192 ) -> Poll<io::Result<usize>> {
193 let this = self.project();
194 match this.inner.poll_write_vectored(cx, bufs) {
195 Poll::Ready(Ok(n)) => {
196 if let Some(bytes_sent) = this.bytes_sent {
197 bytes_sent.increment(n as u64);
198 }
199 Poll::Ready(Ok(n))
200 }
201 other => other,
202 }
203 }
204}
205
206#[derive(Clone)]
211struct InnerConnector {
212 http: TokioHickoryHttpConnector,
213 connect_timeout: Duration,
214 #[cfg(unix)]
215 unix_socket_path: Option<Arc<std::path::Path>>,
216}
217
218impl Service<Uri> for InnerConnector {
219 type Response = Transport;
220 type Error = BoxError;
221 type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
222
223 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224 #[cfg(unix)]
227 if self.unix_socket_path.is_some() {
228 return Poll::Ready(Ok(()));
229 }
230
231 self.http.poll_ready(cx).map_err(Into::into)
232 }
233
234 fn call(&mut self, dst: Uri) -> Self::Future {
235 #[cfg(unix)]
236 if let Some(path) = self.unix_socket_path.clone() {
237 let connect_timeout = self.connect_timeout;
238 return Box::pin(async move {
239 let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
240 .await
241 .map_err(|_| -> BoxError {
242 Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
243 })?
244 .map_err(|e| -> BoxError { Box::new(e) })?;
245 Ok(Transport::Unix(TokioIo::new(stream)))
246 });
247 }
248
249 let fut = self.http.call(dst);
250 Box::pin(async move {
251 let tcp = fut.await.map_err(BoxError::from)?;
252 Ok(Transport::Tcp(tcp))
253 })
254 }
255}
256
257#[derive(Clone)]
259pub struct HttpsCapableConnector {
260 inner: HttpsConnector<InnerConnector>,
261 bytes_sent: Option<Counter>,
262 conn_age_limit: Option<Duration>,
263}
264
265impl Service<Uri> for HttpsCapableConnector {
266 type Response = HttpsCapableConnection;
267 type Error = BoxError;
268 type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
269
270 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
271 self.inner.poll_ready(cx)
272 }
273
274 fn call(&mut self, dst: Uri) -> Self::Future {
275 let inner = self.inner.call(dst);
276 let bytes_sent = self.bytes_sent.clone();
277 let conn_age_limit = self.conn_age_limit;
278 Box::pin(async move {
279 inner.await.map(|inner| HttpsCapableConnection {
280 inner,
281 bytes_sent,
282 conn_age_limit,
283 })
284 })
285 }
286}
287
288#[derive(Default)]
290pub struct HttpsCapableConnectorBuilder {
291 connect_timeout: Option<Duration>,
292 bytes_sent: Option<Counter>,
293 conn_age_limit: Option<Duration>,
294 #[cfg(unix)]
295 unix_socket_path: Option<PathBuf>,
296}
297
298impl HttpsCapableConnectorBuilder {
299 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
303 self.connect_timeout = Some(timeout);
304 self
305 }
306
307 pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
314 where
315 L: Into<Option<Duration>>,
316 {
317 self.conn_age_limit = limit.into();
318 self
319 }
320
321 pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
328 self.bytes_sent = Some(counter);
329 self
330 }
331
332 #[cfg(unix)]
340 pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
341 self.unix_socket_path = Some(path.into());
342 self
343 }
344
345 pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
347 let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
348
349 let hickory_resolver = TokioHickoryResolver::from_system_conf()
350 .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
351
352 let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
355 http_connector.set_connect_timeout(Some(connect_timeout));
356 http_connector.enforce_http(false);
357
358 let inner_connector = InnerConnector {
359 http: http_connector,
360 connect_timeout,
361 #[cfg(unix)]
362 unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
363 };
364
365 let https_connector = HttpsConnectorBuilder::new()
367 .with_tls_config(tls_config)
368 .https_or_http()
369 .enable_all_versions()
370 .wrap_connector(inner_connector);
371
372 Ok(HttpsCapableConnector {
373 inner: https_connector,
374 bytes_sent: self.bytes_sent,
375 conn_age_limit: self.conn_age_limit,
376 })
377 }
378}
379
380pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
381 let maybe_conn_metadata = captured_conn.connection_metadata();
382 if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
383 let mut extensions = Extensions::new();
384 conn_metadata.get_extras(&mut extensions);
385
386 if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
390 if conn_age_limit.is_expired() {
391 debug!("connection is expired; poisoning it");
392 conn_metadata.poison();
393 }
394 }
395 }
396}