1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_error::GenericError;
6use snafu::{OptionExt as _, Snafu};
7use tokio::{
8 pin, select,
9 task::{AbortHandle, Id, JoinSet},
10};
11use tracing::{debug, error, warn};
12
13use super::{
14 dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
15 restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
16 shutdown::{ProcessShutdown, ShutdownHandle},
17};
18use crate::runtime::{
19 process::{Process, ProcessExt as _},
20 state::DataspaceRegistry,
21};
22
23pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
25
26type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
32
33#[derive(Debug)]
38enum WorkerError {
39 Initialization {
45 child_name: Option<String>,
46 source: InitializationError,
47 },
48
49 Runtime(GenericError),
51}
52
53impl From<SupervisorError> for WorkerError {
54 fn from(err: SupervisorError) -> Self {
55 match err {
56 SupervisorError::FailedToInitialize { child_name, source } => WorkerError::Initialization {
59 child_name: Some(child_name),
60 source,
61 },
62 other => WorkerError::Runtime(other.into()),
64 }
65 }
66}
67
68#[derive(Debug, Snafu)]
70pub enum ProcessError {
71 #[snafu(display("Child process was aborted by the supervisor."))]
73 Aborted,
74
75 #[snafu(display("Child process panicked."))]
77 Panicked,
78
79 #[snafu(display("Child process terminated with an error: {}", source))]
81 Terminated {
82 source: GenericError,
84 },
85}
86
87#[derive(Debug, Snafu)]
93#[snafu(context(suffix(false)))]
94pub enum InitializationError {
95 #[snafu(display("Process failed to initialize: {}", source))]
97 Failed {
98 source: GenericError,
100 },
101}
102
103impl From<GenericError> for InitializationError {
104 fn from(source: GenericError) -> Self {
105 Self::Failed { source }
106 }
107}
108
109pub enum ShutdownStrategy {
111 Graceful(Duration),
113
114 Brutal,
116}
117
118#[async_trait]
120pub trait Supervisable: Send + Sync {
121 fn name(&self) -> &str;
123
124 fn shutdown_strategy(&self) -> ShutdownStrategy {
126 ShutdownStrategy::Graceful(Duration::from_secs(5))
127 }
128
129 async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError>;
143}
144
145#[derive(Debug, Snafu)]
147#[snafu(context(suffix(false)))]
148pub enum SupervisorError {
149 #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
151 InvalidName {
152 name: String,
154 },
155
156 #[snafu(display("Supervisor has no child processes."))]
158 NoChildren,
159
160 #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
165 FailedToInitialize {
166 child_name: String,
168
169 source: InitializationError,
171 },
172
173 #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
175 Shutdown,
176}
177
178pub enum ChildSpecification {
187 Worker(Arc<dyn Supervisable>),
188 Supervisor(Supervisor),
189}
190
191impl ChildSpecification {
192 fn process_type(&self) -> &'static str {
193 match self {
194 Self::Worker(_) => "worker",
195 Self::Supervisor(_) => "supervisor",
196 }
197 }
198
199 fn name(&self) -> &str {
200 match self {
201 Self::Worker(worker) => worker.name(),
202 Self::Supervisor(supervisor) => &supervisor.supervisor_id,
203 }
204 }
205
206 fn shutdown_strategy(&self) -> ShutdownStrategy {
207 match self {
208 Self::Worker(worker) => worker.shutdown_strategy(),
209
210 Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
213 }
214 }
215
216 fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
217 match self {
218 Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
219 name: worker.name().to_string(),
220 }),
221 Self::Supervisor(sup) => {
222 Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
223 name: sup.supervisor_id.to_string(),
224 })
225 }
226 }
227 }
228
229 fn create_worker_future(
230 &self, process: Process, process_shutdown: ProcessShutdown,
231 ) -> Result<WorkerFuture, SupervisorError> {
232 match self {
233 Self::Worker(worker) => {
234 let worker = Arc::clone(worker);
235 Ok(Box::pin(async move {
236 let run_future =
237 worker
238 .initialize(process_shutdown)
239 .await
240 .map_err(|source| WorkerError::Initialization {
241 child_name: None,
242 source,
243 })?;
244 run_future.await.map_err(WorkerError::Runtime)
245 }))
246 }
247 Self::Supervisor(sup) => {
248 match sup.runtime_mode() {
249 RuntimeMode::Ambient => {
250 Ok(sup.as_nested_process(process, process_shutdown))
252 }
253 RuntimeMode::Dedicated(config) => {
254 let child_name = sup.supervisor_id.to_string();
257 let dataspace = process.dataspace().clone();
258 let handle =
259 spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown, dataspace)
260 .map_err(|e| SupervisorError::FailedToInitialize {
261 child_name,
262 source: e.into(),
263 })?;
264
265 Ok(Box::pin(async move { handle.await.map_err(WorkerError::from) }))
266 }
267 }
268 }
269 }
270 }
271}
272
273impl Clone for ChildSpecification {
274 fn clone(&self) -> Self {
275 match self {
276 Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
277 Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
278 }
279 }
280}
281
282impl From<Supervisor> for ChildSpecification {
283 fn from(supervisor: Supervisor) -> Self {
284 Self::Supervisor(supervisor)
285 }
286}
287
288impl<T> From<T> for ChildSpecification
289where
290 T: Supervisable + 'static,
291{
292 fn from(worker: T) -> Self {
293 Self::Worker(Arc::new(worker))
294 }
295}
296
297pub struct Supervisor {
325 supervisor_id: Arc<str>,
326 child_specs: Vec<ChildSpecification>,
327 restart_strategy: RestartStrategy,
328 runtime_mode: RuntimeMode,
329}
330
331impl Supervisor {
332 pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
334 if supervisor_id.as_ref().is_empty() {
338 return Err(SupervisorError::InvalidName {
339 name: supervisor_id.as_ref().to_string(),
340 });
341 }
342
343 Ok(Self {
344 supervisor_id: supervisor_id.as_ref().into(),
345 child_specs: Vec::new(),
346 restart_strategy: RestartStrategy::default(),
347 runtime_mode: RuntimeMode::default(),
348 })
349 }
350
351 pub fn id(&self) -> &str {
353 &self.supervisor_id
354 }
355
356 pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
358 self.restart_strategy = strategy;
359 self
360 }
361
362 pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
372 self.runtime_mode = RuntimeMode::Dedicated(config);
373 self
374 }
375
376 pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
378 &self.runtime_mode
379 }
380
381 pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
386 let child_spec = process.into();
387 debug!(
388 supervisor_id = %self.supervisor_id,
389 "Adding new static child process #{}. ({}, {})",
390 self.child_specs.len(),
391 child_spec.process_type(),
392 child_spec.name(),
393 );
394 self.child_specs.push(child_spec);
395 }
396
397 fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
398 match self.child_specs.get(child_spec_idx) {
399 Some(child_spec) => child_spec,
400 None => unreachable!("child spec index should never be out of bounds"),
401 }
402 }
403
404 fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
405 let child_spec = self.get_child_spec(child_spec_idx);
406 debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
407 worker_state.add_worker(child_spec_idx, child_spec)
408 }
409
410 fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
411 debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
412 for child_spec_idx in 0..self.child_specs.len() {
413 self.spawn_child(child_spec_idx, worker_state)?;
414 }
415
416 Ok(())
417 }
418
419 async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
420 if self.child_specs.is_empty() {
421 return Err(SupervisorError::NoChildren);
422 }
423
424 let mut restart_state = RestartState::new(self.restart_strategy);
425 let mut worker_state = WorkerState::new(process);
426
427 self.spawn_all_children(&mut worker_state)?;
430
431 let shutdown = process_shutdown.wait_for_shutdown();
433 pin!(shutdown);
434
435 loop {
436 select! {
437 _ = &mut shutdown => {
441 debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
442 worker_state.shutdown_workers().await;
443 break;
444 },
445 worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
446 Some((child_spec_idx, worker_result)) => {
452 let child_spec = self.get_child_spec(child_spec_idx);
453
454 if let Err(WorkerError::Initialization { child_name, source }) = worker_result {
456 let full_name = match child_name {
459 Some(inner) => format!("{}/{}", child_spec.name(), inner),
460 None => child_spec.name().to_string(),
461 };
462
463 error!(supervisor_id = %self.supervisor_id, worker_name = full_name, "Child process failed to initialize: {}", source);
464 worker_state.shutdown_workers().await;
465 return Err(SupervisorError::FailedToInitialize {
466 child_name: full_name,
467 source,
468 });
469 }
470
471 let worker_result = worker_result
473 .map_err(|e| match e {
474 WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
475 WorkerError::Initialization { .. } => unreachable!("handled above"),
476 });
477
478 match restart_state.evaluate_restart() {
479 RestartAction::Restart(mode) => match mode {
480 RestartMode::OneForOne => {
481 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
482 self.spawn_child(child_spec_idx, &mut worker_state)?;
483 }
484 RestartMode::OneForAll => {
485 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
486 worker_state.shutdown_workers().await;
487 self.spawn_all_children(&mut worker_state)?;
488 }
489 },
490 RestartAction::Shutdown => {
491 error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
492 worker_state.shutdown_workers().await;
493 return Err(SupervisorError::Shutdown);
494 }
495 }
496 },
497 None => unreachable!("should not have empty worker joinset prior to shutdown"),
498 }
499 }
500 }
501
502 Ok(())
503 }
504
505 fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture {
506 debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
509
510 let sup = self.inner_clone();
512
513 Box::pin(async move {
514 sup.run_inner(process, process_shutdown)
515 .await
516 .map_err(WorkerError::from)
517 })
518 }
519
520 pub async fn run(&mut self) -> Result<(), SupervisorError> {
526 let process_shutdown = ProcessShutdown::noop();
529 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
530 name: self.supervisor_id.to_string(),
531 })?;
532
533 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
534 self.run_inner(process.clone(), process_shutdown)
535 .into_process_future(process)
536 .await
537 }
538
539 pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
548 let process_shutdown = ProcessShutdown::wrapped(shutdown);
549 self.run_with_process_shutdown(process_shutdown, None).await
550 }
551
552 pub(crate) async fn run_with_process_shutdown(
564 &mut self, process_shutdown: ProcessShutdown, dataspace: Option<DataspaceRegistry>,
565 ) -> Result<(), SupervisorError> {
566 let process =
567 Process::supervisor_with_dataspace(&self.supervisor_id, None, dataspace).context(InvalidName {
568 name: self.supervisor_id.to_string(),
569 })?;
570
571 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
572 self.run_inner(process.clone(), process_shutdown)
573 .into_process_future(process)
574 .await
575 }
576
577 fn inner_clone(&self) -> Self {
578 Self {
582 supervisor_id: Arc::clone(&self.supervisor_id),
583 child_specs: self.child_specs.clone(),
584 restart_strategy: self.restart_strategy,
585 runtime_mode: self.runtime_mode.clone(),
586 }
587 }
588}
589
590struct ProcessState {
591 worker_id: usize,
592 shutdown_strategy: ShutdownStrategy,
593 shutdown_handle: ShutdownHandle,
594 abort_handle: AbortHandle,
595}
596
597struct WorkerState {
598 process: Process,
599 worker_tasks: JoinSet<Result<(), WorkerError>>,
600 worker_map: FastIndexMap<Id, ProcessState>,
601}
602
603impl WorkerState {
604 fn new(process: Process) -> Self {
605 Self {
606 process,
607 worker_tasks: JoinSet::new(),
608 worker_map: FastIndexMap::default(),
609 }
610 }
611
612 fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
613 let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
614 let process = child_spec.create_process(&self.process)?;
615 let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?;
616 let shutdown_strategy = child_spec.shutdown_strategy();
617 let abort_handle = self.worker_tasks.spawn(worker_future.into_process_future(process));
618 self.worker_map.insert(
619 abort_handle.id(),
620 ProcessState {
621 worker_id,
622 shutdown_strategy,
623 shutdown_handle,
624 abort_handle,
625 },
626 );
627 Ok(())
628 }
629
630 async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> {
631 debug!("Waiting for next process to complete.");
632
633 match self.worker_tasks.join_next_with_id().await {
634 Some(Ok((worker_task_id, worker_result))) => {
635 let process_state = self
636 .worker_map
637 .swap_remove(&worker_task_id)
638 .expect("worker task ID not found");
639 Some((process_state.worker_id, worker_result))
640 }
641 Some(Err(e)) => {
642 let worker_task_id = e.id();
643 let process_state = self
644 .worker_map
645 .swap_remove(&worker_task_id)
646 .expect("worker task ID not found");
647 let e = if e.is_cancelled() {
648 ProcessError::Aborted
649 } else {
650 ProcessError::Panicked
651 };
652 Some((process_state.worker_id, Err(WorkerError::Runtime(e.into()))))
653 }
654 None => None,
655 }
656 }
657
658 async fn shutdown_workers(&mut self) {
659 debug!("Shutting down all processes.");
660
661 while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
671 let ProcessState {
672 worker_id,
673 shutdown_strategy,
674 shutdown_handle,
675 abort_handle,
676 } = process_state;
677
678 let shutdown_deadline = match shutdown_strategy {
680 ShutdownStrategy::Graceful(timeout) => {
681 debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
682 shutdown_handle.trigger();
683
684 tokio::time::sleep(timeout)
685 }
686 ShutdownStrategy::Brutal => {
687 debug!(worker_id, "Forcefully aborting process.");
688 abort_handle.abort();
689
690 tokio::time::sleep(Duration::MAX)
693 }
694 };
695 pin!(shutdown_deadline);
696
697 loop {
700 select! {
701 worker_result = self.worker_tasks.join_next_with_id() => {
702 match worker_result {
703 Some(Ok((worker_task_id, _))) => {
704 if worker_task_id == current_worker_task_id {
705 debug!(?worker_task_id, "Target process exited successfully.");
706 break;
707 } else {
708 debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
709 self.worker_map.swap_remove(&worker_task_id);
710 }
711 },
712 Some(Err(e)) => {
713 let worker_task_id = e.id();
714 if worker_task_id == current_worker_task_id {
715 debug!(?worker_task_id, "Target process exited with error.");
716 break;
717 } else {
718 debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
719 self.worker_map.swap_remove(&worker_task_id);
720 }
721 }
722 None => unreachable!("worker task must exist in join set if we are waiting for it"),
723 }
724 },
725 _ = &mut shutdown_deadline => {
727 debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
728 abort_handle.abort();
729 }
730 }
731 }
732 }
733
734 debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
735 debug_assert!(
736 self.worker_tasks.is_empty(),
737 "worker tasks should be empty after shutdown"
738 );
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use std::sync::atomic::{AtomicUsize, Ordering};
745
746 use async_trait::async_trait;
747 use tokio::{
748 sync::oneshot,
749 task::JoinHandle,
750 time::{sleep, timeout},
751 };
752
753 use super::*;
754
755 #[derive(Clone)]
757 enum InitBehavior {
758 Instant,
760
761 Slow(Duration),
763
764 Fail(&'static str),
766 }
767
768 #[derive(Clone)]
770 enum RunBehavior {
771 UntilShutdown,
773
774 FailAfter(Duration, &'static str),
776 }
777
778 struct MockWorker {
780 name: &'static str,
781 init_behavior: InitBehavior,
782 run_behavior: RunBehavior,
783 start_count: Arc<AtomicUsize>,
784 }
785
786 impl MockWorker {
787 fn long_running(name: &'static str) -> Self {
789 Self {
790 name,
791 init_behavior: InitBehavior::Instant,
792 run_behavior: RunBehavior::UntilShutdown,
793 start_count: Arc::new(AtomicUsize::new(0)),
794 }
795 }
796
797 fn failing(name: &'static str, delay: Duration) -> Self {
799 Self {
800 name,
801 init_behavior: InitBehavior::Instant,
802 run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
803 start_count: Arc::new(AtomicUsize::new(0)),
804 }
805 }
806
807 fn init_failure(name: &'static str) -> Self {
809 Self {
810 name,
811 init_behavior: InitBehavior::Fail("init failed"),
812 run_behavior: RunBehavior::UntilShutdown,
813 start_count: Arc::new(AtomicUsize::new(0)),
814 }
815 }
816
817 fn slow_init(name: &'static str, init_delay: Duration) -> Self {
819 Self {
820 name,
821 init_behavior: InitBehavior::Slow(init_delay),
822 run_behavior: RunBehavior::UntilShutdown,
823 start_count: Arc::new(AtomicUsize::new(0)),
824 }
825 }
826
827 fn start_count(&self) -> Arc<AtomicUsize> {
829 Arc::clone(&self.start_count)
830 }
831 }
832
833 #[async_trait]
834 impl Supervisable for MockWorker {
835 fn name(&self) -> &str {
836 self.name
837 }
838
839 fn shutdown_strategy(&self) -> ShutdownStrategy {
840 ShutdownStrategy::Graceful(Duration::from_millis(500))
841 }
842
843 async fn initialize(
844 &self, mut process_shutdown: ProcessShutdown,
845 ) -> Result<SupervisorFuture, InitializationError> {
846 match &self.init_behavior {
847 InitBehavior::Instant => {}
848 InitBehavior::Slow(delay) => {
849 sleep(*delay).await;
850 }
851 InitBehavior::Fail(msg) => {
852 return Err(InitializationError::Failed {
853 source: GenericError::msg(*msg),
854 });
855 }
856 }
857
858 let start_count = Arc::clone(&self.start_count);
859 let run_behavior = self.run_behavior.clone();
860
861 Ok(Box::pin(async move {
862 start_count.fetch_add(1, Ordering::SeqCst);
863
864 match run_behavior {
865 RunBehavior::UntilShutdown => {
866 process_shutdown.wait_for_shutdown().await;
867 Ok(())
868 }
869 RunBehavior::FailAfter(delay, msg) => {
870 select! {
871 _ = sleep(delay) => {
872 Err(GenericError::msg(msg))
873 }
874 _ = process_shutdown.wait_for_shutdown() => {
875 Ok(())
876 }
877 }
878 }
879 }
880 }))
881 }
882 }
883
884 async fn run_supervisor_with_trigger(
887 mut supervisor: Supervisor,
888 ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
889 let (tx, rx) = oneshot::channel();
890 let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
891 sleep(Duration::from_millis(50)).await;
893 (tx, handle)
894 }
895
896 #[tokio::test]
899 async fn standalone_supervisor_shuts_down_cleanly() {
900 let mut sup = Supervisor::new("test-sup").unwrap();
901 sup.add_worker(MockWorker::long_running("worker1"));
902 sup.add_worker(MockWorker::long_running("worker2"));
903
904 let (tx, handle) = run_supervisor_with_trigger(sup).await;
905 tx.send(()).unwrap();
906
907 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
908 assert!(result.is_ok());
909 }
910
911 #[tokio::test]
912 async fn nested_supervisor_shuts_down_cleanly() {
913 let mut child_sup = Supervisor::new("child-sup").unwrap();
914 child_sup.add_worker(MockWorker::long_running("inner-worker"));
915
916 let mut parent_sup = Supervisor::new("parent-sup").unwrap();
917 parent_sup.add_worker(MockWorker::long_running("outer-worker"));
918 parent_sup.add_worker(child_sup);
919
920 let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
921 tx.send(()).unwrap();
922
923 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
924 assert!(result.is_ok());
925 }
926
927 #[tokio::test]
928 async fn supervisor_with_no_children_returns_error() {
929 let mut sup = Supervisor::new("empty-sup").unwrap();
930
931 let (tx, rx) = oneshot::channel::<()>();
932 let result = sup.run_with_shutdown(rx).await;
933 drop(tx);
934
935 assert!(matches!(result, Err(SupervisorError::NoChildren)));
936 }
937
938 #[tokio::test]
941 async fn one_for_one_restarts_only_failed_child() {
942 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
943 let failing_count = failing.start_count();
944
945 let stable = MockWorker::long_running("stable-worker");
946 let stable_count = stable.start_count();
947
948 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
949 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
950 );
951 sup.add_worker(stable);
952 sup.add_worker(failing);
953
954 let (tx, handle) = run_supervisor_with_trigger(sup).await;
955
956 sleep(Duration::from_millis(300)).await;
958 let _ = tx.send(());
959
960 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
961 assert!(result.is_ok());
962
963 assert!(
965 failing_count.load(Ordering::SeqCst) >= 2,
966 "failing worker should have been restarted"
967 );
968 assert_eq!(
970 stable_count.load(Ordering::SeqCst),
971 1,
972 "stable worker should not have been restarted"
973 );
974 }
975
976 #[tokio::test]
977 async fn one_for_all_restarts_all_children() {
978 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
979 let failing_count = failing.start_count();
980
981 let stable = MockWorker::long_running("stable-worker");
982 let stable_count = stable.start_count();
983
984 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
985 RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
986 );
987 sup.add_worker(stable);
988 sup.add_worker(failing);
989
990 let (tx, handle) = run_supervisor_with_trigger(sup).await;
991
992 sleep(Duration::from_millis(300)).await;
994 let _ = tx.send(());
995
996 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
997 assert!(result.is_ok());
998
999 assert!(
1001 failing_count.load(Ordering::SeqCst) >= 2,
1002 "failing worker should have been restarted"
1003 );
1004 assert!(
1005 stable_count.load(Ordering::SeqCst) >= 2,
1006 "stable worker should also have been restarted"
1007 );
1008 }
1009
1010 #[tokio::test]
1011 async fn restart_limit_exceeded_shuts_down_supervisor() {
1012 let mut sup = Supervisor::new("test-sup")
1013 .unwrap()
1014 .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
1015 sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
1017
1018 let (tx, rx) = oneshot::channel::<()>();
1019 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1020
1021 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1022 drop(tx);
1023
1024 assert!(matches!(result, Err(SupervisorError::Shutdown)));
1025 }
1026
1027 #[tokio::test]
1030 async fn init_failure_propagates_with_child_name() {
1031 let mut sup = Supervisor::new("test-sup").unwrap();
1032 sup.add_worker(MockWorker::long_running("good-worker"));
1033 sup.add_worker(MockWorker::init_failure("bad-worker"));
1034
1035 let (_tx, rx) = oneshot::channel::<()>();
1036 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1037 .await
1038 .unwrap();
1039
1040 match result {
1041 Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
1042 assert_eq!(child_name, "bad-worker");
1043 }
1044 other => panic!("expected FailedToInitialize, got: {:?}", other),
1045 }
1046 }
1047
1048 #[tokio::test]
1049 async fn init_failure_does_not_trigger_restart() {
1050 let init_fail = MockWorker::init_failure("bad-worker");
1051 let start_count = init_fail.start_count();
1052
1053 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1054 RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1055 );
1056 sup.add_worker(init_fail);
1057
1058 let (_tx, rx) = oneshot::channel::<()>();
1059 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1060 .await
1061 .unwrap();
1062
1063 assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1064 assert_eq!(start_count.load(Ordering::SeqCst), 0);
1066 }
1067
1068 #[tokio::test]
1071 async fn shutdown_completes_promptly_in_steady_state() {
1072 let mut sup = Supervisor::new("test-sup").unwrap();
1073 sup.add_worker(MockWorker::long_running("worker1"));
1074 sup.add_worker(MockWorker::long_running("worker2"));
1075
1076 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1077 tx.send(()).unwrap();
1078
1079 let result = timeout(Duration::from_secs(1), handle).await;
1081 assert!(result.is_ok(), "shutdown should complete promptly");
1082 }
1083
1084 #[tokio::test]
1085 async fn shutdown_during_slow_init_completes_promptly() {
1086 let mut sup = Supervisor::new("test-sup").unwrap();
1087 sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1089
1090 let (tx, rx) = oneshot::channel();
1091 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1092
1093 sleep(Duration::from_millis(20)).await;
1095 tx.send(()).unwrap();
1096
1097 let result = timeout(Duration::from_secs(2), handle).await;
1100 assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1101 }
1102}