1use std::collections::BTreeMap;
2use std::collections::btree_map::Entry;
3use std::fmt;
4
5use pext::simple_extension_declaration::MappingType;
6use substrait::proto::extensions as pext;
7use thiserror::Error;
8
9pub const EXTENSIONS_HEADER: &str = "=== Extensions";
10pub const EXTENSION_URIS_HEADER: &str = "URIs:";
11pub const EXTENSION_FUNCTIONS_HEADER: &str = "Functions:";
12pub const EXTENSION_TYPES_HEADER: &str = "Types:";
13pub const EXTENSION_TYPE_VARIATIONS_HEADER: &str = "Type Variations:";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub enum ExtensionKind {
17 Function,
19 Type,
20 TypeVariation,
21}
22
23impl ExtensionKind {
24 pub fn name(&self) -> &'static str {
25 match self {
26 ExtensionKind::Function => "function",
28 ExtensionKind::Type => "type",
29 ExtensionKind::TypeVariation => "type_variation",
30 }
31 }
32}
33
34impl fmt::Display for ExtensionKind {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 ExtensionKind::Function => write!(f, "Function"),
39 ExtensionKind::Type => write!(f, "Type"),
40 ExtensionKind::TypeVariation => write!(f, "Type Variation"),
41 }
42 }
43}
44
45#[derive(Error, Debug, PartialEq, Clone)]
46pub enum InsertError {
47 #[error("Extension declaration missing mapping type")]
48 MissingMappingType,
49
50 #[error("Duplicate URI anchor {anchor} for {prev} and {name}")]
51 DuplicateUriAnchor {
52 anchor: u32,
53 prev: String,
54 name: String,
55 },
56
57 #[error("Duplicate extension {kind} anchor {anchor} for {prev} and {name}")]
58 DuplicateAnchor {
59 kind: ExtensionKind,
60 anchor: u32,
61 prev: String,
62 name: String,
63 },
64
65 #[error("Missing URI anchor {uri} for extension {kind} anchor {anchor} name {name}")]
66 MissingUri {
67 kind: ExtensionKind,
68 anchor: u32,
69 name: String,
70 uri: u32,
71 },
72
73 #[error(
74 "Duplicate extension {kind} anchor {anchor} for {prev} and {name}, also missing URI {uri}"
75 )]
76 DuplicateAndMissingUri {
77 kind: ExtensionKind,
78 anchor: u32,
79 prev: String,
80 name: String,
81 uri: u32,
82 },
83}
84
85#[derive(Default, Debug, Clone, PartialEq)]
88pub struct SimpleExtensions {
89 uris: BTreeMap<u32, String>,
91 extensions: BTreeMap<(u32, ExtensionKind), (u32, String)>,
93}
94
95impl SimpleExtensions {
96 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn from_extensions<'a>(
101 uris: impl IntoIterator<Item = &'a pext::SimpleExtensionUri>,
102 extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
103 ) -> (Self, Vec<InsertError>) {
104 let mut exts = Self::new();
108
109 let mut errors = Vec::<InsertError>::new();
110
111 for uri in uris {
112 if let Err(e) = exts.add_extension_uri(uri.uri.clone(), uri.extension_uri_anchor) {
113 errors.push(e);
114 }
115 }
116
117 for extension in extensions {
118 match &extension.mapping_type {
119 Some(MappingType::ExtensionType(t)) => {
120 if let Err(e) = exts.add_extension(
121 ExtensionKind::Type,
122 t.extension_uri_reference,
123 t.type_anchor,
124 t.name.clone(),
125 ) {
126 errors.push(e);
127 }
128 }
129 Some(MappingType::ExtensionFunction(f)) => {
130 if let Err(e) = exts.add_extension(
131 ExtensionKind::Function,
132 f.extension_uri_reference,
133 f.function_anchor,
134 f.name.clone(),
135 ) {
136 errors.push(e);
137 }
138 }
139 Some(MappingType::ExtensionTypeVariation(v)) => {
140 if let Err(e) = exts.add_extension(
141 ExtensionKind::TypeVariation,
142 v.extension_uri_reference,
143 v.type_variation_anchor,
144 v.name.clone(),
145 ) {
146 errors.push(e);
147 }
148 }
149 None => {
150 errors.push(InsertError::MissingMappingType);
151 }
152 }
153 }
154
155 (exts, errors)
156 }
157
158 pub fn add_extension_uri(&mut self, uri: String, anchor: u32) -> Result<(), InsertError> {
159 match self.uris.entry(anchor) {
160 Entry::Occupied(e) => {
161 return Err(InsertError::DuplicateUriAnchor {
162 anchor,
163 prev: e.get().clone(),
164 name: uri,
165 });
166 }
167 Entry::Vacant(e) => {
168 e.insert(uri);
169 }
170 }
171 Ok(())
172 }
173
174 pub fn add_extension(
175 &mut self,
176 kind: ExtensionKind,
177 uri: u32,
178 anchor: u32,
179 name: String,
180 ) -> Result<(), InsertError> {
181 let missing_uri = !self.uris.contains_key(&uri);
182
183 let prev = match self.extensions.entry((anchor, kind)) {
184 Entry::Occupied(e) => Some(e.get().1.clone()),
185 Entry::Vacant(v) => {
186 v.insert((uri, name.clone()));
187 None
188 }
189 };
190
191 match (missing_uri, prev) {
192 (true, Some(prev)) => Err(InsertError::DuplicateAndMissingUri {
193 kind,
194 anchor,
195 prev,
196 name,
197 uri,
198 }),
199 (false, Some(prev)) => Err(InsertError::DuplicateAnchor {
200 kind,
201 anchor,
202 prev,
203 name,
204 }),
205 (true, None) => Err(InsertError::MissingUri {
206 kind,
207 anchor,
208 name,
209 uri,
210 }),
211 (false, None) => Ok(()),
212 }
213 }
214
215 pub fn is_empty(&self) -> bool {
216 self.uris.is_empty() && self.extensions.is_empty()
217 }
218
219 pub fn to_extension_uris(&self) -> Vec<pext::SimpleExtensionUri> {
221 self.uris
222 .iter()
223 .map(|(anchor, uri)| pext::SimpleExtensionUri {
224 extension_uri_anchor: *anchor,
225 uri: uri.clone(),
226 })
227 .collect()
228 }
229
230 pub fn to_extension_declarations(&self) -> Vec<pext::SimpleExtensionDeclaration> {
232 self.extensions
233 .iter()
234 .map(|((anchor, kind), (uri_ref, name))| {
235 let mapping_type = match kind {
236 ExtensionKind::Function => MappingType::ExtensionFunction(
237 pext::simple_extension_declaration::ExtensionFunction {
238 extension_uri_reference: *uri_ref,
239 function_anchor: *anchor,
240 name: name.clone(),
241 },
242 ),
243 ExtensionKind::Type => MappingType::ExtensionType(
244 pext::simple_extension_declaration::ExtensionType {
245 extension_uri_reference: *uri_ref,
246 type_anchor: *anchor,
247 name: name.clone(),
248 },
249 ),
250 ExtensionKind::TypeVariation => MappingType::ExtensionTypeVariation(
251 pext::simple_extension_declaration::ExtensionTypeVariation {
252 extension_uri_reference: *uri_ref,
253 type_variation_anchor: *anchor,
254 name: name.clone(),
255 },
256 ),
257 };
258 pext::SimpleExtensionDeclaration {
259 mapping_type: Some(mapping_type),
260 }
261 })
262 .collect()
263 }
264
265 pub fn write<W: fmt::Write>(&self, w: &mut W, indent: &str) -> fmt::Result {
269 if self.is_empty() {
270 return Ok(());
272 }
273
274 writeln!(w, "{EXTENSIONS_HEADER}")?;
275 if !self.uris.is_empty() {
276 writeln!(w, "{EXTENSION_URIS_HEADER}")?;
277 for (anchor, uri) in &self.uris {
278 writeln!(w, "{indent}@{anchor:3}: {uri}")?;
279 }
280 }
281
282 let kinds_and_headers = [
283 (ExtensionKind::Function, EXTENSION_FUNCTIONS_HEADER),
284 (ExtensionKind::Type, EXTENSION_TYPES_HEADER),
285 (
286 ExtensionKind::TypeVariation,
287 EXTENSION_TYPE_VARIATIONS_HEADER,
288 ),
289 ];
290 for (kind, header) in kinds_and_headers {
291 let mut filtered = self
292 .extensions
293 .iter()
294 .filter(|((_a, k), _)| *k == kind)
295 .peekable();
296 if filtered.peek().is_none() {
297 continue;
298 }
299
300 writeln!(w, "{header}")?;
301 for ((anchor, _), (uri_ref, name)) in filtered {
302 writeln!(w, "{indent}#{anchor:3} @{uri_ref:3}: {name}")?;
303 }
304 }
305 Ok(())
306 }
307
308 pub fn to_string(&self, indent: &str) -> String {
309 let mut output = String::new();
310 self.write(&mut output, indent).unwrap();
311 output
312 }
313}
314
315#[derive(Error, Debug, Clone, PartialEq)]
316pub enum MissingReference {
317 #[error("Missing URI for {0}")]
318 MissingUri(u32),
319 #[error("Missing anchor for {0}: {1}")]
320 MissingAnchor(ExtensionKind, u32),
321 #[error("Missing name for {0}: {1}")]
322 MissingName(ExtensionKind, String),
323 #[error("Mismatched {0}: {1}#{2}")]
324 Mismatched(ExtensionKind, String, u32),
326 #[error("Duplicate name without anchor for {0}: {1}")]
327 DuplicateName(ExtensionKind, String),
328}
329
330#[derive(Debug, Clone, PartialEq)]
331pub struct SimpleExtension {
332 pub kind: ExtensionKind,
333 pub name: String,
334 pub anchor: u32,
335 pub uri: u32,
336}
337
338impl SimpleExtensions {
339 pub fn find_uri(&self, anchor: u32) -> Result<&str, MissingReference> {
340 self.uris
341 .get(&anchor)
342 .map(String::as_str)
343 .ok_or(MissingReference::MissingUri(anchor))
344 }
345
346 pub fn find_by_anchor(
347 &self,
348 kind: ExtensionKind,
349 anchor: u32,
350 ) -> Result<(u32, &str), MissingReference> {
351 let &(uri, ref name) = self
352 .extensions
353 .get(&(anchor, kind))
354 .ok_or(MissingReference::MissingAnchor(kind, anchor))?;
355
356 Ok((uri, name))
357 }
358
359 pub fn find_by_name<'a>(
360 &'a self,
361 kind: ExtensionKind,
362 name: &'a str,
363 ) -> Result<u32, MissingReference> {
364 let mut matches = self
365 .extensions
366 .iter()
367 .filter(move |((_a, k), (_, n))| *k == kind && n.as_str() == name)
368 .map(|((anchor, _), _)| *anchor);
369
370 let anchor = matches
371 .next()
372 .ok_or(MissingReference::MissingName(kind, name.to_string()))?;
373
374 match matches.next() {
375 Some(_) => Err(MissingReference::DuplicateName(kind, name.to_string())),
376 None => Ok(anchor),
377 }
378 }
379
380 pub fn is_name_unique(
387 &self,
388 kind: ExtensionKind,
389 anchor: u32,
390 name: &str,
391 ) -> Result<bool, MissingReference> {
392 let mut found = false;
393 let mut other = false;
394 for (&(a, k), (_, n)) in self.extensions.iter() {
395 if k != kind {
396 continue;
397 }
398
399 if a == anchor {
400 found = true;
401 if n != name {
402 return Err(MissingReference::Mismatched(kind, name.to_string(), anchor));
403 }
404 continue;
405 }
406
407 if n.as_str() != name {
408 continue;
410 }
411
412 other = true;
414 if found {
415 break;
416 }
417 }
418
419 match (found, other) {
420 (true, false) => Ok(true),
422 (true, true) => Ok(false),
424 (false, _) => Err(MissingReference::MissingAnchor(kind, anchor)),
426 }
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use pext::simple_extension_declaration::{
433 ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
434 };
435 use substrait::proto::extensions as pext;
436
437 use super::*;
438
439 fn new_uri(anchor: u32, uri_str: &str) -> pext::SimpleExtensionUri {
440 pext::SimpleExtensionUri {
441 extension_uri_anchor: anchor,
442 uri: uri_str.to_string(),
443 }
444 }
445
446 fn new_ext_fn(anchor: u32, uri_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
447 pext::SimpleExtensionDeclaration {
448 mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
449 extension_uri_reference: uri_ref,
450 function_anchor: anchor,
451 name: name.to_string(),
452 })),
453 }
454 }
455
456 fn new_ext_type(anchor: u32, uri_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
457 pext::SimpleExtensionDeclaration {
458 mapping_type: Some(MappingType::ExtensionType(ExtensionType {
459 extension_uri_reference: uri_ref,
460 type_anchor: anchor,
461 name: name.to_string(),
462 })),
463 }
464 }
465
466 fn new_type_var(anchor: u32, uri_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
467 pext::SimpleExtensionDeclaration {
468 mapping_type: Some(MappingType::ExtensionTypeVariation(
469 ExtensionTypeVariation {
470 extension_uri_reference: uri_ref,
471 type_variation_anchor: anchor,
472 name: name.to_string(),
473 },
474 )),
475 }
476 }
477
478 fn assert_no_errors(errs: &[InsertError]) {
479 for err in errs {
480 println!("Error: {err:?}");
481 }
482 assert!(errs.is_empty());
483 }
484
485 fn unwrap_new_extensions<'a>(
486 uris: impl IntoIterator<Item = &'a pext::SimpleExtensionUri>,
487 extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
488 ) -> SimpleExtensions {
489 let (exts, errs) = SimpleExtensions::from_extensions(uris, extensions);
490 assert_no_errors(&errs);
491 exts
492 }
493
494 #[test]
495 fn test_extension_lookup_empty() {
496 let lookup = SimpleExtensions::new();
497 assert!(lookup.find_uri(1).is_err());
498 assert!(lookup.find_by_anchor(ExtensionKind::Function, 1).is_err());
499 assert!(lookup.find_by_anchor(ExtensionKind::Type, 1).is_err());
500 assert!(
501 lookup
502 .find_by_anchor(ExtensionKind::TypeVariation, 1)
503 .is_err()
504 );
505 assert!(lookup.find_by_name(ExtensionKind::Function, "any").is_err());
506 assert!(lookup.find_by_name(ExtensionKind::Type, "any").is_err());
507 assert!(
508 lookup
509 .find_by_name(ExtensionKind::TypeVariation, "any")
510 .is_err()
511 );
512 }
513
514 #[test]
515 fn test_from_extensions_basic() {
516 let uris = vec![new_uri(1, "uri1"), new_uri(2, "uri2")];
517 let extensions = vec![
518 new_ext_fn(10, 1, "func1"),
519 new_ext_type(20, 1, "type1"),
520 new_type_var(30, 2, "var1"),
521 ];
522 let exts = unwrap_new_extensions(&uris, &extensions);
523
524 assert_eq!(exts.find_uri(1).unwrap(), "uri1");
525 assert_eq!(exts.find_uri(2).unwrap(), "uri2");
526 assert!(exts.find_uri(3).is_err());
527
528 let (uri, name) = exts.find_by_anchor(ExtensionKind::Function, 10).unwrap();
529 assert_eq!(name, "func1");
530 assert_eq!(uri, 1);
531 assert!(exts.find_by_anchor(ExtensionKind::Function, 11).is_err());
532
533 let (uri, name) = exts.find_by_anchor(ExtensionKind::Type, 20).unwrap();
534 assert_eq!(name, "type1");
535 assert_eq!(uri, 1);
536 assert!(exts.find_by_anchor(ExtensionKind::Type, 21).is_err());
537
538 let (uri, name) = exts
539 .find_by_anchor(ExtensionKind::TypeVariation, 30)
540 .unwrap();
541 assert_eq!(name, "var1");
542 assert_eq!(uri, 2);
543 assert!(
544 exts.find_by_anchor(ExtensionKind::TypeVariation, 31)
545 .is_err()
546 );
547 }
548
549 #[test]
550 fn test_from_extensions_duplicates() {
551 let uris = vec![
552 new_uri(1, "uri_old"),
553 new_uri(1, "uri_new"),
554 new_uri(2, "second"),
555 ];
556 let extensions = vec![
557 new_ext_fn(10, 1, "func_old"),
558 new_ext_fn(10, 2, "func_new"), new_ext_fn(11, 3, "func_missing"),
560 ];
561 let (exts, errs) = SimpleExtensions::from_extensions(&uris, &extensions);
562 assert_eq!(
563 errs,
564 vec![
565 InsertError::DuplicateUriAnchor {
566 anchor: 1,
567 name: "uri_new".to_string(),
568 prev: "uri_old".to_string()
569 },
570 InsertError::DuplicateAnchor {
571 kind: ExtensionKind::Function,
572 anchor: 10,
573 prev: "func_old".to_string(),
574 name: "func_new".to_string()
575 },
576 InsertError::MissingUri {
577 kind: ExtensionKind::Function,
578 anchor: 11,
579 name: "func_missing".to_string(),
580 uri: 3,
581 },
582 ]
583 );
584
585 assert_eq!(exts.find_uri(1).unwrap(), "uri_old");
587 assert_eq!(
588 exts.find_by_anchor(ExtensionKind::Function, 10).unwrap(),
589 (1, "func_old")
590 );
591 }
592
593 #[test]
594 fn test_from_extensions_invalid_mapping_type() {
595 let extensions = vec![pext::SimpleExtensionDeclaration { mapping_type: None }];
596
597 let (_exts, errs) = SimpleExtensions::from_extensions(vec![], &extensions);
598 assert_eq!(errs.len(), 1);
599 let err = &errs[0];
600 assert_eq!(err, &InsertError::MissingMappingType);
601 }
602
603 #[test]
604 fn test_find_by_name() {
605 let uris = vec![new_uri(1, "uri1")];
606 let extensions = vec![
607 new_ext_fn(10, 1, "name1"),
608 new_ext_fn(11, 1, "name2"),
609 new_ext_fn(12, 1, "name1"), new_ext_type(20, 1, "type_name1"),
611 new_type_var(30, 1, "var_name1"),
612 ];
613 let exts = unwrap_new_extensions(&uris, &extensions);
614
615 let err = exts
616 .find_by_name(ExtensionKind::Function, "name1")
617 .unwrap_err();
618 assert_eq!(
619 err,
620 MissingReference::DuplicateName(ExtensionKind::Function, "name1".to_string())
621 );
622
623 let found = exts.find_by_name(ExtensionKind::Function, "name2").unwrap();
624 assert_eq!(found, 11);
625
626 let found = exts
627 .find_by_name(ExtensionKind::Type, "type_name1")
628 .unwrap();
629 assert_eq!(found, 20);
630
631 let err = exts
632 .find_by_name(ExtensionKind::Type, "non_existent_type_name")
633 .unwrap_err();
634 assert_eq!(
635 err,
636 MissingReference::MissingName(
637 ExtensionKind::Type,
638 "non_existent_type_name".to_string()
639 )
640 );
641
642 let found = exts
643 .find_by_name(ExtensionKind::TypeVariation, "var_name1")
644 .unwrap();
645 assert_eq!(found, 30);
646
647 let err = exts
648 .find_by_name(ExtensionKind::TypeVariation, "non_existent_var_name")
649 .unwrap_err();
650 assert_eq!(
651 err,
652 MissingReference::MissingName(
653 ExtensionKind::TypeVariation,
654 "non_existent_var_name".to_string()
655 )
656 );
657 }
658
659 #[test]
660 fn test_display_extension_lookup_empty() {
661 let lookup = SimpleExtensions::new();
662 let mut output = String::new();
663 lookup.write(&mut output, " ").unwrap();
664 let expected = r"";
665 assert_eq!(output, expected.trim_start());
666 }
667
668 #[test]
669 fn test_display_extension_lookup_with_content() {
670 let uris = vec![
671 new_uri(1, "/my/uri1"),
672 new_uri(42, "/another/uri"),
673 new_uri(4091, "/big/anchor"),
674 ];
675 let extensions = vec![
676 new_ext_fn(10, 1, "my_func"),
677 new_ext_type(20, 1, "my_type"),
678 new_type_var(30, 42, "my_var"),
679 new_ext_fn(11, 42, "another_func"),
680 new_ext_fn(108812, 4091, "big_func"),
681 ];
682 let exts = unwrap_new_extensions(&uris, &extensions);
683 let display_str = exts.to_string(" ");
684
685 let expected = r"
686=== Extensions
687URIs:
688 @ 1: /my/uri1
689 @ 42: /another/uri
690 @4091: /big/anchor
691Functions:
692 # 10 @ 1: my_func
693 # 11 @ 42: another_func
694 #108812 @4091: big_func
695Types:
696 # 20 @ 1: my_type
697Type Variations:
698 # 30 @ 42: my_var
699";
700 assert_eq!(display_str, expected.trim_start());
701 }
702
703 #[test]
704 fn test_extensions_output() {
705 let mut extensions = SimpleExtensions::new();
707 extensions
708 .add_extension_uri("/uri/common".to_string(), 1)
709 .unwrap();
710 extensions
711 .add_extension_uri("/uri/specific_funcs".to_string(), 2)
712 .unwrap();
713 extensions
714 .add_extension(ExtensionKind::Function, 1, 10, "func_a".to_string())
715 .unwrap();
716 extensions
717 .add_extension(ExtensionKind::Function, 2, 11, "func_b_special".to_string())
718 .unwrap();
719 extensions
720 .add_extension(ExtensionKind::Type, 1, 20, "SomeType".to_string())
721 .unwrap();
722 extensions
723 .add_extension(ExtensionKind::TypeVariation, 2, 30, "VarX".to_string())
724 .unwrap();
725
726 let output = extensions.to_string(" ");
728
729 let expected_output = r#"
731=== Extensions
732URIs:
733 @ 1: /uri/common
734 @ 2: /uri/specific_funcs
735Functions:
736 # 10 @ 1: func_a
737 # 11 @ 2: func_b_special
738Types:
739 # 20 @ 1: SomeType
740Type Variations:
741 # 30 @ 2: VarX
742"#;
743
744 assert_eq!(output, expected_output.trim_start());
745 }
746}