saluki_core/topology/
running.rs1use std::{
2 collections::HashMap,
3 time::{Duration, Instant},
4};
5
6use saluki_error::{generic_error, GenericError};
7use tokio::{
8 pin,
9 runtime::Runtime,
10 select,
11 task::{Id, JoinError, JoinSet},
12 time::{interval, sleep},
13};
14use tracing::{debug, error, info, warn};
15
16use super::{shutdown::ComponentShutdownCoordinator, ComponentId};
17
18pub struct RunningTopology {
20 thread_pool: Runtime,
21 shutdown_coordinator: ComponentShutdownCoordinator,
22 component_tasks: JoinSet<Result<(), GenericError>>,
23 component_task_map: HashMap<Id, ComponentId>,
24}
25
26impl RunningTopology {
27 pub(super) fn from_parts(
29 thread_pool: Runtime, shutdown_coordinator: ComponentShutdownCoordinator,
30 component_tasks: JoinSet<Result<(), GenericError>>, component_task_map: HashMap<Id, ComponentId>,
31 ) -> Self {
32 Self {
33 thread_pool,
34 shutdown_coordinator,
35 component_tasks,
36 component_task_map,
37 }
38 }
39
40 pub async fn wait_for_unexpected_finish(&mut self) {
45 let task_result = self
46 .component_tasks
47 .join_next_with_id()
48 .await
49 .expect("no components to wait for");
50
51 handle_task_result(&mut self.component_task_map, task_result, true);
55 }
56
57 pub async fn shutdown(self) -> Result<(), GenericError> {
66 self.shutdown_with_timeout(Duration::MAX).await
67 }
68
69 pub async fn shutdown_with_timeout(mut self, timeout: Duration) -> Result<(), GenericError> {
76 let shutdown_deadline = Instant::now() + timeout;
77
78 let shutdown_timeout = sleep(timeout);
79 pin!(shutdown_timeout);
80
81 let mut progress_interval = interval(Duration::from_secs(5));
82 progress_interval.tick().await;
83
84 self.shutdown_coordinator.shutdown();
87
88 let mut stopped_cleanly = true;
89
90 loop {
91 select! {
92 maybe_task_result = self.component_tasks.join_next_with_id() => match maybe_task_result {
94 None => {
95 info!("All components stopped.");
96 break;
97 },
98 Some(component_result) => if !handle_task_result(&mut self.component_task_map, component_result, false) {
99 stopped_cleanly = false;
100 },
101 },
102
103 _ = progress_interval.tick() => {
105 let mut remaining_components = self.component_task_map.values()
106 .map(|id| id.to_string())
107 .collect::<Vec<_>>();
108 remaining_components.sort();
109 let remaining_time = shutdown_deadline.saturating_duration_since(Instant::now());
110
111 info!("Waiting for the remaining component(s) to stop: {}. {} seconds remaining.", remaining_components.join(", "), remaining_time.as_secs_f64().round() as u64);
112 },
113
114 _ = &mut shutdown_timeout => {
116 warn!("Forcefully stopping topology after shutdown grace period.");
117 stopped_cleanly = false;
118 break;
119 },
120 }
121 }
122
123 self.thread_pool.shutdown_background();
127
128 if stopped_cleanly {
129 Ok(())
130 } else {
131 Err(generic_error!("Topology failed to shutdown cleanly."))
132 }
133 }
134}
135
136fn handle_task_result(
140 component_task_map: &mut HashMap<Id, ComponentId>, task_result: Result<(Id, Result<(), GenericError>), JoinError>,
141 unexpected: bool,
142) -> bool {
143 let (task_id, stopped_successfully) = match task_result {
144 Ok((id, component_result)) => {
145 let component_id = component_task_map.get(&id).expect("component ID not found");
146 match component_result {
147 Ok(()) => {
148 if unexpected {
149 warn!(%component_id, "Component unexpectedly finished.");
150 } else {
151 debug!(%component_id, "Component stopped.");
152 }
153 (id, true)
154 }
155 Err(e) => {
156 error!(%component_id, error = %e, "Component stopped with error.");
157 (id, false)
158 }
159 }
160 }
161 Err(e) => {
162 let id = e.id();
163 let component_id = component_task_map.get(&id).expect("component ID not found");
164 error!(%component_id, error = %e, "Component task failed unexpectedly.");
165 (id, false)
166 }
167 };
168
169 component_task_map.remove(&task_id);
170 stopped_successfully
171}