Skip to main content

saluki_core/runtime/
supervisor.rs

1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_error::GenericError;
6use snafu::{OptionExt as _, Snafu};
7use tokio::{
8    pin, select,
9    task::{AbortHandle, Id, JoinSet},
10};
11use tracing::{debug, error, warn};
12
13use super::{
14    dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
15    restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
16    shutdown::{ProcessShutdown, ShutdownHandle},
17};
18use crate::runtime::{
19    process::{Process, ProcessExt as _},
20    state::DataspaceRegistry,
21};
22
23/// A `Future` that represents the execution of a supervised process.
24pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
25
26/// A `Future` that represents the full lifecycle of a worker, including initialization.
27///
28/// Unlike [`SupervisorFuture`], which only represents the runtime phase, this future first performs async
29/// initialization and then runs the worker. This allows initialization to happen concurrently when multiple workers are
30/// spawned, and keeps the supervisor loop responsive to shutdown signals during initialization.
31type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
32
33/// Worker lifecycle errors.
34///
35/// Distinguishes between initialization failures (which shouldn't trigger restart logic) and runtime failures (which
36/// are eligible for restart).
37#[derive(Debug)]
38enum WorkerError {
39    /// The worker failed during async initialization.
40    ///
41    /// The optional `child_name` carries the name of the original failing child when the error originates from a
42    /// nested supervisor. This allows the parent to include it in its own `FailedToInitialize` error for better
43    /// diagnostics across supervision tree levels.
44    Initialization {
45        child_name: Option<String>,
46        source: InitializationError,
47    },
48
49    /// The worker failed during runtime execution.
50    Runtime(GenericError),
51}
52
53impl From<SupervisorError> for WorkerError {
54    fn from(err: SupervisorError) -> Self {
55        match err {
56            // Propagate initialization failures so the parent supervisor does NOT attempt to restart.
57            // Preserve the original child name so the parent can include it in diagnostics.
58            SupervisorError::FailedToInitialize { child_name, source } => WorkerError::Initialization {
59                child_name: Some(child_name),
60                source,
61            },
62            // All other supervisor errors (shutdown, no children, invalid name) are runtime-level.
63            other => WorkerError::Runtime(other.into()),
64        }
65    }
66}
67
68/// Process errors.
69#[derive(Debug, Snafu)]
70pub enum ProcessError {
71    /// The child process was aborted by the supervisor.
72    #[snafu(display("Child process was aborted by the supervisor."))]
73    Aborted,
74
75    /// The child process panicked.
76    #[snafu(display("Child process panicked."))]
77    Panicked,
78
79    /// The child process terminated with an error.
80    #[snafu(display("Child process terminated with an error: {}", source))]
81    Terminated {
82        /// The error that caused the termination.
83        source: GenericError,
84    },
85}
86
87/// Initialization errors.
88///
89/// Initialization errors are distinct from runtime errors: they indicate that a process couldn't be started at all
90/// (for example, failed to bind a port, missing configuration). These errors don't trigger restart logic; instead, they
91/// immediately propagate up and fail the supervisor.
92#[derive(Debug, Snafu)]
93#[snafu(context(suffix(false)))]
94pub enum InitializationError {
95    /// The process couldn't be initialized due to an error.
96    #[snafu(display("Process failed to initialize: {}", source))]
97    Failed {
98        /// The underlying error that caused initialization to fail.
99        source: GenericError,
100    },
101}
102
103impl From<GenericError> for InitializationError {
104    fn from(source: GenericError) -> Self {
105        Self::Failed { source }
106    }
107}
108
109/// Strategy for shutting down a process.
110pub enum ShutdownStrategy {
111    /// Waits for the configured duration for the process to exit, and then forcefully aborts it otherwise.
112    Graceful(Duration),
113
114    /// Forcefully aborts the process without waiting.
115    Brutal,
116}
117
118/// A supervisable process.
119#[async_trait]
120pub trait Supervisable: Send + Sync {
121    /// Returns the name of the process.
122    fn name(&self) -> &str;
123
124    /// Returns the shutdown strategy for the process.
125    fn shutdown_strategy(&self) -> ShutdownStrategy {
126        ShutdownStrategy::Graceful(Duration::from_secs(5))
127    }
128
129    /// Initializes the process asynchronously.
130    ///
131    /// During initialization, any resources or configuration for the process can be created asynchronously, and the
132    /// same runtime that's used for running the process is used for initialization. The resulting future is expected
133    /// to complete as soon as reasonably possible after `process_shutdown` resolves.
134    ///
135    /// **Important:** The `process_shutdown` signal must be moved into the returned [`SupervisorFuture`] so the
136    /// worker can respond to supervisor-initiated shutdown. If `process_shutdown` is dropped during initialization,
137    /// the worker will be unable to shut down gracefully and will be forcefully aborted after the shutdown timeout.
138    ///
139    /// # Errors
140    ///
141    /// If the process can't be initialized, an error is returned.
142    async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError>;
143}
144
145/// Supervisor errors.
146#[derive(Debug, Snafu)]
147#[snafu(context(suffix(false)))]
148pub enum SupervisorError {
149    /// Supervisor or worker name is invalid.
150    #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
151    InvalidName {
152        /// The name of the supervisor is invalid.
153        name: String,
154    },
155
156    /// The supervisor has no child processes.
157    #[snafu(display("Supervisor has no child processes."))]
158    NoChildren,
159
160    /// A child process failed to initialize.
161    ///
162    /// This error indicates that a child couldn't complete its async initialization. This is distinct from runtime
163    /// failures and doesn't trigger restart logic.
164    #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
165    FailedToInitialize {
166        /// The name of the child that failed to initialize.
167        child_name: String,
168
169        /// The underlying initialization error.
170        source: InitializationError,
171    },
172
173    /// The supervisor exceeded its restart limits and was forced to shutdown.
174    #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
175    Shutdown,
176}
177
178/// A child process specification.
179///
180/// All workers added to a [`Supervisor`] must be specified as a `ChildSpecification`. This acts a template for how the
181/// supervisor should create the underlying future that represents the process, as well as information about the
182/// process, such as its name, shutdown strategy, and more.
183///
184/// A child process specification can be created implicitly from an existing [`Supervisor`], or any type that implements
185/// [`Supervisable`].
186pub enum ChildSpecification {
187    Worker(Arc<dyn Supervisable>),
188    Supervisor(Supervisor),
189}
190
191impl ChildSpecification {
192    fn process_type(&self) -> &'static str {
193        match self {
194            Self::Worker(_) => "worker",
195            Self::Supervisor(_) => "supervisor",
196        }
197    }
198
199    fn name(&self) -> &str {
200        match self {
201            Self::Worker(worker) => worker.name(),
202            Self::Supervisor(supervisor) => &supervisor.supervisor_id,
203        }
204    }
205
206    fn shutdown_strategy(&self) -> ShutdownStrategy {
207        match self {
208            Self::Worker(worker) => worker.shutdown_strategy(),
209
210            // Supervisors should always be given as much time as necessary shutdown down gracefully to ensure that the
211            // entire supervision subtree can be shutdown cleanly.
212            Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
213        }
214    }
215
216    fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
217        match self {
218            Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
219                name: worker.name().to_string(),
220            }),
221            Self::Supervisor(sup) => {
222                Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
223                    name: sup.supervisor_id.to_string(),
224                })
225            }
226        }
227    }
228
229    fn create_worker_future(
230        &self, process: Process, process_shutdown: ProcessShutdown,
231    ) -> Result<WorkerFuture, SupervisorError> {
232        match self {
233            Self::Worker(worker) => {
234                let worker = Arc::clone(worker);
235                Ok(Box::pin(async move {
236                    let run_future =
237                        worker
238                            .initialize(process_shutdown)
239                            .await
240                            .map_err(|source| WorkerError::Initialization {
241                                child_name: None,
242                                source,
243                            })?;
244                    run_future.await.map_err(WorkerError::Runtime)
245                }))
246            }
247            Self::Supervisor(sup) => {
248                match sup.runtime_mode() {
249                    RuntimeMode::Ambient => {
250                        // Run on the parent's ambient runtime.
251                        Ok(sup.as_nested_process(process, process_shutdown))
252                    }
253                    RuntimeMode::Dedicated(config) => {
254                        // Spawn in a dedicated runtime on a new OS thread, passing the parent's
255                        // dataspace so the nested supervisor inherits it across the thread boundary.
256                        let child_name = sup.supervisor_id.to_string();
257                        let dataspace = process.dataspace().clone();
258                        let handle =
259                            spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown, dataspace)
260                                .map_err(|e| SupervisorError::FailedToInitialize {
261                                    child_name,
262                                    source: e.into(),
263                                })?;
264
265                        Ok(Box::pin(async move { handle.await.map_err(WorkerError::from) }))
266                    }
267                }
268            }
269        }
270    }
271}
272
273impl Clone for ChildSpecification {
274    fn clone(&self) -> Self {
275        match self {
276            Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
277            Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
278        }
279    }
280}
281
282impl From<Supervisor> for ChildSpecification {
283    fn from(supervisor: Supervisor) -> Self {
284        Self::Supervisor(supervisor)
285    }
286}
287
288impl<T> From<T> for ChildSpecification
289where
290    T: Supervisable + 'static,
291{
292    fn from(worker: T) -> Self {
293        Self::Worker(Arc::new(worker))
294    }
295}
296
297/// Supervises a set of workers.
298///
299/// # Workers
300///
301/// All workers are defined through implementation of the [`Supervisable`] trait, which provides the logic for both
302/// creating the underlying worker future that's spawned, as well as other metadata, such as the worker's name, how the
303/// worker should be shutdown, and so on.
304///
305/// Supervisors also (indirectly) implement the [`Supervisable`] trait, allowing them to be supervised by other
306/// supervisors in order to construct _supervision trees_.
307///
308/// # Instrumentation
309///
310/// Supervisors automatically create their own allocation group
311/// ([`TrackingAllocator`][resource_accounting::TrackingAllocator]), which is used to track both the memory usage of the
312/// supervisor itself and its children. Additionally, individual worker processes are wrapped in a dedicated
313/// [`tracing::Span`] to allow tracing the causal relationship between arbitrary code and the worker executing it.
314///
315/// # Restart Strategies
316///
317/// As the main purpose of a supervisor, restart behavior is fully configurable. A number of restart strategies are
318/// available, which generally relate to the purpose of the supervisor: whether the workers being managed are
319/// independent or interdependent.
320///
321/// All restart strategies are configured through [`RestartStrategy`], which has more information on the available
322/// strategies and configuration settings.
323pub struct Supervisor {
324    supervisor_id: Arc<str>,
325    child_specs: Vec<ChildSpecification>,
326    restart_strategy: RestartStrategy,
327    runtime_mode: RuntimeMode,
328}
329
330impl Supervisor {
331    /// Creates an empty `Supervisor` with the default restart strategy.
332    pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
333        // We try to throw an error about invalid names as early as possible. This is a manual check, so we might still
334        // encounter an error later when actually running the supervisor, but this is a good first step to catch the
335        // bulk of invalid names.
336        if supervisor_id.as_ref().is_empty() {
337            return Err(SupervisorError::InvalidName {
338                name: supervisor_id.as_ref().to_string(),
339            });
340        }
341
342        Ok(Self {
343            supervisor_id: supervisor_id.as_ref().into(),
344            child_specs: Vec::new(),
345            restart_strategy: RestartStrategy::default(),
346            runtime_mode: RuntimeMode::default(),
347        })
348    }
349
350    /// Returns the supervisor's ID.
351    pub fn id(&self) -> &str {
352        &self.supervisor_id
353    }
354
355    /// Sets the restart strategy for the supervisor.
356    pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
357        self.restart_strategy = strategy;
358        self
359    }
360
361    /// Configures this supervisor to run in a dedicated runtime.
362    ///
363    /// When this supervisor is added as a child to another supervisor, it will spawn its own OS threads and Tokio
364    /// runtime instead of running on the parent's ambient runtime.
365    ///
366    /// This provides runtime isolation, which can be useful for:
367    /// - CPU-bound work that shouldn't block the parent's runtime
368    /// - Isolating failures in one part of the system
369    /// - Using different runtime configurations (for example, single-threaded vs multi-threaded)
370    pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
371        self.runtime_mode = RuntimeMode::Dedicated(config);
372        self
373    }
374
375    /// Returns the runtime mode for this supervisor.
376    pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
377        &self.runtime_mode
378    }
379
380    /// Adds a worker to the supervisor.
381    ///
382    /// A worker can be anything that implements the [`Supervisable`] trait. A [`Supervisor`] can also be added as a
383    /// worker and managed in a nested fashion, known as a supervision tree.
384    pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
385        let child_spec = process.into();
386        debug!(
387            supervisor_id = %self.supervisor_id,
388            "Adding new static child process #{}. ({}, {})",
389            self.child_specs.len(),
390            child_spec.process_type(),
391            child_spec.name(),
392        );
393        self.child_specs.push(child_spec);
394    }
395
396    fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
397        match self.child_specs.get(child_spec_idx) {
398            Some(child_spec) => child_spec,
399            None => unreachable!("child spec index should never be out of bounds"),
400        }
401    }
402
403    fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
404        let child_spec = self.get_child_spec(child_spec_idx);
405        debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
406        worker_state.add_worker(child_spec_idx, child_spec)
407    }
408
409    fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
410        debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
411        for child_spec_idx in 0..self.child_specs.len() {
412            self.spawn_child(child_spec_idx, worker_state)?;
413        }
414
415        Ok(())
416    }
417
418    async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
419        if self.child_specs.is_empty() {
420            return Err(SupervisorError::NoChildren);
421        }
422
423        let mut restart_state = RestartState::new(self.restart_strategy);
424        let mut worker_state = WorkerState::new(process);
425
426        // Spawn all child processes. Since initialization is folded into each worker's task, this returns immediately
427        // after spawning -- children initialize concurrently in the background.
428        self.spawn_all_children(&mut worker_state)?;
429
430        // Now we supervise.
431        let shutdown = process_shutdown.wait_for_shutdown();
432        pin!(shutdown);
433
434        loop {
435            select! {
436                // Shutdown has been triggered.
437                //
438                // Propagate shutdown to all child processes and wait for them to exit.
439                _ = &mut shutdown => {
440                    debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
441                    worker_state.shutdown_workers().await;
442                    break;
443                },
444                worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
445                    // TODO: Erlang/OTP defaults to always trying to restart a process, even if it doesn't terminate due
446                    // to a legitimate failure. It does allow configuring this behavior on a per-process basis, however.
447                    // We don't support dynamically adding child processes, which is the only real use case I can think
448                    // of for having non-long-lived child processes... so I think for now, we're OK just always try to
449                    // restart.
450                    Some((child_spec_idx, worker_result)) =>  {
451                        let child_spec = self.get_child_spec(child_spec_idx);
452
453                        // Initialization failures are not eligible for restart -- they propagate immediately.
454                        if let Err(WorkerError::Initialization { child_name, source }) = worker_result {
455                            // If the error came from a nested supervisor, include the original child name
456                            // to make the error chain more informative (e.g., "ctrl-pln/privileged-api").
457                            let full_name = match child_name {
458                                Some(inner) => format!("{}/{}", child_spec.name(), inner),
459                                None => child_spec.name().to_string(),
460                            };
461
462                            error!(supervisor_id = %self.supervisor_id, worker_name = full_name, "Child process failed to initialize: {}", source);
463                            worker_state.shutdown_workers().await;
464                            return Err(SupervisorError::FailedToInitialize {
465                                child_name: full_name,
466                                source,
467                            });
468                        }
469
470                        // Convert the worker result to a process error for restart evaluation.
471                        let worker_result = worker_result
472                            .map_err(|e| match e {
473                                WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
474                                WorkerError::Initialization { .. } => unreachable!("handled above"),
475                            });
476
477                        match restart_state.evaluate_restart() {
478                            RestartAction::Restart(mode) => match mode {
479                                RestartMode::OneForOne => {
480                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
481                                    self.spawn_child(child_spec_idx, &mut worker_state)?;
482                                }
483                                RestartMode::OneForAll => {
484                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
485                                    worker_state.shutdown_workers().await;
486                                    self.spawn_all_children(&mut worker_state)?;
487                                }
488                            },
489                            RestartAction::Shutdown => {
490                                error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
491                                worker_state.shutdown_workers().await;
492                                return Err(SupervisorError::Shutdown);
493                            }
494                        }
495                    },
496                    None => unreachable!("should not have empty worker joinset prior to shutdown"),
497                }
498            }
499        }
500
501        Ok(())
502    }
503
504    fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture {
505        // Simple wrapper around `run_inner` to satisfy the return type signature needed when running the supervisor as
506        // a nested child process in another supervisor.
507        debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
508
509        // Create a standalone clone of ourselves so we can fulfill the future signature.
510        let sup = self.inner_clone();
511
512        Box::pin(async move {
513            sup.run_inner(process, process_shutdown)
514                .await
515                .map_err(WorkerError::from)
516        })
517    }
518
519    /// Runs the supervisor forever.
520    ///
521    /// # Errors
522    ///
523    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
524    pub async fn run(&mut self) -> Result<(), SupervisorError> {
525        // Create a no-op `ProcessShutdown` to satisfy the `run_inner` function. This is never used since we want to run
526        // forever, but we need to satisfy the signature.
527        let process_shutdown = ProcessShutdown::noop();
528        let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
529            name: self.supervisor_id.to_string(),
530        })?;
531
532        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
533        self.run_inner(process.clone(), process_shutdown)
534            .into_process_future(process)
535            .await
536    }
537
538    /// Runs the supervisor until shutdown is triggered.
539    ///
540    /// When `shutdown` resolves, the supervisor will shutdown all child processes according to their shutdown strategy,
541    /// and then return.
542    ///
543    /// # Errors
544    ///
545    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
546    pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
547        let process_shutdown = ProcessShutdown::wrapped(shutdown);
548        self.run_with_process_shutdown(process_shutdown, None).await
549    }
550
551    /// Runs the supervisor until the given `ProcessShutdown` signal is received.
552    ///
553    /// This is an internal variant of `run_with_shutdown` that takes a `ProcessShutdown` directly, used when spawning
554    /// supervisors in dedicated runtimes where the shutdown signal is already wrapped in a `ProcessShutdown`.
555    ///
556    /// If `dataspace` is provided, the supervisor will use it instead of creating a new one. This is used to propagate
557    /// the parent's dataspace across OS thread boundaries for dedicated runtimes.
558    ///
559    /// # Errors
560    ///
561    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
562    pub(crate) async fn run_with_process_shutdown(
563        &mut self, process_shutdown: ProcessShutdown, dataspace: Option<DataspaceRegistry>,
564    ) -> Result<(), SupervisorError> {
565        let process =
566            Process::supervisor_with_dataspace(&self.supervisor_id, None, dataspace).context(InvalidName {
567                name: self.supervisor_id.to_string(),
568            })?;
569
570        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
571        self.run_inner(process.clone(), process_shutdown)
572            .into_process_future(process)
573            .await
574    }
575
576    fn inner_clone(&self) -> Self {
577        // This is no different than if we just implemented `Clone` directly, but it allows us to avoid exposing a
578        // _public_ implementation of `Clone`, which we don't want normal users to be able to do. We only need this
579        // internally to support nested supervisors.
580        Self {
581            supervisor_id: Arc::clone(&self.supervisor_id),
582            child_specs: self.child_specs.clone(),
583            restart_strategy: self.restart_strategy,
584            runtime_mode: self.runtime_mode.clone(),
585        }
586    }
587}
588
589struct ProcessState {
590    worker_id: usize,
591    shutdown_strategy: ShutdownStrategy,
592    shutdown_handle: ShutdownHandle,
593    abort_handle: AbortHandle,
594}
595
596struct WorkerState {
597    process: Process,
598    worker_tasks: JoinSet<Result<(), WorkerError>>,
599    worker_map: FastIndexMap<Id, ProcessState>,
600}
601
602impl WorkerState {
603    fn new(process: Process) -> Self {
604        Self {
605            process,
606            worker_tasks: JoinSet::new(),
607            worker_map: FastIndexMap::default(),
608        }
609    }
610
611    fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
612        let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
613        let process = child_spec.create_process(&self.process)?;
614        let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?;
615        let shutdown_strategy = child_spec.shutdown_strategy();
616        let abort_handle = self.worker_tasks.spawn(worker_future.into_process_future(process));
617        self.worker_map.insert(
618            abort_handle.id(),
619            ProcessState {
620                worker_id,
621                shutdown_strategy,
622                shutdown_handle,
623                abort_handle,
624            },
625        );
626        Ok(())
627    }
628
629    async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> {
630        debug!("Waiting for next process to complete.");
631
632        match self.worker_tasks.join_next_with_id().await {
633            Some(Ok((worker_task_id, worker_result))) => {
634                let process_state = self
635                    .worker_map
636                    .swap_remove(&worker_task_id)
637                    .expect("worker task ID not found");
638                Some((process_state.worker_id, worker_result))
639            }
640            Some(Err(e)) => {
641                let worker_task_id = e.id();
642                let process_state = self
643                    .worker_map
644                    .swap_remove(&worker_task_id)
645                    .expect("worker task ID not found");
646                let e = if e.is_cancelled() {
647                    ProcessError::Aborted
648                } else {
649                    ProcessError::Panicked
650                };
651                Some((process_state.worker_id, Err(WorkerError::Runtime(e.into()))))
652            }
653            None => None,
654        }
655    }
656
657    async fn shutdown_workers(&mut self) {
658        debug!("Shutting down all processes.");
659
660        // Pop entries from the worker map, which grabs us workers in the reverse order they were added. This lets us
661        // ensure we're shutting down any _dependent_ processes (processes which depend on previously-started processes)
662        // first.
663        //
664        // For each entry, we trigger shutdown in whatever way necessary, and then wait for the process to exit by
665        // driving the `JoinSet`. If other workers complete while we're waiting, we'll simply remove them from the
666        // worker map and continue waiting for the current worker we're shutting down.
667        //
668        // We do this until the worker map is empty, at which point we can be sure that all processes have exited.
669        while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
670            let ProcessState {
671                worker_id,
672                shutdown_strategy,
673                shutdown_handle,
674                abort_handle,
675            } = process_state;
676
677            // Trigger the process to shutdown based on the configured shutdown strategy.
678            let shutdown_deadline = match shutdown_strategy {
679                ShutdownStrategy::Graceful(timeout) => {
680                    debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
681                    shutdown_handle.trigger();
682
683                    tokio::time::sleep(timeout)
684                }
685                ShutdownStrategy::Brutal => {
686                    debug!(worker_id, "Forcefully aborting process.");
687                    abort_handle.abort();
688
689                    // We have to return a future that never resolves, since we're already aborting it. This is a little
690                    // hacky but it's also difficult to do an optional future, so this is what we're going with for now.
691                    tokio::time::sleep(Duration::MAX)
692                }
693            };
694            pin!(shutdown_deadline);
695
696            // Wait for the process to exit by driving the `JoinSet`. If other workers complete while we're waiting,
697            // we'll simply remove them from the worker map and continue waiting.
698            loop {
699                select! {
700                    worker_result = self.worker_tasks.join_next_with_id() => {
701                        match worker_result {
702                            Some(Ok((worker_task_id, _))) => {
703                                if worker_task_id == current_worker_task_id {
704                                    debug!(?worker_task_id, "Target process exited successfully.");
705                                    break;
706                                } else {
707                                    debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
708                                    self.worker_map.swap_remove(&worker_task_id);
709                                }
710                            },
711                            Some(Err(e)) => {
712                                let worker_task_id = e.id();
713                                if worker_task_id == current_worker_task_id {
714                                    debug!(?worker_task_id, "Target process exited with error.");
715                                    break;
716                                } else {
717                                    debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
718                                    self.worker_map.swap_remove(&worker_task_id);
719                                }
720                            }
721                            None => unreachable!("worker task must exist in join set if we are waiting for it"),
722                        }
723                    },
724                    // We've exceeded the shutdown timeout, so we need to abort the process.
725                    _ = &mut shutdown_deadline => {
726                        debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
727                        abort_handle.abort();
728                    }
729                }
730            }
731        }
732
733        debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
734        debug_assert!(
735            self.worker_tasks.is_empty(),
736            "worker tasks should be empty after shutdown"
737        );
738    }
739}
740
741#[cfg(test)]
742mod tests {
743    use std::sync::atomic::{AtomicUsize, Ordering};
744
745    use async_trait::async_trait;
746    use tokio::{
747        sync::oneshot,
748        task::JoinHandle,
749        time::{sleep, timeout},
750    };
751
752    use super::*;
753
754    /// Behavior for a mock worker during initialization.
755    #[derive(Clone)]
756    enum InitBehavior {
757        /// Initialization succeeds immediately.
758        Instant,
759
760        /// Initialization takes the given duration before succeeding.
761        Slow(Duration),
762
763        /// Initialization fails with the given message.
764        Fail(&'static str),
765    }
766
767    /// Behavior for a mock worker during runtime (after initialization).
768    #[derive(Clone)]
769    enum RunBehavior {
770        /// Runs until shutdown is received.
771        UntilShutdown,
772
773        /// Fails with the given error message after the given delay.
774        FailAfter(Duration, &'static str),
775    }
776
777    /// A configurable mock worker for testing supervisor behavior.
778    struct MockWorker {
779        name: &'static str,
780        init_behavior: InitBehavior,
781        run_behavior: RunBehavior,
782        start_count: Arc<AtomicUsize>,
783    }
784
785    impl MockWorker {
786        /// Creates a worker that runs until shutdown.
787        fn long_running(name: &'static str) -> Self {
788            Self {
789                name,
790                init_behavior: InitBehavior::Instant,
791                run_behavior: RunBehavior::UntilShutdown,
792                start_count: Arc::new(AtomicUsize::new(0)),
793            }
794        }
795
796        /// Creates a worker that fails after the given delay.
797        fn failing(name: &'static str, delay: Duration) -> Self {
798            Self {
799                name,
800                init_behavior: InitBehavior::Instant,
801                run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
802                start_count: Arc::new(AtomicUsize::new(0)),
803            }
804        }
805
806        /// Creates a worker that fails during initialization.
807        fn init_failure(name: &'static str) -> Self {
808            Self {
809                name,
810                init_behavior: InitBehavior::Fail("init failed"),
811                run_behavior: RunBehavior::UntilShutdown,
812                start_count: Arc::new(AtomicUsize::new(0)),
813            }
814        }
815
816        /// Creates a worker with slow initialization.
817        fn slow_init(name: &'static str, init_delay: Duration) -> Self {
818            Self {
819                name,
820                init_behavior: InitBehavior::Slow(init_delay),
821                run_behavior: RunBehavior::UntilShutdown,
822                start_count: Arc::new(AtomicUsize::new(0)),
823            }
824        }
825
826        /// Returns a shared handle to the start count for this worker.
827        fn start_count(&self) -> Arc<AtomicUsize> {
828            Arc::clone(&self.start_count)
829        }
830    }
831
832    #[async_trait]
833    impl Supervisable for MockWorker {
834        fn name(&self) -> &str {
835            self.name
836        }
837
838        fn shutdown_strategy(&self) -> ShutdownStrategy {
839            ShutdownStrategy::Graceful(Duration::from_millis(500))
840        }
841
842        async fn initialize(
843            &self, mut process_shutdown: ProcessShutdown,
844        ) -> Result<SupervisorFuture, InitializationError> {
845            match &self.init_behavior {
846                InitBehavior::Instant => {}
847                InitBehavior::Slow(delay) => {
848                    sleep(*delay).await;
849                }
850                InitBehavior::Fail(msg) => {
851                    return Err(InitializationError::Failed {
852                        source: GenericError::msg(*msg),
853                    });
854                }
855            }
856
857            let start_count = Arc::clone(&self.start_count);
858            let run_behavior = self.run_behavior.clone();
859
860            Ok(Box::pin(async move {
861                start_count.fetch_add(1, Ordering::SeqCst);
862
863                match run_behavior {
864                    RunBehavior::UntilShutdown => {
865                        process_shutdown.wait_for_shutdown().await;
866                        Ok(())
867                    }
868                    RunBehavior::FailAfter(delay, msg) => {
869                        select! {
870                            _ = sleep(delay) => {
871                                Err(GenericError::msg(msg))
872                            }
873                            _ = process_shutdown.wait_for_shutdown() => {
874                                Ok(())
875                            }
876                        }
877                    }
878                }
879            }))
880        }
881    }
882
883    /// Helper: run a supervisor with a oneshot-based shutdown trigger.
884    /// Returns the supervisor result and provides the shutdown sender.
885    async fn run_supervisor_with_trigger(
886        mut supervisor: Supervisor,
887    ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
888        let (tx, rx) = oneshot::channel();
889        let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
890        // Give the supervisor a moment to start and spawn children.
891        sleep(Duration::from_millis(50)).await;
892        (tx, handle)
893    }
894
895    // -- Supervisor run mode tests ---------------------------------------------------------
896
897    #[tokio::test]
898    async fn standalone_supervisor_shuts_down_cleanly() {
899        let mut sup = Supervisor::new("test-sup").unwrap();
900        sup.add_worker(MockWorker::long_running("worker1"));
901        sup.add_worker(MockWorker::long_running("worker2"));
902
903        let (tx, handle) = run_supervisor_with_trigger(sup).await;
904        tx.send(()).unwrap();
905
906        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
907        assert!(result.is_ok());
908    }
909
910    #[tokio::test]
911    async fn nested_supervisor_shuts_down_cleanly() {
912        let mut child_sup = Supervisor::new("child-sup").unwrap();
913        child_sup.add_worker(MockWorker::long_running("inner-worker"));
914
915        let mut parent_sup = Supervisor::new("parent-sup").unwrap();
916        parent_sup.add_worker(MockWorker::long_running("outer-worker"));
917        parent_sup.add_worker(child_sup);
918
919        let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
920        tx.send(()).unwrap();
921
922        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
923        assert!(result.is_ok());
924    }
925
926    #[tokio::test]
927    async fn supervisor_with_no_children_returns_error() {
928        let mut sup = Supervisor::new("empty-sup").unwrap();
929
930        let (tx, rx) = oneshot::channel::<()>();
931        let result = sup.run_with_shutdown(rx).await;
932        drop(tx);
933
934        assert!(matches!(result, Err(SupervisorError::NoChildren)));
935    }
936
937    // -- Child restart behavior tests ------------------------------------------------------
938
939    #[tokio::test]
940    async fn one_for_one_restarts_only_failed_child() {
941        let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
942        let failing_count = failing.start_count();
943
944        let stable = MockWorker::long_running("stable-worker");
945        let stable_count = stable.start_count();
946
947        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
948            RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
949        );
950        sup.add_worker(stable);
951        sup.add_worker(failing);
952
953        let (tx, handle) = run_supervisor_with_trigger(sup).await;
954
955        // Wait for a few restarts to happen.
956        sleep(Duration::from_millis(300)).await;
957        let _ = tx.send(());
958
959        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
960        assert!(result.is_ok());
961
962        // The failing worker should have been started multiple times.
963        assert!(
964            failing_count.load(Ordering::SeqCst) >= 2,
965            "failing worker should have been restarted"
966        );
967        // The stable worker should only have been started once (never restarted).
968        assert_eq!(
969            stable_count.load(Ordering::SeqCst),
970            1,
971            "stable worker should not have been restarted"
972        );
973    }
974
975    #[tokio::test]
976    async fn one_for_all_restarts_all_children() {
977        let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
978        let failing_count = failing.start_count();
979
980        let stable = MockWorker::long_running("stable-worker");
981        let stable_count = stable.start_count();
982
983        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
984            RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
985        );
986        sup.add_worker(stable);
987        sup.add_worker(failing);
988
989        let (tx, handle) = run_supervisor_with_trigger(sup).await;
990
991        // Wait for at least one restart cycle.
992        sleep(Duration::from_millis(300)).await;
993        let _ = tx.send(());
994
995        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
996        assert!(result.is_ok());
997
998        // Both workers should have been started multiple times.
999        assert!(
1000            failing_count.load(Ordering::SeqCst) >= 2,
1001            "failing worker should have been restarted"
1002        );
1003        assert!(
1004            stable_count.load(Ordering::SeqCst) >= 2,
1005            "stable worker should also have been restarted"
1006        );
1007    }
1008
1009    #[tokio::test]
1010    async fn restart_limit_exceeded_shuts_down_supervisor() {
1011        let mut sup = Supervisor::new("test-sup")
1012            .unwrap()
1013            .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
1014        // This worker fails immediately, which will exhaust the restart budget quickly.
1015        sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
1016
1017        let (tx, rx) = oneshot::channel::<()>();
1018        let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1019
1020        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1021        drop(tx);
1022
1023        assert!(matches!(result, Err(SupervisorError::Shutdown)));
1024    }
1025
1026    // -- Initialization failure tests ------------------------------------------------------
1027
1028    #[tokio::test]
1029    async fn init_failure_propagates_with_child_name() {
1030        let mut sup = Supervisor::new("test-sup").unwrap();
1031        sup.add_worker(MockWorker::long_running("good-worker"));
1032        sup.add_worker(MockWorker::init_failure("bad-worker"));
1033
1034        let (_tx, rx) = oneshot::channel::<()>();
1035        let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1036            .await
1037            .unwrap();
1038
1039        match result {
1040            Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
1041                assert_eq!(child_name, "bad-worker");
1042            }
1043            other => panic!("expected FailedToInitialize, got: {:?}", other),
1044        }
1045    }
1046
1047    #[tokio::test]
1048    async fn init_failure_does_not_trigger_restart() {
1049        let init_fail = MockWorker::init_failure("bad-worker");
1050        let start_count = init_fail.start_count();
1051
1052        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1053            RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1054        );
1055        sup.add_worker(init_fail);
1056
1057        let (_tx, rx) = oneshot::channel::<()>();
1058        let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1059            .await
1060            .unwrap();
1061
1062        assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1063        // The worker never got past init, so start_count should be 0.
1064        assert_eq!(start_count.load(Ordering::SeqCst), 0);
1065    }
1066
1067    // -- Shutdown responsiveness tests -----------------------------------------------------
1068
1069    #[tokio::test]
1070    async fn shutdown_completes_promptly_in_steady_state() {
1071        let mut sup = Supervisor::new("test-sup").unwrap();
1072        sup.add_worker(MockWorker::long_running("worker1"));
1073        sup.add_worker(MockWorker::long_running("worker2"));
1074
1075        let (tx, handle) = run_supervisor_with_trigger(sup).await;
1076        tx.send(()).unwrap();
1077
1078        // Shutdown should complete well within 1 second (workers respond to shutdown signal immediately).
1079        let result = timeout(Duration::from_secs(1), handle).await;
1080        assert!(result.is_ok(), "shutdown should complete promptly");
1081    }
1082
1083    #[tokio::test]
1084    async fn shutdown_during_slow_init_completes_promptly() {
1085        let mut sup = Supervisor::new("test-sup").unwrap();
1086        // This worker takes 30 seconds to initialize — but we'll trigger shutdown immediately.
1087        sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1088
1089        let (tx, rx) = oneshot::channel();
1090        let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1091
1092        // Give the supervisor just enough time to spawn the task, then trigger shutdown.
1093        sleep(Duration::from_millis(20)).await;
1094        tx.send(()).unwrap();
1095
1096        // Shutdown should complete quickly even though the worker hasn't finished initializing.
1097        // The supervisor loop sees the shutdown signal and aborts the still-initializing task.
1098        let result = timeout(Duration::from_secs(2), handle).await;
1099        assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1100    }
1101}