dd_sds/match_validation/
config.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::str::FromStr;
5use std::{hash::Hash, time::Duration};
6
7use crate::match_validation::http_validator_v2::HttpValidatorV2;
8
9use super::aws_validator::AwsValidator;
10use super::config_v2::{CustomHttpConfigV2, PairedValidatorConfig};
11use super::http_validator::HttpValidator;
12use super::match_validator::MatchValidator;
13
14pub const DEFAULT_HTTPS_TIMEOUT_SEC: u64 = 3;
15pub const DEFAULT_AWS_STS_ENDPOINT: &str = "https://sts.amazonaws.com";
16
17#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
18pub struct AwsConfig {
19    // Override default AWS STS endpoint for testing
20    #[serde(default = "default_aws_sts_endpoint")]
21    pub aws_sts_endpoint: String,
22    // Override default datetime for testing
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub forced_datetime_utc: Option<DateTime<Utc>>,
25    #[serde(default = "default_timeout")]
26    pub timeout: Duration,
27}
28
29fn default_aws_sts_endpoint() -> String {
30    DEFAULT_AWS_STS_ENDPOINT.to_string()
31}
32
33fn default_timeout() -> Duration {
34    Duration::from_secs(DEFAULT_HTTPS_TIMEOUT_SEC)
35}
36
37impl Default for AwsConfig {
38    fn default() -> Self {
39        AwsConfig {
40            aws_sts_endpoint: default_aws_sts_endpoint(),
41            forced_datetime_utc: None,
42            timeout: default_timeout(),
43        }
44    }
45}
46
47#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
48#[serde(tag = "kind")]
49pub enum AwsType {
50    AwsId,
51    AwsSecret(AwsConfig),
52    AwsSession,
53}
54
55#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
56#[serde(rename_all = "UPPERCASE")]
57pub enum HttpMethod {
58    Get,
59    Post,
60    Put,
61    Delete,
62    Patch,
63}
64
65impl FromStr for HttpMethod {
66    type Err = String;
67
68    fn from_str(s: &str) -> Result<Self, Self::Err> {
69        match s.to_uppercase().as_str() {
70            "GET" => Ok(HttpMethod::Get),
71            "POST" => Ok(HttpMethod::Post),
72            "PUT" => Ok(HttpMethod::Put),
73            "DELETE" => Ok(HttpMethod::Delete),
74            "PATCH" => Ok(HttpMethod::Patch),
75            _ => Err(format!("Invalid HTTP method: {s}")),
76        }
77    }
78}
79
80#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
81pub struct RequestHeader {
82    pub key: String,
83    // $MATCH is a special keyword that will be replaced by the matched string
84    pub value: String,
85}
86
87impl RequestHeader {
88    pub fn get_value_with_match(&self, matche: &str) -> String {
89        // Replace $MATCH in value
90        let mut value = self.value.clone();
91        value = value.replace("$MATCH", matche);
92        value
93    }
94}
95
96#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
97pub struct HttpValidatorOption {
98    pub timeout: Duration,
99    // TODO(trosenblatt) add more options
100    // pub max_retries: u64,
101    // pub retry_delay: u64,
102}
103
104#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
105pub struct CustomHttpConfig {
106    pub endpoint: String,
107    #[serde(default)]
108    pub hosts: Vec<String>,
109    #[serde(default = "default_http_method")]
110    pub http_method: HttpMethod,
111    pub request_headers: BTreeMap<String, String>,
112    #[serde(default = "default_valid_http_status_code")]
113    pub valid_http_status_code: Vec<HttpStatusCodeRange>,
114    #[serde(default = "default_invalid_http_status_code")]
115    pub invalid_http_status_code: Vec<HttpStatusCodeRange>,
116    #[serde(default = "default_timeout_seconds")]
117    pub timeout_seconds: u32,
118}
119
120impl Default for CustomHttpConfig {
121    fn default() -> Self {
122        CustomHttpConfig {
123            endpoint: "".to_string(),
124            hosts: vec![],
125            http_method: HttpMethod::Get,
126            request_headers: BTreeMap::new(),
127            valid_http_status_code: vec![],
128            invalid_http_status_code: vec![],
129            timeout_seconds: DEFAULT_HTTPS_TIMEOUT_SEC as u32,
130        }
131    }
132}
133
134impl CustomHttpConfig {
135    pub fn get_endpoints(&self) -> Result<Vec<String>, String> {
136        // Handle errors cases
137        // - endpoint contains $HOST but no hosts are provided
138        // - endpoint does not contain $HOST but hosts are provided
139        if self.endpoint.contains("$HOST") && self.hosts.is_empty() {
140            return Err("Endpoint contains $HOST but no hosts are provided".to_string());
141        }
142        if !self.endpoint.contains("$HOST") && !self.hosts.is_empty() {
143            return Err("Endpoint does not contain $HOST but hosts are provided".to_string());
144        }
145
146        // Replace $HOST in endpoint and build the endpoints vector
147        let mut endpoints = vec![];
148        for host in self.hosts.clone() {
149            endpoints.push(self.endpoint.replace("$HOST", &host));
150        }
151        if endpoints.is_empty() {
152            // If no hosts are provided, use the endpoint as is
153            endpoints.push(self.endpoint.to_string());
154        }
155        Ok(endpoints)
156    }
157
158    // Builders
159
160    pub fn with_endpoint(mut self, endpoint: String) -> Self {
161        self.endpoint = endpoint;
162        self
163    }
164
165    pub fn with_hosts(mut self, hosts: Vec<String>) -> Self {
166        self.hosts = hosts;
167        self
168    }
169
170    pub fn with_request_headers(mut self, request_headers: BTreeMap<String, String>) -> Self {
171        self.request_headers = request_headers;
172        self
173    }
174
175    pub fn with_valid_http_status_code(
176        mut self,
177        valid_http_status_code: Vec<HttpStatusCodeRange>,
178    ) -> Self {
179        self.valid_http_status_code = valid_http_status_code;
180        self
181    }
182
183    pub fn with_invalid_http_status_code(
184        mut self,
185        invalid_http_status_code: Vec<HttpStatusCodeRange>,
186    ) -> Self {
187        self.invalid_http_status_code = invalid_http_status_code;
188        self
189    }
190
191    // Setters
192
193    pub fn set_endpoint(&mut self, endpoint: String) {
194        self.endpoint = endpoint;
195    }
196
197    pub fn set_hosts(&mut self, hosts: Vec<String>) {
198        self.hosts = hosts;
199    }
200
201    pub fn set_http_method(&mut self, http_method: HttpMethod) {
202        self.http_method = http_method;
203    }
204
205    pub fn set_request_headers(&mut self, request_headers: BTreeMap<String, String>) {
206        self.request_headers = request_headers;
207    }
208
209    pub fn set_valid_http_status_code(&mut self, valid_http_status_code: Vec<HttpStatusCodeRange>) {
210        self.valid_http_status_code = valid_http_status_code;
211    }
212
213    pub fn set_invalid_http_status_code(
214        &mut self,
215        invalid_http_status_code: Vec<HttpStatusCodeRange>,
216    ) {
217        self.invalid_http_status_code = invalid_http_status_code;
218    }
219
220    pub fn set_timeout_seconds(&mut self, timeout_seconds: u32) {
221        self.timeout_seconds = timeout_seconds;
222    }
223}
224
225fn default_timeout_seconds() -> u32 {
226    DEFAULT_HTTPS_TIMEOUT_SEC as u32
227}
228
229fn default_http_method() -> HttpMethod {
230    HttpMethod::Get
231}
232
233fn default_valid_http_status_code() -> Vec<HttpStatusCodeRange> {
234    vec![HttpStatusCodeRange {
235        start: 200,
236        end: 300,
237    }]
238}
239
240fn default_invalid_http_status_code() -> Vec<HttpStatusCodeRange> {
241    vec![HttpStatusCodeRange {
242        start: 400,
243        end: 500,
244    }]
245}
246
247#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
248pub struct HttpStatusCodeRange {
249    pub start: u16,
250    pub end: u16,
251}
252
253#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
254#[serde(tag = "type", content = "config")]
255pub enum MatchValidationType {
256    Aws(AwsType),
257    CustomHttp(CustomHttpConfig),
258    CustomHttpV2(CustomHttpConfigV2),
259    PairedValidator(PairedValidatorConfig),
260}
261
262impl MatchValidationType {
263    // Method used to check if the validator can be created based on this type
264    pub fn can_create_match_validator(&self) -> bool {
265        match self {
266            MatchValidationType::Aws(aws_type) => matches!(aws_type, AwsType::AwsSecret(_)),
267            MatchValidationType::CustomHttp(_) => true,
268            MatchValidationType::CustomHttpV2(_) => true,
269            MatchValidationType::PairedValidator(_) => false, // Paired validators don't create standalone validators
270        }
271    }
272    pub fn get_internal_match_validation_type(&self) -> InternalMatchValidationType {
273        match self {
274            MatchValidationType::Aws(_) => InternalMatchValidationType::Aws,
275            MatchValidationType::CustomHttp(http_config) => {
276                InternalMatchValidationType::CustomHttp(http_config.get_endpoints().unwrap())
277            }
278            MatchValidationType::CustomHttpV2(_) => InternalMatchValidationType::CustomHttpV2,
279            MatchValidationType::PairedValidator(config) => {
280                InternalMatchValidationType::PairedValidator(config.kind.clone())
281            }
282        }
283    }
284    pub fn into_match_validator(&self) -> Result<Box<dyn MatchValidator>, String> {
285        match self {
286            MatchValidationType::Aws(aws_type) => match aws_type {
287                AwsType::AwsSecret(aws_config) => {
288                    Ok(Box::new(AwsValidator::new(aws_config.clone())))
289                }
290                _ => Err("This aws type shall not be used to create a validator".to_string()),
291            },
292            MatchValidationType::CustomHttp(http_config) => Ok(Box::new(
293                HttpValidator::new_from_config(http_config.clone()),
294            )),
295            MatchValidationType::CustomHttpV2(http_config_v2) => Ok(Box::new(
296                HttpValidatorV2::new_from_config(http_config_v2.clone()),
297            )),
298            MatchValidationType::PairedValidator(_) => {
299                Err("PairedValidator cannot be used to create a standalone validator".to_string())
300            }
301        }
302    }
303}
304
305// This is the match validation type stored in the compiled rule
306// It is used to retrieve the MatchValidator. We don't need the full configuration for that purpose
307// as it would be heavy to compute hash and compare the full configuration.
308#[derive(PartialEq, Eq, Hash)]
309pub enum InternalMatchValidationType {
310    Aws,
311    CustomHttp(Vec<String>),
312    CustomHttpV2,
313    PairedValidator(String), // Stores the vendor kind
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_serialization_of_aws_config() {
322        let aws_config = AwsConfig {
323            aws_sts_endpoint: "https://sts.amazonaws.com".to_string(),
324            forced_datetime_utc: None,
325            timeout: Duration::from_secs(3),
326        };
327        let serialized = serde_json::to_string(&aws_config).unwrap();
328        // The forced_datetime_utc is not serialized because it is None
329        assert_eq!(
330            serialized,
331            "{\"aws_sts_endpoint\":\"https://sts.amazonaws.com\",\"timeout\":{\"secs\":3,\"nanos\":0}}"
332        );
333    }
334}