saluki_io/net/util/middleware/
retry_circuit_breaker.rs

1use std::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    sync::{Arc, Mutex},
6    task::{ready, Context, Poll},
7};
8
9use futures::FutureExt as _;
10use pin_project_lite::pin_project;
11use tower::{retry::Policy, Layer, Service};
12use tracing::debug;
13
14/// An error from [`RetryCircuitBreaker`].
15#[derive(Debug)]
16pub enum Error<E, R> {
17    /// The inner service responded with an error.
18    Service(E),
19
20    /// The circuit breaker is open and requests are being rejected.
21    Open(R),
22}
23
24impl<E, R> std::error::Error for Error<E, R>
25where
26    E: std::error::Error + 'static,
27    R: fmt::Debug,
28{
29    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
30        match self {
31            Self::Service(e) => Some(e),
32            Self::Open(_) => None,
33        }
34    }
35}
36
37impl<E, R> fmt::Display for Error<E, R>
38where
39    E: fmt::Display,
40{
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Self::Service(e) => write!(f, "service error: {}", e),
44            Self::Open(_) => write!(f, "circuit breaker open"),
45        }
46    }
47}
48
49impl<E, R> PartialEq for Error<E, R>
50where
51    E: PartialEq,
52    R: PartialEq,
53{
54    fn eq(&self, other: &Self) -> bool {
55        match (self, other) {
56            (Self::Service(a), Self::Service(b)) => a == b,
57            (Self::Open(a), Self::Open(b)) => a == b,
58            _ => false,
59        }
60    }
61}
62
63pin_project! {
64    /// Response future for [`RetryCircuitBreaker`].
65    pub struct ResponseFuture<P, F, Request> {
66        state: Arc<Mutex<State<P>>>,
67        #[pin]
68        inner: Option<F>,
69        req: Option<Request>,
70    }
71}
72
73impl<P, F, T, E, Request> Future for ResponseFuture<P, F, Request>
74where
75    P: Policy<Request, T, E>,
76    P::Future: Send + 'static,
77    F: Future<Output = Result<T, E>>,
78{
79    type Output = Result<T, Error<E, Request>>;
80
81    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82        // Our response future exists in two states: the circuit breaker was either closed or open when we created it.
83        //
84        // When the circuit breaker is open while creating the response future, there's no actual response future to
85        // call in this case. We simply store the original request and pass it back while indicating to the caller that
86        // the circuit breaker is open. Simple.
87        //
88        // When the circuit breaker is closed while creating the response future, this means we can proceed, and we
89        // generate a legitimate response future to poll. However, the retry policy may return `None` when trying to
90        // clone the request, which indicates the request actually isn't eligible to be retried at all. Thus, when we
91        // don't have an original request here, we just return the inner service's response as-is. When we _do_ have the
92        // original request, we utilize the retry policy to determine if it can be retried, and if so, potentially
93        // update our circuit breaker state based on what the retry policy tells us.
94
95        let this = self.project();
96        if let Some(inner) = this.inner.as_pin_mut() {
97            let mut result = ready!(inner.poll(cx));
98
99            let mut state = this.state.lock().unwrap();
100            match this.req.take() {
101                Some(mut req) => match state.policy.retry(&mut req, &mut result) {
102                    Some(backoff) => {
103                        // The policy has indicated that the request should be retried, so we need to open the circuit
104                        // breaker by setting the backoff future to use. Another request's retry decision may have
105                        // already beat us to the punch, though, so don't overwrite it if it's already set.
106                        if state.backoff.is_none() {
107                            debug!("no existing backoff future present, setting delay backoff");
108                            state.backoff = Some(backoff.boxed());
109                        }
110
111                        Poll::Ready(Err(Error::Open(req)))
112                    }
113                    None => {
114                        debug!("request completed, no retry indicated");
115                        Poll::Ready(result.map_err(Error::Service))
116                    }
117                },
118                None => {
119                    debug!("request completed, but request not cloneable so returning response as-is");
120                    Poll::Ready(result.map_err(Error::Service))
121                }
122            }
123        } else {
124            debug!("circuit breaker open prior to call, returning error");
125            Poll::Ready(Err(Error::Open(
126                this.req.take().expect("response future polled after completion"),
127            )))
128        }
129    }
130}
131
132struct State<P> {
133    policy: P,
134    backoff: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
135}
136
137impl<P> std::fmt::Debug for State<P>
138where
139    P: std::fmt::Debug,
140{
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        let backoff = if self.backoff.is_some() { "set" } else { "unset" };
143        f.debug_struct("State")
144            .field("policy", &self.policy)
145            .field("backoff", &backoff)
146            .finish()
147    }
148}
149
150/// Wraps a service in a [circuit breaker][circuit_breaker] and signals when a request must be retried at a later time.
151///
152/// [circuit_breaker]: https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern
153pub struct RetryCircuitBreakerLayer<P> {
154    policy: P,
155}
156
157impl<P> RetryCircuitBreakerLayer<P> {
158    /// Creates a new [`RetryCircuitBreakerLayer`] with the given policy.
159    pub const fn new(policy: P) -> Self {
160        Self { policy }
161    }
162}
163
164impl<P, S> Layer<S> for RetryCircuitBreakerLayer<P>
165where
166    P: Clone,
167{
168    type Service = RetryCircuitBreaker<S, P>;
169
170    fn layer(&self, inner: S) -> Self::Service {
171        RetryCircuitBreaker::new(inner, self.policy.clone())
172    }
173}
174
175/// Wraps a service in a [circuit breaker][circuit_breaker] and signals when a request must be retried at a later time.
176///
177/// This circuit breaker implementation is specific to retrying requests. In many cases, a request can fail in two
178/// ways: unrecoverable errors, which should not be retried, and recoverable errors, which should be retried after a
179/// some period of time. When a request can be retried, it may not be advantageous to wait for the given request to
180/// be retried successfully, as the request should perhaps be stored in a queue and retried at a later time,
181/// potentially to avoid applying backpressure to the client.
182///
183/// [`RetryCircuitBreaker`] provides this capability by separating the logic of determining whether or not a request
184/// should be retried from actually performing the retry itself. When a request leads to an unrecoverable error,
185/// that error is immediately passed back to the caller without affecting the circuit breaker state. However, when a
186/// recoverable error is encountered, the circuit breaker will signal to the caller that the request should be
187/// retried, and update its internal state to open the circuit breaker for a configurable period of time. Further
188/// requests to the circuit breaker will be rejected with an error (indicating the open state) until that period of
189/// time has passed.
190///
191/// [circuit_breaker]: https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern
192#[derive(Debug)]
193pub struct RetryCircuitBreaker<S, P> {
194    inner: S,
195    state: Arc<Mutex<State<P>>>,
196}
197
198impl<S, P> RetryCircuitBreaker<S, P> {
199    /// Creates a new [`RetryCircuitBreaker`].
200    pub fn new(inner: S, policy: P) -> Self {
201        Self {
202            inner,
203            state: Arc::new(Mutex::new(State { policy, backoff: None })),
204        }
205    }
206}
207
208impl<S, P, Request> Service<Request> for RetryCircuitBreaker<S, P>
209where
210    S: Service<Request>,
211    P: Policy<Request, S::Response, S::Error>,
212    P::Future: Send + 'static,
213{
214    type Response = S::Response;
215    type Error = Error<S::Error, Request>;
216    type Future = ResponseFuture<P, S::Future, Request>;
217
218    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
219        {
220            // Check if we're currently in a backoff state.
221            let mut state = self.state.lock().unwrap();
222            if let Some(backoff) = state.backoff.as_mut() {
223                ready!(backoff.as_mut().poll(cx));
224
225                debug!("circuit breaker backoff complete");
226
227                // The backoff future has completed, so we can reset the circuit breaker state.
228                state.backoff = None;
229            }
230        }
231
232        // Check the readiness of the inner service.
233        self.inner.poll_ready(cx).map_err(Error::Service)
234    }
235
236    fn call(&mut self, req: Request) -> Self::Future {
237        let response_state = Arc::clone(&self.state);
238
239        let mut state = self.state.lock().unwrap();
240        if state.backoff.is_some() {
241            ResponseFuture {
242                state: response_state,
243                inner: None,
244                req: Some(req),
245            }
246        } else {
247            // The circuit breaker is closed, so we can proceed with the request.
248            let cloned_req = state.policy.clone_request(&req);
249            let inner = self.inner.call(req);
250
251            ResponseFuture {
252                state: response_state,
253                inner: Some(inner),
254                req: cloned_req,
255            }
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use std::{
263        future::{ready, Ready},
264        time::Duration,
265    };
266
267    use tokio::time::Sleep;
268    use tokio_test::{assert_pending, assert_ready, assert_ready_ok};
269    use tower::{retry::Policy, ServiceExt as _};
270
271    use super::*;
272
273    const BACKOFF_DUR: Duration = Duration::from_secs(1);
274
275    #[derive(Clone, Debug, Eq, PartialEq)]
276    enum BasicRequest {
277        Ok(String),
278        Err(String),
279    }
280
281    impl BasicRequest {
282        fn success<S: AsRef<str>>(value: S) -> Self {
283            Self::Ok(value.as_ref().to_string())
284        }
285
286        fn failure<S: AsRef<str>>(value: S) -> Self {
287            Self::Err(value.as_ref().to_string())
288        }
289
290        fn as_service_response(&self) -> Result<String, Error<String, Self>> {
291            match self {
292                Self::Ok(value) => Ok(value.clone()),
293                Self::Err(value) => Err(Error::Service(value.clone())),
294            }
295        }
296
297        fn as_open_response(&self) -> Result<String, Error<String, Self>> {
298            Err(Error::Open(self.clone()))
299        }
300    }
301
302    impl PartialEq<Result<String, String>> for BasicRequest {
303        fn eq(&self, other: &Result<String, String>) -> bool {
304            match self {
305                Self::Ok(value) => other.as_ref() == Ok(value),
306                Self::Err(value) => other.as_ref() == Err(value),
307            }
308        }
309    }
310
311    #[derive(Debug)]
312    struct LoopbackService;
313
314    impl Service<BasicRequest> for LoopbackService {
315        type Response = String;
316        type Error = String;
317        type Future = Ready<Result<String, String>>;
318
319        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
320            Poll::Ready(Ok(()))
321        }
322
323        fn call(&mut self, req: BasicRequest) -> Self::Future {
324            let res = match req {
325                BasicRequest::Ok(value) => Ok(value),
326                BasicRequest::Err(value) => Err(value),
327            };
328            ready(res)
329        }
330    }
331
332    #[derive(Debug)]
333    struct CloneableTestRetryPolicy;
334
335    impl<Req, T, E> Policy<Req, T, E> for CloneableTestRetryPolicy
336    where
337        Req: Clone,
338    {
339        type Future = Sleep;
340
341        fn retry(&mut self, _: &mut Req, res: &mut Result<T, E>) -> Option<Self::Future> {
342            match res {
343                Ok(_) => None,
344                Err(_) => Some(tokio::time::sleep(BACKOFF_DUR)),
345            }
346        }
347
348        fn clone_request(&mut self, req: &Req) -> Option<Req> {
349            Some(req.clone())
350        }
351    }
352
353    #[derive(Debug)]
354    struct NonCloneableTestRetryPolicy;
355
356    impl<Req, T, E> Policy<Req, T, E> for NonCloneableTestRetryPolicy {
357        type Future = Sleep;
358
359        fn retry(&mut self, _: &mut Req, res: &mut Result<T, E>) -> Option<Self::Future> {
360            match res {
361                Ok(_) => None,
362                Err(_) => Some(tokio::time::sleep(BACKOFF_DUR)),
363            }
364        }
365
366        fn clone_request(&mut self, _: &Req) -> Option<Req> {
367            None
368        }
369    }
370
371    #[tokio::test(start_paused = true)]
372    async fn basic() {
373        let good_req = BasicRequest::success("good");
374        let bad_req = BasicRequest::failure("bad");
375
376        let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, CloneableTestRetryPolicy);
377
378        // First request should succeed.
379        //
380        // We should see that it called through to the inner service.
381        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
382        let fut = svc.call(good_req.clone());
383        let result = fut.await;
384        assert_eq!(result, good_req.as_service_response());
385
386        // Second request should fail and should be retried.
387        //
388        // We should see that it called through to the inner service
389        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
390        let fut = svc.call(bad_req.clone());
391        let result = fut.await;
392        assert_eq!(result, bad_req.as_open_response());
393
394        // When trying to make our third request, we should have to wait for the backoff duration before the service
395        // indicates that it's ready for another call.
396        let mut svc_fut = tokio_test::task::spawn(circuit_breaker.ready());
397        assert_pending!(svc_fut.poll());
398
399        // Advance time past the backoff duration, which should make our service ready.
400        tokio::time::advance(BACKOFF_DUR + Duration::from_millis(1)).await;
401        assert!(svc_fut.is_woken());
402        let svc = assert_ready_ok!(svc_fut.poll());
403
404        let fut = svc.call(good_req.clone());
405        let result = fut.await;
406        assert_eq!(result, good_req.as_service_response());
407
408        // Fourth request should succeed unimpeded since the breaker is closed again.
409        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
410        let fut = svc.call(good_req.clone());
411        let result = fut.await;
412        assert_eq!(result, good_req.as_service_response());
413    }
414
415    #[tokio::test]
416    async fn retry_policy_no_clone() {
417        let good_req = BasicRequest::success("good");
418        let bad_req = BasicRequest::failure("bad");
419
420        // First request should succeed.
421        //
422        // We should see that it called through to the inner service.
423        let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, NonCloneableTestRetryPolicy);
424        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
425        let fut = svc.call(good_req.clone());
426        let result = fut.await;
427        assert_eq!(result, good_req.as_service_response());
428
429        // Second request should fail and should be a service error, because without being able to clone the request, it
430        // can't be retried anyways.
431        let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, NonCloneableTestRetryPolicy);
432        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
433        let fut = svc.call(bad_req.clone());
434        let result = fut.await;
435        assert_eq!(result, bad_req.as_service_response());
436    }
437
438    #[tokio::test(start_paused = true)]
439    async fn concurrent_calls_can_advance() {
440        let good_req = BasicRequest::success("good");
441        let bad_req = BasicRequest::failure("bad");
442
443        let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, CloneableTestRetryPolicy);
444
445        // First request should succeed. This is just a warmup.
446        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
447        let fut = svc.call(good_req.clone());
448        let result = fut.await;
449        assert_eq!(result, good_req.as_service_response());
450
451        // Now we'll create two calls -- one that should fail and one that succeed -- but won't poll them until both are
452        // created. This simulates two concurrent calls happening, and what we want to show is that the circuit breaker
453        // should only mark itself as open to _new_ calls after it the state changes to open, and should not affect
454        // running requests.
455        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
456        let bad_fut = svc.call(bad_req.clone());
457
458        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
459        let good_fut = svc.call(good_req.clone());
460
461        let bad_result = bad_fut.await;
462        assert_eq!(bad_result, bad_req.as_open_response());
463
464        let good_result = good_fut.await;
465        assert_eq!(good_result, good_req.as_service_response());
466
467        // Now we'll go to make a fourth request, and we'll manually check the readiness of the service to ensure that
468        // we're now in a backoff state.
469        let mut svc_fut = tokio_test::task::spawn(circuit_breaker.ready());
470        assert_pending!(svc_fut.poll());
471    }
472
473    #[tokio::test(start_paused = true)]
474    async fn breaker_open_between_ready_and_call() {
475        let good_req = BasicRequest::success("good");
476        let bad_req = BasicRequest::failure("bad");
477
478        let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, CloneableTestRetryPolicy);
479
480        // We'll create two calls -- one that should fail and one that succeed -- but order their creation / polling such that the
481        // bad request updates the breaker state to be open _before_ we create the good request. This is to exercise
482        // that even though the service may report itself as ready, an in-flight request that completes and ultimately
483        // changes the breaker state to open should cause subsequent calls to immediately fail.
484        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
485        let bad_fut = svc.call(bad_req.clone());
486
487        // We're just making sure here that the service is ready to accept another call, but we're not creating that
488        // call yet.
489        let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
490
491        // Run our bad request first and ensure it fails with an open error.
492        let bad_result = bad_fut.await;
493        assert_eq!(bad_result, bad_req.as_open_response());
494
495        // Now _create_ the good request and ensure that it also fails with an open error.
496        let good_fut = svc.call(good_req.clone());
497        let good_result = good_fut.await;
498        assert_eq!(good_result, good_req.as_open_response());
499    }
500}