substrait_explain/extensions/
simple.rs

1use std::collections::BTreeMap;
2use std::collections::btree_map::Entry;
3use std::fmt;
4
5use pext::simple_extension_declaration::MappingType;
6use substrait::proto::extensions as pext;
7use thiserror::Error;
8
9pub const EXTENSIONS_HEADER: &str = "=== Extensions";
10pub const EXTENSION_URIS_HEADER: &str = "URIs:";
11pub const EXTENSION_FUNCTIONS_HEADER: &str = "Functions:";
12pub const EXTENSION_TYPES_HEADER: &str = "Types:";
13pub const EXTENSION_TYPE_VARIATIONS_HEADER: &str = "Type Variations:";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub enum ExtensionKind {
17    // Uri,
18    Function,
19    Type,
20    TypeVariation,
21}
22
23impl ExtensionKind {
24    pub fn name(&self) -> &'static str {
25        match self {
26            // ExtensionKind::Uri => "uri",
27            ExtensionKind::Function => "function",
28            ExtensionKind::Type => "type",
29            ExtensionKind::TypeVariation => "type_variation",
30        }
31    }
32}
33
34impl fmt::Display for ExtensionKind {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            // ExtensionKind::Uri => write!(f, "URI"),
38            ExtensionKind::Function => write!(f, "Function"),
39            ExtensionKind::Type => write!(f, "Type"),
40            ExtensionKind::TypeVariation => write!(f, "Type Variation"),
41        }
42    }
43}
44
45#[derive(Error, Debug, PartialEq, Clone)]
46pub enum InsertError {
47    #[error("Extension declaration missing mapping type")]
48    MissingMappingType,
49
50    #[error("Duplicate URI anchor {anchor} for {prev} and {name}")]
51    DuplicateUriAnchor {
52        anchor: u32,
53        prev: String,
54        name: String,
55    },
56
57    #[error("Duplicate extension {kind} anchor {anchor} for {prev} and {name}")]
58    DuplicateAnchor {
59        kind: ExtensionKind,
60        anchor: u32,
61        prev: String,
62        name: String,
63    },
64
65    #[error("Missing URI anchor {uri} for extension {kind} anchor {anchor} name {name}")]
66    MissingUri {
67        kind: ExtensionKind,
68        anchor: u32,
69        name: String,
70        uri: u32,
71    },
72
73    #[error(
74        "Duplicate extension {kind} anchor {anchor} for {prev} and {name}, also missing URI {uri}"
75    )]
76    DuplicateAndMissingUri {
77        kind: ExtensionKind,
78        anchor: u32,
79        prev: String,
80        name: String,
81        uri: u32,
82    },
83}
84
85/// ExtensionLookup contains mappings from anchors to extension URIs, functions,
86/// types, and type variations.
87#[derive(Default, Debug, Clone, PartialEq)]
88pub struct SimpleExtensions {
89    // Maps from extension URI anchor to URI
90    uris: BTreeMap<u32, String>,
91    // Maps from anchor and extension kind to (URI, name)
92    extensions: BTreeMap<(u32, ExtensionKind), (u32, String)>,
93}
94
95impl SimpleExtensions {
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    pub fn from_extensions<'a>(
101        uris: impl IntoIterator<Item = &'a pext::SimpleExtensionUri>,
102        extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
103    ) -> (Self, Vec<InsertError>) {
104        // TODO: this checks for missing URIs and duplicate anchors, but not
105        // duplicate names with the same anchor.
106
107        let mut exts = Self::new();
108
109        let mut errors = Vec::<InsertError>::new();
110
111        for uri in uris {
112            if let Err(e) = exts.add_extension_uri(uri.uri.clone(), uri.extension_uri_anchor) {
113                errors.push(e);
114            }
115        }
116
117        for extension in extensions {
118            match &extension.mapping_type {
119                Some(MappingType::ExtensionType(t)) => {
120                    if let Err(e) = exts.add_extension(
121                        ExtensionKind::Type,
122                        t.extension_uri_reference,
123                        t.type_anchor,
124                        t.name.clone(),
125                    ) {
126                        errors.push(e);
127                    }
128                }
129                Some(MappingType::ExtensionFunction(f)) => {
130                    if let Err(e) = exts.add_extension(
131                        ExtensionKind::Function,
132                        f.extension_uri_reference,
133                        f.function_anchor,
134                        f.name.clone(),
135                    ) {
136                        errors.push(e);
137                    }
138                }
139                Some(MappingType::ExtensionTypeVariation(v)) => {
140                    if let Err(e) = exts.add_extension(
141                        ExtensionKind::TypeVariation,
142                        v.extension_uri_reference,
143                        v.type_variation_anchor,
144                        v.name.clone(),
145                    ) {
146                        errors.push(e);
147                    }
148                }
149                None => {
150                    errors.push(InsertError::MissingMappingType);
151                }
152            }
153        }
154
155        (exts, errors)
156    }
157
158    pub fn add_extension_uri(&mut self, uri: String, anchor: u32) -> Result<(), InsertError> {
159        match self.uris.entry(anchor) {
160            Entry::Occupied(e) => {
161                return Err(InsertError::DuplicateUriAnchor {
162                    anchor,
163                    prev: e.get().clone(),
164                    name: uri,
165                });
166            }
167            Entry::Vacant(e) => {
168                e.insert(uri);
169            }
170        }
171        Ok(())
172    }
173
174    pub fn add_extension(
175        &mut self,
176        kind: ExtensionKind,
177        uri: u32,
178        anchor: u32,
179        name: String,
180    ) -> Result<(), InsertError> {
181        let missing_uri = !self.uris.contains_key(&uri);
182
183        let prev = match self.extensions.entry((anchor, kind)) {
184            Entry::Occupied(e) => Some(e.get().1.clone()),
185            Entry::Vacant(v) => {
186                v.insert((uri, name.clone()));
187                None
188            }
189        };
190
191        match (missing_uri, prev) {
192            (true, Some(prev)) => Err(InsertError::DuplicateAndMissingUri {
193                kind,
194                anchor,
195                prev,
196                name,
197                uri,
198            }),
199            (false, Some(prev)) => Err(InsertError::DuplicateAnchor {
200                kind,
201                anchor,
202                prev,
203                name,
204            }),
205            (true, None) => Err(InsertError::MissingUri {
206                kind,
207                anchor,
208                name,
209                uri,
210            }),
211            (false, None) => Ok(()),
212        }
213    }
214
215    pub fn is_empty(&self) -> bool {
216        self.uris.is_empty() && self.extensions.is_empty()
217    }
218
219    /// Convert the extension URIs to protobuf format for Plan construction.
220    pub fn to_extension_uris(&self) -> Vec<pext::SimpleExtensionUri> {
221        self.uris
222            .iter()
223            .map(|(anchor, uri)| pext::SimpleExtensionUri {
224                extension_uri_anchor: *anchor,
225                uri: uri.clone(),
226            })
227            .collect()
228    }
229
230    /// Convert the extensions to protobuf format for Plan construction.
231    pub fn to_extension_declarations(&self) -> Vec<pext::SimpleExtensionDeclaration> {
232        self.extensions
233            .iter()
234            .map(|((anchor, kind), (uri_ref, name))| {
235                let mapping_type = match kind {
236                    ExtensionKind::Function => MappingType::ExtensionFunction(
237                        pext::simple_extension_declaration::ExtensionFunction {
238                            extension_uri_reference: *uri_ref,
239                            function_anchor: *anchor,
240                            name: name.clone(),
241                        },
242                    ),
243                    ExtensionKind::Type => MappingType::ExtensionType(
244                        pext::simple_extension_declaration::ExtensionType {
245                            extension_uri_reference: *uri_ref,
246                            type_anchor: *anchor,
247                            name: name.clone(),
248                        },
249                    ),
250                    ExtensionKind::TypeVariation => MappingType::ExtensionTypeVariation(
251                        pext::simple_extension_declaration::ExtensionTypeVariation {
252                            extension_uri_reference: *uri_ref,
253                            type_variation_anchor: *anchor,
254                            name: name.clone(),
255                        },
256                    ),
257                };
258                pext::SimpleExtensionDeclaration {
259                    mapping_type: Some(mapping_type),
260                }
261            })
262            .collect()
263    }
264
265    /// Write the extensions to the given writer, with the given indent.
266    ///
267    /// The header will be included if there are any extensions.
268    pub fn write<W: fmt::Write>(&self, w: &mut W, indent: &str) -> fmt::Result {
269        if self.is_empty() {
270            // No extensions, so no need to write anything.
271            return Ok(());
272        }
273
274        writeln!(w, "{EXTENSIONS_HEADER}")?;
275        if !self.uris.is_empty() {
276            writeln!(w, "{EXTENSION_URIS_HEADER}")?;
277            for (anchor, uri) in &self.uris {
278                writeln!(w, "{indent}@{anchor:3}: {uri}")?;
279            }
280        }
281
282        let kinds_and_headers = [
283            (ExtensionKind::Function, EXTENSION_FUNCTIONS_HEADER),
284            (ExtensionKind::Type, EXTENSION_TYPES_HEADER),
285            (
286                ExtensionKind::TypeVariation,
287                EXTENSION_TYPE_VARIATIONS_HEADER,
288            ),
289        ];
290        for (kind, header) in kinds_and_headers {
291            let mut filtered = self
292                .extensions
293                .iter()
294                .filter(|((_a, k), _)| *k == kind)
295                .peekable();
296            if filtered.peek().is_none() {
297                continue;
298            }
299
300            writeln!(w, "{header}")?;
301            for ((anchor, _), (uri_ref, name)) in filtered {
302                writeln!(w, "{indent}#{anchor:3} @{uri_ref:3}: {name}")?;
303            }
304        }
305        Ok(())
306    }
307
308    pub fn to_string(&self, indent: &str) -> String {
309        let mut output = String::new();
310        self.write(&mut output, indent).unwrap();
311        output
312    }
313}
314
315#[derive(Error, Debug, Clone, PartialEq)]
316pub enum MissingReference {
317    #[error("Missing URI for {0}")]
318    MissingUri(u32),
319    #[error("Missing anchor for {0}: {1}")]
320    MissingAnchor(ExtensionKind, u32),
321    #[error("Missing name for {0}: {1}")]
322    MissingName(ExtensionKind, String),
323    #[error("Mismatched {0}: {1}#{2}")]
324    /// When the name of the value does not match the expected name
325    Mismatched(ExtensionKind, String, u32),
326    #[error("Duplicate name without anchor for {0}: {1}")]
327    DuplicateName(ExtensionKind, String),
328}
329
330#[derive(Debug, Clone, PartialEq)]
331pub struct SimpleExtension {
332    pub kind: ExtensionKind,
333    pub name: String,
334    pub anchor: u32,
335    pub uri: u32,
336}
337
338impl SimpleExtensions {
339    pub fn find_uri(&self, anchor: u32) -> Result<&str, MissingReference> {
340        self.uris
341            .get(&anchor)
342            .map(String::as_str)
343            .ok_or(MissingReference::MissingUri(anchor))
344    }
345
346    pub fn find_by_anchor(
347        &self,
348        kind: ExtensionKind,
349        anchor: u32,
350    ) -> Result<(u32, &str), MissingReference> {
351        let &(uri, ref name) = self
352            .extensions
353            .get(&(anchor, kind))
354            .ok_or(MissingReference::MissingAnchor(kind, anchor))?;
355
356        Ok((uri, name))
357    }
358
359    pub fn find_by_name<'a>(
360        &'a self,
361        kind: ExtensionKind,
362        name: &'a str,
363    ) -> Result<u32, MissingReference> {
364        let mut matches = self
365            .extensions
366            .iter()
367            .filter(move |((_a, k), (_, n))| *k == kind && n.as_str() == name)
368            .map(|((anchor, _), _)| *anchor);
369
370        let anchor = matches
371            .next()
372            .ok_or(MissingReference::MissingName(kind, name.to_string()))?;
373
374        match matches.next() {
375            Some(_) => Err(MissingReference::DuplicateName(kind, name.to_string())),
376            None => Ok(anchor),
377        }
378    }
379
380    // Validate that the extension exists, has the given name and anchor, and
381    // returns true if the name is unique for that extension kind.
382    //
383    // If the name is not unique, returns Ok(false). This is a valid case where
384    // two extensions have the same name (and presumably different URIs), but
385    // different anchors.
386    pub fn is_name_unique(
387        &self,
388        kind: ExtensionKind,
389        anchor: u32,
390        name: &str,
391    ) -> Result<bool, MissingReference> {
392        let mut found = false;
393        let mut other = false;
394        for (&(a, k), (_, n)) in self.extensions.iter() {
395            if k != kind {
396                continue;
397            }
398
399            if a == anchor {
400                found = true;
401                if n != name {
402                    return Err(MissingReference::Mismatched(kind, name.to_string(), anchor));
403                }
404                continue;
405            }
406
407            if n.as_str() != name {
408                // Neither anchor nor name match, so this is irrelevant.
409                continue;
410            }
411
412            // At this point, the anchor is different, but the name is the same.
413            other = true;
414            if found {
415                break;
416            }
417        }
418
419        match (found, other) {
420            // Found the one we're looking for, and no other matches.
421            (true, false) => Ok(true),
422            // Found the one we're looking for, and another match.
423            (true, true) => Ok(false),
424            // Didn't find the one we're looking for;
425            (false, _) => Err(MissingReference::MissingAnchor(kind, anchor)),
426        }
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use pext::simple_extension_declaration::{
433        ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
434    };
435    use substrait::proto::extensions as pext;
436
437    use super::*;
438
439    fn new_uri(anchor: u32, uri_str: &str) -> pext::SimpleExtensionUri {
440        pext::SimpleExtensionUri {
441            extension_uri_anchor: anchor,
442            uri: uri_str.to_string(),
443        }
444    }
445
446    fn new_ext_fn(anchor: u32, uri_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
447        pext::SimpleExtensionDeclaration {
448            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
449                extension_uri_reference: uri_ref,
450                function_anchor: anchor,
451                name: name.to_string(),
452            })),
453        }
454    }
455
456    fn new_ext_type(anchor: u32, uri_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
457        pext::SimpleExtensionDeclaration {
458            mapping_type: Some(MappingType::ExtensionType(ExtensionType {
459                extension_uri_reference: uri_ref,
460                type_anchor: anchor,
461                name: name.to_string(),
462            })),
463        }
464    }
465
466    fn new_type_var(anchor: u32, uri_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
467        pext::SimpleExtensionDeclaration {
468            mapping_type: Some(MappingType::ExtensionTypeVariation(
469                ExtensionTypeVariation {
470                    extension_uri_reference: uri_ref,
471                    type_variation_anchor: anchor,
472                    name: name.to_string(),
473                },
474            )),
475        }
476    }
477
478    fn assert_no_errors(errs: &[InsertError]) {
479        for err in errs {
480            println!("Error: {err:?}");
481        }
482        assert!(errs.is_empty());
483    }
484
485    fn unwrap_new_extensions<'a>(
486        uris: impl IntoIterator<Item = &'a pext::SimpleExtensionUri>,
487        extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
488    ) -> SimpleExtensions {
489        let (exts, errs) = SimpleExtensions::from_extensions(uris, extensions);
490        assert_no_errors(&errs);
491        exts
492    }
493
494    #[test]
495    fn test_extension_lookup_empty() {
496        let lookup = SimpleExtensions::new();
497        assert!(lookup.find_uri(1).is_err());
498        assert!(lookup.find_by_anchor(ExtensionKind::Function, 1).is_err());
499        assert!(lookup.find_by_anchor(ExtensionKind::Type, 1).is_err());
500        assert!(
501            lookup
502                .find_by_anchor(ExtensionKind::TypeVariation, 1)
503                .is_err()
504        );
505        assert!(lookup.find_by_name(ExtensionKind::Function, "any").is_err());
506        assert!(lookup.find_by_name(ExtensionKind::Type, "any").is_err());
507        assert!(
508            lookup
509                .find_by_name(ExtensionKind::TypeVariation, "any")
510                .is_err()
511        );
512    }
513
514    #[test]
515    fn test_from_extensions_basic() {
516        let uris = vec![new_uri(1, "uri1"), new_uri(2, "uri2")];
517        let extensions = vec![
518            new_ext_fn(10, 1, "func1"),
519            new_ext_type(20, 1, "type1"),
520            new_type_var(30, 2, "var1"),
521        ];
522        let exts = unwrap_new_extensions(&uris, &extensions);
523
524        assert_eq!(exts.find_uri(1).unwrap(), "uri1");
525        assert_eq!(exts.find_uri(2).unwrap(), "uri2");
526        assert!(exts.find_uri(3).is_err());
527
528        let (uri, name) = exts.find_by_anchor(ExtensionKind::Function, 10).unwrap();
529        assert_eq!(name, "func1");
530        assert_eq!(uri, 1);
531        assert!(exts.find_by_anchor(ExtensionKind::Function, 11).is_err());
532
533        let (uri, name) = exts.find_by_anchor(ExtensionKind::Type, 20).unwrap();
534        assert_eq!(name, "type1");
535        assert_eq!(uri, 1);
536        assert!(exts.find_by_anchor(ExtensionKind::Type, 21).is_err());
537
538        let (uri, name) = exts
539            .find_by_anchor(ExtensionKind::TypeVariation, 30)
540            .unwrap();
541        assert_eq!(name, "var1");
542        assert_eq!(uri, 2);
543        assert!(
544            exts.find_by_anchor(ExtensionKind::TypeVariation, 31)
545                .is_err()
546        );
547    }
548
549    #[test]
550    fn test_from_extensions_duplicates() {
551        let uris = vec![
552            new_uri(1, "uri_old"),
553            new_uri(1, "uri_new"),
554            new_uri(2, "second"),
555        ];
556        let extensions = vec![
557            new_ext_fn(10, 1, "func_old"),
558            new_ext_fn(10, 2, "func_new"), // Duplicate function anchor
559            new_ext_fn(11, 3, "func_missing"),
560        ];
561        let (exts, errs) = SimpleExtensions::from_extensions(&uris, &extensions);
562        assert_eq!(
563            errs,
564            vec![
565                InsertError::DuplicateUriAnchor {
566                    anchor: 1,
567                    name: "uri_new".to_string(),
568                    prev: "uri_old".to_string()
569                },
570                InsertError::DuplicateAnchor {
571                    kind: ExtensionKind::Function,
572                    anchor: 10,
573                    prev: "func_old".to_string(),
574                    name: "func_new".to_string()
575                },
576                InsertError::MissingUri {
577                    kind: ExtensionKind::Function,
578                    anchor: 11,
579                    name: "func_missing".to_string(),
580                    uri: 3,
581                },
582            ]
583        );
584
585        // This is a duplicate anchor, so the first one is used.
586        assert_eq!(exts.find_uri(1).unwrap(), "uri_old");
587        assert_eq!(
588            exts.find_by_anchor(ExtensionKind::Function, 10).unwrap(),
589            (1, "func_old")
590        );
591    }
592
593    #[test]
594    fn test_from_extensions_invalid_mapping_type() {
595        let extensions = vec![pext::SimpleExtensionDeclaration { mapping_type: None }];
596
597        let (_exts, errs) = SimpleExtensions::from_extensions(vec![], &extensions);
598        assert_eq!(errs.len(), 1);
599        let err = &errs[0];
600        assert_eq!(err, &InsertError::MissingMappingType);
601    }
602
603    #[test]
604    fn test_find_by_name() {
605        let uris = vec![new_uri(1, "uri1")];
606        let extensions = vec![
607            new_ext_fn(10, 1, "name1"),
608            new_ext_fn(11, 1, "name2"),
609            new_ext_fn(12, 1, "name1"), // Duplicate name
610            new_ext_type(20, 1, "type_name1"),
611            new_type_var(30, 1, "var_name1"),
612        ];
613        let exts = unwrap_new_extensions(&uris, &extensions);
614
615        let err = exts
616            .find_by_name(ExtensionKind::Function, "name1")
617            .unwrap_err();
618        assert_eq!(
619            err,
620            MissingReference::DuplicateName(ExtensionKind::Function, "name1".to_string())
621        );
622
623        let found = exts.find_by_name(ExtensionKind::Function, "name2").unwrap();
624        assert_eq!(found, 11);
625
626        let found = exts
627            .find_by_name(ExtensionKind::Type, "type_name1")
628            .unwrap();
629        assert_eq!(found, 20);
630
631        let err = exts
632            .find_by_name(ExtensionKind::Type, "non_existent_type_name")
633            .unwrap_err();
634        assert_eq!(
635            err,
636            MissingReference::MissingName(
637                ExtensionKind::Type,
638                "non_existent_type_name".to_string()
639            )
640        );
641
642        let found = exts
643            .find_by_name(ExtensionKind::TypeVariation, "var_name1")
644            .unwrap();
645        assert_eq!(found, 30);
646
647        let err = exts
648            .find_by_name(ExtensionKind::TypeVariation, "non_existent_var_name")
649            .unwrap_err();
650        assert_eq!(
651            err,
652            MissingReference::MissingName(
653                ExtensionKind::TypeVariation,
654                "non_existent_var_name".to_string()
655            )
656        );
657    }
658
659    #[test]
660    fn test_display_extension_lookup_empty() {
661        let lookup = SimpleExtensions::new();
662        let mut output = String::new();
663        lookup.write(&mut output, "  ").unwrap();
664        let expected = r"";
665        assert_eq!(output, expected.trim_start());
666    }
667
668    #[test]
669    fn test_display_extension_lookup_with_content() {
670        let uris = vec![
671            new_uri(1, "/my/uri1"),
672            new_uri(42, "/another/uri"),
673            new_uri(4091, "/big/anchor"),
674        ];
675        let extensions = vec![
676            new_ext_fn(10, 1, "my_func"),
677            new_ext_type(20, 1, "my_type"),
678            new_type_var(30, 42, "my_var"),
679            new_ext_fn(11, 42, "another_func"),
680            new_ext_fn(108812, 4091, "big_func"),
681        ];
682        let exts = unwrap_new_extensions(&uris, &extensions);
683        let display_str = exts.to_string("  ");
684
685        let expected = r"
686=== Extensions
687URIs:
688  @  1: /my/uri1
689  @ 42: /another/uri
690  @4091: /big/anchor
691Functions:
692  # 10 @  1: my_func
693  # 11 @ 42: another_func
694  #108812 @4091: big_func
695Types:
696  # 20 @  1: my_type
697Type Variations:
698  # 30 @ 42: my_var
699";
700        assert_eq!(display_str, expected.trim_start());
701    }
702
703    #[test]
704    fn test_extensions_output() {
705        // Manually build the extensions
706        let mut extensions = SimpleExtensions::new();
707        extensions
708            .add_extension_uri("/uri/common".to_string(), 1)
709            .unwrap();
710        extensions
711            .add_extension_uri("/uri/specific_funcs".to_string(), 2)
712            .unwrap();
713        extensions
714            .add_extension(ExtensionKind::Function, 1, 10, "func_a".to_string())
715            .unwrap();
716        extensions
717            .add_extension(ExtensionKind::Function, 2, 11, "func_b_special".to_string())
718            .unwrap();
719        extensions
720            .add_extension(ExtensionKind::Type, 1, 20, "SomeType".to_string())
721            .unwrap();
722        extensions
723            .add_extension(ExtensionKind::TypeVariation, 2, 30, "VarX".to_string())
724            .unwrap();
725
726        // Convert to string
727        let output = extensions.to_string("  ");
728
729        // The output should match the expected format
730        let expected_output = r#"
731=== Extensions
732URIs:
733  @  1: /uri/common
734  @  2: /uri/specific_funcs
735Functions:
736  # 10 @  1: func_a
737  # 11 @  2: func_b_special
738Types:
739  # 20 @  1: SomeType
740Type Variations:
741  # 30 @  2: VarX
742"#;
743
744        assert_eq!(output, expected_output.trim_start());
745    }
746}