1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{RegexPatternCaptureGroupsValidationError, validate_and_create_regex};
7use crate::{CompiledRule, CreateScannerError, Labels};
8use regex_automata::util::captures::GroupInfo;
9use serde::{Deserialize, Serialize};
10use serde_with::DefaultOnNull;
11use serde_with::serde_as;
12use std::sync::Arc;
13use strum::{AsRefStr, EnumIter};
14
15pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
16
17#[serde_as]
18#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
19pub struct RegexRuleConfig {
20 pub pattern: String,
21 pub proximity_keywords: Option<ProximityKeywordsConfig>,
22 pub validator: Option<SecondaryValidator>,
23 #[serde_as(deserialize_as = "DefaultOnNull")]
24 #[serde(default)]
25 pub labels: Labels,
26 pub pattern_capture_groups: Option<Vec<String>>,
27}
28
29impl RegexRuleConfig {
30 pub fn new(pattern: &str) -> Self {
31 #[allow(deprecated)]
32 Self {
33 pattern: pattern.to_owned(),
34 proximity_keywords: None,
35 validator: None,
36 labels: Labels::default(),
37 pattern_capture_groups: None,
38 }
39 }
40
41 pub fn with_pattern(&self, pattern: &str) -> Self {
42 self.mutate_clone(|x| x.pattern = pattern.to_string())
43 }
44
45 pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
46 self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
47 }
48
49 pub fn with_labels(&self, labels: Labels) -> Self {
50 self.mutate_clone(|x| x.labels = labels)
51 }
52
53 pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
54 self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
55 }
56
57 pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
58 self.mutate_clone(|x| match x.pattern_capture_groups {
59 Some(ref mut pattern_capture_groups) => {
60 pattern_capture_groups.push(pattern_capture_group.to_string());
61 }
62 None => {
63 x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
64 }
65 })
66 }
67
68 pub fn build(&self) -> Arc<dyn RuleConfig> {
69 Arc::new(self.clone())
70 }
71
72 fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
73 let mut clone = self.clone();
74 modify(&mut clone);
75 clone
76 }
77
78 pub fn with_included_keywords(
79 &self,
80 keywords: impl IntoIterator<Item = impl AsRef<str>>,
81 ) -> Self {
82 let mut this = self.clone();
83 let mut config = self.get_or_create_proximity_keywords_config();
84 config.included_keywords = keywords
85 .into_iter()
86 .map(|x| x.as_ref().to_string())
87 .collect::<Vec<_>>();
88 this.proximity_keywords = Some(config);
89 this
90 }
91
92 pub fn with_excluded_keywords(
93 &self,
94 keywords: impl IntoIterator<Item = impl AsRef<str>>,
95 ) -> Self {
96 let mut this = self.clone();
97 let mut config = self.get_or_create_proximity_keywords_config();
98 config.excluded_keywords = keywords
99 .into_iter()
100 .map(|x| x.as_ref().to_string())
101 .collect::<Vec<_>>();
102 this.proximity_keywords = Some(config);
103 this
104 }
105
106 pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
107 let mut this = self.clone();
108 this.validator = validator;
109 this
110 }
111
112 fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
113 self.proximity_keywords
114 .clone()
115 .unwrap_or_else(|| ProximityKeywordsConfig {
116 look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
117 included_keywords: vec![],
118 excluded_keywords: vec![],
119 })
120 }
121}
122
123fn is_pattern_capture_groups_valid(
124 pattern_capture_groups: &Option<Vec<String>>,
125 group_info: &GroupInfo,
126) -> Result<(), RegexPatternCaptureGroupsValidationError> {
127 if pattern_capture_groups.is_none() {
128 return Ok(());
129 }
130 let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
131 if pattern_capture_groups.len() != 1 {
132 return Err(
134 RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
135 pattern_capture_groups.len(),
136 ),
137 );
138 }
139 let pattern_capture_group = pattern_capture_groups.first().unwrap();
140 if !group_info
141 .all_names()
142 .filter(|(_, _, name)| name.is_some())
143 .map(|(_, _, name)| name.unwrap())
144 .any(|name| name == pattern_capture_group)
145 {
146 return Err(
147 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
148 pattern_capture_group.clone(),
149 ),
150 );
151 }
152 if pattern_capture_group != "sds_match" {
155 return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
156 }
157 Ok(())
158}
159
160impl RuleConfig for RegexRuleConfig {
161 fn convert_to_compiled_rule(
162 &self,
163 rule_index: usize,
164 scanner_labels: Labels,
165 ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
166 let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
167
168 let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
169
170 let (included_keywords, excluded_keywords) = self
171 .proximity_keywords
172 .as_ref()
173 .map(|config| compile_keywords_proximity_config(config, &rule_labels))
174 .unwrap_or(Ok((None, None)))?;
175
176 is_pattern_capture_groups_valid(&self.pattern_capture_groups, regex.group_info())?;
177
178 Ok(Box::new(RegexCompiledRule {
179 rule_index,
180 regex,
181 included_keywords,
182 excluded_keywords,
183 validator: self.validator.clone().map(|x| x.compile()),
184 metrics: RuleMetrics::new(&rule_labels),
185 pattern_capture_groups: self.pattern_capture_groups.clone(),
186 }))
187 }
188
189 fn as_regex_rule(&self) -> Option<&RegexRuleConfig> {
190 Some(self)
191 }
192}
193
194#[serde_as]
195#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
196pub struct ProximityKeywordsConfig {
197 pub look_ahead_character_count: usize,
198
199 #[serde_as(deserialize_as = "DefaultOnNull")]
200 #[serde(default)]
201 pub included_keywords: Vec<String>,
202
203 #[serde_as(deserialize_as = "DefaultOnNull")]
204 #[serde(default)]
205 pub excluded_keywords: Vec<String>,
206}
207
208#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
209#[serde(tag = "type")]
210pub enum SecondaryValidator {
211 AbaRtnChecksum,
212 AustrianSSNChecksum,
213 BelgiumNationalRegisterChecksum,
214 BrazilianCnpjChecksum,
215 BrazilianCpfChecksum,
216 BtcChecksum,
217 BulgarianEGNChecksum,
218 ChineseIdChecksum,
219 CoordinationNumberChecksum,
220 CzechPersonalIdentificationNumberChecksum,
221 CzechTaxIdentificationNumberChecksum,
222 DutchBsnChecksum,
223 DutchPassportChecksum,
224 EntropyCheck,
225 EstoniaPersonalCodeChecksum,
226 EthereumChecksum,
227 FinnishHetuChecksum,
228 FranceNifChecksum,
229 FranceSsnChecksum,
230 GermanIdsChecksum,
231 GermanSvnrChecksum,
232 GithubTokenChecksum,
233 GreeceAmkaChecksum,
234 GreekTinChecksum,
235 HungarianTinChecksum,
236 IbanChecker,
237 IrishPpsChecksum,
238 ItalianNationalIdChecksum,
239 JwtClaimsValidator { config: JwtClaimsValidatorConfig },
240 JwtExpirationChecker,
241 LatviaNationalIdChecksum,
242 LithuanianPersonalIdentificationNumberChecksum,
243 LuhnChecksum,
244 LuxembourgIndividualNINChecksum,
245 Mod11_10checksum,
246 Mod11_2checksum,
247 Mod1271_36Checksum,
248 Mod27_26checksum,
249 Mod37_2checksum,
250 Mod37_36checksum,
251 Mod661_26checksum,
252 Mod97_10checksum,
253 MoneroAddress,
254 NhsCheckDigit,
255 NirChecksum,
256 PolishNationalIdChecksum,
257 PolishNipChecksum,
258 PortugueseTaxIdChecksum,
259 RodneCisloNumberChecksum,
260 RomanianPersonalNumericCode,
261 SloveniaTinChecksum,
262 SlovenianPINChecksum,
263 SpanishDniChecksum,
264 SpanishNussChecksum,
265 SwedenPINChecksum,
266 UsDeaChecksum,
267 UsNpiChecksum,
268 VerhoeffChecksum,
269}
270
271#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
272#[serde(tag = "type", content = "config")]
273pub enum ClaimRequirement {
274 Present,
276 NotExpired,
278 ExactValue(String),
280 RegexMatch(String),
282}
283
284#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
285pub struct JwtClaimsValidatorConfig {
286 #[serde(default)]
287 pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
288 #[serde(default)]
289 pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
290}
291
292#[cfg(test)]
293mod test {
294 use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
295 use std::collections::BTreeMap;
296 use strum::IntoEnumIterator;
297
298 use super::*;
299
300 #[test]
301 fn should_override_pattern() {
302 let rule_config = RegexRuleConfig::new("123").with_pattern("456");
303 assert_eq!(rule_config.pattern, "456");
304 }
305
306 #[test]
307 #[allow(deprecated)]
308 fn should_have_default() {
309 let rule_config = RegexRuleConfig::new("123");
310 assert_eq!(
311 rule_config,
312 RegexRuleConfig {
313 pattern: "123".to_string(),
314 proximity_keywords: None,
315 validator: None,
316 labels: Labels::empty(),
317 pattern_capture_groups: None,
318 }
319 );
320 }
321
322 #[test]
323 fn should_use_capture_group() {
324 let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
325 .with_pattern_capture_groups(vec!["capture_group".to_string()]);
326 assert_eq!(
327 rule_config,
328 RegexRuleConfig {
329 pattern: "hey (?<capture_group>world)".to_string(),
330 proximity_keywords: None,
331 validator: None,
332 labels: Labels::empty(),
333 pattern_capture_groups: Some(vec!["capture_group".to_string()]),
334 }
335 );
336 }
337
338 #[test]
339 fn proximity_keywords_should_have_default() {
340 let json_config = r#"{"look_ahead_character_count": 0}"#;
341 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
342 assert_eq!(
343 test,
344 ProximityKeywordsConfig {
345 look_ahead_character_count: 0,
346 included_keywords: vec![],
347 excluded_keywords: vec![]
348 }
349 );
350
351 let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
352 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
353 assert_eq!(
354 test,
355 ProximityKeywordsConfig {
356 look_ahead_character_count: 0,
357 included_keywords: vec![],
358 excluded_keywords: vec![]
359 }
360 );
361 }
362
363 #[test]
364 #[allow(deprecated)]
365 fn test_third_party_active_checker() {
366 let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
368 let validation_type = MatchValidationType::CustomHttp(http_config.clone());
369 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
370 .third_party_active_checker(validation_type.clone());
371
372 assert_eq!(
373 rule_config.third_party_active_checker,
374 Some(validation_type.clone())
375 );
376 assert_eq!(rule_config.match_validation_type, None);
377 assert_eq!(
378 rule_config.get_third_party_active_checker(),
379 Some(&validation_type)
380 );
381
382 let aws_type = AwsType::AwsId;
384 let validation_type2 = MatchValidationType::Aws(aws_type);
385 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
386 .third_party_active_checker(validation_type2.clone());
387
388 assert_eq!(
389 rule_config.third_party_active_checker,
390 Some(validation_type2.clone())
391 );
392 assert_eq!(
393 rule_config.get_third_party_active_checker(),
394 Some(&validation_type2)
395 );
396
397 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
399 .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
400
401 assert_eq!(
402 rule_config.get_third_party_active_checker(),
403 Some(&MatchValidationType::CustomHttp(http_config.clone()))
404 );
405 }
406
407 #[test]
408 fn test_secondary_validator_enum_iter() {
409 let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
411 assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
413 assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
414 }
415
416 #[test]
417 fn test_secondary_validator_are_sorted() {
418 let validator_names: Vec<String> = SecondaryValidator::iter()
419 .map(|a| a.as_ref().to_string())
420 .collect();
421 let mut sorted_validator_names = validator_names.clone();
422 sorted_validator_names.sort();
423 assert_eq!(
424 sorted_validator_names, validator_names,
425 "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
426 );
427 }
428
429 #[test]
431 fn test_jwt_claims_validator_config_serialization_order() {
432 let mut required_claims = BTreeMap::new();
434 required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
435 required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
436 required_claims.insert(
437 "aaa".to_string(),
438 ClaimRequirement::ExactValue("test".to_string()),
439 );
440 required_claims.insert(
441 "mmm".to_string(),
442 ClaimRequirement::RegexMatch(r"^test.*".to_string()),
443 );
444
445 let config = JwtClaimsValidatorConfig {
446 required_claims,
447 required_headers: std::collections::BTreeMap::new(),
448 };
449
450 let serialized1 = serde_json::to_string(&config).unwrap();
452 let serialized2 = serde_json::to_string(&config).unwrap();
453
454 assert_eq!(serialized1, serialized2, "Serialization should be stable");
456
457 assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
459 assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
460 assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
461 }
462
463 #[test]
464 fn test_capture_groups_validation() {
465 let test_cases: Vec<(
466 &str,
467 Vec<String>,
468 Result<(), RegexPatternCaptureGroupsValidationError>,
469 )> = vec![
470 (
471 "hello (?<sds_match>world)",
472 vec!["sds_match".to_string()],
473 Ok(()),
474 ),
475 (
476 "hello (?<capture_group>world)",
477 vec!["capture_group".to_string()],
478 Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
479 ),
480 (
481 "hello (?<sds_match>world) and (?<another_group>world)",
482 vec!["sds_match".to_string()],
483 Ok(()),
484 ),
485 (
486 "hello (?<capture_grou>world)",
487 vec!["capture_group".to_string()],
488 Err(
489 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
490 "capture_group".to_string(),
491 ),
492 ),
493 ),
494 (
495 "hello (?<sds_match>world)",
496 vec!["sds_match".to_string(), "sds_match2".to_string()],
497 Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
498 ),
499 ];
500 for (pattern, capture_groups, expected_result) in test_cases {
501 let rule_config =
502 RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
503 assert_eq!(
504 is_pattern_capture_groups_valid(
505 &rule_config.pattern_capture_groups,
506 &get_memoized_regex(pattern, validate_and_create_regex)
507 .unwrap()
508 .group_info()
509 ),
510 expected_result
511 );
512 }
513 }
514}