1use std::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{ready, Context, Poll},
7};
8
9use futures::FutureExt as _;
10use pin_project_lite::pin_project;
11use tower::{retry::Policy, Layer, Service};
12use tracing::debug;
13
14#[derive(Debug)]
16pub enum Error<E, R> {
17 Service(E),
19
20 Open(R),
22}
23
24impl<E, R> std::error::Error for Error<E, R>
25where
26 E: std::error::Error + 'static,
27 R: fmt::Debug,
28{
29 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
30 match self {
31 Self::Service(e) => Some(e),
32 Self::Open(_) => None,
33 }
34 }
35}
36
37impl<E, R> fmt::Display for Error<E, R>
38where
39 E: fmt::Display,
40{
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Self::Service(e) => write!(f, "service error: {}", e),
44 Self::Open(_) => write!(f, "circuit breaker open"),
45 }
46 }
47}
48
49impl<E, R> PartialEq for Error<E, R>
50where
51 E: PartialEq,
52 R: PartialEq,
53{
54 fn eq(&self, other: &Self) -> bool {
55 match (self, other) {
56 (Self::Service(a), Self::Service(b)) => a == b,
57 (Self::Open(a), Self::Open(b)) => a == b,
58 _ => false,
59 }
60 }
61}
62
63pin_project! {
64 pub struct ResponseFuture<P, F, Request> {
66 state: Arc<Mutex<State<P>>>,
67 #[pin]
68 inner: Option<F>,
69 req: Option<Request>,
70 }
71}
72
73impl<P, F, T, E, Request> Future for ResponseFuture<P, F, Request>
74where
75 P: Policy<Request, T, E>,
76 P::Future: Send + 'static,
77 F: Future<Output = Result<T, E>>,
78{
79 type Output = Result<T, Error<E, Request>>;
80
81 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82 let this = self.project();
96 if let Some(inner) = this.inner.as_pin_mut() {
97 let mut result = ready!(inner.poll(cx));
98
99 let mut state = this.state.lock().unwrap();
100 match this.req.take() {
101 Some(mut req) => match state.policy.retry(&mut req, &mut result) {
102 Some(backoff) => {
103 if state.backoff.is_none() {
107 debug!("no existing backoff future present, setting delay backoff");
108 state.backoff = Some(backoff.boxed());
109 }
110
111 Poll::Ready(Err(Error::Open(req)))
112 }
113 None => {
114 debug!("request completed, no retry indicated");
115 Poll::Ready(result.map_err(Error::Service))
116 }
117 },
118 None => {
119 debug!("request completed, but request not cloneable so returning response as-is");
120 Poll::Ready(result.map_err(Error::Service))
121 }
122 }
123 } else {
124 debug!("circuit breaker open prior to call, returning error");
125 Poll::Ready(Err(Error::Open(
126 this.req.take().expect("response future polled after completion"),
127 )))
128 }
129 }
130}
131
132struct State<P> {
133 policy: P,
134 backoff: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
135}
136
137impl<P> std::fmt::Debug for State<P>
138where
139 P: std::fmt::Debug,
140{
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 let backoff = if self.backoff.is_some() { "set" } else { "unset" };
143 f.debug_struct("State")
144 .field("policy", &self.policy)
145 .field("backoff", &backoff)
146 .finish()
147 }
148}
149
150pub struct RetryCircuitBreakerLayer<P> {
154 policy: P,
155}
156
157impl<P> RetryCircuitBreakerLayer<P> {
158 pub const fn new(policy: P) -> Self {
160 Self { policy }
161 }
162}
163
164impl<P, S> Layer<S> for RetryCircuitBreakerLayer<P>
165where
166 P: Clone,
167{
168 type Service = RetryCircuitBreaker<S, P>;
169
170 fn layer(&self, inner: S) -> Self::Service {
171 RetryCircuitBreaker::new(inner, self.policy.clone())
172 }
173}
174
175#[derive(Debug)]
193pub struct RetryCircuitBreaker<S, P> {
194 inner: S,
195 state: Arc<Mutex<State<P>>>,
196}
197
198impl<S, P> RetryCircuitBreaker<S, P> {
199 pub fn new(inner: S, policy: P) -> Self {
201 Self {
202 inner,
203 state: Arc::new(Mutex::new(State { policy, backoff: None })),
204 }
205 }
206}
207
208impl<S, P, Request> Service<Request> for RetryCircuitBreaker<S, P>
209where
210 S: Service<Request>,
211 P: Policy<Request, S::Response, S::Error>,
212 P::Future: Send + 'static,
213{
214 type Response = S::Response;
215 type Error = Error<S::Error, Request>;
216 type Future = ResponseFuture<P, S::Future, Request>;
217
218 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
219 {
220 let mut state = self.state.lock().unwrap();
222 if let Some(backoff) = state.backoff.as_mut() {
223 ready!(backoff.as_mut().poll(cx));
224
225 debug!("circuit breaker backoff complete");
226
227 state.backoff = None;
229 }
230 }
231
232 self.inner.poll_ready(cx).map_err(Error::Service)
234 }
235
236 fn call(&mut self, req: Request) -> Self::Future {
237 let response_state = Arc::clone(&self.state);
238
239 let mut state = self.state.lock().unwrap();
240 if state.backoff.is_some() {
241 ResponseFuture {
242 state: response_state,
243 inner: None,
244 req: Some(req),
245 }
246 } else {
247 let cloned_req = state.policy.clone_request(&req);
249 let inner = self.inner.call(req);
250
251 ResponseFuture {
252 state: response_state,
253 inner: Some(inner),
254 req: cloned_req,
255 }
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use std::{
263 future::{ready, Ready},
264 time::Duration,
265 };
266
267 use tokio::time::Sleep;
268 use tokio_test::{assert_pending, assert_ready, assert_ready_ok};
269 use tower::{retry::Policy, ServiceExt as _};
270
271 use super::*;
272
273 const BACKOFF_DUR: Duration = Duration::from_secs(1);
274
275 #[derive(Clone, Debug, Eq, PartialEq)]
276 enum BasicRequest {
277 Ok(String),
278 Err(String),
279 }
280
281 impl BasicRequest {
282 fn success<S: AsRef<str>>(value: S) -> Self {
283 Self::Ok(value.as_ref().to_string())
284 }
285
286 fn failure<S: AsRef<str>>(value: S) -> Self {
287 Self::Err(value.as_ref().to_string())
288 }
289
290 fn as_service_response(&self) -> Result<String, Error<String, Self>> {
291 match self {
292 Self::Ok(value) => Ok(value.clone()),
293 Self::Err(value) => Err(Error::Service(value.clone())),
294 }
295 }
296
297 fn as_open_response(&self) -> Result<String, Error<String, Self>> {
298 Err(Error::Open(self.clone()))
299 }
300 }
301
302 impl PartialEq<Result<String, String>> for BasicRequest {
303 fn eq(&self, other: &Result<String, String>) -> bool {
304 match self {
305 Self::Ok(value) => other.as_ref() == Ok(value),
306 Self::Err(value) => other.as_ref() == Err(value),
307 }
308 }
309 }
310
311 #[derive(Debug)]
312 struct LoopbackService;
313
314 impl Service<BasicRequest> for LoopbackService {
315 type Response = String;
316 type Error = String;
317 type Future = Ready<Result<String, String>>;
318
319 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
320 Poll::Ready(Ok(()))
321 }
322
323 fn call(&mut self, req: BasicRequest) -> Self::Future {
324 let res = match req {
325 BasicRequest::Ok(value) => Ok(value),
326 BasicRequest::Err(value) => Err(value),
327 };
328 ready(res)
329 }
330 }
331
332 #[derive(Debug)]
333 struct CloneableTestRetryPolicy;
334
335 impl<Req, T, E> Policy<Req, T, E> for CloneableTestRetryPolicy
336 where
337 Req: Clone,
338 {
339 type Future = Sleep;
340
341 fn retry(&mut self, _: &mut Req, res: &mut Result<T, E>) -> Option<Self::Future> {
342 match res {
343 Ok(_) => None,
344 Err(_) => Some(tokio::time::sleep(BACKOFF_DUR)),
345 }
346 }
347
348 fn clone_request(&mut self, req: &Req) -> Option<Req> {
349 Some(req.clone())
350 }
351 }
352
353 #[derive(Debug)]
354 struct NonCloneableTestRetryPolicy;
355
356 impl<Req, T, E> Policy<Req, T, E> for NonCloneableTestRetryPolicy {
357 type Future = Sleep;
358
359 fn retry(&mut self, _: &mut Req, res: &mut Result<T, E>) -> Option<Self::Future> {
360 match res {
361 Ok(_) => None,
362 Err(_) => Some(tokio::time::sleep(BACKOFF_DUR)),
363 }
364 }
365
366 fn clone_request(&mut self, _: &Req) -> Option<Req> {
367 None
368 }
369 }
370
371 #[tokio::test(start_paused = true)]
372 async fn basic() {
373 let good_req = BasicRequest::success("good");
374 let bad_req = BasicRequest::failure("bad");
375
376 let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, CloneableTestRetryPolicy);
377
378 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
382 let fut = svc.call(good_req.clone());
383 let result = fut.await;
384 assert_eq!(result, good_req.as_service_response());
385
386 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
390 let fut = svc.call(bad_req.clone());
391 let result = fut.await;
392 assert_eq!(result, bad_req.as_open_response());
393
394 let mut svc_fut = tokio_test::task::spawn(circuit_breaker.ready());
397 assert_pending!(svc_fut.poll());
398
399 tokio::time::advance(BACKOFF_DUR + Duration::from_millis(1)).await;
401 assert!(svc_fut.is_woken());
402 let svc = assert_ready_ok!(svc_fut.poll());
403
404 let fut = svc.call(good_req.clone());
405 let result = fut.await;
406 assert_eq!(result, good_req.as_service_response());
407
408 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
410 let fut = svc.call(good_req.clone());
411 let result = fut.await;
412 assert_eq!(result, good_req.as_service_response());
413 }
414
415 #[tokio::test]
416 async fn retry_policy_no_clone() {
417 let good_req = BasicRequest::success("good");
418 let bad_req = BasicRequest::failure("bad");
419
420 let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, NonCloneableTestRetryPolicy);
424 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
425 let fut = svc.call(good_req.clone());
426 let result = fut.await;
427 assert_eq!(result, good_req.as_service_response());
428
429 let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, NonCloneableTestRetryPolicy);
432 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
433 let fut = svc.call(bad_req.clone());
434 let result = fut.await;
435 assert_eq!(result, bad_req.as_service_response());
436 }
437
438 #[tokio::test(start_paused = true)]
439 async fn concurrent_calls_can_advance() {
440 let good_req = BasicRequest::success("good");
441 let bad_req = BasicRequest::failure("bad");
442
443 let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, CloneableTestRetryPolicy);
444
445 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
447 let fut = svc.call(good_req.clone());
448 let result = fut.await;
449 assert_eq!(result, good_req.as_service_response());
450
451 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
456 let bad_fut = svc.call(bad_req.clone());
457
458 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
459 let good_fut = svc.call(good_req.clone());
460
461 let bad_result = bad_fut.await;
462 assert_eq!(bad_result, bad_req.as_open_response());
463
464 let good_result = good_fut.await;
465 assert_eq!(good_result, good_req.as_service_response());
466
467 let mut svc_fut = tokio_test::task::spawn(circuit_breaker.ready());
470 assert_pending!(svc_fut.poll());
471 }
472
473 #[tokio::test(start_paused = true)]
474 async fn breaker_open_between_ready_and_call() {
475 let good_req = BasicRequest::success("good");
476 let bad_req = BasicRequest::failure("bad");
477
478 let mut circuit_breaker = RetryCircuitBreaker::new(LoopbackService, CloneableTestRetryPolicy);
479
480 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
485 let bad_fut = svc.call(bad_req.clone());
486
487 let svc = circuit_breaker.ready().await.expect("should never fail to be ready");
490
491 let bad_result = bad_fut.await;
493 assert_eq!(bad_result, bad_req.as_open_response());
494
495 let good_fut = svc.call(good_req.clone());
497 let good_result = good_fut.await;
498 assert_eq!(good_result, good_req.as_open_response());
499 }
500}