1use std::{collections::HashMap, future::Future, num::NonZeroUsize};
2
3use memory_accounting::{
4 allocator::{AllocationGroupToken, Tracked},
5 MemoryLimiter,
6};
7use saluki_common::task::JoinSetExt as _;
8use saluki_error::{generic_error, ErrorContext as _, GenericError};
9use saluki_health::HealthRegistry;
10use tokio::{
11 sync::mpsc,
12 task::{AbortHandle, JoinSet},
13};
14use tracing::{debug, error_span};
15
16use super::{
17 graph::Graph, running::RunningTopology, shutdown::ComponentShutdownCoordinator, ComponentId, EventsBuffer,
18 EventsConsumer, OutputName, PayloadsConsumer, RegisteredComponent, TypedComponentId,
19};
20use crate::{
21 components::{
22 destinations::{Destination, DestinationContext},
23 encoders::{Encoder, EncoderContext},
24 forwarders::{Forwarder, ForwarderContext},
25 sources::{Source, SourceContext},
26 transforms::{Transform, TransformContext},
27 ComponentContext, ComponentType,
28 },
29 topology::{context::TopologyContext, EventsDispatcher, PayloadsBuffer, PayloadsDispatcher},
30};
31
32pub struct BuiltTopology {
39 name: String,
40 graph: Graph,
41 sources: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Source + Send>>>>,
42 transforms: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Transform + Send>>>>,
43 destinations: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Destination + Send>>>>,
44 encoders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Encoder + Send>>>>,
45 forwarders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Forwarder + Send>>>>,
46 component_token: AllocationGroupToken,
47 interconnect_capacity: NonZeroUsize,
48}
49
50impl BuiltTopology {
51 #[allow(clippy::too_many_arguments)]
52 pub(crate) fn from_parts(
53 name: String, graph: Graph,
54 sources: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Source + Send>>>>,
55 transforms: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Transform + Send>>>>,
56 destinations: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Destination + Send>>>>,
57 encoders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Encoder + Send>>>>,
58 forwarders: HashMap<ComponentId, RegisteredComponent<Tracked<Box<dyn Forwarder + Send>>>>,
59 component_token: AllocationGroupToken, interconnect_capacity: NonZeroUsize,
60 ) -> Self {
61 Self {
62 name,
63 graph,
64 sources,
65 transforms,
66 destinations,
67 encoders,
68 forwarders,
69 component_token,
70 interconnect_capacity,
71 }
72 }
73
74 pub async fn spawn(
82 self, health_registry: &HealthRegistry, memory_limiter: MemoryLimiter,
83 ) -> Result<RunningTopology, GenericError> {
84 let root_component_name = format!("topology.{}", self.name);
85
86 let _guard = self.component_token.enter();
87
88 let thread_pool = tokio::runtime::Builder::new_multi_thread()
89 .worker_threads(8)
90 .enable_all()
91 .build()
92 .error_context("Failed to build asynchronous thread pool runtime.")?;
93 let thread_pool_handle = thread_pool.handle().clone();
94
95 let topology_context = TopologyContext::new(memory_limiter, health_registry.clone(), thread_pool_handle);
96
97 let mut component_tasks = JoinSet::new();
98 let mut component_task_map = HashMap::new();
99
100 let mut interconnects = ComponentInterconnects::from_graph(self.interconnect_capacity, &self.graph)
102 .error_context("Failed to build component interconnects.")?;
103
104 let mut shutdown_coordinator = ComponentShutdownCoordinator::default();
105
106 for (component_id, source) in self.sources {
108 let (source, component_registry) = source.into_parts();
109
110 let dispatcher = interconnects
111 .take_source_dispatcher(&component_id)
112 .ok_or_else(|| generic_error!("No events dispatcher found for source component '{}'", component_id))?;
113
114 let shutdown_handle = shutdown_coordinator.register();
115 let health_handle = health_registry
116 .register_component(format!("{}.sources.{}", root_component_name, component_id))
117 .expect("duplicate source component ID in health registry");
118
119 let component_context = ComponentContext::source(component_id.clone());
120 let context = SourceContext::new(
121 &topology_context,
122 &component_context,
123 component_registry,
124 shutdown_handle,
125 health_handle,
126 dispatcher,
127 );
128
129 let (alloc_group, source) = source.into_parts();
130 let task_handle = spawn_component(
131 &mut component_tasks,
132 component_context,
133 alloc_group,
134 source.run(context),
135 );
136 component_task_map.insert(task_handle.id(), component_id);
137 }
138
139 for (component_id, transform) in self.transforms {
141 let (transform, component_registry) = transform.into_parts();
142
143 let dispatcher = interconnects.take_transform_dispatcher(&component_id).ok_or_else(|| {
144 generic_error!("No events dispatcher found for transform component '{}'", component_id)
145 })?;
146
147 let consumer = interconnects
148 .take_transform_consumer(&component_id)
149 .ok_or_else(|| generic_error!("No events consumer found for transform component '{}'", component_id))?;
150
151 let health_handle = health_registry
152 .register_component(format!("{}.transforms.{}", root_component_name, component_id))
153 .expect("duplicate transform component ID in health registry");
154
155 let component_context = ComponentContext::transform(component_id.clone());
156 let context = TransformContext::new(
157 &topology_context,
158 &component_context,
159 component_registry,
160 health_handle,
161 dispatcher,
162 consumer,
163 );
164
165 let (alloc_group, transform) = transform.into_parts();
166 let task_handle = spawn_component(
167 &mut component_tasks,
168 component_context,
169 alloc_group,
170 transform.run(context),
171 );
172 component_task_map.insert(task_handle.id(), component_id);
173 }
174
175 for (component_id, destination) in self.destinations {
177 let (destination, component_registry) = destination.into_parts();
178
179 let consumer = interconnects.take_destination_consumer(&component_id).ok_or_else(|| {
180 generic_error!("No events consumer found for destination component '{}'", component_id)
181 })?;
182
183 let health_handle = health_registry
184 .register_component(format!("{}.destinations.{}", root_component_name, component_id))
185 .expect("duplicate destination component ID in health registry");
186
187 let component_context = ComponentContext::destination(component_id.clone());
188 let context = DestinationContext::new(
189 &topology_context,
190 &component_context,
191 component_registry,
192 health_handle,
193 consumer,
194 );
195
196 let (alloc_group, destination) = destination.into_parts();
197 let task_handle = spawn_component(
198 &mut component_tasks,
199 component_context,
200 alloc_group,
201 destination.run(context),
202 );
203 component_task_map.insert(task_handle.id(), component_id);
204 }
205
206 for (component_id, encoder) in self.encoders {
208 let (encoder, component_registry) = encoder.into_parts();
209
210 let dispatcher = interconnects.take_encoder_dispatcher(&component_id).ok_or_else(|| {
211 generic_error!("No payloads dispatcher found for encoder component '{}'", component_id)
212 })?;
213
214 let consumer = interconnects
215 .take_encoder_consumer(&component_id)
216 .ok_or_else(|| generic_error!("No events consumer found for encoder component '{}'", component_id))?;
217
218 let health_handle = health_registry
219 .register_component(format!("{}.encoders.{}", root_component_name, component_id))
220 .expect("duplicate encoder component ID in health registry");
221
222 let component_context = ComponentContext::encoder(component_id.clone());
223 let context = EncoderContext::new(
224 &topology_context,
225 &component_context,
226 component_registry,
227 health_handle,
228 dispatcher,
229 consumer,
230 );
231
232 let (alloc_group, encoder) = encoder.into_parts();
233 let task_handle = spawn_component(
234 &mut component_tasks,
235 component_context,
236 alloc_group,
237 encoder.run(context),
238 );
239 component_task_map.insert(task_handle.id(), component_id);
240 }
241
242 for (component_id, forwarder) in self.forwarders {
244 let (forwarder, component_registry) = forwarder.into_parts();
245
246 let consumer = interconnects.take_forwarder_consumer(&component_id).ok_or_else(|| {
247 generic_error!("No payloads consumer found for forwarder component '{}'", component_id)
248 })?;
249
250 let health_handle = health_registry
251 .register_component(format!("{}.forwarders.{}", root_component_name, component_id))
252 .expect("duplicate forwarder component ID in health registry");
253
254 let component_context = ComponentContext::forwarder(component_id.clone());
255 let context = ForwarderContext::new(
256 &topology_context,
257 &component_context,
258 component_registry,
259 health_handle,
260 consumer,
261 );
262
263 let (alloc_group, forwarder) = forwarder.into_parts();
264 let task_handle = spawn_component(
265 &mut component_tasks,
266 component_context,
267 alloc_group,
268 forwarder.run(context),
269 );
270 component_task_map.insert(task_handle.id(), component_id);
271 }
272
273 Ok(RunningTopology::from_parts(
274 thread_pool,
275 shutdown_coordinator,
276 component_tasks,
277 component_task_map,
278 ))
279 }
280}
281
282struct ComponentInterconnects {
283 interconnect_capacity: NonZeroUsize,
284 source_dispatchers: HashMap<ComponentId, EventsDispatcher>,
285 transform_consumers: HashMap<ComponentId, (mpsc::Sender<EventsBuffer>, EventsConsumer)>,
286 transform_dispatchers: HashMap<ComponentId, EventsDispatcher>,
287 destination_consumers: HashMap<ComponentId, (mpsc::Sender<EventsBuffer>, EventsConsumer)>,
288 encoder_consumers: HashMap<ComponentId, (mpsc::Sender<EventsBuffer>, EventsConsumer)>,
289 encoder_dispatchers: HashMap<ComponentId, PayloadsDispatcher>,
290 forwarder_consumers: HashMap<ComponentId, (mpsc::Sender<PayloadsBuffer>, PayloadsConsumer)>,
291}
292
293impl ComponentInterconnects {
294 fn from_graph(interconnect_capacity: NonZeroUsize, graph: &Graph) -> Result<Self, GenericError> {
295 let mut interconnects = Self {
296 interconnect_capacity,
297 source_dispatchers: HashMap::new(),
298 transform_consumers: HashMap::new(),
299 transform_dispatchers: HashMap::new(),
300 destination_consumers: HashMap::new(),
301 encoder_consumers: HashMap::new(),
302 encoder_dispatchers: HashMap::new(),
303 forwarder_consumers: HashMap::new(),
304 };
305
306 interconnects.generate_interconnects(graph)?;
307 Ok(interconnects)
308 }
309
310 fn take_source_dispatcher(&mut self, component_id: &ComponentId) -> Option<EventsDispatcher> {
311 self.source_dispatchers.remove(component_id)
312 }
313
314 fn take_transform_dispatcher(&mut self, component_id: &ComponentId) -> Option<EventsDispatcher> {
315 self.transform_dispatchers.remove(component_id)
316 }
317
318 fn take_encoder_dispatcher(&mut self, component_id: &ComponentId) -> Option<PayloadsDispatcher> {
319 self.encoder_dispatchers.remove(component_id)
320 }
321
322 fn take_transform_consumer(&mut self, component_id: &ComponentId) -> Option<EventsConsumer> {
323 self.transform_consumers
324 .remove(component_id)
325 .map(|(_, consumer)| consumer)
326 }
327
328 fn take_destination_consumer(&mut self, component_id: &ComponentId) -> Option<EventsConsumer> {
329 self.destination_consumers
330 .remove(component_id)
331 .map(|(_, consumer)| consumer)
332 }
333
334 fn take_encoder_consumer(&mut self, component_id: &ComponentId) -> Option<EventsConsumer> {
335 self.encoder_consumers
336 .remove(component_id)
337 .map(|(_, consumer)| consumer)
338 }
339
340 fn take_forwarder_consumer(&mut self, component_id: &ComponentId) -> Option<PayloadsConsumer> {
341 self.forwarder_consumers
342 .remove(component_id)
343 .map(|(_, consumer)| consumer)
344 }
345
346 fn generate_interconnects(&mut self, graph: &Graph) -> Result<(), GenericError> {
347 let outbound_edges = graph.get_outbound_directed_edges();
352 for (upstream_id, output_map) in outbound_edges {
353 match upstream_id.component_type() {
354 ComponentType::Source | ComponentType::Transform => {
355 self.generate_event_interconnect(upstream_id, output_map)?;
356 }
357 ComponentType::Encoder => self.generate_payload_interconnect(upstream_id, output_map)?,
358 _ => panic!(
359 "Only sources, transforms, and encoders can dispatch events/payloads to downstream components."
360 ),
361 }
362 }
363
364 Ok(())
365 }
366
367 fn generate_event_interconnect(
368 &mut self, upstream_id: TypedComponentId, output_map: HashMap<OutputName, Vec<TypedComponentId>>,
369 ) -> Result<(), GenericError> {
370 for (upstream_output_id, downstream_ids) in output_map {
371 let mut senders = Vec::new();
372 for downstream_id in downstream_ids {
373 debug!(upstream_id = %upstream_id.component_id(), %upstream_output_id, downstream_id = %downstream_id.component_id(), "Adding dispatcher output.");
374 let sender = self.get_or_create_events_sender(downstream_id);
375 senders.push(sender);
376 }
377
378 let dispatcher = self.get_or_create_events_dispatcher(upstream_id.clone());
379 dispatcher.add_output(upstream_output_id.clone())?;
380
381 for sender in senders {
382 dispatcher.attach_sender_to_output(&upstream_output_id, sender)?;
383 }
384 }
385
386 Ok(())
387 }
388
389 fn generate_payload_interconnect(
390 &mut self, upstream_id: TypedComponentId, output_map: HashMap<OutputName, Vec<TypedComponentId>>,
391 ) -> Result<(), GenericError> {
392 for (upstream_output_id, downstream_ids) in output_map {
393 let mut senders = Vec::new();
394 for downstream_id in downstream_ids {
395 debug!(upstream_id = %upstream_id.component_id(), %upstream_output_id, downstream_id = %downstream_id.component_id(), "Adding dispatcher output.");
396 let sender = self.get_or_create_payloads_sender(downstream_id);
397 senders.push(sender);
398 }
399
400 let dispatcher = self.get_or_create_payloads_dispatcher(upstream_id.clone());
401 dispatcher.add_output(upstream_output_id.clone())?;
402
403 for sender in senders {
404 dispatcher.attach_sender_to_output(&upstream_output_id, sender)?;
405 }
406 }
407
408 Ok(())
409 }
410
411 fn get_or_create_events_dispatcher(&mut self, component_id: TypedComponentId) -> &mut EventsDispatcher {
412 let (component_id, component_type, component_context) = component_id.into_parts();
413
414 match component_type {
415 ComponentType::Source => self
416 .source_dispatchers
417 .entry(component_id)
418 .or_insert_with(|| EventsDispatcher::new(component_context)),
419 ComponentType::Transform => self
420 .transform_dispatchers
421 .entry(component_id)
422 .or_insert_with(|| EventsDispatcher::new(component_context)),
423 _ => {
424 panic!("Only sources and transforms can dispatch events to downstream components.")
425 }
426 }
427 }
428
429 fn get_or_create_events_sender(&mut self, component_id: TypedComponentId) -> mpsc::Sender<EventsBuffer> {
430 let (component_id, component_type, component_context) = component_id.into_parts();
431 let interconnect_capacity = self.interconnect_capacity;
432
433 let (sender, _) = match component_type {
434 ComponentType::Transform => self
435 .transform_consumers
436 .entry(component_id)
437 .or_insert_with(|| build_events_consumer_pair(component_context, interconnect_capacity)),
438 ComponentType::Destination => self
439 .destination_consumers
440 .entry(component_id)
441 .or_insert_with(|| build_events_consumer_pair(component_context, interconnect_capacity)),
442 ComponentType::Encoder => self
443 .encoder_consumers
444 .entry(component_id)
445 .or_insert_with(|| build_events_consumer_pair(component_context, interconnect_capacity)),
446 _ => panic!("Only transforms, destinations, and encoders can consume events."),
447 };
448
449 sender.clone()
450 }
451
452 fn get_or_create_payloads_dispatcher(&mut self, component_id: TypedComponentId) -> &mut PayloadsDispatcher {
453 let (component_id, component_type, component_context) = component_id.into_parts();
454
455 match component_type {
456 ComponentType::Encoder => self
457 .encoder_dispatchers
458 .entry(component_id)
459 .or_insert_with(|| PayloadsDispatcher::new(component_context)),
460 _ => {
461 panic!("Only encoders can dispatch payloads to downstream components.")
462 }
463 }
464 }
465
466 fn get_or_create_payloads_sender(&mut self, component_id: TypedComponentId) -> mpsc::Sender<PayloadsBuffer> {
467 let (component_id, component_type, component_context) = component_id.into_parts();
468 let interconnect_capacity = self.interconnect_capacity;
469
470 let (sender, _) = match component_type {
471 ComponentType::Forwarder => self
472 .forwarder_consumers
473 .entry(component_id)
474 .or_insert_with(|| build_payloads_consumer_pair(component_context, interconnect_capacity)),
475 _ => panic!("Only forwarders can consume payloads."),
476 };
477
478 sender.clone()
479 }
480}
481
482fn build_events_consumer_pair(
483 component_context: ComponentContext, interconnect_capacity: NonZeroUsize,
484) -> (mpsc::Sender<EventsBuffer>, EventsConsumer) {
485 let (sender, receiver) = mpsc::channel(interconnect_capacity.get());
486 let consumer = EventsConsumer::new(component_context, receiver);
487 (sender, consumer)
488}
489
490fn build_payloads_consumer_pair(
491 component_context: ComponentContext, interconnect_capacity: NonZeroUsize,
492) -> (mpsc::Sender<PayloadsBuffer>, PayloadsConsumer) {
493 let (sender, receiver) = mpsc::channel(interconnect_capacity.get());
494 let consumer = PayloadsConsumer::new(component_context, receiver);
495 (sender, consumer)
496}
497
498fn spawn_component<F>(
499 join_set: &mut JoinSet<Result<(), GenericError>>, context: ComponentContext,
500 allocation_group_token: AllocationGroupToken, component_future: F,
501) -> AbortHandle
502where
503 F: Future<Output = Result<(), GenericError>> + Send + 'static,
504{
505 let component_span = error_span!(
506 "component",
507 "type" = context.component_type().as_str(),
508 id = %context.component_id(),
509 );
510
511 let _span = component_span.enter();
512 let _guard = allocation_group_token.enter();
513
514 let component_task_name = format!(
515 "topology-{}-{}",
516 context.component_type().as_str(),
517 context.component_id()
518 );
519 join_set.spawn_traced_named(component_task_name, component_future)
520}
521
522#[cfg(test)]
523mod tests {
524 use std::num::NonZeroUsize;
525
526 use super::*;
527 use crate::data_model::event::EventType;
528 use crate::data_model::payload::PayloadType;
529 use crate::topology::graph::Graph;
530
531 #[test]
532 fn component_interconnects_adds_output_before_attaching() {
533 let mut graph = Graph::default();
534
535 graph
537 .with_source("source1", EventType::EventD)
538 .with_transform("transform1", EventType::EventD, EventType::EventD)
539 .with_encoder("encoder1", EventType::EventD, PayloadType::Raw)
540 .with_forwarder("forwarder1", PayloadType::Raw)
541 .with_destination("dest1", EventType::EventD)
542 .with_edge("source1", "transform1")
543 .with_edge("transform1", "encoder1")
544 .with_edge("encoder1", "forwarder1")
545 .with_edge("transform1", "dest1");
546
547 let interconnect_capacity = NonZeroUsize::new(10).unwrap();
550 let _ = ComponentInterconnects::from_graph(interconnect_capacity, &graph)
551 .expect("should build interconnects successfully");
552 }
553}