dd_sds/scanner/regex_rule/
config.rs

1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{RegexPatternCaptureGroupsValidationError, validate_and_create_regex};
7use crate::{CompiledRule, CreateScannerError, Labels};
8use regex_automata::util::captures::GroupInfo;
9use serde::{Deserialize, Serialize};
10use serde_with::DefaultOnNull;
11use serde_with::serde_as;
12use std::sync::Arc;
13use strum::{AsRefStr, EnumIter};
14
15pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
16
17#[serde_as]
18#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
19pub struct RegexRuleConfig {
20    pub pattern: String,
21    pub proximity_keywords: Option<ProximityKeywordsConfig>,
22    pub validator: Option<SecondaryValidator>,
23    #[serde_as(deserialize_as = "DefaultOnNull")]
24    #[serde(default)]
25    pub labels: Labels,
26    pub pattern_capture_groups: Option<Vec<String>>,
27}
28
29impl RegexRuleConfig {
30    pub fn new(pattern: &str) -> Self {
31        #[allow(deprecated)]
32        Self {
33            pattern: pattern.to_owned(),
34            proximity_keywords: None,
35            validator: None,
36            labels: Labels::default(),
37            pattern_capture_groups: None,
38        }
39    }
40
41    pub fn with_pattern(&self, pattern: &str) -> Self {
42        self.mutate_clone(|x| x.pattern = pattern.to_string())
43    }
44
45    pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
46        self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
47    }
48
49    pub fn with_labels(&self, labels: Labels) -> Self {
50        self.mutate_clone(|x| x.labels = labels)
51    }
52
53    pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
54        self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
55    }
56
57    pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
58        self.mutate_clone(|x| match x.pattern_capture_groups {
59            Some(ref mut pattern_capture_groups) => {
60                pattern_capture_groups.push(pattern_capture_group.to_string());
61            }
62            None => {
63                x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
64            }
65        })
66    }
67
68    pub fn build(&self) -> Arc<dyn RuleConfig> {
69        Arc::new(self.clone())
70    }
71
72    fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
73        let mut clone = self.clone();
74        modify(&mut clone);
75        clone
76    }
77
78    pub fn with_included_keywords(
79        &self,
80        keywords: impl IntoIterator<Item = impl AsRef<str>>,
81    ) -> Self {
82        let mut this = self.clone();
83        let mut config = self.get_or_create_proximity_keywords_config();
84        config.included_keywords = keywords
85            .into_iter()
86            .map(|x| x.as_ref().to_string())
87            .collect::<Vec<_>>();
88        this.proximity_keywords = Some(config);
89        this
90    }
91
92    pub fn with_excluded_keywords(
93        &self,
94        keywords: impl IntoIterator<Item = impl AsRef<str>>,
95    ) -> Self {
96        let mut this = self.clone();
97        let mut config = self.get_or_create_proximity_keywords_config();
98        config.excluded_keywords = keywords
99            .into_iter()
100            .map(|x| x.as_ref().to_string())
101            .collect::<Vec<_>>();
102        this.proximity_keywords = Some(config);
103        this
104    }
105
106    pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
107        let mut this = self.clone();
108        this.validator = validator;
109        this
110    }
111
112    fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
113        self.proximity_keywords
114            .clone()
115            .unwrap_or_else(|| ProximityKeywordsConfig {
116                look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
117                included_keywords: vec![],
118                excluded_keywords: vec![],
119            })
120    }
121}
122
123fn is_pattern_capture_groups_valid(
124    pattern_capture_groups: &Option<Vec<String>>,
125    group_info: &GroupInfo,
126) -> Result<(), RegexPatternCaptureGroupsValidationError> {
127    if pattern_capture_groups.is_none() {
128        return Ok(());
129    }
130    let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
131    if pattern_capture_groups.len() != 1 {
132        // We currently only allow one capture group
133        return Err(
134            RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
135                pattern_capture_groups.len(),
136            ),
137        );
138    }
139    let pattern_capture_group = pattern_capture_groups.first().unwrap();
140    if !group_info
141        .all_names()
142        .filter(|(_, _, name)| name.is_some())
143        .map(|(_, _, name)| name.unwrap())
144        .any(|name| name == pattern_capture_group)
145    {
146        return Err(
147            RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
148                pattern_capture_group.clone(),
149            ),
150        );
151    }
152    // At this point, the capture group is in the regex, and there is exactly one.
153    // Currently, it must be called `sds_match`.
154    if pattern_capture_group != "sds_match" {
155        return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
156    }
157    Ok(())
158}
159
160impl RuleConfig for RegexRuleConfig {
161    fn convert_to_compiled_rule(
162        &self,
163        rule_index: usize,
164        scanner_labels: Labels,
165    ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
166        let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
167
168        let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
169
170        let (included_keywords, excluded_keywords) = self
171            .proximity_keywords
172            .as_ref()
173            .map(|config| compile_keywords_proximity_config(config, &rule_labels))
174            .unwrap_or(Ok((None, None)))?;
175
176        is_pattern_capture_groups_valid(&self.pattern_capture_groups, regex.group_info())?;
177
178        Ok(Box::new(RegexCompiledRule {
179            rule_index,
180            regex,
181            included_keywords,
182            excluded_keywords,
183            validator: self.validator.clone().map(|x| x.compile()),
184            metrics: RuleMetrics::new(&rule_labels),
185            pattern_capture_groups: self.pattern_capture_groups.clone(),
186        }))
187    }
188
189    fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
190        Some(self)
191    }
192}
193
194#[serde_as]
195#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
196pub struct ProximityKeywordsConfig {
197    pub look_ahead_character_count: usize,
198
199    #[serde_as(deserialize_as = "DefaultOnNull")]
200    #[serde(default)]
201    pub included_keywords: Vec<String>,
202
203    #[serde_as(deserialize_as = "DefaultOnNull")]
204    #[serde(default)]
205    pub excluded_keywords: Vec<String>,
206}
207
208#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
209#[serde(tag = "type")]
210pub enum SecondaryValidator {
211    AbaRtnChecksum,
212    BelgiumNationalRegisterChecksum,
213    BrazilianCnpjChecksum,
214    BrazilianCpfChecksum,
215    BtcChecksum,
216    BulgarianEGNChecksum,
217    ChineseIdChecksum,
218    CoordinationNumberChecksum,
219    CzechPersonalIdentificationNumberChecksum,
220    CzechTaxIdentificationNumberChecksum,
221    DutchBsnChecksum,
222    DutchPassportChecksum,
223    EntropyCheck,
224    EstoniaPersonalCodeChecksum,
225    EthereumChecksum,
226    FinnishHetuChecksum,
227    FranceNifChecksum,
228    FranceSsnChecksum,
229    GermanIdsChecksum,
230    GermanSvnrChecksum,
231    GithubTokenChecksum,
232    GreeceAmkaChecksum,
233    GreekTinChecksum,
234    HungarianTinChecksum,
235    IbanChecker,
236    IrishPpsChecksum,
237    ItalianNationalIdChecksum,
238    JwtClaimsValidator { config: JwtClaimsValidatorConfig },
239    JwtExpirationChecker,
240    LatviaNationalIdChecksum,
241    LithuanianPersonalIdentificationNumberChecksum,
242    LuhnChecksum,
243    LuxembourgIndividualNINChecksum,
244    Mod11_10checksum,
245    Mod11_2checksum,
246    Mod1271_36Checksum,
247    Mod27_26checksum,
248    Mod37_2checksum,
249    Mod37_36checksum,
250    Mod661_26checksum,
251    Mod97_10checksum,
252    MoneroAddress,
253    NhsCheckDigit,
254    NirChecksum,
255    PolishNationalIdChecksum,
256    PolishNipChecksum,
257    PortugueseTaxIdChecksum,
258    RodneCisloNumberChecksum,
259    RomanianPersonalNumericCode,
260    SloveniaTinChecksum,
261    SlovenianPINChecksum,
262    SpanishDniChecksum,
263    SpanishNussChecksum,
264    SwedenPINChecksum,
265    UsDeaChecksum,
266    UsNpiChecksum,
267    VerhoeffChecksum,
268}
269
270#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
271#[serde(tag = "type", content = "config")]
272pub enum ClaimRequirement {
273    /// Just check that the claim exists
274    Present,
275    /// Check that the claim exists and is not expired
276    NotExpired,
277    /// Check that the claim exists and has an exact value
278    ExactValue(String),
279    /// Check that the claim exists and matches a regex pattern
280    RegexMatch(String),
281}
282
283#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
284pub struct JwtClaimsValidatorConfig {
285    #[serde(default)]
286    pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
287    #[serde(default)]
288    pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
289}
290
291#[cfg(test)]
292mod test {
293    use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
294    use std::collections::BTreeMap;
295    use strum::IntoEnumIterator;
296
297    use super::*;
298
299    #[test]
300    fn should_override_pattern() {
301        let rule_config = RegexRuleConfig::new("123").with_pattern("456");
302        assert_eq!(rule_config.pattern, "456");
303    }
304
305    #[test]
306    #[allow(deprecated)]
307    fn should_have_default() {
308        let rule_config = RegexRuleConfig::new("123");
309        assert_eq!(
310            rule_config,
311            RegexRuleConfig {
312                pattern: "123".to_string(),
313                proximity_keywords: None,
314                validator: None,
315                labels: Labels::empty(),
316                pattern_capture_groups: None,
317            }
318        );
319    }
320
321    #[test]
322    fn should_use_capture_group() {
323        let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
324            .with_pattern_capture_groups(vec!["capture_group".to_string()]);
325        assert_eq!(
326            rule_config,
327            RegexRuleConfig {
328                pattern: "hey (?<capture_group>world)".to_string(),
329                proximity_keywords: None,
330                validator: None,
331                labels: Labels::empty(),
332                pattern_capture_groups: Some(vec!["capture_group".to_string()]),
333            }
334        );
335    }
336
337    #[test]
338    fn proximity_keywords_should_have_default() {
339        let json_config = r#"{"look_ahead_character_count": 0}"#;
340        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
341        assert_eq!(
342            test,
343            ProximityKeywordsConfig {
344                look_ahead_character_count: 0,
345                included_keywords: vec![],
346                excluded_keywords: vec![]
347            }
348        );
349
350        let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
351        let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
352        assert_eq!(
353            test,
354            ProximityKeywordsConfig {
355                look_ahead_character_count: 0,
356                included_keywords: vec![],
357                excluded_keywords: vec![]
358            }
359        );
360    }
361
362    #[test]
363    #[allow(deprecated)]
364    fn test_third_party_active_checker() {
365        // Test setting only the new field
366        let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
367        let validation_type = MatchValidationType::CustomHttp(http_config.clone());
368        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
369            .third_party_active_checker(validation_type.clone());
370
371        assert_eq!(
372            rule_config.third_party_active_checker,
373            Some(validation_type.clone())
374        );
375        assert_eq!(rule_config.match_validation_type, None);
376        assert_eq!(
377            rule_config.get_third_party_active_checker(),
378            Some(&validation_type)
379        );
380
381        // Test setting via deprecated field updates both
382        let aws_type = AwsType::AwsId;
383        let validation_type2 = MatchValidationType::Aws(aws_type);
384        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
385            .third_party_active_checker(validation_type2.clone());
386
387        assert_eq!(
388            rule_config.third_party_active_checker,
389            Some(validation_type2.clone())
390        );
391        assert_eq!(
392            rule_config.get_third_party_active_checker(),
393            Some(&validation_type2)
394        );
395
396        // Test that get_match_validation_type prioritizes third_party_active_checker
397        let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
398            .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
399
400        assert_eq!(
401            rule_config.get_third_party_active_checker(),
402            Some(&MatchValidationType::CustomHttp(http_config.clone()))
403        );
404    }
405
406    #[test]
407    fn test_secondary_validator_enum_iter() {
408        // Test that we can iterate over all SecondaryValidator variants
409        let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
410        // Verify some variants
411        assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
412        assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
413    }
414
415    #[test]
416    fn test_secondary_validator_are_sorted() {
417        let validator_names: Vec<String> = SecondaryValidator::iter()
418            .map(|a| a.as_ref().to_string())
419            .collect();
420        let mut sorted_validator_names = validator_names.clone();
421        sorted_validator_names.sort();
422        assert_eq!(
423            sorted_validator_names, validator_names,
424            "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
425        );
426    }
427
428    // The order has to be stable to pass linter checks. Otherwise, each instantiation will change the file
429    #[test]
430    fn test_jwt_claims_validator_config_serialization_order() {
431        // Create a config with claims in non-alphabetical order
432        let mut required_claims = BTreeMap::new();
433        required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
434        required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
435        required_claims.insert(
436            "aaa".to_string(),
437            ClaimRequirement::ExactValue("test".to_string()),
438        );
439        required_claims.insert(
440            "mmm".to_string(),
441            ClaimRequirement::RegexMatch(r"^test.*".to_string()),
442        );
443
444        let config = JwtClaimsValidatorConfig {
445            required_claims,
446            required_headers: std::collections::BTreeMap::new(),
447        };
448
449        // Serialize multiple times to ensure stable order
450        let serialized1 = serde_json::to_string(&config).unwrap();
451        let serialized2 = serde_json::to_string(&config).unwrap();
452
453        // Both serializations should be identical
454        assert_eq!(serialized1, serialized2, "Serialization should be stable");
455
456        // Keys should be in alphabetical order
457        assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
458        assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
459        assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
460    }
461
462    #[test]
463    fn test_capture_groups_validation() {
464        let test_cases: Vec<(
465            &str,
466            Vec<String>,
467            Result<(), RegexPatternCaptureGroupsValidationError>,
468        )> = vec![
469            (
470                "hello (?<sds_match>world)",
471                vec!["sds_match".to_string()],
472                Ok(()),
473            ),
474            (
475                "hello (?<capture_group>world)",
476                vec!["capture_group".to_string()],
477                Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
478            ),
479            (
480                "hello (?<sds_match>world) and (?<another_group>world)",
481                vec!["sds_match".to_string()],
482                Ok(()),
483            ),
484            (
485                "hello (?<capture_grou>world)",
486                vec!["capture_group".to_string()],
487                Err(
488                    RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
489                        "capture_group".to_string(),
490                    ),
491                ),
492            ),
493            (
494                "hello (?<sds_match>world)",
495                vec!["sds_match".to_string(), "sds_match2".to_string()],
496                Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
497            ),
498        ];
499        for (pattern, capture_groups, expected_result) in test_cases {
500            let rule_config =
501                RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
502            assert_eq!(
503                is_pattern_capture_groups_valid(
504                    &rule_config.pattern_capture_groups,
505                    &get_memoized_regex(pattern, validate_and_create_regex)
506                        .unwrap()
507                        .group_info()
508                ),
509                expected_result
510            );
511        }
512    }
513}