1use crate::proximity_keywords::compile_keywords_proximity_config;
2use crate::scanner::config::RuleConfig;
3use crate::scanner::metrics::RuleMetrics;
4use crate::scanner::regex_rule::compiled::RegexCompiledRule;
5use crate::scanner::regex_rule::regex_store::get_memoized_regex;
6use crate::validation::{RegexPatternCaptureGroupsValidationError, validate_and_create_regex};
7use crate::{CompiledRule, CreateScannerError, Labels};
8use regex_automata::util::captures::GroupInfo;
9use serde::{Deserialize, Serialize};
10use serde_with::DefaultOnNull;
11use serde_with::serde_as;
12use std::sync::Arc;
13use strum::{AsRefStr, EnumIter};
14
15pub const DEFAULT_KEYWORD_LOOKAHEAD: usize = 30;
16
17#[serde_as]
18#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
19pub struct RegexRuleConfig {
20 pub pattern: String,
21 pub proximity_keywords: Option<ProximityKeywordsConfig>,
22 pub validator: Option<SecondaryValidator>,
23 #[serde_as(deserialize_as = "DefaultOnNull")]
24 #[serde(default)]
25 pub labels: Labels,
26 pub pattern_capture_groups: Option<Vec<String>>,
27}
28
29impl RegexRuleConfig {
30 pub fn new(pattern: &str) -> Self {
31 #[allow(deprecated)]
32 Self {
33 pattern: pattern.to_owned(),
34 proximity_keywords: None,
35 validator: None,
36 labels: Labels::default(),
37 pattern_capture_groups: None,
38 }
39 }
40
41 pub fn with_pattern(&self, pattern: &str) -> Self {
42 self.mutate_clone(|x| x.pattern = pattern.to_string())
43 }
44
45 pub fn with_proximity_keywords(&self, proximity_keywords: ProximityKeywordsConfig) -> Self {
46 self.mutate_clone(|x| x.proximity_keywords = Some(proximity_keywords))
47 }
48
49 pub fn with_labels(&self, labels: Labels) -> Self {
50 self.mutate_clone(|x| x.labels = labels)
51 }
52
53 pub fn with_pattern_capture_groups(&self, pattern_capture_groups: Vec<String>) -> Self {
54 self.mutate_clone(|x| x.pattern_capture_groups = Some(pattern_capture_groups))
55 }
56
57 pub fn with_pattern_capture_group(&self, pattern_capture_group: &str) -> Self {
58 self.mutate_clone(|x| match x.pattern_capture_groups {
59 Some(ref mut pattern_capture_groups) => {
60 pattern_capture_groups.push(pattern_capture_group.to_string());
61 }
62 None => {
63 x.pattern_capture_groups = Some(vec![pattern_capture_group.to_string()]);
64 }
65 })
66 }
67
68 pub fn build(&self) -> Arc<dyn RuleConfig> {
69 Arc::new(self.clone())
70 }
71
72 fn mutate_clone(&self, modify: impl FnOnce(&mut Self)) -> Self {
73 let mut clone = self.clone();
74 modify(&mut clone);
75 clone
76 }
77
78 pub fn with_included_keywords(
79 &self,
80 keywords: impl IntoIterator<Item = impl AsRef<str>>,
81 ) -> Self {
82 let mut this = self.clone();
83 let mut config = self.get_or_create_proximity_keywords_config();
84 config.included_keywords = keywords
85 .into_iter()
86 .map(|x| x.as_ref().to_string())
87 .collect::<Vec<_>>();
88 this.proximity_keywords = Some(config);
89 this
90 }
91
92 pub fn with_validator(&self, validator: Option<SecondaryValidator>) -> Self {
93 let mut this = self.clone();
94 this.validator = validator;
95 this
96 }
97
98 fn get_or_create_proximity_keywords_config(&self) -> ProximityKeywordsConfig {
99 self.proximity_keywords
100 .clone()
101 .unwrap_or_else(|| ProximityKeywordsConfig {
102 look_ahead_character_count: DEFAULT_KEYWORD_LOOKAHEAD,
103 included_keywords: vec![],
104 excluded_keywords: vec![],
105 })
106 }
107}
108
109fn is_pattern_capture_groups_valid(
110 pattern_capture_groups: &Option<Vec<String>>,
111 group_info: &GroupInfo,
112) -> Result<(), RegexPatternCaptureGroupsValidationError> {
113 if pattern_capture_groups.is_none() {
114 return Ok(());
115 }
116 let pattern_capture_groups = pattern_capture_groups.as_ref().unwrap();
117 if pattern_capture_groups.len() != 1 {
118 return Err(
120 RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(
121 pattern_capture_groups.len(),
122 ),
123 );
124 }
125 let pattern_capture_group = pattern_capture_groups.first().unwrap();
126 if !group_info
127 .all_names()
128 .filter(|(_, _, name)| name.is_some())
129 .map(|(_, _, name)| name.unwrap())
130 .any(|name| name == pattern_capture_group)
131 {
132 return Err(
133 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
134 pattern_capture_group.clone(),
135 ),
136 );
137 }
138 if pattern_capture_group != "sds_match" {
141 return Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch);
142 }
143 Ok(())
144}
145
146impl RuleConfig for RegexRuleConfig {
147 fn convert_to_compiled_rule(
148 &self,
149 rule_index: usize,
150 scanner_labels: Labels,
151 ) -> Result<Box<dyn CompiledRule>, CreateScannerError> {
152 let regex = get_memoized_regex(&self.pattern, validate_and_create_regex)?;
153
154 let rule_labels = scanner_labels.clone_with_labels(self.labels.clone());
155
156 let (included_keywords, excluded_keywords) = self
157 .proximity_keywords
158 .as_ref()
159 .map(|config| compile_keywords_proximity_config(config, &rule_labels))
160 .unwrap_or(Ok((None, None)))?;
161
162 is_pattern_capture_groups_valid(&self.pattern_capture_groups, regex.group_info())?;
163
164 Ok(Box::new(RegexCompiledRule {
165 rule_index,
166 regex,
167 included_keywords,
168 excluded_keywords,
169 validator: self.validator.clone().map(|x| x.compile()),
170 metrics: RuleMetrics::new(&rule_labels),
171 pattern_capture_groups: self.pattern_capture_groups.clone(),
172 }))
173 }
174}
175
176#[serde_as]
177#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
178pub struct ProximityKeywordsConfig {
179 pub look_ahead_character_count: usize,
180
181 #[serde_as(deserialize_as = "DefaultOnNull")]
182 #[serde(default)]
183 pub included_keywords: Vec<String>,
184
185 #[serde_as(deserialize_as = "DefaultOnNull")]
186 #[serde(default)]
187 pub excluded_keywords: Vec<String>,
188}
189
190#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, EnumIter, AsRefStr)]
191#[serde(tag = "type")]
192pub enum SecondaryValidator {
193 AbaRtnChecksum,
194 BrazilianCnpjChecksum,
195 BrazilianCpfChecksum,
196 BtcChecksum,
197 BulgarianEGNChecksum,
198 ChineseIdChecksum,
199 CoordinationNumberChecksum,
200 CzechPersonalIdentificationNumberChecksum,
201 CzechTaxIdentificationNumberChecksum,
202 DutchBsnChecksum,
203 DutchPassportChecksum,
204 EntropyCheck,
205 EthereumChecksum,
206 FinnishHetuChecksum,
207 FranceNifChecksum,
208 FranceSsnChecksum,
209 GermanIdsChecksum,
210 GermanSvnrChecksum,
211 GithubTokenChecksum,
212 GreekTinChecksum,
213 HungarianTinChecksum,
214 IbanChecker,
215 IrishPpsChecksum,
216 ItalianNationalIdChecksum,
217 JwtClaimsValidator { config: JwtClaimsValidatorConfig },
218 JwtExpirationChecker,
219 LatviaNationalIdChecksum,
220 LithuanianPersonalIdentificationNumberChecksum,
221 LuhnChecksum,
222 LuxembourgIndividualNINChecksum,
223 Mod11_10checksum,
224 Mod11_2checksum,
225 Mod1271_36Checksum,
226 Mod27_26checksum,
227 Mod37_2checksum,
228 Mod37_36checksum,
229 Mod661_26checksum,
230 Mod97_10checksum,
231 MoneroAddress,
232 NhsCheckDigit,
233 NirChecksum,
234 PolishNationalIdChecksum,
235 PolishNipChecksum,
236 PortugueseTaxIdChecksum,
237 RodneCisloNumberChecksum,
238 RomanianPersonalNumericCode,
239 SlovenianPINChecksum,
240 SpanishDniChecksum,
241 SpanishNussChecksum,
242 SwedenPINChecksum,
243}
244
245#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
246#[serde(tag = "type", content = "config")]
247pub enum ClaimRequirement {
248 Present,
250 NotExpired,
252 ExactValue(String),
254 RegexMatch(String),
256}
257
258#[derive(Serialize, Deserialize, Default, Clone, Debug, PartialEq)]
259pub struct JwtClaimsValidatorConfig {
260 #[serde(default)]
261 pub required_headers: std::collections::BTreeMap<String, ClaimRequirement>,
262 #[serde(default)]
263 pub required_claims: std::collections::BTreeMap<String, ClaimRequirement>,
264}
265
266#[cfg(test)]
267mod test {
268 use crate::{AwsType, CustomHttpConfig, MatchValidationType, RootRuleConfig};
269 use std::collections::BTreeMap;
270 use strum::IntoEnumIterator;
271
272 use super::*;
273
274 #[test]
275 fn should_override_pattern() {
276 let rule_config = RegexRuleConfig::new("123").with_pattern("456");
277 assert_eq!(rule_config.pattern, "456");
278 }
279
280 #[test]
281 #[allow(deprecated)]
282 fn should_have_default() {
283 let rule_config = RegexRuleConfig::new("123");
284 assert_eq!(
285 rule_config,
286 RegexRuleConfig {
287 pattern: "123".to_string(),
288 proximity_keywords: None,
289 validator: None,
290 labels: Labels::empty(),
291 pattern_capture_groups: None,
292 }
293 );
294 }
295
296 #[test]
297 fn should_use_capture_group() {
298 let rule_config = RegexRuleConfig::new("hey (?<capture_group>world)")
299 .with_pattern_capture_groups(vec!["capture_group".to_string()]);
300 assert_eq!(
301 rule_config,
302 RegexRuleConfig {
303 pattern: "hey (?<capture_group>world)".to_string(),
304 proximity_keywords: None,
305 validator: None,
306 labels: Labels::empty(),
307 pattern_capture_groups: Some(vec!["capture_group".to_string()]),
308 }
309 );
310 }
311
312 #[test]
313 fn proximity_keywords_should_have_default() {
314 let json_config = r#"{"look_ahead_character_count": 0}"#;
315 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
316 assert_eq!(
317 test,
318 ProximityKeywordsConfig {
319 look_ahead_character_count: 0,
320 included_keywords: vec![],
321 excluded_keywords: vec![]
322 }
323 );
324
325 let json_config = r#"{"look_ahead_character_count": 0, "excluded_keywords": null, "included_keywords": null}"#;
326 let test: ProximityKeywordsConfig = serde_json::from_str(json_config).unwrap();
327 assert_eq!(
328 test,
329 ProximityKeywordsConfig {
330 look_ahead_character_count: 0,
331 included_keywords: vec![],
332 excluded_keywords: vec![]
333 }
334 );
335 }
336
337 #[test]
338 #[allow(deprecated)]
339 fn test_third_party_active_checker() {
340 let http_config = CustomHttpConfig::default().with_endpoint("http://test.com".to_string());
342 let validation_type = MatchValidationType::CustomHttp(http_config.clone());
343 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
344 .third_party_active_checker(validation_type.clone());
345
346 assert_eq!(
347 rule_config.third_party_active_checker,
348 Some(validation_type.clone())
349 );
350 assert_eq!(rule_config.match_validation_type, None);
351 assert_eq!(
352 rule_config.get_third_party_active_checker(),
353 Some(&validation_type)
354 );
355
356 let aws_type = AwsType::AwsId;
358 let validation_type2 = MatchValidationType::Aws(aws_type);
359 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
360 .third_party_active_checker(validation_type2.clone());
361
362 assert_eq!(
363 rule_config.third_party_active_checker,
364 Some(validation_type2.clone())
365 );
366 assert_eq!(
367 rule_config.get_third_party_active_checker(),
368 Some(&validation_type2)
369 );
370
371 let rule_config = RootRuleConfig::new(RegexRuleConfig::new("123"))
373 .third_party_active_checker(MatchValidationType::CustomHttp(http_config.clone()));
374
375 assert_eq!(
376 rule_config.get_third_party_active_checker(),
377 Some(&MatchValidationType::CustomHttp(http_config.clone()))
378 );
379 }
380
381 #[test]
382 fn test_secondary_validator_enum_iter() {
383 let validators: Vec<SecondaryValidator> = SecondaryValidator::iter().collect();
385 assert!(validators.contains(&SecondaryValidator::GithubTokenChecksum));
387 assert!(validators.contains(&SecondaryValidator::JwtExpirationChecker));
388 }
389
390 #[test]
391 fn test_secondary_validator_are_sorted() {
392 let validator_names: Vec<String> = SecondaryValidator::iter()
393 .map(|a| a.as_ref().to_string())
394 .collect();
395 let mut sorted_validator_names = validator_names.clone();
396 sorted_validator_names.sort();
397 assert_eq!(
398 sorted_validator_names, validator_names,
399 "Secondary validators should be sorted by alphabetical order, but it's not the case, expected order:"
400 );
401 }
402
403 #[test]
405 fn test_jwt_claims_validator_config_serialization_order() {
406 let mut required_claims = BTreeMap::new();
408 required_claims.insert("zzz".to_string(), ClaimRequirement::Present);
409 required_claims.insert("exp".to_string(), ClaimRequirement::NotExpired);
410 required_claims.insert(
411 "aaa".to_string(),
412 ClaimRequirement::ExactValue("test".to_string()),
413 );
414 required_claims.insert(
415 "mmm".to_string(),
416 ClaimRequirement::RegexMatch(r"^test.*".to_string()),
417 );
418
419 let config = JwtClaimsValidatorConfig {
420 required_claims,
421 required_headers: std::collections::BTreeMap::new(),
422 };
423
424 let serialized1 = serde_json::to_string(&config).unwrap();
426 let serialized2 = serde_json::to_string(&config).unwrap();
427
428 assert_eq!(serialized1, serialized2, "Serialization should be stable");
430
431 assert!(serialized1.find("aaa").unwrap() < serialized1.find("exp").unwrap());
433 assert!(serialized1.find("exp").unwrap() < serialized1.find("mmm").unwrap());
434 assert!(serialized1.find("mmm").unwrap() < serialized1.find("zzz").unwrap());
435 }
436
437 #[test]
438 fn test_capture_groups_validation() {
439 let test_cases: Vec<(
440 &str,
441 Vec<String>,
442 Result<(), RegexPatternCaptureGroupsValidationError>,
443 )> = vec![
444 (
445 "hello (?<sds_match>world)",
446 vec!["sds_match".to_string()],
447 Ok(()),
448 ),
449 (
450 "hello (?<capture_group>world)",
451 vec!["capture_group".to_string()],
452 Err(RegexPatternCaptureGroupsValidationError::TargetedCaptureGroupMustBeSdsMatch),
453 ),
454 (
455 "hello (?<sds_match>world) and (?<another_group>world)",
456 vec!["sds_match".to_string()],
457 Ok(()),
458 ),
459 (
460 "hello (?<capture_grou>world)",
461 vec!["capture_group".to_string()],
462 Err(
463 RegexPatternCaptureGroupsValidationError::CaptureGroupNotPresent(
464 "capture_group".to_string(),
465 ),
466 ),
467 ),
468 (
469 "hello (?<sds_match>world)",
470 vec!["sds_match".to_string(), "sds_match2".to_string()],
471 Err(RegexPatternCaptureGroupsValidationError::TooManyCaptureGroups(2)),
472 ),
473 ];
474 for (pattern, capture_groups, expected_result) in test_cases {
475 let rule_config =
476 RegexRuleConfig::new(pattern).with_pattern_capture_groups(capture_groups);
477 assert_eq!(
478 is_pattern_capture_groups_valid(
479 &rule_config.pattern_capture_groups,
480 &get_memoized_regex(pattern, validate_and_create_regex)
481 .unwrap()
482 .group_info()
483 ),
484 expected_result
485 );
486 }
487 }
488}