1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_error::{ErrorContext as _, GenericError};
6use snafu::{OptionExt as _, Snafu};
7use tokio::{
8 pin, select,
9 task::{AbortHandle, Id, JoinSet},
10};
11use tracing::{debug, error, warn};
12
13use super::{
14 dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
15 restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
16 shutdown::{ProcessShutdown, ShutdownHandle},
17};
18use crate::runtime::process::{Process, ProcessExt as _};
19
20pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
22
23type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
29
30#[derive(Debug)]
35enum WorkerError {
36 Initialization(InitializationError),
38
39 Runtime(GenericError),
41}
42
43#[derive(Debug, Snafu)]
45pub enum ProcessError {
46 #[snafu(display("Child process was aborted by the supervisor."))]
48 Aborted,
49
50 #[snafu(display("Child process panicked."))]
52 Panicked,
53
54 #[snafu(display("Child process terminated with an error: {}", source))]
56 Terminated {
57 source: GenericError,
59 },
60}
61
62#[derive(Debug, Snafu)]
68#[snafu(context(suffix(false)))]
69pub enum InitializationError {
70 #[snafu(display("Process failed to initialize: {}", source))]
72 Failed {
73 source: GenericError,
75 },
76
77 #[snafu(display("Process is permanently unavailable"))]
81 PermanentlyUnavailable,
82}
83
84pub enum ShutdownStrategy {
86 Graceful(Duration),
88
89 Brutal,
91}
92
93#[async_trait]
95pub trait Supervisable: Send + Sync {
96 fn name(&self) -> &str;
98
99 fn shutdown_strategy(&self) -> ShutdownStrategy {
101 ShutdownStrategy::Graceful(Duration::from_secs(5))
102 }
103
104 async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError>;
114}
115
116#[derive(Debug, Snafu)]
118#[snafu(context(suffix(false)))]
119pub enum SupervisorError {
120 #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
122 InvalidName {
123 name: String,
125 },
126
127 #[snafu(display("Supervisor has no child processes."))]
129 NoChildren,
130
131 #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
136 FailedToInitialize {
137 child_name: String,
139
140 source: InitializationError,
142 },
143
144 #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
146 Shutdown,
147}
148
149pub enum ChildSpecification {
158 Worker(Arc<dyn Supervisable>),
159 Supervisor(Supervisor),
160}
161
162impl ChildSpecification {
163 fn process_type(&self) -> &'static str {
164 match self {
165 Self::Worker(_) => "worker",
166 Self::Supervisor(_) => "supervisor",
167 }
168 }
169
170 fn name(&self) -> &str {
171 match self {
172 Self::Worker(worker) => worker.name(),
173 Self::Supervisor(supervisor) => &supervisor.supervisor_id,
174 }
175 }
176
177 fn shutdown_strategy(&self) -> ShutdownStrategy {
178 match self {
179 Self::Worker(worker) => worker.shutdown_strategy(),
180
181 Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
184 }
185 }
186
187 fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
188 match self {
189 Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
190 name: worker.name().to_string(),
191 }),
192 Self::Supervisor(sup) => {
193 Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
194 name: sup.supervisor_id.to_string(),
195 })
196 }
197 }
198 }
199
200 fn create_worker_future(
201 &self, process: Process, process_shutdown: ProcessShutdown,
202 ) -> Result<WorkerFuture, SupervisorError> {
203 match self {
204 Self::Worker(worker) => {
205 let worker = Arc::clone(worker);
206 Ok(Box::pin(async move {
207 let run_future = worker
208 .initialize(process_shutdown)
209 .await
210 .map_err(WorkerError::Initialization)?;
211 run_future.await.map_err(WorkerError::Runtime)
212 }))
213 }
214 Self::Supervisor(sup) => {
215 match sup.runtime_mode() {
216 RuntimeMode::Ambient => {
217 Ok(sup.as_nested_process(process, process_shutdown))
219 }
220 RuntimeMode::Dedicated(config) => {
221 let child_name = sup.supervisor_id.to_string();
223 let handle = spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown)
224 .map_err(|e| SupervisorError::FailedToInitialize {
225 child_name,
226 source: InitializationError::Failed { source: e },
227 })?;
228
229 Ok(Box::pin(async move { handle.await.map_err(WorkerError::Runtime) }))
230 }
231 }
232 }
233 }
234 }
235}
236
237impl Clone for ChildSpecification {
238 fn clone(&self) -> Self {
239 match self {
240 Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
241 Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
242 }
243 }
244}
245
246impl From<Supervisor> for ChildSpecification {
247 fn from(supervisor: Supervisor) -> Self {
248 Self::Supervisor(supervisor)
249 }
250}
251
252impl<T> From<T> for ChildSpecification
253where
254 T: Supervisable + 'static,
255{
256 fn from(worker: T) -> Self {
257 Self::Worker(Arc::new(worker))
258 }
259}
260
261pub struct Supervisor {
289 supervisor_id: Arc<str>,
290 child_specs: Vec<ChildSpecification>,
291 restart_strategy: RestartStrategy,
292 runtime_mode: RuntimeMode,
293}
294
295impl Supervisor {
296 pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
298 if supervisor_id.as_ref().is_empty() {
302 return Err(SupervisorError::InvalidName {
303 name: supervisor_id.as_ref().to_string(),
304 });
305 }
306
307 Ok(Self {
308 supervisor_id: supervisor_id.as_ref().into(),
309 child_specs: Vec::new(),
310 restart_strategy: RestartStrategy::default(),
311 runtime_mode: RuntimeMode::default(),
312 })
313 }
314
315 pub fn id(&self) -> &str {
317 &self.supervisor_id
318 }
319
320 pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
322 self.restart_strategy = strategy;
323 self
324 }
325
326 pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
336 self.runtime_mode = RuntimeMode::Dedicated(config);
337 self
338 }
339
340 pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
342 &self.runtime_mode
343 }
344
345 pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
350 let child_spec = process.into();
351 debug!(
352 supervisor_id = %self.supervisor_id,
353 "Adding new static child process #{}. ({}, {})",
354 self.child_specs.len(),
355 child_spec.process_type(),
356 child_spec.name(),
357 );
358 self.child_specs.push(child_spec);
359 }
360
361 fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
362 match self.child_specs.get(child_spec_idx) {
363 Some(child_spec) => child_spec,
364 None => unreachable!("child spec index should never be out of bounds"),
365 }
366 }
367
368 fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
369 let child_spec = self.get_child_spec(child_spec_idx);
370 debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
371 worker_state.add_worker(child_spec_idx, child_spec)
372 }
373
374 fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
375 debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
376 for child_spec_idx in 0..self.child_specs.len() {
377 self.spawn_child(child_spec_idx, worker_state)?;
378 }
379
380 Ok(())
381 }
382
383 async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
384 if self.child_specs.is_empty() {
385 return Err(SupervisorError::NoChildren);
386 }
387
388 let mut restart_state = RestartState::new(self.restart_strategy);
389 let mut worker_state = WorkerState::new(process);
390
391 self.spawn_all_children(&mut worker_state)?;
394
395 let shutdown = process_shutdown.wait_for_shutdown();
397 pin!(shutdown);
398
399 loop {
400 select! {
401 _ = &mut shutdown => {
405 debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
406 worker_state.shutdown_workers().await;
407 break;
408 },
409 worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
410 Some((child_spec_idx, worker_result)) => {
416 let child_spec = self.get_child_spec(child_spec_idx);
417
418 if let Err(WorkerError::Initialization(e)) = worker_result {
420 error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), "Child process failed to initialize: {}", e);
421 worker_state.shutdown_workers().await;
422 return Err(SupervisorError::FailedToInitialize {
423 child_name: child_spec.name().to_string(),
424 source: e,
425 });
426 }
427
428 let worker_result = worker_result
430 .map_err(|e| match e {
431 WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
432 WorkerError::Initialization(_) => unreachable!("handled above"),
433 });
434
435 match restart_state.evaluate_restart() {
436 RestartAction::Restart(mode) => match mode {
437 RestartMode::OneForOne => {
438 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
439 self.spawn_child(child_spec_idx, &mut worker_state)?;
440 }
441 RestartMode::OneForAll => {
442 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
443 worker_state.shutdown_workers().await;
444 self.spawn_all_children(&mut worker_state)?;
445 }
446 },
447 RestartAction::Shutdown => {
448 error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
449 worker_state.shutdown_workers().await;
450 return Err(SupervisorError::Shutdown);
451 }
452 }
453 },
454 None => unreachable!("should not have empty worker joinset prior to shutdown"),
455 }
456 }
457 }
458
459 Ok(())
460 }
461
462 fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture {
463 debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
466
467 let sup = self.inner_clone();
469
470 Box::pin(async move {
471 sup.run_inner(process, process_shutdown)
472 .await
473 .error_context("Nested supervisor failed to exit cleanly.")
474 .map_err(WorkerError::Runtime)
475 })
476 }
477
478 pub async fn run(&mut self) -> Result<(), SupervisorError> {
484 let process_shutdown = ProcessShutdown::noop();
487 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
488 name: self.supervisor_id.to_string(),
489 })?;
490
491 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
492 self.run_inner(process.clone(), process_shutdown)
493 .into_instrumented(process)
494 .await
495 }
496
497 pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
506 let process_shutdown = ProcessShutdown::wrapped(shutdown);
507 self.run_with_process_shutdown(process_shutdown).await
508 }
509
510 pub(crate) async fn run_with_process_shutdown(
519 &mut self, process_shutdown: ProcessShutdown,
520 ) -> Result<(), SupervisorError> {
521 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
522 name: self.supervisor_id.to_string(),
523 })?;
524
525 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
526 self.run_inner(process.clone(), process_shutdown)
527 .into_instrumented(process)
528 .await
529 }
530
531 fn inner_clone(&self) -> Self {
532 Self {
536 supervisor_id: Arc::clone(&self.supervisor_id),
537 child_specs: self.child_specs.clone(),
538 restart_strategy: self.restart_strategy,
539 runtime_mode: self.runtime_mode.clone(),
540 }
541 }
542}
543
544struct ProcessState {
545 worker_id: usize,
546 shutdown_strategy: ShutdownStrategy,
547 shutdown_handle: ShutdownHandle,
548 abort_handle: AbortHandle,
549}
550
551struct WorkerState {
552 process: Process,
553 worker_tasks: JoinSet<Result<(), WorkerError>>,
554 worker_map: FastIndexMap<Id, ProcessState>,
555}
556
557impl WorkerState {
558 fn new(process: Process) -> Self {
559 Self {
560 process,
561 worker_tasks: JoinSet::new(),
562 worker_map: FastIndexMap::default(),
563 }
564 }
565
566 fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
567 let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
568 let process = child_spec.create_process(&self.process)?;
569 let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?;
570 let shutdown_strategy = child_spec.shutdown_strategy();
571 let abort_handle = self.worker_tasks.spawn(worker_future.into_instrumented(process));
572 self.worker_map.insert(
573 abort_handle.id(),
574 ProcessState {
575 worker_id,
576 shutdown_strategy,
577 shutdown_handle,
578 abort_handle,
579 },
580 );
581 Ok(())
582 }
583
584 async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> {
585 debug!("Waiting for next process to complete.");
586
587 match self.worker_tasks.join_next_with_id().await {
588 Some(Ok((worker_task_id, worker_result))) => {
589 let process_state = self
590 .worker_map
591 .swap_remove(&worker_task_id)
592 .expect("worker task ID not found");
593 Some((process_state.worker_id, worker_result))
594 }
595 Some(Err(e)) => {
596 let worker_task_id = e.id();
597 let process_state = self
598 .worker_map
599 .swap_remove(&worker_task_id)
600 .expect("worker task ID not found");
601 let e = if e.is_cancelled() {
602 ProcessError::Aborted
603 } else {
604 ProcessError::Panicked
605 };
606 Some((process_state.worker_id, Err(WorkerError::Runtime(e.into()))))
607 }
608 None => None,
609 }
610 }
611
612 async fn shutdown_workers(&mut self) {
613 debug!("Shutting down all processes.");
614
615 while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
625 let ProcessState {
626 worker_id,
627 shutdown_strategy,
628 shutdown_handle,
629 abort_handle,
630 } = process_state;
631
632 let shutdown_deadline = match shutdown_strategy {
634 ShutdownStrategy::Graceful(timeout) => {
635 debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
636 shutdown_handle.trigger();
637
638 tokio::time::sleep(timeout)
639 }
640 ShutdownStrategy::Brutal => {
641 debug!(worker_id, "Forcefully aborting process.");
642 abort_handle.abort();
643
644 tokio::time::sleep(Duration::MAX)
647 }
648 };
649 pin!(shutdown_deadline);
650
651 loop {
654 select! {
655 worker_result = self.worker_tasks.join_next_with_id() => {
656 match worker_result {
657 Some(Ok((worker_task_id, _))) => {
658 if worker_task_id == current_worker_task_id {
659 debug!(?worker_task_id, "Target process exited successfully.");
660 break;
661 } else {
662 debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
663 self.worker_map.swap_remove(&worker_task_id);
664 }
665 },
666 Some(Err(e)) => {
667 let worker_task_id = e.id();
668 if worker_task_id == current_worker_task_id {
669 debug!(?worker_task_id, "Target process exited with error.");
670 break;
671 } else {
672 debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
673 self.worker_map.swap_remove(&worker_task_id);
674 }
675 }
676 None => unreachable!("worker task must exist in join set if we are waiting for it"),
677 }
678 },
679 _ = &mut shutdown_deadline => {
681 debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
682 abort_handle.abort();
683 }
684 }
685 }
686 }
687
688 debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
689 debug_assert!(
690 self.worker_tasks.is_empty(),
691 "worker tasks should be empty after shutdown"
692 );
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use std::sync::atomic::{AtomicUsize, Ordering};
699
700 use async_trait::async_trait;
701 use tokio::{
702 sync::oneshot,
703 task::JoinHandle,
704 time::{sleep, timeout},
705 };
706
707 use super::*;
708
709 #[derive(Clone)]
711 enum InitBehavior {
712 Instant,
714
715 Slow(Duration),
717
718 Fail(&'static str),
720 }
721
722 #[derive(Clone)]
724 enum RunBehavior {
725 UntilShutdown,
727
728 FailAfter(Duration, &'static str),
730 }
731
732 struct MockWorker {
734 name: &'static str,
735 init_behavior: InitBehavior,
736 run_behavior: RunBehavior,
737 start_count: Arc<AtomicUsize>,
738 }
739
740 impl MockWorker {
741 fn long_running(name: &'static str) -> Self {
743 Self {
744 name,
745 init_behavior: InitBehavior::Instant,
746 run_behavior: RunBehavior::UntilShutdown,
747 start_count: Arc::new(AtomicUsize::new(0)),
748 }
749 }
750
751 fn failing(name: &'static str, delay: Duration) -> Self {
753 Self {
754 name,
755 init_behavior: InitBehavior::Instant,
756 run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
757 start_count: Arc::new(AtomicUsize::new(0)),
758 }
759 }
760
761 fn init_failure(name: &'static str) -> Self {
763 Self {
764 name,
765 init_behavior: InitBehavior::Fail("init failed"),
766 run_behavior: RunBehavior::UntilShutdown,
767 start_count: Arc::new(AtomicUsize::new(0)),
768 }
769 }
770
771 fn slow_init(name: &'static str, init_delay: Duration) -> Self {
773 Self {
774 name,
775 init_behavior: InitBehavior::Slow(init_delay),
776 run_behavior: RunBehavior::UntilShutdown,
777 start_count: Arc::new(AtomicUsize::new(0)),
778 }
779 }
780
781 fn start_count(&self) -> Arc<AtomicUsize> {
783 Arc::clone(&self.start_count)
784 }
785 }
786
787 #[async_trait]
788 impl Supervisable for MockWorker {
789 fn name(&self) -> &str {
790 self.name
791 }
792
793 fn shutdown_strategy(&self) -> ShutdownStrategy {
794 ShutdownStrategy::Graceful(Duration::from_millis(500))
795 }
796
797 async fn initialize(
798 &self, mut process_shutdown: ProcessShutdown,
799 ) -> Result<SupervisorFuture, InitializationError> {
800 match &self.init_behavior {
801 InitBehavior::Instant => {}
802 InitBehavior::Slow(delay) => {
803 sleep(*delay).await;
804 }
805 InitBehavior::Fail(msg) => {
806 return Err(InitializationError::Failed {
807 source: GenericError::msg(*msg),
808 });
809 }
810 }
811
812 let start_count = Arc::clone(&self.start_count);
813 let run_behavior = self.run_behavior.clone();
814
815 Ok(Box::pin(async move {
816 start_count.fetch_add(1, Ordering::SeqCst);
817
818 match run_behavior {
819 RunBehavior::UntilShutdown => {
820 process_shutdown.wait_for_shutdown().await;
821 Ok(())
822 }
823 RunBehavior::FailAfter(delay, msg) => {
824 select! {
825 _ = sleep(delay) => {
826 Err(GenericError::msg(msg))
827 }
828 _ = process_shutdown.wait_for_shutdown() => {
829 Ok(())
830 }
831 }
832 }
833 }
834 }))
835 }
836 }
837
838 async fn run_supervisor_with_trigger(
841 mut supervisor: Supervisor,
842 ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
843 let (tx, rx) = oneshot::channel();
844 let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
845 sleep(Duration::from_millis(50)).await;
847 (tx, handle)
848 }
849
850 #[tokio::test]
853 async fn standalone_supervisor_shuts_down_cleanly() {
854 let mut sup = Supervisor::new("test-sup").unwrap();
855 sup.add_worker(MockWorker::long_running("worker1"));
856 sup.add_worker(MockWorker::long_running("worker2"));
857
858 let (tx, handle) = run_supervisor_with_trigger(sup).await;
859 tx.send(()).unwrap();
860
861 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
862 assert!(result.is_ok());
863 }
864
865 #[tokio::test]
866 async fn nested_supervisor_shuts_down_cleanly() {
867 let mut child_sup = Supervisor::new("child-sup").unwrap();
868 child_sup.add_worker(MockWorker::long_running("inner-worker"));
869
870 let mut parent_sup = Supervisor::new("parent-sup").unwrap();
871 parent_sup.add_worker(MockWorker::long_running("outer-worker"));
872 parent_sup.add_worker(child_sup);
873
874 let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
875 tx.send(()).unwrap();
876
877 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
878 assert!(result.is_ok());
879 }
880
881 #[tokio::test]
882 async fn supervisor_with_no_children_returns_error() {
883 let mut sup = Supervisor::new("empty-sup").unwrap();
884
885 let (tx, rx) = oneshot::channel::<()>();
886 let result = sup.run_with_shutdown(rx).await;
887 drop(tx);
888
889 assert!(matches!(result, Err(SupervisorError::NoChildren)));
890 }
891
892 #[tokio::test]
895 async fn one_for_one_restarts_only_failed_child() {
896 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
897 let failing_count = failing.start_count();
898
899 let stable = MockWorker::long_running("stable-worker");
900 let stable_count = stable.start_count();
901
902 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
903 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
904 );
905 sup.add_worker(stable);
906 sup.add_worker(failing);
907
908 let (tx, handle) = run_supervisor_with_trigger(sup).await;
909
910 sleep(Duration::from_millis(300)).await;
912 let _ = tx.send(());
913
914 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
915 assert!(result.is_ok());
916
917 assert!(
919 failing_count.load(Ordering::SeqCst) >= 2,
920 "failing worker should have been restarted"
921 );
922 assert_eq!(
924 stable_count.load(Ordering::SeqCst),
925 1,
926 "stable worker should not have been restarted"
927 );
928 }
929
930 #[tokio::test]
931 async fn one_for_all_restarts_all_children() {
932 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
933 let failing_count = failing.start_count();
934
935 let stable = MockWorker::long_running("stable-worker");
936 let stable_count = stable.start_count();
937
938 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
939 RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
940 );
941 sup.add_worker(stable);
942 sup.add_worker(failing);
943
944 let (tx, handle) = run_supervisor_with_trigger(sup).await;
945
946 sleep(Duration::from_millis(300)).await;
948 let _ = tx.send(());
949
950 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
951 assert!(result.is_ok());
952
953 assert!(
955 failing_count.load(Ordering::SeqCst) >= 2,
956 "failing worker should have been restarted"
957 );
958 assert!(
959 stable_count.load(Ordering::SeqCst) >= 2,
960 "stable worker should also have been restarted"
961 );
962 }
963
964 #[tokio::test]
965 async fn restart_limit_exceeded_shuts_down_supervisor() {
966 let mut sup = Supervisor::new("test-sup")
967 .unwrap()
968 .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
969 sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
971
972 let (tx, rx) = oneshot::channel::<()>();
973 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
974
975 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
976 drop(tx);
977
978 assert!(matches!(result, Err(SupervisorError::Shutdown)));
979 }
980
981 #[tokio::test]
984 async fn init_failure_propagates_with_child_name() {
985 let mut sup = Supervisor::new("test-sup").unwrap();
986 sup.add_worker(MockWorker::long_running("good-worker"));
987 sup.add_worker(MockWorker::init_failure("bad-worker"));
988
989 let (_tx, rx) = oneshot::channel::<()>();
990 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
991 .await
992 .unwrap();
993
994 match result {
995 Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
996 assert_eq!(child_name, "bad-worker");
997 }
998 other => panic!("expected FailedToInitialize, got: {:?}", other),
999 }
1000 }
1001
1002 #[tokio::test]
1003 async fn init_failure_does_not_trigger_restart() {
1004 let init_fail = MockWorker::init_failure("bad-worker");
1005 let start_count = init_fail.start_count();
1006
1007 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1008 RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1009 );
1010 sup.add_worker(init_fail);
1011
1012 let (_tx, rx) = oneshot::channel::<()>();
1013 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1014 .await
1015 .unwrap();
1016
1017 assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1018 assert_eq!(start_count.load(Ordering::SeqCst), 0);
1020 }
1021
1022 #[tokio::test]
1025 async fn shutdown_completes_promptly_in_steady_state() {
1026 let mut sup = Supervisor::new("test-sup").unwrap();
1027 sup.add_worker(MockWorker::long_running("worker1"));
1028 sup.add_worker(MockWorker::long_running("worker2"));
1029
1030 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1031 tx.send(()).unwrap();
1032
1033 let result = timeout(Duration::from_secs(1), handle).await;
1035 assert!(result.is_ok(), "shutdown should complete promptly");
1036 }
1037
1038 #[tokio::test]
1039 async fn shutdown_during_slow_init_completes_promptly() {
1040 let mut sup = Supervisor::new("test-sup").unwrap();
1041 sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1043
1044 let (tx, rx) = oneshot::channel();
1045 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1046
1047 sleep(Duration::from_millis(20)).await;
1049 tx.send(()).unwrap();
1050
1051 let result = timeout(Duration::from_secs(2), handle).await;
1054 assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1055 }
1056}