1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{
7 RegexPatternCaptureGroupsValidationError, validate_and_create_regex,
8 validate_named_capture_group_minimum_length,
9};
10use crate::{CompiledRule, CreateScannerError, Labels};
11use regex_automata::util::captures::GroupInfo;
12use serde::{Deserialize, Serialize};
13use serde_with::DefaultOnNull;
14use serde_with::serde_as;
15use std::sync::Arc;
16use strum::{AsRefStr, EnumIter};
17
18pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
19
20#[serde_as]
21#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
22pub struct RegexRuleConfig {
23 pub pattern: String,
24 pub proximity_keywords: Option<ProximityKeywordsConfig>,
25 pub validator: Option<SecondaryValidator>,
26 #[serde_as(deserialize_as = "DefaultOnNull")]
27 #[serde(default)]
28 pub labels: Labels,
29 pub pattern_capture_groups: Option<Vec<String>>,
30}
31
32impl RegexRuleConfig {
33 pub fn new(pattern: &str) -> Self {
34 #[allow(deprecated)]
35 Self {
36 pattern: pattern.to_owned(),
37 proximity_keywords: None,
38 validator: None,
39 labels: Labels::default(),
40 pattern_capture_groups: None,
41 }
42 }
43
44 pub fn with_pattern(&self, pattern: &str) -> Self {
45 self.mutate_clone(|x| x.pattern = pattern.to_string())
46 }
47
48 pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
49 self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
50 }
51
52 pub fn with_labels(&self, labels: Labels) -> Self {
53 self.mutate_clone(|x| x.labels = labels)
54 }
55
56 pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
57 self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
58 }
59
60 pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
61 self.mutate_clone(|x| match x.pattern_capture_groups {
62 Some(ref mut pattern_capture_groups) => {
63 pattern_capture_groups.push(pattern_capture_group.to_string());
64 }
65 None => {
66 x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
67 }
68 })
69 }
70
71 pub fn build(&self) -> Arc<dyn RuleConfig> {
72 Arc::new(self.clone())
73 }
74
75 fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
76 let mut clone = self.clone();
77 modify(&mut clone);
78 clone
79 }
80
81 pub fn with_included_keywords(
82 &self,
83 keywords: impl IntoIterator<Item = impl AsRef<str>>,
84 ) -> Self {
85 let mut this = self.clone();
86 let mut config = self.get_or_create_proximity_keywords_config();
87 config.included_keywords = keywords
88 .into_iter()
89 .map(|x| x.as_ref().to_string())
90 .collect::<Vec<_>>();
91 this.proximity_keywords = Some(config);
92 this
93 }
94
95 pub fn with_excluded_keywords(
96 &self,
97 keywords: impl IntoIterator<Item = impl AsRef<str>>,
98 ) -> Self {
99 let mut this = self.clone();
100 let mut config = self.get_or_create_proximity_keywords_config();
101 config.excluded_keywords = keywords
102 .into_iter()
103 .map(|x| x.as_ref().to_string())
104 .collect::<Vec<_>>();
105 this.proximity_keywords = Some(config);
106 this
107 }
108
109 pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
110 let mut this = self.clone();
111 this.validator = validator;
112 this
113 }
114
115 fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
116 self.proximity_keywords
117 .clone()
118 .unwrap_or_else(|| ProximityKeywordsConfig {
119 look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
120 included_keywords: vec![],
121 excluded_keywords: vec![],
122 })
123 }
124}
125
126fn is_pattern_capture_groups_valid(
127 pattern: &str,
128 pattern_capture_groups: &Option<Vec<String>>,
129 group_info: &GroupInfo,
130) -> Result<(), RegexPatternCaptureGroupsValidationError> {
131 if pattern_capture_groups.is_none() {
132 return Ok(());
133 }
134 let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
135 if pattern_capture_groups.len() != 1 {
136 return Err(
138 RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
139 pattern_capture_groups.len(),
140 ),
141 );
142 }
143 let pattern_capture_group = pattern_capture_groups.first().unwrap();
144 if !group_info
145 .all_names()
146 .filter(|(_, _, name)| name.is_some())
147 .map(|(_, _, name)| name.unwrap())
148 .any(|name| name == pattern_capture_group)
149 {
150 return Err(
151 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
152 pattern_capture_group.clone(),
153 ),
154 );
155 }
156 if pattern_capture_group != "sds_match" {
159 return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
160 }
161 validate_named_capture_group_minimum_length(pattern, pattern_capture_group)?;
162 Ok(())
163}
164
165impl RuleConfig for RegexRuleConfig {
166 fn convert_to_compiled_rule(
167 &self,
168 rule_index: usize,
169 scanner_labels: Labels,
170 ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
171 let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
172
173 let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
174
175 let (included_keywords, excluded_keywords) = self
176 .proximity_keywords
177 .as_ref()
178 .map(|config| compile_keywords_proximity_config(config, &rule_labels))
179 .unwrap_or(Ok((None, None)))?;
180
181 is_pattern_capture_groups_valid(
182 &self.pattern,
183 &self.pattern_capture_groups,
184 regex.group_info(),
185 )?;
186
187 Ok(Box::new(RegexCompiledRule {
188 rule_index,
189 regex,
190 included_keywords,
191 excluded_keywords,
192 validator: self.validator.clone().map(|x| x.compile()),
193 metrics: RuleMetrics::new(&rule_labels),
194 pattern_capture_groups: self.pattern_capture_groups.clone(),
195 }))
196 }
197
198 fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
199 Some(self)
200 }
201}
202
203#[serde_as]
204#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
205pub struct ProximityKeywordsConfig {
206 pub look_ahead_character_count: usize,
207
208 #[serde_as(deserialize_as = "DefaultOnNull")]
209 #[serde(default)]
210 pub included_keywords: Vec<String>,
211
212 #[serde_as(deserialize_as = "DefaultOnNull")]
213 #[serde(default)]
214 pub excluded_keywords: Vec<String>,
215}
216
217#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
218#[serde(tag = "type")]
219pub enum SecondaryValidator {
220 AbaRtnChecksum,
221 AustralianMedicareChecksum,
222 AustralianTfnChecksum,
223 AustrianSSNChecksum,
224 BelgiumNationalRegisterChecksum,
225 BrazilianCnpjChecksum,
226 BrazilianCpfChecksum,
227 BtcChecksum,
228 BulgarianEGNChecksum,
229 ChineseIdChecksum,
230 CoordinationNumberChecksum,
231 CzechPersonalIdentificationNumberChecksum,
232 CzechTaxIdentificationNumberChecksum,
233 DutchBsnChecksum,
234 DutchPassportChecksum,
235 EntropyCheck,
236 EstoniaPersonalCodeChecksum,
237 EthereumChecksum,
238 FinnishHetuChecksum,
239 FranceNifChecksum,
240 FranceSsnChecksum,
241 GermanIdsChecksum,
242 GermanSvnrChecksum,
243 GithubTokenChecksum,
244 GreeceAmkaChecksum,
245 GreekTinChecksum,
246 HungarianTinChecksum,
247 IbanChecker,
248 IrishPpsChecksum,
249 ItalianNationalIdChecksum,
250 JwtClaimsValidator { config: JwtClaimsValidatorConfig },
251 JwtExpirationChecker,
252 LatviaNationalIdChecksum,
253 LithuanianPersonalIdentificationNumberChecksum,
254 LuhnChecksum,
255 LuxembourgIndividualNINChecksum,
256 Mod11_10checksum,
257 Mod11_2checksum,
258 Mod1271_36Checksum,
259 Mod27_26checksum,
260 Mod37_2checksum,
261 Mod37_36checksum,
262 Mod661_26checksum,
263 Mod97_10checksum,
264 MoneroAddress,
265 NhsCheckDigit,
266 NirChecksum,
267 PolishNationalIdChecksum,
268 PolishNipChecksum,
269 PortugueseTaxIdChecksum,
270 RodneCisloNumberChecksum,
271 RomanianPersonalNumericCode,
272 SingaporeNricChecksum,
273 SloveniaTinChecksum,
274 SlovenianPINChecksum,
275 SpanishDniChecksum,
276 SpanishNussChecksum,
277 SwedenPINChecksum,
278 UkNinoFormatCheck,
279 UkTrnChecksum,
280 UsDeaChecksum,
281 UsNpiChecksum,
282 VerhoeffChecksum,
283}
284
285#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
286#[serde(tag = "type", content = "config")]
287pub enum ClaimRequirement {
288 Present,
290 NotExpired,
292 ExactValue(String),
294 RegexMatch(String),
296}
297
298#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
299pub struct JwtClaimsValidatorConfig {
300 #[serde(default)]
301 pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
302 #[serde(default)]
303 pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
304}
305
306#[cfg(test)]
307mod test {
308 use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
309 use std::collections::BTreeMap;
310 use strum::IntoEnumIterator;
311
312 use super::*;
313
314 #[test]
315 fn should_override_pattern() {
316 let rule_config = RegexRuleConfig::new("123").with_pattern("456");
317 assert_eq!(rule_config.pattern, "456");
318 }
319
320 #[test]
321 #[allow(deprecated)]
322 fn should_have_default() {
323 let rule_config = RegexRuleConfig::new("123");
324 assert_eq!(
325 rule_config,
326 RegexRuleConfig {
327 pattern: "123".to_string(),
328 proximity_keywords: None,
329 validator: None,
330 labels: Labels::empty(),
331 pattern_capture_groups: None,
332 }
333 );
334 }
335
336 #[test]
337 fn should_use_capture_group() {
338 let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
339 .with_pattern_capture_groups(vec!["capture_group".to_string()]);
340 assert_eq!(
341 rule_config,
342 RegexRuleConfig {
343 pattern: "hey (?<capture_group>world)".to_string(),
344 proximity_keywords: None,
345 validator: None,
346 labels: Labels::empty(),
347 pattern_capture_groups: Some(vec!["capture_group".to_string()]),
348 }
349 );
350 }
351
352 #[test]
353 fn proximity_keywords_should_have_default() {
354 let json_config = r#"{"look_ahead_character_count": 0}"#;
355 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
356 assert_eq!(
357 test,
358 ProximityKeywordsConfig {
359 look_ahead_character_count: 0,
360 included_keywords: vec![],
361 excluded_keywords: vec![]
362 }
363 );
364
365 let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
366 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
367 assert_eq!(
368 test,
369 ProximityKeywordsConfig {
370 look_ahead_character_count: 0,
371 included_keywords: vec![],
372 excluded_keywords: vec![]
373 }
374 );
375 }
376
377 #[test]
378 #[allow(deprecated)]
379 fn test_third_party_active_checker() {
380 let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
382 let validation_type = MatchValidationType::CustomHttp(http_config.clone());
383 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
384 .third_party_active_checker(validation_type.clone());
385
386 assert_eq!(
387 rule_config.third_party_active_checker,
388 Some(validation_type.clone())
389 );
390 assert_eq!(rule_config.match_validation_type, None);
391 assert_eq!(
392 rule_config.get_third_party_active_checker(),
393 Some(&validation_type)
394 );
395
396 let aws_type = AwsType::AwsId;
398 let validation_type2 = MatchValidationType::Aws(aws_type);
399 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
400 .third_party_active_checker(validation_type2.clone());
401
402 assert_eq!(
403 rule_config.third_party_active_checker,
404 Some(validation_type2.clone())
405 );
406 assert_eq!(
407 rule_config.get_third_party_active_checker(),
408 Some(&validation_type2)
409 );
410
411 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
413 .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
414
415 assert_eq!(
416 rule_config.get_third_party_active_checker(),
417 Some(&MatchValidationType::CustomHttp(http_config.clone()))
418 );
419 }
420
421 #[test]
422 fn test_secondary_validator_enum_iter() {
423 let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
425 assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
427 assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
428 }
429
430 #[test]
431 fn test_secondary_validator_are_sorted() {
432 let validator_names: Vec<String> = SecondaryValidator::iter()
433 .map(|a| a.as_ref().to_string())
434 .collect();
435 let mut sorted_validator_names = validator_names.clone();
436 sorted_validator_names.sort();
437 assert_eq!(
438 sorted_validator_names, validator_names,
439 "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
440 );
441 }
442
443 #[test]
445 fn test_jwt_claims_validator_config_serialization_order() {
446 let mut required_claims = BTreeMap::new();
448 required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
449 required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
450 required_claims.insert(
451 "aaa".to_string(),
452 ClaimRequirement::ExactValue("test".to_string()),
453 );
454 required_claims.insert(
455 "mmm".to_string(),
456 ClaimRequirement::RegexMatch(r"^test.*".to_string()),
457 );
458
459 let config = JwtClaimsValidatorConfig {
460 required_claims,
461 required_headers: std::collections::BTreeMap::new(),
462 };
463
464 let serialized1 = serde_json::to_string(&config).unwrap();
466 let serialized2 = serde_json::to_string(&config).unwrap();
467
468 assert_eq!(serialized1, serialized2, "Serialization should be stable");
470
471 assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
473 assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
474 assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
475 }
476
477 #[test]
478 fn test_capture_groups_validation() {
479 let test_cases: Vec<(
480 &str,
481 Vec<String>,
482 Result<(), RegexPatternCaptureGroupsValidationError>,
483 )> = vec![
484 (
485 "hello (?<sds_match>world)",
486 vec!["sds_match".to_string()],
487 Ok(()),
488 ),
489 (
490 "hello (?<capture_group>world)",
491 vec!["capture_group".to_string()],
492 Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
493 ),
494 (
495 "hello (?<sds_match>world) and (?<another_group>world)",
496 vec!["sds_match".to_string()],
497 Ok(()),
498 ),
499 (
500 "hello (?<capture_grou>world)",
501 vec!["capture_group".to_string()],
502 Err(
503 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
504 "capture_group".to_string(),
505 ),
506 ),
507 ),
508 (
509 "hello (?<sds_match>d*)",
510 vec!["sds_match".to_string()],
511 Err(RegexPatternCaptureGroupsValidationError::CaptureGroupMatchesEmptyString),
512 ),
513 (
514 "hello (?<sds_match>world)",
515 vec!["sds_match".to_string(), "sds_match2".to_string()],
516 Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
517 ),
518 ];
519 for (pattern, capture_groups, expected_result) in test_cases {
520 let rule_config =
521 RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
522 assert_eq!(
523 is_pattern_capture_groups_valid(
524 &rule_config.pattern,
525 &rule_config.pattern_capture_groups,
526 &get_memoized_regex(pattern, validate_and_create_regex)
527 .unwrap()
528 .group_info()
529 ),
530 expected_result
531 );
532 }
533 }
534}