1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{RegexPatternCaptureGroupsValidationError, validate_and_create_regex};
7use crate::{CompiledRule, CreateScannerError, Labels};
8use regex_automata::util::captures::GroupInfo;
9use serde::{Deserialize, Serialize};
10use serde_with::DefaultOnNull;
11use serde_with::serde_as;
12use std::sync::Arc;
13use strum::{AsRefStr, EnumIter};
14
15pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
16
17#[serde_as]
18#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
19pub struct RegexRuleConfig {
20 pub pattern: String,
21 pub proximity_keywords: Option<ProximityKeywordsConfig>,
22 pub validator: Option<SecondaryValidator>,
23 #[serde_as(deserialize_as = "DefaultOnNull")]
24 #[serde(default)]
25 pub labels: Labels,
26 pub pattern_capture_groups: Option<Vec<String>>,
27}
28
29impl RegexRuleConfig {
30 pub fn new(pattern: &str) -> Self {
31 #[allow(deprecated)]
32 Self {
33 pattern: pattern.to_owned(),
34 proximity_keywords: None,
35 validator: None,
36 labels: Labels::default(),
37 pattern_capture_groups: None,
38 }
39 }
40
41 pub fn with_pattern(&self, pattern: &str) -> Self {
42 self.mutate_clone(|x| x.pattern = pattern.to_string())
43 }
44
45 pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
46 self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
47 }
48
49 pub fn with_labels(&self, labels: Labels) -> Self {
50 self.mutate_clone(|x| x.labels = labels)
51 }
52
53 pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
54 self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
55 }
56
57 pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
58 self.mutate_clone(|x| match x.pattern_capture_groups {
59 Some(ref mut pattern_capture_groups) => {
60 pattern_capture_groups.push(pattern_capture_group.to_string());
61 }
62 None => {
63 x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
64 }
65 })
66 }
67
68 pub fn build(&self) -> Arc<dyn RuleConfig> {
69 Arc::new(self.clone())
70 }
71
72 fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
73 let mut clone = self.clone();
74 modify(&mut clone);
75 clone
76 }
77
78 pub fn with_included_keywords(
79 &self,
80 keywords: impl IntoIterator<Item = impl AsRef<str>>,
81 ) -> Self {
82 let mut this = self.clone();
83 let mut config = self.get_or_create_proximity_keywords_config();
84 config.included_keywords = keywords
85 .into_iter()
86 .map(|x| x.as_ref().to_string())
87 .collect::<Vec<_>>();
88 this.proximity_keywords = Some(config);
89 this
90 }
91
92 pub fn with_excluded_keywords(
93 &self,
94 keywords: impl IntoIterator<Item = impl AsRef<str>>,
95 ) -> Self {
96 let mut this = self.clone();
97 let mut config = self.get_or_create_proximity_keywords_config();
98 config.excluded_keywords = keywords
99 .into_iter()
100 .map(|x| x.as_ref().to_string())
101 .collect::<Vec<_>>();
102 this.proximity_keywords = Some(config);
103 this
104 }
105
106 pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
107 let mut this = self.clone();
108 this.validator = validator;
109 this
110 }
111
112 fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
113 self.proximity_keywords
114 .clone()
115 .unwrap_or_else(|| ProximityKeywordsConfig {
116 look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
117 included_keywords: vec![],
118 excluded_keywords: vec![],
119 })
120 }
121}
122
123fn is_pattern_capture_groups_valid(
124 pattern_capture_groups: &Option<Vec<String>>,
125 group_info: &GroupInfo,
126) -> Result<(), RegexPatternCaptureGroupsValidationError> {
127 if pattern_capture_groups.is_none() {
128 return Ok(());
129 }
130 let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
131 if pattern_capture_groups.len() != 1 {
132 return Err(
134 RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
135 pattern_capture_groups.len(),
136 ),
137 );
138 }
139 let pattern_capture_group = pattern_capture_groups.first().unwrap();
140 if !group_info
141 .all_names()
142 .filter(|(_, _, name)| name.is_some())
143 .map(|(_, _, name)| name.unwrap())
144 .any(|name| name == pattern_capture_group)
145 {
146 return Err(
147 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
148 pattern_capture_group.clone(),
149 ),
150 );
151 }
152 if pattern_capture_group != "sds_match" {
155 return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
156 }
157 Ok(())
158}
159
160impl RuleConfig for RegexRuleConfig {
161 fn convert_to_compiled_rule(
162 &self,
163 rule_index: usize,
164 scanner_labels: Labels,
165 ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
166 let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
167
168 let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
169
170 let (included_keywords, excluded_keywords) = self
171 .proximity_keywords
172 .as_ref()
173 .map(|config| compile_keywords_proximity_config(config, &rule_labels))
174 .unwrap_or(Ok((None, None)))?;
175
176 is_pattern_capture_groups_valid(&self.pattern_capture_groups, regex.group_info())?;
177
178 Ok(Box::new(RegexCompiledRule {
179 rule_index,
180 regex,
181 included_keywords,
182 excluded_keywords,
183 validator: self.validator.clone().map(|x| x.compile()),
184 metrics: RuleMetrics::new(&rule_labels),
185 pattern_capture_groups: self.pattern_capture_groups.clone(),
186 }))
187 }
188
189 fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
190 Some(self)
191 }
192}
193
194#[serde_as]
195#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
196pub struct ProximityKeywordsConfig {
197 pub look_ahead_character_count: usize,
198
199 #[serde_as(deserialize_as = "DefaultOnNull")]
200 #[serde(default)]
201 pub included_keywords: Vec<String>,
202
203 #[serde_as(deserialize_as = "DefaultOnNull")]
204 #[serde(default)]
205 pub excluded_keywords: Vec<String>,
206}
207
208#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
209#[serde(tag = "type")]
210pub enum SecondaryValidator {
211 AbaRtnChecksum,
212 BelgiumNationalRegisterChecksum,
213 BrazilianCnpjChecksum,
214 BrazilianCpfChecksum,
215 BtcChecksum,
216 BulgarianEGNChecksum,
217 ChineseIdChecksum,
218 CoordinationNumberChecksum,
219 CzechPersonalIdentificationNumberChecksum,
220 CzechTaxIdentificationNumberChecksum,
221 DutchBsnChecksum,
222 DutchPassportChecksum,
223 EntropyCheck,
224 EstoniaPersonalCodeChecksum,
225 EthereumChecksum,
226 FinnishHetuChecksum,
227 FranceNifChecksum,
228 FranceSsnChecksum,
229 GermanIdsChecksum,
230 GermanSvnrChecksum,
231 GithubTokenChecksum,
232 GreeceAmkaChecksum,
233 GreekTinChecksum,
234 HungarianTinChecksum,
235 IbanChecker,
236 IrishPpsChecksum,
237 ItalianNationalIdChecksum,
238 JwtClaimsValidator { config: JwtClaimsValidatorConfig },
239 JwtExpirationChecker,
240 LatviaNationalIdChecksum,
241 LithuanianPersonalIdentificationNumberChecksum,
242 LuhnChecksum,
243 LuxembourgIndividualNINChecksum,
244 Mod11_10checksum,
245 Mod11_2checksum,
246 Mod1271_36Checksum,
247 Mod27_26checksum,
248 Mod37_2checksum,
249 Mod37_36checksum,
250 Mod661_26checksum,
251 Mod97_10checksum,
252 MoneroAddress,
253 NhsCheckDigit,
254 NirChecksum,
255 PolishNationalIdChecksum,
256 PolishNipChecksum,
257 PortugueseTaxIdChecksum,
258 RodneCisloNumberChecksum,
259 RomanianPersonalNumericCode,
260 SloveniaTinChecksum,
261 SlovenianPINChecksum,
262 SpanishDniChecksum,
263 SpanishNussChecksum,
264 SwedenPINChecksum,
265 UsDeaChecksum,
266 UsNpiChecksum,
267 VerhoeffChecksum,
268}
269
270#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
271#[serde(tag = "type", content = "config")]
272pub enum ClaimRequirement {
273 Present,
275 NotExpired,
277 ExactValue(String),
279 RegexMatch(String),
281}
282
283#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
284pub struct JwtClaimsValidatorConfig {
285 #[serde(default)]
286 pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
287 #[serde(default)]
288 pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
289}
290
291#[cfg(test)]
292mod test {
293 use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
294 use std::collections::BTreeMap;
295 use strum::IntoEnumIterator;
296
297 use super::*;
298
299 #[test]
300 fn should_override_pattern() {
301 let rule_config = RegexRuleConfig::new("123").with_pattern("456");
302 assert_eq!(rule_config.pattern, "456");
303 }
304
305 #[test]
306 #[allow(deprecated)]
307 fn should_have_default() {
308 let rule_config = RegexRuleConfig::new("123");
309 assert_eq!(
310 rule_config,
311 RegexRuleConfig {
312 pattern: "123".to_string(),
313 proximity_keywords: None,
314 validator: None,
315 labels: Labels::empty(),
316 pattern_capture_groups: None,
317 }
318 );
319 }
320
321 #[test]
322 fn should_use_capture_group() {
323 let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
324 .with_pattern_capture_groups(vec!["capture_group".to_string()]);
325 assert_eq!(
326 rule_config,
327 RegexRuleConfig {
328 pattern: "hey (?<capture_group>world)".to_string(),
329 proximity_keywords: None,
330 validator: None,
331 labels: Labels::empty(),
332 pattern_capture_groups: Some(vec!["capture_group".to_string()]),
333 }
334 );
335 }
336
337 #[test]
338 fn proximity_keywords_should_have_default() {
339 let json_config = r#"{"look_ahead_character_count": 0}"#;
340 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
341 assert_eq!(
342 test,
343 ProximityKeywordsConfig {
344 look_ahead_character_count: 0,
345 included_keywords: vec![],
346 excluded_keywords: vec![]
347 }
348 );
349
350 let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
351 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
352 assert_eq!(
353 test,
354 ProximityKeywordsConfig {
355 look_ahead_character_count: 0,
356 included_keywords: vec![],
357 excluded_keywords: vec![]
358 }
359 );
360 }
361
362 #[test]
363 #[allow(deprecated)]
364 fn test_third_party_active_checker() {
365 let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
367 let validation_type = MatchValidationType::CustomHttp(http_config.clone());
368 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
369 .third_party_active_checker(validation_type.clone());
370
371 assert_eq!(
372 rule_config.third_party_active_checker,
373 Some(validation_type.clone())
374 );
375 assert_eq!(rule_config.match_validation_type, None);
376 assert_eq!(
377 rule_config.get_third_party_active_checker(),
378 Some(&validation_type)
379 );
380
381 let aws_type = AwsType::AwsId;
383 let validation_type2 = MatchValidationType::Aws(aws_type);
384 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
385 .third_party_active_checker(validation_type2.clone());
386
387 assert_eq!(
388 rule_config.third_party_active_checker,
389 Some(validation_type2.clone())
390 );
391 assert_eq!(
392 rule_config.get_third_party_active_checker(),
393 Some(&validation_type2)
394 );
395
396 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
398 .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
399
400 assert_eq!(
401 rule_config.get_third_party_active_checker(),
402 Some(&MatchValidationType::CustomHttp(http_config.clone()))
403 );
404 }
405
406 #[test]
407 fn test_secondary_validator_enum_iter() {
408 let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
410 assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
412 assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
413 }
414
415 #[test]
416 fn test_secondary_validator_are_sorted() {
417 let validator_names: Vec<String> = SecondaryValidator::iter()
418 .map(|a| a.as_ref().to_string())
419 .collect();
420 let mut sorted_validator_names = validator_names.clone();
421 sorted_validator_names.sort();
422 assert_eq!(
423 sorted_validator_names, validator_names,
424 "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
425 );
426 }
427
428 #[test]
430 fn test_jwt_claims_validator_config_serialization_order() {
431 let mut required_claims = BTreeMap::new();
433 required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
434 required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
435 required_claims.insert(
436 "aaa".to_string(),
437 ClaimRequirement::ExactValue("test".to_string()),
438 );
439 required_claims.insert(
440 "mmm".to_string(),
441 ClaimRequirement::RegexMatch(r"^test.*".to_string()),
442 );
443
444 let config = JwtClaimsValidatorConfig {
445 required_claims,
446 required_headers: std::collections::BTreeMap::new(),
447 };
448
449 let serialized1 = serde_json::to_string(&config).unwrap();
451 let serialized2 = serde_json::to_string(&config).unwrap();
452
453 assert_eq!(serialized1, serialized2, "Serialization should be stable");
455
456 assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
458 assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
459 assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
460 }
461
462 #[test]
463 fn test_capture_groups_validation() {
464 let test_cases: Vec<(
465 &str,
466 Vec<String>,
467 Result<(), RegexPatternCaptureGroupsValidationError>,
468 )> = vec![
469 (
470 "hello (?<sds_match>world)",
471 vec!["sds_match".to_string()],
472 Ok(()),
473 ),
474 (
475 "hello (?<capture_group>world)",
476 vec!["capture_group".to_string()],
477 Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
478 ),
479 (
480 "hello (?<sds_match>world) and (?<another_group>world)",
481 vec!["sds_match".to_string()],
482 Ok(()),
483 ),
484 (
485 "hello (?<capture_grou>world)",
486 vec!["capture_group".to_string()],
487 Err(
488 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
489 "capture_group".to_string(),
490 ),
491 ),
492 ),
493 (
494 "hello (?<sds_match>world)",
495 vec!["sds_match".to_string(), "sds_match2".to_string()],
496 Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
497 ),
498 ];
499 for (pattern, capture_groups, expected_result) in test_cases {
500 let rule_config =
501 RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
502 assert_eq!(
503 is_pattern_capture_groups_valid(
504 &rule_config.pattern_capture_groups,
505 &get_memoized_regex(pattern, validate_and_create_regex)
506 .unwrap()
507 .group_info()
508 ),
509 expected_result
510 );
511 }
512 }
513}