Skip to main content

saluki_io/net/util/retry/
backoff.rs

1use std::{
2    convert::Infallible,
3    fmt,
4    sync::{Arc, Mutex},
5    time::Duration,
6};
7
8use rand::{rand_core::TryRng, rng, Rng as _, RngExt as _};
9
10#[derive(Clone)]
11pub enum BackoffRng {
12    /// A lazily initialized, thread-local CSPRNG seeded by the operating system.
13    ///
14    /// Provided by [`rand::ThreadRng`][rand_threadrng].
15    ///
16    /// [rand_threadrng]: https://docs.rs/rand/latest/rand/rngs/struct.ThreadRng.html
17    SecureDefault,
18
19    /// A shared random number generator.
20    Shared(Arc<Mutex<Box<dyn TryRng<Error = Infallible> + Send + Sync>>>),
21}
22
23impl fmt::Debug for BackoffRng {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            BackoffRng::SecureDefault => f.debug_tuple("SecureDefault").finish(),
27            BackoffRng::Shared(_) => f.debug_tuple("Shared").finish(),
28        }
29    }
30}
31
32impl TryRng for BackoffRng {
33    type Error = Infallible;
34
35    fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
36        Ok(match self {
37            BackoffRng::SecureDefault => rng().next_u32(),
38            BackoffRng::Shared(rng) => rng.lock().unwrap().next_u32(),
39        })
40    }
41
42    fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
43        Ok(match self {
44            BackoffRng::SecureDefault => rng().next_u64(),
45            BackoffRng::Shared(rng) => rng.lock().unwrap().next_u64(),
46        })
47    }
48
49    fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
50        match self {
51            BackoffRng::SecureDefault => rng().fill_bytes(dst),
52            BackoffRng::Shared(rng) => rng.lock().unwrap().fill_bytes(dst),
53        }
54        Ok(())
55    }
56}
57
58/// An exponential backoff strategy.
59///
60/// This backoff strategy provides backoff durations that increase exponentially based on a user-provided error count,
61/// with a minimum and maximum bound on the duration. Additionally, jitter can be added to the backoff duration in order
62/// to help avoiding multiple callers retrying their requests at the same time.
63#[derive(Clone, Debug)]
64pub struct ExponentialBackoff {
65    min_backoff: Duration,
66    max_backoff: Duration,
67    min_backoff_factor: f64,
68    rng: BackoffRng,
69}
70
71impl ExponentialBackoff {
72    /// Creates a new `ExponentialBackoff` with the given minimum and maximum backoff durations.
73    ///
74    /// Jitter is not applied to the calculated backoff durations.
75    pub fn new(min_backoff: Duration, max_backoff: Duration) -> Self {
76        Self {
77            min_backoff,
78            max_backoff,
79            min_backoff_factor: 1.0,
80            rng: BackoffRng::SecureDefault,
81        }
82    }
83
84    /// Creates a new `ExponentialBackoff` with the given minimum and maximum backoff durations, and minimum backoff
85    /// factor.
86    ///
87    /// Jitter is applied to the calculated backoff durations based on the minimum backoff factor, such that any given
88    /// backoff duration will be between `D/min_backoff_factor` and `D`, where `D` is the calculated backoff duration
89    /// for the given external error count. If the minimum backoff factor is set to 1.0 or less, then jitter will be
90    /// disabled.
91    ///
92    /// Concretely, this means that with a minimum backoff duration of 10ms, and a minimum backoff factor of 2.0, the
93    /// duration for an error count of one would be 20ms without jitter, but anywhere between 10ms and 20ms with jitter.
94    /// For an error count of two, it be 40ms without jitter, but anywhere between 20ms and 40ms with jitter.
95    pub fn with_jitter(min_backoff: Duration, max_backoff: Duration, min_backoff_factor: f64) -> Self {
96        Self {
97            min_backoff,
98            max_backoff,
99            min_backoff_factor: min_backoff_factor.max(1.0),
100            rng: BackoffRng::SecureDefault,
101        }
102    }
103
104    /// Sets the random number generator to use for calculating jittered backoff durations.
105    ///
106    /// Useful for testing purposes, where the RNG must be overridden to add determinism. The RNG is shared atomically
107    /// behind a mutex, allowing it to be cloned, so care should be taken to never use this outside of tests.
108    ///
109    /// Defaults to a lazily initialized, thread-local CSPRNG seeded by the operating system.
110    pub fn with_rng<R>(self, rng: R) -> Self
111    where
112        R: TryRng<Error = Infallible> + Send + Sync + 'static,
113    {
114        ExponentialBackoff {
115            min_backoff: self.min_backoff,
116            max_backoff: self.max_backoff,
117            min_backoff_factor: self.min_backoff_factor,
118            rng: BackoffRng::Shared(Arc::new(Mutex::new(Box::new(rng)))),
119        }
120    }
121
122    /// Calculates the backoff duration for the given error count.
123    ///
124    /// The error count value is generally user-defined, but should constitute the number of consecutive errors, or
125    /// attempts, that have been made when retrying an operation or request.
126    pub fn get_backoff_duration(&mut self, error_count: u32) -> Duration {
127        if error_count == 0 {
128            return self.min_backoff;
129        }
130
131        let mut backoff = self.min_backoff.saturating_mul(2u32.saturating_pow(error_count));
132
133        // Apply jitter if necessary.
134        if self.min_backoff_factor > 1.0 {
135            let backoff_lower = backoff.div_f64(self.min_backoff_factor);
136            let backoff_upper = backoff;
137            backoff = self.rng.random_range(backoff_lower..=backoff_upper)
138        }
139
140        backoff.clamp(self.min_backoff, self.max_backoff)
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use std::{convert::Infallible, time::Duration};
147
148    use proptest::prelude::*;
149    use rand::rand_core::TryRng;
150
151    use crate::net::util::retry::ExponentialBackoff;
152
153    /// Adapter to bridge `proptest`'s `TestRng` (`rand` 0.9 `RngCore`) to `rand` 0.10's `TryRng`. Created by Claude as
154    /// a workaround to compilation issues when updating our workspace version of `rand` to 0.10 while `proptest` was
155    /// still using 0.9.
156    // TODO: remove this when proptest updates to 0.10
157    struct PropTestRng(proptest::test_runner::TestRng);
158
159    impl TryRng for PropTestRng {
160        type Error = Infallible;
161
162        fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
163            use proptest::prelude::RngCore as _;
164            Ok(self.0.next_u32())
165        }
166
167        fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
168            use proptest::prelude::RngCore as _;
169            Ok(self.0.next_u64())
170        }
171
172        fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
173            use proptest::prelude::RngCore as _;
174            self.0.fill_bytes(dst);
175            Ok(())
176        }
177    }
178
179    fn arb_exponential_backoff(min_backoff_factor: f64) -> impl Strategy<Value = ExponentialBackoff> {
180        (1u64..=u64::MAX, 1u64..u64::MAX)
181            .prop_map(move |(min_backoff, max_backoff)| {
182                let max_backoff = min_backoff.saturating_add(max_backoff);
183                ExponentialBackoff::with_jitter(
184                    Duration::from_nanos(min_backoff),
185                    Duration::from_nanos(max_backoff),
186                    min_backoff_factor,
187                )
188            })
189            .prop_perturb(|backoff, rng| backoff.with_rng(PropTestRng(rng)))
190    }
191
192    proptest! {
193        #[test]
194        fn property_test_exponential_backoff_no_jitter(
195            mut backoff in arb_exponential_backoff(1.0),
196            error_count in 0..u32::MAX,
197            error_count_increase in 1..5u32
198        ) {
199            // The goal of this test is to show that for some arbitrary error count, the calculated backoff duration we
200            // get is always less than or equal to the calculated backoff duration for an error count that is _larger_.
201            let first = backoff.get_backoff_duration(error_count);
202            let first_followup = backoff.get_backoff_duration(error_count);
203            let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
204            let second_followup = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
205
206            assert_eq!(first, first_followup);
207            assert_eq!(second, second_followup);
208            assert!(first <= second);
209            assert!(first >= backoff.min_backoff);
210            assert!(first <= backoff.max_backoff);
211            assert!(second >= backoff.min_backoff);
212            assert!(second <= backoff.max_backoff);
213        }
214
215        #[test]
216        fn property_test_exponential_backoff_default_jitter(
217            mut backoff in arb_exponential_backoff(2.0),
218            error_count in 0..u32::MAX,
219            error_count_increase in 1..5u32
220        ) {
221            // The goal of this test is to show that for some arbitrary error count, the calculated backoff duration we
222            // get is always less than or equal to the calculated backoff duration for an error count that is _larger_.
223            let first = backoff.get_backoff_duration(error_count);
224            let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
225
226            assert!(first <= second);
227            assert!(first >= backoff.min_backoff);
228            assert!(first <= backoff.max_backoff);
229            assert!(second >= backoff.min_backoff);
230            assert!(second <= backoff.max_backoff);
231        }
232    }
233}