saluki_io/net/client/http/
conn.rs1use std::{
2 future::Future,
3 io,
4 path::PathBuf,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use hickory_resolver::net::NetError;
12use http::{Extensions, Uri};
13use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
14use hyper_util::{
15 client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
16 rt::TokioIo,
17};
18use metrics::Counter;
19use pin_project_lite::pin_project;
20use rustls::ClientConfig;
21use saluki_error::{ErrorContext as _, GenericError};
22use tokio::net::TcpStream;
23use tower::{BoxError, Service};
24use tracing::debug;
25
26use super::telemetry::HttpTransactionErrorTelemetry;
27use crate::net::dns::{HickoryHttpConnector, HickoryResolver};
28
29#[derive(Clone)]
39struct ConnectionAgeLimit {
40 limit: Duration,
41 created: Instant,
42}
43
44impl ConnectionAgeLimit {
45 fn new(limit: Duration) -> Self {
46 ConnectionAgeLimit {
47 limit,
48 created: Instant::now(),
49 }
50 }
51
52 fn is_expired(&self) -> bool {
53 self.created.elapsed() >= self.limit
54 }
55}
56
57enum Transport {
62 Tcp(TokioIo<TcpStream>),
63 #[cfg(unix)]
64 Unix(TokioIo<tokio::net::UnixStream>),
65}
66
67impl Connection for Transport {
68 fn connected(&self) -> Connected {
69 match self {
70 Self::Tcp(s) => s.connected(),
71 #[cfg(unix)]
72 Self::Unix(_) => Connected::new(),
73 }
74 }
75}
76
77impl hyper::rt::Read for Transport {
78 fn poll_read(
79 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
80 ) -> Poll<io::Result<()>> {
81 match Pin::get_mut(self) {
82 Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
83 #[cfg(unix)]
84 Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
85 }
86 }
87}
88
89impl hyper::rt::Write for Transport {
90 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
91 match Pin::get_mut(self) {
92 Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
93 #[cfg(unix)]
94 Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
95 }
96 }
97
98 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99 match Pin::get_mut(self) {
100 Self::Tcp(s) => Pin::new(s).poll_flush(cx),
101 #[cfg(unix)]
102 Self::Unix(s) => Pin::new(s).poll_flush(cx),
103 }
104 }
105
106 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107 match Pin::get_mut(self) {
108 Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
109 #[cfg(unix)]
110 Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
111 }
112 }
113
114 fn is_write_vectored(&self) -> bool {
115 match self {
116 Self::Tcp(s) => s.is_write_vectored(),
117 #[cfg(unix)]
118 Self::Unix(s) => s.is_write_vectored(),
119 }
120 }
121
122 fn poll_write_vectored(
123 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
124 ) -> Poll<io::Result<usize>> {
125 match Pin::get_mut(self) {
126 Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
127 #[cfg(unix)]
128 Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
129 }
130 }
131}
132
133pin_project! {
134 pub struct HttpsCapableConnection {
136 #[pin]
137 inner: MaybeHttpsStream<Transport>,
138 bytes_sent: Option<Counter>,
139 error_telemetry: Option<HttpTransactionErrorTelemetry>,
140 conn_age_limit: Option<Duration>,
141 }
142}
143
144impl Connection for HttpsCapableConnection {
145 fn connected(&self) -> Connected {
146 let connected = self.inner.connected();
147
148 if let Some(conn_age_limit) = self.conn_age_limit {
149 debug!("setting connection age limit to {:?}", conn_age_limit);
150 connected.extra(ConnectionAgeLimit::new(conn_age_limit))
151 } else {
152 connected
153 }
154 }
155}
156
157impl hyper::rt::Read for HttpsCapableConnection {
158 fn poll_read(
159 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
160 ) -> Poll<io::Result<()>> {
161 let this = self.project();
162 this.inner.poll_read(cx, buf)
163 }
164}
165
166impl hyper::rt::Write for HttpsCapableConnection {
167 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
168 let this = self.project();
169 match this.inner.poll_write(cx, buf) {
170 Poll::Ready(Ok(n)) => {
171 if let Some(bytes_sent) = this.bytes_sent {
172 bytes_sent.increment(n as u64);
173 }
174 Poll::Ready(Ok(n))
175 }
176 Poll::Ready(Err(error)) => {
177 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
178 error_telemetry.increment_wrote_request_error();
179 }
180 Poll::Ready(Err(error))
181 }
182 other => other,
183 }
184 }
185
186 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
187 let this = self.project();
188 match this.inner.poll_flush(cx) {
189 Poll::Ready(Err(error)) => {
190 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
191 error_telemetry.increment_wrote_request_error();
192 }
193 Poll::Ready(Err(error))
194 }
195 other => other,
196 }
197 }
198
199 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200 let this = self.project();
201 this.inner.poll_shutdown(cx)
202 }
203
204 fn is_write_vectored(&self) -> bool {
205 self.inner.is_write_vectored()
206 }
207
208 fn poll_write_vectored(
209 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
210 ) -> Poll<io::Result<usize>> {
211 let this = self.project();
212 match this.inner.poll_write_vectored(cx, bufs) {
213 Poll::Ready(Ok(n)) => {
214 if let Some(bytes_sent) = this.bytes_sent {
215 bytes_sent.increment(n as u64);
216 }
217 Poll::Ready(Ok(n))
218 }
219 Poll::Ready(Err(error)) => {
220 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
221 error_telemetry.increment_wrote_request_error();
222 }
223 Poll::Ready(Err(error))
224 }
225 other => other,
226 }
227 }
228}
229
230#[derive(Clone)]
235struct InnerConnector {
236 http: HickoryHttpConnector,
237 connect_timeout: Duration,
238 error_telemetry: Option<HttpTransactionErrorTelemetry>,
239 #[cfg(unix)]
240 unix_socket_path: Option<Arc<std::path::Path>>,
241}
242
243impl Service<Uri> for InnerConnector {
244 type Response = Transport;
245 type Error = BoxError;
246 type Future = Pin<Box<dyn Future<Output = Result<Transport, BoxError>> + Send>>;
247
248 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
249 #[cfg(unix)]
252 if self.unix_socket_path.is_some() {
253 return Poll::Ready(Ok(()));
254 }
255
256 self.http.poll_ready(cx).map_err(Into::into)
257 }
258
259 fn call(&mut self, dst: Uri) -> Self::Future {
260 #[cfg(unix)]
261 if let Some(path) = self.unix_socket_path.clone() {
262 let connect_timeout = self.connect_timeout;
263 let error_telemetry = self.error_telemetry.clone();
264 return Box::pin(async move {
265 let stream = tokio::time::timeout(connect_timeout, tokio::net::UnixStream::connect(&*path))
266 .await
267 .map_err(|_| -> BoxError {
268 if let Some(error_telemetry) = &error_telemetry {
269 error_telemetry.increment_connection_error();
270 }
271 Box::new(io::Error::new(io::ErrorKind::TimedOut, "unix socket connect timed out"))
272 })?
273 .map_err(|e| -> BoxError {
274 if let Some(error_telemetry) = &error_telemetry {
275 error_telemetry.increment_connection_error();
276 }
277 Box::new(e)
278 })?;
279 Ok(Transport::Unix(TokioIo::new(stream)))
280 });
281 }
282
283 let fut = self.http.call(dst);
284 let error_telemetry = self.error_telemetry.clone();
285 Box::pin(async move {
286 let tcp = fut.await.map_err(|error| {
287 if !is_dns_error(&error) {
288 if let Some(error_telemetry) = &error_telemetry {
289 error_telemetry.increment_connection_error();
290 }
291 }
292 BoxError::from(error)
293 })?;
294 Ok(Transport::Tcp(tcp))
295 })
296 }
297}
298
299#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
301pub enum HttpProtocol {
302 #[default]
304 Auto,
305
306 Http1,
308}
309
310#[derive(Clone)]
312pub struct HttpsCapableConnector {
313 inner: HttpsConnector<InnerConnector>,
314 bytes_sent: Option<Counter>,
315 error_telemetry: Option<HttpTransactionErrorTelemetry>,
316 conn_age_limit: Option<Duration>,
317}
318
319impl Service<Uri> for HttpsCapableConnector {
320 type Response = HttpsCapableConnection;
321 type Error = BoxError;
322 type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
323
324 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
325 self.inner.poll_ready(cx)
326 }
327
328 fn call(&mut self, dst: Uri) -> Self::Future {
329 let inner = self.inner.call(dst);
330 let bytes_sent = self.bytes_sent.clone();
331 let error_telemetry = self.error_telemetry.clone();
332 let conn_age_limit = self.conn_age_limit;
333 Box::pin(async move {
334 match inner.await {
335 Ok(inner) => Ok(HttpsCapableConnection {
336 inner,
337 bytes_sent,
338 error_telemetry,
339 conn_age_limit,
340 }),
341 Err(error) => {
342 if is_tls_error(error.as_ref()) {
343 if let Some(error_telemetry) = &error_telemetry {
344 error_telemetry.increment_tls_error();
345 }
346 }
347 Err(error)
348 }
349 }
350 })
351 }
352}
353
354#[derive(Default)]
356pub struct HttpsCapableConnectorBuilder {
357 connect_timeout: Option<Duration>,
358 bytes_sent: Option<Counter>,
359 error_telemetry: Option<HttpTransactionErrorTelemetry>,
360 conn_age_limit: Option<Duration>,
361 http_protocol: HttpProtocol,
362 #[cfg(unix)]
363 unix_socket_path: Option<PathBuf>,
364}
365
366impl HttpsCapableConnectorBuilder {
367 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
371 self.connect_timeout = Some(timeout);
372 self
373 }
374
375 pub fn with_http_protocol(mut self, protocol: HttpProtocol) -> Self {
379 self.http_protocol = protocol;
380 self
381 }
382
383 pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
390 where
391 L: Into<Option<Duration>>,
392 {
393 self.conn_age_limit = limit.into();
394 self
395 }
396
397 pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
404 self.bytes_sent = Some(counter);
405 self
406 }
407
408 pub(super) fn with_error_telemetry(mut self, error_telemetry: HttpTransactionErrorTelemetry) -> Self {
410 self.error_telemetry = Some(error_telemetry);
411 self
412 }
413
414 #[cfg(unix)]
422 pub fn with_unix_socket_path<P: Into<PathBuf>>(mut self, path: P) -> Self {
423 self.unix_socket_path = Some(path.into());
424 self
425 }
426
427 pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
429 let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
430
431 let mut hickory_resolver = HickoryResolver::from_system_conf()
432 .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
433 if let Some(error_telemetry) = &self.error_telemetry {
434 hickory_resolver = hickory_resolver.with_lookup_errors_counter(error_telemetry.dns_errors());
435 }
436
437 let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
440 http_connector.set_connect_timeout(Some(connect_timeout));
441 http_connector.enforce_http(false);
442
443 let inner_connector = InnerConnector {
444 http: http_connector,
445 connect_timeout,
446 error_telemetry: self.error_telemetry.clone(),
447 #[cfg(unix)]
448 unix_socket_path: self.unix_socket_path.map(PathBuf::into_boxed_path).map(Arc::from),
449 };
450
451 let https_connector_builder = HttpsConnectorBuilder::new().with_tls_config(tls_config).https_or_http();
453 let https_connector = match self.http_protocol {
454 HttpProtocol::Auto => https_connector_builder
455 .enable_all_versions()
456 .wrap_connector(inner_connector),
457 HttpProtocol::Http1 => https_connector_builder.enable_http1().wrap_connector(inner_connector),
458 };
459
460 Ok(HttpsCapableConnector {
461 inner: https_connector,
462 bytes_sent: self.bytes_sent,
463 error_telemetry: self.error_telemetry,
464 conn_age_limit: self.conn_age_limit,
465 })
466 }
467}
468
469#[cfg(test)]
470fn configure_tls_alpn_for_http_protocol(mut tls_config: ClientConfig, protocol: HttpProtocol) -> ClientConfig {
471 match protocol {
472 HttpProtocol::Auto => {
473 tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
474 }
475 HttpProtocol::Http1 => {
476 tls_config.alpn_protocols.clear();
477 }
478 }
479
480 tls_config
481}
482
483fn is_tls_error(error: &(dyn std::error::Error + 'static)) -> bool {
484 let mut current = Some(error);
485 while let Some(error) = current {
486 if error.downcast_ref::<rustls::Error>().is_some() {
487 return true;
488 }
489 current = error.source();
490 }
491 false
492}
493
494fn is_dns_error(error: &(dyn std::error::Error + 'static)) -> bool {
495 let mut current = Some(error);
496 while let Some(error) = current {
497 if error.downcast_ref::<NetError>().is_some() {
498 return true;
499 }
500 current = error.source();
501 }
502 false
503}
504
505pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
506 let maybe_conn_metadata = captured_conn.connection_metadata();
507 if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
508 let mut extensions = Extensions::new();
509 conn_metadata.get_extras(&mut extensions);
510
511 if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
515 if conn_age_limit.is_expired() {
516 debug!("connection is expired; poisoning it");
517 conn_metadata.poison();
518 }
519 }
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::{configure_tls_alpn_for_http_protocol, HttpProtocol};
526
527 fn empty_tls_config() -> rustls::ClientConfig {
528 rustls::ClientConfig::builder_with_provider(rustls::crypto::aws_lc_rs::default_provider().into())
529 .with_safe_default_protocol_versions()
530 .expect("AWS-LC default protocol versions should be valid")
531 .with_root_certificates(rustls::RootCertStore::empty())
532 .with_no_client_auth()
533 }
534
535 #[test]
536 fn auto_protocol_advertises_h2_and_http1_alpn() {
537 let tls_config = configure_tls_alpn_for_http_protocol(empty_tls_config(), HttpProtocol::Auto);
538
539 assert_eq!(tls_config.alpn_protocols, vec![b"h2".to_vec(), b"http/1.1".to_vec()]);
540 }
541
542 #[test]
543 fn http1_protocol_leaves_alpn_empty() {
544 let tls_config = configure_tls_alpn_for_http_protocol(empty_tls_config(), HttpProtocol::Http1);
545
546 assert!(tls_config.alpn_protocols.is_empty());
547 }
548}