1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::str::FromStr;
5use std::{hash::Hash, time::Duration};
6
7use crate::match_validation::http_validator_v2::HttpValidatorV2;
8
9use super::aws_validator::AwsValidator;
10use super::config_v2::{CustomHttpConfigV2, PairedValidatorConfig};
11use super::http_validator::HttpValidator;
12use super::match_validator::MatchValidator;
13
14pub const DEFAULT_HTTPS_TIMEOUT_SEC: u64 = 3;
15pub const DEFAULT_AWS_STS_ENDPOINT: &str = "https://sts.amazonaws.com";
16
17#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
18pub struct AwsConfig {
19 #[serde(default = "default_aws_sts_endpoint")]
21 pub aws_sts_endpoint: String,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub forced_datetime_utc: Option<DateTime<Utc>>,
25 #[serde(default = "default_timeout")]
26 pub timeout: Duration,
27}
28
29fn default_aws_sts_endpoint() -> String {
30 DEFAULT_AWS_STS_ENDPOINT.to_string()
31}
32
33fn default_timeout() -> Duration {
34 Duration::from_secs(DEFAULT_HTTPS_TIMEOUT_SEC)
35}
36
37impl Default for AwsConfig {
38 fn default() -> Self {
39 AwsConfig {
40 aws_sts_endpoint: default_aws_sts_endpoint(),
41 forced_datetime_utc: None,
42 timeout: default_timeout(),
43 }
44 }
45}
46
47#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
48#[serde(tag = "kind")]
49pub enum AwsType {
50 AwsId,
51 AwsSecret(AwsConfig),
52 AwsSession,
53}
54
55#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
56#[serde(rename_all = "UPPERCASE")]
57pub enum HttpMethod {
58 Get,
59 Post,
60 Put,
61 Delete,
62 Patch,
63}
64
65impl FromStr for HttpMethod {
66 type Err = String;
67
68 fn from_str(s: &str) -> Result<Self, Self::Err> {
69 match s.to_uppercase().as_str() {
70 "GET" => Ok(HttpMethod::Get),
71 "POST" => Ok(HttpMethod::Post),
72 "PUT" => Ok(HttpMethod::Put),
73 "DELETE" => Ok(HttpMethod::Delete),
74 "PATCH" => Ok(HttpMethod::Patch),
75 _ => Err(format!("Invalid HTTP method: {s}")),
76 }
77 }
78}
79
80#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
81pub struct RequestHeader {
82 pub key: String,
83 pub value: String,
85}
86
87impl RequestHeader {
88 pub fn get_value_with_match(&self, matche: &str) -> String {
89 let mut value = self.value.clone();
91 value = value.replace("$MATCH", matche);
92 value
93 }
94}
95
96#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
97pub struct HttpValidatorOption {
98 pub timeout: Duration,
99 }
103
104#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
105pub struct CustomHttpConfig {
106 pub endpoint: String,
107 #[serde(default)]
108 pub hosts: Vec<String>,
109 #[serde(default = "default_http_method")]
110 pub http_method: HttpMethod,
111 pub request_headers: BTreeMap<String, String>,
112 #[serde(default = "default_valid_http_status_code")]
113 pub valid_http_status_code: Vec<HttpStatusCodeRange>,
114 #[serde(default = "default_invalid_http_status_code")]
115 pub invalid_http_status_code: Vec<HttpStatusCodeRange>,
116 #[serde(default = "default_timeout_seconds")]
117 pub timeout_seconds: u32,
118}
119
120impl Default for CustomHttpConfig {
121 fn default() -> Self {
122 CustomHttpConfig {
123 endpoint: "".to_string(),
124 hosts: vec![],
125 http_method: HttpMethod::Get,
126 request_headers: BTreeMap::new(),
127 valid_http_status_code: vec![],
128 invalid_http_status_code: vec![],
129 timeout_seconds: DEFAULT_HTTPS_TIMEOUT_SEC as u32,
130 }
131 }
132}
133
134impl CustomHttpConfig {
135 pub fn get_endpoints(&self) -> Result<Vec<String>, String> {
136 if self.endpoint.contains("$HOST") && self.hosts.is_empty() {
140 return Err("Endpoint contains $HOST but no hosts are provided".to_string());
141 }
142 if !self.endpoint.contains("$HOST") && !self.hosts.is_empty() {
143 return Err("Endpoint does not contain $HOST but hosts are provided".to_string());
144 }
145
146 let mut endpoints = vec![];
148 for host in self.hosts.clone() {
149 endpoints.push(self.endpoint.replace("$HOST", &host));
150 }
151 if endpoints.is_empty() {
152 endpoints.push(self.endpoint.to_string());
154 }
155 Ok(endpoints)
156 }
157
158 pub fn with_endpoint(mut self, endpoint: String) -> Self {
161 self.endpoint = endpoint;
162 self
163 }
164
165 pub fn with_hosts(mut self, hosts: Vec<String>) -> Self {
166 self.hosts = hosts;
167 self
168 }
169
170 pub fn with_request_headers(mut self, request_headers: BTreeMap<String, String>) -> Self {
171 self.request_headers = request_headers;
172 self
173 }
174
175 pub fn with_valid_http_status_code(
176 mut self,
177 valid_http_status_code: Vec<HttpStatusCodeRange>,
178 ) -> Self {
179 self.valid_http_status_code = valid_http_status_code;
180 self
181 }
182
183 pub fn with_invalid_http_status_code(
184 mut self,
185 invalid_http_status_code: Vec<HttpStatusCodeRange>,
186 ) -> Self {
187 self.invalid_http_status_code = invalid_http_status_code;
188 self
189 }
190
191 pub fn set_endpoint(&mut self, endpoint: String) {
194 self.endpoint = endpoint;
195 }
196
197 pub fn set_hosts(&mut self, hosts: Vec<String>) {
198 self.hosts = hosts;
199 }
200
201 pub fn set_http_method(&mut self, http_method: HttpMethod) {
202 self.http_method = http_method;
203 }
204
205 pub fn set_request_headers(&mut self, request_headers: BTreeMap<String, String>) {
206 self.request_headers = request_headers;
207 }
208
209 pub fn set_valid_http_status_code(&mut self, valid_http_status_code: Vec<HttpStatusCodeRange>) {
210 self.valid_http_status_code = valid_http_status_code;
211 }
212
213 pub fn set_invalid_http_status_code(
214 &mut self,
215 invalid_http_status_code: Vec<HttpStatusCodeRange>,
216 ) {
217 self.invalid_http_status_code = invalid_http_status_code;
218 }
219
220 pub fn set_timeout_seconds(&mut self, timeout_seconds: u32) {
221 self.timeout_seconds = timeout_seconds;
222 }
223}
224
225fn default_timeout_seconds() -> u32 {
226 DEFAULT_HTTPS_TIMEOUT_SEC as u32
227}
228
229fn default_http_method() -> HttpMethod {
230 HttpMethod::Get
231}
232
233fn default_valid_http_status_code() -> Vec<HttpStatusCodeRange> {
234 vec![HttpStatusCodeRange {
235 start: 200,
236 end: 300,
237 }]
238}
239
240fn default_invalid_http_status_code() -> Vec<HttpStatusCodeRange> {
241 vec![HttpStatusCodeRange {
242 start: 400,
243 end: 500,
244 }]
245}
246
247#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
248pub struct HttpStatusCodeRange {
249 pub start: u16,
250 pub end: u16,
251}
252
253#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
254#[serde(tag = "type", content = "config")]
255pub enum MatchValidationType {
256 Aws(AwsType),
257 CustomHttp(CustomHttpConfig),
258 CustomHttpV2(CustomHttpConfigV2),
259 PairedValidator(PairedValidatorConfig),
260}
261
262impl MatchValidationType {
263 pub fn can_create_match_validator(&self) -> bool {
265 match self {
266 MatchValidationType::Aws(aws_type) => matches!(aws_type, AwsType::AwsSecret(_)),
267 MatchValidationType::CustomHttp(_) => true,
268 MatchValidationType::CustomHttpV2(_) => true,
269 MatchValidationType::PairedValidator(_) => false, }
271 }
272 pub fn get_internal_match_validation_type(&self) -> InternalMatchValidationType {
273 match self {
274 MatchValidationType::Aws(_) => InternalMatchValidationType::Aws,
275 MatchValidationType::CustomHttp(http_config) => {
276 InternalMatchValidationType::CustomHttp(http_config.get_endpoints().unwrap())
277 }
278 MatchValidationType::CustomHttpV2(_) => InternalMatchValidationType::CustomHttpV2,
279 MatchValidationType::PairedValidator(config) => {
280 InternalMatchValidationType::PairedValidator(config.kind.clone())
281 }
282 }
283 }
284 pub fn into_match_validator(&self) -> Result<Box<dyn MatchValidator>, String> {
285 match self {
286 MatchValidationType::Aws(aws_type) => match aws_type {
287 AwsType::AwsSecret(aws_config) => {
288 Ok(Box::new(AwsValidator::new(aws_config.clone())))
289 }
290 _ => Err("This aws type shall not be used to create a validator".to_string()),
291 },
292 MatchValidationType::CustomHttp(http_config) => Ok(Box::new(
293 HttpValidator::new_from_config(http_config.clone()),
294 )),
295 MatchValidationType::CustomHttpV2(http_config_v2) => Ok(Box::new(
296 HttpValidatorV2::new_from_config(http_config_v2.clone()),
297 )),
298 MatchValidationType::PairedValidator(_) => {
299 Err("PairedValidator cannot be used to create a standalone validator".to_string())
300 }
301 }
302 }
303}
304
305#[derive(PartialEq, Eq, Hash)]
309pub enum InternalMatchValidationType {
310 Aws,
311 CustomHttp(Vec<String>),
312 CustomHttpV2,
313 PairedValidator(String), }
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_serialization_of_aws_config() {
322 let aws_config = AwsConfig {
323 aws_sts_endpoint: "https://sts.amazonaws.com".to_string(),
324 forced_datetime_utc: None,
325 timeout: Duration::from_secs(3),
326 };
327 let serialized = serde_json::to_string(&aws_config).unwrap();
328 assert_eq!(
330 serialized,
331 "{\"aws_sts_endpoint\":\"https://sts.amazonaws.com\",\"timeout\":{\"secs\":3,\"nanos\":0}}"
332 );
333 }
334}