saluki_io/net/client/http/
conn.rs1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6 time::{Duration, Instant},
7};
8
9use http::{Extensions, Uri};
10use hyper_hickory::{TokioHickoryHttpConnector, TokioHickoryResolver};
11use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder, MaybeHttpsStream};
12use hyper_util::{
13 client::legacy::connect::{CaptureConnection, Connected, Connection, HttpConnector},
14 rt::TokioIo,
15};
16use metrics::Counter;
17use pin_project_lite::pin_project;
18use rustls::ClientConfig;
19use saluki_error::{ErrorContext as _, GenericError};
20use tokio::net::TcpStream;
21use tower::{BoxError, Service};
22use tracing::debug;
23
24#[derive(Clone)]
34struct ConnectionAgeLimit {
35 limit: Duration,
36 created: Instant,
37}
38
39impl ConnectionAgeLimit {
40 fn new(limit: Duration) -> Self {
41 ConnectionAgeLimit {
42 limit,
43 created: Instant::now(),
44 }
45 }
46
47 fn is_expired(&self) -> bool {
48 self.created.elapsed() >= self.limit
49 }
50}
51
52pin_project! {
53 pub struct HttpsCapableConnection {
55 #[pin]
56 inner: MaybeHttpsStream<TokioIo<TcpStream>>,
57 bytes_sent: Option<Counter>,
58 conn_age_limit: Option<Duration>,
59 }
60}
61
62impl Connection for HttpsCapableConnection {
63 fn connected(&self) -> Connected {
64 let connected = self.inner.connected();
65
66 if let Some(conn_age_limit) = self.conn_age_limit {
67 debug!("setting connection age limit to {:?}", conn_age_limit);
68 connected.extra(ConnectionAgeLimit::new(conn_age_limit))
69 } else {
70 connected
71 }
72 }
73}
74
75impl hyper::rt::Read for HttpsCapableConnection {
76 fn poll_read(
77 self: Pin<&mut Self>, cx: &mut Context<'_>, buf: hyper::rt::ReadBufCursor<'_>,
78 ) -> Poll<io::Result<()>> {
79 let this = self.project();
80 this.inner.poll_read(cx, buf)
81 }
82}
83
84impl hyper::rt::Write for HttpsCapableConnection {
85 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
86 let this = self.project();
87 match this.inner.poll_write(cx, buf) {
88 Poll::Ready(Ok(n)) => {
89 if let Some(bytes_sent) = this.bytes_sent {
90 bytes_sent.increment(n as u64);
91 }
92 Poll::Ready(Ok(n))
93 }
94 other => other,
95 }
96 }
97
98 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99 let this = self.project();
100 this.inner.poll_flush(cx)
101 }
102
103 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
104 let this = self.project();
105 this.inner.poll_shutdown(cx)
106 }
107
108 fn is_write_vectored(&self) -> bool {
109 self.inner.is_write_vectored()
110 }
111
112 fn poll_write_vectored(
113 self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>],
114 ) -> Poll<io::Result<usize>> {
115 let this = self.project();
116 match this.inner.poll_write_vectored(cx, bufs) {
117 Poll::Ready(Ok(n)) => {
118 if let Some(bytes_sent) = this.bytes_sent {
119 bytes_sent.increment(n as u64);
120 }
121 Poll::Ready(Ok(n))
122 }
123 other => other,
124 }
125 }
126}
127
128#[derive(Clone)]
130pub struct HttpsCapableConnector {
131 inner: HttpsConnector<TokioHickoryHttpConnector>,
132 bytes_sent: Option<Counter>,
133 conn_age_limit: Option<Duration>,
134}
135
136impl Service<Uri> for HttpsCapableConnector {
137 type Response = HttpsCapableConnection;
138 type Error = BoxError;
139 type Future = Pin<Box<dyn Future<Output = Result<HttpsCapableConnection, BoxError>> + Send>>;
140
141 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142 self.inner.poll_ready(cx)
143 }
144
145 fn call(&mut self, dst: Uri) -> Self::Future {
146 let inner = self.inner.call(dst);
147 let bytes_sent = self.bytes_sent.clone();
148 let conn_age_limit = self.conn_age_limit;
149 Box::pin(async move {
150 inner.await.map(|inner| HttpsCapableConnection {
151 inner,
152 bytes_sent,
153 conn_age_limit,
154 })
155 })
156 }
157}
158
159#[derive(Default)]
161pub struct HttpsCapableConnectorBuilder {
162 connect_timeout: Option<Duration>,
163 bytes_sent: Option<Counter>,
164 conn_age_limit: Option<Duration>,
165}
166
167impl HttpsCapableConnectorBuilder {
168 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
172 self.connect_timeout = Some(timeout);
173 self
174 }
175
176 pub fn with_connection_age_limit<L>(mut self, limit: L) -> Self
183 where
184 L: Into<Option<Duration>>,
185 {
186 self.conn_age_limit = limit.into();
187 self
188 }
189
190 pub fn with_bytes_sent_counter(mut self, counter: Counter) -> Self {
197 self.bytes_sent = Some(counter);
198 self
199 }
200
201 pub fn build(self, tls_config: ClientConfig) -> Result<HttpsCapableConnector, GenericError> {
203 let connect_timeout = self.connect_timeout.unwrap_or(Duration::from_secs(30));
204
205 let hickory_resolver = TokioHickoryResolver::from_system_conf()
206 .error_context("Failed to load system DNS configuration when creating DNS resolver for HTTP client.")?;
207
208 let mut http_connector = HttpConnector::new_with_resolver(hickory_resolver);
211 http_connector.set_connect_timeout(Some(connect_timeout));
212 http_connector.enforce_http(false);
213
214 let https_connector = HttpsConnectorBuilder::new()
216 .with_tls_config(tls_config)
217 .https_or_http()
218 .enable_all_versions()
219 .wrap_connector(http_connector);
220
221 Ok(HttpsCapableConnector {
222 inner: https_connector,
223 bytes_sent: self.bytes_sent,
224 conn_age_limit: self.conn_age_limit,
225 })
226 }
227}
228
229pub(super) fn check_connection_state(captured_conn: CaptureConnection) {
230 let maybe_conn_metadata = captured_conn.connection_metadata();
231 if let Some(conn_metadata) = maybe_conn_metadata.as_ref() {
232 let mut extensions = Extensions::new();
233 conn_metadata.get_extras(&mut extensions);
234
235 if let Some(conn_age_limit) = extensions.get::<ConnectionAgeLimit>() {
239 if conn_age_limit.is_expired() {
240 debug!("connection is expired; poisoning it");
241 conn_metadata.poison();
242 }
243 }
244 }
245}