1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use saluki_common::collections::FastIndexMap;
4use saluki_error::{ErrorContext as _, GenericError};
5use snafu::{OptionExt as _, Snafu};
6use tokio::{
7 pin, select,
8 task::{AbortHandle, Id, JoinSet},
9};
10use tracing::{debug, error, warn};
11
12use super::{
13 restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
14 shutdown::{ProcessShutdown, ShutdownHandle},
15};
16use crate::runtime::process::{Process, ProcessExt as _};
17
18pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
20
21#[derive(Debug, Snafu)]
23pub enum ProcessError {
24 #[snafu(display("Child process was aborted by the supervisor."))]
26 Aborted,
27
28 #[snafu(display("Child process panicked."))]
30 Panicked,
31
32 #[snafu(display("Child process terminated with an error: {}", source))]
34 Terminated {
35 source: GenericError,
37 },
38}
39
40pub enum ShutdownStrategy {
42 Graceful(Duration),
44
45 Brutal,
47}
48
49pub trait Supervisable: Send + Sync {
51 fn name(&self) -> &str;
53
54 fn shutdown_strategy(&self) -> ShutdownStrategy {
56 ShutdownStrategy::Graceful(Duration::from_secs(5))
57 }
58
59 fn initialize(&self, process_shutdown: ProcessShutdown) -> Option<SupervisorFuture>;
66}
67
68#[derive(Debug, Snafu)]
70#[snafu(context(suffix(false)))]
71pub enum SupervisorError {
72 #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
74 InvalidName {
75 name: String,
77 },
78
79 #[snafu(display("Supervisor has no child processes."))]
81 NoChildren,
82
83 #[snafu(display("Child process failed to initialize."))]
85 FailedToInitialize,
86
87 #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
89 Shutdown,
90}
91
92pub enum ChildSpecification {
101 Worker(Arc<dyn Supervisable>),
102 Supervisor(Supervisor),
103}
104
105impl ChildSpecification {
106 fn process_type(&self) -> &'static str {
107 match self {
108 Self::Worker(_) => "worker",
109 Self::Supervisor(_) => "supervisor",
110 }
111 }
112
113 fn name(&self) -> &str {
114 match self {
115 Self::Worker(worker) => worker.name(),
116 Self::Supervisor(supervisor) => &supervisor.supervisor_id,
117 }
118 }
119
120 fn shutdown_strategy(&self) -> ShutdownStrategy {
121 match self {
122 Self::Worker(worker) => worker.shutdown_strategy(),
123
124 Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
127 }
128 }
129
130 fn initialize(
131 &self, parent_process: &Process, process_shutdown: ProcessShutdown,
132 ) -> Result<Option<(Process, SupervisorFuture)>, SupervisorError> {
133 match self {
134 Self::Worker(worker) => {
135 let process = Process::worker(worker.name(), parent_process).context(InvalidName {
136 name: worker.name().to_string(),
137 })?;
138 Ok(worker.initialize(process_shutdown).map(|future| (process, future)))
139 }
140 Self::Supervisor(sup) => {
141 let process = Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
142 name: sup.supervisor_id.to_string(),
143 })?;
144 Ok(Some((
145 process.clone(),
146 sup.as_nested_process(process, process_shutdown),
147 )))
148 }
149 }
150 }
151}
152
153impl Clone for ChildSpecification {
154 fn clone(&self) -> Self {
155 match self {
156 Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
157 Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
158 }
159 }
160}
161
162impl From<Supervisor> for ChildSpecification {
163 fn from(supervisor: Supervisor) -> Self {
164 Self::Supervisor(supervisor)
165 }
166}
167
168impl<T> From<T> for ChildSpecification
169where
170 T: Supervisable + 'static,
171{
172 fn from(worker: T) -> Self {
173 Self::Worker(Arc::new(worker))
174 }
175}
176
177pub struct Supervisor {
204 supervisor_id: Arc<str>,
205 child_specs: Vec<ChildSpecification>,
206 restart_strategy: RestartStrategy,
207}
208
209impl Supervisor {
210 pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
212 if supervisor_id.as_ref().is_empty() {
216 return Err(SupervisorError::InvalidName {
217 name: supervisor_id.as_ref().to_string(),
218 });
219 }
220
221 Ok(Self {
222 supervisor_id: supervisor_id.as_ref().into(),
223 child_specs: Vec::new(),
224 restart_strategy: RestartStrategy::default(),
225 })
226 }
227
228 pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
230 self.restart_strategy = strategy;
231 self
232 }
233
234 pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
239 let child_spec = process.into();
240 debug!(
241 supervisor_id = %self.supervisor_id,
242 "Adding new static child process #{}. ({}, {})",
243 self.child_specs.len(),
244 child_spec.process_type(),
245 child_spec.name(),
246 );
247 self.child_specs.push(child_spec);
248 }
249
250 fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
251 match self.child_specs.get(child_spec_idx) {
252 Some(child_spec) => child_spec,
253 None => unreachable!("child spec index should never be out of bounds"),
254 }
255 }
256
257 fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
258 let child_spec = self.get_child_spec(child_spec_idx);
259 debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
260 worker_state.add_worker(child_spec_idx, child_spec)
261 }
262
263 fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
264 debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
265 for child_spec_idx in 0..self.child_specs.len() {
266 self.spawn_child(child_spec_idx, worker_state)?;
267 }
268
269 Ok(())
270 }
271
272 async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
273 if self.child_specs.is_empty() {
274 return Err(SupervisorError::NoChildren);
275 }
276
277 let mut restart_state = RestartState::new(self.restart_strategy);
278 let mut worker_state = WorkerState::new(process);
279
280 self.spawn_all_children(&mut worker_state)?;
282
283 let shutdown = process_shutdown.wait_for_shutdown();
285 pin!(shutdown);
286
287 loop {
288 select! {
289 _ = &mut shutdown => {
293 debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
294 worker_state.shutdown_workers().await;
295 break;
296 },
297 worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
298 Some((child_spec_idx, worker_result)) => {
303 let child_spec = self.get_child_spec(child_spec_idx);
304 match restart_state.evaluate_restart() {
305 RestartAction::Restart(mode) => match mode {
306 RestartMode::OneForOne => {
307 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
308 self.spawn_child(child_spec_idx, &mut worker_state)?;
309 }
310 RestartMode::OneForAll => {
311 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
312 worker_state.shutdown_workers().await;
313 self.spawn_all_children(&mut worker_state)?;
314 }
315 },
316 RestartAction::Shutdown => {
317 error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
318 worker_state.shutdown_workers().await;
319 return Err(SupervisorError::Shutdown);
320 }
321 }
322 },
323 None => unreachable!("should not have empty worker joinset prior to shutdown"),
324 }
325 }
326 }
327
328 Ok(())
329 }
330
331 fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> SupervisorFuture {
332 debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
335
336 let sup = self.inner_clone();
338
339 Box::pin(async move {
340 sup.run_inner(process, process_shutdown)
341 .await
342 .error_context("Nested supervisor failed to exit cleanly.")
343 })
344 }
345
346 pub async fn run(&mut self) -> Result<(), SupervisorError> {
352 let process_shutdown = ProcessShutdown::noop();
355 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
356 name: self.supervisor_id.to_string(),
357 })?;
358
359 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
360 self.run_inner(process.clone(), process_shutdown)
361 .into_instrumented(process)
362 .await
363 }
364
365 pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
374 let process_shutdown = ProcessShutdown::wrapped(shutdown);
375 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
376 name: self.supervisor_id.to_string(),
377 })?;
378
379 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
380 self.run_inner(process.clone(), process_shutdown)
381 .into_instrumented(process)
382 .await
383 }
384
385 fn inner_clone(&self) -> Self {
386 Self {
390 supervisor_id: Arc::clone(&self.supervisor_id),
391 child_specs: self.child_specs.clone(),
392 restart_strategy: self.restart_strategy,
393 }
394 }
395}
396
397struct ProcessState {
398 worker_id: usize,
399 shutdown_strategy: ShutdownStrategy,
400 shutdown_handle: ShutdownHandle,
401 abort_handle: AbortHandle,
402}
403
404struct WorkerState {
405 process: Process,
406 worker_tasks: JoinSet<Result<(), GenericError>>,
407 worker_map: FastIndexMap<Id, ProcessState>,
408}
409
410impl WorkerState {
411 fn new(process: Process) -> Self {
412 Self {
413 process,
414 worker_tasks: JoinSet::new(),
415 worker_map: FastIndexMap::default(),
416 }
417 }
418
419 fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
420 let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
421 match child_spec.initialize(&self.process, process_shutdown)? {
422 Some((process, worker)) => {
423 let shutdown_strategy = child_spec.shutdown_strategy();
424
425 let abort_handle = self.worker_tasks.spawn(worker.into_instrumented(process));
426 self.worker_map.insert(
427 abort_handle.id(),
428 ProcessState {
429 worker_id,
430 shutdown_strategy,
431 shutdown_handle,
432 abort_handle,
433 },
434 );
435
436 Ok(())
437 }
438 None => Err(SupervisorError::FailedToInitialize),
439 }
440 }
441
442 async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), ProcessError>)> {
443 debug!("Waiting for next process to complete.");
444
445 match self.worker_tasks.join_next_with_id().await {
446 Some(Ok((worker_task_id, worker_result))) => {
447 let process_state = self
448 .worker_map
449 .swap_remove(&worker_task_id)
450 .expect("worker task ID not found");
451 Some((
452 process_state.worker_id,
453 worker_result.map_err(|e| ProcessError::Terminated { source: e }),
454 ))
455 }
456 Some(Err(e)) => {
457 let worker_task_id = e.id();
458 let process_state = self
459 .worker_map
460 .swap_remove(&worker_task_id)
461 .expect("worker task ID not found");
462 let e = if e.is_cancelled() {
463 ProcessError::Aborted
464 } else {
465 ProcessError::Panicked
466 };
467 Some((process_state.worker_id, Err(e)))
468 }
469 None => None,
470 }
471 }
472
473 async fn shutdown_workers(&mut self) {
474 debug!("Shutting down all processes.");
475
476 while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
486 let ProcessState {
487 worker_id,
488 shutdown_strategy,
489 shutdown_handle,
490 abort_handle,
491 } = process_state;
492
493 let shutdown_deadline = match shutdown_strategy {
495 ShutdownStrategy::Graceful(timeout) => {
496 debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
497 shutdown_handle.trigger();
498
499 tokio::time::sleep(timeout)
500 }
501 ShutdownStrategy::Brutal => {
502 debug!(worker_id, "Forcefully aborting process.");
503 abort_handle.abort();
504
505 tokio::time::sleep(Duration::MAX)
508 }
509 };
510 pin!(shutdown_deadline);
511
512 loop {
515 select! {
516 worker_result = self.worker_tasks.join_next_with_id() => {
517 match worker_result {
518 Some(Ok((worker_task_id, _))) => {
519 if worker_task_id == current_worker_task_id {
520 debug!(?worker_task_id, "Target process exited successfully.");
521 break;
522 } else {
523 debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
524 self.worker_map.swap_remove(&worker_task_id);
525 }
526 },
527 Some(Err(e)) => {
528 let worker_task_id = e.id();
529 if worker_task_id == current_worker_task_id {
530 debug!(?worker_task_id, "Target process exited with error.");
531 break;
532 } else {
533 debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
534 self.worker_map.swap_remove(&worker_task_id);
535 }
536 }
537 None => unreachable!("worker task must exist in join set if we are waiting for it"),
538 }
539 },
540 _ = &mut shutdown_deadline => {
542 debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
543 abort_handle.abort();
544 }
545 }
546 }
547 }
548
549 debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
550 debug_assert!(
551 self.worker_tasks.is_empty(),
552 "worker tasks should be empty after shutdown"
553 );
554 }
555}