1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_common::sync::shutdown::{ShutdownCoordinator, ShutdownHandle};
6use saluki_error::GenericError;
7use snafu::{OptionExt as _, Snafu};
8use tokio::{
9 pin, select,
10 task::{AbortHandle, Id, JoinSet},
11};
12use tracing::{debug, error, warn};
13
14use super::{
15 dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
16 restart::{RestartAction, RestartMode, RestartState, RestartStrategy, RestartType},
17};
18use crate::runtime::{
19 process::{Process, ProcessExt as _},
20 state::DataspaceRegistry,
21};
22
23pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
25
26type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
32
33#[derive(Debug)]
38enum WorkerError {
39 Initialization {
45 child_name: Option<String>,
46 source: InitializationError,
47 },
48
49 Runtime(GenericError),
51}
52
53impl From<SupervisorError> for WorkerError {
54 fn from(err: SupervisorError) -> Self {
55 match err {
56 SupervisorError::FailedToInitialize { child_name, source } => WorkerError::Initialization {
59 child_name: Some(child_name),
60 source,
61 },
62 other => WorkerError::Runtime(other.into()),
64 }
65 }
66}
67
68#[derive(Debug, Snafu)]
70pub enum ProcessError {
71 #[snafu(display("Child process was aborted by the supervisor."))]
73 Aborted,
74
75 #[snafu(display("Child process panicked."))]
77 Panicked,
78
79 #[snafu(display("Child process terminated with an error: {}", source))]
81 Terminated {
82 source: GenericError,
84 },
85}
86
87#[derive(Debug, Snafu)]
93#[snafu(context(suffix(false)))]
94pub enum InitializationError {
95 #[snafu(display("Process failed to initialize: {}", source))]
97 Failed {
98 source: GenericError,
100 },
101}
102
103impl From<GenericError> for InitializationError {
104 fn from(source: GenericError) -> Self {
105 Self::Failed { source }
106 }
107}
108
109pub enum ShutdownStrategy {
111 Graceful(Duration),
113
114 Brutal,
116}
117
118#[async_trait]
120pub trait Supervisable: Send + Sync {
121 fn name(&self) -> &str;
123
124 fn shutdown_strategy(&self) -> ShutdownStrategy {
126 ShutdownStrategy::Graceful(Duration::from_secs(5))
127 }
128
129 async fn initialize(&self, process_shutdown: ShutdownHandle) -> Result<SupervisorFuture, InitializationError>;
143}
144
145#[derive(Debug, Snafu)]
147#[snafu(context(suffix(false)))]
148pub enum SupervisorError {
149 #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
151 InvalidName {
152 name: String,
154 },
155
156 #[snafu(display("Supervisor has no child processes."))]
158 NoChildren,
159
160 #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
165 FailedToInitialize {
166 child_name: String,
168
169 source: InitializationError,
171 },
172
173 #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
175 Shutdown,
176}
177
178pub struct ChildSpecification<S = WorkerSpec> {
194 spec_inner: S,
195}
196
197pub struct WorkerSpec {
199 worker: Arc<dyn Supervisable>,
200 restart_type: RestartType,
201}
202
203pub struct SupervisorSpec {
205 supervisor: Supervisor,
206}
207
208impl ChildSpecification<WorkerSpec> {
209 pub fn worker<T: Supervisable + 'static>(worker: T) -> Self {
211 Self {
212 spec_inner: WorkerSpec {
213 worker: Arc::new(worker),
214 restart_type: RestartType::Permanent,
215 },
216 }
217 }
218
219 #[must_use]
223 pub fn with_restart_type(mut self, restart_type: RestartType) -> Self {
224 self.spec_inner.restart_type = restart_type;
225 self
226 }
227}
228
229impl<T> From<T> for ChildSpecification<WorkerSpec>
230where
231 T: Supervisable + 'static,
232{
233 fn from(worker: T) -> Self {
234 Self::worker(worker)
235 }
236}
237
238impl From<Supervisor> for ChildSpecification<SupervisorSpec> {
239 fn from(supervisor: Supervisor) -> Self {
240 Self {
241 spec_inner: SupervisorSpec { supervisor },
242 }
243 }
244}
245
246mod sealed {
247 pub trait Sealed {}
248}
249
250impl sealed::Sealed for WorkerSpec {}
251impl sealed::Sealed for SupervisorSpec {}
252
253pub trait ChildState: sealed::Sealed + Sized {
259 #[doc(hidden)]
260 fn register(spec: ChildSpecification<Self>, supervisor: &mut Supervisor);
261}
262
263impl ChildState for WorkerSpec {
264 fn register(spec: ChildSpecification<Self>, supervisor: &mut Supervisor) {
265 supervisor.push_child(ChildEntry {
266 child: SupervisedChild::Worker(spec.spec_inner.worker),
267 restart: spec.spec_inner.restart_type,
268 });
269 }
270}
271
272impl ChildState for SupervisorSpec {
273 fn register(spec: ChildSpecification<Self>, supervisor: &mut Supervisor) {
274 supervisor.push_child(ChildEntry {
275 child: SupervisedChild::Supervisor(spec.spec_inner.supervisor),
276 restart: RestartType::Permanent,
277 });
278 }
279}
280
281enum SupervisedChild {
287 Worker(Arc<dyn Supervisable>),
288 Supervisor(Supervisor),
289}
290
291impl SupervisedChild {
292 fn process_type(&self) -> &'static str {
293 match self {
294 Self::Worker(_) => "worker",
295 Self::Supervisor(_) => "supervisor",
296 }
297 }
298
299 fn name(&self) -> &str {
300 match self {
301 Self::Worker(worker) => worker.name(),
302 Self::Supervisor(supervisor) => &supervisor.supervisor_id,
303 }
304 }
305
306 fn shutdown_strategy(&self) -> ShutdownStrategy {
307 match self {
308 Self::Worker(worker) => worker.shutdown_strategy(),
309
310 Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
313 }
314 }
315
316 fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
317 match self {
318 Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
319 name: worker.name().to_string(),
320 }),
321 Self::Supervisor(sup) => {
322 Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
323 name: sup.supervisor_id.to_string(),
324 })
325 }
326 }
327 }
328
329 fn create_worker_future(
330 &self, process: Process, process_shutdown: ShutdownHandle,
331 ) -> Result<WorkerFuture, SupervisorError> {
332 match self {
333 Self::Worker(worker) => {
334 let worker = Arc::clone(worker);
335 Ok(Box::pin(async move {
336 let run_future =
337 worker
338 .initialize(process_shutdown)
339 .await
340 .map_err(|source| WorkerError::Initialization {
341 child_name: None,
342 source,
343 })?;
344 run_future.await.map_err(WorkerError::Runtime)
345 }))
346 }
347 Self::Supervisor(sup) => {
348 match sup.runtime_mode() {
349 RuntimeMode::Ambient => {
350 Ok(sup.as_nested_process(process, process_shutdown))
352 }
353 RuntimeMode::Dedicated(config) => {
354 let child_name = sup.supervisor_id.to_string();
357 let dataspace = process.dataspace().clone();
358 let handle =
359 spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown, dataspace)
360 .map_err(|e| SupervisorError::FailedToInitialize {
361 child_name,
362 source: e.into(),
363 })?;
364
365 Ok(Box::pin(async move { handle.await.map_err(WorkerError::from) }))
366 }
367 }
368 }
369 }
370 }
371}
372
373impl Clone for SupervisedChild {
374 fn clone(&self) -> Self {
375 match self {
376 Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
377 Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
378 }
379 }
380}
381
382#[derive(Clone)]
384struct ChildEntry {
385 child: SupervisedChild,
386 restart: RestartType,
387}
388
389pub struct Supervisor {
416 supervisor_id: Arc<str>,
417 child_specs: Vec<ChildEntry>,
418 restart_strategy: RestartStrategy,
419 runtime_mode: RuntimeMode,
420}
421
422impl Supervisor {
423 pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
425 if supervisor_id.as_ref().is_empty() {
429 return Err(SupervisorError::InvalidName {
430 name: supervisor_id.as_ref().to_string(),
431 });
432 }
433
434 Ok(Self {
435 supervisor_id: supervisor_id.as_ref().into(),
436 child_specs: Vec::new(),
437 restart_strategy: RestartStrategy::default(),
438 runtime_mode: RuntimeMode::default(),
439 })
440 }
441
442 pub fn id(&self) -> &str {
444 &self.supervisor_id
445 }
446
447 pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
449 self.restart_strategy = strategy;
450 self
451 }
452
453 pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
463 self.runtime_mode = RuntimeMode::Dedicated(config);
464 self
465 }
466
467 pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
469 &self.runtime_mode
470 }
471
472 pub fn add_worker<S, T>(&mut self, child: T)
480 where
481 S: ChildState,
482 T: Into<ChildSpecification<S>>,
483 {
484 S::register(child.into(), self);
485 }
486
487 fn push_child(&mut self, entry: ChildEntry) {
488 debug!(
489 supervisor_id = %self.supervisor_id,
490 "Adding new static child process #{}. ({}, {}, {:?})",
491 self.child_specs.len(),
492 entry.child.process_type(),
493 entry.child.name(),
494 entry.restart,
495 );
496 self.child_specs.push(entry);
497 }
498
499 fn get_child_spec(&self, child_spec_idx: usize) -> &SupervisedChild {
500 match self.child_specs.get(child_spec_idx) {
501 Some(entry) => &entry.child,
502 None => unreachable!("child spec index should never be out of bounds"),
503 }
504 }
505
506 fn get_restart_type(&self, child_spec_idx: usize) -> RestartType {
507 match self.child_specs.get(child_spec_idx) {
508 Some(entry) => entry.restart,
509 None => unreachable!("child spec index should never be out of bounds"),
510 }
511 }
512
513 fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
514 let child_spec = self.get_child_spec(child_spec_idx);
515 debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
516 worker_state.add_worker(child_spec_idx, child_spec)
517 }
518
519 fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
520 debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
521 for child_spec_idx in 0..self.child_specs.len() {
522 self.spawn_child(child_spec_idx, worker_state)?;
523 }
524
525 Ok(())
526 }
527
528 fn respawn_children_one_for_all(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
536 debug!(supervisor_id = %self.supervisor_id, "Restarting all eligible static child processes.");
537 for child_spec_idx in 0..self.child_specs.len() {
538 if self.get_restart_type(child_spec_idx) != RestartType::Temporary {
539 self.spawn_child(child_spec_idx, worker_state)?;
540 }
541 }
542
543 Ok(())
544 }
545
546 async fn run_inner(&self, process: Process, process_shutdown: ShutdownHandle) -> Result<(), SupervisorError> {
547 if self.child_specs.is_empty() {
548 return Err(SupervisorError::NoChildren);
549 }
550
551 let mut restart_state = RestartState::new(self.restart_strategy);
552 let mut worker_state = WorkerState::new(process);
553
554 self.spawn_all_children(&mut worker_state)?;
557
558 pin!(process_shutdown);
560
561 loop {
562 select! {
563 _ = &mut process_shutdown => {
567 debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
568 worker_state.shutdown_workers().await;
569 break;
570 },
571 (child_spec_idx, worker_result) = worker_state.wait_for_next_worker() => {
572 let child_spec = self.get_child_spec(child_spec_idx);
573
574 if let Err(WorkerError::Initialization { child_name, source }) = worker_result {
576 let full_name = match child_name {
579 Some(inner) => format!("{}/{}", child_spec.name(), inner),
580 None => child_spec.name().to_string(),
581 };
582
583 error!(supervisor_id = %self.supervisor_id, worker_name = full_name, "Child process failed to initialize: {}", source);
584 worker_state.shutdown_workers().await;
585 return Err(SupervisorError::FailedToInitialize {
586 child_name: full_name,
587 source,
588 });
589 }
590
591 let abnormal = worker_result.is_err();
594 let restart_type = self.get_restart_type(child_spec_idx);
595
596 let worker_result = worker_result.map_err(|e| match e {
598 WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
599 WorkerError::Initialization { .. } => unreachable!("handled above"),
600 });
601
602 if !restart_type.should_restart(abnormal) {
603 debug!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?restart_type, ?worker_result, "Child process exited and is not eligible for restart.");
609 } else {
610 match restart_state.evaluate_restart() {
611 RestartAction::Restart(mode) => match mode {
612 RestartMode::OneForOne => {
613 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
614 self.spawn_child(child_spec_idx, &mut worker_state)?;
615 }
616 RestartMode::OneForAll => {
617 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
618 worker_state.shutdown_workers().await;
619 self.respawn_children_one_for_all(&mut worker_state)?;
620 }
621 },
622 RestartAction::Shutdown => {
623 error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
624 worker_state.shutdown_workers().await;
625 return Err(SupervisorError::Shutdown);
626 }
627 }
628 }
629 }
630 }
631 }
632
633 Ok(())
634 }
635
636 fn as_nested_process(&self, process: Process, process_shutdown: ShutdownHandle) -> WorkerFuture {
637 debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
640
641 let sup = self.inner_clone();
643
644 Box::pin(async move {
645 sup.run_inner(process, process_shutdown)
646 .await
647 .map_err(WorkerError::from)
648 })
649 }
650
651 pub async fn run(&mut self) -> Result<(), SupervisorError> {
657 let process_shutdown = ShutdownHandle::noop();
660 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
661 name: self.supervisor_id.to_string(),
662 })?;
663
664 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
665 self.run_inner(process.clone(), process_shutdown)
666 .into_process_future(process)
667 .await
668 }
669
670 pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
679 let (shutdown_coordinator, shutdown_handle) = ShutdownHandle::paired();
683 let run = self.run_with_shutdown_inner(shutdown_handle, None);
684 pin!(run, shutdown);
685
686 let mut shutdown_coordinator = Some(shutdown_coordinator);
687 loop {
688 select! {
689 result = &mut run => return result,
690 _ = &mut shutdown, if shutdown_coordinator.is_some() => {
691 shutdown_coordinator.take().expect("coordinator present per select guard").shutdown();
692 }
693 }
694 }
695 }
696
697 pub(crate) async fn run_with_shutdown_inner(
709 &mut self, process_shutdown: ShutdownHandle, dataspace: Option<DataspaceRegistry>,
710 ) -> Result<(), SupervisorError> {
711 let process =
712 Process::supervisor_with_dataspace(&self.supervisor_id, None, dataspace).context(InvalidName {
713 name: self.supervisor_id.to_string(),
714 })?;
715
716 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
717 self.run_inner(process.clone(), process_shutdown)
718 .into_process_future(process)
719 .await
720 }
721
722 fn inner_clone(&self) -> Self {
723 Self {
727 supervisor_id: Arc::clone(&self.supervisor_id),
728 child_specs: self.child_specs.clone(),
729 restart_strategy: self.restart_strategy,
730 runtime_mode: self.runtime_mode.clone(),
731 }
732 }
733}
734
735struct ProcessState {
736 worker_id: usize,
737 shutdown_strategy: ShutdownStrategy,
738 shutdown_coordinator: ShutdownCoordinator,
739 abort_handle: AbortHandle,
740}
741
742struct WorkerState {
743 process: Process,
744 worker_tasks: JoinSet<Result<(), WorkerError>>,
745 worker_map: FastIndexMap<Id, ProcessState>,
746}
747
748impl WorkerState {
749 fn new(process: Process) -> Self {
750 Self {
751 process,
752 worker_tasks: JoinSet::new(),
753 worker_map: FastIndexMap::default(),
754 }
755 }
756
757 fn add_worker(&mut self, worker_id: usize, child_spec: &SupervisedChild) -> Result<(), SupervisorError> {
758 let (shutdown_coordinator, shutdown_handle) = ShutdownHandle::paired();
759 let process = child_spec.create_process(&self.process)?;
760 let worker_future = child_spec.create_worker_future(process.clone(), shutdown_handle)?;
761 let shutdown_strategy = child_spec.shutdown_strategy();
762 let abort_handle = self.worker_tasks.spawn(worker_future.into_process_future(process));
763 self.worker_map.insert(
764 abort_handle.id(),
765 ProcessState {
766 worker_id,
767 shutdown_strategy,
768 shutdown_coordinator,
769 abort_handle,
770 },
771 );
772 Ok(())
773 }
774
775 async fn wait_for_next_worker(&mut self) -> (usize, Result<(), WorkerError>) {
776 debug!("Waiting for next process to complete.");
777
778 if self.worker_tasks.is_empty() {
784 std::future::pending::<()>().await;
785 }
786
787 match self.worker_tasks.join_next_with_id().await {
788 Some(Ok((worker_task_id, worker_result))) => {
789 let process_state = self
790 .worker_map
791 .swap_remove(&worker_task_id)
792 .expect("worker task ID not found");
793 (process_state.worker_id, worker_result)
794 }
795 Some(Err(e)) => {
796 let worker_task_id = e.id();
797 let process_state = self
798 .worker_map
799 .swap_remove(&worker_task_id)
800 .expect("worker task ID not found");
801 let e = if e.is_cancelled() {
802 ProcessError::Aborted
803 } else {
804 ProcessError::Panicked
805 };
806 (process_state.worker_id, Err(WorkerError::Runtime(e.into())))
807 }
808 None => unreachable!(
809 "join set is non-empty here: we park above while empty, and only this method removes workers"
810 ),
811 }
812 }
813
814 async fn shutdown_workers(&mut self) {
815 debug!("Shutting down all processes.");
816
817 while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
827 let ProcessState {
828 worker_id,
829 shutdown_strategy,
830 shutdown_coordinator,
831 abort_handle,
832 } = process_state;
833
834 let shutdown_deadline = match shutdown_strategy {
836 ShutdownStrategy::Graceful(timeout) => {
837 debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
838 shutdown_coordinator.shutdown();
839
840 tokio::time::sleep(timeout)
841 }
842 ShutdownStrategy::Brutal => {
843 debug!(worker_id, "Forcefully aborting process.");
844 abort_handle.abort();
845
846 tokio::time::sleep(Duration::MAX)
849 }
850 };
851 pin!(shutdown_deadline);
852
853 loop {
856 select! {
857 worker_result = self.worker_tasks.join_next_with_id() => {
858 match worker_result {
859 Some(Ok((worker_task_id, _))) => {
860 if worker_task_id == current_worker_task_id {
861 debug!(?worker_task_id, "Target process exited successfully.");
862 break;
863 } else {
864 debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
865 self.worker_map.swap_remove(&worker_task_id);
866 }
867 },
868 Some(Err(e)) => {
869 let worker_task_id = e.id();
870 if worker_task_id == current_worker_task_id {
871 debug!(?worker_task_id, "Target process exited with error.");
872 break;
873 } else {
874 debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
875 self.worker_map.swap_remove(&worker_task_id);
876 }
877 }
878 None => unreachable!("worker task must exist in join set if we are waiting for it"),
879 }
880 },
881 _ = &mut shutdown_deadline => {
883 debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
884 abort_handle.abort();
885 }
886 }
887 }
888 }
889
890 debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
891 debug_assert!(
892 self.worker_tasks.is_empty(),
893 "worker tasks should be empty after shutdown"
894 );
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use std::sync::atomic::{AtomicUsize, Ordering};
901
902 use async_trait::async_trait;
903 use tokio::{
904 sync::oneshot,
905 task::JoinHandle,
906 time::{sleep, timeout},
907 };
908
909 use super::*;
910
911 #[derive(Clone)]
913 enum InitBehavior {
914 Instant,
916
917 Slow(Duration),
919
920 Fail(&'static str),
922 }
923
924 #[derive(Clone)]
926 enum RunBehavior {
927 UntilShutdown,
929
930 FailAfter(Duration, &'static str),
932
933 CompleteAfter(Duration),
935 }
936
937 struct MockWorker {
939 name: &'static str,
940 init_behavior: InitBehavior,
941 run_behavior: RunBehavior,
942 start_count: Arc<AtomicUsize>,
943 }
944
945 impl MockWorker {
946 fn long_running(name: &'static str) -> Self {
948 Self {
949 name,
950 init_behavior: InitBehavior::Instant,
951 run_behavior: RunBehavior::UntilShutdown,
952 start_count: Arc::new(AtomicUsize::new(0)),
953 }
954 }
955
956 fn failing(name: &'static str, delay: Duration) -> Self {
958 Self {
959 name,
960 init_behavior: InitBehavior::Instant,
961 run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
962 start_count: Arc::new(AtomicUsize::new(0)),
963 }
964 }
965
966 fn completing(name: &'static str, delay: Duration) -> Self {
968 Self {
969 name,
970 init_behavior: InitBehavior::Instant,
971 run_behavior: RunBehavior::CompleteAfter(delay),
972 start_count: Arc::new(AtomicUsize::new(0)),
973 }
974 }
975
976 fn init_failure(name: &'static str) -> Self {
978 Self {
979 name,
980 init_behavior: InitBehavior::Fail("init failed"),
981 run_behavior: RunBehavior::UntilShutdown,
982 start_count: Arc::new(AtomicUsize::new(0)),
983 }
984 }
985
986 fn slow_init(name: &'static str, init_delay: Duration) -> Self {
988 Self {
989 name,
990 init_behavior: InitBehavior::Slow(init_delay),
991 run_behavior: RunBehavior::UntilShutdown,
992 start_count: Arc::new(AtomicUsize::new(0)),
993 }
994 }
995
996 fn start_count(&self) -> Arc<AtomicUsize> {
998 Arc::clone(&self.start_count)
999 }
1000 }
1001
1002 #[async_trait]
1003 impl Supervisable for MockWorker {
1004 fn name(&self) -> &str {
1005 self.name
1006 }
1007
1008 fn shutdown_strategy(&self) -> ShutdownStrategy {
1009 ShutdownStrategy::Graceful(Duration::from_millis(500))
1010 }
1011
1012 async fn initialize(&self, process_shutdown: ShutdownHandle) -> Result<SupervisorFuture, InitializationError> {
1013 match &self.init_behavior {
1014 InitBehavior::Instant => {}
1015 InitBehavior::Slow(delay) => {
1016 sleep(*delay).await;
1017 }
1018 InitBehavior::Fail(msg) => {
1019 return Err(InitializationError::Failed {
1020 source: GenericError::msg(*msg),
1021 });
1022 }
1023 }
1024
1025 let start_count = Arc::clone(&self.start_count);
1026 let run_behavior = self.run_behavior.clone();
1027
1028 Ok(Box::pin(async move {
1029 start_count.fetch_add(1, Ordering::SeqCst);
1030
1031 match run_behavior {
1032 RunBehavior::UntilShutdown => {
1033 process_shutdown.await;
1034 Ok(())
1035 }
1036 RunBehavior::FailAfter(delay, msg) => {
1037 select! {
1038 _ = sleep(delay) => {
1039 Err(GenericError::msg(msg))
1040 }
1041 _ = process_shutdown => {
1042 Ok(())
1043 }
1044 }
1045 }
1046 RunBehavior::CompleteAfter(delay) => {
1047 select! {
1048 _ = sleep(delay) => Ok(()),
1049 _ = process_shutdown => Ok(()),
1050 }
1051 }
1052 }
1053 }))
1054 }
1055 }
1056
1057 async fn run_supervisor_with_trigger(
1060 mut supervisor: Supervisor,
1061 ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
1062 let (tx, rx) = oneshot::channel();
1063 let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
1064 sleep(Duration::from_millis(50)).await;
1066 (tx, handle)
1067 }
1068
1069 #[tokio::test]
1072 async fn standalone_supervisor_shuts_down_cleanly() {
1073 let mut sup = Supervisor::new("test-sup").unwrap();
1074 sup.add_worker(MockWorker::long_running("worker1"));
1075 sup.add_worker(MockWorker::long_running("worker2"));
1076
1077 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1078 tx.send(()).unwrap();
1079
1080 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1081 assert!(result.is_ok());
1082 }
1083
1084 #[tokio::test]
1085 async fn nested_supervisor_shuts_down_cleanly() {
1086 let mut child_sup = Supervisor::new("child-sup").unwrap();
1087 child_sup.add_worker(MockWorker::long_running("inner-worker"));
1088
1089 let mut parent_sup = Supervisor::new("parent-sup").unwrap();
1090 parent_sup.add_worker(MockWorker::long_running("outer-worker"));
1091 parent_sup.add_worker(child_sup);
1092
1093 let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
1094 tx.send(()).unwrap();
1095
1096 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1097 assert!(result.is_ok());
1098 }
1099
1100 #[tokio::test]
1101 async fn supervisor_with_no_children_returns_error() {
1102 let mut sup = Supervisor::new("empty-sup").unwrap();
1103
1104 let (tx, rx) = oneshot::channel::<()>();
1105 let result = sup.run_with_shutdown(rx).await;
1106 drop(tx);
1107
1108 assert!(matches!(result, Err(SupervisorError::NoChildren)));
1109 }
1110
1111 #[tokio::test]
1114 async fn one_for_one_restarts_only_failed_child() {
1115 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
1116 let failing_count = failing.start_count();
1117
1118 let stable = MockWorker::long_running("stable-worker");
1119 let stable_count = stable.start_count();
1120
1121 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1122 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
1123 );
1124 sup.add_worker(stable);
1125 sup.add_worker(failing);
1126
1127 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1128
1129 sleep(Duration::from_millis(300)).await;
1131 let _ = tx.send(());
1132
1133 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1134 assert!(result.is_ok());
1135
1136 assert!(
1138 failing_count.load(Ordering::SeqCst) >= 2,
1139 "failing worker should have been restarted"
1140 );
1141 assert_eq!(
1143 stable_count.load(Ordering::SeqCst),
1144 1,
1145 "stable worker should not have been restarted"
1146 );
1147 }
1148
1149 #[tokio::test]
1150 async fn one_for_all_restarts_all_children() {
1151 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
1152 let failing_count = failing.start_count();
1153
1154 let stable = MockWorker::long_running("stable-worker");
1155 let stable_count = stable.start_count();
1156
1157 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1158 RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
1159 );
1160 sup.add_worker(stable);
1161 sup.add_worker(failing);
1162
1163 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1164
1165 sleep(Duration::from_millis(300)).await;
1167 let _ = tx.send(());
1168
1169 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1170 assert!(result.is_ok());
1171
1172 assert!(
1174 failing_count.load(Ordering::SeqCst) >= 2,
1175 "failing worker should have been restarted"
1176 );
1177 assert!(
1178 stable_count.load(Ordering::SeqCst) >= 2,
1179 "stable worker should also have been restarted"
1180 );
1181 }
1182
1183 #[tokio::test]
1184 async fn one_for_all_does_not_restart_temporary_children() {
1185 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
1188 let failing_count = failing.start_count();
1189
1190 let temp = MockWorker::long_running("temp-worker");
1191 let temp_count = temp.start_count();
1192
1193 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1194 RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
1195 );
1196 sup.add_worker(ChildSpecification::worker(temp).with_restart_type(RestartType::Temporary));
1197 sup.add_worker(failing);
1198
1199 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1200
1201 sleep(Duration::from_millis(300)).await;
1203 let _ = tx.send(());
1204
1205 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1206 assert!(result.is_ok());
1207 assert!(
1208 failing_count.load(Ordering::SeqCst) >= 2,
1209 "permanent worker should have been restarted by one-for-all"
1210 );
1211 assert_eq!(
1212 temp_count.load(Ordering::SeqCst),
1213 1,
1214 "temporary child must not be restarted by a one-for-all group restart"
1215 );
1216 }
1217
1218 #[tokio::test]
1219 async fn one_for_all_restarts_transient_children() {
1220 let transient = MockWorker::completing("transient-worker", Duration::from_millis(30));
1223 let transient_count = transient.start_count();
1224
1225 let failing = MockWorker::failing("failing-worker", Duration::from_millis(80));
1227
1228 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1229 RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
1230 );
1231 sup.add_worker(ChildSpecification::worker(transient).with_restart_type(RestartType::Transient));
1232 sup.add_worker(failing);
1233
1234 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1235
1236 sleep(Duration::from_millis(300)).await;
1237 let _ = tx.send(());
1238
1239 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1240 assert!(result.is_ok());
1241 assert!(
1242 transient_count.load(Ordering::SeqCst) >= 2,
1243 "transient child must be restarted by a one-for-all group restart, even after a clean exit"
1244 );
1245 }
1246
1247 #[tokio::test]
1248 async fn restart_limit_exceeded_shuts_down_supervisor() {
1249 let mut sup = Supervisor::new("test-sup")
1250 .unwrap()
1251 .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
1252 sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
1254
1255 let (tx, rx) = oneshot::channel::<()>();
1256 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1257
1258 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1259 drop(tx);
1260
1261 assert!(matches!(result, Err(SupervisorError::Shutdown)));
1262 }
1263
1264 #[tokio::test]
1267 async fn temporary_child_is_not_restarted() {
1268 let temp = MockWorker::failing("temp-worker", Duration::from_millis(50));
1270 let temp_count = temp.start_count();
1271
1272 let stable = MockWorker::long_running("stable-worker");
1273
1274 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1275 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
1276 );
1277 sup.add_worker(stable);
1278 sup.add_worker(ChildSpecification::worker(temp).with_restart_type(RestartType::Temporary));
1279
1280 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1281
1282 sleep(Duration::from_millis(300)).await;
1284 let _ = tx.send(());
1285
1286 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1287 assert!(result.is_ok());
1288 assert_eq!(
1289 temp_count.load(Ordering::SeqCst),
1290 1,
1291 "temporary worker must not be restarted"
1292 );
1293 }
1294
1295 #[tokio::test]
1296 async fn transient_child_is_not_restarted_on_clean_exit() {
1297 let transient = MockWorker::completing("transient-worker", Duration::from_millis(50));
1298 let transient_count = transient.start_count();
1299
1300 let stable = MockWorker::long_running("stable-worker");
1301
1302 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1303 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
1304 );
1305 sup.add_worker(stable);
1306 sup.add_worker(ChildSpecification::worker(transient).with_restart_type(RestartType::Transient));
1307
1308 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1309
1310 sleep(Duration::from_millis(300)).await;
1311 let _ = tx.send(());
1312
1313 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1314 assert!(result.is_ok());
1315 assert_eq!(
1316 transient_count.load(Ordering::SeqCst),
1317 1,
1318 "transient worker must not be restarted after a clean exit"
1319 );
1320 }
1321
1322 #[tokio::test]
1323 async fn transient_child_is_restarted_on_failure() {
1324 let transient = MockWorker::failing("transient-worker", Duration::from_millis(50));
1325 let transient_count = transient.start_count();
1326
1327 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1328 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
1329 );
1330 sup.add_worker(ChildSpecification::worker(transient).with_restart_type(RestartType::Transient));
1331
1332 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1333
1334 sleep(Duration::from_millis(300)).await;
1335 let _ = tx.send(());
1336
1337 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1338 assert!(result.is_ok());
1339 assert!(
1340 transient_count.load(Ordering::SeqCst) >= 2,
1341 "transient worker must be restarted after an abnormal exit"
1342 );
1343 }
1344
1345 #[tokio::test]
1346 async fn permanent_child_is_restarted_on_clean_exit() {
1347 let permanent = MockWorker::completing("permanent-worker", Duration::from_millis(50));
1350 let permanent_count = permanent.start_count();
1351
1352 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1353 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
1354 );
1355 sup.add_worker(permanent);
1357
1358 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1359
1360 sleep(Duration::from_millis(300)).await;
1361 let _ = tx.send(());
1362
1363 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1364 assert!(result.is_ok());
1365 assert!(
1366 permanent_count.load(Ordering::SeqCst) >= 2,
1367 "permanent worker must be restarted even after a clean exit"
1368 );
1369 }
1370
1371 #[tokio::test]
1372 async fn temporary_failures_do_not_consume_restart_intensity() {
1373 let mut sup = Supervisor::new("test-sup")
1377 .unwrap()
1378 .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
1379
1380 let workers = [
1381 MockWorker::failing("temp-0", Duration::from_millis(20)),
1382 MockWorker::failing("temp-1", Duration::from_millis(20)),
1383 MockWorker::failing("temp-2", Duration::from_millis(20)),
1384 MockWorker::failing("temp-3", Duration::from_millis(20)),
1385 MockWorker::failing("temp-4", Duration::from_millis(20)),
1386 ];
1387 let counts: Vec<_> = workers.iter().map(|w| w.start_count()).collect();
1388 for worker in workers {
1389 sup.add_worker(ChildSpecification::worker(worker).with_restart_type(RestartType::Temporary));
1390 }
1391 sup.add_worker(MockWorker::long_running("stable-worker"));
1393
1394 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1395 sleep(Duration::from_millis(300)).await;
1396 let _ = tx.send(());
1397
1398 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1399 assert!(
1400 result.is_ok(),
1401 "supervisor must not trip its restart limit on temporary exits"
1402 );
1403 for count in counts {
1404 assert_eq!(
1405 count.load(Ordering::SeqCst),
1406 1,
1407 "each temporary worker runs exactly once"
1408 );
1409 }
1410 }
1411
1412 #[tokio::test]
1413 async fn supervisor_idles_when_all_temporary_children_exit() {
1414 let mut sup = Supervisor::new("test-sup").unwrap();
1417 sup.add_worker(
1418 ChildSpecification::worker(MockWorker::completing("temp-a", Duration::from_millis(30)))
1419 .with_restart_type(RestartType::Temporary),
1420 );
1421 sup.add_worker(
1422 ChildSpecification::worker(MockWorker::completing("temp-b", Duration::from_millis(30)))
1423 .with_restart_type(RestartType::Temporary),
1424 );
1425
1426 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1427
1428 sleep(Duration::from_millis(200)).await;
1430 assert!(
1431 !handle.is_finished(),
1432 "supervisor must keep running after all children exit"
1433 );
1434
1435 let _ = tx.send(());
1436 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1437 assert!(result.is_ok());
1438 }
1439
1440 #[tokio::test]
1443 async fn init_failure_propagates_with_child_name() {
1444 let mut sup = Supervisor::new("test-sup").unwrap();
1445 sup.add_worker(MockWorker::long_running("good-worker"));
1446 sup.add_worker(MockWorker::init_failure("bad-worker"));
1447
1448 let (_tx, rx) = oneshot::channel::<()>();
1449 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1450 .await
1451 .unwrap();
1452
1453 match result {
1454 Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
1455 assert_eq!(child_name, "bad-worker");
1456 }
1457 other => panic!("expected FailedToInitialize, got: {:?}", other),
1458 }
1459 }
1460
1461 #[tokio::test]
1462 async fn init_failure_does_not_trigger_restart() {
1463 let init_fail = MockWorker::init_failure("bad-worker");
1464 let start_count = init_fail.start_count();
1465
1466 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1467 RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1468 );
1469 sup.add_worker(init_fail);
1470
1471 let (_tx, rx) = oneshot::channel::<()>();
1472 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1473 .await
1474 .unwrap();
1475
1476 assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1477 assert_eq!(start_count.load(Ordering::SeqCst), 0);
1479 }
1480
1481 #[tokio::test]
1484 async fn shutdown_completes_promptly_in_steady_state() {
1485 let mut sup = Supervisor::new("test-sup").unwrap();
1486 sup.add_worker(MockWorker::long_running("worker1"));
1487 sup.add_worker(MockWorker::long_running("worker2"));
1488
1489 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1490 tx.send(()).unwrap();
1491
1492 let result = timeout(Duration::from_secs(1), handle).await;
1494 assert!(result.is_ok(), "shutdown should complete promptly");
1495 }
1496
1497 #[tokio::test]
1498 async fn shutdown_during_slow_init_completes_promptly() {
1499 let mut sup = Supervisor::new("test-sup").unwrap();
1500 sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1502
1503 let (tx, rx) = oneshot::channel();
1504 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1505
1506 sleep(Duration::from_millis(20)).await;
1508 tx.send(()).unwrap();
1509
1510 let result = timeout(Duration::from_secs(2), handle).await;
1513 assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1514 }
1515}