saluki_core/topology/
built.rs

1use std::{collections::HashMap, future::Future, num::NonZeroUsize};
2
3use memory_accounting::{
4    allocator::{AllocationGroupToken, Tracked},
5    MemoryLimiter,
6};
7use saluki_common::task::JoinSetExt as _;
8use saluki_error::{generic_error, ErrorContext as _, GenericError};
9use saluki_health::HealthRegistry;
10use tokio::{
11    sync::mpsc,
12    task::{AbortHandle, JoinSet},
13};
14use tracing::{debug, error_span};
15
16use super::{
17    graph::Graph, running::RunningTopology, shutdown::ComponentShutdownCoordinator, ComponentId, EventsBuffer,
18    EventsConsumer, OutputName, PayloadsConsumer, RegisteredComponent, TypedComponentId,
19};
20use crate::{
21    components::{
22        destinations::{Destination, DestinationContext},
23        encoders::{Encoder, EncoderContext},
24        forwarders::{Forwarder, ForwarderContext},
25        sources::{Source, SourceContext},
26        transforms::{Transform, TransformContext},
27        ComponentContext, ComponentType,
28    },
29    topology::{context::TopologyContext, EventsDispatcher, PayloadsBuffer, PayloadsDispatcher},
30};
31
32/// A built topology.
33///
34/// Built topologies represent a topology blueprint where each configured component, along with their associated
35/// connections to other components, was validated and built successfully.
36///
37/// A built topology must be spawned via [`spawn`][Self::spawn].
38pub struct BuiltTopology {
39    name: String,
40    graph: Graph,
41    sources: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Source + Send>>>>,
42    transforms: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Transform + Send>>>>,
43    destinations: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Destination + Send>>>>,
44    encoders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Encoder + Send>>>>,
45    forwarders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Forwarder + Send>>>>,
46    component_token: AllocationGroupToken,
47    interconnect_capacity: NonZeroUsize,
48}
49
50impl BuiltTopology {
51    #[allow(clippy::too_many_arguments)]
52    pub(crate) fn from_parts(
53        name: String, graph: Graph,
54        sources: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Source + Send>>>>,
55        transforms: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Transform + Send>>>>,
56        destinations: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Destination + Send>>>>,
57        encoders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Encoder + Send>>>>,
58        forwarders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Forwarder + Send>>>>,
59        component_token: AllocationGroupToken, interconnect_capacity: NonZeroUsize,
60    ) -> Self {
61        Self {
62            name,
63            graph,
64            sources,
65            transforms,
66            destinations,
67            encoders,
68            forwarders,
69            component_token,
70            interconnect_capacity,
71        }
72    }
73
74    /// Spawns the topology.
75    ///
76    /// A handle is returned that can be used to trigger the topology to shutdown.
77    ///
78    /// ## Errors
79    ///
80    /// If an error occurs while spawning the topology, an error is returned.
81    pub async fn spawn(
82        self, health_registry: &HealthRegistry, memory_limiter: MemoryLimiter,
83    ) -> Result<RunningTopology, GenericError> {
84        let root_component_name = format!("topology.{}", self.name);
85
86        let _guard = self.component_token.enter();
87
88        let thread_pool = tokio::runtime::Builder::new_multi_thread()
89            .worker_threads(8)
90            .enable_all()
91            .build()
92            .error_context("Failed to build asynchronous thread pool runtime.")?;
93        let thread_pool_handle = thread_pool.handle().clone();
94
95        let topology_context = TopologyContext::new(memory_limiter, health_registry.clone(), thread_pool_handle);
96
97        let mut component_tasks = JoinSet::new();
98        let mut component_task_map = HashMap::new();
99
100        // Build our interconnects, which we'll grab from piecemeal as we spawn our components.
101        let mut interconnects = ComponentInterconnects::from_graph(self.interconnect_capacity, &self.graph)
102            .error_context("Failed to build component interconnects.")?;
103
104        let mut shutdown_coordinator = ComponentShutdownCoordinator::default();
105
106        // Spawn our sources.
107        for (component_id, source) in self.sources {
108            let (source, component_registry) = source.into_parts();
109
110            let dispatcher = interconnects
111                .take_source_dispatcher(&component_id)
112                .ok_or_else(|| generic_error!("No events dispatcher found for source component '{}'", component_id))?;
113
114            let shutdown_handle = shutdown_coordinator.register();
115            let health_handle = health_registry
116                .register_component(format!("{}.sources.{}", root_component_name, component_id))
117                .expect("duplicate source component ID in health registry");
118
119            let component_context = ComponentContext::source(component_id.clone());
120            let context = SourceContext::new(
121                &topology_context,
122                &component_context,
123                component_registry,
124                shutdown_handle,
125                health_handle,
126                dispatcher,
127            );
128
129            let (alloc_group, source) = source.into_parts();
130            let task_handle = spawn_component(
131                &mut component_tasks,
132                component_context,
133                alloc_group,
134                source.run(context),
135            );
136            component_task_map.insert(task_handle.id(), component_id);
137        }
138
139        // Spawn our transforms.
140        for (component_id, transform) in self.transforms {
141            let (transform, component_registry) = transform.into_parts();
142
143            let dispatcher = interconnects.take_transform_dispatcher(&component_id).ok_or_else(|| {
144                generic_error!("No events dispatcher found for transform component '{}'", component_id)
145            })?;
146
147            let consumer = interconnects
148                .take_transform_consumer(&component_id)
149                .ok_or_else(|| generic_error!("No events consumer found for transform component '{}'", component_id))?;
150
151            let health_handle = health_registry
152                .register_component(format!("{}.transforms.{}", root_component_name, component_id))
153                .expect("duplicate transform component ID in health registry");
154
155            let component_context = ComponentContext::transform(component_id.clone());
156            let context = TransformContext::new(
157                &topology_context,
158                &component_context,
159                component_registry,
160                health_handle,
161                dispatcher,
162                consumer,
163            );
164
165            let (alloc_group, transform) = transform.into_parts();
166            let task_handle = spawn_component(
167                &mut component_tasks,
168                component_context,
169                alloc_group,
170                transform.run(context),
171            );
172            component_task_map.insert(task_handle.id(), component_id);
173        }
174
175        // Spawn our destinations.
176        for (component_id, destination) in self.destinations {
177            let (destination, component_registry) = destination.into_parts();
178
179            let consumer = interconnects.take_destination_consumer(&component_id).ok_or_else(|| {
180                generic_error!("No events consumer found for destination component '{}'", component_id)
181            })?;
182
183            let health_handle = health_registry
184                .register_component(format!("{}.destinations.{}", root_component_name, component_id))
185                .expect("duplicate destination component ID in health registry");
186
187            let component_context = ComponentContext::destination(component_id.clone());
188            let context = DestinationContext::new(
189                &topology_context,
190                &component_context,
191                component_registry,
192                health_handle,
193                consumer,
194            );
195
196            let (alloc_group, destination) = destination.into_parts();
197            let task_handle = spawn_component(
198                &mut component_tasks,
199                component_context,
200                alloc_group,
201                destination.run(context),
202            );
203            component_task_map.insert(task_handle.id(), component_id);
204        }
205
206        // Spawn our encoders.
207        for (component_id, encoder) in self.encoders {
208            let (encoder, component_registry) = encoder.into_parts();
209
210            let dispatcher = interconnects.take_encoder_dispatcher(&component_id).ok_or_else(|| {
211                generic_error!("No payloads dispatcher found for encoder component '{}'", component_id)
212            })?;
213
214            let consumer = interconnects
215                .take_encoder_consumer(&component_id)
216                .ok_or_else(|| generic_error!("No events consumer found for encoder component '{}'", component_id))?;
217
218            let health_handle = health_registry
219                .register_component(format!("{}.encoders.{}", root_component_name, component_id))
220                .expect("duplicate encoder component ID in health registry");
221
222            let component_context = ComponentContext::encoder(component_id.clone());
223            let context = EncoderContext::new(
224                &topology_context,
225                &component_context,
226                component_registry,
227                health_handle,
228                dispatcher,
229                consumer,
230            );
231
232            let (alloc_group, encoder) = encoder.into_parts();
233            let task_handle = spawn_component(
234                &mut component_tasks,
235                component_context,
236                alloc_group,
237                encoder.run(context),
238            );
239            component_task_map.insert(task_handle.id(), component_id);
240        }
241
242        // Spawn our forwarders.
243        for (component_id, forwarder) in self.forwarders {
244            let (forwarder, component_registry) = forwarder.into_parts();
245
246            let consumer = interconnects.take_forwarder_consumer(&component_id).ok_or_else(|| {
247                generic_error!("No payloads consumer found for forwarder component '{}'", component_id)
248            })?;
249
250            let health_handle = health_registry
251                .register_component(format!("{}.forwarders.{}", root_component_name, component_id))
252                .expect("duplicate forwarder component ID in health registry");
253
254            let component_context = ComponentContext::forwarder(component_id.clone());
255            let context = ForwarderContext::new(
256                &topology_context,
257                &component_context,
258                component_registry,
259                health_handle,
260                consumer,
261            );
262
263            let (alloc_group, forwarder) = forwarder.into_parts();
264            let task_handle = spawn_component(
265                &mut component_tasks,
266                component_context,
267                alloc_group,
268                forwarder.run(context),
269            );
270            component_task_map.insert(task_handle.id(), component_id);
271        }
272
273        Ok(RunningTopology::from_parts(
274            thread_pool,
275            shutdown_coordinator,
276            component_tasks,
277            component_task_map,
278        ))
279    }
280}
281
282struct ComponentInterconnects {
283    interconnect_capacity: NonZeroUsize,
284    source_dispatchers: HashMap<ComponentId, EventsDispatcher>,
285    transform_consumers: HashMap<ComponentId, (mpsc::Sender<EventsBuffer>, EventsConsumer)>,
286    transform_dispatchers: HashMap<ComponentId, EventsDispatcher>,
287    destination_consumers: HashMap<ComponentId, (mpsc::Sender<EventsBuffer>, EventsConsumer)>,
288    encoder_consumers: HashMap<ComponentId, (mpsc::Sender<EventsBuffer>, EventsConsumer)>,
289    encoder_dispatchers: HashMap<ComponentId, PayloadsDispatcher>,
290    forwarder_consumers: HashMap<ComponentId, (mpsc::Sender<PayloadsBuffer>, PayloadsConsumer)>,
291}
292
293impl ComponentInterconnects {
294    fn from_graph(interconnect_capacity: NonZeroUsize, graph: &Graph) -> Result<Self, GenericError> {
295        let mut interconnects = Self {
296            interconnect_capacity,
297            source_dispatchers: HashMap::new(),
298            transform_consumers: HashMap::new(),
299            transform_dispatchers: HashMap::new(),
300            destination_consumers: HashMap::new(),
301            encoder_consumers: HashMap::new(),
302            encoder_dispatchers: HashMap::new(),
303            forwarder_consumers: HashMap::new(),
304        };
305
306        interconnects.generate_interconnects(graph)?;
307        Ok(interconnects)
308    }
309
310    fn take_source_dispatcher(&mut self, component_id: &ComponentId) -> Option<EventsDispatcher> {
311        self.source_dispatchers.remove(component_id)
312    }
313
314    fn take_transform_dispatcher(&mut self, component_id: &ComponentId) -> Option<EventsDispatcher> {
315        self.transform_dispatchers.remove(component_id)
316    }
317
318    fn take_encoder_dispatcher(&mut self, component_id: &ComponentId) -> Option<PayloadsDispatcher> {
319        self.encoder_dispatchers.remove(component_id)
320    }
321
322    fn take_transform_consumer(&mut self, component_id: &ComponentId) -> Option<EventsConsumer> {
323        self.transform_consumers
324            .remove(component_id)
325            .map(|(_, consumer)| consumer)
326    }
327
328    fn take_destination_consumer(&mut self, component_id: &ComponentId) -> Option<EventsConsumer> {
329        self.destination_consumers
330            .remove(component_id)
331            .map(|(_, consumer)| consumer)
332    }
333
334    fn take_encoder_consumer(&mut self, component_id: &ComponentId) -> Option<EventsConsumer> {
335        self.encoder_consumers
336            .remove(component_id)
337            .map(|(_, consumer)| consumer)
338    }
339
340    fn take_forwarder_consumer(&mut self, component_id: &ComponentId) -> Option<PayloadsConsumer> {
341        self.forwarder_consumers
342            .remove(component_id)
343            .map(|(_, consumer)| consumer)
344    }
345
346    fn generate_interconnects(&mut self, graph: &Graph) -> Result<(), GenericError> {
347        // Collect and iterate over each outbound edge in the topology graph.
348        //
349        // For each upstream component ("from" side of the edge), we attach each downstream component ("to" side of the edge) to it,
350        // creating the relevant dispatcher or consumer if necessary.
351        let outbound_edges = graph.get_outbound_directed_edges();
352        for (upstream_id, output_map) in outbound_edges {
353            match upstream_id.component_type() {
354                ComponentType::Source | ComponentType::Transform => {
355                    self.generate_event_interconnect(upstream_id, output_map)?;
356                }
357                ComponentType::Encoder => self.generate_payload_interconnect(upstream_id, output_map)?,
358                _ => panic!(
359                    "Only sources, transforms, and encoders can dispatch events/payloads to downstream components."
360                ),
361            }
362        }
363
364        Ok(())
365    }
366
367    fn generate_event_interconnect(
368        &mut self, upstream_id: TypedComponentId, output_map: HashMap<OutputName, Vec<TypedComponentId>>,
369    ) -> Result<(), GenericError> {
370        for (upstream_output_id, downstream_ids) in output_map {
371            let mut senders = Vec::new();
372            for downstream_id in downstream_ids {
373                debug!(upstream_id = %upstream_id.component_id(), %upstream_output_id, downstream_id = %downstream_id.component_id(), "Adding dispatcher output.");
374                let sender = self.get_or_create_events_sender(downstream_id);
375                senders.push(sender);
376            }
377
378            let dispatcher = self.get_or_create_events_dispatcher(upstream_id.clone());
379            dispatcher.add_output(upstream_output_id.clone())?;
380
381            for sender in senders {
382                dispatcher.attach_sender_to_output(&upstream_output_id, sender)?;
383            }
384        }
385
386        Ok(())
387    }
388
389    fn generate_payload_interconnect(
390        &mut self, upstream_id: TypedComponentId, output_map: HashMap<OutputName, Vec<TypedComponentId>>,
391    ) -> Result<(), GenericError> {
392        for (upstream_output_id, downstream_ids) in output_map {
393            let mut senders = Vec::new();
394            for downstream_id in downstream_ids {
395                debug!(upstream_id = %upstream_id.component_id(), %upstream_output_id, downstream_id = %downstream_id.component_id(), "Adding dispatcher output.");
396                let sender = self.get_or_create_payloads_sender(downstream_id);
397                senders.push(sender);
398            }
399
400            let dispatcher = self.get_or_create_payloads_dispatcher(upstream_id.clone());
401            dispatcher.add_output(upstream_output_id.clone())?;
402
403            for sender in senders {
404                dispatcher.attach_sender_to_output(&upstream_output_id, sender)?;
405            }
406        }
407
408        Ok(())
409    }
410
411    fn get_or_create_events_dispatcher(&mut self, component_id: TypedComponentId) -> &mut EventsDispatcher {
412        let (component_id, component_type, component_context) = component_id.into_parts();
413
414        match component_type {
415            ComponentType::Source => self
416                .source_dispatchers
417                .entry(component_id)
418                .or_insert_with(|| EventsDispatcher::new(component_context)),
419            ComponentType::Transform => self
420                .transform_dispatchers
421                .entry(component_id)
422                .or_insert_with(|| EventsDispatcher::new(component_context)),
423            _ => {
424                panic!("Only sources and transforms can dispatch events to downstream components.")
425            }
426        }
427    }
428
429    fn get_or_create_events_sender(&mut self, component_id: TypedComponentId) -> mpsc::Sender<EventsBuffer> {
430        let (component_id, component_type, component_context) = component_id.into_parts();
431        let interconnect_capacity = self.interconnect_capacity;
432
433        let (sender, _) = match component_type {
434            ComponentType::Transform => self
435                .transform_consumers
436                .entry(component_id)
437                .or_insert_with(|| build_events_consumer_pair(component_context, interconnect_capacity)),
438            ComponentType::Destination => self
439                .destination_consumers
440                .entry(component_id)
441                .or_insert_with(|| build_events_consumer_pair(component_context, interconnect_capacity)),
442            ComponentType::Encoder => self
443                .encoder_consumers
444                .entry(component_id)
445                .or_insert_with(|| build_events_consumer_pair(component_context, interconnect_capacity)),
446            _ => panic!("Only transforms, destinations, and encoders can consume events."),
447        };
448
449        sender.clone()
450    }
451
452    fn get_or_create_payloads_dispatcher(&mut self, component_id: TypedComponentId) -> &mut PayloadsDispatcher {
453        let (component_id, component_type, component_context) = component_id.into_parts();
454
455        match component_type {
456            ComponentType::Encoder => self
457                .encoder_dispatchers
458                .entry(component_id)
459                .or_insert_with(|| PayloadsDispatcher::new(component_context)),
460            _ => {
461                panic!("Only encoders can dispatch payloads to downstream components.")
462            }
463        }
464    }
465
466    fn get_or_create_payloads_sender(&mut self, component_id: TypedComponentId) -> mpsc::Sender<PayloadsBuffer> {
467        let (component_id, component_type, component_context) = component_id.into_parts();
468        let interconnect_capacity = self.interconnect_capacity;
469
470        let (sender, _) = match component_type {
471            ComponentType::Forwarder => self
472                .forwarder_consumers
473                .entry(component_id)
474                .or_insert_with(|| build_payloads_consumer_pair(component_context, interconnect_capacity)),
475            _ => panic!("Only forwarders can consume payloads."),
476        };
477
478        sender.clone()
479    }
480}
481
482fn build_events_consumer_pair(
483    component_context: ComponentContext, interconnect_capacity: NonZeroUsize,
484) -> (mpsc::Sender<EventsBuffer>, EventsConsumer) {
485    let (sender, receiver) = mpsc::channel(interconnect_capacity.get());
486    let consumer = EventsConsumer::new(component_context, receiver);
487    (sender, consumer)
488}
489
490fn build_payloads_consumer_pair(
491    component_context: ComponentContext, interconnect_capacity: NonZeroUsize,
492) -> (mpsc::Sender<PayloadsBuffer>, PayloadsConsumer) {
493    let (sender, receiver) = mpsc::channel(interconnect_capacity.get());
494    let consumer = PayloadsConsumer::new(component_context, receiver);
495    (sender, consumer)
496}
497
498fn spawn_component<F>(
499    join_set: &mut JoinSet<Result<(), GenericError>>, context: ComponentContext,
500    allocation_group_token: AllocationGroupToken, component_future: F,
501) -> AbortHandle
502where
503    F: Future<Output = Result<(), GenericError>> + Send + 'static,
504{
505    let component_span = error_span!(
506        "component",
507        "type" = context.component_type().as_str(),
508        id = %context.component_id(),
509    );
510
511    let _span = component_span.enter();
512    let _guard = allocation_group_token.enter();
513
514    let component_task_name = format!(
515        "topology-{}-{}",
516        context.component_type().as_str(),
517        context.component_id()
518    );
519    join_set.spawn_traced_named(component_task_name, component_future)
520}
521
522#[cfg(test)]
523mod tests {
524    use std::num::NonZeroUsize;
525
526    use super::*;
527    use crate::data_model::event::EventType;
528    use crate::data_model::payload::PayloadType;
529    use crate::topology::graph::Graph;
530
531    #[test]
532    fn component_interconnects_adds_output_before_attaching() {
533        let mut graph = Graph::default();
534
535        // Create a set of components and connect them together.
536        graph
537            .with_source("source1", EventType::EventD)
538            .with_transform("transform1", EventType::EventD, EventType::EventD)
539            .with_encoder("encoder1", EventType::EventD, PayloadType::Raw)
540            .with_forwarder("forwarder1", PayloadType::Raw)
541            .with_destination("dest1", EventType::EventD)
542            .with_edge("source1", "transform1")
543            .with_edge("transform1", "encoder1")
544            .with_edge("encoder1", "forwarder1")
545            .with_edge("transform1", "dest1");
546
547        // Ensure we can properly build the interconnects for them, which requires adding the outputs
548        // before attaching senders to them:
549        let interconnect_capacity = NonZeroUsize::new(10).unwrap();
550        let _ = ComponentInterconnects::from_graph(interconnect_capacity, &graph)
551            .expect("should build interconnects successfully");
552    }
553}