Skip to main content

saluki_core/runtime/
supervisor.rs

1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_error::{ErrorContext as _, GenericError};
6use snafu::{OptionExt as _, Snafu};
7use tokio::{
8    pin, select,
9    task::{AbortHandle, Id, JoinSet},
10};
11use tracing::{debug, error, warn};
12
13use super::{
14    dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
15    restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
16    shutdown::{ProcessShutdown, ShutdownHandle},
17};
18use crate::runtime::process::{Process, ProcessExt as _};
19
20/// A `Future` that represents the execution of a supervised process.
21pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
22
23/// A `Future` that represents the full lifecycle of a worker, including initialization.
24///
25/// Unlike [`SupervisorFuture`], which only represents the runtime phase, this future first performs async
26/// initialization and then runs the worker. This allows initialization to happen concurrently when multiple workers are
27/// spawned, and keeps the supervisor loop responsive to shutdown signals during initialization.
28type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
29
30/// Worker lifecycle errors.
31///
32/// Distinguishes between initialization failures (which should NOT trigger restart logic) and runtime failures (which
33/// are eligible for restart).
34#[derive(Debug)]
35enum WorkerError {
36    /// The worker failed during async initialization.
37    Initialization(InitializationError),
38
39    /// The worker failed during runtime execution.
40    Runtime(GenericError),
41}
42
43/// Process errors.
44#[derive(Debug, Snafu)]
45pub enum ProcessError {
46    /// The child process was aborted by the supervisor.
47    #[snafu(display("Child process was aborted by the supervisor."))]
48    Aborted,
49
50    /// The child process panicked.
51    #[snafu(display("Child process panicked."))]
52    Panicked,
53
54    /// The child process terminated with an error.
55    #[snafu(display("Child process terminated with an error: {}", source))]
56    Terminated {
57        /// The error that caused the termination.
58        source: GenericError,
59    },
60}
61
62/// Initialization errors.
63///
64/// Initialization errors are distinct from runtime errors: they indicate that a process could not be started at all
65/// (e.g., failed to bind a port, missing configuration). These errors do NOT trigger restart logic; instead, they
66/// immediately propagate up and fail the supervisor.
67#[derive(Debug, Snafu)]
68#[snafu(context(suffix(false)))]
69pub enum InitializationError {
70    /// The process could not be initialized due to an error.
71    #[snafu(display("Process failed to initialize: {}", source))]
72    Failed {
73        /// The underlying error that caused initialization to fail.
74        source: GenericError,
75    },
76
77    /// The process is permanently unavailable and cannot be initialized.
78    ///
79    /// This is for cases where initialization is structurally impossible, not due to a transient error.
80    #[snafu(display("Process is permanently unavailable"))]
81    PermanentlyUnavailable,
82}
83
84/// Strategy for shutting down a process.
85pub enum ShutdownStrategy {
86    /// Waits for the configured duration for the process to exit, and then forcefully aborts it otherwise.
87    Graceful(Duration),
88
89    /// Forcefully aborts the process without waiting.
90    Brutal,
91}
92
93/// A supervisable process.
94#[async_trait]
95pub trait Supervisable: Send + Sync {
96    /// Returns the name of the process.
97    fn name(&self) -> &str;
98
99    /// Returns the shutdown strategy for the process.
100    fn shutdown_strategy(&self) -> ShutdownStrategy {
101        ShutdownStrategy::Graceful(Duration::from_secs(5))
102    }
103
104    /// Initializes the process asynchronously.
105    ///
106    /// During initialization, any resources or configuration for the process can be created asynchronously, and the
107    /// same runtime that is used for running the process is used for initialization. The resulting future is expected
108    /// to complete as soon as reasonably possible after `process_shutdown` resolves.
109    ///
110    /// # Errors
111    ///
112    /// If the process cannot be initialized, an error is returned.
113    async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError>;
114}
115
116/// Supervisor errors.
117#[derive(Debug, Snafu)]
118#[snafu(context(suffix(false)))]
119pub enum SupervisorError {
120    /// Supervisor or worker name is invalid.
121    #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
122    InvalidName {
123        /// The name of the supervisor is invalid.
124        name: String,
125    },
126
127    /// The supervisor has no child processes.
128    #[snafu(display("Supervisor has no child processes."))]
129    NoChildren,
130
131    /// A child process failed to initialize.
132    ///
133    /// This error indicates that a child could not complete its async initialization. This is distinct from runtime
134    /// failures and does NOT trigger restart logic.
135    #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
136    FailedToInitialize {
137        /// The name of the child that failed to initialize.
138        child_name: String,
139
140        /// The underlying initialization error.
141        source: InitializationError,
142    },
143
144    /// The supervisor exceeded its restart limits and was forced to shutdown.
145    #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
146    Shutdown,
147}
148
149/// A child process specification.
150///
151/// All workers added to a [`Supervisor`] must be specified as a `ChildSpecification`. This acts a template for how the
152/// supervisor should create the underlying future that represents the process, as well as information about the
153/// process, such as its name, shutdown strategy, and more.
154///
155/// A child process specification can be created implicitly from an existing [`Supervisor`], or any type that implements
156/// [`Supervisable`].
157pub enum ChildSpecification {
158    Worker(Arc<dyn Supervisable>),
159    Supervisor(Supervisor),
160}
161
162impl ChildSpecification {
163    fn process_type(&self) -> &'static str {
164        match self {
165            Self::Worker(_) => "worker",
166            Self::Supervisor(_) => "supervisor",
167        }
168    }
169
170    fn name(&self) -> &str {
171        match self {
172            Self::Worker(worker) => worker.name(),
173            Self::Supervisor(supervisor) => &supervisor.supervisor_id,
174        }
175    }
176
177    fn shutdown_strategy(&self) -> ShutdownStrategy {
178        match self {
179            Self::Worker(worker) => worker.shutdown_strategy(),
180
181            // Supervisors should always be given as much time as necessary shutdown down gracefully to ensure that the
182            // entire supervision subtree can be shutdown cleanly.
183            Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
184        }
185    }
186
187    fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
188        match self {
189            Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
190                name: worker.name().to_string(),
191            }),
192            Self::Supervisor(sup) => {
193                Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
194                    name: sup.supervisor_id.to_string(),
195                })
196            }
197        }
198    }
199
200    fn create_worker_future(
201        &self, process: Process, process_shutdown: ProcessShutdown,
202    ) -> Result<WorkerFuture, SupervisorError> {
203        match self {
204            Self::Worker(worker) => {
205                let worker = Arc::clone(worker);
206                Ok(Box::pin(async move {
207                    let run_future = worker
208                        .initialize(process_shutdown)
209                        .await
210                        .map_err(WorkerError::Initialization)?;
211                    run_future.await.map_err(WorkerError::Runtime)
212                }))
213            }
214            Self::Supervisor(sup) => {
215                match sup.runtime_mode() {
216                    RuntimeMode::Ambient => {
217                        // Run on the parent's ambient runtime.
218                        Ok(sup.as_nested_process(process, process_shutdown))
219                    }
220                    RuntimeMode::Dedicated(config) => {
221                        // Spawn in a dedicated runtime on a new OS thread.
222                        let child_name = sup.supervisor_id.to_string();
223                        let handle = spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown)
224                            .map_err(|e| SupervisorError::FailedToInitialize {
225                                child_name,
226                                source: InitializationError::Failed { source: e },
227                            })?;
228
229                        Ok(Box::pin(async move { handle.await.map_err(WorkerError::Runtime) }))
230                    }
231                }
232            }
233        }
234    }
235}
236
237impl Clone for ChildSpecification {
238    fn clone(&self) -> Self {
239        match self {
240            Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
241            Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
242        }
243    }
244}
245
246impl From<Supervisor> for ChildSpecification {
247    fn from(supervisor: Supervisor) -> Self {
248        Self::Supervisor(supervisor)
249    }
250}
251
252impl<T> From<T> for ChildSpecification
253where
254    T: Supervisable + 'static,
255{
256    fn from(worker: T) -> Self {
257        Self::Worker(Arc::new(worker))
258    }
259}
260
261/// Supervises a set of workers.
262///
263/// # Workers
264///
265/// All workers are defined through implementation of the [`Supervisable`] trait, which provides the logic for both
266/// creating the underlying worker future that is spawned, as well as other metadata, such as the worker's name, how the
267/// worker should be shutdown, and so on.
268///
269/// Supervisors also (indirectly) implement the [`Supervisable`] trait, allowing them to be supervised by other
270/// supervisors in order to construct _supervision trees_.
271///
272/// # Instrumentation
273///
274/// Supervisors automatically create their own allocation group
275/// ([`TrackingAllocator`][memory_accounting::allocator::TrackingAllocator]), which is used to track both the memory
276/// usage of the supervisor itself and its children. Additionally, individual worker processes are wrapped in a
277/// dedicated [`tracing::Span`] to allow tracing the causal relationship between arbitrary code and the worker executing
278/// it.
279///
280/// # Restart Strategies
281///
282/// As the main purpose of a supervisor, restart behavior is fully configurable. A number of restart strategies are
283/// available, which generally relate to the purpose of the supervisor: whether the workers being managed are
284/// independent or interdependent.
285///
286/// All restart strategies are configured through [`RestartStrategy`], which has more information on the available
287/// strategies and configuration settings.
288pub struct Supervisor {
289    supervisor_id: Arc<str>,
290    child_specs: Vec<ChildSpecification>,
291    restart_strategy: RestartStrategy,
292    runtime_mode: RuntimeMode,
293}
294
295impl Supervisor {
296    /// Creates an empty `Supervisor` with the default restart strategy.
297    pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
298        // We try to throw an error about invalid names as early as possible. This is a manual check, so we might still
299        // encounter an error later when actually running the supervisor, but this is a good first step to catch the
300        // bulk of invalid names.
301        if supervisor_id.as_ref().is_empty() {
302            return Err(SupervisorError::InvalidName {
303                name: supervisor_id.as_ref().to_string(),
304            });
305        }
306
307        Ok(Self {
308            supervisor_id: supervisor_id.as_ref().into(),
309            child_specs: Vec::new(),
310            restart_strategy: RestartStrategy::default(),
311            runtime_mode: RuntimeMode::default(),
312        })
313    }
314
315    /// Returns the supervisor's ID.
316    pub fn id(&self) -> &str {
317        &self.supervisor_id
318    }
319
320    /// Sets the restart strategy for the supervisor.
321    pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
322        self.restart_strategy = strategy;
323        self
324    }
325
326    /// Configures this supervisor to run in a dedicated runtime.
327    ///
328    /// When this supervisor is added as a child to another supervisor, it will spawn its own OS thread(s) and Tokio
329    /// runtime instead of running on the parent's ambient runtime.
330    ///
331    /// This provides runtime isolation, which can be useful for:
332    /// - CPU-bound work that shouldn't block the parent's runtime
333    /// - Isolating failures in one part of the system
334    /// - Using different runtime configurations (e.g., single-threaded vs multi-threaded)
335    pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
336        self.runtime_mode = RuntimeMode::Dedicated(config);
337        self
338    }
339
340    /// Returns the runtime mode for this supervisor.
341    pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
342        &self.runtime_mode
343    }
344
345    /// Adds a worker to the supervisor.
346    ///
347    /// A worker can be anything that implements the [`Supervisable`] trait. A [`Supervisor`] can also be added as a
348    /// worker and managed in a nested fashion, known as a supervision tree.
349    pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
350        let child_spec = process.into();
351        debug!(
352            supervisor_id = %self.supervisor_id,
353            "Adding new static child process #{}. ({}, {})",
354            self.child_specs.len(),
355            child_spec.process_type(),
356            child_spec.name(),
357        );
358        self.child_specs.push(child_spec);
359    }
360
361    fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
362        match self.child_specs.get(child_spec_idx) {
363            Some(child_spec) => child_spec,
364            None => unreachable!("child spec index should never be out of bounds"),
365        }
366    }
367
368    fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
369        let child_spec = self.get_child_spec(child_spec_idx);
370        debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
371        worker_state.add_worker(child_spec_idx, child_spec)
372    }
373
374    fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
375        debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
376        for child_spec_idx in 0..self.child_specs.len() {
377            self.spawn_child(child_spec_idx, worker_state)?;
378        }
379
380        Ok(())
381    }
382
383    async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
384        if self.child_specs.is_empty() {
385            return Err(SupervisorError::NoChildren);
386        }
387
388        let mut restart_state = RestartState::new(self.restart_strategy);
389        let mut worker_state = WorkerState::new(process);
390
391        // Spawn all child processes. Since initialization is folded into each worker's task, this returns immediately
392        // after spawning -- children initialize concurrently in the background.
393        self.spawn_all_children(&mut worker_state)?;
394
395        // Now we supervise.
396        let shutdown = process_shutdown.wait_for_shutdown();
397        pin!(shutdown);
398
399        loop {
400            select! {
401                // Shutdown has been triggered.
402                //
403                // Propagate shutdown to all child processes and wait for them to exit.
404                _ = &mut shutdown => {
405                    debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
406                    worker_state.shutdown_workers().await;
407                    break;
408                },
409                worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
410                    // TODO: Erlang/OTP defaults to always trying to restart a process, even if it doesn't terminate due
411                    // to a legitimate failure. It does allow configuring this behavior on a per-process basis, however.
412                    // We don't support dynamically adding child processes, which is the only real use case I can think
413                    // of for having non-long-lived child processes... so I think for now, we're OK just always try to
414                    // restart.
415                    Some((child_spec_idx, worker_result)) =>  {
416                        let child_spec = self.get_child_spec(child_spec_idx);
417
418                        // Initialization failures are not eligible for restart -- they propagate immediately.
419                        if let Err(WorkerError::Initialization(e)) = worker_result {
420                            error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), "Child process failed to initialize: {}", e);
421                            worker_state.shutdown_workers().await;
422                            return Err(SupervisorError::FailedToInitialize {
423                                child_name: child_spec.name().to_string(),
424                                source: e,
425                            });
426                        }
427
428                        // Convert the worker result to a process error for restart evaluation.
429                        let worker_result = worker_result
430                            .map_err(|e| match e {
431                                WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
432                                WorkerError::Initialization(_) => unreachable!("handled above"),
433                            });
434
435                        match restart_state.evaluate_restart() {
436                            RestartAction::Restart(mode) => match mode {
437                                RestartMode::OneForOne => {
438                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
439                                    self.spawn_child(child_spec_idx, &mut worker_state)?;
440                                }
441                                RestartMode::OneForAll => {
442                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
443                                    worker_state.shutdown_workers().await;
444                                    self.spawn_all_children(&mut worker_state)?;
445                                }
446                            },
447                            RestartAction::Shutdown => {
448                                error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
449                                worker_state.shutdown_workers().await;
450                                return Err(SupervisorError::Shutdown);
451                            }
452                        }
453                    },
454                    None => unreachable!("should not have empty worker joinset prior to shutdown"),
455                }
456            }
457        }
458
459        Ok(())
460    }
461
462    fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture {
463        // Simple wrapper around `run_inner` to satisfy the return type signature needed when running the supervisor as
464        // a nested child process in another supervisor.
465        debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
466
467        // Create a standalone clone of ourselves so we can fulfill the future signature.
468        let sup = self.inner_clone();
469
470        Box::pin(async move {
471            sup.run_inner(process, process_shutdown)
472                .await
473                .error_context("Nested supervisor failed to exit cleanly.")
474                .map_err(WorkerError::Runtime)
475        })
476    }
477
478    /// Runs the supervisor forever.
479    ///
480    /// # Errors
481    ///
482    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
483    pub async fn run(&mut self) -> Result<(), SupervisorError> {
484        // Create a no-op `ProcessShutdown` to satisfy the `run_inner` function. This is never used since we want to run
485        // forever, but we need to satisfy the signature.
486        let process_shutdown = ProcessShutdown::noop();
487        let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
488            name: self.supervisor_id.to_string(),
489        })?;
490
491        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
492        self.run_inner(process.clone(), process_shutdown)
493            .into_instrumented(process)
494            .await
495    }
496
497    /// Runs the supervisor until shutdown is triggered.
498    ///
499    /// When `shutdown` resolves, the supervisor will shutdown all child processes according to their shutdown strategy,
500    /// and then return.
501    ///
502    /// # Errors
503    ///
504    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
505    pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
506        let process_shutdown = ProcessShutdown::wrapped(shutdown);
507        self.run_with_process_shutdown(process_shutdown).await
508    }
509
510    /// Runs the supervisor until the given `ProcessShutdown` signal is received.
511    ///
512    /// This is an internal variant of `run_with_shutdown` that takes a `ProcessShutdown` directly, used when spawning
513    /// supervisors in dedicated runtimes where the shutdown signal is already wrapped in a `ProcessShutdown`.
514    ///
515    /// # Errors
516    ///
517    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
518    pub(crate) async fn run_with_process_shutdown(
519        &mut self, process_shutdown: ProcessShutdown,
520    ) -> Result<(), SupervisorError> {
521        let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
522            name: self.supervisor_id.to_string(),
523        })?;
524
525        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
526        self.run_inner(process.clone(), process_shutdown)
527            .into_instrumented(process)
528            .await
529    }
530
531    fn inner_clone(&self) -> Self {
532        // This is no different than if we just implemented `Clone` directly, but it allows us to avoid exposing a
533        // _public_ implementation of `Clone`, which we don't want normal users to be able to do. We only need this
534        // internally to support nested supervisors.
535        Self {
536            supervisor_id: Arc::clone(&self.supervisor_id),
537            child_specs: self.child_specs.clone(),
538            restart_strategy: self.restart_strategy,
539            runtime_mode: self.runtime_mode.clone(),
540        }
541    }
542}
543
544struct ProcessState {
545    worker_id: usize,
546    shutdown_strategy: ShutdownStrategy,
547    shutdown_handle: ShutdownHandle,
548    abort_handle: AbortHandle,
549}
550
551struct WorkerState {
552    process: Process,
553    worker_tasks: JoinSet<Result<(), WorkerError>>,
554    worker_map: FastIndexMap<Id, ProcessState>,
555}
556
557impl WorkerState {
558    fn new(process: Process) -> Self {
559        Self {
560            process,
561            worker_tasks: JoinSet::new(),
562            worker_map: FastIndexMap::default(),
563        }
564    }
565
566    fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
567        let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
568        let process = child_spec.create_process(&self.process)?;
569        let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?;
570        let shutdown_strategy = child_spec.shutdown_strategy();
571        let abort_handle = self.worker_tasks.spawn(worker_future.into_instrumented(process));
572        self.worker_map.insert(
573            abort_handle.id(),
574            ProcessState {
575                worker_id,
576                shutdown_strategy,
577                shutdown_handle,
578                abort_handle,
579            },
580        );
581        Ok(())
582    }
583
584    async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> {
585        debug!("Waiting for next process to complete.");
586
587        match self.worker_tasks.join_next_with_id().await {
588            Some(Ok((worker_task_id, worker_result))) => {
589                let process_state = self
590                    .worker_map
591                    .swap_remove(&worker_task_id)
592                    .expect("worker task ID not found");
593                Some((process_state.worker_id, worker_result))
594            }
595            Some(Err(e)) => {
596                let worker_task_id = e.id();
597                let process_state = self
598                    .worker_map
599                    .swap_remove(&worker_task_id)
600                    .expect("worker task ID not found");
601                let e = if e.is_cancelled() {
602                    ProcessError::Aborted
603                } else {
604                    ProcessError::Panicked
605                };
606                Some((process_state.worker_id, Err(WorkerError::Runtime(e.into()))))
607            }
608            None => None,
609        }
610    }
611
612    async fn shutdown_workers(&mut self) {
613        debug!("Shutting down all processes.");
614
615        // Pop entries from the worker map, which grabs us workers in the reverse order they were added. This lets us
616        // ensure we're shutting down any _dependent_ processes (processes which depend on previously-started processes)
617        // first.
618        //
619        // For each entry, we trigger shutdown in whatever way necessary, and then wait for the process to exit by
620        // driving the `JoinSet`. If other workers complete while we're waiting, we'll simply remove them from the
621        // worker map and continue waiting for the current worker we're shutting down.
622        //
623        // We do this until the worker map is empty, at which point we can be sure that all processes have exited.
624        while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
625            let ProcessState {
626                worker_id,
627                shutdown_strategy,
628                shutdown_handle,
629                abort_handle,
630            } = process_state;
631
632            // Trigger the process to shutdown based on the configured shutdown strategy.
633            let shutdown_deadline = match shutdown_strategy {
634                ShutdownStrategy::Graceful(timeout) => {
635                    debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
636                    shutdown_handle.trigger();
637
638                    tokio::time::sleep(timeout)
639                }
640                ShutdownStrategy::Brutal => {
641                    debug!(worker_id, "Forcefully aborting process.");
642                    abort_handle.abort();
643
644                    // We have to return a future that never resolves, since we're already aborting it. This is a little
645                    // hacky but it's also difficult to do an optional future, so this is what we're going with for now.
646                    tokio::time::sleep(Duration::MAX)
647                }
648            };
649            pin!(shutdown_deadline);
650
651            // Wait for the process to exit by driving the `JoinSet`. If other workers complete while we're waiting,
652            // we'll simply remove them from the worker map and continue waiting.
653            loop {
654                select! {
655                    worker_result = self.worker_tasks.join_next_with_id() => {
656                        match worker_result {
657                            Some(Ok((worker_task_id, _))) => {
658                                if worker_task_id == current_worker_task_id {
659                                    debug!(?worker_task_id, "Target process exited successfully.");
660                                    break;
661                                } else {
662                                    debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
663                                    self.worker_map.swap_remove(&worker_task_id);
664                                }
665                            },
666                            Some(Err(e)) => {
667                                let worker_task_id = e.id();
668                                if worker_task_id == current_worker_task_id {
669                                    debug!(?worker_task_id, "Target process exited with error.");
670                                    break;
671                                } else {
672                                    debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
673                                    self.worker_map.swap_remove(&worker_task_id);
674                                }
675                            }
676                            None => unreachable!("worker task must exist in join set if we are waiting for it"),
677                        }
678                    },
679                    // We've exceeded the shutdown timeout, so we need to abort the process.
680                    _ = &mut shutdown_deadline => {
681                        debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
682                        abort_handle.abort();
683                    }
684                }
685            }
686        }
687
688        debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
689        debug_assert!(
690            self.worker_tasks.is_empty(),
691            "worker tasks should be empty after shutdown"
692        );
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use std::sync::atomic::{AtomicUsize, Ordering};
699
700    use async_trait::async_trait;
701    use tokio::{
702        sync::oneshot,
703        task::JoinHandle,
704        time::{sleep, timeout},
705    };
706
707    use super::*;
708
709    /// Behavior for a mock worker during initialization.
710    #[derive(Clone)]
711    enum InitBehavior {
712        /// Initialization succeeds immediately.
713        Instant,
714
715        /// Initialization takes the given duration before succeeding.
716        Slow(Duration),
717
718        /// Initialization fails with the given message.
719        Fail(&'static str),
720    }
721
722    /// Behavior for a mock worker during runtime (after initialization).
723    #[derive(Clone)]
724    enum RunBehavior {
725        /// Runs until shutdown is received.
726        UntilShutdown,
727
728        /// Fails with the given error message after the given delay.
729        FailAfter(Duration, &'static str),
730    }
731
732    /// A configurable mock worker for testing supervisor behavior.
733    struct MockWorker {
734        name: &'static str,
735        init_behavior: InitBehavior,
736        run_behavior: RunBehavior,
737        start_count: Arc<AtomicUsize>,
738    }
739
740    impl MockWorker {
741        /// Creates a worker that runs until shutdown.
742        fn long_running(name: &'static str) -> Self {
743            Self {
744                name,
745                init_behavior: InitBehavior::Instant,
746                run_behavior: RunBehavior::UntilShutdown,
747                start_count: Arc::new(AtomicUsize::new(0)),
748            }
749        }
750
751        /// Creates a worker that fails after the given delay.
752        fn failing(name: &'static str, delay: Duration) -> Self {
753            Self {
754                name,
755                init_behavior: InitBehavior::Instant,
756                run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
757                start_count: Arc::new(AtomicUsize::new(0)),
758            }
759        }
760
761        /// Creates a worker that fails during initialization.
762        fn init_failure(name: &'static str) -> Self {
763            Self {
764                name,
765                init_behavior: InitBehavior::Fail("init failed"),
766                run_behavior: RunBehavior::UntilShutdown,
767                start_count: Arc::new(AtomicUsize::new(0)),
768            }
769        }
770
771        /// Creates a worker with slow initialization.
772        fn slow_init(name: &'static str, init_delay: Duration) -> Self {
773            Self {
774                name,
775                init_behavior: InitBehavior::Slow(init_delay),
776                run_behavior: RunBehavior::UntilShutdown,
777                start_count: Arc::new(AtomicUsize::new(0)),
778            }
779        }
780
781        /// Returns a shared handle to the start count for this worker.
782        fn start_count(&self) -> Arc<AtomicUsize> {
783            Arc::clone(&self.start_count)
784        }
785    }
786
787    #[async_trait]
788    impl Supervisable for MockWorker {
789        fn name(&self) -> &str {
790            self.name
791        }
792
793        fn shutdown_strategy(&self) -> ShutdownStrategy {
794            ShutdownStrategy::Graceful(Duration::from_millis(500))
795        }
796
797        async fn initialize(
798            &self, mut process_shutdown: ProcessShutdown,
799        ) -> Result<SupervisorFuture, InitializationError> {
800            match &self.init_behavior {
801                InitBehavior::Instant => {}
802                InitBehavior::Slow(delay) => {
803                    sleep(*delay).await;
804                }
805                InitBehavior::Fail(msg) => {
806                    return Err(InitializationError::Failed {
807                        source: GenericError::msg(*msg),
808                    });
809                }
810            }
811
812            let start_count = Arc::clone(&self.start_count);
813            let run_behavior = self.run_behavior.clone();
814
815            Ok(Box::pin(async move {
816                start_count.fetch_add(1, Ordering::SeqCst);
817
818                match run_behavior {
819                    RunBehavior::UntilShutdown => {
820                        process_shutdown.wait_for_shutdown().await;
821                        Ok(())
822                    }
823                    RunBehavior::FailAfter(delay, msg) => {
824                        select! {
825                            _ = sleep(delay) => {
826                                Err(GenericError::msg(msg))
827                            }
828                            _ = process_shutdown.wait_for_shutdown() => {
829                                Ok(())
830                            }
831                        }
832                    }
833                }
834            }))
835        }
836    }
837
838    /// Helper: run a supervisor with a oneshot-based shutdown trigger.
839    /// Returns the supervisor result and provides the shutdown sender.
840    async fn run_supervisor_with_trigger(
841        mut supervisor: Supervisor,
842    ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
843        let (tx, rx) = oneshot::channel();
844        let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
845        // Give the supervisor a moment to start and spawn children.
846        sleep(Duration::from_millis(50)).await;
847        (tx, handle)
848    }
849
850    // -- Supervisor run mode tests ---------------------------------------------------------
851
852    #[tokio::test]
853    async fn standalone_supervisor_shuts_down_cleanly() {
854        let mut sup = Supervisor::new("test-sup").unwrap();
855        sup.add_worker(MockWorker::long_running("worker1"));
856        sup.add_worker(MockWorker::long_running("worker2"));
857
858        let (tx, handle) = run_supervisor_with_trigger(sup).await;
859        tx.send(()).unwrap();
860
861        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
862        assert!(result.is_ok());
863    }
864
865    #[tokio::test]
866    async fn nested_supervisor_shuts_down_cleanly() {
867        let mut child_sup = Supervisor::new("child-sup").unwrap();
868        child_sup.add_worker(MockWorker::long_running("inner-worker"));
869
870        let mut parent_sup = Supervisor::new("parent-sup").unwrap();
871        parent_sup.add_worker(MockWorker::long_running("outer-worker"));
872        parent_sup.add_worker(child_sup);
873
874        let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
875        tx.send(()).unwrap();
876
877        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
878        assert!(result.is_ok());
879    }
880
881    #[tokio::test]
882    async fn supervisor_with_no_children_returns_error() {
883        let mut sup = Supervisor::new("empty-sup").unwrap();
884
885        let (tx, rx) = oneshot::channel::<()>();
886        let result = sup.run_with_shutdown(rx).await;
887        drop(tx);
888
889        assert!(matches!(result, Err(SupervisorError::NoChildren)));
890    }
891
892    // -- Child restart behavior tests ------------------------------------------------------
893
894    #[tokio::test]
895    async fn one_for_one_restarts_only_failed_child() {
896        let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
897        let failing_count = failing.start_count();
898
899        let stable = MockWorker::long_running("stable-worker");
900        let stable_count = stable.start_count();
901
902        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
903            RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
904        );
905        sup.add_worker(stable);
906        sup.add_worker(failing);
907
908        let (tx, handle) = run_supervisor_with_trigger(sup).await;
909
910        // Wait for a few restarts to happen.
911        sleep(Duration::from_millis(300)).await;
912        let _ = tx.send(());
913
914        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
915        assert!(result.is_ok());
916
917        // The failing worker should have been started multiple times.
918        assert!(
919            failing_count.load(Ordering::SeqCst) >= 2,
920            "failing worker should have been restarted"
921        );
922        // The stable worker should only have been started once (never restarted).
923        assert_eq!(
924            stable_count.load(Ordering::SeqCst),
925            1,
926            "stable worker should not have been restarted"
927        );
928    }
929
930    #[tokio::test]
931    async fn one_for_all_restarts_all_children() {
932        let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
933        let failing_count = failing.start_count();
934
935        let stable = MockWorker::long_running("stable-worker");
936        let stable_count = stable.start_count();
937
938        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
939            RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
940        );
941        sup.add_worker(stable);
942        sup.add_worker(failing);
943
944        let (tx, handle) = run_supervisor_with_trigger(sup).await;
945
946        // Wait for at least one restart cycle.
947        sleep(Duration::from_millis(300)).await;
948        let _ = tx.send(());
949
950        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
951        assert!(result.is_ok());
952
953        // Both workers should have been started multiple times.
954        assert!(
955            failing_count.load(Ordering::SeqCst) >= 2,
956            "failing worker should have been restarted"
957        );
958        assert!(
959            stable_count.load(Ordering::SeqCst) >= 2,
960            "stable worker should also have been restarted"
961        );
962    }
963
964    #[tokio::test]
965    async fn restart_limit_exceeded_shuts_down_supervisor() {
966        let mut sup = Supervisor::new("test-sup")
967            .unwrap()
968            .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
969        // This worker fails immediately, which will exhaust the restart budget quickly.
970        sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
971
972        let (tx, rx) = oneshot::channel::<()>();
973        let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
974
975        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
976        drop(tx);
977
978        assert!(matches!(result, Err(SupervisorError::Shutdown)));
979    }
980
981    // -- Initialization failure tests ------------------------------------------------------
982
983    #[tokio::test]
984    async fn init_failure_propagates_with_child_name() {
985        let mut sup = Supervisor::new("test-sup").unwrap();
986        sup.add_worker(MockWorker::long_running("good-worker"));
987        sup.add_worker(MockWorker::init_failure("bad-worker"));
988
989        let (_tx, rx) = oneshot::channel::<()>();
990        let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
991            .await
992            .unwrap();
993
994        match result {
995            Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
996                assert_eq!(child_name, "bad-worker");
997            }
998            other => panic!("expected FailedToInitialize, got: {:?}", other),
999        }
1000    }
1001
1002    #[tokio::test]
1003    async fn init_failure_does_not_trigger_restart() {
1004        let init_fail = MockWorker::init_failure("bad-worker");
1005        let start_count = init_fail.start_count();
1006
1007        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1008            RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1009        );
1010        sup.add_worker(init_fail);
1011
1012        let (_tx, rx) = oneshot::channel::<()>();
1013        let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1014            .await
1015            .unwrap();
1016
1017        assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1018        // The worker never got past init, so start_count should be 0.
1019        assert_eq!(start_count.load(Ordering::SeqCst), 0);
1020    }
1021
1022    // -- Shutdown responsiveness tests -----------------------------------------------------
1023
1024    #[tokio::test]
1025    async fn shutdown_completes_promptly_in_steady_state() {
1026        let mut sup = Supervisor::new("test-sup").unwrap();
1027        sup.add_worker(MockWorker::long_running("worker1"));
1028        sup.add_worker(MockWorker::long_running("worker2"));
1029
1030        let (tx, handle) = run_supervisor_with_trigger(sup).await;
1031        tx.send(()).unwrap();
1032
1033        // Shutdown should complete well within 1 second (workers respond to shutdown signal immediately).
1034        let result = timeout(Duration::from_secs(1), handle).await;
1035        assert!(result.is_ok(), "shutdown should complete promptly");
1036    }
1037
1038    #[tokio::test]
1039    async fn shutdown_during_slow_init_completes_promptly() {
1040        let mut sup = Supervisor::new("test-sup").unwrap();
1041        // This worker takes 30 seconds to initialize — but we'll trigger shutdown immediately.
1042        sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1043
1044        let (tx, rx) = oneshot::channel();
1045        let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1046
1047        // Give the supervisor just enough time to spawn the task, then trigger shutdown.
1048        sleep(Duration::from_millis(20)).await;
1049        tx.send(()).unwrap();
1050
1051        // Shutdown should complete quickly even though the worker hasn't finished initializing.
1052        // The supervisor loop sees the shutdown signal and aborts the still-initializing task.
1053        let result = timeout(Duration::from_secs(2), handle).await;
1054        assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1055    }
1056}