saluki_io/net/util/retry/
backoff.rs1use std::{
2 fmt,
3 sync::{Arc, Mutex},
4 time::Duration,
5};
6
7use rand::{rng, Rng as _, RngCore};
8
9#[derive(Clone)]
10pub enum BackoffRng {
11 SecureDefault,
17
18 Shared(Arc<Mutex<Box<dyn RngCore + Send + Sync>>>),
20}
21
22impl fmt::Debug for BackoffRng {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 BackoffRng::SecureDefault => f.debug_tuple("SecureDefault").finish(),
26 BackoffRng::Shared(_) => f.debug_tuple("Shared").finish(),
27 }
28 }
29}
30
31impl RngCore for BackoffRng {
32 fn next_u32(&mut self) -> u32 {
33 match self {
34 BackoffRng::SecureDefault => rng().next_u32(),
35 BackoffRng::Shared(rng) => rng.lock().unwrap().next_u32(),
36 }
37 }
38
39 fn next_u64(&mut self) -> u64 {
40 match self {
41 BackoffRng::SecureDefault => rng().next_u64(),
42 BackoffRng::Shared(rng) => rng.lock().unwrap().next_u64(),
43 }
44 }
45
46 fn fill_bytes(&mut self, dest: &mut [u8]) {
47 match self {
48 BackoffRng::SecureDefault => rng().fill_bytes(dest),
49 BackoffRng::Shared(rng) => rng.lock().unwrap().fill_bytes(dest),
50 }
51 }
52}
53
54#[derive(Clone, Debug)]
60pub struct ExponentialBackoff {
61 min_backoff: Duration,
62 max_backoff: Duration,
63 min_backoff_factor: f64,
64 rng: BackoffRng,
65}
66
67impl ExponentialBackoff {
68 pub fn new(min_backoff: Duration, max_backoff: Duration) -> Self {
72 Self {
73 min_backoff,
74 max_backoff,
75 min_backoff_factor: 1.0,
76 rng: BackoffRng::SecureDefault,
77 }
78 }
79
80 pub fn with_jitter(min_backoff: Duration, max_backoff: Duration, min_backoff_factor: f64) -> Self {
92 Self {
93 min_backoff,
94 max_backoff,
95 min_backoff_factor: min_backoff_factor.max(1.0),
96 rng: BackoffRng::SecureDefault,
97 }
98 }
99
100 pub fn with_rng<R>(self, rng: R) -> Self
107 where
108 R: RngCore + Send + Sync + 'static,
109 {
110 ExponentialBackoff {
111 min_backoff: self.min_backoff,
112 max_backoff: self.max_backoff,
113 min_backoff_factor: self.min_backoff_factor,
114 rng: BackoffRng::Shared(Arc::new(Mutex::new(Box::new(rng)))),
115 }
116 }
117
118 pub fn get_backoff_duration(&mut self, error_count: u32) -> Duration {
123 if error_count == 0 {
124 return self.min_backoff;
125 }
126
127 let mut backoff = self.min_backoff.saturating_mul(2u32.saturating_pow(error_count));
128
129 if self.min_backoff_factor > 1.0 {
131 let backoff_lower = backoff.div_f64(self.min_backoff_factor);
132 let backoff_upper = backoff;
133 backoff = self.rng.random_range(backoff_lower..=backoff_upper)
134 }
135
136 backoff.clamp(self.min_backoff, self.max_backoff)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use std::time::Duration;
143
144 use proptest::prelude::*;
145
146 use crate::net::util::retry::ExponentialBackoff;
147
148 fn arb_exponential_backoff(min_backoff_factor: f64) -> impl Strategy<Value = ExponentialBackoff> {
149 (1u64..=u64::MAX, 1u64..u64::MAX)
150 .prop_map(move |(min_backoff, max_backoff)| {
151 let max_backoff = min_backoff.saturating_add(max_backoff);
152 ExponentialBackoff::with_jitter(
153 Duration::from_nanos(min_backoff),
154 Duration::from_nanos(max_backoff),
155 min_backoff_factor,
156 )
157 })
158 .prop_perturb(|backoff, rng| backoff.with_rng(rng))
159 }
160
161 proptest! {
162 #[test]
163 fn property_test_exponential_backoff_no_jitter(
164 mut backoff in arb_exponential_backoff(1.0),
165 error_count in 0..u32::MAX,
166 error_count_increase in 1..5u32
167 ) {
168 let first = backoff.get_backoff_duration(error_count);
171 let first_followup = backoff.get_backoff_duration(error_count);
172 let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
173 let second_followup = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
174
175 assert_eq!(first, first_followup);
176 assert_eq!(second, second_followup);
177 assert!(first <= second);
178 assert!(first >= backoff.min_backoff);
179 assert!(first <= backoff.max_backoff);
180 assert!(second >= backoff.min_backoff);
181 assert!(second <= backoff.max_backoff);
182 }
183
184 #[test]
185 fn property_test_exponential_backoff_default_jitter(
186 mut backoff in arb_exponential_backoff(2.0),
187 error_count in 0..u32::MAX,
188 error_count_increase in 1..5u32
189 ) {
190 let first = backoff.get_backoff_duration(error_count);
193 let second = backoff.get_backoff_duration(error_count.saturating_add(error_count_increase));
194
195 assert!(first <= second);
196 assert!(first >= backoff.min_backoff);
197 assert!(first <= backoff.max_backoff);
198 assert!(second >= backoff.min_backoff);
199 assert!(second <= backoff.max_backoff);
200 }
201 }
202}