1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use saluki_common::collections::FastIndexMap;
5use saluki_error::GenericError;
6use snafu::{OptionExt as _, Snafu};
7use tokio::{
8 pin, select,
9 task::{AbortHandle, Id, JoinSet},
10};
11use tracing::{debug, error, warn};
12
13use super::{
14 dedicated::{spawn_dedicated_runtime, RuntimeConfiguration, RuntimeMode},
15 restart::{RestartAction, RestartMode, RestartState, RestartStrategy},
16 shutdown::{ProcessShutdown, ShutdownHandle},
17};
18use crate::runtime::{
19 process::{Process, ProcessExt as _},
20 state::DataspaceRegistry,
21};
22
23pub type SupervisorFuture = Pin<Box<dyn Future<Output = Result<(), GenericError>> + Send>>;
25
26type WorkerFuture = Pin<Box<dyn Future<Output = Result<(), WorkerError>> + Send>>;
32
33#[derive(Debug)]
38enum WorkerError {
39 Initialization {
45 child_name: Option<String>,
46 source: InitializationError,
47 },
48
49 Runtime(GenericError),
51}
52
53impl From<SupervisorError> for WorkerError {
54 fn from(err: SupervisorError) -> Self {
55 match err {
56 SupervisorError::FailedToInitialize { child_name, source } => WorkerError::Initialization {
59 child_name: Some(child_name),
60 source,
61 },
62 other => WorkerError::Runtime(other.into()),
64 }
65 }
66}
67
68#[derive(Debug, Snafu)]
70pub enum ProcessError {
71 #[snafu(display("Child process was aborted by the supervisor."))]
73 Aborted,
74
75 #[snafu(display("Child process panicked."))]
77 Panicked,
78
79 #[snafu(display("Child process terminated with an error: {}", source))]
81 Terminated {
82 source: GenericError,
84 },
85}
86
87#[derive(Debug, Snafu)]
93#[snafu(context(suffix(false)))]
94pub enum InitializationError {
95 #[snafu(display("Process failed to initialize: {}", source))]
97 Failed {
98 source: GenericError,
100 },
101}
102
103impl From<GenericError> for InitializationError {
104 fn from(source: GenericError) -> Self {
105 Self::Failed { source }
106 }
107}
108
109pub enum ShutdownStrategy {
111 Graceful(Duration),
113
114 Brutal,
116}
117
118#[async_trait]
120pub trait Supervisable: Send + Sync {
121 fn name(&self) -> &str;
123
124 fn shutdown_strategy(&self) -> ShutdownStrategy {
126 ShutdownStrategy::Graceful(Duration::from_secs(5))
127 }
128
129 async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result<SupervisorFuture, InitializationError>;
143}
144
145#[derive(Debug, Snafu)]
147#[snafu(context(suffix(false)))]
148pub enum SupervisorError {
149 #[snafu(display("Invalid name for supervisor or worker: '{}'", name))]
151 InvalidName {
152 name: String,
154 },
155
156 #[snafu(display("Supervisor has no child processes."))]
158 NoChildren,
159
160 #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))]
165 FailedToInitialize {
166 child_name: String,
168
169 source: InitializationError,
171 },
172
173 #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))]
175 Shutdown,
176}
177
178pub enum ChildSpecification {
187 Worker(Arc<dyn Supervisable>),
188 Supervisor(Supervisor),
189}
190
191impl ChildSpecification {
192 fn process_type(&self) -> &'static str {
193 match self {
194 Self::Worker(_) => "worker",
195 Self::Supervisor(_) => "supervisor",
196 }
197 }
198
199 fn name(&self) -> &str {
200 match self {
201 Self::Worker(worker) => worker.name(),
202 Self::Supervisor(supervisor) => &supervisor.supervisor_id,
203 }
204 }
205
206 fn shutdown_strategy(&self) -> ShutdownStrategy {
207 match self {
208 Self::Worker(worker) => worker.shutdown_strategy(),
209
210 Self::Supervisor(_) => ShutdownStrategy::Graceful(Duration::MAX),
213 }
214 }
215
216 fn create_process(&self, parent_process: &Process) -> Result<Process, SupervisorError> {
217 match self {
218 Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName {
219 name: worker.name().to_string(),
220 }),
221 Self::Supervisor(sup) => {
222 Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName {
223 name: sup.supervisor_id.to_string(),
224 })
225 }
226 }
227 }
228
229 fn create_worker_future(
230 &self, process: Process, process_shutdown: ProcessShutdown,
231 ) -> Result<WorkerFuture, SupervisorError> {
232 match self {
233 Self::Worker(worker) => {
234 let worker = Arc::clone(worker);
235 Ok(Box::pin(async move {
236 let run_future =
237 worker
238 .initialize(process_shutdown)
239 .await
240 .map_err(|source| WorkerError::Initialization {
241 child_name: None,
242 source,
243 })?;
244 run_future.await.map_err(WorkerError::Runtime)
245 }))
246 }
247 Self::Supervisor(sup) => {
248 match sup.runtime_mode() {
249 RuntimeMode::Ambient => {
250 Ok(sup.as_nested_process(process, process_shutdown))
252 }
253 RuntimeMode::Dedicated(config) => {
254 let child_name = sup.supervisor_id.to_string();
257 let dataspace = process.dataspace().clone();
258 let handle =
259 spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown, dataspace)
260 .map_err(|e| SupervisorError::FailedToInitialize {
261 child_name,
262 source: e.into(),
263 })?;
264
265 Ok(Box::pin(async move { handle.await.map_err(WorkerError::from) }))
266 }
267 }
268 }
269 }
270 }
271}
272
273impl Clone for ChildSpecification {
274 fn clone(&self) -> Self {
275 match self {
276 Self::Worker(worker) => Self::Worker(Arc::clone(worker)),
277 Self::Supervisor(supervisor) => Self::Supervisor(supervisor.inner_clone()),
278 }
279 }
280}
281
282impl From<Supervisor> for ChildSpecification {
283 fn from(supervisor: Supervisor) -> Self {
284 Self::Supervisor(supervisor)
285 }
286}
287
288impl<T> From<T> for ChildSpecification
289where
290 T: Supervisable + 'static,
291{
292 fn from(worker: T) -> Self {
293 Self::Worker(Arc::new(worker))
294 }
295}
296
297pub struct Supervisor {
324 supervisor_id: Arc<str>,
325 child_specs: Vec<ChildSpecification>,
326 restart_strategy: RestartStrategy,
327 runtime_mode: RuntimeMode,
328}
329
330impl Supervisor {
331 pub fn new<S: AsRef<str>>(supervisor_id: S) -> Result<Self, SupervisorError> {
333 if supervisor_id.as_ref().is_empty() {
337 return Err(SupervisorError::InvalidName {
338 name: supervisor_id.as_ref().to_string(),
339 });
340 }
341
342 Ok(Self {
343 supervisor_id: supervisor_id.as_ref().into(),
344 child_specs: Vec::new(),
345 restart_strategy: RestartStrategy::default(),
346 runtime_mode: RuntimeMode::default(),
347 })
348 }
349
350 pub fn id(&self) -> &str {
352 &self.supervisor_id
353 }
354
355 pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
357 self.restart_strategy = strategy;
358 self
359 }
360
361 pub fn with_dedicated_runtime(mut self, config: RuntimeConfiguration) -> Self {
371 self.runtime_mode = RuntimeMode::Dedicated(config);
372 self
373 }
374
375 pub(crate) fn runtime_mode(&self) -> &RuntimeMode {
377 &self.runtime_mode
378 }
379
380 pub fn add_worker<T: Into<ChildSpecification>>(&mut self, process: T) {
385 let child_spec = process.into();
386 debug!(
387 supervisor_id = %self.supervisor_id,
388 "Adding new static child process #{}. ({}, {})",
389 self.child_specs.len(),
390 child_spec.process_type(),
391 child_spec.name(),
392 );
393 self.child_specs.push(child_spec);
394 }
395
396 fn get_child_spec(&self, child_spec_idx: usize) -> &ChildSpecification {
397 match self.child_specs.get(child_spec_idx) {
398 Some(child_spec) => child_spec,
399 None => unreachable!("child spec index should never be out of bounds"),
400 }
401 }
402
403 fn spawn_child(&self, child_spec_idx: usize, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
404 let child_spec = self.get_child_spec(child_spec_idx);
405 debug!(supervisor_id = %self.supervisor_id, "Spawning static child process #{} ({}).", child_spec_idx, child_spec.name());
406 worker_state.add_worker(child_spec_idx, child_spec)
407 }
408
409 fn spawn_all_children(&self, worker_state: &mut WorkerState) -> Result<(), SupervisorError> {
410 debug!(supervisor_id = %self.supervisor_id, "Spawning all static child processes.");
411 for child_spec_idx in 0..self.child_specs.len() {
412 self.spawn_child(child_spec_idx, worker_state)?;
413 }
414
415 Ok(())
416 }
417
418 async fn run_inner(&self, process: Process, mut process_shutdown: ProcessShutdown) -> Result<(), SupervisorError> {
419 if self.child_specs.is_empty() {
420 return Err(SupervisorError::NoChildren);
421 }
422
423 let mut restart_state = RestartState::new(self.restart_strategy);
424 let mut worker_state = WorkerState::new(process);
425
426 self.spawn_all_children(&mut worker_state)?;
429
430 let shutdown = process_shutdown.wait_for_shutdown();
432 pin!(shutdown);
433
434 loop {
435 select! {
436 _ = &mut shutdown => {
440 debug!(supervisor_id = %self.supervisor_id, "Shutdown triggered, shutting down all child processes.");
441 worker_state.shutdown_workers().await;
442 break;
443 },
444 worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result {
445 Some((child_spec_idx, worker_result)) => {
451 let child_spec = self.get_child_spec(child_spec_idx);
452
453 if let Err(WorkerError::Initialization { child_name, source }) = worker_result {
455 let full_name = match child_name {
458 Some(inner) => format!("{}/{}", child_spec.name(), inner),
459 None => child_spec.name().to_string(),
460 };
461
462 error!(supervisor_id = %self.supervisor_id, worker_name = full_name, "Child process failed to initialize: {}", source);
463 worker_state.shutdown_workers().await;
464 return Err(SupervisorError::FailedToInitialize {
465 child_name: full_name,
466 source,
467 });
468 }
469
470 let worker_result = worker_result
472 .map_err(|e| match e {
473 WorkerError::Runtime(e) => ProcessError::Terminated { source: e },
474 WorkerError::Initialization { .. } => unreachable!("handled above"),
475 });
476
477 match restart_state.evaluate_restart() {
478 RestartAction::Restart(mode) => match mode {
479 RestartMode::OneForOne => {
480 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting.");
481 self.spawn_child(child_spec_idx, &mut worker_state)?;
482 }
483 RestartMode::OneForAll => {
484 warn!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Child process terminated, restarting all processes.");
485 worker_state.shutdown_workers().await;
486 self.spawn_all_children(&mut worker_state)?;
487 }
488 },
489 RestartAction::Shutdown => {
490 error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), ?worker_result, "Supervisor shutting down due to restart limits.");
491 worker_state.shutdown_workers().await;
492 return Err(SupervisorError::Shutdown);
493 }
494 }
495 },
496 None => unreachable!("should not have empty worker joinset prior to shutdown"),
497 }
498 }
499 }
500
501 Ok(())
502 }
503
504 fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture {
505 debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting.");
508
509 let sup = self.inner_clone();
511
512 Box::pin(async move {
513 sup.run_inner(process, process_shutdown)
514 .await
515 .map_err(WorkerError::from)
516 })
517 }
518
519 pub async fn run(&mut self) -> Result<(), SupervisorError> {
525 let process_shutdown = ProcessShutdown::noop();
528 let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName {
529 name: self.supervisor_id.to_string(),
530 })?;
531
532 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
533 self.run_inner(process.clone(), process_shutdown)
534 .into_process_future(process)
535 .await
536 }
537
538 pub async fn run_with_shutdown<F: Future + Send + 'static>(&mut self, shutdown: F) -> Result<(), SupervisorError> {
547 let process_shutdown = ProcessShutdown::wrapped(shutdown);
548 self.run_with_process_shutdown(process_shutdown, None).await
549 }
550
551 pub(crate) async fn run_with_process_shutdown(
563 &mut self, process_shutdown: ProcessShutdown, dataspace: Option<DataspaceRegistry>,
564 ) -> Result<(), SupervisorError> {
565 let process =
566 Process::supervisor_with_dataspace(&self.supervisor_id, None, dataspace).context(InvalidName {
567 name: self.supervisor_id.to_string(),
568 })?;
569
570 debug!(supervisor_id = %self.supervisor_id, "Supervisor starting.");
571 self.run_inner(process.clone(), process_shutdown)
572 .into_process_future(process)
573 .await
574 }
575
576 fn inner_clone(&self) -> Self {
577 Self {
581 supervisor_id: Arc::clone(&self.supervisor_id),
582 child_specs: self.child_specs.clone(),
583 restart_strategy: self.restart_strategy,
584 runtime_mode: self.runtime_mode.clone(),
585 }
586 }
587}
588
589struct ProcessState {
590 worker_id: usize,
591 shutdown_strategy: ShutdownStrategy,
592 shutdown_handle: ShutdownHandle,
593 abort_handle: AbortHandle,
594}
595
596struct WorkerState {
597 process: Process,
598 worker_tasks: JoinSet<Result<(), WorkerError>>,
599 worker_map: FastIndexMap<Id, ProcessState>,
600}
601
602impl WorkerState {
603 fn new(process: Process) -> Self {
604 Self {
605 process,
606 worker_tasks: JoinSet::new(),
607 worker_map: FastIndexMap::default(),
608 }
609 }
610
611 fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> {
612 let (process_shutdown, shutdown_handle) = ProcessShutdown::paired();
613 let process = child_spec.create_process(&self.process)?;
614 let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?;
615 let shutdown_strategy = child_spec.shutdown_strategy();
616 let abort_handle = self.worker_tasks.spawn(worker_future.into_process_future(process));
617 self.worker_map.insert(
618 abort_handle.id(),
619 ProcessState {
620 worker_id,
621 shutdown_strategy,
622 shutdown_handle,
623 abort_handle,
624 },
625 );
626 Ok(())
627 }
628
629 async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> {
630 debug!("Waiting for next process to complete.");
631
632 match self.worker_tasks.join_next_with_id().await {
633 Some(Ok((worker_task_id, worker_result))) => {
634 let process_state = self
635 .worker_map
636 .swap_remove(&worker_task_id)
637 .expect("worker task ID not found");
638 Some((process_state.worker_id, worker_result))
639 }
640 Some(Err(e)) => {
641 let worker_task_id = e.id();
642 let process_state = self
643 .worker_map
644 .swap_remove(&worker_task_id)
645 .expect("worker task ID not found");
646 let e = if e.is_cancelled() {
647 ProcessError::Aborted
648 } else {
649 ProcessError::Panicked
650 };
651 Some((process_state.worker_id, Err(WorkerError::Runtime(e.into()))))
652 }
653 None => None,
654 }
655 }
656
657 async fn shutdown_workers(&mut self) {
658 debug!("Shutting down all processes.");
659
660 while let Some((current_worker_task_id, process_state)) = self.worker_map.pop() {
670 let ProcessState {
671 worker_id,
672 shutdown_strategy,
673 shutdown_handle,
674 abort_handle,
675 } = process_state;
676
677 let shutdown_deadline = match shutdown_strategy {
679 ShutdownStrategy::Graceful(timeout) => {
680 debug!(worker_id, shutdown_timeout = ?timeout, "Gracefully shutting down process.");
681 shutdown_handle.trigger();
682
683 tokio::time::sleep(timeout)
684 }
685 ShutdownStrategy::Brutal => {
686 debug!(worker_id, "Forcefully aborting process.");
687 abort_handle.abort();
688
689 tokio::time::sleep(Duration::MAX)
692 }
693 };
694 pin!(shutdown_deadline);
695
696 loop {
699 select! {
700 worker_result = self.worker_tasks.join_next_with_id() => {
701 match worker_result {
702 Some(Ok((worker_task_id, _))) => {
703 if worker_task_id == current_worker_task_id {
704 debug!(?worker_task_id, "Target process exited successfully.");
705 break;
706 } else {
707 debug!(?worker_task_id, "Non-target process exited successfully. Continuing to wait.");
708 self.worker_map.swap_remove(&worker_task_id);
709 }
710 },
711 Some(Err(e)) => {
712 let worker_task_id = e.id();
713 if worker_task_id == current_worker_task_id {
714 debug!(?worker_task_id, "Target process exited with error.");
715 break;
716 } else {
717 debug!(?worker_task_id, "Non-target process exited with error. Continuing to wait.");
718 self.worker_map.swap_remove(&worker_task_id);
719 }
720 }
721 None => unreachable!("worker task must exist in join set if we are waiting for it"),
722 }
723 },
724 _ = &mut shutdown_deadline => {
726 debug!(worker_id, "Shutdown timeout expired, forcefully aborting process.");
727 abort_handle.abort();
728 }
729 }
730 }
731 }
732
733 debug_assert!(self.worker_map.is_empty(), "worker map should be empty after shutdown");
734 debug_assert!(
735 self.worker_tasks.is_empty(),
736 "worker tasks should be empty after shutdown"
737 );
738 }
739}
740
741#[cfg(test)]
742mod tests {
743 use std::sync::atomic::{AtomicUsize, Ordering};
744
745 use async_trait::async_trait;
746 use tokio::{
747 sync::oneshot,
748 task::JoinHandle,
749 time::{sleep, timeout},
750 };
751
752 use super::*;
753
754 #[derive(Clone)]
756 enum InitBehavior {
757 Instant,
759
760 Slow(Duration),
762
763 Fail(&'static str),
765 }
766
767 #[derive(Clone)]
769 enum RunBehavior {
770 UntilShutdown,
772
773 FailAfter(Duration, &'static str),
775 }
776
777 struct MockWorker {
779 name: &'static str,
780 init_behavior: InitBehavior,
781 run_behavior: RunBehavior,
782 start_count: Arc<AtomicUsize>,
783 }
784
785 impl MockWorker {
786 fn long_running(name: &'static str) -> Self {
788 Self {
789 name,
790 init_behavior: InitBehavior::Instant,
791 run_behavior: RunBehavior::UntilShutdown,
792 start_count: Arc::new(AtomicUsize::new(0)),
793 }
794 }
795
796 fn failing(name: &'static str, delay: Duration) -> Self {
798 Self {
799 name,
800 init_behavior: InitBehavior::Instant,
801 run_behavior: RunBehavior::FailAfter(delay, "worker failed"),
802 start_count: Arc::new(AtomicUsize::new(0)),
803 }
804 }
805
806 fn init_failure(name: &'static str) -> Self {
808 Self {
809 name,
810 init_behavior: InitBehavior::Fail("init failed"),
811 run_behavior: RunBehavior::UntilShutdown,
812 start_count: Arc::new(AtomicUsize::new(0)),
813 }
814 }
815
816 fn slow_init(name: &'static str, init_delay: Duration) -> Self {
818 Self {
819 name,
820 init_behavior: InitBehavior::Slow(init_delay),
821 run_behavior: RunBehavior::UntilShutdown,
822 start_count: Arc::new(AtomicUsize::new(0)),
823 }
824 }
825
826 fn start_count(&self) -> Arc<AtomicUsize> {
828 Arc::clone(&self.start_count)
829 }
830 }
831
832 #[async_trait]
833 impl Supervisable for MockWorker {
834 fn name(&self) -> &str {
835 self.name
836 }
837
838 fn shutdown_strategy(&self) -> ShutdownStrategy {
839 ShutdownStrategy::Graceful(Duration::from_millis(500))
840 }
841
842 async fn initialize(
843 &self, mut process_shutdown: ProcessShutdown,
844 ) -> Result<SupervisorFuture, InitializationError> {
845 match &self.init_behavior {
846 InitBehavior::Instant => {}
847 InitBehavior::Slow(delay) => {
848 sleep(*delay).await;
849 }
850 InitBehavior::Fail(msg) => {
851 return Err(InitializationError::Failed {
852 source: GenericError::msg(*msg),
853 });
854 }
855 }
856
857 let start_count = Arc::clone(&self.start_count);
858 let run_behavior = self.run_behavior.clone();
859
860 Ok(Box::pin(async move {
861 start_count.fetch_add(1, Ordering::SeqCst);
862
863 match run_behavior {
864 RunBehavior::UntilShutdown => {
865 process_shutdown.wait_for_shutdown().await;
866 Ok(())
867 }
868 RunBehavior::FailAfter(delay, msg) => {
869 select! {
870 _ = sleep(delay) => {
871 Err(GenericError::msg(msg))
872 }
873 _ = process_shutdown.wait_for_shutdown() => {
874 Ok(())
875 }
876 }
877 }
878 }
879 }))
880 }
881 }
882
883 async fn run_supervisor_with_trigger(
886 mut supervisor: Supervisor,
887 ) -> (oneshot::Sender<()>, JoinHandle<Result<(), SupervisorError>>) {
888 let (tx, rx) = oneshot::channel();
889 let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await });
890 sleep(Duration::from_millis(50)).await;
892 (tx, handle)
893 }
894
895 #[tokio::test]
898 async fn standalone_supervisor_shuts_down_cleanly() {
899 let mut sup = Supervisor::new("test-sup").unwrap();
900 sup.add_worker(MockWorker::long_running("worker1"));
901 sup.add_worker(MockWorker::long_running("worker2"));
902
903 let (tx, handle) = run_supervisor_with_trigger(sup).await;
904 tx.send(()).unwrap();
905
906 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
907 assert!(result.is_ok());
908 }
909
910 #[tokio::test]
911 async fn nested_supervisor_shuts_down_cleanly() {
912 let mut child_sup = Supervisor::new("child-sup").unwrap();
913 child_sup.add_worker(MockWorker::long_running("inner-worker"));
914
915 let mut parent_sup = Supervisor::new("parent-sup").unwrap();
916 parent_sup.add_worker(MockWorker::long_running("outer-worker"));
917 parent_sup.add_worker(child_sup);
918
919 let (tx, handle) = run_supervisor_with_trigger(parent_sup).await;
920 tx.send(()).unwrap();
921
922 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
923 assert!(result.is_ok());
924 }
925
926 #[tokio::test]
927 async fn supervisor_with_no_children_returns_error() {
928 let mut sup = Supervisor::new("empty-sup").unwrap();
929
930 let (tx, rx) = oneshot::channel::<()>();
931 let result = sup.run_with_shutdown(rx).await;
932 drop(tx);
933
934 assert!(matches!(result, Err(SupervisorError::NoChildren)));
935 }
936
937 #[tokio::test]
940 async fn one_for_one_restarts_only_failed_child() {
941 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
942 let failing_count = failing.start_count();
943
944 let stable = MockWorker::long_running("stable-worker");
945 let stable_count = stable.start_count();
946
947 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
948 RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)),
949 );
950 sup.add_worker(stable);
951 sup.add_worker(failing);
952
953 let (tx, handle) = run_supervisor_with_trigger(sup).await;
954
955 sleep(Duration::from_millis(300)).await;
957 let _ = tx.send(());
958
959 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
960 assert!(result.is_ok());
961
962 assert!(
964 failing_count.load(Ordering::SeqCst) >= 2,
965 "failing worker should have been restarted"
966 );
967 assert_eq!(
969 stable_count.load(Ordering::SeqCst),
970 1,
971 "stable worker should not have been restarted"
972 );
973 }
974
975 #[tokio::test]
976 async fn one_for_all_restarts_all_children() {
977 let failing = MockWorker::failing("failing-worker", Duration::from_millis(50));
978 let failing_count = failing.start_count();
979
980 let stable = MockWorker::long_running("stable-worker");
981 let stable_count = stable.start_count();
982
983 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
984 RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)),
985 );
986 sup.add_worker(stable);
987 sup.add_worker(failing);
988
989 let (tx, handle) = run_supervisor_with_trigger(sup).await;
990
991 sleep(Duration::from_millis(300)).await;
993 let _ = tx.send(());
994
995 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
996 assert!(result.is_ok());
997
998 assert!(
1000 failing_count.load(Ordering::SeqCst) >= 2,
1001 "failing worker should have been restarted"
1002 );
1003 assert!(
1004 stable_count.load(Ordering::SeqCst) >= 2,
1005 "stable worker should also have been restarted"
1006 );
1007 }
1008
1009 #[tokio::test]
1010 async fn restart_limit_exceeded_shuts_down_supervisor() {
1011 let mut sup = Supervisor::new("test-sup")
1012 .unwrap()
1013 .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10)));
1014 sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO));
1016
1017 let (tx, rx) = oneshot::channel::<()>();
1018 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1019
1020 let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap();
1021 drop(tx);
1022
1023 assert!(matches!(result, Err(SupervisorError::Shutdown)));
1024 }
1025
1026 #[tokio::test]
1029 async fn init_failure_propagates_with_child_name() {
1030 let mut sup = Supervisor::new("test-sup").unwrap();
1031 sup.add_worker(MockWorker::long_running("good-worker"));
1032 sup.add_worker(MockWorker::init_failure("bad-worker"));
1033
1034 let (_tx, rx) = oneshot::channel::<()>();
1035 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1036 .await
1037 .unwrap();
1038
1039 match result {
1040 Err(SupervisorError::FailedToInitialize { child_name, .. }) => {
1041 assert_eq!(child_name, "bad-worker");
1042 }
1043 other => panic!("expected FailedToInitialize, got: {:?}", other),
1044 }
1045 }
1046
1047 #[tokio::test]
1048 async fn init_failure_does_not_trigger_restart() {
1049 let init_fail = MockWorker::init_failure("bad-worker");
1050 let start_count = init_fail.start_count();
1051
1052 let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy(
1053 RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)),
1054 );
1055 sup.add_worker(init_fail);
1056
1057 let (_tx, rx) = oneshot::channel::<()>();
1058 let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx))
1059 .await
1060 .unwrap();
1061
1062 assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. })));
1063 assert_eq!(start_count.load(Ordering::SeqCst), 0);
1065 }
1066
1067 #[tokio::test]
1070 async fn shutdown_completes_promptly_in_steady_state() {
1071 let mut sup = Supervisor::new("test-sup").unwrap();
1072 sup.add_worker(MockWorker::long_running("worker1"));
1073 sup.add_worker(MockWorker::long_running("worker2"));
1074
1075 let (tx, handle) = run_supervisor_with_trigger(sup).await;
1076 tx.send(()).unwrap();
1077
1078 let result = timeout(Duration::from_secs(1), handle).await;
1080 assert!(result.is_ok(), "shutdown should complete promptly");
1081 }
1082
1083 #[tokio::test]
1084 async fn shutdown_during_slow_init_completes_promptly() {
1085 let mut sup = Supervisor::new("test-sup").unwrap();
1086 sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30)));
1088
1089 let (tx, rx) = oneshot::channel();
1090 let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await });
1091
1092 sleep(Duration::from_millis(20)).await;
1094 tx.send(()).unwrap();
1095
1096 let result = timeout(Duration::from_secs(2), handle).await;
1099 assert!(result.is_ok(), "shutdown during slow init should complete promptly");
1100 }
1101}