saluki_core/topology/interconnect/
consumer.rs

1use saluki_metrics::static_metrics;
2use tokio::sync::mpsc;
3
4use super::Dispatchable;
5use crate::components::ComponentContext;
6
7static_metrics!(
8    name => ConsumerMetrics,
9    prefix => component,
10    labels => [component_id: String, component_type: &'static str],
11    metrics => [
12        counter(events_received_total),
13        trace_histogram(events_received_size),
14    ],
15);
16
17impl ConsumerMetrics {
18    fn from_component_context(context: ComponentContext) -> Self {
19        Self::new(context.component_id().to_string(), context.component_type().as_str())
20    }
21}
22
23/// A stream of items sent to a component.
24///
25/// This represents the receiving end of a component interconnect, where the sending end is [`Dispatcher<T>`][super::Dispatcher].
26pub struct Consumer<T> {
27    inner: mpsc::Receiver<T>,
28    metrics: ConsumerMetrics,
29}
30
31impl<T> Consumer<T>
32where
33    T: Dispatchable,
34{
35    /// Create a new `Consumer` for the given component context and inner receiver.
36    pub fn new(context: ComponentContext, inner: mpsc::Receiver<T>) -> Self {
37        Self {
38            inner,
39            metrics: ConsumerMetrics::from_component_context(context),
40        }
41    }
42
43    /// Gets the next item in the stream.
44    ///
45    /// If the component (or components) connected to this consumer have stopped, `None` is returned.
46    pub async fn next(&mut self) -> Option<T> {
47        match self.inner.recv().await {
48            Some(item) => {
49                self.metrics.events_received_total().increment(item.item_count() as u64);
50                self.metrics.events_received_size().record(item.item_count() as f64);
51                Some(item)
52            }
53            None => None,
54        }
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use metrics::{Key, Label};
61    use metrics_util::{
62        debugging::{DebugValue, DebuggingRecorder},
63        CompositeKey, MetricKind,
64    };
65    use ordered_float::OrderedFloat;
66
67    use super::*;
68    use crate::topology::ComponentId;
69
70    #[derive(Clone, Debug, Eq, PartialEq)]
71    struct DispatchableEvent<T> {
72        item_count: usize,
73        data: T,
74    }
75
76    impl<T: Clone> DispatchableEvent<T> {
77        fn new(data: T) -> Self {
78            Self { item_count: 1, data }
79        }
80
81        fn with_item_count(item_count: usize, data: T) -> Self {
82            Self { item_count, data }
83        }
84    }
85
86    impl<T: Clone> Dispatchable for DispatchableEvent<T> {
87        fn item_count(&self) -> usize {
88            self.item_count
89        }
90    }
91
92    fn create_consumer<T: Clone>(
93        channel_size: usize,
94    ) -> (Consumer<DispatchableEvent<T>>, mpsc::Sender<DispatchableEvent<T>>) {
95        let component_context = ComponentId::try_from("consumer_test")
96            .map(ComponentContext::source)
97            .expect("component ID should never be invalid");
98
99        let (tx, rx) = mpsc::channel(channel_size);
100        let consumer = Consumer::new(component_context, rx);
101
102        (consumer, tx)
103    }
104
105    fn get_consumer_metric_composite_key(kind: MetricKind, name: &'static str) -> CompositeKey {
106        // We build the labels according to what we'll generate when calling `create_consumer`:
107        static LABELS: &[Label] = &[
108            Label::from_static_parts("component_id", "consumer_test"),
109            Label::from_static_parts("component_type", "source"),
110        ];
111        let key = Key::from_static_parts(name, LABELS);
112        CompositeKey::new(kind, key)
113    }
114
115    #[tokio::test]
116    async fn next() {
117        let (mut consumer, tx) = create_consumer(1);
118
119        // Send an item, and make sure we can receive it:
120        let input_item = DispatchableEvent::new("hello world");
121        tx.send(input_item.clone()).await.expect("should not fail to send item");
122
123        let output_item = consumer.next().await.expect("should receive item");
124        assert_eq!(output_item, input_item);
125
126        // Now drop the sender, which should close the consumer:
127        drop(tx);
128
129        assert!(consumer.next().await.is_none());
130    }
131
132    #[tokio::test]
133    async fn metrics() {
134        let events_received_key =
135            get_consumer_metric_composite_key(MetricKind::Counter, ConsumerMetrics::events_received_total_name());
136        let events_received_size_key =
137            get_consumer_metric_composite_key(MetricKind::Histogram, ConsumerMetrics::events_received_size_name());
138
139        let recorder = DebuggingRecorder::new();
140        let snapshotter = recorder.snapshotter();
141        let (mut consumer, tx) = metrics::with_local_recorder(&recorder, || create_consumer(1));
142
143        // Send an item with an item count of 1, and make sure we can receive it, and that we update our metrics accordingly:
144        let single_item = DispatchableEvent::new("single item");
145        tx.send(single_item.clone())
146            .await
147            .expect("should not fail to send item");
148
149        let output_item = consumer.next().await.expect("should receive item");
150        assert_eq!(output_item, single_item);
151
152        // TODO: This API for querying the metrics really sucks... and we need something better.
153        let current_metrics = snapshotter.snapshot().into_hashmap();
154        let (_, _, events_received) = current_metrics
155            .get(&events_received_key)
156            .expect("should have events received metric");
157        let (_, _, events_received_size) = current_metrics
158            .get(&events_received_size_key)
159            .expect("should have events received size metric");
160        assert_eq!(events_received, &DebugValue::Counter(1));
161        let expected_sizes = vec![OrderedFloat(1.0)];
162        assert_eq!(events_received_size, &DebugValue::Histogram(expected_sizes));
163
164        // Now send an item with an item count of 42, and make sure we can receive it, and that we update our metrics accordingly:
165        let multiple_items = DispatchableEvent::with_item_count(42, "multiple_items");
166        tx.send(multiple_items.clone())
167            .await
168            .expect("should not fail to send item");
169
170        let output_item = consumer.next().await.expect("should receive item");
171        assert_eq!(output_item, multiple_items);
172
173        // TODO: This API for querying the metrics really sucks... and we need something better.
174        let current_metrics = snapshotter.snapshot().into_hashmap();
175        let (_, _, events_received) = current_metrics
176            .get(&events_received_key)
177            .expect("should have events received metric");
178        let (_, _, events_received_size) = current_metrics
179            .get(&events_received_size_key)
180            .expect("should have events received size metric");
181        assert_eq!(events_received, &DebugValue::Counter(42));
182
183        let expected_sizes = vec![OrderedFloat(42.0)];
184        assert_eq!(events_received_size, &DebugValue::Histogram(expected_sizes));
185    }
186}