saluki_io/net/client/http/
telemetry.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{ready, Context, Poll},
7};
8
9use http::{uri::Authority, Request, Response, StatusCode, Uri};
10use http_body::Body;
11use metrics::Counter;
12use pin_project_lite::pin_project;
13use saluki_metrics::MetricsBuilder;
14use stringtheory::MetaString;
15use tower::{Layer, Service};
16
17pub type EndpointNameFn = dyn Fn(&Uri) -> Option<MetaString> + Send + Sync;
18
19const ERROR_TYPE_CLIENT: &str = "client_error";
20pub(super) const ERROR_TYPE_CONNECTION: &str = "connection_error";
21pub(super) const ERROR_TYPE_DNS: &str = "dns_error";
22pub(super) const ERROR_TYPE_TLS: &str = "tls_error";
23pub(super) const ERROR_TYPE_WROTE_REQUEST: &str = "wrote_request_error";
24const ERROR_TYPE_CANT_SEND: &str = "cant_send";
25const ERROR_TYPE_GT_400: &str = "gt_400";
26const ERROR_SCOPE_PHASE: &str = "phase";
27const ERROR_SCOPE_TRANSACTION: &str = "transaction";
28
29#[derive(Clone)]
31pub(crate) struct HttpTransactionErrorTelemetry {
32 dns_errors: Counter,
33 connection_errors: Counter,
34 tls_errors: Counter,
35 wrote_request_errors: Counter,
36 send_errors: Counter,
37 http_errors: Counter,
38}
39
40impl HttpTransactionErrorTelemetry {
41 pub(crate) fn from_builder(builder: &MetricsBuilder) -> Self {
43 Self {
46 dns_errors: register_scoped_error(builder, ERROR_TYPE_DNS, ERROR_SCOPE_PHASE),
47 connection_errors: register_scoped_error(builder, ERROR_TYPE_CONNECTION, ERROR_SCOPE_PHASE),
48 tls_errors: register_scoped_error(builder, ERROR_TYPE_TLS, ERROR_SCOPE_PHASE),
49 wrote_request_errors: register_scoped_error(builder, ERROR_TYPE_WROTE_REQUEST, ERROR_SCOPE_PHASE),
50 send_errors: register_scoped_error(builder, ERROR_TYPE_CANT_SEND, ERROR_SCOPE_TRANSACTION),
51 http_errors: register_scoped_error(builder, ERROR_TYPE_GT_400, ERROR_SCOPE_TRANSACTION),
52 }
53 }
54
55 pub(crate) fn dns_errors(&self) -> Counter {
56 self.dns_errors.clone()
57 }
58
59 pub(crate) fn increment_connection_error(&self) {
60 self.connection_errors.increment(1);
61 }
62
63 pub(crate) fn increment_tls_error(&self) {
64 self.tls_errors.increment(1);
65 }
66
67 pub(crate) fn increment_wrote_request_error(&self) {
68 self.wrote_request_errors.increment(1);
69 }
70
71 fn increment_send_error(&self) {
72 self.send_errors.increment(1);
73 }
74
75 fn increment_http_error(&self) {
76 self.http_errors.increment(1);
77 }
78}
79
80fn register_scoped_error(builder: &MetricsBuilder, error_type: &'static str, error_scope: &'static str) -> Counter {
81 builder.register_counter_with_tags(
82 "network_http_requests_errors_total",
83 [("error_type", error_type), ("error_scope", error_scope)],
84 )
85}
86
87#[derive(Clone, Default)]
124pub struct EndpointTelemetryLayer {
125 builder: MetricsBuilder,
126 endpoint_name_fn: Option<Arc<EndpointNameFn>>,
127 error_telemetry: Option<HttpTransactionErrorTelemetry>,
128}
129
130impl EndpointTelemetryLayer {
131 pub fn with_metrics_builder(mut self, builder: MetricsBuilder) -> Self {
136 self.builder = builder;
137 self
138 }
139
140 pub(super) fn with_error_telemetry(mut self, error_telemetry: HttpTransactionErrorTelemetry) -> Self {
141 self.error_telemetry = Some(error_telemetry);
142 self
143 }
144
145 pub fn with_endpoint_name_fn<F>(mut self, endpoint_name_fn: F) -> Self
155 where
156 F: Fn(&Uri) -> Option<MetaString> + Send + Sync + 'static,
157 {
158 self.endpoint_name_fn = Some(Arc::new(endpoint_name_fn));
159 self
160 }
161}
162
163impl<S> Layer<S> for EndpointTelemetryLayer {
164 type Service = EndpointTelemetry<S>;
165
166 fn layer(&self, service: S) -> Self::Service {
167 EndpointTelemetry {
168 service,
169 builder: self.builder.clone(),
170 endpoint_name_fn: self.endpoint_name_fn.clone(),
171 error_telemetry: self.error_telemetry.clone(),
172 domains: HashMap::new(),
173 endpoint_name_cache: HashMap::new(),
174 }
175 }
176}
177
178struct PerEndpointTelemetry {
179 builder: MetricsBuilder,
180 dropped: Counter,
181 success: Counter,
182 success_bytes: Counter,
183 http_errors_map: Mutex<HashMap<StatusCode, Counter>>,
184}
185
186impl PerEndpointTelemetry {
187 fn new(builder: MetricsBuilder, uri: &Uri, endpoint_name: &str) -> Self {
188 let mut domain = format!("{}://{}", uri.scheme_str().unwrap(), uri.host().unwrap());
190 if let Some(port) = uri.port() {
191 domain.push(':');
192 domain.push_str(port.as_str());
193 }
194
195 let builder = builder
196 .add_default_tag(("domain", domain))
197 .add_default_tag(("endpoint", endpoint_name.to_string()));
198
199 let dropped = builder.register_counter("network_http_requests_failed_total");
200 let success = builder.register_counter("network_http_requests_success_total");
201 let success_bytes = builder.register_counter("network_http_requests_success_sent_bytes_total");
202 let http_errors_map = Mutex::new(HashMap::new());
203
204 Self {
205 builder,
206 dropped,
207 success,
208 success_bytes,
209 http_errors_map,
210 }
211 }
212
213 fn increment_dropped(&self) {
214 self.dropped.increment(1);
215 }
216
217 fn increment_success(&self) {
218 self.success.increment(1);
219 }
220
221 fn increment_success_bytes(&self, len: u64) {
222 self.success_bytes.increment(len);
223 }
224
225 fn increment_http_error(&self, status: StatusCode) {
226 let mut http_errors_map = self.http_errors_map.lock().unwrap();
227 let counter = http_errors_map.entry(status).or_insert_with(move || {
228 self.builder.register_counter_with_tags(
229 "network_http_requests_errors_total",
230 [
231 ("error_type", ERROR_TYPE_CLIENT.to_string()),
232 ("code", status.as_str().to_string()),
233 ],
234 )
235 });
236 counter.increment(1);
237 }
238}
239
240#[derive(Clone)]
242pub struct EndpointTelemetry<S> {
243 service: S,
244 builder: MetricsBuilder,
245 endpoint_name_fn: Option<Arc<EndpointNameFn>>,
246 error_telemetry: Option<HttpTransactionErrorTelemetry>,
247 domains: HashMap<Authority, HashMap<MetaString, Arc<PerEndpointTelemetry>>>,
248 endpoint_name_cache: HashMap<Uri, MetaString>,
249}
250
251impl<S> EndpointTelemetry<S> {
252 fn get_telemetry_handle<B>(&mut self, req: &Request<B>) -> Option<Arc<PerEndpointTelemetry>>
253 where
254 B: Body,
255 {
256 if req.uri().scheme().is_none() || req.uri().host().is_none() {
258 return None;
259 }
260
261 let authority = req.uri().authority()?.clone();
263 let domain = self.domains.entry(authority).or_default();
264
265 let endpoint_telemetry = match self.endpoint_name_cache.get(req.uri()) {
270 Some(endpoint_name) => domain
271 .get(endpoint_name)
272 .expect("per-endpoint telemetry must exist if name is cached"),
273 None => {
274 let endpoint_name = self
276 .endpoint_name_fn
277 .as_ref()
278 .and_then(|f| f(req.uri()))
279 .map(sanitize_endpoint_name)
280 .unwrap_or_else(|| sanitize_endpoint_name(req.uri().path().into()));
281
282 self.endpoint_name_cache
283 .insert(req.uri().clone(), endpoint_name.clone());
284
285 domain.entry(endpoint_name).or_insert_with_key(|endpoint_name| {
287 Arc::new(PerEndpointTelemetry::new(
288 self.builder.clone(),
289 req.uri(),
290 endpoint_name,
291 ))
292 })
293 }
294 };
295
296 Some(Arc::clone(endpoint_telemetry))
297 }
298}
299
300impl<B, B2, S> Service<Request<B>> for EndpointTelemetry<S>
301where
302 S: Service<Request<B>, Response = http::Response<B2>>,
303 B: Body,
304 B2: Body,
305{
306 type Response = S::Response;
307 type Error = S::Error;
308 type Future = EndpointTelemetryFuture<S::Future>;
309
310 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
311 self.service.poll_ready(cx)
312 }
313
314 fn call(&mut self, req: Request<B>) -> Self::Future {
315 let maybe_body_len = req.body().size_hint().exact();
316 let per_endpoint = self.get_telemetry_handle(&req);
317 let fut = self.service.call(req);
318
319 EndpointTelemetryFuture {
320 per_endpoint,
321 error_telemetry: self.error_telemetry.clone(),
322 maybe_body_len,
323 fut,
324 }
325 }
326}
327
328pin_project! {
329 pub struct EndpointTelemetryFuture<F> {
331 per_endpoint: Option<Arc<PerEndpointTelemetry>>,
332 error_telemetry: Option<HttpTransactionErrorTelemetry>,
333 maybe_body_len: Option<u64>,
334
335 #[pin]
336 fut: F,
337 }
338}
339
340impl<F, B, E> Future for EndpointTelemetryFuture<F>
341where
342 F: Future<Output = Result<Response<B>, E>>,
343 B: Body,
344{
345 type Output = Result<Response<B>, E>;
346
347 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348 let this = self.project();
349 match ready!(this.fut.poll(cx)) {
350 Ok(response) => {
351 if let Some(per_endpoint) = this.per_endpoint.as_ref() {
352 let status = response.status();
353 if status.is_client_error() || status.is_server_error() {
354 per_endpoint.increment_http_error(status);
356
357 let status_code = status.as_u16();
358 if status_code == 400 || status_code == 403 || status_code == 413 {
359 per_endpoint.increment_dropped()
362 } else if let Some(error_telemetry) = this.error_telemetry.as_ref() {
363 error_telemetry.increment_http_error();
364 }
365 } else {
366 per_endpoint.increment_success();
367 if let Some(body_len) = this.maybe_body_len {
368 per_endpoint.increment_success_bytes(*body_len);
369 }
370 }
371 }
372
373 Poll::Ready(Ok(response))
374 }
375 Err(e) => {
376 if let Some(error_telemetry) = this.error_telemetry.as_ref() {
377 error_telemetry.increment_send_error();
378 }
379
380 Poll::Ready(Err(e))
381 }
382 }
383 }
384}
385
386fn is_sanitized_endpoint_name(s: &str) -> bool {
387 s.chars().all(|c| {
388 c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-' || c == '/' || c == '\\' || c == '.'
389 })
390}
391
392fn sanitize_endpoint_name(endpoint_name: MetaString) -> MetaString {
393 if is_sanitized_endpoint_name(&endpoint_name) {
395 return endpoint_name;
396 }
397
398 endpoint_name
399 .chars()
400 .map(|c| {
401 if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '/' || c == '\\' || c == '.' {
402 c.to_ascii_lowercase()
403 } else {
404 '_'
405 }
406 })
407 .collect::<String>()
408 .into()
409}
410
411#[cfg(test)]
412mod tests {
413 use proptest::prelude::*;
414
415 use super::*;
416
417 proptest! {
418 #[test]
419 fn property_test_sanitize_endpoint_name(input in ".*") {
420 let sanitized = sanitize_endpoint_name(input.into());
421 prop_assert!(is_sanitized_endpoint_name(&sanitized));
422 }
423 }
424}