saluki_io/net/util/retry/
backoff.rs

1use std::{
2    fmt,
3    sync::{Arc, Mutex},
4    time::Duration,
5};
6
7use rand::{rng, Rng as _, RngCore};
8
9#[derive(Clone)]
10pub enum BackoffRng {
11    /// A lazily-initialized, thread-local CSPRNG seeded by the operating system.
12    ///
13    /// Provided by [`rand::ThreadRng`][rand_threadrng].
14    ///
15    /// [rand_threadrng]: https://docs.rs/rand/latest/rand/rngs/struct.ThreadRng.html
16    SecureDefault,
17
18    /// A shared random number generator.
19    Shared(Arc<Mutex<Box<dyn RngCore + Send + Sync>>>),
20}
21
22impl fmt::Debug for BackoffRng {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            BackoffRng::SecureDefault => f.debug_tuple("SecureDefault").finish(),
26            BackoffRng::Shared(_) => f.debug_tuple("Shared").finish(),
27        }
28    }
29}
30
31impl RngCore for BackoffRng {
32    fn next_u32(&mut self) -> u32 {
33        match self {
34            BackoffRng::SecureDefault => rng().next_u32(),
35            BackoffRng::Shared(rng) => rng.lock().unwrap().next_u32(),
36        }
37    }
38
39    fn next_u64(&mut self) -> u64 {
40        match self {
41            BackoffRng::SecureDefault => rng().next_u64(),
42            BackoffRng::Shared(rng) => rng.lock().unwrap().next_u64(),
43        }
44    }
45
46    fn fill_bytes(&mut self, dest: &mut [u8]) {
47        match self {
48            BackoffRng::SecureDefault => rng().fill_bytes(dest),
49            BackoffRng::Shared(rng) => rng.lock().unwrap().fill_bytes(dest),
50        }
51    }
52}
53
54/// An exponential backoff strategy.
55///
56/// This backoff strategy provides backoff durations that increase exponentially based on a user-provided error count,
57/// with a minimum and maximum bound on the duration. Additionally, jitter can be added to the backoff duration in order
58/// to help avoiding multiple callers retrying their requests at the same time.
59#[derive(Clone, Debug)]
60pub struct ExponentialBackoff {
61    min_backoff: Duration,
62    max_backoff: Duration,
63    min_backoff_factor: f64,
64    rng: BackoffRng,
65}
66
67impl ExponentialBackoff {
68    /// Creates a new `ExponentialBackoff` with the given minimum and maximum backoff durations.
69    ///
70    /// Jitter is not applied to the calculated backoff durations.
71    pub fn new(min_backoff: Duration, max_backoff: Duration) -> Self {
72        Self {
73            min_backoff,
74            max_backoff,
75            min_backoff_factor: 1.0,
76            rng: BackoffRng::SecureDefault,
77        }
78    }
79
80    /// Creates a new `ExponentialBackoff` with the given minimum and maximum backoff durations, and minimum backoff
81    /// factor.
82    ///
83    /// Jitter is applied to the calculated backoff durations based on the minimum backoff factor, such that any given
84    /// backoff duration will be between `D/min_backoff_factor` and `D`, where `D` is the calculated backoff duration
85    /// for the given external error count. If the minimum backoff factor is set to 1.0 or less, then jitter will be
86    /// disabled.
87    ///
88    /// Concretely, this means that with a minimum backoff duration of 10ms, and a minimum backoff factor of 2.0, the
89    /// duration for an error count of one would be 20ms without jitter, but anywhere between 10ms and 20ms with jitter.
90    /// For an error count of two, it be 40ms without jitter, but anywhere between 20ms and 40ms with jitter.
91    pub fn with_jitter(min_backoff: Duration, max_backoff: Duration, min_backoff_factor: f64) -> Self {
92        Self {
93            min_backoff,
94            max_backoff,
95            min_backoff_factor: min_backoff_factor.max(1.0),
96            rng: BackoffRng::SecureDefault,
97        }
98    }
99
100    /// Sets the random number generator to use for calculating jittered backoff durations.
101    ///
102    /// Useful for testing purposes, where the RNG must be overridden to add determinism. The RNG is shared atomically
103    /// behind a mutex, allowing it to be cloned, so care should be taken to never use this outside of tests.
104    ///
105    /// Defaults to a lazily-initialized, thread-local CSPRNG seeded by the operating system.
106    pub fn with_rng<R>(self, rng: R) -> Self
107    where
108        R: RngCore + Send + Sync + 'static,
109    {
110        ExponentialBackoff {
111            min_backoff: self.min_backoff,
112            max_backoff: self.max_backoff,
113            min_backoff_factor: self.min_backoff_factor,
114            rng: BackoffRng::Shared(Arc::new(Mutex::new(Box::new(rng)))),
115        }
116    }
117
118    /// Calculates the backoff duration for the given error count.
119    ///
120    /// The error count value is generally user-defined, but should constitute the number of consecutive errors, or
121    /// attempts, that have been made when retrying an operation or request.
122    pub fn get_backoff_duration(&mut self, error_count: u32) -> Duration {
123        if error_count == 0 {
124            return self.min_backoff;
125        }
126
127        let mut backoff = self.min_backoff.saturating_mul(2u32.saturating_pow(error_count));
128
129        // Apply jitter if necessary.
130        if self.min_backoff_factor > 1.0 {
131            let backoff_lower = backoff.div_f64(self.min_backoff_factor);
132            let backoff_upper = backoff;
133            backoff = self.rng.random_range(backoff_lower..=backoff_upper)
134        }
135
136        backoff.clamp(self.min_backoff, self.max_backoff)
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use std::time::Duration;
143
144    use proptest::prelude::*;
145
146    use crate::net::util::retry::ExponentialBackoff;
147
148    fn arb_exponential_backoff(min_backoff_factor: f64) -> impl Strategy<Value = ExponentialBackoff> {
149        (1u64..=u64::MAX, 1u64..u64::MAX)
150            .prop_map(move |(min_backoff, max_backoff)| {
151                let max_backoff = min_backoff.saturating_add(max_backoff);
152                ExponentialBackoff::with_jitter(
153                    Duration::from_nanos(min_backoff),
154                    Duration::from_nanos(max_backoff),
155                    min_backoff_factor,
156                )
157            })
158            .prop_perturb(|backoff, rng| backoff.with_rng(rng))
159    }
160
161    proptest! {
162        #[test]
163        fn property_test_exponential_backoff_no_jitter(
164            mut backoff in arb_exponential_backoff(1.0),
165            error_count in 0..u32::MAX,
166            error_count_increase in 1..5u32
167        ) {
168            // The goal of this test is to show that for some arbitrary error count, the calculated backoff duration we
169            // get is always less than or equal to the calculated backoff duration for an error count that is _larger_.
170            let first = backoff.get_backoff_duration(error_count);
171            let first_followup = backoff.get_backoff_duration(error_count);
172            let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
173            let second_followup = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
174
175            assert_eq!(first, first_followup);
176            assert_eq!(second, second_followup);
177            assert!(first <= second);
178            assert!(first >= backoff.min_backoff);
179            assert!(first <= backoff.max_backoff);
180            assert!(second >= backoff.min_backoff);
181            assert!(second <= backoff.max_backoff);
182        }
183
184        #[test]
185        fn property_test_exponential_backoff_default_jitter(
186            mut backoff in arb_exponential_backoff(2.0),
187            error_count in 0..u32::MAX,
188            error_count_increase in 1..5u32
189        ) {
190            // The goal of this test is to show that for some arbitrary error count, the calculated backoff duration we
191            // get is always less than or equal to the calculated backoff duration for an error count that is _larger_.
192            let first = backoff.get_backoff_duration(error_count);
193            let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
194
195            assert!(first <= second);
196            assert!(first >= backoff.min_backoff);
197            assert!(first <= backoff.max_backoff);
198            assert!(second >= backoff.min_backoff);
199            assert!(second <= backoff.max_backoff);
200        }
201    }
202}