saluki_io/net/util/retry/classifier/
http.rs1use std::sync::Arc;
2
3use http::Response;
4
5use super::RetryClassifier;
6
7pub type HttpRetryPredicate<B = ()> = Arc<dyn Fn(&Response<B>) -> bool + Send + Sync>;
11
12fn default_should_retry<B>(response: &Response<B>) -> bool {
13 let status = response.status();
14
15 match status {
16 http::StatusCode::BAD_REQUEST
19 | http::StatusCode::UNAUTHORIZED
20 | http::StatusCode::FORBIDDEN
21 | http::StatusCode::PAYLOAD_TOO_LARGE => false,
22
23 _ => status.is_client_error() || status.is_server_error(),
25 }
26}
27
28pub struct StandardHttpClassifier<B = ()> {
43 predicates: Vec<HttpRetryPredicate<B>>,
44}
45
46impl<B> Clone for StandardHttpClassifier<B> {
47 fn clone(&self) -> Self {
48 Self {
49 predicates: self.predicates.clone(),
50 }
51 }
52}
53
54impl<B: 'static> Default for StandardHttpClassifier<B> {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl<B: 'static> StandardHttpClassifier<B> {
61 pub fn new() -> Self {
63 Self {
64 predicates: vec![Arc::new(default_should_retry::<B>)],
65 }
66 }
67
68 pub fn with_predicate(mut self, predicate: HttpRetryPredicate<B>) -> Self {
72 self.predicates.push(predicate);
73 self
74 }
75}
76
77impl<B, Error> RetryClassifier<http::Response<B>, Error> for StandardHttpClassifier<B> {
78 fn should_retry(&self, response: &Result<http::Response<B>, Error>) -> bool {
79 match response {
80 Ok(resp) => self.predicates.iter().any(|p| p(resp)),
81 Err(_) => true,
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use std::sync::atomic::{AtomicBool, Ordering};
89
90 use http::StatusCode;
91
92 use super::*;
93
94 type TestResponse = Result<http::Response<()>, ()>;
95
96 fn ok(status: StatusCode) -> TestResponse {
97 Ok(http::Response::builder().status(status).body(()).unwrap())
98 }
99
100 fn err() -> TestResponse {
101 Err(())
102 }
103
104 fn classify(classifier: &StandardHttpClassifier<()>, response: &TestResponse) -> bool {
105 <StandardHttpClassifier<()> as RetryClassifier<http::Response<()>, ()>>::should_retry(classifier, response)
106 }
107
108 #[test]
109 fn default_classifier_retries_5xx_and_most_4xx() {
110 let classifier = StandardHttpClassifier::new();
111
112 assert!(!classify(&classifier, &ok(StatusCode::OK)));
113 assert!(!classify(&classifier, &ok(StatusCode::NO_CONTENT)));
114
115 for status in [
116 StatusCode::INTERNAL_SERVER_ERROR,
117 StatusCode::BAD_GATEWAY,
118 StatusCode::SERVICE_UNAVAILABLE,
119 StatusCode::GATEWAY_TIMEOUT,
120 ] {
121 assert!(classify(&classifier, &ok(status)), "{} should be retried", status);
122 }
123
124 for status in [
125 StatusCode::REQUEST_TIMEOUT,
126 StatusCode::TOO_MANY_REQUESTS,
127 StatusCode::NOT_FOUND,
128 ] {
129 assert!(classify(&classifier, &ok(status)), "{} should be retried", status);
130 }
131 }
132
133 #[test]
134 fn default_classifier_does_not_retry_known_client_misconfig() {
135 let classifier = StandardHttpClassifier::new();
136
137 for status in [
138 StatusCode::BAD_REQUEST,
139 StatusCode::UNAUTHORIZED,
140 StatusCode::FORBIDDEN,
141 StatusCode::PAYLOAD_TOO_LARGE,
142 ] {
143 assert!(!classify(&classifier, &ok(status)), "{} should not be retried", status);
144 }
145 }
146
147 #[test]
148 fn default_classifier_retries_transport_error() {
149 let classifier = StandardHttpClassifier::new();
150 assert!(classify(&classifier, &err()));
151 }
152
153 #[test]
154 fn predicate_adds_retry_for_403() {
155 let classifier = StandardHttpClassifier::new()
156 .with_predicate(Arc::new(|response| response.status() == StatusCode::FORBIDDEN));
157
158 assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN)));
159 assert!(!classify(&classifier, &ok(StatusCode::UNAUTHORIZED)));
161 assert!(!classify(&classifier, &ok(StatusCode::BAD_REQUEST)));
162 assert!(classify(&classifier, &ok(StatusCode::INTERNAL_SERVER_ERROR)));
164 }
165
166 #[test]
167 fn predicate_is_re_evaluated_each_call() {
168 let flag = Arc::new(AtomicBool::new(false));
169 let flag_clone = Arc::clone(&flag);
170 let predicate: HttpRetryPredicate =
171 Arc::new(move |response| response.status() == StatusCode::FORBIDDEN && flag_clone.load(Ordering::SeqCst));
172
173 let classifier = StandardHttpClassifier::new().with_predicate(predicate);
174
175 assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN)));
176
177 flag.store(true, Ordering::SeqCst);
178 assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN)));
179
180 flag.store(false, Ordering::SeqCst);
181 assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN)));
182 }
183}