1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::str::FromStr;
5use std::{hash::Hash, time::Duration};
6
7use super::aws_validator::AwsValidator;
8use super::http_validator::HttpValidator;
9use super::match_validator::MatchValidator;
10
11pub const DEFAULT_HTTPS_TIMEOUT_SEC: u64 = 3;
12pub const DEFAULT_AWS_STS_ENDPOINT: &str = "https://sts.amazonaws.com";
13
14#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
15pub struct AwsConfig {
16 #[serde(default = "default_aws_sts_endpoint")]
18 pub aws_sts_endpoint: String,
19 pub forced_datetime_utc: Option<DateTime<Utc>>,
21 #[serde(default = "default_timeout")]
22 pub timeout: Duration,
23}
24
25fn default_aws_sts_endpoint() -> String {
26 DEFAULT_AWS_STS_ENDPOINT.to_string()
27}
28
29fn default_timeout() -> Duration {
30 Duration::from_secs(DEFAULT_HTTPS_TIMEOUT_SEC)
31}
32
33impl Default for AwsConfig {
34 fn default() -> Self {
35 AwsConfig {
36 aws_sts_endpoint: default_aws_sts_endpoint(),
37 forced_datetime_utc: None,
38 timeout: default_timeout(),
39 }
40 }
41}
42
43#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
44#[serde(tag = "kind")]
45pub enum AwsType {
46 AwsId,
47 AwsSecret(AwsConfig),
48 AwsSession,
49}
50
51#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
52#[serde(rename_all = "UPPERCASE")]
53pub enum HttpMethod {
54 Get,
55 Post,
56 Put,
57 Delete,
58 Patch,
59}
60
61impl FromStr for HttpMethod {
62 type Err = String;
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 match s.to_uppercase().as_str() {
66 "GET" => Ok(HttpMethod::Get),
67 "POST" => Ok(HttpMethod::Post),
68 "PUT" => Ok(HttpMethod::Put),
69 "DELETE" => Ok(HttpMethod::Delete),
70 "PATCH" => Ok(HttpMethod::Patch),
71 _ => Err(format!("Invalid HTTP method: {}", s)),
72 }
73 }
74}
75
76#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
77pub struct RequestHeader {
78 pub key: String,
79 pub value: String,
81}
82
83impl RequestHeader {
84 pub fn get_value_with_match(&self, matche: &str) -> String {
85 let mut value = self.value.clone();
87 value = value.replace("$MATCH", matche);
88 value
89 }
90}
91
92#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
93pub struct HttpValidatorOption {
94 pub timeout: Duration,
95 }
99
100#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
101pub struct CustomHttpConfig {
102 pub endpoint: String,
103 #[serde(default)]
104 pub hosts: Vec<String>,
105 #[serde(default = "default_http_method")]
106 pub http_method: HttpMethod,
107 pub request_headers: BTreeMap<String, String>,
108 #[serde(default = "default_valid_http_status_code")]
109 pub valid_http_status_code: Vec<HttpStatusCodeRange>,
110 #[serde(default = "default_invalid_http_status_code")]
111 pub invalid_http_status_code: Vec<HttpStatusCodeRange>,
112 #[serde(default = "default_timeout_seconds")]
113 pub timeout_seconds: u32,
114}
115
116impl Default for CustomHttpConfig {
117 fn default() -> Self {
118 CustomHttpConfig {
119 endpoint: "".to_string(),
120 hosts: vec![],
121 http_method: HttpMethod::Get,
122 request_headers: BTreeMap::new(),
123 valid_http_status_code: vec![],
124 invalid_http_status_code: vec![],
125 timeout_seconds: DEFAULT_HTTPS_TIMEOUT_SEC as u32,
126 }
127 }
128}
129
130impl CustomHttpConfig {
131 pub fn get_endpoints(&self) -> Result<Vec<String>, String> {
132 if self.endpoint.contains("$HOST") && self.hosts.is_empty() {
136 return Err("Endpoint contains $HOST but no hosts are provided".to_string());
137 }
138 if !self.endpoint.contains("$HOST") && !self.hosts.is_empty() {
139 return Err("Endpoint does not contain $HOST but hosts are provided".to_string());
140 }
141
142 let mut endpoints = vec![];
144 for host in self.hosts.clone() {
145 endpoints.push(self.endpoint.replace("$HOST", &host));
146 }
147 if endpoints.is_empty() {
148 endpoints.push(self.endpoint.to_string());
150 }
151 Ok(endpoints)
152 }
153
154 pub fn with_endpoint(mut self, endpoint: String) -> Self {
157 self.endpoint = endpoint;
158 self
159 }
160
161 pub fn with_hosts(mut self, hosts: Vec<String>) -> Self {
162 self.hosts = hosts;
163 self
164 }
165
166 pub fn with_request_headers(mut self, request_headers: BTreeMap<String, String>) -> Self {
167 self.request_headers = request_headers;
168 self
169 }
170
171 pub fn with_valid_http_status_code(
172 mut self,
173 valid_http_status_code: Vec<HttpStatusCodeRange>,
174 ) -> Self {
175 self.valid_http_status_code = valid_http_status_code;
176 self
177 }
178
179 pub fn with_invalid_http_status_code(
180 mut self,
181 invalid_http_status_code: Vec<HttpStatusCodeRange>,
182 ) -> Self {
183 self.invalid_http_status_code = invalid_http_status_code;
184 self
185 }
186
187 pub fn set_endpoint(&mut self, endpoint: String) {
190 self.endpoint = endpoint;
191 }
192
193 pub fn set_hosts(&mut self, hosts: Vec<String>) {
194 self.hosts = hosts;
195 }
196
197 pub fn set_http_method(&mut self, http_method: HttpMethod) {
198 self.http_method = http_method;
199 }
200
201 pub fn set_request_headers(&mut self, request_headers: BTreeMap<String, String>) {
202 self.request_headers = request_headers;
203 }
204
205 pub fn set_valid_http_status_code(&mut self, valid_http_status_code: Vec<HttpStatusCodeRange>) {
206 self.valid_http_status_code = valid_http_status_code;
207 }
208
209 pub fn set_invalid_http_status_code(
210 &mut self,
211 invalid_http_status_code: Vec<HttpStatusCodeRange>,
212 ) {
213 self.invalid_http_status_code = invalid_http_status_code;
214 }
215
216 pub fn set_timeout_seconds(&mut self, timeout_seconds: u32) {
217 self.timeout_seconds = timeout_seconds;
218 }
219}
220
221fn default_timeout_seconds() -> u32 {
222 DEFAULT_HTTPS_TIMEOUT_SEC as u32
223}
224
225fn default_http_method() -> HttpMethod {
226 HttpMethod::Get
227}
228
229fn default_valid_http_status_code() -> Vec<HttpStatusCodeRange> {
230 vec![HttpStatusCodeRange {
231 start: 200,
232 end: 300,
233 }]
234}
235
236fn default_invalid_http_status_code() -> Vec<HttpStatusCodeRange> {
237 vec![HttpStatusCodeRange {
238 start: 400,
239 end: 500,
240 }]
241}
242
243#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
244pub struct HttpStatusCodeRange {
245 pub start: u16,
246 pub end: u16,
247}
248
249#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
250#[serde(tag = "type", content = "config")]
251pub enum MatchValidationType {
252 Aws(AwsType),
253 CustomHttp(CustomHttpConfig),
254}
255
256impl MatchValidationType {
257 pub fn can_create_match_validator(&self) -> bool {
259 match self {
260 MatchValidationType::Aws(aws_type) => matches!(aws_type, AwsType::AwsSecret(_)),
261 MatchValidationType::CustomHttp(_) => true,
262 }
263 }
264 pub fn get_internal_match_validation_type(&self) -> InternalMatchValidationType {
265 match self {
266 MatchValidationType::Aws(_) => InternalMatchValidationType::Aws,
267 MatchValidationType::CustomHttp(http_config) => {
268 InternalMatchValidationType::CustomHttp(http_config.get_endpoints().unwrap())
269 }
270 }
271 }
272 pub fn into_match_validator(&self) -> Result<Box<dyn MatchValidator>, String> {
273 match self {
274 MatchValidationType::Aws(aws_type) => match aws_type {
275 AwsType::AwsSecret(aws_config) => {
276 Ok(Box::new(AwsValidator::new(aws_config.clone())))
277 }
278 _ => Err("This aws type shall not be used to create a validator".to_string()),
279 },
280 MatchValidationType::CustomHttp(http_config) => Ok(Box::new(
281 HttpValidator::new_from_config(http_config.clone()),
282 )),
283 }
284 }
285}
286
287#[derive(PartialEq, Eq, Hash)]
291pub enum InternalMatchValidationType {
292 Aws,
293 CustomHttp(Vec<String>),
294}