saluki_io/net/client/http/
conn.rs1use std::{
2 future::Future,
3 io,
4 path::PathBuf,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use http::{Extensions, Uri};
12use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
13use hyper_util::{
14 client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
15 rt::TokioIo,
16};
17use metrics::Counter;
18use pin_project_lite::pin_project;
19use rustls::ClientConfig;
20use saluki_error::{ErrorContext as _, GenericError};
21use tokio::net::TcpStream;
22use tower::{BoxError, Service};
23use tracing::debug;
24
25use crate::net::dns::{HickoryHttpConnector, HickoryResolver};
26
27#[derive(Clone)]
37struct ConnectionAgeLimit {
38 limit: Duration,
39 created: Instant,
40}
41
42impl ConnectionAgeLimit {
43 fn new(limit: Duration) -> Self {
44 ConnectionAgeLimit {
45 limit,
46 created: Instant::now(),
47 }
48 }
49
50 fn is_expired(&self) -> bool {
51 self.created.elapsed() >= self.limit
52 }
53}
54
55enum Transport {
60 Tcp(TokioIo<TcpStream>),
61 #[cfg(unix)]
62 Unix(TokioIo<tokio::net::UnixStream>),
63}
64
65impl Connection for Transport {
66 fn connected(&self) -> Connected {
67 match self {
68 Self::Tcp(s) => s.connected(),
69 #[cfg(unix)]
70 Self::Unix(_) => Connected::new(),
71 }
72 }
73}
74
75impl hyper::rt::Read for Transport {
76 fn poll_read(
77 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
78 ) -> Poll<io::Result<()>> {
79 match Pin::get_mut(self) {
80 Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
81 #[cfg(unix)]
82 Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
83 }
84 }
85}
86
87impl hyper::rt::Write for Transport {
88 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
89 match Pin::get_mut(self) {
90 Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
91 #[cfg(unix)]
92 Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
93 }
94 }
95
96 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97 match Pin::get_mut(self) {
98 Self::Tcp(s) => Pin::new(s).poll_flush(cx),
99 #[cfg(unix)]
100 Self::Unix(s) => Pin::new(s).poll_flush(cx),
101 }
102 }
103
104 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105 match Pin::get_mut(self) {
106 Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
107 #[cfg(unix)]
108 Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
109 }
110 }
111
112 fn is_write_vectored(&self) -> bool {
113 match self {
114 Self::Tcp(s) => s.is_write_vectored(),
115 #[cfg(unix)]
116 Self::Unix(s) => s.is_write_vectored(),
117 }
118 }
119
120 fn poll_write_vectored(
121 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
122 ) -> Poll<io::Result<usize>> {
123 match Pin::get_mut(self) {
124 Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
125 #[cfg(unix)]
126 Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
127 }
128 }
129}
130
131pin_project! {
132 pub struct HttpsCapableConnection {
134 #[pin]
135 inner: MaybeHttpsStream<Transport>,
136 bytes_sent: Option<Counter>,
137 conn_age_limit: Option<Duration>,
138 }
139}
140
141impl Connection for HttpsCapableConnection {
142 fn connected(&self) -> Connected {
143 let connected = self.inner.connected();
144
145 if let Some(conn_age_limit) = self.conn_age_limit {
146 debug!("setting connection age limit to {:?}", conn_age_limit);
147 connected.extra(ConnectionAgeLimit::new(conn_age_limit))
148 } else {
149 connected
150 }
151 }
152}
153
154impl hyper::rt::Read for HttpsCapableConnection {
155 fn poll_read(
156 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
157 ) -> Poll<io::Result<()>> {
158 let this = self.project();
159 this.inner.poll_read(cx, buf)
160 }
161}
162
163impl hyper::rt::Write for HttpsCapableConnection {
164 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
165 let this = self.project();
166 match this.inner.poll_write(cx, buf) {
167 Poll::Ready(Ok(n)) => {
168 if let Some(bytes_sent) = this.bytes_sent {
169 bytes_sent.increment(n as u64);
170 }
171 Poll::Ready(Ok(n))
172 }
173 other => other,
174 }
175 }
176
177 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 let this = self.project();
179 this.inner.poll_flush(cx)
180 }
181
182 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
183 let this = self.project();
184 this.inner.poll_shutdown(cx)
185 }
186
187 fn is_write_vectored(&self) -> bool {
188 self.inner.is_write_vectored()
189 }
190
191 fn poll_write_vectored(
192 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
193 ) -> Poll<io::Result<usize>> {
194 let this = self.project();
195 match this.inner.poll_write_vectored(cx, bufs) {
196 Poll::Ready(Ok(n)) => {
197 if let Some(bytes_sent) = this.bytes_sent {
198 bytes_sent.increment(n as u64);
199 }
200 Poll::Ready(Ok(n))
201 }
202 other => other,
203 }
204 }
205}
206
207#[derive(Clone)]
212struct InnerConnector {
213 http: HickoryHttpConnector,
214 connect_timeout: Duration,
215 #[cfg(unix)]
216 unix_socket_path: Option<Arc<std::path::Path>>,
217}
218
219impl Service<Uri> for InnerConnector {
220 type Response = Transport;
221 type Error = BoxError;
222 type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
223
224 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225 #[cfg(unix)]
228 if self.unix_socket_path.is_some() {
229 return Poll::Ready(Ok(()));
230 }
231
232 self.http.poll_ready(cx).map_err(Into::into)
233 }
234
235 fn call(&mut self, dst: Uri) -> Self::Future {
236 #[cfg(unix)]
237 if let Some(path) = self.unix_socket_path.clone() {
238 let connect_timeout = self.connect_timeout;
239 return Box::pin(async move {
240 let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
241 .await
242 .map_err(|_| -> BoxError {
243 Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
244 })?
245 .map_err(|e| -> BoxError { Box::new(e) })?;
246 Ok(Transport::Unix(TokioIo::new(stream)))
247 });
248 }
249
250 let fut = self.http.call(dst);
251 Box::pin(async move {
252 let tcp = fut.await.map_err(BoxError::from)?;
253 Ok(Transport::Tcp(tcp))
254 })
255 }
256}
257
258#[derive(Clone)]
260pub struct HttpsCapableConnector {
261 inner: HttpsConnector<InnerConnector>,
262 bytes_sent: Option<Counter>,
263 conn_age_limit: Option<Duration>,
264}
265
266impl Service<Uri> for HttpsCapableConnector {
267 type Response = HttpsCapableConnection;
268 type Error = BoxError;
269 type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
270
271 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272 self.inner.poll_ready(cx)
273 }
274
275 fn call(&mut self, dst: Uri) -> Self::Future {
276 let inner = self.inner.call(dst);
277 let bytes_sent = self.bytes_sent.clone();
278 let conn_age_limit = self.conn_age_limit;
279 Box::pin(async move {
280 inner.await.map(|inner| HttpsCapableConnection {
281 inner,
282 bytes_sent,
283 conn_age_limit,
284 })
285 })
286 }
287}
288
289#[derive(Default)]
291pub struct HttpsCapableConnectorBuilder {
292 connect_timeout: Option<Duration>,
293 bytes_sent: Option<Counter>,
294 conn_age_limit: Option<Duration>,
295 #[cfg(unix)]
296 unix_socket_path: Option<PathBuf>,
297}
298
299impl HttpsCapableConnectorBuilder {
300 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
304 self.connect_timeout = Some(timeout);
305 self
306 }
307
308 pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
315 where
316 L: Into<Option<Duration>>,
317 {
318 self.conn_age_limit = limit.into();
319 self
320 }
321
322 pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
329 self.bytes_sent = Some(counter);
330 self
331 }
332
333 #[cfg(unix)]
341 pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
342 self.unix_socket_path = Some(path.into());
343 self
344 }
345
346 pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
348 let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
349
350 let hickory_resolver = HickoryResolver::from_system_conf()
351 .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
352
353 let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
356 http_connector.set_connect_timeout(Some(connect_timeout));
357 http_connector.enforce_http(false);
358
359 let inner_connector = InnerConnector {
360 http: http_connector,
361 connect_timeout,
362 #[cfg(unix)]
363 unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
364 };
365
366 let https_connector = HttpsConnectorBuilder::new()
368 .with_tls_config(tls_config)
369 .https_or_http()
370 .enable_all_versions()
371 .wrap_connector(inner_connector);
372
373 Ok(HttpsCapableConnector {
374 inner: https_connector,
375 bytes_sent: self.bytes_sent,
376 conn_age_limit: self.conn_age_limit,
377 })
378 }
379}
380
381pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
382 let maybe_conn_metadata = captured_conn.connection_metadata();
383 if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
384 let mut extensions = Extensions::new();
385 conn_metadata.get_extras(&mut extensions);
386
387 if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
391 if conn_age_limit.is_expired() {
392 debug!("connection is expired; poisoning it");
393 conn_metadata.poison();
394 }
395 }
396 }
397}