Skip to main content

saluki_core/runtime/
supervisor.rs

1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use saluki_common::collections::FastIndexMap;
4use saluki_error::{ErrorContext as _, GenericError};
5use snafu::{OptionExt as _, Snafu};
6use tokio::{
7    pin, select,
8    task::{AbortHandle, Id, JoinSet},
9};
10use tracing::{debug, error, warn};
11
12use super::{
13    restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
14    shutdown::{ProcessShutdown, ShutdownHandle},
15};
16use crate::runtime::process::{Process, ProcessExt as _};
17
18/// A `Future` that represents the execution of a supervised process.
19pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
20
21/// Process errors.
22#[derive(Debug, Snafu)]
23pub enum ProcessError {
24    /// The child process was aborted by the supervisor.
25    #[snafu(display("Child process was aborted by the supervisor."))]
26    Aborted,
27
28    /// The child process panicked.
29    #[snafu(display("Child process panicked."))]
30    Panicked,
31
32    /// The child process terminated with an error.
33    #[snafu(display("Child process terminated with an error: {}", source))]
34    Terminated {
35        /// The error that caused the termination.
36        source: GenericError,
37    },
38}
39
40/// Strategy for shutting down a process.
41pub enum ShutdownStrategy {
42    /// Waits for the configured duration for the process to exit, and then forcefully aborts it otherwise.
43    Graceful(Duration),
44
45    /// Forcefully aborts the process without waiting.
46    Brutal,
47}
48
49/// A supervisable process.
50pub trait Supervisable: Send + Sync {
51    /// Returns the name of the process.
52    fn name(&self) -> &str;
53
54    /// Defines the shutdown strategy for the process.
55    fn shutdown_strategy(&self) -> ShutdownStrategy {
56        ShutdownStrategy::Graceful(Duration::from_secs(5))
57    }
58
59    /// Initialize a `Future` that represents the execution of the process.
60    ///
61    /// When `Some` is returned, the process is spawned and managed by the supervisor. When `None` is returned, the
62    /// process is considered to be permanently failed. This can be useful for supervised tasks that are not expected to
63    /// ever fail, or cannot support restart, but should still be managed within the same supervision hierarchy as other
64    /// processes.
65    fn initialize(&self, process_shutdown: ProcessShutdown) -> Option<SupervisorFuture>;
66}
67
68/// Supervisor errors.
69#[derive(Debug, Snafu)]
70#[snafu(context(suffix(false)))]
71pub enum SupervisorError {
72    /// Supervisor or worker name is invalid.
73    #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
74    InvalidName {
75        /// The name of the supervisor is invalid.
76        name: String,
77    },
78
79    /// The supervisor has no child processes.
80    #[snafu(display("Supervisor has no child processes."))]
81    NoChildren,
82
83    /// A child process failed to initialize.
84    #[snafu(display("Child process failed to initialize."))]
85    FailedToInitialize,
86
87    /// The supervisor exceeded its restart limits and was forced to shutdown.
88    #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
89    Shutdown,
90}
91
92/// A child process specification.
93///
94/// All workers added to a [`Supervisor`] must be specified as a `ChildSpecification`. This acts a template for how the
95/// supervisor should create the underlying future that represents the process, as well as information about the
96/// process, such as its name, shutdown strategy, and more.
97///
98/// A child process specification can be created implicitly from an existing [`Supervisor`], or any type that implements
99/// [`Supervisable`].
100pub enum ChildSpecification {
101    Worker(Arc<dyn Supervisable>),
102    Supervisor(Supervisor),
103}
104
105impl ChildSpecification {
106    fn process_type(&self) -> &'static str {
107        match self {
108            Self::Worker(_) => "worker",
109            Self::Supervisor(_) => "supervisor",
110        }
111    }
112
113    fn name(&self) -> &str {
114        match self {
115            Self::Worker(worker) => worker.name(),
116            Self::Supervisor(supervisor) => &supervisor.supervisor_id,
117        }
118    }
119
120    fn shutdown_strategy(&self) -> ShutdownStrategy {
121        match self {
122            Self::Worker(worker) => worker.shutdown_strategy(),
123
124            // Supervisors should always be given as much time as necessary shutdown down gracefully to ensure that the
125            // entire supervision subtree can be shutdown cleanly.
126            Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
127        }
128    }
129
130    fn initialize(
131        &self, parent_process: &Process, process_shutdown: ProcessShutdown,
132    ) -> Result<Option<(Process, SupervisorFuture)>, SupervisorError> {
133        match self {
134            Self::Worker(worker) => {
135                let process = Process::worker(worker.name(), parent_process).context(InvalidName {
136                    name: worker.name().to_string(),
137                })?;
138                Ok(worker.initialize(process_shutdown).map(|future| (process, future)))
139            }
140            Self::Supervisor(sup) => {
141                let process = Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
142                    name: sup.supervisor_id.to_string(),
143                })?;
144                Ok(Some((
145                    process.clone(),
146                    sup.as_nested_process(process, process_shutdown),
147                )))
148            }
149        }
150    }
151}
152
153impl Clone for ChildSpecification {
154    fn clone(&self) -> Self {
155        match self {
156            Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
157            Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
158        }
159    }
160}
161
162impl From<Supervisor> for ChildSpecification {
163    fn from(supervisor: Supervisor) -> Self {
164        Self::Supervisor(supervisor)
165    }
166}
167
168impl<T> From<T> for ChildSpecification
169where
170    T: Supervisable + 'static,
171{
172    fn from(worker: T) -> Self {
173        Self::Worker(Arc::new(worker))
174    }
175}
176
177/// Supervises a set of workers.
178///
179/// # Workers
180///
181/// All workers are defined through implementation of the [`Supervisable`] trait, which provides the logic for both
182/// creating the underlying worker future that is spawned, as well as other metadata, such as the worker's name, how the
183/// worker should be shutdown, and so on.
184///
185/// Supervisors also (indirectly) implement the [`Supervisable`] trait, allowing them to be supervised by other
186/// supervisors in order to construct _supervision trees_.
187///
188/// # Instrumentation
189///
190/// Supervisors automatically create their own allocation group
191/// ([`TrackingAllocator`][memory_accounting::allocator::TrackingAllocator]), which is used to track both the memory usage of the
192/// supervisor itself and its children. Additionally, individual worker processes are wrapped in a dedicated
193/// [`tracing::Span`] to allow tracing the casual relationship between arbitrary code and the worker executing it.
194///
195/// # Restart Strategies
196///
197/// As the main purpose of a supervisor, restart behavior is fully configurable. A number of restart strategies are
198/// available, which generally relate to the purpose of the supervisor: whether the workers being managed are
199/// independent or interdependent.
200///
201/// All restart strategies are configured through [`RestartStrategy`], which has more information on the available
202/// strategies and configuration settings.
203pub struct Supervisor {
204    supervisor_id: Arc<str>,
205    child_specs: Vec<ChildSpecification>,
206    restart_strategy: RestartStrategy,
207}
208
209impl Supervisor {
210    /// Creates an empty `Supervisor` with the default restart strategy.
211    pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
212        // We try to throw an error about invalid names as early as possible. This is a manual check, so we might still
213        // encounter an error later when actually running the supervisor, but this is a good first step to catch the
214        // bulk of invalid names.
215        if supervisor_id.as_ref().is_empty() {
216            return Err(SupervisorError::InvalidName {
217                name: supervisor_id.as_ref().to_string(),
218            });
219        }
220
221        Ok(Self {
222            supervisor_id: supervisor_id.as_ref().into(),
223            child_specs: Vec::new(),
224            restart_strategy: RestartStrategy::default(),
225        })
226    }
227
228    /// Sets the restart strategy for the supervisor.
229    pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
230        self.restart_strategy = strategy;
231        self
232    }
233
234    /// Adds a worker to the supervisor.
235    ///
236    /// A worker can be anything that implements the [`Supervisable`] trait. A [`Supervisor`] can also be added as a
237    /// worker and managed in a nested fashion, known as a supervision tree.
238    pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
239        let child_spec = process.into();
240        debug!(
241            supervisor_id = %self.supervisor_id,
242            "Adding new static child process #{}. ({}, {})",
243            self.child_specs.len(),
244            child_spec.process_type(),
245            child_spec.name(),
246        );
247        self.child_specs.push(child_spec);
248    }
249
250    fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
251        match self.child_specs.get(child_spec_idx) {
252            Some(child_spec) => child_spec,
253            None => unreachable!("child spec index should never be out of bounds"),
254        }
255    }
256
257    fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
258        let child_spec = self.get_child_spec(child_spec_idx);
259        debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
260        worker_state.add_worker(child_spec_idx, child_spec)
261    }
262
263    fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
264        debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
265        for child_spec_idx in 0..self.child_specs.len() {
266            self.spawn_child(child_spec_idx, worker_state)?;
267        }
268
269        Ok(())
270    }
271
272    async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
273        if self.child_specs.is_empty() {
274            return Err(SupervisorError::NoChildren);
275        }
276
277        let mut restart_state = RestartState::new(self.restart_strategy);
278        let mut worker_state = WorkerState::new(process);
279
280        // Do the initial spawn of all child processes and supervisors.
281        self.spawn_all_children(&mut worker_state)?;
282
283        // Now we supervise.
284        let shutdown = process_shutdown.wait_for_shutdown();
285        pin!(shutdown);
286
287        loop {
288            select! {
289                // Shutdown has been triggered.
290                //
291                // Propagate shutdown to all child processes and wait for them to exit.
292                _ = &mut shutdown => {
293                    debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
294                    worker_state.shutdown_workers().await;
295                    break;
296                },
297                worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
298                    // TODO: Erlang/OTP defaults to always trying to restart a process, even if it doesn't terminate due to a
299                    // legitimate failure. It does allow configuring this behavior on a per-process basis, however. We don't
300                    // support dynamically adding child processes, which is the only real use case I can think of for having
301                    // non-long-lived child processes... so I think for now, we're OK just always try to restart.
302                    Some((child_spec_idx, worker_result)) =>  {
303                        let child_spec = self.get_child_spec(child_spec_idx);
304                        match restart_state.evaluate_restart() {
305                            RestartAction::Restart(mode) => match mode {
306                                RestartMode::OneForOne => {
307                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
308                                    self.spawn_child(child_spec_idx, &mut worker_state)?;
309                                }
310                                RestartMode::OneForAll => {
311                                    warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
312                                    worker_state.shutdown_workers().await;
313                                    self.spawn_all_children(&mut worker_state)?;
314                                }
315                            },
316                            RestartAction::Shutdown => {
317                                error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
318                                worker_state.shutdown_workers().await;
319                                return Err(SupervisorError::Shutdown);
320                            }
321                        }
322                    },
323                    None => unreachable!("should not have empty worker joinset prior to shutdown"),
324                }
325            }
326        }
327
328        Ok(())
329    }
330
331    fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> SupervisorFuture {
332        // Simple wrapper around `run_inner` to satisfy the return type signature needed when running the supervisor as
333        // a nested child process in another supervisor.
334        debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
335
336        // Create a standalone clone of ourselves so we can fulfill the future signature.
337        let sup = self.inner_clone();
338
339        Box::pin(async move {
340            sup.run_inner(process, process_shutdown)
341                .await
342                .error_context("Nested supervisor failed to exit cleanly.")
343        })
344    }
345
346    /// Runs the supervisor forever.
347    ///
348    /// # Errors
349    ///
350    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
351    pub async fn run(&mut self) -> Result<(), SupervisorError> {
352        // Create a no-op `ProcessShutdown` to satisfy the `run_inner` function. This is never used since we want to
353        // run forever, but we need to satisfy the signature.
354        let process_shutdown = ProcessShutdown::noop();
355        let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
356            name: self.supervisor_id.to_string(),
357        })?;
358
359        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
360        self.run_inner(process.clone(), process_shutdown)
361            .into_instrumented(process)
362            .await
363    }
364
365    /// Runs the supervisor until shutdown is triggered.
366    ///
367    /// When `shutdown` resolves, the supervisor will shutdown all child processes according to their shutdown strategy,
368    /// and then return.
369    ///
370    /// # Errors
371    ///
372    /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned.
373    pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
374        let process_shutdown = ProcessShutdown::wrapped(shutdown);
375        let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
376            name: self.supervisor_id.to_string(),
377        })?;
378
379        debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
380        self.run_inner(process.clone(), process_shutdown)
381            .into_instrumented(process)
382            .await
383    }
384
385    fn inner_clone(&self) -> Self {
386        // This is no different than if we just implemented `Clone` directly, but it allows us to avoid exposing a
387        // _public_ implementation of `Clone`, which we don't want normal users to be able to do. We only need this
388        // internally to support nested supervisors.
389        Self {
390            supervisor_id: Arc::clone(&self.supervisor_id),
391            child_specs: self.child_specs.clone(),
392            restart_strategy: self.restart_strategy,
393        }
394    }
395}
396
397struct ProcessState {
398    worker_id: usize,
399    shutdown_strategy: ShutdownStrategy,
400    shutdown_handle: ShutdownHandle,
401    abort_handle: AbortHandle,
402}
403
404struct WorkerState {
405    process: Process,
406    worker_tasks: JoinSet<Result<(), GenericError>>,
407    worker_map: FastIndexMap<Id, ProcessState>,
408}
409
410impl WorkerState {
411    fn new(process: Process) -> Self {
412        Self {
413            process,
414            worker_tasks: JoinSet::new(),
415            worker_map: FastIndexMap::default(),
416        }
417    }
418
419    fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
420        let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
421        match child_spec.initialize(&self.process, process_shutdown)? {
422            Some((process, worker)) => {
423                let shutdown_strategy = child_spec.shutdown_strategy();
424
425                let abort_handle = self.worker_tasks.spawn(worker.into_instrumented(process));
426                self.worker_map.insert(
427                    abort_handle.id(),
428                    ProcessState {
429                        worker_id,
430                        shutdown_strategy,
431                        shutdown_handle,
432                        abort_handle,
433                    },
434                );
435
436                Ok(())
437            }
438            None => Err(SupervisorError::FailedToInitialize),
439        }
440    }
441
442    async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), ProcessError>)> {
443        debug!("Waiting for next process to complete.");
444
445        match self.worker_tasks.join_next_with_id().await {
446            Some(Ok((worker_task_id, worker_result))) => {
447                let process_state = self
448                    .worker_map
449                    .swap_remove(&worker_task_id)
450                    .expect("worker task ID not found");
451                Some((
452                    process_state.worker_id,
453                    worker_result.map_err(|e| ProcessError::Terminated { source: e }),
454                ))
455            }
456            Some(Err(e)) => {
457                let worker_task_id = e.id();
458                let process_state = self
459                    .worker_map
460                    .swap_remove(&worker_task_id)
461                    .expect("worker task ID not found");
462                let e = if e.is_cancelled() {
463                    ProcessError::Aborted
464                } else {
465                    ProcessError::Panicked
466                };
467                Some((process_state.worker_id, Err(e)))
468            }
469            None => None,
470        }
471    }
472
473    async fn shutdown_workers(&mut self) {
474        debug!("Shutting down all processes.");
475
476        // Pop entries from the worker map, which grabs us workers in the reverse order they were added. This lets us
477        // ensure we're shutting down any _dependent_ processes (processes which depend on previously-started processes)
478        // first.
479        //
480        // For each entry, we trigger shutdown in whatever way necessary, and then wait for the process to exit by
481        // driving the `JoinSet`. If other workers complete while we're waiting, we'll simply remove them from the
482        // worker map and continue waiting for the current worker we're shutting down.
483        //
484        // We do this until the worker map is empty, at which point we can be sure that all processes have exited.
485        while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
486            let ProcessState {
487                worker_id,
488                shutdown_strategy,
489                shutdown_handle,
490                abort_handle,
491            } = process_state;
492
493            // Trigger the process to shutdown based on the configured shutdown strategy.
494            let shutdown_deadline = match shutdown_strategy {
495                ShutdownStrategy::Graceful(timeout) => {
496                    debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
497                    shutdown_handle.trigger();
498
499                    tokio::time::sleep(timeout)
500                }
501                ShutdownStrategy::Brutal => {
502                    debug!(worker_id, "Forcefully aborting process.");
503                    abort_handle.abort();
504
505                    // We have to return a future that never resolves, since we're already aborting it. This is a little
506                    // hacky but it's also difficult to do an optional future, so this is what we're going with for now.
507                    tokio::time::sleep(Duration::MAX)
508                }
509            };
510            pin!(shutdown_deadline);
511
512            // Wait for the process to exit by driving the `JoinSet`. If other workers complete while we're waiting,
513            // we'll simply remove them from the worker map and continue waiting.
514            loop {
515                select! {
516                    worker_result = self.worker_tasks.join_next_with_id() => {
517                        match worker_result {
518                            Some(Ok((worker_task_id, _))) => {
519                                if worker_task_id == current_worker_task_id {
520                                    debug!(?worker_task_id, "Target process exited successfully.");
521                                    break;
522                                } else {
523                                    debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
524                                    self.worker_map.swap_remove(&worker_task_id);
525                                }
526                            },
527                            Some(Err(e)) => {
528                                let worker_task_id = e.id();
529                                if worker_task_id == current_worker_task_id {
530                                    debug!(?worker_task_id, "Target process exited with error.");
531                                    break;
532                                } else {
533                                    debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
534                                    self.worker_map.swap_remove(&worker_task_id);
535                                }
536                            }
537                            None => unreachable!("worker task must exist in join set if we are waiting for it"),
538                        }
539                    },
540                    // We've exceeded the shutdown timeout, so we need to abort the process.
541                    _ = &mut shutdown_deadline => {
542                        debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
543                        abort_handle.abort();
544                    }
545                }
546            }
547        }
548
549        debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
550        debug_assert!(
551            self.worker_tasks.is_empty(),
552            "worker tasks should be empty after shutdown"
553        );
554    }
555}