dd_sds/scanner/regex_rule/
config.rs

1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{RegexPatternCaptureGroupsValidationError, validate_and_create_regex};
7use crate::{CompiledRule, CreateScannerError, Labels};
8use regex_automata::util::captures::GroupInfo;
9use serde::{Deserialize, Serialize};
10use serde_with::DefaultOnNull;
11use serde_with::serde_as;
12use std::sync::Arc;
13use strum::{AsRefStr, EnumIter};
14
15pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
16
17#[serde_as]
18#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
19pub struct RegexRuleConfig {
20    pub pattern: String,
21    pub proximity_keywords: Option<ProximityKeywordsConfig>,
22    pub validator: Option<SecondaryValidator>,
23    #[serde_as(deserialize_as = "DefaultOnNull")]
24    #[serde(default)]
25    pub labels: Labels,
26    pub pattern_capture_groups: Option<Vec<String>>,
27}
28
29impl RegexRuleConfig {
30    pub fn new(pattern: &str) -> Self {
31        #[allow(deprecated)]
32        Self {
33            pattern: pattern.to_owned(),
34            proximity_keywords: None,
35            validator: None,
36            labels: Labels::default(),
37            pattern_capture_groups: None,
38        }
39    }
40
41    pub fn with_pattern(&self, pattern: &str) -> Self {
42        self.mutate_clone(|x| x.pattern = pattern.to_string())
43    }
44
45    pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
46        self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
47    }
48
49    pub fn with_labels(&self, labels: Labels) -> Self {
50        self.mutate_clone(|x| x.labels = labels)
51    }
52
53    pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
54        self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
55    }
56
57    pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
58        self.mutate_clone(|x| match x.pattern_capture_groups {
59            Some(ref mut pattern_capture_groups) => {
60                pattern_capture_groups.push(pattern_capture_group.to_string());
61            }
62            None => {
63                x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
64            }
65        })
66    }
67
68    pub fn build(&self) -> Arc<dyn RuleConfig> {
69        Arc::new(self.clone())
70    }
71
72    fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
73        let mut clone = self.clone();
74        modify(&mut clone);
75        clone
76    }
77
78    pub fn with_included_keywords(
79        &self,
80        keywords: impl IntoIterator<Item = impl AsRef<str>>,
81    ) -> Self {
82        let mut this = self.clone();
83        let mut config = self.get_or_create_proximity_keywords_config();
84        config.included_keywords = keywords
85            .into_iter()
86            .map(|x| x.as_ref().to_string())
87            .collect::<Vec<_>>();
88        this.proximity_keywords = Some(config);
89        this
90    }
91
92    pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
93        let mut this = self.clone();
94        this.validator = validator;
95        this
96    }
97
98    fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
99        self.proximity_keywords
100            .clone()
101            .unwrap_or_else(|| ProximityKeywordsConfig {
102                look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
103                included_keywords: vec![],
104                excluded_keywords: vec![],
105            })
106    }
107}
108
109fn is_pattern_capture_groups_valid(
110    pattern_capture_groups: &Option<Vec<String>>,
111    group_info: &GroupInfo,
112) -> Result<(), RegexPatternCaptureGroupsValidationError> {
113    if pattern_capture_groups.is_none() {
114        return Ok(());
115    }
116    let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
117    if pattern_capture_groups.len() != 1 {
118        // We currently only allow one capture group
119        return Err(
120            RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
121                pattern_capture_groups.len(),
122            ),
123        );
124    }
125    let pattern_capture_group = pattern_capture_groups.first().unwrap();
126    if !group_info
127        .all_names()
128        .filter(|(_, _, name)| name.is_some())
129        .map(|(_, _, name)| name.unwrap())
130        .any(|name| name == pattern_capture_group)
131    {
132        return Err(
133            RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
134                pattern_capture_group.clone(),
135            ),
136        );
137    }
138    // At this point, the capture group is in the regex, and there is exactly one.
139    // Currently, it must be called `sds_match`.
140    if pattern_capture_group != "sds_match" {
141        return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
142    }
143    Ok(())
144}
145
146impl RuleConfig for RegexRuleConfig {
147    fn convert_to_compiled_rule(
148        &self,
149        rule_index: usize,
150        scanner_labels: Labels,
151    ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
152        let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
153
154        let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
155
156        let (included_keywords, excluded_keywords) = self
157            .proximity_keywords
158            .as_ref()
159            .map(|config| compile_keywords_proximity_config(config, &rule_labels))
160            .unwrap_or(Ok((None, None)))?;
161
162        is_pattern_capture_groups_valid(&self.pattern_capture_groups, regex.group_info())?;
163
164        Ok(Box::new(RegexCompiledRule {
165            rule_index,
166            regex,
167            included_keywords,
168            excluded_keywords,
169            validator: self.validator.clone().map(|x| x.compile()),
170            metrics: RuleMetrics::new(&rule_labels),
171            pattern_capture_groups: self.pattern_capture_groups.clone(),
172        }))
173    }
174}
175
176#[serde_as]
177#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
178pub struct ProximityKeywordsConfig {
179    pub look_ahead_character_count: usize,
180
181    #[serde_as(deserialize_as = "DefaultOnNull")]
182    #[serde(default)]
183    pub included_keywords: Vec<String>,
184
185    #[serde_as(deserialize_as = "DefaultOnNull")]
186    #[serde(default)]
187    pub excluded_keywords: Vec<String>,
188}
189
190#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
191#[serde(tag = "type")]
192pub enum SecondaryValidator {
193    AbaRtnChecksum,
194    BrazilianCnpjChecksum,
195    BrazilianCpfChecksum,
196    BtcChecksum,
197    BulgarianEGNChecksum,
198    ChineseIdChecksum,
199    CoordinationNumberChecksum,
200    CzechPersonalIdentificationNumberChecksum,
201    CzechTaxIdentificationNumberChecksum,
202    DutchBsnChecksum,
203    DutchPassportChecksum,
204    EntropyCheck,
205    EthereumChecksum,
206    FinnishHetuChecksum,
207    FranceNifChecksum,
208    FranceSsnChecksum,
209    GermanIdsChecksum,
210    GermanSvnrChecksum,
211    GithubTokenChecksum,
212    GreekTinChecksum,
213    HungarianTinChecksum,
214    IbanChecker,
215    IrishPpsChecksum,
216    ItalianNationalIdChecksum,
217    JwtClaimsValidator { config: JwtClaimsValidatorConfig },
218    JwtExpirationChecker,
219    LatviaNationalIdChecksum,
220    LithuanianPersonalIdentificationNumberChecksum,
221    LuhnChecksum,
222    LuxembourgIndividualNINChecksum,
223    Mod11_10checksum,
224    Mod11_2checksum,
225    Mod1271_36Checksum,
226    Mod27_26checksum,
227    Mod37_2checksum,
228    Mod37_36checksum,
229    Mod661_26checksum,
230    Mod97_10checksum,
231    MoneroAddress,
232    NhsCheckDigit,
233    NirChecksum,
234    PolishNationalIdChecksum,
235    PolishNipChecksum,
236    PortugueseTaxIdChecksum,
237    RodneCisloNumberChecksum,
238    RomanianPersonalNumericCode,
239    SlovenianPINChecksum,
240    SpanishDniChecksum,
241    SpanishNussChecksum,
242    SwedenPINChecksum,
243}
244
245#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
246#[serde(tag = "type", content = "config")]
247pub enum ClaimRequirement {
248    /// Just check that the claim exists
249    Present,
250    /// Check that the claim exists and is not expired
251    NotExpired,
252    /// Check that the claim exists and has an exact value
253    ExactValue(String),
254    /// Check that the claim exists and matches a regex pattern
255    RegexMatch(String),
256}
257
258#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
259pub struct JwtClaimsValidatorConfig {
260    #[serde(default)]
261    pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
262    #[serde(default)]
263    pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
264}
265
266#[cfg(test)]
267mod test {
268    use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
269    use std::collections::BTreeMap;
270    use strum::IntoEnumIterator;
271
272    use super::*;
273
274    #[test]
275    fn should_override_pattern() {
276        let rule_config = RegexRuleConfig::new("123").with_pattern("456");
277        assert_eq!(rule_config.pattern, "456");
278    }
279
280    #[test]
281    #[allow(deprecated)]
282    fn should_have_default() {
283        let rule_config = RegexRuleConfig::new("123");
284        assert_eq!(
285            rule_config,
286            RegexRuleConfig {
287                pattern: "123".to_string(),
288                proximity_keywords: None,
289                validator: None,
290                labels: Labels::empty(),
291                pattern_capture_groups: None,
292            }
293        );
294    }
295
296    #[test]
297    fn should_use_capture_group() {
298        let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
299            .with_pattern_capture_groups(vec!["capture_group".to_string()]);
300        assert_eq!(
301            rule_config,
302            RegexRuleConfig {
303                pattern: "hey (?<capture_group>world)".to_string(),
304                proximity_keywords: None,
305                validator: None,
306                labels: Labels::empty(),
307                pattern_capture_groups: Some(vec!["capture_group".to_string()]),
308            }
309        );
310    }
311
312    #[test]
313    fn proximity_keywords_should_have_default() {
314        let json_config = r#"{"look_ahead_character_count": 0}"#;
315        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
316        assert_eq!(
317            test,
318            ProximityKeywordsConfig {
319                look_ahead_character_count: 0,
320                included_keywords: vec![],
321                excluded_keywords: vec![]
322            }
323        );
324
325        let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
326        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
327        assert_eq!(
328            test,
329            ProximityKeywordsConfig {
330                look_ahead_character_count: 0,
331                included_keywords: vec![],
332                excluded_keywords: vec![]
333            }
334        );
335    }
336
337    #[test]
338    #[allow(deprecated)]
339    fn test_third_party_active_checker() {
340        // Test setting only the new field
341        let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
342        let validation_type = MatchValidationType::CustomHttp(http_config.clone());
343        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
344            .third_party_active_checker(validation_type.clone());
345
346        assert_eq!(
347            rule_config.third_party_active_checker,
348            Some(validation_type.clone())
349        );
350        assert_eq!(rule_config.match_validation_type, None);
351        assert_eq!(
352            rule_config.get_third_party_active_checker(),
353            Some(&validation_type)
354        );
355
356        // Test setting via deprecated field updates both
357        let aws_type = AwsType::AwsId;
358        let validation_type2 = MatchValidationType::Aws(aws_type);
359        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
360            .third_party_active_checker(validation_type2.clone());
361
362        assert_eq!(
363            rule_config.third_party_active_checker,
364            Some(validation_type2.clone())
365        );
366        assert_eq!(
367            rule_config.get_third_party_active_checker(),
368            Some(&validation_type2)
369        );
370
371        // Test that get_match_validation_type prioritizes third_party_active_checker
372        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
373            .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
374
375        assert_eq!(
376            rule_config.get_third_party_active_checker(),
377            Some(&MatchValidationType::CustomHttp(http_config.clone()))
378        );
379    }
380
381    #[test]
382    fn test_secondary_validator_enum_iter() {
383        // Test that we can iterate over all SecondaryValidator variants
384        let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
385        // Verify some variants
386        assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
387        assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
388    }
389
390    #[test]
391    fn test_secondary_validator_are_sorted() {
392        let validator_names: Vec<String> = SecondaryValidator::iter()
393            .map(|a| a.as_ref().to_string())
394            .collect();
395        let mut sorted_validator_names = validator_names.clone();
396        sorted_validator_names.sort();
397        assert_eq!(
398            sorted_validator_names, validator_names,
399            "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
400        );
401    }
402
403    // The order has to be stable to pass linter checks. Otherwise, each instantiation will change the file
404    #[test]
405    fn test_jwt_claims_validator_config_serialization_order() {
406        // Create a config with claims in non-alphabetical order
407        let mut required_claims = BTreeMap::new();
408        required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
409        required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
410        required_claims.insert(
411            "aaa".to_string(),
412            ClaimRequirement::ExactValue("test".to_string()),
413        );
414        required_claims.insert(
415            "mmm".to_string(),
416            ClaimRequirement::RegexMatch(r"^test.*".to_string()),
417        );
418
419        let config = JwtClaimsValidatorConfig {
420            required_claims,
421            required_headers: std::collections::BTreeMap::new(),
422        };
423
424        // Serialize multiple times to ensure stable order
425        let serialized1 = serde_json::to_string(&config).unwrap();
426        let serialized2 = serde_json::to_string(&config).unwrap();
427
428        // Both serializations should be identical
429        assert_eq!(serialized1, serialized2, "Serialization should be stable");
430
431        // Keys should be in alphabetical order
432        assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
433        assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
434        assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
435    }
436
437    #[test]
438    fn test_capture_groups_validation() {
439        let test_cases: Vec<(
440            &str,
441            Vec<String>,
442            Result<(), RegexPatternCaptureGroupsValidationError>,
443        )> = vec![
444            (
445                "hello (?<sds_match>world)",
446                vec!["sds_match".to_string()],
447                Ok(()),
448            ),
449            (
450                "hello (?<capture_group>world)",
451                vec!["capture_group".to_string()],
452                Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
453            ),
454            (
455                "hello (?<sds_match>world) and (?<another_group>world)",
456                vec!["sds_match".to_string()],
457                Ok(()),
458            ),
459            (
460                "hello (?<capture_grou>world)",
461                vec!["capture_group".to_string()],
462                Err(
463                    RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
464                        "capture_group".to_string(),
465                    ),
466                ),
467            ),
468            (
469                "hello (?<sds_match>world)",
470                vec!["sds_match".to_string(), "sds_match2".to_string()],
471                Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
472            ),
473        ];
474        for (pattern, capture_groups, expected_result) in test_cases {
475            let rule_config =
476                RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
477            assert_eq!(
478                is_pattern_capture_groups_valid(
479                    &rule_config.pattern_capture_groups,
480                    &get_memoized_regex(pattern, validate_and_create_regex)
481                        .unwrap()
482                        .group_info()
483                ),
484                expected_result
485            );
486        }
487    }
488}