Skip to main content

substrait_explain/extensions/
simple.rs

1use std::collections::BTreeMap;
2use std::collections::btree_map::Entry;
3use std::fmt;
4
5use pext::simple_extension_declaration::MappingType;
6use substrait::proto::extensions as pext;
7use thiserror::Error;
8
9pub const EXTENSIONS_HEADER: &str = "=== Extensions";
10pub const EXTENSION_URNS_HEADER: &str = "URNs:";
11pub const EXTENSION_FUNCTIONS_HEADER: &str = "Functions:";
12pub const EXTENSION_TYPES_HEADER: &str = "Types:";
13pub const EXTENSION_TYPE_VARIATIONS_HEADER: &str = "Type Variations:";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub enum ExtensionKind {
17    // Urn,
18    Function,
19    Type,
20    TypeVariation,
21}
22
23impl ExtensionKind {
24    pub fn name(&self) -> &'static str {
25        match self {
26            // ExtensionKind::Urn => "urn",
27            ExtensionKind::Function => "function",
28            ExtensionKind::Type => "type",
29            ExtensionKind::TypeVariation => "type_variation",
30        }
31    }
32}
33
34impl fmt::Display for ExtensionKind {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            // ExtensionKind::Urn => write!(f, "URN"),
38            ExtensionKind::Function => write!(f, "Function"),
39            ExtensionKind::Type => write!(f, "Type"),
40            ExtensionKind::TypeVariation => write!(f, "Type Variation"),
41        }
42    }
43}
44
45#[derive(Error, Debug, PartialEq, Clone)]
46pub enum InsertError {
47    #[error("Extension declaration missing mapping type")]
48    MissingMappingType,
49
50    #[error("Duplicate URN anchor {anchor} for {prev} and {name}")]
51    DuplicateUrnAnchor {
52        anchor: u32,
53        prev: String,
54        name: String,
55    },
56
57    #[error("Duplicate extension {kind} anchor {anchor} for {prev} and {name}")]
58    DuplicateAnchor {
59        kind: ExtensionKind,
60        anchor: u32,
61        prev: String,
62        name: String,
63    },
64
65    #[error("Missing URN anchor {urn} for extension {kind} anchor {anchor} name {name}")]
66    MissingUrn {
67        kind: ExtensionKind,
68        anchor: u32,
69        name: String,
70        urn: u32,
71    },
72
73    #[error(
74        "Duplicate extension {kind} anchor {anchor} for {prev} and {name}, also missing URN {urn}"
75    )]
76    DuplicateAndMissingUrn {
77        kind: ExtensionKind,
78        anchor: u32,
79        prev: String,
80        name: String,
81        urn: u32,
82    },
83}
84
85/// A Substrait compound function name, e.g. `"equal:any_any"`, `"count:"`, or `"add"`.
86///
87/// A name is either *simple* (no `:`, e.g. `"add"`) or *full* (has a `:`,
88/// e.g. `"count:"` or `"equal:any_any"`). The part after the `:` is the
89/// type-signature suffix, which encodes the argument types (`"i64_i64"`) or
90/// zero argument types (`"count:"`).
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct CompoundName {
93    /// Full name including the signature suffix, e.g. `"equal:any_any"`.
94    name: String,
95    /// Byte index of the `:` separator, or `name.len()` when absent.
96    index: usize,
97}
98
99impl CompoundName {
100    pub fn new(name: impl Into<String>) -> Self {
101        let name = name.into();
102        let index = name.find(':').unwrap_or(name.len());
103        Self { name, index }
104    }
105
106    /// The base name (part before the first `:`), e.g. `"equal"`.
107    pub fn base(&self) -> &str {
108        &self.name[..self.index]
109    }
110
111    /// The full compound name, e.g. `"equal:any_any"` or `"count:"`.
112    pub fn full(&self) -> &str {
113        &self.name
114    }
115
116    /// Returns `true` if `self` (a stored name) is matched by `pattern` (a written name).
117    ///
118    /// - Simple pattern (no `:`): matches any stored name with the same base.
119    /// - Full pattern (has `:`): exact match only.
120    pub fn matches(&self, pattern: &str) -> bool {
121        if pattern.contains(':') {
122            self.full() == pattern
123        } else {
124            self.base() == pattern
125        }
126    }
127}
128
129impl fmt::Display for CompoundName {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.write_str(&self.name)
132    }
133}
134
135/// ExtensionLookup contains mappings from anchors to extension URNs, functions,
136/// types, and type variations.
137#[derive(Default, Debug, Clone, PartialEq)]
138pub struct SimpleExtensions {
139    // Maps from extension URN anchor to URN
140    urns: BTreeMap<u32, String>,
141    // Maps from anchor and extension kind to (URN anchor, name)
142    extensions: BTreeMap<(u32, ExtensionKind), (u32, CompoundName)>,
143}
144
145impl SimpleExtensions {
146    pub fn new() -> Self {
147        Self::default()
148    }
149
150    pub fn from_extensions<'a>(
151        urns: impl IntoIterator<Item = &'a pext::SimpleExtensionUrn>,
152        extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
153    ) -> (Self, Vec<InsertError>) {
154        // TODO: this checks for missing URNs and duplicate anchors, but not
155        // duplicate names with the same anchor.
156
157        let mut exts = Self::new();
158
159        let mut errors = Vec::<InsertError>::new();
160
161        for urn in urns {
162            if let Err(e) = exts.add_extension_urn(urn.urn.clone(), urn.extension_urn_anchor) {
163                errors.push(e);
164            }
165        }
166
167        for extension in extensions {
168            match &extension.mapping_type {
169                Some(MappingType::ExtensionType(t)) => {
170                    if let Err(e) = exts.add_extension(
171                        ExtensionKind::Type,
172                        t.extension_urn_reference,
173                        t.type_anchor,
174                        t.name.clone(),
175                    ) {
176                        errors.push(e);
177                    }
178                }
179                Some(MappingType::ExtensionFunction(f)) => {
180                    if let Err(e) = exts.add_extension(
181                        ExtensionKind::Function,
182                        f.extension_urn_reference,
183                        f.function_anchor,
184                        f.name.clone(),
185                    ) {
186                        errors.push(e);
187                    }
188                }
189                Some(MappingType::ExtensionTypeVariation(v)) => {
190                    if let Err(e) = exts.add_extension(
191                        ExtensionKind::TypeVariation,
192                        v.extension_urn_reference,
193                        v.type_variation_anchor,
194                        v.name.clone(),
195                    ) {
196                        errors.push(e);
197                    }
198                }
199                None => {
200                    errors.push(InsertError::MissingMappingType);
201                }
202            }
203        }
204
205        (exts, errors)
206    }
207
208    pub fn add_extension_urn(&mut self, urn: String, anchor: u32) -> Result<(), InsertError> {
209        match self.urns.entry(anchor) {
210            Entry::Occupied(e) => {
211                return Err(InsertError::DuplicateUrnAnchor {
212                    anchor,
213                    prev: e.get().clone(),
214                    name: urn,
215                });
216            }
217            Entry::Vacant(e) => {
218                e.insert(urn);
219            }
220        }
221        Ok(())
222    }
223
224    pub fn add_extension(
225        &mut self,
226        kind: ExtensionKind,
227        urn: u32,
228        anchor: u32,
229        name: String,
230    ) -> Result<(), InsertError> {
231        let missing_urn = !self.urns.contains_key(&urn);
232
233        let prev = match self.extensions.entry((anchor, kind)) {
234            Entry::Occupied(e) => Some(e.get().1.full().to_string()),
235            Entry::Vacant(v) => {
236                v.insert((urn, CompoundName::new(name.clone())));
237                None
238            }
239        };
240
241        match (missing_urn, prev) {
242            (true, Some(prev)) => Err(InsertError::DuplicateAndMissingUrn {
243                kind,
244                anchor,
245                prev,
246                name,
247                urn,
248            }),
249            (false, Some(prev)) => Err(InsertError::DuplicateAnchor {
250                kind,
251                anchor,
252                prev,
253                name,
254            }),
255            (true, None) => Err(InsertError::MissingUrn {
256                kind,
257                anchor,
258                name,
259                urn,
260            }),
261            (false, None) => Ok(()),
262        }
263    }
264
265    pub fn is_empty(&self) -> bool {
266        self.urns.is_empty() && self.extensions.is_empty()
267    }
268
269    /// Convert the extension URNs to protobuf format for Plan construction.
270    pub fn to_extension_urns(&self) -> Vec<pext::SimpleExtensionUrn> {
271        self.urns
272            .iter()
273            .map(|(anchor, urn)| pext::SimpleExtensionUrn {
274                extension_urn_anchor: *anchor,
275                urn: urn.clone(),
276            })
277            .collect()
278    }
279
280    /// Convert the extensions to protobuf format for Plan construction.
281    pub fn to_extension_declarations(&self) -> Vec<pext::SimpleExtensionDeclaration> {
282        self.extensions
283            .iter()
284            .map(|((anchor, kind), (urn_ref, name))| {
285                let mapping_type = match kind {
286                    ExtensionKind::Function => MappingType::ExtensionFunction(
287                        pext::simple_extension_declaration::ExtensionFunction {
288                            extension_urn_reference: *urn_ref,
289                            function_anchor: *anchor,
290                            name: name.full().to_string(),
291                        },
292                    ),
293                    ExtensionKind::Type => MappingType::ExtensionType(
294                        pext::simple_extension_declaration::ExtensionType {
295                            extension_urn_reference: *urn_ref,
296                            type_anchor: *anchor,
297                            name: name.full().to_string(),
298                        },
299                    ),
300                    ExtensionKind::TypeVariation => MappingType::ExtensionTypeVariation(
301                        pext::simple_extension_declaration::ExtensionTypeVariation {
302                            extension_urn_reference: *urn_ref,
303                            type_variation_anchor: *anchor,
304                            name: name.full().to_string(),
305                        },
306                    ),
307                };
308                pext::SimpleExtensionDeclaration {
309                    mapping_type: Some(mapping_type),
310                }
311            })
312            .collect()
313    }
314
315    /// Write the extensions to the given writer, with the given indent.
316    ///
317    /// The header will be included if there are any extensions.
318    pub fn write<W: fmt::Write>(&self, w: &mut W, indent: &str) -> fmt::Result {
319        if self.is_empty() {
320            // No extensions, so no need to write anything.
321            return Ok(());
322        }
323
324        writeln!(w, "{EXTENSIONS_HEADER}")?;
325        if !self.urns.is_empty() {
326            writeln!(w, "{EXTENSION_URNS_HEADER}")?;
327            for (anchor, urn) in &self.urns {
328                writeln!(w, "{indent}@{anchor:3}: {urn}")?;
329            }
330        }
331
332        let kinds_and_headers = [
333            (ExtensionKind::Function, EXTENSION_FUNCTIONS_HEADER),
334            (ExtensionKind::Type, EXTENSION_TYPES_HEADER),
335            (
336                ExtensionKind::TypeVariation,
337                EXTENSION_TYPE_VARIATIONS_HEADER,
338            ),
339        ];
340        for (kind, header) in kinds_and_headers {
341            let mut filtered = self
342                .extensions
343                .iter()
344                .filter(|((_a, k), _)| *k == kind)
345                .peekable();
346            if filtered.peek().is_none() {
347                continue;
348            }
349
350            writeln!(w, "{header}")?;
351            for ((anchor, _), (urn_ref, name)) in filtered {
352                writeln!(w, "{indent}#{anchor:3} @{urn_ref:3}: {name}")?;
353            }
354        }
355        Ok(())
356    }
357
358    pub fn to_string(&self, indent: &str) -> String {
359        let mut output = String::new();
360        self.write(&mut output, indent).unwrap();
361        output
362    }
363}
364
365#[derive(Error, Debug, Clone, PartialEq)]
366pub enum MissingReference {
367    #[error("Missing URN for {0}")]
368    MissingUrn(u32),
369    #[error("Missing anchor for {0}: {1}")]
370    MissingAnchor(ExtensionKind, u32),
371    #[error("Missing name for {0}: {1}")]
372    MissingName(ExtensionKind, String),
373    #[error("Mismatched {0}: {1}#{2}")]
374    /// When the name of the value does not match the expected name
375    Mismatched(ExtensionKind, String, u32),
376    #[error("Duplicate name without anchor for {0}: {1}")]
377    DuplicateName(ExtensionKind, String),
378}
379
380#[derive(Debug, Clone, PartialEq)]
381pub struct SimpleExtension {
382    pub kind: ExtensionKind,
383    pub name: String,
384    pub anchor: u32,
385    pub urn: u32,
386}
387
388/// The result of resolving a function anchor to its full metadata.
389pub struct ResolvedFunction<'a> {
390    pub anchor: u32,
391    pub urn: u32,
392    /// The full compound name stored for this anchor.
393    pub name: &'a CompoundName,
394    /// `true` when the base name is unique across all registered functions
395    /// (controls whether the signature suffix is needed in compact mode).
396    pub base_name_unique: bool,
397    /// `true` when the full compound name is unique across all registered
398    /// functions (controls whether the `#anchor` suffix is needed).
399    pub name_unique: bool,
400}
401
402impl SimpleExtensions {
403    pub fn find_urn(&self, anchor: u32) -> Result<&str, MissingReference> {
404        self.urns
405            .get(&anchor)
406            .map(String::as_str)
407            .ok_or(MissingReference::MissingUrn(anchor))
408    }
409
410    pub fn find_by_anchor(
411        &self,
412        kind: ExtensionKind,
413        anchor: u32,
414    ) -> Result<(u32, &CompoundName), MissingReference> {
415        let &(urn, ref name) = self
416            .extensions
417            .get(&(anchor, kind))
418            .ok_or(MissingReference::MissingAnchor(kind, anchor))?;
419
420        Ok((urn, name))
421    }
422
423    pub fn find_by_name(&self, kind: ExtensionKind, name: &str) -> Result<u32, MissingReference> {
424        let mut matches = self
425            .extensions
426            .iter()
427            .filter(move |((_a, k), (_, n))| *k == kind && n.full() == name)
428            .map(|((anchor, _), _)| *anchor);
429
430        let anchor = matches
431            .next()
432            .ok_or(MissingReference::MissingName(kind, name.to_string()))?;
433
434        match matches.next() {
435            Some(_) => Err(MissingReference::DuplicateName(kind, name.to_string())),
436            None => Ok(anchor),
437        }
438    }
439
440    /// Returns `true` when no other extension of the same kind has the same
441    /// full compound name (i.e. the anchor display can be suppressed).
442    ///
443    /// Returns `Err` when `anchor` is not registered for `kind`.
444    pub fn is_name_unique(
445        &self,
446        kind: ExtensionKind,
447        anchor: u32,
448        name: &str,
449    ) -> Result<bool, MissingReference> {
450        let mut found = false;
451        let mut other = false;
452        for (&(a, k), (_, n)) in self.extensions.iter() {
453            if k != kind {
454                continue;
455            }
456
457            if a == anchor {
458                found = true;
459                if n.full() != name {
460                    return Err(MissingReference::Mismatched(kind, name.to_string(), anchor));
461                }
462                continue;
463            }
464
465            if n.full() != name {
466                // Neither anchor nor name match, so this is irrelevant.
467                continue;
468            }
469
470            // At this point, the anchor is different, but the name is the same.
471            other = true;
472            if found {
473                break;
474            }
475        }
476
477        match (found, other) {
478            // Found the one we're looking for, and no other matches.
479            (true, false) => Ok(true),
480            // Found the one we're looking for, and another match.
481            (true, true) => Ok(false),
482            // Didn't find the one we're looking for.
483            (false, _) => Err(MissingReference::MissingAnchor(kind, anchor)),
484        }
485    }
486
487    /// Look up a function anchor and return its full resolution metadata.
488    /// The caller already has `anchor` from the
489    /// Substrait plan and needs the name, URN, and uniqueness flags.
490    pub fn lookup_function(&self, anchor: u32) -> Result<ResolvedFunction<'_>, MissingReference> {
491        let (urn, name) = self.find_by_anchor(ExtensionKind::Function, anchor)?;
492        let name_unique = self.is_name_unique(ExtensionKind::Function, anchor, name.full())?;
493        let base_name_unique = self.is_base_name_unique(ExtensionKind::Function, anchor)?;
494        Ok(ResolvedFunction {
495            anchor,
496            urn,
497            name,
498            name_unique,
499            base_name_unique,
500        })
501    }
502
503    /// Resolve a [`CompoundName`] written in the plan to a [`ResolvedFunction`].
504    ///
505    /// * `anchor = Some(a)` — the anchor identifies the function; the name is
506    ///   validated for consistency using [`CompoundName::matches`].
507    /// * `anchor = None` — the name must be unambiguous on its own:
508    ///   - Simple (no `:`): base-name search; fails if more than one function
509    ///     shares that base name.
510    ///   - Full (has `:`): exact match only.
511    pub fn resolve_function(
512        &self,
513        name: &str,
514        anchor: Option<u32>,
515    ) -> Result<ResolvedFunction<'_>, MissingReference> {
516        let resolved_anchor = match anchor {
517            Some(a) => {
518                let (_, stored) = self.find_by_anchor(ExtensionKind::Function, a)?;
519                if stored.matches(name) {
520                    a
521                } else {
522                    return Err(MissingReference::Mismatched(
523                        ExtensionKind::Function,
524                        name.to_string(),
525                        a,
526                    ));
527                }
528            }
529            None => {
530                if name.contains(':') {
531                    self.find_by_name(ExtensionKind::Function, name)?
532                } else {
533                    self.find_by_base_name(ExtensionKind::Function, name)?
534                }
535            }
536        };
537        self.lookup_function(resolved_anchor)
538    }
539
540    fn is_base_name_unique(
541        &self,
542        kind: ExtensionKind,
543        anchor: u32,
544    ) -> Result<bool, MissingReference> {
545        let (_, name) = self.find_by_anchor(kind, anchor)?;
546        let my_base = name.base();
547
548        let other_exists = self
549            .extensions
550            .iter()
551            .any(|(&(a, k), (_, n))| k == kind && a != anchor && n.base() == my_base);
552
553        Ok(!other_exists)
554    }
555
556    fn find_by_base_name(&self, kind: ExtensionKind, base: &str) -> Result<u32, MissingReference> {
557        let mut matches = self
558            .extensions
559            .iter()
560            .filter(|&(&(_a, k), (_, n))| k == kind && n.matches(base))
561            .map(|(&(anchor, _), _)| anchor);
562
563        let anchor = matches
564            .next()
565            .ok_or_else(|| MissingReference::MissingName(kind, base.to_string()))?;
566
567        match matches.next() {
568            Some(_) => Err(MissingReference::DuplicateName(kind, base.to_string())),
569            None => Ok(anchor),
570        }
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use pext::simple_extension_declaration::{
577        ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
578    };
579    use substrait::proto::extensions as pext;
580
581    use super::*;
582
583    fn new_urn(anchor: u32, urn_str: &str) -> pext::SimpleExtensionUrn {
584        pext::SimpleExtensionUrn {
585            extension_urn_anchor: anchor,
586            urn: urn_str.to_string(),
587        }
588    }
589
590    fn new_ext_fn(anchor: u32, urn_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
591        pext::SimpleExtensionDeclaration {
592            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
593                extension_urn_reference: urn_ref,
594                function_anchor: anchor,
595                name: name.to_string(),
596            })),
597        }
598    }
599
600    fn new_ext_type(anchor: u32, urn_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
601        pext::SimpleExtensionDeclaration {
602            mapping_type: Some(MappingType::ExtensionType(ExtensionType {
603                extension_urn_reference: urn_ref,
604                type_anchor: anchor,
605                name: name.to_string(),
606            })),
607        }
608    }
609
610    fn new_type_var(anchor: u32, urn_ref: u32, name: &str) -> pext::SimpleExtensionDeclaration {
611        pext::SimpleExtensionDeclaration {
612            mapping_type: Some(MappingType::ExtensionTypeVariation(
613                ExtensionTypeVariation {
614                    extension_urn_reference: urn_ref,
615                    type_variation_anchor: anchor,
616                    name: name.to_string(),
617                },
618            )),
619        }
620    }
621
622    fn assert_no_errors(errs: &[InsertError]) {
623        for err in errs {
624            println!("Error: {err:?}");
625        }
626        assert!(errs.is_empty());
627    }
628
629    fn unwrap_new_extensions<'a>(
630        urns: impl IntoIterator<Item = &'a pext::SimpleExtensionUrn>,
631        extensions: impl IntoIterator<Item = &'a pext::SimpleExtensionDeclaration>,
632    ) -> SimpleExtensions {
633        let (exts, errs) = SimpleExtensions::from_extensions(urns, extensions);
634        assert_no_errors(&errs);
635        exts
636    }
637
638    #[test]
639    fn test_extension_lookup_empty() {
640        let lookup = SimpleExtensions::new();
641        assert!(lookup.find_urn(1).is_err());
642        assert!(lookup.find_by_anchor(ExtensionKind::Function, 1).is_err());
643        assert!(lookup.find_by_anchor(ExtensionKind::Type, 1).is_err());
644        assert!(
645            lookup
646                .find_by_anchor(ExtensionKind::TypeVariation, 1)
647                .is_err()
648        );
649        assert!(lookup.find_by_name(ExtensionKind::Function, "any").is_err());
650        assert!(lookup.find_by_name(ExtensionKind::Type, "any").is_err());
651        assert!(
652            lookup
653                .find_by_name(ExtensionKind::TypeVariation, "any")
654                .is_err()
655        );
656    }
657
658    #[test]
659    fn test_from_extensions_basic() {
660        let urns = vec![new_urn(1, "urn1"), new_urn(2, "urn2")];
661        let extensions = vec![
662            new_ext_fn(10, 1, "func1"),
663            new_ext_type(20, 1, "type1"),
664            new_type_var(30, 2, "var1"),
665        ];
666        let exts = unwrap_new_extensions(&urns, &extensions);
667
668        assert_eq!(exts.find_urn(1).unwrap(), "urn1");
669        assert_eq!(exts.find_urn(2).unwrap(), "urn2");
670        assert!(exts.find_urn(3).is_err());
671
672        let (urn, name) = exts.find_by_anchor(ExtensionKind::Function, 10).unwrap();
673        assert_eq!(name.full(), "func1");
674        assert_eq!(urn, 1);
675        assert!(exts.find_by_anchor(ExtensionKind::Function, 11).is_err());
676
677        let (urn, name) = exts.find_by_anchor(ExtensionKind::Type, 20).unwrap();
678        assert_eq!(name.full(), "type1");
679        assert_eq!(urn, 1);
680        assert!(exts.find_by_anchor(ExtensionKind::Type, 21).is_err());
681
682        let (urn, name) = exts
683            .find_by_anchor(ExtensionKind::TypeVariation, 30)
684            .unwrap();
685        assert_eq!(name.full(), "var1");
686        assert_eq!(urn, 2);
687        assert!(
688            exts.find_by_anchor(ExtensionKind::TypeVariation, 31)
689                .is_err()
690        );
691    }
692
693    #[test]
694    fn test_from_extensions_duplicates() {
695        let urns = vec![
696            new_urn(1, "urn_old"),
697            new_urn(1, "urn_new"),
698            new_urn(2, "second"),
699        ];
700        let extensions = vec![
701            new_ext_fn(10, 1, "func_old"),
702            new_ext_fn(10, 2, "func_new"), // Duplicate function anchor
703            new_ext_fn(11, 3, "func_missing"),
704        ];
705        let (exts, errs) = SimpleExtensions::from_extensions(&urns, &extensions);
706        assert_eq!(
707            errs,
708            vec![
709                InsertError::DuplicateUrnAnchor {
710                    anchor: 1,
711                    name: "urn_new".to_string(),
712                    prev: "urn_old".to_string()
713                },
714                InsertError::DuplicateAnchor {
715                    kind: ExtensionKind::Function,
716                    anchor: 10,
717                    prev: "func_old".to_string(),
718                    name: "func_new".to_string()
719                },
720                InsertError::MissingUrn {
721                    kind: ExtensionKind::Function,
722                    anchor: 11,
723                    name: "func_missing".to_string(),
724                    urn: 3,
725                },
726            ]
727        );
728
729        // This is a duplicate anchor, so the first one is used.
730        assert_eq!(exts.find_urn(1).unwrap(), "urn_old");
731        let (urn, name) = exts.find_by_anchor(ExtensionKind::Function, 10).unwrap();
732        assert_eq!(urn, 1);
733        assert_eq!(name.full(), "func_old");
734    }
735
736    #[test]
737    fn test_from_extensions_invalid_mapping_type() {
738        let extensions = vec![pext::SimpleExtensionDeclaration { mapping_type: None }];
739
740        let (_exts, errs) = SimpleExtensions::from_extensions(vec![], &extensions);
741        assert_eq!(errs.len(), 1);
742        let err = &errs[0];
743        assert_eq!(err, &InsertError::MissingMappingType);
744    }
745
746    #[test]
747    fn test_find_by_name() {
748        let urns = vec![new_urn(1, "urn1")];
749        let extensions = vec![
750            new_ext_fn(10, 1, "name1"),
751            new_ext_fn(11, 1, "name2"),
752            new_ext_fn(12, 1, "name1"), // Duplicate name
753            new_ext_type(20, 1, "type_name1"),
754            new_type_var(30, 1, "var_name1"),
755        ];
756        let exts = unwrap_new_extensions(&urns, &extensions);
757
758        let err = exts
759            .find_by_name(ExtensionKind::Function, "name1")
760            .unwrap_err();
761        assert_eq!(
762            err,
763            MissingReference::DuplicateName(ExtensionKind::Function, "name1".to_string())
764        );
765
766        let found = exts.find_by_name(ExtensionKind::Function, "name2").unwrap();
767        assert_eq!(found, 11);
768
769        let found = exts
770            .find_by_name(ExtensionKind::Type, "type_name1")
771            .unwrap();
772        assert_eq!(found, 20);
773
774        let err = exts
775            .find_by_name(ExtensionKind::Type, "non_existent_type_name")
776            .unwrap_err();
777        assert_eq!(
778            err,
779            MissingReference::MissingName(
780                ExtensionKind::Type,
781                "non_existent_type_name".to_string()
782            )
783        );
784
785        let found = exts
786            .find_by_name(ExtensionKind::TypeVariation, "var_name1")
787            .unwrap();
788        assert_eq!(found, 30);
789
790        let err = exts
791            .find_by_name(ExtensionKind::TypeVariation, "non_existent_var_name")
792            .unwrap_err();
793        assert_eq!(
794            err,
795            MissingReference::MissingName(
796                ExtensionKind::TypeVariation,
797                "non_existent_var_name".to_string()
798            )
799        );
800    }
801
802    #[test]
803    fn test_display_extension_lookup_empty() {
804        let lookup = SimpleExtensions::new();
805        let mut output = String::new();
806        lookup.write(&mut output, "  ").unwrap();
807        let expected = r"";
808        assert_eq!(output, expected.trim_start());
809    }
810
811    #[test]
812    fn test_display_extension_lookup_with_content() {
813        let urns = vec![
814            new_urn(1, "/my/urn1"),
815            new_urn(42, "/another/urn"),
816            new_urn(4091, "/big/anchor"),
817        ];
818        let extensions = vec![
819            new_ext_fn(10, 1, "my_func"),
820            new_ext_type(20, 1, "my_type"),
821            new_type_var(30, 42, "my_var"),
822            new_ext_fn(11, 42, "another_func"),
823            new_ext_fn(108812, 4091, "big_func"),
824        ];
825        let exts = unwrap_new_extensions(&urns, &extensions);
826        let display_str = exts.to_string("  ");
827
828        let expected = r"
829=== Extensions
830URNs:
831  @  1: /my/urn1
832  @ 42: /another/urn
833  @4091: /big/anchor
834Functions:
835  # 10 @  1: my_func
836  # 11 @ 42: another_func
837  #108812 @4091: big_func
838Types:
839  # 20 @  1: my_type
840Type Variations:
841  # 30 @ 42: my_var
842";
843        assert_eq!(display_str, expected.trim_start());
844    }
845
846    #[test]
847    fn test_extensions_output() {
848        // Manually build the extensions
849        let mut extensions = SimpleExtensions::new();
850        extensions
851            .add_extension_urn("/urn/common".to_string(), 1)
852            .unwrap();
853        extensions
854            .add_extension_urn("/urn/specific_funcs".to_string(), 2)
855            .unwrap();
856        extensions
857            .add_extension(ExtensionKind::Function, 1, 10, "func_a".to_string())
858            .unwrap();
859        extensions
860            .add_extension(ExtensionKind::Function, 2, 11, "func_b_special".to_string())
861            .unwrap();
862        extensions
863            .add_extension(ExtensionKind::Type, 1, 20, "SomeType".to_string())
864            .unwrap();
865        extensions
866            .add_extension(ExtensionKind::TypeVariation, 2, 30, "VarX".to_string())
867            .unwrap();
868
869        // Convert to string
870        let output = extensions.to_string("  ");
871
872        // The output should match the expected format
873        let expected_output = r#"
874=== Extensions
875URNs:
876  @  1: /urn/common
877  @  2: /urn/specific_funcs
878Functions:
879  # 10 @  1: func_a
880  # 11 @  2: func_b_special
881Types:
882  # 20 @  1: SomeType
883Type Variations:
884  # 30 @  2: VarX
885"#;
886
887        assert_eq!(output, expected_output.trim_start());
888    }
889
890    #[test]
891    fn test_compound_name_full_zero_arg_type_signature() {
892        // A Full name whose type signature encodes zero argument types (nothing after the colon).
893        let n = CompoundName::new("add:");
894        assert_eq!(n.full(), "add:");
895        assert_eq!(n.base(), "add");
896        // Full pattern: exact match only.
897        assert!(n.matches("add:"));
898        assert!(!n.matches("add:i64_i64"));
899        // Simple pattern: base match.
900        assert!(n.matches("add"));
901    }
902
903    #[test]
904    fn test_compound_name_with_signature() {
905        let n = CompoundName::new("equal:any_any");
906        assert_eq!(n.full(), "equal:any_any");
907        assert_eq!(n.base(), "equal");
908
909        let n2 = CompoundName::new("regexp_match_substring:str_str_i64");
910        assert_eq!(n2.base(), "regexp_match_substring");
911        assert_eq!(n2.full(), "regexp_match_substring:str_str_i64");
912
913        let n3 = CompoundName::new("add:i64_i64");
914        assert_eq!(n3.base(), "add");
915    }
916
917    // ---- Tests for lookup_function ----
918
919    fn make_overloaded_extensions() -> SimpleExtensions {
920        let urns = vec![new_urn(1, "urn:comparison")];
921        let extensions = vec![
922            new_ext_fn(1, 1, "equal:any_any"),
923            new_ext_fn(2, 1, "equal:str_str"),
924            new_ext_fn(3, 1, "add:i64_i64"),
925        ];
926        unwrap_new_extensions(&urns, &extensions)
927    }
928
929    #[test]
930    fn test_lookup_function_uniqueness_flags() {
931        // `equal:any_any` and `equal:str_str` share the base name "equal" →
932        // base_name_unique false, compound name unique within the one URN.
933        // `add:i64_i64` is the only "add" → both flags true.
934        let exts = make_overloaded_extensions();
935
936        let r1 = exts.lookup_function(1).unwrap();
937        assert_eq!(r1.name.full(), "equal:any_any");
938        assert!(!r1.base_name_unique, "two 'equal' overloads");
939        assert!(r1.name_unique, "compound name 'equal:any_any' is unique");
940
941        let r2 = exts.lookup_function(2).unwrap();
942        assert_eq!(r2.name.full(), "equal:str_str");
943        assert!(!r2.base_name_unique);
944        assert!(r2.name_unique);
945
946        let r3 = exts.lookup_function(3).unwrap();
947        assert_eq!(r3.name.full(), "add:i64_i64");
948        assert!(r3.base_name_unique, "only one 'add' overload");
949        assert!(r3.name_unique, "compound name appears only once");
950    }
951
952    #[test]
953    fn test_lookup_function_missing_anchor() {
954        let exts = SimpleExtensions::new();
955        assert!(exts.lookup_function(99).is_err());
956    }
957
958    #[test]
959    fn test_lookup_function_plain_name_overloaded_across_urns() {
960        // Same plain name in two URNs → base_name_unique false, name_unique false.
961        let urns = vec![new_urn(1, "urn1"), new_urn(2, "urn2")];
962        let extensions = vec![
963            new_ext_fn(1, 1, "duplicated"),
964            new_ext_fn(2, 2, "duplicated"),
965        ];
966        let exts = unwrap_new_extensions(&urns, &extensions);
967
968        let r = exts.lookup_function(1).unwrap();
969        assert!(!r.base_name_unique);
970        assert!(!r.name_unique);
971    }
972
973    #[test]
974    fn test_lookup_function_different_base_names_each_unique() {
975        // `equal:any_any` and `like:str_str` have distinct base names → each unique.
976        let urns = vec![new_urn(1, "urn1")];
977        let extensions = vec![
978            new_ext_fn(1, 1, "equal:any_any"),
979            new_ext_fn(2, 1, "like:str_str"),
980        ];
981        let exts = unwrap_new_extensions(&urns, &extensions);
982
983        assert!(exts.lookup_function(1).unwrap().base_name_unique);
984        assert!(exts.lookup_function(2).unwrap().base_name_unique);
985    }
986
987    // ---- Tests for resolve_function ----
988
989    fn make_resolve_extensions() -> SimpleExtensions {
990        let urns = vec![new_urn(1, "test_urn")];
991        let extensions = vec![
992            new_ext_fn(1, 1, "equal:any_any"),
993            new_ext_fn(2, 1, "equal:str_str"),
994            new_ext_fn(3, 1, "add:i64_i64"),
995            new_ext_fn(4, 1, "add:"),
996        ];
997        unwrap_new_extensions(&urns, &extensions)
998    }
999
1000    #[test]
1001    fn test_resolve_function_with_anchor() {
1002        // Explicit anchor: exact compound name and mismatch errors.
1003        let exts = make_resolve_extensions();
1004
1005        // Exact compound name matches stored name → resolves.
1006        assert_eq!(
1007            exts.resolve_function("equal:any_any", Some(1))
1008                .unwrap()
1009                .anchor,
1010            1
1011        );
1012
1013        // Simple (no-sig) form matches any stored name with the same base.
1014        assert_eq!(exts.resolve_function("add", Some(3)).unwrap().anchor, 3);
1015        assert_eq!(exts.resolve_function("add", Some(4)).unwrap().anchor, 4);
1016
1017        // Full form requires exact match — "add:" does not match anchor 3 (stored "add:i64_i64").
1018        assert!(exts.resolve_function("add:", Some(3)).is_err());
1019
1020        // "add" does not match stored "equal:any_any" (different base) → error.
1021        assert!(exts.resolve_function("add", Some(1)).is_err());
1022
1023        // Same base name, different overload → error.
1024        assert!(exts.resolve_function("equal:any_any", Some(2)).is_err());
1025    }
1026
1027    #[test]
1028    fn test_resolve_function_without_anchor() {
1029        // No anchor: exact compound name resolution.
1030        let exts = make_resolve_extensions();
1031
1032        assert_eq!(
1033            exts.resolve_function("equal:any_any", None).unwrap().anchor,
1034            1
1035        );
1036        assert_eq!(
1037            exts.resolve_function("equal:str_str", None).unwrap().anchor,
1038            2
1039        );
1040    }
1041
1042    #[test]
1043    fn test_resolve_function_without_anchor_full_sig() {
1044        // Full form (has colon) resolves by exact match only — no fallback.
1045        let exts = make_resolve_extensions();
1046
1047        assert_eq!(exts.resolve_function("add:", None).unwrap().anchor, 4);
1048        assert_eq!(
1049            exts.resolve_function("add:i64_i64", None).unwrap().anchor,
1050            3
1051        );
1052
1053        // "equal:" is not registered → error (no fallback to base-name search).
1054        assert!(exts.resolve_function("equal:", None).is_err());
1055    }
1056
1057    #[test]
1058    fn test_resolve_function_without_anchor_no_sig() {
1059        // No-signature form (no colon) without anchor: base-name search.
1060        let exts = make_resolve_extensions();
1061
1062        // Ambiguous base name → error.
1063        assert!(exts.resolve_function("add", None).is_err());
1064        assert!(exts.resolve_function("equal", None).is_err());
1065    }
1066
1067    #[test]
1068    fn test_resolve_function_plain_stored_name() {
1069        // Functions stored without a signature still resolve by their plain name.
1070        let urns = vec![new_urn(1, "urn")];
1071        let extensions = vec![new_ext_fn(10, 1, "coalesce")];
1072        let exts = unwrap_new_extensions(&urns, &extensions);
1073        assert_eq!(exts.resolve_function("coalesce", None).unwrap().anchor, 10);
1074    }
1075
1076    #[test]
1077    fn test_resolve_function_not_found() {
1078        let exts = SimpleExtensions::new();
1079        assert!(exts.resolve_function("nonexistent", None).is_err());
1080    }
1081
1082    #[test]
1083    fn test_compound_name_roundtrip_in_extensions_section() {
1084        // Verify that compound names survive a write → parse roundtrip through
1085        // the Extensions section text format.
1086        let urns = vec![new_urn(1, "substrait:functions_comparison")];
1087        let extensions = vec![
1088            new_ext_fn(1, 1, "equal:any_any"),
1089            new_ext_fn(2, 1, "equal:str_str"),
1090        ];
1091        let exts = unwrap_new_extensions(&urns, &extensions);
1092
1093        let text = exts.to_string("  ");
1094        assert!(
1095            text.contains("equal:any_any"),
1096            "compound name must appear in output"
1097        );
1098        assert!(
1099            text.contains("equal:str_str"),
1100            "compound name must appear in output"
1101        );
1102    }
1103}