Skip to main content

saluki_io/net/util/retry/classifier/
http.rs

1use std::sync::Arc;
2
3use http::Response;
4
5use super::RetryClassifier;
6
7/// A predicate that decides whether a response should be treated as retriable.
8///
9/// The predicate receives the response and returns `true` if the response should be retried.
10pub type HttpRetryPredicate<B = ()> = Arc<dyn Fn(&Response<B>) -> bool + Send + Sync>;
11
12fn default_should_retry<B>(response: &Response<B>) -> bool {
13    let status = response.status();
14
15    match status {
16        // There are some status codes that likely indicate a fundamental misconfiguration or bug on the client side
17        // which won't be resolved by retrying the request.
18        http::StatusCode::BAD_REQUEST
19        | http::StatusCode::UNAUTHORIZED
20        | http::StatusCode::FORBIDDEN
21        | http::StatusCode::PAYLOAD_TOO_LARGE => false,
22
23        // For all other status codes, we'll only retry if they're in the client/server error range.
24        _ => status.is_client_error() || status.is_server_error(),
25    }
26}
27
28/// A standard HTTP response classifier.
29///
30/// Generally treats all client (4xx) and server (5xx) errors as retriable, with the exception of a few specific client
31/// errors that shouldn't be retried:
32///
33/// - 400 Bad Request (likely a client-side bug)
34/// - 401 Unauthorized (likely a client-side misconfiguration)
35/// - 403 Forbidden (likely a client-side misconfiguration)
36/// - 413 Payload Too Large (likely a client-side bug)
37///
38/// Additional [`HttpRetryPredicate`]s can be registered via [`StandardHttpClassifier::with_predicate`]. A response is
39/// retried if any predicate—including the default—returns `true` (OR semantics). This allows callers to
40/// selectively unlock retries for status codes that the default predicate would not retry, without affecting other
41/// status codes.
42pub struct StandardHttpClassifier<B = ()> {
43    predicates: Vec<HttpRetryPredicate<B>>,
44}
45
46impl<B> Clone for StandardHttpClassifier<B> {
47    fn clone(&self) -> Self {
48        Self {
49            predicates: self.predicates.clone(),
50        }
51    }
52}
53
54impl<B: 'static> Default for StandardHttpClassifier<B> {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl<B: 'static> StandardHttpClassifier<B> {
61    /// Creates a new [`StandardHttpClassifier`] with the default status-code predicate installed.
62    pub fn new() -> Self {
63        Self {
64            predicates: vec![Arc::new(default_should_retry::<B>)],
65        }
66    }
67
68    /// Adds a predicate.
69    ///
70    /// A response is retried if any predicate—including the default—returns `true` (OR semantics).
71    pub fn with_predicate(mut self, predicate: HttpRetryPredicate<B>) -> Self {
72        self.predicates.push(predicate);
73        self
74    }
75}
76
77impl<B, Error> RetryClassifier<http::Response<B>, Error> for StandardHttpClassifier<B> {
78    fn should_retry(&self, response: &Result<http::Response<B>, Error>) -> bool {
79        match response {
80            Ok(resp) => self.predicates.iter().any(|p| p(resp)),
81            Err(_) => true,
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::sync::atomic::{AtomicBool, Ordering};
89
90    use http::StatusCode;
91
92    use super::*;
93
94    type TestResponse = Result<http::Response<()>, ()>;
95
96    fn ok(status: StatusCode) -> TestResponse {
97        Ok(http::Response::builder().status(status).body(()).unwrap())
98    }
99
100    fn err() -> TestResponse {
101        Err(())
102    }
103
104    fn classify(classifier: &StandardHttpClassifier<()>, response: &TestResponse) -> bool {
105        <StandardHttpClassifier<()> as RetryClassifier<http::Response<()>, ()>>::should_retry(classifier, response)
106    }
107
108    #[test]
109    fn default_classifier_retries_5xx_and_most_4xx() {
110        let classifier = StandardHttpClassifier::new();
111
112        assert!(!classify(&classifier, &ok(StatusCode::OK)));
113        assert!(!classify(&classifier, &ok(StatusCode::NO_CONTENT)));
114
115        for status in [
116            StatusCode::INTERNAL_SERVER_ERROR,
117            StatusCode::BAD_GATEWAY,
118            StatusCode::SERVICE_UNAVAILABLE,
119            StatusCode::GATEWAY_TIMEOUT,
120        ] {
121            assert!(classify(&classifier, &ok(status)), "{} should be retried", status);
122        }
123
124        for status in [
125            StatusCode::REQUEST_TIMEOUT,
126            StatusCode::TOO_MANY_REQUESTS,
127            StatusCode::NOT_FOUND,
128        ] {
129            assert!(classify(&classifier, &ok(status)), "{} should be retried", status);
130        }
131    }
132
133    #[test]
134    fn default_classifier_does_not_retry_known_client_misconfig() {
135        let classifier = StandardHttpClassifier::new();
136
137        for status in [
138            StatusCode::BAD_REQUEST,
139            StatusCode::UNAUTHORIZED,
140            StatusCode::FORBIDDEN,
141            StatusCode::PAYLOAD_TOO_LARGE,
142        ] {
143            assert!(!classify(&classifier, &ok(status)), "{} should not be retried", status);
144        }
145    }
146
147    #[test]
148    fn default_classifier_retries_transport_error() {
149        let classifier = StandardHttpClassifier::new();
150        assert!(classify(&classifier, &err()));
151    }
152
153    #[test]
154    fn predicate_adds_retry_for_403() {
155        let classifier = StandardHttpClassifier::new()
156            .with_predicate(Arc::new(|response| response.status() == StatusCode::FORBIDDEN));
157
158        assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN)));
159        // Sibling client-misconfig statuses without a matching predicate keep their default (non-retriable) behavior.
160        assert!(!classify(&classifier, &ok(StatusCode::UNAUTHORIZED)));
161        assert!(!classify(&classifier, &ok(StatusCode::BAD_REQUEST)));
162        // Status codes that are retried by default are unaffected.
163        assert!(classify(&classifier, &ok(StatusCode::INTERNAL_SERVER_ERROR)));
164    }
165
166    #[test]
167    fn predicate_is_re_evaluated_each_call() {
168        let flag = Arc::new(AtomicBool::new(false));
169        let flag_clone = Arc::clone(&flag);
170        let predicate: HttpRetryPredicate =
171            Arc::new(move |response| response.status() == StatusCode::FORBIDDEN && flag_clone.load(Ordering::SeqCst));
172
173        let classifier = StandardHttpClassifier::new().with_predicate(predicate);
174
175        assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN)));
176
177        flag.store(true, Ordering::SeqCst);
178        assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN)));
179
180        flag.store(false, Ordering::SeqCst);
181        assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN)));
182    }
183}