Skip to main content

saluki_components/transforms/mrf_gateway/
mod.rs

1//! MRF metrics gateway transform.
2
3use std::collections::HashSet;
4
5use async_trait::async_trait;
6use resource_accounting::{MemoryBounds, MemoryBoundsBuilder};
7use saluki_config::GenericConfiguration;
8use saluki_core::{
9    components::{
10        transforms::{Transform, TransformBuilder, TransformContext},
11        ComponentContext,
12    },
13    data_model::event::{Event, EventType},
14    topology::{EventsBuffer, OutputDefinition},
15};
16use saluki_error::GenericError;
17use tokio::select;
18use tracing::{debug, error};
19
20use crate::config::MrfConfiguration;
21
22/// MRF metrics gateway transform configuration.
23///
24/// This transform sits between the enrichment stage and the MRF-specific encoder/forwarder. It owns
25/// all routing and filtering decisions for the MRF metrics pipeline:
26///
27/// - When MRF is disabled, all events are dropped.
28/// - When MRF is enabled with no allowlist, all events are forwarded.
29/// - When MRF is enabled with an allowlist, only matching events are forwarded.
30///
31/// The transform reads static MRF configuration from a snapshot taken at build time, and watches
32/// `multi_region_failover.failover_metrics` and `multi_region_failover.metric_allowlist` for
33/// dynamic updates.
34pub struct MrfMetricsGatewayConfiguration {
35    mrf_config: MrfConfiguration,
36    configuration: GenericConfiguration,
37}
38
39impl MrfMetricsGatewayConfiguration {
40    /// Creates a new `MrfMetricsGatewayConfiguration` from the given [`MrfConfiguration`].
41    pub fn new(mrf_config: MrfConfiguration, configuration: GenericConfiguration) -> Self {
42        Self {
43            mrf_config,
44            configuration,
45        }
46    }
47}
48
49/// Routing and filtering state for the MRF metrics gateway.
50#[derive(Debug)]
51enum GatewayMode {
52    /// MRF is disabled or improperly configured; drop all events.
53    Inactive,
54    /// MRF is active and no allowlist is configured; forward all events.
55    ForwardAll,
56    /// MRF is active and an allowlist is configured; forward only matching events.
57    FilteredForward { allowlist: HashSet<String> },
58}
59
60/// MRF metrics gateway transform.
61pub struct MrfMetricsGateway {
62    mrf_config: MrfConfiguration,
63    mode: GatewayMode,
64    configuration: GenericConfiguration,
65}
66
67impl MrfMetricsGateway {
68    fn new(mrf_config: MrfConfiguration, configuration: GenericConfiguration) -> Self {
69        let mode = Self::mode_for_config(&mrf_config);
70
71        Self {
72            mrf_config,
73            mode,
74            configuration,
75        }
76    }
77
78    fn mode_for_config(mrf_config: &MrfConfiguration) -> GatewayMode {
79        if !mrf_config.is_metrics_forwarding_requested() {
80            GatewayMode::Inactive
81        } else if mrf_config.metric_allowlist().is_empty() {
82            GatewayMode::ForwardAll
83        } else {
84            GatewayMode::FilteredForward {
85                allowlist: mrf_config.metric_allowlist().iter().cloned().collect(),
86            }
87        }
88    }
89
90    fn update_failover_metrics(&mut self, failover_metrics: bool) {
91        self.mrf_config.set_failover_metrics(failover_metrics);
92        self.mode = Self::mode_for_config(&self.mrf_config);
93    }
94
95    fn update_metric_allowlist(&mut self, metric_allowlist: Vec<String>) {
96        self.mrf_config.set_metric_allowlist(metric_allowlist);
97        self.mode = Self::mode_for_config(&self.mrf_config);
98    }
99
100    fn should_forward(&self, event: &Event) -> bool {
101        match &self.mode {
102            GatewayMode::Inactive => false,
103            GatewayMode::ForwardAll => true,
104            GatewayMode::FilteredForward { allowlist } => {
105                let Event::Metric(metric) = event else {
106                    return false;
107                };
108                allowlist.contains(metric.context().name().as_ref())
109            }
110        }
111    }
112
113    async fn process_event_batch(
114        &self, mut events: EventsBuffer, context: &mut TransformContext,
115    ) -> Result<(), GenericError> {
116        let input_count = events.len();
117        events.remove_if(|event| !self.should_forward(event));
118        let forwarded_count = events.len();
119        let dropped_count = input_count.saturating_sub(forwarded_count);
120
121        let sent_count = context.dispatcher().buffered()?.send_all(events).await?;
122        debug!(
123            forwarded_events = sent_count,
124            dropped_events = dropped_count,
125            "MRF metrics gateway processed event batch."
126        );
127
128        Ok(())
129    }
130}
131
132#[async_trait]
133impl TransformBuilder for MrfMetricsGatewayConfiguration {
134    async fn build(&self, _context: ComponentContext) -> Result<Box<dyn Transform + Send>, GenericError> {
135        Ok(Box::new(MrfMetricsGateway::new(
136            self.mrf_config.clone(),
137            self.configuration.clone(),
138        )))
139    }
140
141    fn input_event_type(&self) -> EventType {
142        EventType::Metric
143    }
144
145    fn outputs(&self) -> &[OutputDefinition<EventType>] {
146        static OUTPUTS: &[OutputDefinition<EventType>] = &[OutputDefinition::default_output(EventType::Metric)];
147        OUTPUTS
148    }
149}
150
151impl MemoryBounds for MrfMetricsGatewayConfiguration {
152    fn specify_bounds(&self, builder: &mut MemoryBoundsBuilder) {
153        let allowlist = self.mrf_config.metric_allowlist();
154        builder
155            .minimum()
156            .with_single_value::<MrfMetricsGateway>("component struct")
157            .with_fixed_amount("hashset overhead", std::mem::size_of::<HashSet<String>>())
158            .with_fixed_amount(
159                "allowlist strings",
160                allowlist
161                    .iter()
162                    .map(|name| name.len() + std::mem::size_of::<String>())
163                    .sum::<usize>(),
164            )
165            .with_fixed_amount(
166                "hashset buckets",
167                allowlist.len() * std::mem::size_of::<Option<String>>() * 2,
168            );
169    }
170}
171
172#[async_trait]
173impl Transform for MrfMetricsGateway {
174    async fn run(mut self: Box<Self>, mut context: TransformContext) -> Result<(), GenericError> {
175        let mut health = context.take_health_handle();
176        let mut failover_metrics_watcher = self
177            .configuration
178            .watch_for_updates("multi_region_failover.failover_metrics");
179        let mut metric_allowlist_watcher = self
180            .configuration
181            .watch_for_updates("multi_region_failover.metric_allowlist");
182
183        health.mark_ready();
184        debug!(mode = ?self.mode, "MRF metrics gateway transform started.");
185
186        loop {
187            select! {
188                _ = health.live() => continue,
189                maybe_events = context.events().next() => match maybe_events {
190                    Some(events) => {
191                        if let Err(e) = self.process_event_batch(events, &mut context).await {
192                            error!(error = %e, "MRF metrics gateway failed to process event batch.");
193                        }
194                    }
195                    None => {
196                        debug!("Event stream terminated, shutting down MRF metrics gateway transform.");
197                        break;
198                    }
199                },
200                (_, maybe_failover_metrics) = failover_metrics_watcher.changed::<bool>() => {
201                    if let Some(failover_metrics) = maybe_failover_metrics {
202                        self.update_failover_metrics(failover_metrics);
203                    }
204                },
205                (_, maybe_metric_allowlist) = metric_allowlist_watcher.changed::<Vec<String>>() => {
206                    if let Some(metric_allowlist) = maybe_metric_allowlist {
207                        self.update_metric_allowlist(metric_allowlist);
208                    }
209                },
210            }
211        }
212
213        debug!("MRF metrics gateway transform stopped.");
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use saluki_config::{dynamic::ConfigUpdate, ConfigurationLoader};
221    use saluki_core::data_model::event::{metric::Metric, Event};
222    use serde_json::json;
223
224    use super::*;
225
226    async fn dynamic_gateway_from_config(
227        value: serde_json::Value,
228    ) -> (MrfMetricsGateway, tokio::sync::mpsc::Sender<ConfigUpdate>) {
229        let (config, sender) = ConfigurationLoader::for_tests(Some(value), None, true).await;
230        let sender = sender.expect("dynamic sender should exist");
231        sender
232            .send(ConfigUpdate::Snapshot(json!({})))
233            .await
234            .expect("initial dynamic snapshot should be sent");
235        config.ready().await;
236
237        let mrf_config = MrfConfiguration::from_configuration(&config).expect("MRF configuration should deserialize");
238        (MrfMetricsGateway::new(mrf_config, config), sender)
239    }
240
241    #[tokio::test]
242    async fn failover_metrics_dynamic_update_toggles_forwarding() {
243        let (mut gw, sender) = dynamic_gateway_from_config(json!({
244            "multi_region_failover": {
245                "enabled": true,
246                "failover_metrics": false,
247                "api_key": "mrf-api-key",
248                "dd_url": "https://mrf.example.com"
249            }
250        }))
251        .await;
252        let mut watcher = gw
253            .configuration
254            .watch_for_updates("multi_region_failover.failover_metrics");
255
256        assert!(!gw.should_forward(&Event::Metric(Metric::counter("any.metric", 1.0))));
257
258        sender
259            .send(ConfigUpdate::Partial {
260                key: "multi_region_failover.failover_metrics".to_string(),
261                value: json!(true),
262            })
263            .await
264            .expect("dynamic update should be sent");
265        let (_, maybe_failover_metrics) =
266            tokio::time::timeout(std::time::Duration::from_secs(2), watcher.changed::<bool>())
267                .await
268                .expect("failover metrics update should be received");
269        gw.update_failover_metrics(maybe_failover_metrics.expect("update should have a new value"));
270        assert!(gw.should_forward(&Event::Metric(Metric::counter("any.metric", 1.0))));
271
272        sender
273            .send(ConfigUpdate::Partial {
274                key: "multi_region_failover.failover_metrics".to_string(),
275                value: json!(false),
276            })
277            .await
278            .expect("dynamic update should be sent");
279        let (_, maybe_failover_metrics) =
280            tokio::time::timeout(std::time::Duration::from_secs(2), watcher.changed::<bool>())
281                .await
282                .expect("failover metrics update should be received");
283        gw.update_failover_metrics(maybe_failover_metrics.expect("update should have a new value"));
284        assert!(!gw.should_forward(&Event::Metric(Metric::counter("any.metric", 1.0))));
285    }
286
287    #[tokio::test]
288    async fn metric_allowlist_dynamic_update_changes_filtering() {
289        let (mut gw, sender) = dynamic_gateway_from_config(json!({
290            "multi_region_failover": {
291                "enabled": true,
292                "failover_metrics": true,
293                "api_key": "mrf-api-key",
294                "dd_url": "https://mrf.example.com"
295            }
296        }))
297        .await;
298        let mut watcher = gw
299            .configuration
300            .watch_for_updates("multi_region_failover.metric_allowlist");
301
302        assert!(gw.should_forward(&Event::Metric(Metric::counter("allowed.metric", 1.0))));
303        assert!(gw.should_forward(&Event::Metric(Metric::counter("also.allowed", 1.0))));
304
305        sender
306            .send(ConfigUpdate::Partial {
307                key: "multi_region_failover.metric_allowlist".to_string(),
308                value: json!(["also.allowed"]),
309            })
310            .await
311            .expect("dynamic update should be sent");
312        let (_, maybe_metric_allowlist) =
313            tokio::time::timeout(std::time::Duration::from_secs(2), watcher.changed::<Vec<String>>())
314                .await
315                .expect("metric allowlist update should be received");
316        gw.update_metric_allowlist(maybe_metric_allowlist.expect("update should have a new value"));
317
318        assert!(!gw.should_forward(&Event::Metric(Metric::counter("allowed.metric", 1.0))));
319        assert!(gw.should_forward(&Event::Metric(Metric::counter("also.allowed", 1.0))));
320        assert!(!gw.should_forward(&Event::Metric(Metric::counter("blocked.metric", 1.0))));
321    }
322}