dd_sds/scanner/regex_rule/
config.rs

1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{
7    RegexPatternCaptureGroupsValidationError, validate_and_create_regex,
8    validate_named_capture_group_minimum_length,
9};
10use crate::{CompiledRule, CreateScannerError, Labels};
11use regex_automata::util::captures::GroupInfo;
12use serde::{Deserialize, Serialize};
13use serde_with::DefaultOnNull;
14use serde_with::serde_as;
15use std::sync::Arc;
16use strum::{AsRefStr, EnumIter};
17
18pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
19
20#[serde_as]
21#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
22pub struct RegexRuleConfig {
23    pub pattern: String,
24    pub proximity_keywords: Option<ProximityKeywordsConfig>,
25    pub validator: Option<SecondaryValidator>,
26    #[serde_as(deserialize_as = "DefaultOnNull")]
27    #[serde(default)]
28    pub labels: Labels,
29    pub pattern_capture_groups: Option<Vec<String>>,
30}
31
32impl RegexRuleConfig {
33    pub fn new(pattern: &str) -> Self {
34        #[allow(deprecated)]
35        Self {
36            pattern: pattern.to_owned(),
37            proximity_keywords: None,
38            validator: None,
39            labels: Labels::default(),
40            pattern_capture_groups: None,
41        }
42    }
43
44    pub fn with_pattern(&self, pattern: &str) -> Self {
45        self.mutate_clone(|x| x.pattern = pattern.to_string())
46    }
47
48    pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
49        self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
50    }
51
52    pub fn with_labels(&self, labels: Labels) -> Self {
53        self.mutate_clone(|x| x.labels = labels)
54    }
55
56    pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
57        self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
58    }
59
60    pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
61        self.mutate_clone(|x| match x.pattern_capture_groups {
62            Some(ref mut pattern_capture_groups) => {
63                pattern_capture_groups.push(pattern_capture_group.to_string());
64            }
65            None => {
66                x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
67            }
68        })
69    }
70
71    pub fn build(&self) -> Arc<dyn RuleConfig> {
72        Arc::new(self.clone())
73    }
74
75    fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
76        let mut clone = self.clone();
77        modify(&mut clone);
78        clone
79    }
80
81    pub fn with_included_keywords(
82        &self,
83        keywords: impl IntoIterator<Item = impl AsRef<str>>,
84    ) -> Self {
85        let mut this = self.clone();
86        let mut config = self.get_or_create_proximity_keywords_config();
87        config.included_keywords = keywords
88            .into_iter()
89            .map(|x| x.as_ref().to_string())
90            .collect::<Vec<_>>();
91        this.proximity_keywords = Some(config);
92        this
93    }
94
95    pub fn with_excluded_keywords(
96        &self,
97        keywords: impl IntoIterator<Item = impl AsRef<str>>,
98    ) -> Self {
99        let mut this = self.clone();
100        let mut config = self.get_or_create_proximity_keywords_config();
101        config.excluded_keywords = keywords
102            .into_iter()
103            .map(|x| x.as_ref().to_string())
104            .collect::<Vec<_>>();
105        this.proximity_keywords = Some(config);
106        this
107    }
108
109    pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
110        let mut this = self.clone();
111        this.validator = validator;
112        this
113    }
114
115    fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
116        self.proximity_keywords
117            .clone()
118            .unwrap_or_else(|| ProximityKeywordsConfig {
119                look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
120                included_keywords: vec![],
121                excluded_keywords: vec![],
122            })
123    }
124}
125
126fn is_pattern_capture_groups_valid(
127    pattern: &str,
128    pattern_capture_groups: &Option<Vec<String>>,
129    group_info: &GroupInfo,
130) -> Result<(), RegexPatternCaptureGroupsValidationError> {
131    if pattern_capture_groups.is_none() {
132        return Ok(());
133    }
134    let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
135    if pattern_capture_groups.len() != 1 {
136        // We currently only allow one capture group
137        return Err(
138            RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
139                pattern_capture_groups.len(),
140            ),
141        );
142    }
143    let pattern_capture_group = pattern_capture_groups.first().unwrap();
144    if !group_info
145        .all_names()
146        .filter(|(_, _, name)| name.is_some())
147        .map(|(_, _, name)| name.unwrap())
148        .any(|name| name == pattern_capture_group)
149    {
150        return Err(
151            RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
152                pattern_capture_group.clone(),
153            ),
154        );
155    }
156    // At this point, the capture group is in the regex, and there is exactly one.
157    // Currently, it must be called `sds_match`.
158    if pattern_capture_group != "sds_match" {
159        return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
160    }
161    validate_named_capture_group_minimum_length(pattern, pattern_capture_group)?;
162    Ok(())
163}
164
165impl RuleConfig for RegexRuleConfig {
166    fn convert_to_compiled_rule(
167        &self,
168        rule_index: usize,
169        scanner_labels: Labels,
170    ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
171        let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
172
173        let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
174
175        let (included_keywords, excluded_keywords) = self
176            .proximity_keywords
177            .as_ref()
178            .map(|config| compile_keywords_proximity_config(config, &rule_labels))
179            .unwrap_or(Ok((None, None)))?;
180
181        is_pattern_capture_groups_valid(
182            &self.pattern,
183            &self.pattern_capture_groups,
184            regex.group_info(),
185        )?;
186
187        Ok(Box::new(RegexCompiledRule {
188            rule_index,
189            regex,
190            included_keywords,
191            excluded_keywords,
192            validator: self.validator.clone().map(|x| x.compile()),
193            metrics: RuleMetrics::new(&rule_labels),
194            pattern_capture_groups: self.pattern_capture_groups.clone(),
195        }))
196    }
197
198    fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
199        Some(self)
200    }
201}
202
203#[serde_as]
204#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
205pub struct ProximityKeywordsConfig {
206    pub look_ahead_character_count: usize,
207
208    #[serde_as(deserialize_as = "DefaultOnNull")]
209    #[serde(default)]
210    pub included_keywords: Vec<String>,
211
212    #[serde_as(deserialize_as = "DefaultOnNull")]
213    #[serde(default)]
214    pub excluded_keywords: Vec<String>,
215}
216
217#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
218#[serde(tag = "type")]
219pub enum SecondaryValidator {
220    AbaRtnChecksum,
221    AtlassianTokenChecksum,
222    AustralianMedicareChecksum,
223    AustralianTfnChecksum,
224    AustrianSSNChecksum,
225    BelgiumNationalRegisterChecksum,
226    BrazilianCnpjChecksum,
227    BrazilianCpfChecksum,
228    BtcChecksum,
229    BulgarianEGNChecksum,
230    ChineseIdChecksum,
231    CoordinationNumberChecksum,
232    CzechPersonalIdentificationNumberChecksum,
233    CzechTaxIdentificationNumberChecksum,
234    DutchBsnChecksum,
235    DutchPassportChecksum,
236    EntropyCheck,
237    EstoniaPersonalCodeChecksum,
238    EthereumChecksum,
239    FinnishHetuChecksum,
240    FranceNifChecksum,
241    FranceSsnChecksum,
242    GermanIdsChecksum,
243    GermanSvnrChecksum,
244    GithubTokenChecksum,
245    GreeceAmkaChecksum,
246    GreekTinChecksum,
247    HungarianTinChecksum,
248    IbanChecker,
249    IrishPpsChecksum,
250    ItalianNationalIdChecksum,
251    JwtClaimsValidator { config: JwtClaimsValidatorConfig },
252    JwtExpirationChecker,
253    LatviaNationalIdChecksum,
254    LithuanianPersonalIdentificationNumberChecksum,
255    LuhnChecksum,
256    LuxembourgIndividualNINChecksum,
257    Mod11_10checksum,
258    Mod11_2checksum,
259    Mod1271_36Checksum,
260    Mod27_26checksum,
261    Mod37_2checksum,
262    Mod37_36checksum,
263    Mod661_26checksum,
264    Mod97_10checksum,
265    MoneroAddress,
266    NhsCheckDigit,
267    NirChecksum,
268    NonHexChecker,
269    PolishNationalIdChecksum,
270    PolishNipChecksum,
271    PortugueseTaxIdChecksum,
272    RodneCisloNumberChecksum,
273    RomanianPersonalNumericCode,
274    SingaporeNricChecksum,
275    SloveniaTinChecksum,
276    SlovenianPINChecksum,
277    SpanishDniChecksum,
278    SpanishNussChecksum,
279    SwedenPINChecksum,
280    UkNinoFormatCheck,
281    UkTrnChecksum,
282    UsDeaChecksum,
283    UsNpiChecksum,
284    VerhoeffChecksum,
285}
286
287#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
288#[serde(tag = "type", content = "config")]
289pub enum ClaimRequirement {
290    /// Just check that the claim exists
291    Present,
292    /// Check that the claim exists and is not expired
293    NotExpired,
294    /// Check that the claim exists and has an exact value
295    ExactValue(String),
296    /// Check that the claim exists and matches a regex pattern
297    RegexMatch(String),
298}
299
300#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
301pub struct JwtClaimsValidatorConfig {
302    #[serde(default)]
303    pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
304    #[serde(default)]
305    pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
306}
307
308#[cfg(test)]
309mod test {
310    use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
311    use std::collections::BTreeMap;
312    use strum::IntoEnumIterator;
313
314    use super::*;
315
316    #[test]
317    fn should_override_pattern() {
318        let rule_config = RegexRuleConfig::new("123").with_pattern("456");
319        assert_eq!(rule_config.pattern, "456");
320    }
321
322    #[test]
323    #[allow(deprecated)]
324    fn should_have_default() {
325        let rule_config = RegexRuleConfig::new("123");
326        assert_eq!(
327            rule_config,
328            RegexRuleConfig {
329                pattern: "123".to_string(),
330                proximity_keywords: None,
331                validator: None,
332                labels: Labels::empty(),
333                pattern_capture_groups: None,
334            }
335        );
336    }
337
338    #[test]
339    fn should_use_capture_group() {
340        let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
341            .with_pattern_capture_groups(vec!["capture_group".to_string()]);
342        assert_eq!(
343            rule_config,
344            RegexRuleConfig {
345                pattern: "hey (?<capture_group>world)".to_string(),
346                proximity_keywords: None,
347                validator: None,
348                labels: Labels::empty(),
349                pattern_capture_groups: Some(vec!["capture_group".to_string()]),
350            }
351        );
352    }
353
354    #[test]
355    fn proximity_keywords_should_have_default() {
356        let json_config = r#"{"look_ahead_character_count": 0}"#;
357        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
358        assert_eq!(
359            test,
360            ProximityKeywordsConfig {
361                look_ahead_character_count: 0,
362                included_keywords: vec![],
363                excluded_keywords: vec![]
364            }
365        );
366
367        let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
368        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
369        assert_eq!(
370            test,
371            ProximityKeywordsConfig {
372                look_ahead_character_count: 0,
373                included_keywords: vec![],
374                excluded_keywords: vec![]
375            }
376        );
377    }
378
379    #[test]
380    #[allow(deprecated)]
381    fn test_third_party_active_checker() {
382        // Test setting only the new field
383        let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
384        let validation_type = MatchValidationType::CustomHttp(http_config.clone());
385        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
386            .third_party_active_checker(validation_type.clone());
387
388        assert_eq!(
389            rule_config.third_party_active_checker,
390            Some(validation_type.clone())
391        );
392        assert_eq!(rule_config.match_validation_type, None);
393        assert_eq!(
394            rule_config.get_third_party_active_checker(),
395            Some(&validation_type)
396        );
397
398        // Test setting via deprecated field updates both
399        let aws_type = AwsType::AwsId;
400        let validation_type2 = MatchValidationType::Aws(aws_type);
401        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
402            .third_party_active_checker(validation_type2.clone());
403
404        assert_eq!(
405            rule_config.third_party_active_checker,
406            Some(validation_type2.clone())
407        );
408        assert_eq!(
409            rule_config.get_third_party_active_checker(),
410            Some(&validation_type2)
411        );
412
413        // Test that get_match_validation_type prioritizes third_party_active_checker
414        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
415            .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
416
417        assert_eq!(
418            rule_config.get_third_party_active_checker(),
419            Some(&MatchValidationType::CustomHttp(http_config.clone()))
420        );
421    }
422
423    #[test]
424    fn test_secondary_validator_enum_iter() {
425        // Test that we can iterate over all SecondaryValidator variants
426        let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
427        // Verify some variants
428        assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
429        assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
430    }
431
432    #[test]
433    fn test_secondary_validator_are_sorted() {
434        let validator_names: Vec<String> = SecondaryValidator::iter()
435            .map(|a| a.as_ref().to_string())
436            .collect();
437        let mut sorted_validator_names = validator_names.clone();
438        sorted_validator_names.sort();
439        assert_eq!(
440            sorted_validator_names, validator_names,
441            "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
442        );
443    }
444
445    // The order has to be stable to pass linter checks. Otherwise, each instantiation will change the file
446    #[test]
447    fn test_jwt_claims_validator_config_serialization_order() {
448        // Create a config with claims in non-alphabetical order
449        let mut required_claims = BTreeMap::new();
450        required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
451        required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
452        required_claims.insert(
453            "aaa".to_string(),
454            ClaimRequirement::ExactValue("test".to_string()),
455        );
456        required_claims.insert(
457            "mmm".to_string(),
458            ClaimRequirement::RegexMatch(r"^test.*".to_string()),
459        );
460
461        let config = JwtClaimsValidatorConfig {
462            required_claims,
463            required_headers: std::collections::BTreeMap::new(),
464        };
465
466        // Serialize multiple times to ensure stable order
467        let serialized1 = serde_json::to_string(&config).unwrap();
468        let serialized2 = serde_json::to_string(&config).unwrap();
469
470        // Both serializations should be identical
471        assert_eq!(serialized1, serialized2, "Serialization should be stable");
472
473        // Keys should be in alphabetical order
474        assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
475        assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
476        assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
477    }
478
479    #[test]
480    fn test_capture_groups_validation() {
481        let test_cases: Vec<(
482            &str,
483            Vec<String>,
484            Result<(), RegexPatternCaptureGroupsValidationError>,
485        )> = vec![
486            (
487                "hello (?<sds_match>world)",
488                vec!["sds_match".to_string()],
489                Ok(()),
490            ),
491            (
492                "hello (?<capture_group>world)",
493                vec!["capture_group".to_string()],
494                Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
495            ),
496            (
497                "hello (?<sds_match>world) and (?<another_group>world)",
498                vec!["sds_match".to_string()],
499                Ok(()),
500            ),
501            (
502                "hello (?<capture_grou>world)",
503                vec!["capture_group".to_string()],
504                Err(
505                    RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
506                        "capture_group".to_string(),
507                    ),
508                ),
509            ),
510            (
511                "hello (?<sds_match>d*)",
512                vec!["sds_match".to_string()],
513                Err(RegexPatternCaptureGroupsValidationError::CaptureGroupMatchesEmptyString),
514            ),
515            (
516                "hello (?<sds_match>world)",
517                vec!["sds_match".to_string(), "sds_match2".to_string()],
518                Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
519            ),
520        ];
521        for (pattern, capture_groups, expected_result) in test_cases {
522            let rule_config =
523                RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
524            assert_eq!(
525                is_pattern_capture_groups_valid(
526                    &rule_config.pattern,
527                    &rule_config.pattern_capture_groups,
528                    &get_memoized_regex(pattern, validate_and_create_regex)
529                        .unwrap()
530                        .group_info()
531                ),
532                expected_result
533            );
534        }
535    }
536}