dd_sds/scanner/regex_rule/
config.rs

1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{
7    RegexPatternCaptureGroupsValidationError, validate_and_create_regex,
8    validate_named_capture_group_minimum_length,
9};
10use crate::{CompiledRule, CreateScannerError, Labels};
11use regex_automata::util::captures::GroupInfo;
12use serde::{Deserialize, Serialize};
13use serde_with::DefaultOnNull;
14use serde_with::serde_as;
15use std::sync::Arc;
16use strum::{AsRefStr, EnumIter};
17
18pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
19
20#[serde_as]
21#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
22pub struct RegexRuleConfig {
23    pub pattern: String,
24    pub proximity_keywords: Option<ProximityKeywordsConfig>,
25    pub validator: Option<SecondaryValidator>,
26    #[serde_as(deserialize_as = "DefaultOnNull")]
27    #[serde(default)]
28    pub labels: Labels,
29    pub pattern_capture_groups: Option<Vec<String>>,
30}
31
32impl RegexRuleConfig {
33    pub fn new(pattern: &str) -> Self {
34        #[allow(deprecated)]
35        Self {
36            pattern: pattern.to_owned(),
37            proximity_keywords: None,
38            validator: None,
39            labels: Labels::default(),
40            pattern_capture_groups: None,
41        }
42    }
43
44    pub fn with_pattern(&self, pattern: &str) -> Self {
45        self.mutate_clone(|x| x.pattern = pattern.to_string())
46    }
47
48    pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
49        self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
50    }
51
52    pub fn with_labels(&self, labels: Labels) -> Self {
53        self.mutate_clone(|x| x.labels = labels)
54    }
55
56    pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
57        self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
58    }
59
60    pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
61        self.mutate_clone(|x| match x.pattern_capture_groups {
62            Some(ref mut pattern_capture_groups) => {
63                pattern_capture_groups.push(pattern_capture_group.to_string());
64            }
65            None => {
66                x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
67            }
68        })
69    }
70
71    pub fn build(&self) -> Arc<dyn RuleConfig> {
72        Arc::new(self.clone())
73    }
74
75    fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
76        let mut clone = self.clone();
77        modify(&mut clone);
78        clone
79    }
80
81    pub fn with_included_keywords(
82        &self,
83        keywords: impl IntoIterator<Item = impl AsRef<str>>,
84    ) -> Self {
85        let mut this = self.clone();
86        let mut config = self.get_or_create_proximity_keywords_config();
87        config.included_keywords = keywords
88            .into_iter()
89            .map(|x| x.as_ref().to_string())
90            .collect::<Vec<_>>();
91        this.proximity_keywords = Some(config);
92        this
93    }
94
95    pub fn with_excluded_keywords(
96        &self,
97        keywords: impl IntoIterator<Item = impl AsRef<str>>,
98    ) -> Self {
99        let mut this = self.clone();
100        let mut config = self.get_or_create_proximity_keywords_config();
101        config.excluded_keywords = keywords
102            .into_iter()
103            .map(|x| x.as_ref().to_string())
104            .collect::<Vec<_>>();
105        this.proximity_keywords = Some(config);
106        this
107    }
108
109    pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
110        let mut this = self.clone();
111        this.validator = validator;
112        this
113    }
114
115    fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
116        self.proximity_keywords
117            .clone()
118            .unwrap_or_else(|| ProximityKeywordsConfig {
119                look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
120                included_keywords: vec![],
121                excluded_keywords: vec![],
122            })
123    }
124}
125
126fn is_pattern_capture_groups_valid(
127    pattern: &str,
128    pattern_capture_groups: &Option<Vec<String>>,
129    group_info: &GroupInfo,
130) -> Result<(), RegexPatternCaptureGroupsValidationError> {
131    if pattern_capture_groups.is_none() {
132        return Ok(());
133    }
134    let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
135    if pattern_capture_groups.len() != 1 {
136        // We currently only allow one capture group
137        return Err(
138            RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
139                pattern_capture_groups.len(),
140            ),
141        );
142    }
143    let pattern_capture_group = pattern_capture_groups.first().unwrap();
144    if !group_info
145        .all_names()
146        .filter(|(_, _, name)| name.is_some())
147        .map(|(_, _, name)| name.unwrap())
148        .any(|name| name == pattern_capture_group)
149    {
150        return Err(
151            RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
152                pattern_capture_group.clone(),
153            ),
154        );
155    }
156    // At this point, the capture group is in the regex, and there is exactly one.
157    // Currently, it must be called `sds_match`.
158    if pattern_capture_group != "sds_match" {
159        return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
160    }
161    validate_named_capture_group_minimum_length(pattern, pattern_capture_group)?;
162    Ok(())
163}
164
165impl RuleConfig for RegexRuleConfig {
166    fn convert_to_compiled_rule(
167        &self,
168        rule_index: usize,
169        scanner_labels: Labels,
170    ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
171        let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
172
173        let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
174
175        let (included_keywords, excluded_keywords) = self
176            .proximity_keywords
177            .as_ref()
178            .map(|config| compile_keywords_proximity_config(config, &rule_labels))
179            .unwrap_or(Ok((None, None)))?;
180
181        is_pattern_capture_groups_valid(
182            &self.pattern,
183            &self.pattern_capture_groups,
184            regex.group_info(),
185        )?;
186
187        Ok(Box::new(RegexCompiledRule {
188            rule_index,
189            regex,
190            included_keywords,
191            excluded_keywords,
192            validator: self.validator.clone().map(|x| x.compile()),
193            metrics: RuleMetrics::new(&rule_labels),
194            pattern_capture_groups: self.pattern_capture_groups.clone(),
195        }))
196    }
197
198    fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
199        Some(self)
200    }
201}
202
203#[serde_as]
204#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
205pub struct ProximityKeywordsConfig {
206    pub look_ahead_character_count: usize,
207
208    #[serde_as(deserialize_as = "DefaultOnNull")]
209    #[serde(default)]
210    pub included_keywords: Vec<String>,
211
212    #[serde_as(deserialize_as = "DefaultOnNull")]
213    #[serde(default)]
214    pub excluded_keywords: Vec<String>,
215}
216
217#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
218#[serde(tag = "type")]
219pub enum SecondaryValidator {
220    AbaRtnChecksum,
221    AustralianMedicareChecksum,
222    AustralianTfnChecksum,
223    AustrianSSNChecksum,
224    BelgiumNationalRegisterChecksum,
225    BrazilianCnpjChecksum,
226    BrazilianCpfChecksum,
227    BtcChecksum,
228    BulgarianEGNChecksum,
229    ChineseIdChecksum,
230    CoordinationNumberChecksum,
231    CzechPersonalIdentificationNumberChecksum,
232    CzechTaxIdentificationNumberChecksum,
233    DutchBsnChecksum,
234    DutchPassportChecksum,
235    EntropyCheck,
236    EstoniaPersonalCodeChecksum,
237    EthereumChecksum,
238    FinnishHetuChecksum,
239    FranceNifChecksum,
240    FranceSsnChecksum,
241    GermanIdsChecksum,
242    GermanSvnrChecksum,
243    GithubTokenChecksum,
244    GreeceAmkaChecksum,
245    GreekTinChecksum,
246    HungarianTinChecksum,
247    IbanChecker,
248    IrishPpsChecksum,
249    ItalianNationalIdChecksum,
250    JwtClaimsValidator { config: JwtClaimsValidatorConfig },
251    JwtExpirationChecker,
252    LatviaNationalIdChecksum,
253    LithuanianPersonalIdentificationNumberChecksum,
254    LuhnChecksum,
255    LuxembourgIndividualNINChecksum,
256    Mod11_10checksum,
257    Mod11_2checksum,
258    Mod1271_36Checksum,
259    Mod27_26checksum,
260    Mod37_2checksum,
261    Mod37_36checksum,
262    Mod661_26checksum,
263    Mod97_10checksum,
264    MoneroAddress,
265    NhsCheckDigit,
266    NirChecksum,
267    PolishNationalIdChecksum,
268    PolishNipChecksum,
269    PortugueseTaxIdChecksum,
270    RodneCisloNumberChecksum,
271    RomanianPersonalNumericCode,
272    SingaporeNricChecksum,
273    SloveniaTinChecksum,
274    SlovenianPINChecksum,
275    SpanishDniChecksum,
276    SpanishNussChecksum,
277    SwedenPINChecksum,
278    UkNinoFormatCheck,
279    UkTrnChecksum,
280    UsDeaChecksum,
281    UsNpiChecksum,
282    VerhoeffChecksum,
283}
284
285#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
286#[serde(tag = "type", content = "config")]
287pub enum ClaimRequirement {
288    /// Just check that the claim exists
289    Present,
290    /// Check that the claim exists and is not expired
291    NotExpired,
292    /// Check that the claim exists and has an exact value
293    ExactValue(String),
294    /// Check that the claim exists and matches a regex pattern
295    RegexMatch(String),
296}
297
298#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
299pub struct JwtClaimsValidatorConfig {
300    #[serde(default)]
301    pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
302    #[serde(default)]
303    pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
304}
305
306#[cfg(test)]
307mod test {
308    use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
309    use std::collections::BTreeMap;
310    use strum::IntoEnumIterator;
311
312    use super::*;
313
314    #[test]
315    fn should_override_pattern() {
316        let rule_config = RegexRuleConfig::new("123").with_pattern("456");
317        assert_eq!(rule_config.pattern, "456");
318    }
319
320    #[test]
321    #[allow(deprecated)]
322    fn should_have_default() {
323        let rule_config = RegexRuleConfig::new("123");
324        assert_eq!(
325            rule_config,
326            RegexRuleConfig {
327                pattern: "123".to_string(),
328                proximity_keywords: None,
329                validator: None,
330                labels: Labels::empty(),
331                pattern_capture_groups: None,
332            }
333        );
334    }
335
336    #[test]
337    fn should_use_capture_group() {
338        let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
339            .with_pattern_capture_groups(vec!["capture_group".to_string()]);
340        assert_eq!(
341            rule_config,
342            RegexRuleConfig {
343                pattern: "hey (?<capture_group>world)".to_string(),
344                proximity_keywords: None,
345                validator: None,
346                labels: Labels::empty(),
347                pattern_capture_groups: Some(vec!["capture_group".to_string()]),
348            }
349        );
350    }
351
352    #[test]
353    fn proximity_keywords_should_have_default() {
354        let json_config = r#"{"look_ahead_character_count": 0}"#;
355        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
356        assert_eq!(
357            test,
358            ProximityKeywordsConfig {
359                look_ahead_character_count: 0,
360                included_keywords: vec![],
361                excluded_keywords: vec![]
362            }
363        );
364
365        let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
366        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
367        assert_eq!(
368            test,
369            ProximityKeywordsConfig {
370                look_ahead_character_count: 0,
371                included_keywords: vec![],
372                excluded_keywords: vec![]
373            }
374        );
375    }
376
377    #[test]
378    #[allow(deprecated)]
379    fn test_third_party_active_checker() {
380        // Test setting only the new field
381        let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
382        let validation_type = MatchValidationType::CustomHttp(http_config.clone());
383        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
384            .third_party_active_checker(validation_type.clone());
385
386        assert_eq!(
387            rule_config.third_party_active_checker,
388            Some(validation_type.clone())
389        );
390        assert_eq!(rule_config.match_validation_type, None);
391        assert_eq!(
392            rule_config.get_third_party_active_checker(),
393            Some(&validation_type)
394        );
395
396        // Test setting via deprecated field updates both
397        let aws_type = AwsType::AwsId;
398        let validation_type2 = MatchValidationType::Aws(aws_type);
399        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
400            .third_party_active_checker(validation_type2.clone());
401
402        assert_eq!(
403            rule_config.third_party_active_checker,
404            Some(validation_type2.clone())
405        );
406        assert_eq!(
407            rule_config.get_third_party_active_checker(),
408            Some(&validation_type2)
409        );
410
411        // Test that get_match_validation_type prioritizes third_party_active_checker
412        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
413            .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
414
415        assert_eq!(
416            rule_config.get_third_party_active_checker(),
417            Some(&MatchValidationType::CustomHttp(http_config.clone()))
418        );
419    }
420
421    #[test]
422    fn test_secondary_validator_enum_iter() {
423        // Test that we can iterate over all SecondaryValidator variants
424        let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
425        // Verify some variants
426        assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
427        assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
428    }
429
430    #[test]
431    fn test_secondary_validator_are_sorted() {
432        let validator_names: Vec<String> = SecondaryValidator::iter()
433            .map(|a| a.as_ref().to_string())
434            .collect();
435        let mut sorted_validator_names = validator_names.clone();
436        sorted_validator_names.sort();
437        assert_eq!(
438            sorted_validator_names, validator_names,
439            "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
440        );
441    }
442
443    // The order has to be stable to pass linter checks. Otherwise, each instantiation will change the file
444    #[test]
445    fn test_jwt_claims_validator_config_serialization_order() {
446        // Create a config with claims in non-alphabetical order
447        let mut required_claims = BTreeMap::new();
448        required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
449        required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
450        required_claims.insert(
451            "aaa".to_string(),
452            ClaimRequirement::ExactValue("test".to_string()),
453        );
454        required_claims.insert(
455            "mmm".to_string(),
456            ClaimRequirement::RegexMatch(r"^test.*".to_string()),
457        );
458
459        let config = JwtClaimsValidatorConfig {
460            required_claims,
461            required_headers: std::collections::BTreeMap::new(),
462        };
463
464        // Serialize multiple times to ensure stable order
465        let serialized1 = serde_json::to_string(&config).unwrap();
466        let serialized2 = serde_json::to_string(&config).unwrap();
467
468        // Both serializations should be identical
469        assert_eq!(serialized1, serialized2, "Serialization should be stable");
470
471        // Keys should be in alphabetical order
472        assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
473        assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
474        assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
475    }
476
477    #[test]
478    fn test_capture_groups_validation() {
479        let test_cases: Vec<(
480            &str,
481            Vec<String>,
482            Result<(), RegexPatternCaptureGroupsValidationError>,
483        )> = vec![
484            (
485                "hello (?<sds_match>world)",
486                vec!["sds_match".to_string()],
487                Ok(()),
488            ),
489            (
490                "hello (?<capture_group>world)",
491                vec!["capture_group".to_string()],
492                Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
493            ),
494            (
495                "hello (?<sds_match>world) and (?<another_group>world)",
496                vec!["sds_match".to_string()],
497                Ok(()),
498            ),
499            (
500                "hello (?<capture_grou>world)",
501                vec!["capture_group".to_string()],
502                Err(
503                    RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
504                        "capture_group".to_string(),
505                    ),
506                ),
507            ),
508            (
509                "hello (?<sds_match>d*)",
510                vec!["sds_match".to_string()],
511                Err(RegexPatternCaptureGroupsValidationError::CaptureGroupMatchesEmptyString),
512            ),
513            (
514                "hello (?<sds_match>world)",
515                vec!["sds_match".to_string(), "sds_match2".to_string()],
516                Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
517            ),
518        ];
519        for (pattern, capture_groups, expected_result) in test_cases {
520            let rule_config =
521                RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
522            assert_eq!(
523                is_pattern_capture_groups_valid(
524                    &rule_config.pattern,
525                    &rule_config.pattern_capture_groups,
526                    &get_memoized_regex(pattern, validate_and_create_regex)
527                        .unwrap()
528                        .group_info()
529                ),
530                expected_result
531            );
532        }
533    }
534}