dd_sds/match_validation/
config.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::str::FromStr;
5use std::{hash::Hash, time::Duration};
6
7use super::aws_validator::AwsValidator;
8use super::http_validator::HttpValidator;
9use super::match_validator::MatchValidator;
10
11pub const DEFAULT_HTTPS_TIMEOUT_SEC: u64 = 3;
12pub const DEFAULT_AWS_STS_ENDPOINT: &str = "https://sts.amazonaws.com";
13
14#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
15pub struct AwsConfig {
16    // Override default AWS STS endpoint for testing
17    #[serde(default = "default_aws_sts_endpoint")]
18    pub aws_sts_endpoint: String,
19    // Override default datetime for testing
20    pub forced_datetime_utc: Option<DateTime<Utc>>,
21    #[serde(default = "default_timeout")]
22    pub timeout: Duration,
23}
24
25fn default_aws_sts_endpoint() -> String {
26    DEFAULT_AWS_STS_ENDPOINT.to_string()
27}
28
29fn default_timeout() -> Duration {
30    Duration::from_secs(DEFAULT_HTTPS_TIMEOUT_SEC)
31}
32
33impl Default for AwsConfig {
34    fn default() -> Self {
35        AwsConfig {
36            aws_sts_endpoint: default_aws_sts_endpoint(),
37            forced_datetime_utc: None,
38            timeout: default_timeout(),
39        }
40    }
41}
42
43#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
44#[serde(tag = "kind")]
45pub enum AwsType {
46    AwsId,
47    AwsSecret(AwsConfig),
48    AwsSession,
49}
50
51#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
52#[serde(rename_all = "UPPERCASE")]
53pub enum HttpMethod {
54    Get,
55    Post,
56    Put,
57    Delete,
58    Patch,
59}
60
61impl FromStr for HttpMethod {
62    type Err = String;
63
64    fn from_str(s: &str) -> Result<Self, Self::Err> {
65        match s.to_uppercase().as_str() {
66            "GET" => Ok(HttpMethod::Get),
67            "POST" => Ok(HttpMethod::Post),
68            "PUT" => Ok(HttpMethod::Put),
69            "DELETE" => Ok(HttpMethod::Delete),
70            "PATCH" => Ok(HttpMethod::Patch),
71            _ => Err(format!("Invalid HTTP method: {}", s)),
72        }
73    }
74}
75
76#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
77pub struct RequestHeader {
78    pub key: String,
79    // $MATCH is a special keyword that will be replaced by the matched string
80    pub value: String,
81}
82
83impl RequestHeader {
84    pub fn get_value_with_match(&self, matche: &str) -> String {
85        // Replace $MATCH in value
86        let mut value = self.value.clone();
87        value = value.replace("$MATCH", matche);
88        value
89    }
90}
91
92#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
93pub struct HttpValidatorOption {
94    pub timeout: Duration,
95    // TODO(trosenblatt) add more options
96    // pub max_retries: u64,
97    // pub retry_delay: u64,
98}
99
100#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
101pub struct CustomHttpConfig {
102    pub endpoint: String,
103    #[serde(default)]
104    pub hosts: Vec<String>,
105    #[serde(default = "default_http_method")]
106    pub http_method: HttpMethod,
107    pub request_headers: BTreeMap<String, String>,
108    #[serde(default = "default_valid_http_status_code")]
109    pub valid_http_status_code: Vec<HttpStatusCodeRange>,
110    #[serde(default = "default_invalid_http_status_code")]
111    pub invalid_http_status_code: Vec<HttpStatusCodeRange>,
112    #[serde(default = "default_timeout_seconds")]
113    pub timeout_seconds: u32,
114}
115
116impl Default for CustomHttpConfig {
117    fn default() -> Self {
118        CustomHttpConfig {
119            endpoint: "".to_string(),
120            hosts: vec![],
121            http_method: HttpMethod::Get,
122            request_headers: BTreeMap::new(),
123            valid_http_status_code: vec![],
124            invalid_http_status_code: vec![],
125            timeout_seconds: DEFAULT_HTTPS_TIMEOUT_SEC as u32,
126        }
127    }
128}
129
130impl CustomHttpConfig {
131    pub fn get_endpoints(&self) -> Result<Vec<String>, String> {
132        // Handle errors cases
133        // - endpoint contains $HOST but no hosts are provided
134        // - endpoint does not contain $HOST but hosts are provided
135        if self.endpoint.contains("$HOST") && self.hosts.is_empty() {
136            return Err("Endpoint contains $HOST but no hosts are provided".to_string());
137        }
138        if !self.endpoint.contains("$HOST") && !self.hosts.is_empty() {
139            return Err("Endpoint does not contain $HOST but hosts are provided".to_string());
140        }
141
142        // Replace $HOST in endpoint and build the endpoints vector
143        let mut endpoints = vec![];
144        for host in self.hosts.clone() {
145            endpoints.push(self.endpoint.replace("$HOST", &host));
146        }
147        if endpoints.is_empty() {
148            // If no hosts are provided, use the endpoint as is
149            endpoints.push(self.endpoint.to_string());
150        }
151        Ok(endpoints)
152    }
153
154    // Builders
155
156    pub fn with_endpoint(mut self, endpoint: String) -> Self {
157        self.endpoint = endpoint;
158        self
159    }
160
161    pub fn with_hosts(mut self, hosts: Vec<String>) -> Self {
162        self.hosts = hosts;
163        self
164    }
165
166    pub fn with_request_headers(mut self, request_headers: BTreeMap<String, String>) -> Self {
167        self.request_headers = request_headers;
168        self
169    }
170
171    pub fn with_valid_http_status_code(
172        mut self,
173        valid_http_status_code: Vec<HttpStatusCodeRange>,
174    ) -> Self {
175        self.valid_http_status_code = valid_http_status_code;
176        self
177    }
178
179    pub fn with_invalid_http_status_code(
180        mut self,
181        invalid_http_status_code: Vec<HttpStatusCodeRange>,
182    ) -> Self {
183        self.invalid_http_status_code = invalid_http_status_code;
184        self
185    }
186
187    // Setters
188
189    pub fn set_endpoint(&mut self, endpoint: String) {
190        self.endpoint = endpoint;
191    }
192
193    pub fn set_hosts(&mut self, hosts: Vec<String>) {
194        self.hosts = hosts;
195    }
196
197    pub fn set_http_method(&mut self, http_method: HttpMethod) {
198        self.http_method = http_method;
199    }
200
201    pub fn set_request_headers(&mut self, request_headers: BTreeMap<String, String>) {
202        self.request_headers = request_headers;
203    }
204
205    pub fn set_valid_http_status_code(&mut self, valid_http_status_code: Vec<HttpStatusCodeRange>) {
206        self.valid_http_status_code = valid_http_status_code;
207    }
208
209    pub fn set_invalid_http_status_code(
210        &mut self,
211        invalid_http_status_code: Vec<HttpStatusCodeRange>,
212    ) {
213        self.invalid_http_status_code = invalid_http_status_code;
214    }
215
216    pub fn set_timeout_seconds(&mut self, timeout_seconds: u32) {
217        self.timeout_seconds = timeout_seconds;
218    }
219}
220
221fn default_timeout_seconds() -> u32 {
222    DEFAULT_HTTPS_TIMEOUT_SEC as u32
223}
224
225fn default_http_method() -> HttpMethod {
226    HttpMethod::Get
227}
228
229fn default_valid_http_status_code() -> Vec<HttpStatusCodeRange> {
230    vec![HttpStatusCodeRange {
231        start: 200,
232        end: 300,
233    }]
234}
235
236fn default_invalid_http_status_code() -> Vec<HttpStatusCodeRange> {
237    vec![HttpStatusCodeRange {
238        start: 400,
239        end: 500,
240    }]
241}
242
243#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
244pub struct HttpStatusCodeRange {
245    pub start: u16,
246    pub end: u16,
247}
248
249#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
250#[serde(tag = "type", content = "config")]
251pub enum MatchValidationType {
252    Aws(AwsType),
253    CustomHttp(CustomHttpConfig),
254}
255
256impl MatchValidationType {
257    // Method used to check if the validator can be created based on this type
258    pub fn can_create_match_validator(&self) -> bool {
259        match self {
260            MatchValidationType::Aws(aws_type) => matches!(aws_type, AwsType::AwsSecret(_)),
261            MatchValidationType::CustomHttp(_) => true,
262        }
263    }
264    pub fn get_internal_match_validation_type(&self) -> InternalMatchValidationType {
265        match self {
266            MatchValidationType::Aws(_) => InternalMatchValidationType::Aws,
267            MatchValidationType::CustomHttp(http_config) => {
268                InternalMatchValidationType::CustomHttp(http_config.get_endpoints().unwrap())
269            }
270        }
271    }
272    pub fn into_match_validator(&self) -> Result<Box<dyn MatchValidator>, String> {
273        match self {
274            MatchValidationType::Aws(aws_type) => match aws_type {
275                AwsType::AwsSecret(aws_config) => {
276                    Ok(Box::new(AwsValidator::new(aws_config.clone())))
277                }
278                _ => Err("This aws type shall not be used to create a validator".to_string()),
279            },
280            MatchValidationType::CustomHttp(http_config) => Ok(Box::new(
281                HttpValidator::new_from_config(http_config.clone()),
282            )),
283        }
284    }
285}
286
287// This is the match validation type stored in the compiled rule
288// It is used to retrieve the MatchValidator. We don't need the full configuration for that purpose
289// as it would be heavy to compute hash and compare the full configuration.
290#[derive(PartialEq, Eq, Hash)]
291pub enum InternalMatchValidationType {
292    Aws,
293    CustomHttp(Vec<String>),
294}