Skip to main content

saluki_core/runtime/
supervisor.rs

1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_error::GenericError;
6use snafu::{OptionExt as _, Snafu};
7use tokio::{
8    pin, select,
9    task::{AbortHandle, Id, JoinSet},
10};
11use tracing::{debug, error, warn};
12
13use super::{
14    dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
15    restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
16    shutdown::{ProcessShutdown, ShutdownHandle},
17};
18use crate::runtime::{
19    process::{Process, ProcessExt as _},
20    state::DataspaceRegistry,
21};
22
23/// A `Future` that represents the execution of a supervised process.
24pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
25
26/// A `Future` that represents the full lifecycle of a worker, including initialization.
27///
28/// Unlike [`SupervisorFuture`], which only represents the runtime phase, this future first performs async
29/// initialization and then runs the worker. This allows initialization to happen concurrently when multiple workers are
30/// spawned, and keeps the supervisor loop responsive to shutdown signals during initialization.
31type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
32
33/// Worker lifecycle errors.
34///
35/// Distinguishes between initialization failures (which should NOT trigger restart logic) and runtime failures (which
36/// are eligible for restart).
37#[derive(Debug)]
38enum WorkerError {
39    /// The worker failed during async initialization.
40    ///
41    /// The optional `child_name` carries the name of the original failing child when the error originates from a
42    /// nested supervisor. This allows the parent to include it in its own `FailedToInitialize` error for better
43    /// diagnostics across supervision tree levels.
44    Initialization {
45        child_name: Option<String>,
46        source: InitializationError,
47    },
48
49    /// The worker failed during runtime execution.
50    Runtime(GenericError),
51}
52
53impl From<SupervisorError> for WorkerError {
54    fn from(err: SupervisorError) -> Self {
55        match err {
56            // Propagate initialization failures so the parent supervisor does NOT attempt to restart.
57            // Preserve the original child name so the parent can include it in diagnostics.
58            SupervisorError::FailedToInitialize { child_name, source } => WorkerError::Initialization {
59                child_name: Some(child_name),
60                source,
61            },
62            // All other supervisor errors (shutdown, no children, invalid name) are runtime-level.
63            other => WorkerError::Runtime(other.into()),
64        }
65    }
66}
67
68/// Process errors.
69#[derive(Debug, Snafu)]
70pub enum ProcessError {
71    /// The child process was aborted by the supervisor.
72    #[snafu(display("Child process was aborted by the supervisor."))]
73    Aborted,
74
75    /// The child process panicked.
76    #[snafu(display("Child process panicked."))]
77    Panicked,
78
79    /// The child process terminated with an error.
80    #[snafu(display("Child process terminated with an error: {}", source))]
81    Terminated {
82        /// The error that caused the termination.
83        source: GenericError,
84    },
85}
86
87/// Initialization errors.
88///
89/// Initialization errors are distinct from runtime errors: they indicate that a process could not be started at all
90/// (for example, failed to bind a port, missing configuration). These errors do NOT trigger restart logic; instead, they
91/// immediately propagate up and fail the supervisor.
92#[derive(Debug, Snafu)]
93#[snafu(context(suffix(false)))]
94pub enum InitializationError {
95    /// The process could not be initialized due to an error.
96    #[snafu(display("Process failed to initialize: {}", source))]
97    Failed {
98        /// The underlying error that caused initialization to fail.
99        source: GenericError,
100    },
101}
102
103impl From<GenericError> for InitializationError {
104    fn from(source: GenericError) -> Self {
105        Self::Failed { source }
106    }
107}
108
109/// Strategy for shutting down a process.
110pub enum ShutdownStrategy {
111    /// Waits for the configured duration for the process to exit, and then forcefully aborts it otherwise.
112    Graceful(Duration),
113
114    /// Forcefully aborts the process without waiting.
115    Brutal,
116}
117
118/// A supervisable process.
119#[async_trait]
120pub trait Supervisable: Send + Sync {
121    /// Returns the name of the process.
122    fn name(&self) -> &str;
123
124    /// Returns the shutdown strategy for the process.
125    fn shutdown_strategy(&self) -> ShutdownStrategy {
126        ShutdownStrategy::Graceful(Duration::from_secs(5))
127    }
128
129    /// Initializes the process asynchronously.
130    ///
131    /// During initialization, any resources or configuration for the process can be created asynchronously, and the
132    /// same runtime that is used for running the process is used for initialization. The resulting future is expected
133    /// to complete as soon as reasonably possible after `process_shutdown` resolves.
134    ///
135    /// **Important:** The `process_shutdown` signal must be moved into the returned [`SupervisorFuture`] so the
136    /// worker can respond to supervisor-initiated shutdown. If `process_shutdown` is dropped during initialization,
137    /// the worker will be unable to shut down gracefully and will be forcefully aborted after the shutdown timeout.
138    ///
139    /// # Errors
140    ///
141    /// If the process cannot be initialized, an error is returned.
142    async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError>;
143}
144
145/// Supervisor errors.
146#[derive(Debug, Snafu)]
147#[snafu(context(suffix(false)))]
148pub enum SupervisorError {
149    /// Supervisor or worker name is invalid.
150    #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
151    InvalidName {
152        /// The name of the supervisor is invalid.
153        name: String,
154    },
155
156    /// The supervisor has no child processes.
157    #[snafu(display("Supervisor has no child processes."))]
158    NoChildren,
159
160    /// A child process failed to initialize.
161    ///
162    /// This error indicates that a child could not complete its async initialization. This is distinct from runtime
163    /// failures and does NOT trigger restart logic.
164    #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
165    FailedToInitialize {
166        /// The name of the child that failed to initialize.
167        child_name: String,
168
169        /// The underlying initialization error.
170        source: InitializationError,
171    },
172
173    /// The supervisor exceeded its restart limits and was forced to shutdown.
174    #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
175    Shutdown,
176}
177
178/// A child process specification.
179///
180/// All workers added to a [`Supervisor`] must be specified as a `ChildSpecification`. This acts a template for how the
181/// supervisor should create the underlying future that represents the process, as well as information about the
182/// process, such as its name, shutdown strategy, and more.
183///
184/// A child process specification can be created implicitly from an existing [`Supervisor`], or any type that implements
185/// [`Supervisable`].
186pub enum ChildSpecification {
187    Worker(Arc<dyn Supervisable>),
188    Supervisor(Supervisor),
189}
190
191impl ChildSpecification {
192    fn process_type(&self) -> &'static str {
193        match self {
194            Self::Worker(_) => "worker",
195            Self::Supervisor(_) => "supervisor",
196        }
197    }
198
199    fn name(&self) -> &str {
200        match self {
201            Self::Worker(worker) => worker.name(),
202            Self::Supervisor(supervisor) => &supervisor.supervisor_id,
203        }
204    }
205
206    fn shutdown_strategy(&self) -> ShutdownStrategy {
207        match self {
208            Self::Worker(worker) => worker.shutdown_strategy(),
209
210            // Supervisors should always be given as much time as necessary shutdown down gracefully to ensure that the
211            // entire supervision subtree can be shutdown cleanly.
212            Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
213        }
214    }
215
216    fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
217        match self {
218            Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
219                name: worker.name().to_string(),
220            }),
221            Self::Supervisor(sup) => {
222                Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
223                    name: sup.supervisor_id.to_string(),
224                })
225            }
226        }
227    }
228
229    fn create_worker_future(
230        &self, process: Process, process_shutdown: ProcessShutdown,
231    ) -> Result<WorkerFuture, SupervisorError> {
232        match self {
233            Self::Worker(worker) => {
234                let worker = Arc::clone(worker);
235                Ok(Box::pin(async move {
236                    let run_future =
237                        worker
238                            .initialize(process_shutdown)
239                            .await
240                            .map_err(|source| WorkerError::Initialization {
241                                child_name: None,
242                                source,
243                            })?;
244                    run_future.await.map_err(WorkerError::Runtime)
245                }))
246            }
247            Self::Supervisor(sup) => {
248                match sup.runtime_mode() {
249                    RuntimeMode::Ambient => {
250                        // Run on the parent's ambient runtime.
251                        Ok(sup.as_nested_process(process, process_shutdown))
252                    }
253                    RuntimeMode::Dedicated(config) => {
254                        // Spawn in a dedicated runtime on a new OS thread, passing the parent's
255                        // dataspace so the nested supervisor inherits it across the thread boundary.
256                        let child_name = sup.supervisor_id.to_string();
257                        let dataspace = process.dataspace().clone();
258                        let handle =
259                            spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown, dataspace)
260                                .map_err(|e| SupervisorError::FailedToInitialize {
261                                    child_name,
262                                    source: e.into(),
263                                })?;
264
265                        Ok(Box::pin(async move { handle.await.map_err(WorkerError::from) }))
266                    }
267                }
268            }
269        }
270    }
271}
272
273impl Clone for ChildSpecification {
274    fn clone(&self) -> Self {
275        match self {
276            Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
277            Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
278        }
279    }
280}
281
282impl From<Supervisor> for ChildSpecification {
283    fn from(supervisor: Supervisor) -> Self {
284        Self::Supervisor(supervisor)
285    }
286}
287
288impl<T> From<T> for ChildSpecification
289where
290    T: Supervisable + 'static,
291{
292    fn from(worker: T) -> Self {
293        Self::Worker(Arc::new(worker))
294    }
295}
296
297/// Supervises a set of workers.
298///
299/// # Workers
300///
301/// All workers are defined through implementation of the [`Supervisable`] trait, which provides the logic for both
302/// creating the underlying worker future that is spawned, as well as other metadata, such as the worker's name, how the
303/// worker should be shutdown, and so on.
304///
305/// Supervisors also (indirectly) implement the [`Supervisable`] trait, allowing them to be supervised by other
306/// supervisors in order to construct _supervision trees_.
307///
308/// # Instrumentation
309///
310/// Supervisors automatically create their own allocation group
311/// ([`TrackingAllocator`][memory_accounting::allocator::TrackingAllocator]), which is used to track both the memory
312/// usage of the supervisor itself and its children. Additionally, individual worker processes are wrapped in a
313/// dedicated [`tracing::Span`] to allow tracing the causal relationship between arbitrary code and the worker executing
314/// it.
315///
316/// # Restart Strategies
317///
318/// As the main purpose of a supervisor, restart behavior is fully configurable. A number of restart strategies are
319/// available, which generally relate to the purpose of the supervisor: whether the workers being managed are
320/// independent or interdependent.
321///
322/// All restart strategies are configured through [`RestartStrategy`], which has more information on the available
323/// strategies and configuration settings.
324pub struct Supervisor {
325    supervisor_id: Arc<str>,
326    child_specs: Vec<ChildSpecification>,
327    restart_strategy: RestartStrategy,
328    runtime_mode: RuntimeMode,
329}
330
331impl Supervisor {
332    /// Creates an empty `Supervisor` with the default restart strategy.
333    pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
334        // We try to throw an error about invalid names as early as possible. This is a manual check, so we might still
335        // encounter an error later when actually running the supervisor, but this is a good first step to catch the
336        // bulk of invalid names.
337        if supervisor_id.as_ref().is_empty() {
338            return Err(SupervisorError::InvalidName {
339                name: supervisor_id.as_ref().to_string(),
340            });
341        }
342
343        Ok(Self {
344            supervisor_id: supervisor_id.as_ref().into(),
345            child_specs: Vec::new(),
346            restart_strategy: RestartStrategy::default(),
347            runtime_mode: RuntimeMode::default(),
348        })
349    }
350
351    /// Returns the supervisor's ID.
352    pub fn id(&self) -> &str {
353        &self.supervisor_id
354    }
355
356    /// Sets the restart strategy for the supervisor.
357    pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
358        self.restart_strategy = strategy;
359        self
360    }
361
362    /// Configures this supervisor to run in a dedicated runtime.
363    ///
364    /// When this supervisor is added as a child to another supervisor, it will spawn its own OS thread(s) and Tokio
365    /// runtime instead of running on the parent's ambient runtime.
366    ///
367    /// This provides runtime isolation, which can be useful for:
368    /// - CPU-bound work that shouldn't block the parent's runtime
369    /// - Isolating failures in one part of the system
370    /// - Using different runtime configurations (for example, single-threaded vs multi-threaded)
371    pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
372        self.runtime_mode = RuntimeMode::Dedicated(config);
373        self
374    }
375
376    /// Returns the runtime mode for this supervisor.
377    pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
378        &self.runtime_mode
379    }
380
381    /// Adds a worker to the supervisor.
382    ///
383    /// A worker can be anything that implements the [`Supervisable`] trait. A [`Supervisor`] can also be added as a
384    /// worker and managed in a nested fashion, known as a supervision tree.
385    pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
386        let child_spec = process.into();
387        debug!(
388            supervisor_id = %self.supervisor_id,
389            "Adding new static child process #{}. ({}, {})",
390            self.child_specs.len(),
391            child_spec.process_type(),
392            child_spec.name(),
393        );
394        self.child_specs.push(child_spec);
395    }
396
397    fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
398        match self.child_specs.get(child_spec_idx) {
399            Some(child_spec) => child_spec,
400            None => unreachable!("child spec index should never be out of bounds"),
401        }
402    }
403
404    fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
405        let child_spec = self.get_child_spec(child_spec_idx);
406        debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
407        worker_state.add_worker(child_spec_idx, child_spec)
408    }
409
410    fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
411        debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
412        for child_spec_idx in 0..self.child_specs.len() {
413            self.spawn_child(child_spec_idx, worker_state)?;
414        }
415
416        Ok(())
417    }
418
419    async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
420        if self.child_specs.is_empty() {
421            return Err(SupervisorError::NoChildren);
422        }
423
424        let mut restart_state = RestartState::new(self.restart_strategy);
425        let mut worker_state = WorkerState::new(process);
426
427        // Spawn all child processes. Since initialization is folded into each worker's task, this returns immediately
428        // after spawning -- children initialize concurrently in the background.
429        self.spawn_all_children(&mut worker_state)?;
430
431        // Now we supervise.
432        let shutdown = process_shutdown.wait_for_shutdown();
433        pin!(shutdown);
434
435        loop {
436            select! {
437                // Shutdown has been triggered.
438                //
439                // Propagate shutdown to all child processes and wait for them to exit.
440                _ = &mut shutdown => {
441                    debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
442                    worker_state.shutdown_workers().await;
443                    break;
444                },
445                worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
446                    // TODO: Erlang/OTP defaults to always trying to restart a process, even if it doesn't terminate due
447                    // to a legitimate failure. It does allow configuring this behavior on a per-process basis, however.
448                    // We don't support dynamically adding child processes, which is the only real use case I can think
449                    // of for having non-long-lived child processes... so I think for now, we're OK just always try to
450                    // restart.
451                    Some((child_spec_idx, worker_result)) =>  {
452                        let child_spec = self.get_child_spec(child_spec_idx);
453
454                        // Initialization failures are not eligible for restart -- they propagate immediately.
455                        if let Err(WorkerError::Initialization { child_name, source }) = worker_result {
456                            // If the error came from a nested supervisor, include the original child name
457                            // to make the error chain more informative (e.g., "ctrl-pln/privileged-api").
458                            let full_name = match child_name {
459                                Some(inner) => format!("{}/{}", child_spec.name(), inner),
460                                None => child_spec.name().to_string(),
461                            };
462
463                            error!(supervisor_id = %self.supervisor_id, worker_name = full_name, "Child process failed to initialize: {}", source);
464                            worker_state.shutdown_workers().await;
465                            return Err(SupervisorError::FailedToInitialize {
466                                child_name: full_name,
467                                source,
468                            });
469                        }
470
471                        // Convert the worker result to a process error for restart evaluation.
472                        let worker_result = worker_result
473                            .map_err(|e| match e {
474                                WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
475                                WorkerError::Initialization { .. } => unreachable!("handled above"),
476                            });
477
478                        match restart_state.evaluate_restart() {
479                            RestartAction::Restart(mode) => match mode {
480                                RestartMode::OneForOne => {
481                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
482                                    self.spawn_child(child_spec_idx, &mut worker_state)?;
483                                }
484                                RestartMode::OneForAll => {
485                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
486                                    worker_state.shutdown_workers().await;
487                                    self.spawn_all_children(&mut worker_state)?;
488                                }
489                            },
490                            RestartAction::Shutdown => {
491                                error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
492                                worker_state.shutdown_workers().await;
493                                return Err(SupervisorError::Shutdown);
494                            }
495                        }
496                    },
497                    None => unreachable!("should not have empty worker joinset prior to shutdown"),
498                }
499            }
500        }
501
502        Ok(())
503    }
504
505    fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture {
506        // Simple wrapper around `run_inner` to satisfy the return type signature needed when running the supervisor as
507        // a nested child process in another supervisor.
508        debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
509
510        // Create a standalone clone of ourselves so we can fulfill the future signature.
511        let sup = self.inner_clone();
512
513        Box::pin(async move {
514            sup.run_inner(process, process_shutdown)
515                .await
516                .map_err(WorkerError::from)
517        })
518    }
519
520    /// Runs the supervisor forever.
521    ///
522    /// # Errors
523    ///
524    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
525    pub async fn run(&mut self) -> Result<(), SupervisorError> {
526        // Create a no-op `ProcessShutdown` to satisfy the `run_inner` function. This is never used since we want to run
527        // forever, but we need to satisfy the signature.
528        let process_shutdown = ProcessShutdown::noop();
529        let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
530            name: self.supervisor_id.to_string(),
531        })?;
532
533        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
534        self.run_inner(process.clone(), process_shutdown)
535            .into_process_future(process)
536            .await
537    }
538
539    /// Runs the supervisor until shutdown is triggered.
540    ///
541    /// When `shutdown` resolves, the supervisor will shutdown all child processes according to their shutdown strategy,
542    /// and then return.
543    ///
544    /// # Errors
545    ///
546    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
547    pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
548        let process_shutdown = ProcessShutdown::wrapped(shutdown);
549        self.run_with_process_shutdown(process_shutdown, None).await
550    }
551
552    /// Runs the supervisor until the given `ProcessShutdown` signal is received.
553    ///
554    /// This is an internal variant of `run_with_shutdown` that takes a `ProcessShutdown` directly, used when spawning
555    /// supervisors in dedicated runtimes where the shutdown signal is already wrapped in a `ProcessShutdown`.
556    ///
557    /// If `dataspace` is provided, the supervisor will use it instead of creating a new one. This is used to propagate
558    /// the parent's dataspace across OS thread boundaries for dedicated runtimes.
559    ///
560    /// # Errors
561    ///
562    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
563    pub(crate) async fn run_with_process_shutdown(
564        &mut self, process_shutdown: ProcessShutdown, dataspace: Option<DataspaceRegistry>,
565    ) -> Result<(), SupervisorError> {
566        let process =
567            Process::supervisor_with_dataspace(&self.supervisor_id, None, dataspace).context(InvalidName {
568                name: self.supervisor_id.to_string(),
569            })?;
570
571        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
572        self.run_inner(process.clone(), process_shutdown)
573            .into_process_future(process)
574            .await
575    }
576
577    fn inner_clone(&self) -> Self {
578        // This is no different than if we just implemented `Clone` directly, but it allows us to avoid exposing a
579        // _public_ implementation of `Clone`, which we don't want normal users to be able to do. We only need this
580        // internally to support nested supervisors.
581        Self {
582            supervisor_id: Arc::clone(&self.supervisor_id),
583            child_specs: self.child_specs.clone(),
584            restart_strategy: self.restart_strategy,
585            runtime_mode: self.runtime_mode.clone(),
586        }
587    }
588}
589
590struct ProcessState {
591    worker_id: usize,
592    shutdown_strategy: ShutdownStrategy,
593    shutdown_handle: ShutdownHandle,
594    abort_handle: AbortHandle,
595}
596
597struct WorkerState {
598    process: Process,
599    worker_tasks: JoinSet<Result<(), WorkerError>>,
600    worker_map: FastIndexMap<Id, ProcessState>,
601}
602
603impl WorkerState {
604    fn new(process: Process) -> Self {
605        Self {
606            process,
607            worker_tasks: JoinSet::new(),
608            worker_map: FastIndexMap::default(),
609        }
610    }
611
612    fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
613        let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
614        let process = child_spec.create_process(&self.process)?;
615        let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?;
616        let shutdown_strategy = child_spec.shutdown_strategy();
617        let abort_handle = self.worker_tasks.spawn(worker_future.into_process_future(process));
618        self.worker_map.insert(
619            abort_handle.id(),
620            ProcessState {
621                worker_id,
622                shutdown_strategy,
623                shutdown_handle,
624                abort_handle,
625            },
626        );
627        Ok(())
628    }
629
630    async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> {
631        debug!("Waiting for next process to complete.");
632
633        match self.worker_tasks.join_next_with_id().await {
634            Some(Ok((worker_task_id, worker_result))) => {
635                let process_state = self
636                    .worker_map
637                    .swap_remove(&worker_task_id)
638                    .expect("worker task ID not found");
639                Some((process_state.worker_id, worker_result))
640            }
641            Some(Err(e)) => {
642                let worker_task_id = e.id();
643                let process_state = self
644                    .worker_map
645                    .swap_remove(&worker_task_id)
646                    .expect("worker task ID not found");
647                let e = if e.is_cancelled() {
648                    ProcessError::Aborted
649                } else {
650                    ProcessError::Panicked
651                };
652                Some((process_state.worker_id, Err(WorkerError::Runtime(e.into()))))
653            }
654            None => None,
655        }
656    }
657
658    async fn shutdown_workers(&mut self) {
659        debug!("Shutting down all processes.");
660
661        // Pop entries from the worker map, which grabs us workers in the reverse order they were added. This lets us
662        // ensure we're shutting down any _dependent_ processes (processes which depend on previously-started processes)
663        // first.
664        //
665        // For each entry, we trigger shutdown in whatever way necessary, and then wait for the process to exit by
666        // driving the `JoinSet`. If other workers complete while we're waiting, we'll simply remove them from the
667        // worker map and continue waiting for the current worker we're shutting down.
668        //
669        // We do this until the worker map is empty, at which point we can be sure that all processes have exited.
670        while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
671            let ProcessState {
672                worker_id,
673                shutdown_strategy,
674                shutdown_handle,
675                abort_handle,
676            } = process_state;
677
678            // Trigger the process to shutdown based on the configured shutdown strategy.
679            let shutdown_deadline = match shutdown_strategy {
680                ShutdownStrategy::Graceful(timeout) => {
681                    debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
682                    shutdown_handle.trigger();
683
684                    tokio::time::sleep(timeout)
685                }
686                ShutdownStrategy::Brutal => {
687                    debug!(worker_id, "Forcefully aborting process.");
688                    abort_handle.abort();
689
690                    // We have to return a future that never resolves, since we're already aborting it. This is a little
691                    // hacky but it's also difficult to do an optional future, so this is what we're going with for now.
692                    tokio::time::sleep(Duration::MAX)
693                }
694            };
695            pin!(shutdown_deadline);
696
697            // Wait for the process to exit by driving the `JoinSet`. If other workers complete while we're waiting,
698            // we'll simply remove them from the worker map and continue waiting.
699            loop {
700                select! {
701                    worker_result = self.worker_tasks.join_next_with_id() => {
702                        match worker_result {
703                            Some(Ok((worker_task_id, _))) => {
704                                if worker_task_id == current_worker_task_id {
705                                    debug!(?worker_task_id, "Target process exited successfully.");
706                                    break;
707                                } else {
708                                    debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
709                                    self.worker_map.swap_remove(&worker_task_id);
710                                }
711                            },
712                            Some(Err(e)) => {
713                                let worker_task_id = e.id();
714                                if worker_task_id == current_worker_task_id {
715                                    debug!(?worker_task_id, "Target process exited with error.");
716                                    break;
717                                } else {
718                                    debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
719                                    self.worker_map.swap_remove(&worker_task_id);
720                                }
721                            }
722                            None => unreachable!("worker task must exist in join set if we are waiting for it"),
723                        }
724                    },
725                    // We've exceeded the shutdown timeout, so we need to abort the process.
726                    _ = &mut shutdown_deadline => {
727                        debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
728                        abort_handle.abort();
729                    }
730                }
731            }
732        }
733
734        debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
735        debug_assert!(
736            self.worker_tasks.is_empty(),
737            "worker tasks should be empty after shutdown"
738        );
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use std::sync::atomic::{AtomicUsize, Ordering};
745
746    use async_trait::async_trait;
747    use tokio::{
748        sync::oneshot,
749        task::JoinHandle,
750        time::{sleep, timeout},
751    };
752
753    use super::*;
754
755    /// Behavior for a mock worker during initialization.
756    #[derive(Clone)]
757    enum InitBehavior {
758        /// Initialization succeeds immediately.
759        Instant,
760
761        /// Initialization takes the given duration before succeeding.
762        Slow(Duration),
763
764        /// Initialization fails with the given message.
765        Fail(&'static str),
766    }
767
768    /// Behavior for a mock worker during runtime (after initialization).
769    #[derive(Clone)]
770    enum RunBehavior {
771        /// Runs until shutdown is received.
772        UntilShutdown,
773
774        /// Fails with the given error message after the given delay.
775        FailAfter(Duration, &'static str),
776    }
777
778    /// A configurable mock worker for testing supervisor behavior.
779    struct MockWorker {
780        name: &'static str,
781        init_behavior: InitBehavior,
782        run_behavior: RunBehavior,
783        start_count: Arc<AtomicUsize>,
784    }
785
786    impl MockWorker {
787        /// Creates a worker that runs until shutdown.
788        fn long_running(name: &'static str) -> Self {
789            Self {
790                name,
791                init_behavior: InitBehavior::Instant,
792                run_behavior: RunBehavior::UntilShutdown,
793                start_count: Arc::new(AtomicUsize::new(0)),
794            }
795        }
796
797        /// Creates a worker that fails after the given delay.
798        fn failing(name: &'static str, delay: Duration) -> Self {
799            Self {
800                name,
801                init_behavior: InitBehavior::Instant,
802                run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
803                start_count: Arc::new(AtomicUsize::new(0)),
804            }
805        }
806
807        /// Creates a worker that fails during initialization.
808        fn init_failure(name: &'static str) -> Self {
809            Self {
810                name,
811                init_behavior: InitBehavior::Fail("init failed"),
812                run_behavior: RunBehavior::UntilShutdown,
813                start_count: Arc::new(AtomicUsize::new(0)),
814            }
815        }
816
817        /// Creates a worker with slow initialization.
818        fn slow_init(name: &'static str, init_delay: Duration) -> Self {
819            Self {
820                name,
821                init_behavior: InitBehavior::Slow(init_delay),
822                run_behavior: RunBehavior::UntilShutdown,
823                start_count: Arc::new(AtomicUsize::new(0)),
824            }
825        }
826
827        /// Returns a shared handle to the start count for this worker.
828        fn start_count(&self) -> Arc<AtomicUsize> {
829            Arc::clone(&self.start_count)
830        }
831    }
832
833    #[async_trait]
834    impl Supervisable for MockWorker {
835        fn name(&self) -> &str {
836            self.name
837        }
838
839        fn shutdown_strategy(&self) -> ShutdownStrategy {
840            ShutdownStrategy::Graceful(Duration::from_millis(500))
841        }
842
843        async fn initialize(
844            &self, mut process_shutdown: ProcessShutdown,
845        ) -> Result<SupervisorFuture, InitializationError> {
846            match &self.init_behavior {
847                InitBehavior::Instant => {}
848                InitBehavior::Slow(delay) => {
849                    sleep(*delay).await;
850                }
851                InitBehavior::Fail(msg) => {
852                    return Err(InitializationError::Failed {
853                        source: GenericError::msg(*msg),
854                    });
855                }
856            }
857
858            let start_count = Arc::clone(&self.start_count);
859            let run_behavior = self.run_behavior.clone();
860
861            Ok(Box::pin(async move {
862                start_count.fetch_add(1, Ordering::SeqCst);
863
864                match run_behavior {
865                    RunBehavior::UntilShutdown => {
866                        process_shutdown.wait_for_shutdown().await;
867                        Ok(())
868                    }
869                    RunBehavior::FailAfter(delay, msg) => {
870                        select! {
871                            _ = sleep(delay) => {
872                                Err(GenericError::msg(msg))
873                            }
874                            _ = process_shutdown.wait_for_shutdown() => {
875                                Ok(())
876                            }
877                        }
878                    }
879                }
880            }))
881        }
882    }
883
884    /// Helper: run a supervisor with a oneshot-based shutdown trigger.
885    /// Returns the supervisor result and provides the shutdown sender.
886    async fn run_supervisor_with_trigger(
887        mut supervisor: Supervisor,
888    ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
889        let (tx, rx) = oneshot::channel();
890        let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
891        // Give the supervisor a moment to start and spawn children.
892        sleep(Duration::from_millis(50)).await;
893        (tx, handle)
894    }
895
896    // -- Supervisor run mode tests ---------------------------------------------------------
897
898    #[tokio::test]
899    async fn standalone_supervisor_shuts_down_cleanly() {
900        let mut sup = Supervisor::new("test-sup").unwrap();
901        sup.add_worker(MockWorker::long_running("worker1"));
902        sup.add_worker(MockWorker::long_running("worker2"));
903
904        let (tx, handle) = run_supervisor_with_trigger(sup).await;
905        tx.send(()).unwrap();
906
907        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
908        assert!(result.is_ok());
909    }
910
911    #[tokio::test]
912    async fn nested_supervisor_shuts_down_cleanly() {
913        let mut child_sup = Supervisor::new("child-sup").unwrap();
914        child_sup.add_worker(MockWorker::long_running("inner-worker"));
915
916        let mut parent_sup = Supervisor::new("parent-sup").unwrap();
917        parent_sup.add_worker(MockWorker::long_running("outer-worker"));
918        parent_sup.add_worker(child_sup);
919
920        let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
921        tx.send(()).unwrap();
922
923        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
924        assert!(result.is_ok());
925    }
926
927    #[tokio::test]
928    async fn supervisor_with_no_children_returns_error() {
929        let mut sup = Supervisor::new("empty-sup").unwrap();
930
931        let (tx, rx) = oneshot::channel::<()>();
932        let result = sup.run_with_shutdown(rx).await;
933        drop(tx);
934
935        assert!(matches!(result, Err(SupervisorError::NoChildren)));
936    }
937
938    // -- Child restart behavior tests ------------------------------------------------------
939
940    #[tokio::test]
941    async fn one_for_one_restarts_only_failed_child() {
942        let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
943        let failing_count = failing.start_count();
944
945        let stable = MockWorker::long_running("stable-worker");
946        let stable_count = stable.start_count();
947
948        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
949            RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
950        );
951        sup.add_worker(stable);
952        sup.add_worker(failing);
953
954        let (tx, handle) = run_supervisor_with_trigger(sup).await;
955
956        // Wait for a few restarts to happen.
957        sleep(Duration::from_millis(300)).await;
958        let _ = tx.send(());
959
960        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
961        assert!(result.is_ok());
962
963        // The failing worker should have been started multiple times.
964        assert!(
965            failing_count.load(Ordering::SeqCst) >= 2,
966            "failing worker should have been restarted"
967        );
968        // The stable worker should only have been started once (never restarted).
969        assert_eq!(
970            stable_count.load(Ordering::SeqCst),
971            1,
972            "stable worker should not have been restarted"
973        );
974    }
975
976    #[tokio::test]
977    async fn one_for_all_restarts_all_children() {
978        let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
979        let failing_count = failing.start_count();
980
981        let stable = MockWorker::long_running("stable-worker");
982        let stable_count = stable.start_count();
983
984        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
985            RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
986        );
987        sup.add_worker(stable);
988        sup.add_worker(failing);
989
990        let (tx, handle) = run_supervisor_with_trigger(sup).await;
991
992        // Wait for at least one restart cycle.
993        sleep(Duration::from_millis(300)).await;
994        let _ = tx.send(());
995
996        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
997        assert!(result.is_ok());
998
999        // Both workers should have been started multiple times.
1000        assert!(
1001            failing_count.load(Ordering::SeqCst) >= 2,
1002            "failing worker should have been restarted"
1003        );
1004        assert!(
1005            stable_count.load(Ordering::SeqCst) >= 2,
1006            "stable worker should also have been restarted"
1007        );
1008    }
1009
1010    #[tokio::test]
1011    async fn restart_limit_exceeded_shuts_down_supervisor() {
1012        let mut sup = Supervisor::new("test-sup")
1013            .unwrap()
1014            .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
1015        // This worker fails immediately, which will exhaust the restart budget quickly.
1016        sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
1017
1018        let (tx, rx) = oneshot::channel::<()>();
1019        let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1020
1021        let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1022        drop(tx);
1023
1024        assert!(matches!(result, Err(SupervisorError::Shutdown)));
1025    }
1026
1027    // -- Initialization failure tests ------------------------------------------------------
1028
1029    #[tokio::test]
1030    async fn init_failure_propagates_with_child_name() {
1031        let mut sup = Supervisor::new("test-sup").unwrap();
1032        sup.add_worker(MockWorker::long_running("good-worker"));
1033        sup.add_worker(MockWorker::init_failure("bad-worker"));
1034
1035        let (_tx, rx) = oneshot::channel::<()>();
1036        let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1037            .await
1038            .unwrap();
1039
1040        match result {
1041            Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
1042                assert_eq!(child_name, "bad-worker");
1043            }
1044            other => panic!("expected FailedToInitialize, got: {:?}", other),
1045        }
1046    }
1047
1048    #[tokio::test]
1049    async fn init_failure_does_not_trigger_restart() {
1050        let init_fail = MockWorker::init_failure("bad-worker");
1051        let start_count = init_fail.start_count();
1052
1053        let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1054            RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1055        );
1056        sup.add_worker(init_fail);
1057
1058        let (_tx, rx) = oneshot::channel::<()>();
1059        let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1060            .await
1061            .unwrap();
1062
1063        assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1064        // The worker never got past init, so start_count should be 0.
1065        assert_eq!(start_count.load(Ordering::SeqCst), 0);
1066    }
1067
1068    // -- Shutdown responsiveness tests -----------------------------------------------------
1069
1070    #[tokio::test]
1071    async fn shutdown_completes_promptly_in_steady_state() {
1072        let mut sup = Supervisor::new("test-sup").unwrap();
1073        sup.add_worker(MockWorker::long_running("worker1"));
1074        sup.add_worker(MockWorker::long_running("worker2"));
1075
1076        let (tx, handle) = run_supervisor_with_trigger(sup).await;
1077        tx.send(()).unwrap();
1078
1079        // Shutdown should complete well within 1 second (workers respond to shutdown signal immediately).
1080        let result = timeout(Duration::from_secs(1), handle).await;
1081        assert!(result.is_ok(), "shutdown should complete promptly");
1082    }
1083
1084    #[tokio::test]
1085    async fn shutdown_during_slow_init_completes_promptly() {
1086        let mut sup = Supervisor::new("test-sup").unwrap();
1087        // This worker takes 30 seconds to initialize — but we'll trigger shutdown immediately.
1088        sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1089
1090        let (tx, rx) = oneshot::channel();
1091        let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1092
1093        // Give the supervisor just enough time to spawn the task, then trigger shutdown.
1094        sleep(Duration::from_millis(20)).await;
1095        tx.send(()).unwrap();
1096
1097        // Shutdown should complete quickly even though the worker hasn't finished initializing.
1098        // The supervisor loop sees the shutdown signal and aborts the still-initializing task.
1099        let result = timeout(Duration::from_secs(2), handle).await;
1100        assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1101    }
1102}