saluki_io/net/client/http/
telemetry.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{ready, Context, Poll},
7};
8
9use http::{uri::Authority, Request, Response, StatusCode, Uri};
10use http_body::Body;
11use metrics::Counter;
12use pin_project_lite::pin_project;
13use saluki_metrics::MetricsBuilder;
14use stringtheory::MetaString;
15use tower::{Layer, Service};
16
17pub type EndpointNameFn = dyn Fn(&Uri) -> Option<MetaString> + Send + Sync;
18
19#[derive(Clone, Default)]
56pub struct EndpointTelemetryLayer {
57 builder: MetricsBuilder,
58 endpoint_name_fn: Option<Arc<EndpointNameFn>>,
59}
60
61impl EndpointTelemetryLayer {
62 pub fn with_metrics_builder(mut self, builder: MetricsBuilder) -> Self {
67 self.builder = builder;
68 self
69 }
70
71 pub fn with_endpoint_name_fn<F>(mut self, endpoint_name_fn: F) -> Self
81 where
82 F: Fn(&Uri) -> Option<MetaString> + Send + Sync + 'static,
83 {
84 self.endpoint_name_fn = Some(Arc::new(endpoint_name_fn));
85 self
86 }
87}
88
89impl<S> Layer<S> for EndpointTelemetryLayer {
90 type Service = EndpointTelemetry<S>;
91
92 fn layer(&self, service: S) -> Self::Service {
93 EndpointTelemetry {
94 service,
95 builder: self.builder.clone(),
96 endpoint_name_fn: self.endpoint_name_fn.clone(),
97 domains: HashMap::new(),
98 endpoint_name_cache: HashMap::new(),
99 }
100 }
101}
102
103struct PerEndpointTelemetry {
104 builder: MetricsBuilder,
105 dropped: Counter,
106 success: Counter,
107 success_bytes: Counter,
108 errors_map: Mutex<HashMap<&'static str, Counter>>,
109 http_errors_map: Mutex<HashMap<StatusCode, Counter>>,
110}
111
112impl PerEndpointTelemetry {
113 fn new(builder: MetricsBuilder, uri: &Uri, endpoint_name: &str) -> Self {
114 let mut domain = format!("{}://{}", uri.scheme_str().unwrap(), uri.host().unwrap());
116 if let Some(port) = uri.port() {
117 domain.push(':');
118 domain.push_str(port.as_str());
119 }
120
121 let builder = builder
122 .add_default_tag(("domain", domain))
123 .add_default_tag(("endpoint", endpoint_name.to_string()));
124
125 let dropped = builder.register_debug_counter("network_http_requests_failed_total");
126 let success = builder.register_debug_counter("network_http_requests_success_total");
127 let success_bytes = builder.register_debug_counter("network_http_requests_success_sent_bytes_total");
128 let errors_map = Mutex::new(HashMap::new());
129 let http_errors_map = Mutex::new(HashMap::new());
130
131 Self {
132 builder,
133 dropped,
134 success,
135 success_bytes,
136 errors_map,
137 http_errors_map,
138 }
139 }
140
141 fn increment_dropped(&self) {
142 self.dropped.increment(1);
143 }
144
145 fn increment_success(&self) {
146 self.success.increment(1);
147 }
148
149 fn increment_success_bytes(&self, len: u64) {
150 self.success_bytes.increment(len);
151 }
152
153 fn increment_error(&self, error_type: &'static str) {
154 let mut errors_map = self.errors_map.lock().unwrap();
155 let counter = errors_map.entry(error_type).or_insert_with(|| {
156 self.builder
157 .register_debug_counter_with_tags("network_http_requests_errors_total", [("error_type", error_type)])
158 });
159 counter.increment(1);
160 }
161
162 fn increment_http_error(&self, status: StatusCode) {
163 let mut http_errors_map = self.http_errors_map.lock().unwrap();
164 let counter = http_errors_map.entry(status).or_insert_with(move || {
165 self.builder.register_debug_counter_with_tags(
166 "network_http_requests_errors_total",
167 [
168 ("error_type", "client_error".to_string()),
169 ("code", status.as_str().to_string()),
170 ],
171 )
172 });
173 counter.increment(1);
174 }
175}
176
177#[derive(Clone)]
179pub struct EndpointTelemetry<S> {
180 service: S,
181 builder: MetricsBuilder,
182 endpoint_name_fn: Option<Arc<EndpointNameFn>>,
183 domains: HashMap<Authority, HashMap<MetaString, Arc<PerEndpointTelemetry>>>,
184 endpoint_name_cache: HashMap<Uri, MetaString>,
185}
186
187impl<S> EndpointTelemetry<S> {
188 fn get_telemetry_handle<B>(&mut self, req: &Request<B>) -> Option<Arc<PerEndpointTelemetry>>
189 where
190 B: Body,
191 {
192 if req.uri().scheme().is_none() || req.uri().host().is_none() {
194 return None;
195 }
196
197 let authority = req.uri().authority()?.clone();
199 let domain = self.domains.entry(authority).or_default();
200
201 let endpoint_telemetry = match self.endpoint_name_cache.get(req.uri()) {
206 Some(endpoint_name) => domain
207 .get(endpoint_name)
208 .expect("per-endpoint telemetry must exist if name is cached"),
209 None => {
210 let endpoint_name = self
212 .endpoint_name_fn
213 .as_ref()
214 .and_then(|f| f(req.uri()))
215 .map(sanitize_endpoint_name)
216 .unwrap_or_else(|| sanitize_endpoint_name(req.uri().path().into()));
217
218 self.endpoint_name_cache
219 .insert(req.uri().clone(), endpoint_name.clone());
220
221 domain.entry(endpoint_name).or_insert_with_key(|endpoint_name| {
223 Arc::new(PerEndpointTelemetry::new(
224 self.builder.clone(),
225 req.uri(),
226 endpoint_name,
227 ))
228 })
229 }
230 };
231
232 Some(Arc::clone(endpoint_telemetry))
233 }
234}
235
236impl<B, B2, S> Service<Request<B>> for EndpointTelemetry<S>
237where
238 S: Service<Request<B>, Response = http::Response<B2>>,
239 B: Body,
240 B2: Body,
241{
242 type Response = S::Response;
243 type Error = S::Error;
244 type Future = EndpointTelemetryFuture<S::Future>;
245
246 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
247 self.service.poll_ready(cx)
248 }
249
250 fn call(&mut self, req: Request<B>) -> Self::Future {
251 let maybe_body_len = req.body().size_hint().exact();
252 let per_endpoint = self.get_telemetry_handle(&req);
253 let fut = self.service.call(req);
254
255 EndpointTelemetryFuture {
256 per_endpoint,
257 maybe_body_len,
258 fut,
259 }
260 }
261}
262
263pin_project! {
264 pub struct EndpointTelemetryFuture<F> {
266 per_endpoint: Option<Arc<PerEndpointTelemetry>>,
267 maybe_body_len: Option<u64>,
268
269 #[pin]
270 fut: F,
271 }
272}
273
274impl<F, B, E> Future for EndpointTelemetryFuture<F>
275where
276 F: Future<Output = Result<Response<B>, E>>,
277 B: Body,
278{
279 type Output = Result<Response<B>, E>;
280
281 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
282 let this = self.project();
283 match ready!(this.fut.poll(cx)) {
284 Ok(response) => {
285 if let Some(per_endpoint) = this.per_endpoint.as_ref() {
286 let status = response.status();
287 if status.is_client_error() || status.is_server_error() {
288 per_endpoint.increment_http_error(status);
290
291 let status_code = status.as_u16();
292 if status_code == 400 || status_code == 403 || status_code == 413 {
293 per_endpoint.increment_dropped()
296 }
297 } else {
298 per_endpoint.increment_success();
299 if let Some(body_len) = this.maybe_body_len {
300 per_endpoint.increment_success_bytes(*body_len);
301 }
302 }
303 }
304
305 Poll::Ready(Ok(response))
306 }
307 Err(e) => {
308 if let Some(per_endpoint) = this.per_endpoint.as_ref() {
309 per_endpoint.increment_error("send_failed");
310 }
311
312 Poll::Ready(Err(e))
313 }
314 }
315 }
316}
317
318fn is_sanitized_endpoint_name(s: &str) -> bool {
319 s.chars().all(|c| {
320 c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-' || c == '/' || c == '\\' || c == '.'
321 })
322}
323
324fn sanitize_endpoint_name(endpoint_name: MetaString) -> MetaString {
325 if is_sanitized_endpoint_name(&endpoint_name) {
327 return endpoint_name;
328 }
329
330 endpoint_name
331 .chars()
332 .map(|c| {
333 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '/' || c == '\\' || c == '.' {
334 c.to_ascii_lowercase()
335 } else {
336 '_'
337 }
338 })
339 .collect::<String>()
340 .into()
341}
342
343#[cfg(test)]
344mod tests {
345 use proptest::prelude::*;
346
347 use super::*;
348
349 proptest! {
350 #[test]
351 fn property_test_sanitize_endpoint_name(input in ".*") {
352 let sanitized = sanitize_endpoint_name(input.into());
353 prop_assert!(is_sanitized_endpoint_name(&sanitized));
354 }
355 }
356}