1use std::{
8 convert::Infallible,
9 num::NonZeroUsize,
10 sync::{Arc, LazyLock},
11};
12
13use async_trait::async_trait;
14use ddsketch::DDSketch;
15use http::{Request, Response, StatusCode};
16use hyper::{body::Incoming, service::service_fn};
17use prometheus_exposition::{MetricType, PrometheusRenderer};
18use resource_accounting::{MemoryBounds, MemoryBoundsBuilder};
19use saluki_common::{collections::FastIndexMap, iter::ReusableDeduplicator};
20use saluki_context::{tags::Tag, Context};
21use saluki_core::components::{destinations::*, ComponentContext};
22use saluki_core::data_model::event::{
23 metric::{Histogram, Metric, MetricValues},
24 EventType,
25};
26use saluki_error::GenericError;
27use saluki_io::net::{
28 listener::ConnectionOrientedListener,
29 server::http::{ErrorHandle, HttpServer, ShutdownHandle},
30 ListenAddress,
31};
32use serde::Deserialize;
33use stringtheory::{
34 interning::{FixedSizeInterner, Interner as _},
35 MetaString,
36};
37use tokio::{select, sync::RwLock};
38use tracing::debug;
39
40const CONTEXT_LIMIT: usize = 10_000;
41const PAYLOAD_SIZE_LIMIT_BYTES: usize = 1024 * 1024;
42const TAGS_BUFFER_SIZE_LIMIT_BYTES: usize = 2048;
43const RAW_METRICS_PATH: &str = "/metrics";
44const LEGACY_RAW_METRICS_PATH: &str = "/";
45
46const TIME_HISTOGRAM_BUCKET_COUNT: usize = 30;
48static TIME_HISTOGRAM_BUCKETS: LazyLock<[(f64, &'static str); TIME_HISTOGRAM_BUCKET_COUNT]> =
49 LazyLock::new(|| histogram_buckets::<TIME_HISTOGRAM_BUCKET_COUNT>(0.000000128, 4.0));
50
51const NON_TIME_HISTOGRAM_BUCKET_COUNT: usize = 30;
52static NON_TIME_HISTOGRAM_BUCKETS: LazyLock<[(f64, &'static str); NON_TIME_HISTOGRAM_BUCKET_COUNT]> =
53 LazyLock::new(|| histogram_buckets::<NON_TIME_HISTOGRAM_BUCKET_COUNT>(1.0, 2.0));
54
55const METRIC_NAME_STRING_INTERNER_BYTES: NonZeroUsize = NonZeroUsize::new(65536).unwrap();
57
58pub trait PrometheusPayloadProvider: Send + Sync {
60 fn render_payload(&self) -> String;
62}
63
64impl<F> PrometheusPayloadProvider for F
65where
66 F: Fn() -> String + Send + Sync,
67{
68 fn render_payload(&self) -> String {
69 self()
70 }
71}
72
73#[derive(Clone)]
74struct PrometheusAdditionalRoute {
75 path: String,
76 provider: Arc<dyn PrometheusPayloadProvider>,
77}
78
79#[derive(Deserialize)]
97pub struct PrometheusConfiguration {
98 #[serde(rename = "prometheus_listen_addr")]
99 listen_addr: ListenAddress,
100
101 #[serde(skip)]
102 additional_routes: Vec<PrometheusAdditionalRoute>,
103}
104
105impl PrometheusConfiguration {
106 pub fn from_listen_address(listen_addr: ListenAddress) -> Self {
108 Self {
109 listen_addr,
110 additional_routes: Vec::new(),
111 }
112 }
113
114 pub fn with_additional_route(
116 mut self, path: impl Into<String>, provider: Arc<dyn PrometheusPayloadProvider>,
117 ) -> Self {
118 self.additional_routes.push(PrometheusAdditionalRoute {
119 path: path.into(),
120 provider,
121 });
122 self
123 }
124}
125
126#[async_trait]
127impl DestinationBuilder for PrometheusConfiguration {
128 fn input_event_type(&self) -> EventType {
129 EventType::Metric
130 }
131
132 async fn build(&self, _context: ComponentContext) -> Result<Box<dyn Destination + Send>, GenericError> {
133 Ok(Box::new(Prometheus {
134 listener: ConnectionOrientedListener::from_listen_address(self.listen_addr.clone()).await?,
135 additional_routes: self.additional_routes.clone(),
136 metrics: FastIndexMap::default(),
137 payload: Arc::new(RwLock::new(String::new())),
138 renderer: PrometheusRenderer::new(),
139 interner: FixedSizeInterner::new(METRIC_NAME_STRING_INTERNER_BYTES),
140 }))
141 }
142}
143
144impl MemoryBounds for PrometheusConfiguration {
145 fn specify_bounds(&self, builder: &mut MemoryBoundsBuilder) {
146 builder
147 .minimum()
148 .with_single_value::<Prometheus>("component struct");
150
151 builder
152 .firm()
153 .with_map::<Context, PrometheusValue>("state map", CONTEXT_LIMIT)
157 .with_fixed_amount("payload size", PAYLOAD_SIZE_LIMIT_BYTES)
158 .with_fixed_amount("tags buffer", TAGS_BUFFER_SIZE_LIMIT_BYTES);
159 }
160}
161
162struct Prometheus {
163 listener: ConnectionOrientedListener,
164 additional_routes: Vec<PrometheusAdditionalRoute>,
165 metrics: FastIndexMap<PrometheusContext, FastIndexMap<Context, PrometheusValue>>,
166 payload: Arc<RwLock<String>>,
167 renderer: PrometheusRenderer,
168 interner: FixedSizeInterner<1>,
169}
170
171#[async_trait]
172impl Destination for Prometheus {
173 async fn run(mut self: Box<Self>, mut context: DestinationContext) -> Result<(), GenericError> {
174 let Self {
175 listener,
176 additional_routes,
177 mut metrics,
178 payload,
179 mut renderer,
180 interner,
181 } = *self;
182
183 let mut health = context.take_health_handle();
184
185 let (http_shutdown, mut http_error) =
186 spawn_prom_scrape_service(listener, Arc::clone(&payload), additional_routes);
187 health.mark_ready();
188
189 debug!("Prometheus destination started.");
190
191 let mut contexts = 0;
192 let mut tags_deduplicator = ReusableDeduplicator::new();
193
194 loop {
195 select! {
196 _ = health.live() => continue,
197 maybe_events = context.events().next() => match maybe_events {
198 Some(events) => {
199 for event in events {
202 if let Some(metric) = event.try_into_metric() {
203 let prom_context = match into_prometheus_metric(&metric, &mut renderer, &interner) {
207 Some(prom_context) => prom_context,
208 None => continue,
209 };
210
211 let (context, values, _) = metric.into_parts();
212
213 let existing_contexts = metrics.entry(prom_context.clone()).or_default();
215 match existing_contexts.get_mut(&context) {
216 Some(existing_prom_value) => merge_metric_values_with_prom_value(values, existing_prom_value),
217 None => {
218 if contexts >= CONTEXT_LIMIT {
219 debug!("Prometheus destination reached context limit. Skipping metric '{}'.", context.name());
220 continue
221 }
222
223 let mut new_prom_value = get_prom_value_for_prom_context(&prom_context);
224 merge_metric_values_with_prom_value(values, &mut new_prom_value);
225
226 existing_contexts.insert(context, new_prom_value);
227 contexts += 1;
228 }
229 }
230 }
231 }
232
233 regenerate_payload(&metrics, &payload, &mut renderer, &mut tags_deduplicator).await;
235 },
236 None => break,
237 },
238 error = &mut http_error => {
239 if let Some(error) = error {
240 debug!(%error, "HTTP server error.");
241 }
242 break;
243 },
244 }
245 }
246
247 http_shutdown.shutdown();
250
251 debug!("Prometheus destination stopped.");
252
253 Ok(())
254 }
255}
256
257fn spawn_prom_scrape_service(
258 listener: ConnectionOrientedListener, payload: Arc<RwLock<String>>,
259 additional_routes: Vec<PrometheusAdditionalRoute>,
260) -> (ShutdownHandle, ErrorHandle) {
261 let additional_routes = Arc::new(additional_routes);
262 let service = service_fn(move |req: Request<Incoming>| {
263 let payload = Arc::clone(&payload);
264 let additional_routes = Arc::clone(&additional_routes);
265 async move {
266 Ok::<_, Infallible>(build_scrape_response(req.uri().path(), &payload, additional_routes.as_ref()).await)
267 }
268 });
269
270 let http_server = HttpServer::from_listener(listener, service);
271 http_server.listen()
272}
273
274async fn build_scrape_response(
275 path: &str, payload: &Arc<RwLock<String>>, additional_routes: &[PrometheusAdditionalRoute],
276) -> Response<axum::body::Body> {
277 if path == RAW_METRICS_PATH || path == LEGACY_RAW_METRICS_PATH {
278 let payload = payload.read().await;
279 return Response::new(axum::body::Body::from(payload.to_string()));
280 }
281
282 if let Some(route) = additional_routes.iter().find(|route| route.path == path) {
283 return Response::new(axum::body::Body::from(route.provider.render_payload()));
284 }
285
286 Response::builder()
287 .status(StatusCode::NOT_FOUND)
288 .body(axum::body::Body::empty())
289 .expect("response builder should accept static status and empty body")
290}
291
292#[allow(clippy::mutable_key_type)]
293async fn regenerate_payload(
294 metrics: &FastIndexMap<PrometheusContext, FastIndexMap<Context, PrometheusValue>>, payload: &Arc<RwLock<String>>,
295 renderer: &mut PrometheusRenderer, tags_deduplicator: &mut ReusableDeduplicator<Tag>,
296) {
297 renderer.clear();
298
299 for (prom_context, contexts) in metrics {
300 if !write_metrics(renderer, prom_context, contexts, tags_deduplicator) {
301 debug!("Failed to write metric to payload. Continuing...");
302 continue;
303 }
304
305 if renderer.output().len() > PAYLOAD_SIZE_LIMIT_BYTES {
306 debug!(
307 payload_len = renderer.output().len(),
308 "Payload size limit exceeded. Skipping remaining metrics."
309 );
310 break;
311 }
312 }
313
314 let mut payload = payload.write().await;
315 payload.clear();
316 payload.push_str(renderer.output());
317}
318
319fn write_metrics(
320 renderer: &mut PrometheusRenderer, prom_context: &PrometheusContext,
321 contexts: &FastIndexMap<Context, PrometheusValue>, tags_deduplicator: &mut ReusableDeduplicator<Tag>,
322) -> bool {
323 if contexts.is_empty() {
324 debug!("No contexts for metric '{}'. Skipping.", prom_context.metric_name);
325 return true;
326 }
327
328 renderer.begin_group(&prom_context.metric_name, prom_context.metric_type, None);
329
330 for (context, values) in contexts {
331 let labels = match collect_tags(context, tags_deduplicator) {
332 Some(labels) => labels,
333 None => return false,
334 };
335
336 match values {
337 PrometheusValue::Counter(value) | PrometheusValue::Gauge(value) => {
338 renderer.write_gauge_or_counter_series(labels, *value);
339 }
340 PrometheusValue::Histogram(histogram) => {
341 renderer.write_histogram_series(labels, histogram.buckets(), histogram.sum, histogram.count);
342 }
343 PrometheusValue::Summary(sketch) => {
344 let quantiles = [0.1, 0.25, 0.5, 0.95, 0.99, 0.999]
345 .into_iter()
346 .map(|q| (q, sketch.quantile(q).unwrap_or_default()));
347
348 renderer.write_summary_series(labels, quantiles, sketch.sum().unwrap_or_default(), sketch.count());
349 }
350 }
351 }
352
353 renderer.finish_group();
354 true
355}
356
357fn collect_tags<'a>(
359 context: &'a Context, tags_deduplicator: &mut ReusableDeduplicator<Tag>,
360) -> Option<Vec<(&'a str, &'a str)>> {
361 let mut labels = Vec::new();
362 let mut total_bytes = 0;
363
364 let chained_tags = context.tags().into_iter().chain(context.origin_tags());
365 let deduplicated_tags = tags_deduplicator.deduplicated(chained_tags);
366
367 for tag in deduplicated_tags {
368 let tag_name = tag.name();
369 let tag_value = match tag.value() {
370 Some(value) => value,
371 None => {
372 debug!("Skipping bare tag.");
373 continue;
374 }
375 };
376
377 total_bytes += tag_name.len() + tag_value.len() + 4;
380 if total_bytes > TAGS_BUFFER_SIZE_LIMIT_BYTES {
381 debug!("Tags buffer size limit exceeded. Tags may be missing from this metric.");
382 return None;
383 }
384
385 labels.push((tag_name, tag_value));
386 }
387
388 Some(labels)
389}
390
391#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
392struct PrometheusContext {
393 metric_name: MetaString,
394 metric_type: MetricType,
395}
396
397enum PrometheusValue {
398 Counter(f64),
399 Gauge(f64),
400 Histogram(PrometheusHistogram),
401 Summary(DDSketch),
402}
403
404fn into_prometheus_metric(
405 metric: &Metric, renderer: &mut PrometheusRenderer, interner: &FixedSizeInterner<1>,
406) -> Option<PrometheusContext> {
407 let normalized = renderer.normalize_metric_name(metric.context().name());
409 let metric_name = match interner.try_intern(normalized).map(MetaString::from) {
410 Some(name) => name,
411 None => {
412 debug!(
413 "Failed to intern normalized metric name. Skipping metric '{}'.",
414 metric.context().name()
415 );
416 return None;
417 }
418 };
419
420 let metric_type = match metric.values() {
421 MetricValues::Counter(_) => MetricType::Counter,
422 MetricValues::Gauge(_) | MetricValues::Set(_) => MetricType::Gauge,
423 MetricValues::Histogram(_) => MetricType::Histogram,
424 MetricValues::Distribution(_) => MetricType::Summary,
425 _ => return None,
426 };
427
428 Some(PrometheusContext {
429 metric_name,
430 metric_type,
431 })
432}
433
434fn get_prom_value_for_prom_context(prom_context: &PrometheusContext) -> PrometheusValue {
435 match prom_context.metric_type {
436 MetricType::Counter => PrometheusValue::Counter(0.0),
437 MetricType::Gauge => PrometheusValue::Gauge(0.0),
438 MetricType::Histogram => PrometheusValue::Histogram(PrometheusHistogram::new(&prom_context.metric_name)),
439 MetricType::Summary => PrometheusValue::Summary(DDSketch::default()),
440 }
441}
442
443fn merge_metric_values_with_prom_value(values: MetricValues, prom_value: &mut PrometheusValue) {
444 match (values, prom_value) {
445 (MetricValues::Counter(counter_values), PrometheusValue::Counter(prom_counter)) => {
446 for (_, value) in counter_values {
447 *prom_counter += value;
448 }
449 }
450 (MetricValues::Gauge(gauge_values), PrometheusValue::Gauge(prom_gauge)) => {
451 let latest_value = gauge_values
452 .into_iter()
453 .max_by_key(|(ts, _)| ts.map(|v| v.get()).unwrap_or_default())
454 .map(|(_, value)| value)
455 .unwrap_or_default();
456 *prom_gauge = latest_value;
457 }
458 (MetricValues::Set(set_values), PrometheusValue::Gauge(prom_gauge)) => {
459 let latest_value = set_values
460 .into_iter()
461 .max_by_key(|(ts, _)| ts.map(|v| v.get()).unwrap_or_default())
462 .map(|(_, value)| value)
463 .unwrap_or_default();
464 *prom_gauge = latest_value;
465 }
466 (MetricValues::Histogram(histogram_values), PrometheusValue::Histogram(prom_histogram)) => {
467 for (_, value) in histogram_values {
468 prom_histogram.merge_histogram(&value);
469 }
470 }
471 (MetricValues::Distribution(distribution_values), PrometheusValue::Summary(prom_summary)) => {
472 for (_, value) in distribution_values {
473 prom_summary.merge(&value);
474 }
475 }
476 _ => panic!("Mismatched metric types"),
477 }
478}
479
480#[derive(Clone)]
481struct PrometheusHistogram {
482 sum: f64,
483 count: u64,
484 buckets: Vec<(f64, &'static str, u64)>,
485}
486
487impl PrometheusHistogram {
488 fn new(metric_name: &str) -> Self {
489 let base_buckets = if metric_name.ends_with("_seconds") {
491 &TIME_HISTOGRAM_BUCKETS[..]
492 } else {
493 &NON_TIME_HISTOGRAM_BUCKETS[..]
494 };
495
496 let buckets = base_buckets
497 .iter()
498 .map(|(upper_bound, upper_bound_str)| (*upper_bound, *upper_bound_str, 0))
499 .collect();
500
501 Self {
502 sum: 0.0,
503 count: 0,
504 buckets,
505 }
506 }
507
508 fn merge_histogram(&mut self, histogram: &Histogram) {
509 for sample in histogram.samples() {
510 self.add_sample(sample.value.into_inner(), sample.weight.0 as u64);
511 }
512 }
513
514 fn add_sample(&mut self, value: f64, weight: u64) {
515 self.sum += value * weight as f64;
516 self.count += weight;
517
518 for (upper_bound, _, count) in &mut self.buckets {
520 if value <= *upper_bound {
521 *count += weight;
522 }
523 }
524 }
525
526 fn buckets(&self) -> impl Iterator<Item = (&'static str, u64)> + '_ {
527 self.buckets
528 .iter()
529 .map(|(_, upper_bound_str, count)| (*upper_bound_str, *count))
530 }
531}
532
533fn histogram_buckets<const N: usize>(base: f64, scale: f64) -> [(f64, &'static str); N] {
534 let mut buckets = [(0.0, ""); N];
542
543 let log_linear_buckets = std::iter::repeat(base).enumerate().flat_map(|(i, base)| {
544 let pow = scale.powf(i as f64);
545 let value = base * pow;
546
547 let next_pow = scale.powf((i + 1) as f64);
548 let next_value = base * next_pow;
549 let midpoint = (value + next_value) / 2.0;
550
551 [value, midpoint]
552 });
553
554 for (i, current_le) in log_linear_buckets.enumerate().take(N) {
555 let (bucket_le, bucket_le_str) = &mut buckets[i];
556 let current_le_str = format!("{}", current_le);
557
558 *bucket_le = current_le;
559 *bucket_le_str = current_le_str.leak();
560 }
561
562 buckets
563}
564
565#[cfg(test)]
566mod tests {
567 use http_body_util::BodyExt as _;
568
569 use super::*;
570
571 #[test]
572 fn bucket_print() {
573 println!("time buckets: {:?}", *TIME_HISTOGRAM_BUCKETS);
574 println!("non-time buckets: {:?}", *NON_TIME_HISTOGRAM_BUCKETS);
575 }
576
577 #[test]
578 fn prom_histogram_add_sample() {
579 let sample1 = (0.25, 1);
580 let sample2 = (1.0, 2);
581 let sample3 = (2.0, 3);
582
583 let mut histogram = PrometheusHistogram::new("time_metric_seconds");
584 histogram.add_sample(sample1.0, sample1.1);
585 histogram.add_sample(sample2.0, sample2.1);
586 histogram.add_sample(sample3.0, sample3.1);
587
588 let sample1_weighted_value = sample1.0 * sample1.1 as f64;
589 let sample2_weighted_value = sample2.0 * sample2.1 as f64;
590 let sample3_weighted_value = sample3.0 * sample3.1 as f64;
591 let expected_sum = sample1_weighted_value + sample2_weighted_value + sample3_weighted_value;
592 let expected_count = sample1.1 + sample2.1 + sample3.1;
593 assert_eq!(histogram.sum, expected_sum);
594 assert_eq!(histogram.count, expected_count);
595
596 let mut expected_bucket_count = 0;
598 for sample in [sample1, sample2, sample3] {
599 for bucket in &histogram.buckets {
600 if sample.0 <= bucket.0 {
603 assert!(bucket.2 >= expected_bucket_count + sample.1);
604 }
605 }
606
607 expected_bucket_count += sample.1;
609 }
610 }
611
612 #[tokio::test]
613 async fn scrape_routes_serve_raw_compat_and_404() {
614 let payload = Arc::new(RwLock::new("raw".to_string()));
615 let routes = vec![PrometheusAdditionalRoute {
616 path: "/compat/metrics".to_string(),
617 provider: Arc::new(|| "compat".to_string()),
618 }];
619
620 let raw_response = build_scrape_response("/metrics", &payload, &routes).await;
621 assert_eq!(raw_response.status(), StatusCode::OK);
622 let raw_body = raw_response
623 .into_body()
624 .collect()
625 .await
626 .expect("body should collect")
627 .to_bytes();
628 assert_eq!(&raw_body[..], b"raw");
629
630 let legacy_response = build_scrape_response("/", &payload, &routes).await;
631 assert_eq!(legacy_response.status(), StatusCode::OK);
632 let legacy_body = legacy_response
633 .into_body()
634 .collect()
635 .await
636 .expect("body should collect")
637 .to_bytes();
638 assert_eq!(&legacy_body[..], b"raw");
639
640 let compat_response = build_scrape_response("/compat/metrics", &payload, &routes).await;
641 assert_eq!(compat_response.status(), StatusCode::OK);
642 let compat_body = compat_response
643 .into_body()
644 .collect()
645 .await
646 .expect("body should collect")
647 .to_bytes();
648 assert_eq!(&compat_body[..], b"compat");
649
650 let missing_response = build_scrape_response("/missing", &payload, &routes).await;
651 assert_eq!(missing_response.status(), StatusCode::NOT_FOUND);
652 }
653}