saluki_io/net/util/retry/
backoff.rs1use std::{
2 convert::Infallible,
3 fmt,
4 sync::{Arc, Mutex},
5 time::Duration,
6};
7
8use rand::{rand_core::TryRng, rng, Rng as _, RngExt as _};
9
10#[derive(Clone)]
11pub enum BackoffRng {
12 SecureDefault,
18
19 Shared(Arc<Mutex<Box<dyn TryRng<Error = Infallible> + Send + Sync>>>),
21}
22
23impl fmt::Debug for BackoffRng {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 BackoffRng::SecureDefault => f.debug_tuple("SecureDefault").finish(),
27 BackoffRng::Shared(_) => f.debug_tuple("Shared").finish(),
28 }
29 }
30}
31
32impl TryRng for BackoffRng {
33 type Error = Infallible;
34
35 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
36 Ok(match self {
37 BackoffRng::SecureDefault => rng().next_u32(),
38 BackoffRng::Shared(rng) => rng.lock().unwrap().next_u32(),
39 })
40 }
41
42 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
43 Ok(match self {
44 BackoffRng::SecureDefault => rng().next_u64(),
45 BackoffRng::Shared(rng) => rng.lock().unwrap().next_u64(),
46 })
47 }
48
49 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
50 match self {
51 BackoffRng::SecureDefault => rng().fill_bytes(dst),
52 BackoffRng::Shared(rng) => rng.lock().unwrap().fill_bytes(dst),
53 }
54 Ok(())
55 }
56}
57
58#[derive(Clone, Debug)]
64pub struct ExponentialBackoff {
65 min_backoff: Duration,
66 max_backoff: Duration,
67 min_backoff_factor: f64,
68 rng: BackoffRng,
69}
70
71impl ExponentialBackoff {
72 pub fn new(min_backoff: Duration, max_backoff: Duration) -> Self {
76 Self {
77 min_backoff,
78 max_backoff,
79 min_backoff_factor: 1.0,
80 rng: BackoffRng::SecureDefault,
81 }
82 }
83
84 pub fn with_jitter(min_backoff: Duration, max_backoff: Duration, min_backoff_factor: f64) -> Self {
96 Self {
97 min_backoff,
98 max_backoff,
99 min_backoff_factor: min_backoff_factor.max(1.0),
100 rng: BackoffRng::SecureDefault,
101 }
102 }
103
104 pub fn with_rng<R>(self, rng: R) -> Self
111 where
112 R: TryRng<Error = Infallible> + Send + Sync + 'static,
113 {
114 ExponentialBackoff {
115 min_backoff: self.min_backoff,
116 max_backoff: self.max_backoff,
117 min_backoff_factor: self.min_backoff_factor,
118 rng: BackoffRng::Shared(Arc::new(Mutex::new(Box::new(rng)))),
119 }
120 }
121
122 pub fn get_backoff_duration(&mut self, error_count: u32) -> Duration {
127 if error_count == 0 {
128 return self.min_backoff;
129 }
130
131 let mut backoff = self.min_backoff.saturating_mul(2u32.saturating_pow(error_count));
132
133 if self.min_backoff_factor > 1.0 {
135 let backoff_lower = backoff.div_f64(self.min_backoff_factor);
136 let backoff_upper = backoff;
137 backoff = self.rng.random_range(backoff_lower..=backoff_upper)
138 }
139
140 backoff.clamp(self.min_backoff, self.max_backoff)
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::{convert::Infallible, time::Duration};
147
148 use proptest::prelude::*;
149 use rand::rand_core::TryRng;
150
151 use crate::net::util::retry::ExponentialBackoff;
152
153 struct PropTestRng(proptest::test_runner::TestRng);
158
159 impl TryRng for PropTestRng {
160 type Error = Infallible;
161
162 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
163 use proptest::prelude::RngCore as _;
164 Ok(self.0.next_u32())
165 }
166
167 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
168 use proptest::prelude::RngCore as _;
169 Ok(self.0.next_u64())
170 }
171
172 fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
173 use proptest::prelude::RngCore as _;
174 self.0.fill_bytes(dst);
175 Ok(())
176 }
177 }
178
179 fn arb_exponential_backoff(min_backoff_factor: f64) -> impl Strategy<Value = ExponentialBackoff> {
180 (1u64..=u64::MAX, 1u64..u64::MAX)
181 .prop_map(move |(min_backoff, max_backoff)| {
182 let max_backoff = min_backoff.saturating_add(max_backoff);
183 ExponentialBackoff::with_jitter(
184 Duration::from_nanos(min_backoff),
185 Duration::from_nanos(max_backoff),
186 min_backoff_factor,
187 )
188 })
189 .prop_perturb(|backoff, rng| backoff.with_rng(PropTestRng(rng)))
190 }
191
192 proptest! {
193 #[test]
194 fn property_test_exponential_backoff_no_jitter(
195 mut backoff in arb_exponential_backoff(1.0),
196 error_count in 0..u32::MAX,
197 error_count_increase in 1..5u32
198 ) {
199 let first = backoff.get_backoff_duration(error_count);
202 let first_followup = backoff.get_backoff_duration(error_count);
203 let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
204 let second_followup = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
205
206 assert_eq!(first, first_followup);
207 assert_eq!(second, second_followup);
208 assert!(first <= second);
209 assert!(first >= backoff.min_backoff);
210 assert!(first <= backoff.max_backoff);
211 assert!(second >= backoff.min_backoff);
212 assert!(second <= backoff.max_backoff);
213 }
214
215 #[test]
216 fn property_test_exponential_backoff_default_jitter(
217 mut backoff in arb_exponential_backoff(2.0),
218 error_count in 0..u32::MAX,
219 error_count_increase in 1..5u32
220 ) {
221 let first = backoff.get_backoff_duration(error_count);
224 let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
225
226 assert!(first <= second);
227 assert!(first >= backoff.min_backoff);
228 assert!(first <= backoff.max_backoff);
229 assert!(second >= backoff.min_backoff);
230 assert!(second <= backoff.max_backoff);
231 }
232 }
233}