1use std::collections::BTreeMap;
2use std::collections::btree_map::Entry;
3use std::fmt;
4
5use pext::simple_extension_declaration::MappingType;
6use substrait::proto::extensions as pext;
7use thiserror::Error;
8
9pub const EXTENSIONS_HEADER: &str = "=== Extensions";
10pub const EXTENSION_URNS_HEADER: &str = "URNs:";
11pub const EXTENSION_FUNCTIONS_HEADER: &str = "Functions:";
12pub const EXTENSION_TYPES_HEADER: &str = "Types:";
13pub const EXTENSION_TYPE_VARIATIONS_HEADER: &str = "Type Variations:";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub enum ExtensionKind {
17 Function,
19 Type,
20 TypeVariation,
21}
22
23impl ExtensionKind {
24 pub fn name(&self) -> &'static str {
25 match self {
26 ExtensionKind::Function => "function",
28 ExtensionKind::Type => "type",
29 ExtensionKind::TypeVariation => "type_variation",
30 }
31 }
32}
33
34impl fmt::Display for ExtensionKind {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 ExtensionKind::Function => write!(f, "Function"),
39 ExtensionKind::Type => write!(f, "Type"),
40 ExtensionKind::TypeVariation => write!(f, "Type Variation"),
41 }
42 }
43}
44
45#[derive(Error, Debug, PartialEq, Clone)]
46pub enum InsertError {
47 #[error("Extension declaration missing mapping type")]
48 MissingMappingType,
49
50 #[error("Duplicate URN anchor {anchor} for {prev} and {name}")]
51 DuplicateUrnAnchor {
52 anchor: u32,
53 prev: String,
54 name: String,
55 },
56
57 #[error("Duplicate extension {kind} anchor {anchor} for {prev} and {name}")]
58 DuplicateAnchor {
59 kind: ExtensionKind,
60 anchor: u32,
61 prev: String,
62 name: String,
63 },
64
65 #[error("Missing URN anchor {urn} for extension {kind} anchor {anchor} name {name}")]
66 MissingUrn {
67 kind: ExtensionKind,
68 anchor: u32,
69 name: String,
70 urn: u32,
71 },
72
73 #[error(
74 "Duplicate extension {kind} anchor {anchor} for {prev} and {name}, also missing URN {urn}"
75 )]
76 DuplicateAndMissingUrn {
77 kind: ExtensionKind,
78 anchor: u32,
79 prev: String,
80 name: String,
81 urn: u32,
82 },
83}
84
85#[derive(Default, Debug, Clone, PartialEq)]
88pub struct SimpleExtensions {
89 urns: BTreeMap<u32, String>,
91 extensions: BTreeMap<(u32, ExtensionKind), (u32, String)>,
93}
94
95impl SimpleExtensions {
96 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn from_extensions<'a>(
101 urns: impl IntoIterator<Item = &'a pext::SimpleExtensionUrn>,
102 extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
103 ) -> (Self, Vec<InsertError>) {
104 let mut exts = Self::new();
108
109 let mut errors = Vec::<InsertError>::new();
110
111 for urn in urns {
112 if let Err(e) = exts.add_extension_urn(urn.urn.clone(), urn.extension_urn_anchor) {
113 errors.push(e);
114 }
115 }
116
117 for extension in extensions {
118 match &extension.mapping_type {
119 Some(MappingType::ExtensionType(t)) => {
120 if let Err(e) = exts.add_extension(
121 ExtensionKind::Type,
122 t.extension_urn_reference,
123 t.type_anchor,
124 t.name.clone(),
125 ) {
126 errors.push(e);
127 }
128 }
129 Some(MappingType::ExtensionFunction(f)) => {
130 if let Err(e) = exts.add_extension(
131 ExtensionKind::Function,
132 f.extension_urn_reference,
133 f.function_anchor,
134 f.name.clone(),
135 ) {
136 errors.push(e);
137 }
138 }
139 Some(MappingType::ExtensionTypeVariation(v)) => {
140 if let Err(e) = exts.add_extension(
141 ExtensionKind::TypeVariation,
142 v.extension_urn_reference,
143 v.type_variation_anchor,
144 v.name.clone(),
145 ) {
146 errors.push(e);
147 }
148 }
149 None => {
150 errors.push(InsertError::MissingMappingType);
151 }
152 }
153 }
154
155 (exts, errors)
156 }
157
158 pub fn add_extension_urn(&mut self, urn: String, anchor: u32) -> Result<(), InsertError> {
159 match self.urns.entry(anchor) {
160 Entry::Occupied(e) => {
161 return Err(InsertError::DuplicateUrnAnchor {
162 anchor,
163 prev: e.get().clone(),
164 name: urn,
165 });
166 }
167 Entry::Vacant(e) => {
168 e.insert(urn);
169 }
170 }
171 Ok(())
172 }
173
174 pub fn add_extension(
175 &mut self,
176 kind: ExtensionKind,
177 urn: u32,
178 anchor: u32,
179 name: String,
180 ) -> Result<(), InsertError> {
181 let missing_urn = !self.urns.contains_key(&urn);
182
183 let prev = match self.extensions.entry((anchor, kind)) {
184 Entry::Occupied(e) => Some(e.get().1.clone()),
185 Entry::Vacant(v) => {
186 v.insert((urn, name.clone()));
187 None
188 }
189 };
190
191 match (missing_urn, prev) {
192 (true, Some(prev)) => Err(InsertError::DuplicateAndMissingUrn {
193 kind,
194 anchor,
195 prev,
196 name,
197 urn,
198 }),
199 (false, Some(prev)) => Err(InsertError::DuplicateAnchor {
200 kind,
201 anchor,
202 prev,
203 name,
204 }),
205 (true, None) => Err(InsertError::MissingUrn {
206 kind,
207 anchor,
208 name,
209 urn,
210 }),
211 (false, None) => Ok(()),
212 }
213 }
214
215 pub fn is_empty(&self) -> bool {
216 self.urns.is_empty() && self.extensions.is_empty()
217 }
218
219 pub fn to_extension_urns(&self) -> Vec<pext::SimpleExtensionUrn> {
221 self.urns
222 .iter()
223 .map(|(anchor, urn)| pext::SimpleExtensionUrn {
224 extension_urn_anchor: *anchor,
225 urn: urn.clone(),
226 })
227 .collect()
228 }
229
230 pub fn to_extension_declarations(&self) -> Vec<pext::SimpleExtensionDeclaration> {
232 self.extensions
233 .iter()
234 .map(|((anchor, kind), (urn_ref, name))| {
235 let mapping_type = match kind {
236 ExtensionKind::Function => MappingType::ExtensionFunction(
237 #[allow(deprecated)]
238 pext::simple_extension_declaration::ExtensionFunction {
239 extension_urn_reference: *urn_ref,
240 extension_uri_reference: Default::default(), function_anchor: *anchor,
242 name: name.clone(),
243 },
244 ),
245 ExtensionKind::Type => MappingType::ExtensionType(
246 #[allow(deprecated)]
247 pext::simple_extension_declaration::ExtensionType {
248 extension_urn_reference: *urn_ref,
249 extension_uri_reference: Default::default(), type_anchor: *anchor,
251 name: name.clone(),
252 },
253 ),
254 ExtensionKind::TypeVariation => MappingType::ExtensionTypeVariation(
255 #[allow(deprecated)]
256 pext::simple_extension_declaration::ExtensionTypeVariation {
257 extension_urn_reference: *urn_ref,
258 extension_uri_reference: Default::default(), type_variation_anchor: *anchor,
260 name: name.clone(),
261 },
262 ),
263 };
264 pext::SimpleExtensionDeclaration {
265 mapping_type: Some(mapping_type),
266 }
267 })
268 .collect()
269 }
270
271 pub fn write<W: fmt::Write>(&self, w: &mut W, indent: &str) -> fmt::Result {
275 if self.is_empty() {
276 return Ok(());
278 }
279
280 writeln!(w, "{EXTENSIONS_HEADER}")?;
281 if !self.urns.is_empty() {
282 writeln!(w, "{EXTENSION_URNS_HEADER}")?;
283 for (anchor, urn) in &self.urns {
284 writeln!(w, "{indent}@{anchor:3}: {urn}")?;
285 }
286 }
287
288 let kinds_and_headers = [
289 (ExtensionKind::Function, EXTENSION_FUNCTIONS_HEADER),
290 (ExtensionKind::Type, EXTENSION_TYPES_HEADER),
291 (
292 ExtensionKind::TypeVariation,
293 EXTENSION_TYPE_VARIATIONS_HEADER,
294 ),
295 ];
296 for (kind, header) in kinds_and_headers {
297 let mut filtered = self
298 .extensions
299 .iter()
300 .filter(|((_a, k), _)| *k == kind)
301 .peekable();
302 if filtered.peek().is_none() {
303 continue;
304 }
305
306 writeln!(w, "{header}")?;
307 for ((anchor, _), (urn_ref, name)) in filtered {
308 writeln!(w, "{indent}#{anchor:3} @{urn_ref:3}: {name}")?;
309 }
310 }
311 Ok(())
312 }
313
314 pub fn to_string(&self, indent: &str) -> String {
315 let mut output = String::new();
316 self.write(&mut output, indent).unwrap();
317 output
318 }
319}
320
321#[derive(Error, Debug, Clone, PartialEq)]
322pub enum MissingReference {
323 #[error("Missing URN for {0}")]
324 MissingUrn(u32),
325 #[error("Missing anchor for {0}: {1}")]
326 MissingAnchor(ExtensionKind, u32),
327 #[error("Missing name for {0}: {1}")]
328 MissingName(ExtensionKind, String),
329 #[error("Mismatched {0}: {1}#{2}")]
330 Mismatched(ExtensionKind, String, u32),
332 #[error("Duplicate name without anchor for {0}: {1}")]
333 DuplicateName(ExtensionKind, String),
334}
335
336#[derive(Debug, Clone, PartialEq)]
337pub struct SimpleExtension {
338 pub kind: ExtensionKind,
339 pub name: String,
340 pub anchor: u32,
341 pub urn: u32,
342}
343
344impl SimpleExtensions {
345 pub fn find_urn(&self, anchor: u32) -> Result<&str, MissingReference> {
346 self.urns
347 .get(&anchor)
348 .map(String::as_str)
349 .ok_or(MissingReference::MissingUrn(anchor))
350 }
351
352 pub fn find_by_anchor(
353 &self,
354 kind: ExtensionKind,
355 anchor: u32,
356 ) -> Result<(u32, &str), MissingReference> {
357 let &(urn, ref name) = self
358 .extensions
359 .get(&(anchor, kind))
360 .ok_or(MissingReference::MissingAnchor(kind, anchor))?;
361
362 Ok((urn, name))
363 }
364
365 pub fn find_by_name<'a>(
366 &'a self,
367 kind: ExtensionKind,
368 name: &'a str,
369 ) -> Result<u32, MissingReference> {
370 let mut matches = self
371 .extensions
372 .iter()
373 .filter(move |((_a, k), (_, n))| *k == kind && n.as_str() == name)
374 .map(|((anchor, _), _)| *anchor);
375
376 let anchor = matches
377 .next()
378 .ok_or(MissingReference::MissingName(kind, name.to_string()))?;
379
380 match matches.next() {
381 Some(_) => Err(MissingReference::DuplicateName(kind, name.to_string())),
382 None => Ok(anchor),
383 }
384 }
385
386 pub fn is_name_unique(
393 &self,
394 kind: ExtensionKind,
395 anchor: u32,
396 name: &str,
397 ) -> Result<bool, MissingReference> {
398 let mut found = false;
399 let mut other = false;
400 for (&(a, k), (_, n)) in self.extensions.iter() {
401 if k != kind {
402 continue;
403 }
404
405 if a == anchor {
406 found = true;
407 if n != name {
408 return Err(MissingReference::Mismatched(kind, name.to_string(), anchor));
409 }
410 continue;
411 }
412
413 if n.as_str() != name {
414 continue;
416 }
417
418 other = true;
420 if found {
421 break;
422 }
423 }
424
425 match (found, other) {
426 (true, false) => Ok(true),
428 (true, true) => Ok(false),
430 (false, _) => Err(MissingReference::MissingAnchor(kind, anchor)),
432 }
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use pext::simple_extension_declaration::{
439 ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
440 };
441 use substrait::proto::extensions as pext;
442
443 use super::*;
444
445 fn new_urn(anchor: u32, urn_str: &str) -> pext::SimpleExtensionUrn {
446 pext::SimpleExtensionUrn {
447 extension_urn_anchor: anchor,
448 urn: urn_str.to_string(),
449 }
450 }
451
452 fn new_ext_fn(anchor: u32, urn_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
453 pext::SimpleExtensionDeclaration {
454 #[allow(deprecated)]
455 mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
456 extension_urn_reference: urn_ref,
457 extension_uri_reference: Default::default(), function_anchor: anchor,
459 name: name.to_string(),
460 })),
461 }
462 }
463
464 fn new_ext_type(anchor: u32, urn_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
465 #[allow(deprecated)]
466 pext::SimpleExtensionDeclaration {
467 mapping_type: Some(MappingType::ExtensionType(ExtensionType {
468 extension_urn_reference: urn_ref,
469 extension_uri_reference: Default::default(), type_anchor: anchor,
471 name: name.to_string(),
472 })),
473 }
474 }
475
476 fn new_type_var(anchor: u32, urn_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
477 pext::SimpleExtensionDeclaration {
478 #[allow(deprecated)]
479 mapping_type: Some(MappingType::ExtensionTypeVariation(
480 ExtensionTypeVariation {
481 extension_urn_reference: urn_ref,
482 extension_uri_reference: Default::default(), type_variation_anchor: anchor,
484 name: name.to_string(),
485 },
486 )),
487 }
488 }
489
490 fn assert_no_errors(errs: &[InsertError]) {
491 for err in errs {
492 println!("Error: {err:?}");
493 }
494 assert!(errs.is_empty());
495 }
496
497 fn unwrap_new_extensions<'a>(
498 urns: impl IntoIterator<Item = &'a pext::SimpleExtensionUrn>,
499 extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
500 ) -> SimpleExtensions {
501 let (exts, errs) = SimpleExtensions::from_extensions(urns, extensions);
502 assert_no_errors(&errs);
503 exts
504 }
505
506 #[test]
507 fn test_extension_lookup_empty() {
508 let lookup = SimpleExtensions::new();
509 assert!(lookup.find_urn(1).is_err());
510 assert!(lookup.find_by_anchor(ExtensionKind::Function, 1).is_err());
511 assert!(lookup.find_by_anchor(ExtensionKind::Type, 1).is_err());
512 assert!(
513 lookup
514 .find_by_anchor(ExtensionKind::TypeVariation, 1)
515 .is_err()
516 );
517 assert!(lookup.find_by_name(ExtensionKind::Function, "any").is_err());
518 assert!(lookup.find_by_name(ExtensionKind::Type, "any").is_err());
519 assert!(
520 lookup
521 .find_by_name(ExtensionKind::TypeVariation, "any")
522 .is_err()
523 );
524 }
525
526 #[test]
527 fn test_from_extensions_basic() {
528 let urns = vec![new_urn(1, "urn1"), new_urn(2, "urn2")];
529 let extensions = vec![
530 new_ext_fn(10, 1, "func1"),
531 new_ext_type(20, 1, "type1"),
532 new_type_var(30, 2, "var1"),
533 ];
534 let exts = unwrap_new_extensions(&urns, &extensions);
535
536 assert_eq!(exts.find_urn(1).unwrap(), "urn1");
537 assert_eq!(exts.find_urn(2).unwrap(), "urn2");
538 assert!(exts.find_urn(3).is_err());
539
540 let (urn, name) = exts.find_by_anchor(ExtensionKind::Function, 10).unwrap();
541 assert_eq!(name, "func1");
542 assert_eq!(urn, 1);
543 assert!(exts.find_by_anchor(ExtensionKind::Function, 11).is_err());
544
545 let (urn, name) = exts.find_by_anchor(ExtensionKind::Type, 20).unwrap();
546 assert_eq!(name, "type1");
547 assert_eq!(urn, 1);
548 assert!(exts.find_by_anchor(ExtensionKind::Type, 21).is_err());
549
550 let (urn, name) = exts
551 .find_by_anchor(ExtensionKind::TypeVariation, 30)
552 .unwrap();
553 assert_eq!(name, "var1");
554 assert_eq!(urn, 2);
555 assert!(
556 exts.find_by_anchor(ExtensionKind::TypeVariation, 31)
557 .is_err()
558 );
559 }
560
561 #[test]
562 fn test_from_extensions_duplicates() {
563 let urns = vec![
564 new_urn(1, "urn_old"),
565 new_urn(1, "urn_new"),
566 new_urn(2, "second"),
567 ];
568 let extensions = vec![
569 new_ext_fn(10, 1, "func_old"),
570 new_ext_fn(10, 2, "func_new"), new_ext_fn(11, 3, "func_missing"),
572 ];
573 let (exts, errs) = SimpleExtensions::from_extensions(&urns, &extensions);
574 assert_eq!(
575 errs,
576 vec![
577 InsertError::DuplicateUrnAnchor {
578 anchor: 1,
579 name: "urn_new".to_string(),
580 prev: "urn_old".to_string()
581 },
582 InsertError::DuplicateAnchor {
583 kind: ExtensionKind::Function,
584 anchor: 10,
585 prev: "func_old".to_string(),
586 name: "func_new".to_string()
587 },
588 InsertError::MissingUrn {
589 kind: ExtensionKind::Function,
590 anchor: 11,
591 name: "func_missing".to_string(),
592 urn: 3,
593 },
594 ]
595 );
596
597 assert_eq!(exts.find_urn(1).unwrap(), "urn_old");
599 assert_eq!(
600 exts.find_by_anchor(ExtensionKind::Function, 10).unwrap(),
601 (1, "func_old")
602 );
603 }
604
605 #[test]
606 fn test_from_extensions_invalid_mapping_type() {
607 let extensions = vec![pext::SimpleExtensionDeclaration { mapping_type: None }];
608
609 let (_exts, errs) = SimpleExtensions::from_extensions(vec![], &extensions);
610 assert_eq!(errs.len(), 1);
611 let err = &errs[0];
612 assert_eq!(err, &InsertError::MissingMappingType);
613 }
614
615 #[test]
616 fn test_find_by_name() {
617 let urns = vec![new_urn(1, "urn1")];
618 let extensions = vec![
619 new_ext_fn(10, 1, "name1"),
620 new_ext_fn(11, 1, "name2"),
621 new_ext_fn(12, 1, "name1"), new_ext_type(20, 1, "type_name1"),
623 new_type_var(30, 1, "var_name1"),
624 ];
625 let exts = unwrap_new_extensions(&urns, &extensions);
626
627 let err = exts
628 .find_by_name(ExtensionKind::Function, "name1")
629 .unwrap_err();
630 assert_eq!(
631 err,
632 MissingReference::DuplicateName(ExtensionKind::Function, "name1".to_string())
633 );
634
635 let found = exts.find_by_name(ExtensionKind::Function, "name2").unwrap();
636 assert_eq!(found, 11);
637
638 let found = exts
639 .find_by_name(ExtensionKind::Type, "type_name1")
640 .unwrap();
641 assert_eq!(found, 20);
642
643 let err = exts
644 .find_by_name(ExtensionKind::Type, "non_existent_type_name")
645 .unwrap_err();
646 assert_eq!(
647 err,
648 MissingReference::MissingName(
649 ExtensionKind::Type,
650 "non_existent_type_name".to_string()
651 )
652 );
653
654 let found = exts
655 .find_by_name(ExtensionKind::TypeVariation, "var_name1")
656 .unwrap();
657 assert_eq!(found, 30);
658
659 let err = exts
660 .find_by_name(ExtensionKind::TypeVariation, "non_existent_var_name")
661 .unwrap_err();
662 assert_eq!(
663 err,
664 MissingReference::MissingName(
665 ExtensionKind::TypeVariation,
666 "non_existent_var_name".to_string()
667 )
668 );
669 }
670
671 #[test]
672 fn test_display_extension_lookup_empty() {
673 let lookup = SimpleExtensions::new();
674 let mut output = String::new();
675 lookup.write(&mut output, " ").unwrap();
676 let expected = r"";
677 assert_eq!(output, expected.trim_start());
678 }
679
680 #[test]
681 fn test_display_extension_lookup_with_content() {
682 let urns = vec![
683 new_urn(1, "/my/urn1"),
684 new_urn(42, "/another/urn"),
685 new_urn(4091, "/big/anchor"),
686 ];
687 let extensions = vec![
688 new_ext_fn(10, 1, "my_func"),
689 new_ext_type(20, 1, "my_type"),
690 new_type_var(30, 42, "my_var"),
691 new_ext_fn(11, 42, "another_func"),
692 new_ext_fn(108812, 4091, "big_func"),
693 ];
694 let exts = unwrap_new_extensions(&urns, &extensions);
695 let display_str = exts.to_string(" ");
696
697 let expected = r"
698=== Extensions
699URNs:
700 @ 1: /my/urn1
701 @ 42: /another/urn
702 @4091: /big/anchor
703Functions:
704 # 10 @ 1: my_func
705 # 11 @ 42: another_func
706 #108812 @4091: big_func
707Types:
708 # 20 @ 1: my_type
709Type Variations:
710 # 30 @ 42: my_var
711";
712 assert_eq!(display_str, expected.trim_start());
713 }
714
715 #[test]
716 fn test_extensions_output() {
717 let mut extensions = SimpleExtensions::new();
719 extensions
720 .add_extension_urn("/urn/common".to_string(), 1)
721 .unwrap();
722 extensions
723 .add_extension_urn("/urn/specific_funcs".to_string(), 2)
724 .unwrap();
725 extensions
726 .add_extension(ExtensionKind::Function, 1, 10, "func_a".to_string())
727 .unwrap();
728 extensions
729 .add_extension(ExtensionKind::Function, 2, 11, "func_b_special".to_string())
730 .unwrap();
731 extensions
732 .add_extension(ExtensionKind::Type, 1, 20, "SomeType".to_string())
733 .unwrap();
734 extensions
735 .add_extension(ExtensionKind::TypeVariation, 2, 30, "VarX".to_string())
736 .unwrap();
737
738 let output = extensions.to_string(" ");
740
741 let expected_output = r#"
743=== Extensions
744URNs:
745 @ 1: /urn/common
746 @ 2: /urn/specific_funcs
747Functions:
748 # 10 @ 1: func_a
749 # 11 @ 2: func_b_special
750Types:
751 # 20 @ 1: SomeType
752Type Variations:
753 # 30 @ 2: VarX
754"#;
755
756 assert_eq!(output, expected_output.trim_start());
757 }
758}