Skip to main content

saluki_io/net/client/http/
telemetry.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    pin::Pin,
5    sync::{Arc, Mutex},
6    task::{ready, Context, Poll},
7};
8
9use http::{uri::Authority, Request, Response, StatusCode, Uri};
10use http_body::Body;
11use metrics::Counter;
12use pin_project_lite::pin_project;
13use saluki_metrics::MetricsBuilder;
14use stringtheory::MetaString;
15use tower::{Layer, Service};
16
17pub type EndpointNameFn = dyn Fn(&Uri) -> Option<MetaString> + Send + Sync;
18
19const ERROR_TYPE_CLIENT: &str = "client_error";
20pub(super) const ERROR_TYPE_CONNECTION: &str = "connection_error";
21pub(super) const ERROR_TYPE_DNS: &str = "dns_error";
22pub(super) const ERROR_TYPE_TLS: &str = "tls_error";
23pub(super) const ERROR_TYPE_WROTE_REQUEST: &str = "wrote_request_error";
24const ERROR_TYPE_CANT_SEND: &str = "cant_send";
25const ERROR_TYPE_GT_400: &str = "gt_400";
26const ERROR_SCOPE_PHASE: &str = "phase";
27const ERROR_SCOPE_TRANSACTION: &str = "transaction";
28
29/// Emits lifecycle and transaction error telemetry for HTTP requests.
30#[derive(Clone)]
31pub(crate) struct HttpTransactionErrorTelemetry {
32    dns_errors: Counter,
33    connection_errors: Counter,
34    tls_errors: Counter,
35    wrote_request_errors: Counter,
36    send_errors: Counter,
37    http_errors: Counter,
38}
39
40impl HttpTransactionErrorTelemetry {
41    /// Creates a new `HttpTransactionErrorTelemetry` from a metrics builder.
42    pub(crate) fn from_builder(builder: &MetricsBuilder) -> Self {
43        // Mirror Core Agent forwarder buckets by counting lifecycle failures at their source. See
44        // datadog-agent/comp/forwarder/defaultforwarder/transaction/transaction.go::GetClientTrace.
45        Self {
46            dns_errors: register_scoped_error(builder, ERROR_TYPE_DNS, ERROR_SCOPE_PHASE),
47            connection_errors: register_scoped_error(builder, ERROR_TYPE_CONNECTION, ERROR_SCOPE_PHASE),
48            tls_errors: register_scoped_error(builder, ERROR_TYPE_TLS, ERROR_SCOPE_PHASE),
49            wrote_request_errors: register_scoped_error(builder, ERROR_TYPE_WROTE_REQUEST, ERROR_SCOPE_PHASE),
50            send_errors: register_scoped_error(builder, ERROR_TYPE_CANT_SEND, ERROR_SCOPE_TRANSACTION),
51            http_errors: register_scoped_error(builder, ERROR_TYPE_GT_400, ERROR_SCOPE_TRANSACTION),
52        }
53    }
54
55    pub(crate) fn dns_errors(&self) -> Counter {
56        self.dns_errors.clone()
57    }
58
59    pub(crate) fn increment_connection_error(&self) {
60        self.connection_errors.increment(1);
61    }
62
63    pub(crate) fn increment_tls_error(&self) {
64        self.tls_errors.increment(1);
65    }
66
67    pub(crate) fn increment_wrote_request_error(&self) {
68        self.wrote_request_errors.increment(1);
69    }
70
71    fn increment_send_error(&self) {
72        self.send_errors.increment(1);
73    }
74
75    fn increment_http_error(&self) {
76        self.http_errors.increment(1);
77    }
78}
79
80fn register_scoped_error(builder: &MetricsBuilder, error_type: &'static str, error_scope: &'static str) -> Counter {
81    builder.register_counter_with_tags(
82        "network_http_requests_errors_total",
83        [("error_type", error_type), ("error_scope", error_scope)],
84    )
85}
86
87/// Emit telemetry about the status of HTTP transactions.
88///
89/// This layer can be used with services that deal with `http::Request` and `http::Response`, and wraps them to provide
90/// telemetry about the status of an HTTP "transaction": a full round-trip of request and response.
91///
92/// ## Metrics
93///
94/// The following metrics are emitted:
95///
96/// - `network_http_requests_failed_total`: The total number of HTTP requests that failed with a status code of 400,
97///   403, or 413.
98/// - `network_http_requests_success_total`: The total number of successful HTTP requests. (any response with a
99///   non-4xx/5xx status code)
100/// - `network_http_requests_success_sent_bytes_total`: The total number of body bytes sent in successful HTTP requests.
101///   (see note below on how this is calculated)
102/// - `network_http_requests_errors_total`: The total number of HTTP requests that had an error, either during the
103///   sending of the request or in the response. This is further broken down by the `error_type` label.
104///   - For all responses with a status code greater than 400, `error_type` will be `client_error` and `code` will be
105///     the string version of the status code.
106///   - When there is an error during the sending of the request, `error_type` classifies the request failure.
107///
108/// All metrics are emitted with two base tags:
109///
110/// - `domain`: The full domain of the request, including scheme and port, but excluding any credentials.
111/// - `endpoint`: The endpoint name, which is derived from the URI path by default but can be customized. (See
112///   [`EndpointTelemetryLayer::with_endpoint_name_fn`] for information on customization and how the endpoint name,
113///   overall, is sanitized.)
114///
115/// ### Success bytes calculation
116///
117/// We calculate the number of bytes sent by examining the body length itself, which is done via [`Body::size_hint`].
118/// This requires that an exact body size is known, which isn't always the case. If the body size isn't known, this
119/// metric won't be emitted on a successful response.
120///
121/// For common body types, like [`FrozenChunkedBytesBuffer`][saluki_common::buf::FrozenChunkedBytesBuffer], the size
122/// hint is always exact and so this functionality should work as intended.
123#[derive(Clone, Default)]
124pub struct EndpointTelemetryLayer {
125    builder: MetricsBuilder,
126    endpoint_name_fn: Option<Arc<EndpointNameFn>>,
127    error_telemetry: Option<HttpTransactionErrorTelemetry>,
128}
129
130impl EndpointTelemetryLayer {
131    /// Create a new `EndpointTelemetryLayer` with the given `ComponentContext`.
132    ///
133    /// The component context is used when creating metrics, which ensures they're tagged in a consistent way that
134    /// attributes the metrics to the component issuing the HTTP requests.
135    pub fn with_metrics_builder(mut self, builder: MetricsBuilder) -> Self {
136        self.builder = builder;
137        self
138    }
139
140    pub(super) fn with_error_telemetry(mut self, error_telemetry: HttpTransactionErrorTelemetry) -> Self {
141        self.error_telemetry = Some(error_telemetry);
142        self
143    }
144
145    /// Sets the function used to extract the "endpoint name" from a URI.
146    ///
147    /// The value returned by this function will be sanitized to ensure it can be used as a tag value, and is limited
148    /// to: ASCII alphanumerics, hyphens, underscores, slashes, and periods. Any non-conforming character will be
149    /// replaced with an underscore. Characters will be converted to lowercase.
150    ///
151    /// The value returned by this function is also cached for the given URI, and so the function shouldn't rely on
152    /// non-deterministic behavior, or state, that could change the generated endpoint name for subsequent calls with
153    /// the same input URI.
154    pub fn with_endpoint_name_fn<F>(mut self, endpoint_name_fn: F) -> Self
155    where
156        F: Fn(&Uri) -> Option<MetaString> + Send + Sync + 'static,
157    {
158        self.endpoint_name_fn = Some(Arc::new(endpoint_name_fn));
159        self
160    }
161}
162
163impl<S> Layer<S> for EndpointTelemetryLayer {
164    type Service = EndpointTelemetry<S>;
165
166    fn layer(&self, service: S) -> Self::Service {
167        EndpointTelemetry {
168            service,
169            builder: self.builder.clone(),
170            endpoint_name_fn: self.endpoint_name_fn.clone(),
171            error_telemetry: self.error_telemetry.clone(),
172            domains: HashMap::new(),
173            endpoint_name_cache: HashMap::new(),
174        }
175    }
176}
177
178struct PerEndpointTelemetry {
179    builder: MetricsBuilder,
180    dropped: Counter,
181    success: Counter,
182    success_bytes: Counter,
183    http_errors_map: Mutex<HashMap<StatusCode, Counter>>,
184}
185
186impl PerEndpointTelemetry {
187    fn new(builder: MetricsBuilder, uri: &Uri, endpoint_name: &str) -> Self {
188        // Reconstruct the full domain from the URI, including scheme and port, but leaving out any credentials.
189        let mut domain = format!("{}://{}", uri.scheme_str().unwrap(), uri.host().unwrap());
190        if let Some(port) = uri.port() {
191            domain.push(':');
192            domain.push_str(port.as_str());
193        }
194
195        let builder = builder
196            .add_default_tag(("domain", domain))
197            .add_default_tag(("endpoint", endpoint_name.to_string()));
198
199        let dropped = builder.register_counter("network_http_requests_failed_total");
200        let success = builder.register_counter("network_http_requests_success_total");
201        let success_bytes = builder.register_counter("network_http_requests_success_sent_bytes_total");
202        let http_errors_map = Mutex::new(HashMap::new());
203
204        Self {
205            builder,
206            dropped,
207            success,
208            success_bytes,
209            http_errors_map,
210        }
211    }
212
213    fn increment_dropped(&self) {
214        self.dropped.increment(1);
215    }
216
217    fn increment_success(&self) {
218        self.success.increment(1);
219    }
220
221    fn increment_success_bytes(&self, len: u64) {
222        self.success_bytes.increment(len);
223    }
224
225    fn increment_http_error(&self, status: StatusCode) {
226        let mut http_errors_map = self.http_errors_map.lock().unwrap();
227        let counter = http_errors_map.entry(status).or_insert_with(move || {
228            self.builder.register_counter_with_tags(
229                "network_http_requests_errors_total",
230                [
231                    ("error_type", ERROR_TYPE_CLIENT.to_string()),
232                    ("code", status.as_str().to_string()),
233                ],
234            )
235        });
236        counter.increment(1);
237    }
238}
239
240/// Emit telemetry about the status of HTTP transactions.
241#[derive(Clone)]
242pub struct EndpointTelemetry<S> {
243    service: S,
244    builder: MetricsBuilder,
245    endpoint_name_fn: Option<Arc<EndpointNameFn>>,
246    error_telemetry: Option<HttpTransactionErrorTelemetry>,
247    domains: HashMap<Authority, HashMap<MetaString, Arc<PerEndpointTelemetry>>>,
248    endpoint_name_cache: HashMap<Uri, MetaString>,
249}
250
251impl<S> EndpointTelemetry<S> {
252    fn get_telemetry_handle<B>(&mut self, req: &Request<B>) -> Option<Arc<PerEndpointTelemetry>>
253    where
254        B: Body,
255    {
256        // We require a scheme and a host in the URI to emit telemetry.
257        if req.uri().scheme().is_none() || req.uri().host().is_none() {
258            return None;
259        }
260
261        // `Authority` is underpinned by `Bytes` so cloning is cheap.
262        let authority = req.uri().authority()?.clone();
263        let domain = self.domains.entry(authority).or_default();
264
265        // Look up the per-endpoint telemetry handle, or create a new one if it doesn't exist.
266        //
267        // We do some caching of the endpoint name to avoid repeatedly calling the endpoint name function, which could
268        // be expensive due to the sanitization we perform on it.
269        let endpoint_telemetry = match self.endpoint_name_cache.get(req.uri()) {
270            Some(endpoint_name) => domain
271                .get(endpoint_name)
272                .expect("per-endpoint telemetry must exist if name is cached"),
273            None => {
274                // Generate our endpoint name, and then cache it.
275                let endpoint_name = self
276                    .endpoint_name_fn
277                    .as_ref()
278                    .and_then(|f| f(req.uri()))
279                    .map(sanitize_endpoint_name)
280                    .unwrap_or_else(|| sanitize_endpoint_name(req.uri().path().into()));
281
282                self.endpoint_name_cache
283                    .insert(req.uri().clone(), endpoint_name.clone());
284
285                // Now we'll create the per-endpoint telemetry.
286                domain.entry(endpoint_name).or_insert_with_key(|endpoint_name| {
287                    Arc::new(PerEndpointTelemetry::new(
288                        self.builder.clone(),
289                        req.uri(),
290                        endpoint_name,
291                    ))
292                })
293            }
294        };
295
296        Some(Arc::clone(endpoint_telemetry))
297    }
298}
299
300impl<B, B2, S> Service<Request<B>> for EndpointTelemetry<S>
301where
302    S: Service<Request<B>, Response = http::Response<B2>>,
303    B: Body,
304    B2: Body,
305{
306    type Response = S::Response;
307    type Error = S::Error;
308    type Future = EndpointTelemetryFuture<S::Future>;
309
310    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
311        self.service.poll_ready(cx)
312    }
313
314    fn call(&mut self, req: Request<B>) -> Self::Future {
315        let maybe_body_len = req.body().size_hint().exact();
316        let per_endpoint = self.get_telemetry_handle(&req);
317        let fut = self.service.call(req);
318
319        EndpointTelemetryFuture {
320            per_endpoint,
321            error_telemetry: self.error_telemetry.clone(),
322            maybe_body_len,
323            fut,
324        }
325    }
326}
327
328pin_project! {
329    /// Response future from [`EndpointTelemetry`] services.
330    pub struct EndpointTelemetryFuture<F> {
331        per_endpoint: Option<Arc<PerEndpointTelemetry>>,
332        error_telemetry: Option<HttpTransactionErrorTelemetry>,
333        maybe_body_len: Option<u64>,
334
335        #[pin]
336        fut: F,
337    }
338}
339
340impl<F, B, E> Future for EndpointTelemetryFuture<F>
341where
342    F: Future<Output = Result<Response<B>, E>>,
343    B: Body,
344{
345    type Output = Result<Response<B>, E>;
346
347    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
348        let this = self.project();
349        match ready!(this.fut.poll(cx)) {
350            Ok(response) => {
351                if let Some(per_endpoint) = this.per_endpoint.as_ref() {
352                    let status = response.status();
353                    if status.is_client_error() || status.is_server_error() {
354                        // Always increment the HTTP error total by grouped over the actual status code.
355                        per_endpoint.increment_http_error(status);
356
357                        let status_code = status.as_u16();
358                        if status_code == 400 || status_code == 403 || status_code == 413 {
359                            // There's some specific errors where we're not going to retry them, so we can reasonable
360                            // classify these requests as being dropped: they won't be retried, etc.
361                            per_endpoint.increment_dropped()
362                        } else if let Some(error_telemetry) = this.error_telemetry.as_ref() {
363                            error_telemetry.increment_http_error();
364                        }
365                    } else {
366                        per_endpoint.increment_success();
367                        if let Some(body_len) = this.maybe_body_len {
368                            per_endpoint.increment_success_bytes(*body_len);
369                        }
370                    }
371                }
372
373                Poll::Ready(Ok(response))
374            }
375            Err(e) => {
376                if let Some(error_telemetry) = this.error_telemetry.as_ref() {
377                    error_telemetry.increment_send_error();
378                }
379
380                Poll::Ready(Err(e))
381            }
382        }
383    }
384}
385
386fn is_sanitized_endpoint_name(s: &str) -> bool {
387    s.chars().all(|c| {
388        c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-' || c == '/' || c == '\\' || c == '.'
389    })
390}
391
392fn sanitize_endpoint_name(endpoint_name: MetaString) -> MetaString {
393    // Check if the endpoint name is already sanitized, and if so, just return it as-is.
394    if is_sanitized_endpoint_name(&endpoint_name) {
395        return endpoint_name;
396    }
397
398    endpoint_name
399        .chars()
400        .map(|c| {
401            if c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '/' || c == '\\' || c == '.' {
402                c.to_ascii_lowercase()
403            } else {
404                '_'
405            }
406        })
407        .collect::<String>()
408        .into()
409}
410
411#[cfg(test)]
412mod tests {
413    use proptest::prelude::*;
414
415    use super::*;
416
417    proptest! {
418        #[test]
419        fn property_test_sanitize_endpoint_name(input in ".*") {
420            let sanitized = sanitize_endpoint_name(input.into());
421            prop_assert!(is_sanitized_endpoint_name(&sanitized));
422        }
423    }
424}