1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6 time::{Duration, Instant},
7};
8#[cfg(unix)]
9use std::{path::PathBuf, sync::Arc};
10
11use hickory_resolver::net::NetError;
12use http::{Extensions, Uri};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
14use hyper_util::{
15 client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
16 rt::TokioIo,
17};
18use metrics::Counter;
19use pin_project_lite::pin_project;
20use rustls::ClientConfig;
21use saluki_error::{ErrorContext as _, GenericError};
22use tokio::net::TcpStream;
23#[cfg(target_os = "linux")]
24use tokio_vsock::{VsockAddr, VsockStream};
25use tower::{BoxError, Service};
26use tracing::debug;
27
28use super::telemetry::HttpTransactionErrorTelemetry;
29use crate::net::dns::{HickoryHttpConnector, HickoryResolver};
30
31#[derive(Clone)]
41struct ConnectionAgeLimit {
42 limit: Duration,
43 created: Instant,
44}
45
46impl ConnectionAgeLimit {
47 fn new(limit: Duration) -> Self {
48 ConnectionAgeLimit {
49 limit,
50 created: Instant::now(),
51 }
52 }
53
54 fn is_expired(&self) -> bool {
55 self.created.elapsed() >= self.limit
56 }
57}
58
59enum Transport {
64 Tcp(TokioIo<TcpStream>),
65 #[cfg(unix)]
66 Unix(TokioIo<tokio::net::UnixStream>),
67 #[cfg(target_os = "linux")]
68 Vsock(TokioIo<VsockStream>),
69}
70
71impl Connection for Transport {
72 fn connected(&self) -> Connected {
73 match self {
74 Self::Tcp(s) => s.connected(),
75 #[cfg(unix)]
76 Self::Unix(_) => Connected::new(),
77 #[cfg(target_os = "linux")]
78 Self::Vsock(_) => Connected::new(),
79 }
80 }
81}
82
83impl hyper::rt::Read for Transport {
84 fn poll_read(
85 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
86 ) -> Poll<io::Result<()>> {
87 match Pin::get_mut(self) {
88 Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
89 #[cfg(unix)]
90 Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
91 #[cfg(target_os = "linux")]
92 Self::Vsock(s) => Pin::new(s).poll_read(cx, buf),
93 }
94 }
95}
96
97impl hyper::rt::Write for Transport {
98 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
99 match Pin::get_mut(self) {
100 Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
101 #[cfg(unix)]
102 Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
103 #[cfg(target_os = "linux")]
104 Self::Vsock(s) => Pin::new(s).poll_write(cx, buf),
105 }
106 }
107
108 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
109 match Pin::get_mut(self) {
110 Self::Tcp(s) => Pin::new(s).poll_flush(cx),
111 #[cfg(unix)]
112 Self::Unix(s) => Pin::new(s).poll_flush(cx),
113 #[cfg(target_os = "linux")]
114 Self::Vsock(s) => Pin::new(s).poll_flush(cx),
115 }
116 }
117
118 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119 match Pin::get_mut(self) {
120 Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
121 #[cfg(unix)]
122 Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
123 #[cfg(target_os = "linux")]
124 Self::Vsock(s) => Pin::new(s).poll_shutdown(cx),
125 }
126 }
127
128 fn is_write_vectored(&self) -> bool {
129 match self {
130 Self::Tcp(s) => s.is_write_vectored(),
131 #[cfg(unix)]
132 Self::Unix(s) => s.is_write_vectored(),
133 #[cfg(target_os = "linux")]
134 Self::Vsock(s) => s.is_write_vectored(),
135 }
136 }
137
138 fn poll_write_vectored(
139 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
140 ) -> Poll<io::Result<usize>> {
141 match Pin::get_mut(self) {
142 Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
143 #[cfg(unix)]
144 Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
145 #[cfg(target_os = "linux")]
146 Self::Vsock(s) => Pin::new(s).poll_write_vectored(cx, bufs),
147 }
148 }
149}
150
151pin_project! {
152 pub struct HttpsCapableConnection {
154 #[pin]
155 inner: MaybeHttpsStream<Transport>,
156 bytes_sent: Option<Counter>,
157 error_telemetry: Option<HttpTransactionErrorTelemetry>,
158 conn_age_limit: Option<Duration>,
159 }
160}
161
162impl Connection for HttpsCapableConnection {
163 fn connected(&self) -> Connected {
164 let connected = self.inner.connected();
165
166 if let Some(conn_age_limit) = self.conn_age_limit {
167 debug!("setting connection age limit to {:?}", conn_age_limit);
168 connected.extra(ConnectionAgeLimit::new(conn_age_limit))
169 } else {
170 connected
171 }
172 }
173}
174
175impl hyper::rt::Read for HttpsCapableConnection {
176 fn poll_read(
177 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
178 ) -> Poll<io::Result<()>> {
179 let this = self.project();
180 this.inner.poll_read(cx, buf)
181 }
182}
183
184impl hyper::rt::Write for HttpsCapableConnection {
185 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
186 let this = self.project();
187 match this.inner.poll_write(cx, buf) {
188 Poll::Ready(Ok(n)) => {
189 if let Some(bytes_sent) = this.bytes_sent {
190 bytes_sent.increment(n as u64);
191 }
192 Poll::Ready(Ok(n))
193 }
194 Poll::Ready(Err(error)) => {
195 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
196 error_telemetry.increment_wrote_request_error();
197 }
198 Poll::Ready(Err(error))
199 }
200 other => other,
201 }
202 }
203
204 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205 let this = self.project();
206 match this.inner.poll_flush(cx) {
207 Poll::Ready(Err(error)) => {
208 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
209 error_telemetry.increment_wrote_request_error();
210 }
211 Poll::Ready(Err(error))
212 }
213 other => other,
214 }
215 }
216
217 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
218 let this = self.project();
219 this.inner.poll_shutdown(cx)
220 }
221
222 fn is_write_vectored(&self) -> bool {
223 self.inner.is_write_vectored()
224 }
225
226 fn poll_write_vectored(
227 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
228 ) -> Poll<io::Result<usize>> {
229 let this = self.project();
230 match this.inner.poll_write_vectored(cx, bufs) {
231 Poll::Ready(Ok(n)) => {
232 if let Some(bytes_sent) = this.bytes_sent {
233 bytes_sent.increment(n as u64);
234 }
235 Poll::Ready(Ok(n))
236 }
237 Poll::Ready(Err(error)) => {
238 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
239 error_telemetry.increment_wrote_request_error();
240 }
241 Poll::Ready(Err(error))
242 }
243 other => other,
244 }
245 }
246}
247
248#[derive(Clone)]
255struct InnerConnector {
256 http: HickoryHttpConnector,
257 #[cfg(unix)]
258 connect_timeout: Duration,
259 error_telemetry: Option<HttpTransactionErrorTelemetry>,
260 #[cfg(unix)]
261 unix_socket_path: Option<Arc<std::path::Path>>,
262 #[cfg(target_os = "linux")]
263 vsock_addr: Option<VsockAddr>,
264}
265
266impl Service<Uri> for InnerConnector {
267 type Response = Transport;
268 type Error = BoxError;
269 type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
270
271 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272 #[cfg(target_os = "linux")]
276 if self.vsock_addr.is_some() {
277 return Poll::Ready(Ok(()));
278 }
279
280 #[cfg(unix)]
281 if self.unix_socket_path.is_some() {
282 return Poll::Ready(Ok(()));
283 }
284
285 self.http.poll_ready(cx).map_err(Into::into)
286 }
287
288 fn call(&mut self, dst: Uri) -> Self::Future {
289 #[cfg(target_os = "linux")]
290 if let Some(addr) = self.vsock_addr {
291 let connect_timeout = self.connect_timeout;
292 let error_telemetry = self.error_telemetry.clone();
293 return Box::pin(async move {
294 let stream = tokio::time::timeout(connect_timeout, VsockStream::connect(addr))
295 .await
296 .map_err(|_| -> BoxError {
297 if let Some(error_telemetry) = &error_telemetry {
298 error_telemetry.increment_connection_error();
299 }
300 Box::new(io::Error::new(io::ErrorKind::TimedOut, "vsock connect timed out"))
301 })?
302 .map_err(|e| -> BoxError {
303 if let Some(error_telemetry) = &error_telemetry {
304 error_telemetry.increment_connection_error();
305 }
306 Box::new(e)
307 })?;
308 Ok(Transport::Vsock(TokioIo::new(stream)))
309 });
310 }
311
312 #[cfg(unix)]
313 if let Some(path) = self.unix_socket_path.clone() {
314 let connect_timeout = self.connect_timeout;
315 let error_telemetry = self.error_telemetry.clone();
316 return Box::pin(async move {
317 let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
318 .await
319 .map_err(|_| -> BoxError {
320 if let Some(error_telemetry) = &error_telemetry {
321 error_telemetry.increment_connection_error();
322 }
323 Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
324 })?
325 .map_err(|e| -> BoxError {
326 if let Some(error_telemetry) = &error_telemetry {
327 error_telemetry.increment_connection_error();
328 }
329 Box::new(e)
330 })?;
331 Ok(Transport::Unix(TokioIo::new(stream)))
332 });
333 }
334
335 let fut = self.http.call(dst);
336 let error_telemetry = self.error_telemetry.clone();
337 Box::pin(async move {
338 let tcp = fut.await.map_err(|error| {
339 if !is_dns_error(&error) {
340 if let Some(error_telemetry) = &error_telemetry {
341 error_telemetry.increment_connection_error();
342 }
343 }
344 BoxError::from(error)
345 })?;
346 Ok(Transport::Tcp(tcp))
347 })
348 }
349}
350
351#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
353pub enum HttpProtocol {
354 #[default]
356 Auto,
357
358 Http1,
360}
361
362#[derive(Clone)]
364pub struct HttpsCapableConnector {
365 inner: HttpsConnector<InnerConnector>,
366 bytes_sent: Option<Counter>,
367 error_telemetry: Option<HttpTransactionErrorTelemetry>,
368 conn_age_limit: Option<Duration>,
369}
370
371impl Service<Uri> for HttpsCapableConnector {
372 type Response = HttpsCapableConnection;
373 type Error = BoxError;
374 type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
375
376 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
377 self.inner.poll_ready(cx)
378 }
379
380 fn call(&mut self, dst: Uri) -> Self::Future {
381 let inner = self.inner.call(dst);
382 let bytes_sent = self.bytes_sent.clone();
383 let error_telemetry = self.error_telemetry.clone();
384 let conn_age_limit = self.conn_age_limit;
385 Box::pin(async move {
386 match inner.await {
387 Ok(inner) => Ok(HttpsCapableConnection {
388 inner,
389 bytes_sent,
390 error_telemetry,
391 conn_age_limit,
392 }),
393 Err(error) => {
394 if is_tls_error(error.as_ref()) {
395 if let Some(error_telemetry) = &error_telemetry {
396 error_telemetry.increment_tls_error();
397 }
398 }
399 Err(error)
400 }
401 }
402 })
403 }
404}
405
406fn build_dns_resolver(
407 error_telemetry: &Option<HttpTransactionErrorTelemetry>,
408) -> Result<HickoryResolver, GenericError> {
409 let mut r = HickoryResolver::from_system_conf()
410 .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
411 if let Some(et) = error_telemetry {
412 r = r.with_lookup_errors_counter(et.dns_errors());
413 }
414 Ok(r)
415}
416
417#[derive(Default)]
419pub struct HttpsCapableConnectorBuilder {
420 connect_timeout: Option<Duration>,
421 bytes_sent: Option<Counter>,
422 error_telemetry: Option<HttpTransactionErrorTelemetry>,
423 conn_age_limit: Option<Duration>,
424 http_protocol: HttpProtocol,
425 #[cfg(unix)]
426 unix_socket_path: Option<PathBuf>,
427 #[cfg(target_os = "linux")]
428 vsock_addr: Option<VsockAddr>,
429}
430
431impl HttpsCapableConnectorBuilder {
432 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
436 self.connect_timeout = Some(timeout);
437 self
438 }
439
440 pub fn with_http_protocol(mut self, protocol: HttpProtocol) -> Self {
444 self.http_protocol = protocol;
445 self
446 }
447
448 pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
455 where
456 L: Into<Option<Duration>>,
457 {
458 self.conn_age_limit = limit.into();
459 self
460 }
461
462 pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
469 self.bytes_sent = Some(counter);
470 self
471 }
472
473 pub(super) fn with_error_telemetry(mut self, error_telemetry: HttpTransactionErrorTelemetry) -> Self {
475 self.error_telemetry = Some(error_telemetry);
476 self
477 }
478
479 #[cfg(unix)]
487 pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
488 self.unix_socket_path = Some(path.into());
489 self
490 }
491
492 #[cfg(target_os = "linux")]
500 pub fn with_vsock_addr(mut self, addr: VsockAddr) -> Self {
501 self.vsock_addr = Some(addr);
502 self
503 }
504
505 pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
507 let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
508
509 #[cfg(target_os = "linux")]
513 let vsock_only = self.vsock_addr.is_some();
514 #[cfg(not(target_os = "linux"))]
515 let vsock_only = false;
516
517 let hickory_resolver = if vsock_only {
518 HickoryResolver::noop()
519 } else {
520 build_dns_resolver(&self.error_telemetry)?
521 };
522
523 let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
526 http_connector.set_connect_timeout(Some(connect_timeout));
527 http_connector.enforce_http(false);
528
529 let inner_connector = InnerConnector {
530 http: http_connector,
531 #[cfg(unix)]
532 connect_timeout,
533 error_telemetry: self.error_telemetry.clone(),
534 #[cfg(unix)]
535 unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
536 #[cfg(target_os = "linux")]
537 vsock_addr: self.vsock_addr,
538 };
539
540 let https_connector_builder = HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http();
542 let https_connector = match self.http_protocol {
543 HttpProtocol::Auto => https_connector_builder
544 .enable_all_versions()
545 .wrap_connector(inner_connector),
546 HttpProtocol::Http1 => https_connector_builder.enable_http1().wrap_connector(inner_connector),
547 };
548
549 Ok(HttpsCapableConnector {
550 inner: https_connector,
551 bytes_sent: self.bytes_sent,
552 error_telemetry: self.error_telemetry,
553 conn_age_limit: self.conn_age_limit,
554 })
555 }
556}
557
558#[cfg(test)]
559fn configure_tls_alpn_for_http_protocol(mut tls_config: ClientConfig, protocol: HttpProtocol) -> ClientConfig {
560 match protocol {
561 HttpProtocol::Auto => {
562 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
563 }
564 HttpProtocol::Http1 => {
565 tls_config.alpn_protocols.clear();
566 }
567 }
568
569 tls_config
570}
571
572fn is_tls_error(error: &(dyn std::error::Error + 'static)) -> bool {
573 let mut current = Some(error);
574 while let Some(error) = current {
575 if error.downcast_ref::<rustls::Error>().is_some() {
576 return true;
577 }
578 current = error.source();
579 }
580 false
581}
582
583fn is_dns_error(error: &(dyn std::error::Error + 'static)) -> bool {
584 let mut current = Some(error);
585 while let Some(error) = current {
586 if error.downcast_ref::<NetError>().is_some() {
587 return true;
588 }
589 current = error.source();
590 }
591 false
592}
593
594pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
595 let maybe_conn_metadata = captured_conn.connection_metadata();
596 if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
597 let mut extensions = Extensions::new();
598 conn_metadata.get_extras(&mut extensions);
599
600 if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
604 if conn_age_limit.is_expired() {
605 debug!("connection is expired; poisoning it");
606 conn_metadata.poison();
607 }
608 }
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::{configure_tls_alpn_for_http_protocol, HttpProtocol};
615
616 fn empty_tls_config() -> rustls::ClientConfig {
617 rustls::ClientConfig::builder_with_provider(rustls::crypto::aws_lc_rs::default_provider().into())
618 .with_safe_default_protocol_versions()
619 .expect("AWS-LC default protocol versions should be valid")
620 .with_root_certificates(rustls::RootCertStore::empty())
621 .with_no_client_auth()
622 }
623
624 #[test]
625 fn auto_protocol_advertises_h2_and_http1_alpn() {
626 let tls_config = configure_tls_alpn_for_http_protocol(empty_tls_config(), HttpProtocol::Auto);
627
628 assert_eq!(tls_config.alpn_protocols, vec![b"h2".to_vec(), b"http/1.1".to_vec()]);
629 }
630
631 #[test]
632 fn http1_protocol_leaves_alpn_empty() {
633 let tls_config = configure_tls_alpn_for_http_protocol(empty_tls_config(), HttpProtocol::Http1);
634
635 assert!(tls_config.alpn_protocols.is_empty());
636 }
637
638 #[cfg(target_os = "linux")]
642 #[tokio::test]
643 async fn vsock_takes_priority_over_unix_when_both_set() {
644 use std::sync::Arc;
645
646 use tower::Service as _;
647
648 use super::{InnerConnector, VsockAddr};
649 use crate::net::dns::HickoryResolver;
650
651 let mut connector = InnerConnector {
652 http: HickoryResolver::noop().into_http_connector(),
653 connect_timeout: std::time::Duration::from_secs(1),
654 error_telemetry: None,
655 unix_socket_path: Some(Arc::from(std::path::Path::new("/tmp/test.sock"))),
656 vsock_addr: Some(VsockAddr::new(2, 5001)),
657 };
658
659 let uri: http::Uri = "https://127.0.0.1:5001/".parse().unwrap();
662 let err = connector.call(uri).await.err().expect("expected a connection error");
663 assert!(
664 !err.to_string().contains("unix"),
665 "expected vsock error (not unix socket error), got: {err}"
666 );
667 }
668}