dd_sds/scanner/regex_rule/
config.rs

1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{RegexPatternCaptureGroupsValidationError, validate_and_create_regex};
7use crate::{CompiledRule, CreateScannerError, Labels};
8use regex_automata::util::captures::GroupInfo;
9use serde::{Deserialize, Serialize};
10use serde_with::DefaultOnNull;
11use serde_with::serde_as;
12use std::sync::Arc;
13use strum::{AsRefStr, EnumIter};
14
15pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
16
17#[serde_as]
18#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
19pub struct RegexRuleConfig {
20    pub pattern: String,
21    pub proximity_keywords: Option<ProximityKeywordsConfig>,
22    pub validator: Option<SecondaryValidator>,
23    #[serde_as(deserialize_as = "DefaultOnNull")]
24    #[serde(default)]
25    pub labels: Labels,
26    pub pattern_capture_groups: Option<Vec<String>>,
27}
28
29impl RegexRuleConfig {
30    pub fn new(pattern: &str) -> Self {
31        #[allow(deprecated)]
32        Self {
33            pattern: pattern.to_owned(),
34            proximity_keywords: None,
35            validator: None,
36            labels: Labels::default(),
37            pattern_capture_groups: None,
38        }
39    }
40
41    pub fn with_pattern(&self, pattern: &str) -> Self {
42        self.mutate_clone(|x| x.pattern = pattern.to_string())
43    }
44
45    pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
46        self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
47    }
48
49    pub fn with_labels(&self, labels: Labels) -> Self {
50        self.mutate_clone(|x| x.labels = labels)
51    }
52
53    pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
54        self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
55    }
56
57    pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
58        self.mutate_clone(|x| match x.pattern_capture_groups {
59            Some(ref mut pattern_capture_groups) => {
60                pattern_capture_groups.push(pattern_capture_group.to_string());
61            }
62            None => {
63                x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
64            }
65        })
66    }
67
68    pub fn build(&self) -> Arc<dyn RuleConfig> {
69        Arc::new(self.clone())
70    }
71
72    fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
73        let mut clone = self.clone();
74        modify(&mut clone);
75        clone
76    }
77
78    pub fn with_included_keywords(
79        &self,
80        keywords: impl IntoIterator<Item = impl AsRef<str>>,
81    ) -> Self {
82        let mut this = self.clone();
83        let mut config = self.get_or_create_proximity_keywords_config();
84        config.included_keywords = keywords
85            .into_iter()
86            .map(|x| x.as_ref().to_string())
87            .collect::<Vec<_>>();
88        this.proximity_keywords = Some(config);
89        this
90    }
91
92    pub fn with_excluded_keywords(
93        &self,
94        keywords: impl IntoIterator<Item = impl AsRef<str>>,
95    ) -> Self {
96        let mut this = self.clone();
97        let mut config = self.get_or_create_proximity_keywords_config();
98        config.excluded_keywords = keywords
99            .into_iter()
100            .map(|x| x.as_ref().to_string())
101            .collect::<Vec<_>>();
102        this.proximity_keywords = Some(config);
103        this
104    }
105
106    pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
107        let mut this = self.clone();
108        this.validator = validator;
109        this
110    }
111
112    fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
113        self.proximity_keywords
114            .clone()
115            .unwrap_or_else(|| ProximityKeywordsConfig {
116                look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
117                included_keywords: vec![],
118                excluded_keywords: vec![],
119            })
120    }
121}
122
123fn is_pattern_capture_groups_valid(
124    pattern_capture_groups: &Option<Vec<String>>,
125    group_info: &GroupInfo,
126) -> Result<(), RegexPatternCaptureGroupsValidationError> {
127    if pattern_capture_groups.is_none() {
128        return Ok(());
129    }
130    let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
131    if pattern_capture_groups.len() != 1 {
132        // We currently only allow one capture group
133        return Err(
134            RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
135                pattern_capture_groups.len(),
136            ),
137        );
138    }
139    let pattern_capture_group = pattern_capture_groups.first().unwrap();
140    if !group_info
141        .all_names()
142        .filter(|(_, _, name)| name.is_some())
143        .map(|(_, _, name)| name.unwrap())
144        .any(|name| name == pattern_capture_group)
145    {
146        return Err(
147            RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
148                pattern_capture_group.clone(),
149            ),
150        );
151    }
152    // At this point, the capture group is in the regex, and there is exactly one.
153    // Currently, it must be called `sds_match`.
154    if pattern_capture_group != "sds_match" {
155        return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
156    }
157    Ok(())
158}
159
160impl RuleConfig for RegexRuleConfig {
161    fn convert_to_compiled_rule(
162        &self,
163        rule_index: usize,
164        scanner_labels: Labels,
165    ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
166        let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
167
168        let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
169
170        let (included_keywords, excluded_keywords) = self
171            .proximity_keywords
172            .as_ref()
173            .map(|config| compile_keywords_proximity_config(config, &rule_labels))
174            .unwrap_or(Ok((None, None)))?;
175
176        is_pattern_capture_groups_valid(&self.pattern_capture_groups, regex.group_info())?;
177
178        Ok(Box::new(RegexCompiledRule {
179            rule_index,
180            regex,
181            included_keywords,
182            excluded_keywords,
183            validator: self.validator.clone().map(|x| x.compile()),
184            metrics: RuleMetrics::new(&rule_labels),
185            pattern_capture_groups: self.pattern_capture_groups.clone(),
186        }))
187    }
188
189    fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
190        Some(self)
191    }
192}
193
194#[serde_as]
195#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
196pub struct ProximityKeywordsConfig {
197    pub look_ahead_character_count: usize,
198
199    #[serde_as(deserialize_as = "DefaultOnNull")]
200    #[serde(default)]
201    pub included_keywords: Vec<String>,
202
203    #[serde_as(deserialize_as = "DefaultOnNull")]
204    #[serde(default)]
205    pub excluded_keywords: Vec<String>,
206}
207
208#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
209#[serde(tag = "type")]
210pub enum SecondaryValidator {
211    AbaRtnChecksum,
212    AustrianSSNChecksum,
213    BelgiumNationalRegisterChecksum,
214    BrazilianCnpjChecksum,
215    BrazilianCpfChecksum,
216    BtcChecksum,
217    BulgarianEGNChecksum,
218    ChineseIdChecksum,
219    CoordinationNumberChecksum,
220    CzechPersonalIdentificationNumberChecksum,
221    CzechTaxIdentificationNumberChecksum,
222    DutchBsnChecksum,
223    DutchPassportChecksum,
224    EntropyCheck,
225    EstoniaPersonalCodeChecksum,
226    EthereumChecksum,
227    FinnishHetuChecksum,
228    FranceNifChecksum,
229    FranceSsnChecksum,
230    GermanIdsChecksum,
231    GermanSvnrChecksum,
232    GithubTokenChecksum,
233    GreeceAmkaChecksum,
234    GreekTinChecksum,
235    HungarianTinChecksum,
236    IbanChecker,
237    IrishPpsChecksum,
238    ItalianNationalIdChecksum,
239    JwtClaimsValidator { config: JwtClaimsValidatorConfig },
240    JwtExpirationChecker,
241    LatviaNationalIdChecksum,
242    LithuanianPersonalIdentificationNumberChecksum,
243    LuhnChecksum,
244    LuxembourgIndividualNINChecksum,
245    Mod11_10checksum,
246    Mod11_2checksum,
247    Mod1271_36Checksum,
248    Mod27_26checksum,
249    Mod37_2checksum,
250    Mod37_36checksum,
251    Mod661_26checksum,
252    Mod97_10checksum,
253    MoneroAddress,
254    NhsCheckDigit,
255    NirChecksum,
256    PolishNationalIdChecksum,
257    PolishNipChecksum,
258    PortugueseTaxIdChecksum,
259    RodneCisloNumberChecksum,
260    RomanianPersonalNumericCode,
261    SloveniaTinChecksum,
262    SlovenianPINChecksum,
263    SpanishDniChecksum,
264    SpanishNussChecksum,
265    SwedenPINChecksum,
266    UsDeaChecksum,
267    UsNpiChecksum,
268    VerhoeffChecksum,
269}
270
271#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
272#[serde(tag = "type", content = "config")]
273pub enum ClaimRequirement {
274    /// Just check that the claim exists
275    Present,
276    /// Check that the claim exists and is not expired
277    NotExpired,
278    /// Check that the claim exists and has an exact value
279    ExactValue(String),
280    /// Check that the claim exists and matches a regex pattern
281    RegexMatch(String),
282}
283
284#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
285pub struct JwtClaimsValidatorConfig {
286    #[serde(default)]
287    pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
288    #[serde(default)]
289    pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
290}
291
292#[cfg(test)]
293mod test {
294    use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
295    use std::collections::BTreeMap;
296    use strum::IntoEnumIterator;
297
298    use super::*;
299
300    #[test]
301    fn should_override_pattern() {
302        let rule_config = RegexRuleConfig::new("123").with_pattern("456");
303        assert_eq!(rule_config.pattern, "456");
304    }
305
306    #[test]
307    #[allow(deprecated)]
308    fn should_have_default() {
309        let rule_config = RegexRuleConfig::new("123");
310        assert_eq!(
311            rule_config,
312            RegexRuleConfig {
313                pattern: "123".to_string(),
314                proximity_keywords: None,
315                validator: None,
316                labels: Labels::empty(),
317                pattern_capture_groups: None,
318            }
319        );
320    }
321
322    #[test]
323    fn should_use_capture_group() {
324        let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
325            .with_pattern_capture_groups(vec!["capture_group".to_string()]);
326        assert_eq!(
327            rule_config,
328            RegexRuleConfig {
329                pattern: "hey (?<capture_group>world)".to_string(),
330                proximity_keywords: None,
331                validator: None,
332                labels: Labels::empty(),
333                pattern_capture_groups: Some(vec!["capture_group".to_string()]),
334            }
335        );
336    }
337
338    #[test]
339    fn proximity_keywords_should_have_default() {
340        let json_config = r#"{"look_ahead_character_count": 0}"#;
341        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
342        assert_eq!(
343            test,
344            ProximityKeywordsConfig {
345                look_ahead_character_count: 0,
346                included_keywords: vec![],
347                excluded_keywords: vec![]
348            }
349        );
350
351        let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
352        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
353        assert_eq!(
354            test,
355            ProximityKeywordsConfig {
356                look_ahead_character_count: 0,
357                included_keywords: vec![],
358                excluded_keywords: vec![]
359            }
360        );
361    }
362
363    #[test]
364    #[allow(deprecated)]
365    fn test_third_party_active_checker() {
366        // Test setting only the new field
367        let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
368        let validation_type = MatchValidationType::CustomHttp(http_config.clone());
369        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
370            .third_party_active_checker(validation_type.clone());
371
372        assert_eq!(
373            rule_config.third_party_active_checker,
374            Some(validation_type.clone())
375        );
376        assert_eq!(rule_config.match_validation_type, None);
377        assert_eq!(
378            rule_config.get_third_party_active_checker(),
379            Some(&validation_type)
380        );
381
382        // Test setting via deprecated field updates both
383        let aws_type = AwsType::AwsId;
384        let validation_type2 = MatchValidationType::Aws(aws_type);
385        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
386            .third_party_active_checker(validation_type2.clone());
387
388        assert_eq!(
389            rule_config.third_party_active_checker,
390            Some(validation_type2.clone())
391        );
392        assert_eq!(
393            rule_config.get_third_party_active_checker(),
394            Some(&validation_type2)
395        );
396
397        // Test that get_match_validation_type prioritizes third_party_active_checker
398        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
399            .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
400
401        assert_eq!(
402            rule_config.get_third_party_active_checker(),
403            Some(&MatchValidationType::CustomHttp(http_config.clone()))
404        );
405    }
406
407    #[test]
408    fn test_secondary_validator_enum_iter() {
409        // Test that we can iterate over all SecondaryValidator variants
410        let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
411        // Verify some variants
412        assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
413        assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
414    }
415
416    #[test]
417    fn test_secondary_validator_are_sorted() {
418        let validator_names: Vec<String> = SecondaryValidator::iter()
419            .map(|a| a.as_ref().to_string())
420            .collect();
421        let mut sorted_validator_names = validator_names.clone();
422        sorted_validator_names.sort();
423        assert_eq!(
424            sorted_validator_names, validator_names,
425            "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
426        );
427    }
428
429    // The order has to be stable to pass linter checks. Otherwise, each instantiation will change the file
430    #[test]
431    fn test_jwt_claims_validator_config_serialization_order() {
432        // Create a config with claims in non-alphabetical order
433        let mut required_claims = BTreeMap::new();
434        required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
435        required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
436        required_claims.insert(
437            "aaa".to_string(),
438            ClaimRequirement::ExactValue("test".to_string()),
439        );
440        required_claims.insert(
441            "mmm".to_string(),
442            ClaimRequirement::RegexMatch(r"^test.*".to_string()),
443        );
444
445        let config = JwtClaimsValidatorConfig {
446            required_claims,
447            required_headers: std::collections::BTreeMap::new(),
448        };
449
450        // Serialize multiple times to ensure stable order
451        let serialized1 = serde_json::to_string(&config).unwrap();
452        let serialized2 = serde_json::to_string(&config).unwrap();
453
454        // Both serializations should be identical
455        assert_eq!(serialized1, serialized2, "Serialization should be stable");
456
457        // Keys should be in alphabetical order
458        assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
459        assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
460        assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
461    }
462
463    #[test]
464    fn test_capture_groups_validation() {
465        let test_cases: Vec<(
466            &str,
467            Vec<String>,
468            Result<(), RegexPatternCaptureGroupsValidationError>,
469        )> = vec![
470            (
471                "hello (?<sds_match>world)",
472                vec!["sds_match".to_string()],
473                Ok(()),
474            ),
475            (
476                "hello (?<capture_group>world)",
477                vec!["capture_group".to_string()],
478                Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
479            ),
480            (
481                "hello (?<sds_match>world) and (?<another_group>world)",
482                vec!["sds_match".to_string()],
483                Ok(()),
484            ),
485            (
486                "hello (?<capture_grou>world)",
487                vec!["capture_group".to_string()],
488                Err(
489                    RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
490                        "capture_group".to_string(),
491                    ),
492                ),
493            ),
494            (
495                "hello (?<sds_match>world)",
496                vec!["sds_match".to_string(), "sds_match2".to_string()],
497                Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
498            ),
499        ];
500        for (pattern, capture_groups, expected_result) in test_cases {
501            let rule_config =
502                RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
503            assert_eq!(
504                is_pattern_capture_groups_valid(
505                    &rule_config.pattern_capture_groups,
506                    &get_memoized_regex(pattern, validate_and_create_regex)
507                        .unwrap()
508                        .group_info()
509                ),
510                expected_result
511            );
512        }
513    }
514}