Skip to main content

substrait_explain/parser/
extensions.rs

1use std::fmt;
2use std::str::FromStr;
3
4use substrait::proto::{Expression, Type};
5use thiserror::Error;
6
7use super::{
8    MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
9    unwrap_single_pair,
10};
11use crate::extensions::simple::{self, ExtensionKind};
12use crate::extensions::{
13    AddendumKind, ExtensionArgs, ExtensionColumn, ExtensionValue, InsertError, SimpleExtensions,
14    TupleValue,
15};
16use crate::parser::structural::IndentedLine;
17
18#[derive(Debug, Clone, Error)]
19pub enum ExtensionParseError {
20    #[error("Unexpected line, expected {0}")]
21    UnexpectedLine(ExpectedExtensionLine),
22    #[error("Error adding extension: {0}")]
23    ExtensionError(#[from] InsertError),
24    #[error("Error parsing message: {0}")]
25    Message(#[from] super::MessageParseError),
26}
27
28/// The kind of extension-section line expected next.
29///
30/// `ExtensionParser` also uses this as its internal state, since each parser
31/// state corresponds directly to the next accepted line shape.
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum ExpectedExtensionLine {
34    // The extensions section, after parsing the 'Extensions:' header, before
35    // parsing any subsection headers.
36    Extensions,
37    // The extension URNs section, after parsing the 'URNs:' subsection header,
38    // and any URNs so far.
39    ExtensionUrns,
40    // In a subsection, after parsing the subsection header, and any
41    // declarations so far.
42    ExtensionDeclarations(ExtensionKind),
43}
44
45impl fmt::Display for ExpectedExtensionLine {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            ExpectedExtensionLine::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
49            ExpectedExtensionLine::ExtensionUrns => write!(f, "Extension URNs"),
50            ExpectedExtensionLine::ExtensionDeclarations(kind) => {
51                write!(f, "Extension Declaration for {kind}")
52            }
53        }
54    }
55}
56
57/// The parser for the extension section of the Substrait file format.
58///
59/// This is responsible for parsing the extension section of the file, which
60/// contains the extension URNs and declarations. Note that this parser does not
61/// parse the header; otherwise, this is symmetric with the
62/// SimpleExtensions::write method.
63#[derive(Debug)]
64pub struct ExtensionParser {
65    state: ExpectedExtensionLine,
66    extensions: SimpleExtensions,
67}
68
69impl Default for ExtensionParser {
70    fn default() -> Self {
71        Self {
72            state: ExpectedExtensionLine::Extensions,
73            extensions: SimpleExtensions::new(),
74        }
75    }
76}
77
78impl ExtensionParser {
79    pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
80        if line.1.is_empty() {
81            // Blank lines are allowed between subsections, so if we see
82            // one, we revert out of the subsection.
83            self.state = ExpectedExtensionLine::Extensions;
84            return Ok(());
85        }
86
87        match self.state {
88            ExpectedExtensionLine::Extensions => self.parse_subsection(line),
89            ExpectedExtensionLine::ExtensionUrns => self.parse_extension_urns(line),
90            ExpectedExtensionLine::ExtensionDeclarations(extension_kind) => {
91                self.parse_declarations(line, extension_kind)
92            }
93        }
94    }
95
96    fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
97        match line {
98            IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
99                self.state = ExpectedExtensionLine::ExtensionUrns;
100                Ok(())
101            }
102            IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
103                self.state = ExpectedExtensionLine::ExtensionDeclarations(ExtensionKind::Function);
104                Ok(())
105            }
106            IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
107                self.state = ExpectedExtensionLine::ExtensionDeclarations(ExtensionKind::Type);
108                Ok(())
109            }
110            IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
111                self.state =
112                    ExpectedExtensionLine::ExtensionDeclarations(ExtensionKind::TypeVariation);
113                Ok(())
114            }
115            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
116        }
117    }
118
119    fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
120        match line {
121            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
122            IndentedLine(1, s) => {
123                let urn =
124                    URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
125                self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
126                Ok(())
127            }
128            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
129        }
130    }
131
132    fn parse_declarations(
133        &mut self,
134        line: IndentedLine,
135        extension_kind: ExtensionKind,
136    ) -> Result<(), ExtensionParseError> {
137        match line {
138            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
139            IndentedLine(1, s) => {
140                let decl = SimpleExtensionDeclaration::from_str(s)?;
141                self.extensions.add_extension(
142                    extension_kind,
143                    decl.urn_anchor,
144                    decl.anchor,
145                    decl.name,
146                )?;
147                Ok(())
148            }
149            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
150        }
151    }
152
153    pub fn extensions(&self) -> &SimpleExtensions {
154        &self.extensions
155    }
156
157    #[cfg(test)]
158    pub(crate) fn state(&self) -> ExpectedExtensionLine {
159        self.state
160    }
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub struct URNExtensionDeclaration {
165    pub anchor: u32,
166    pub urn: String,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub struct SimpleExtensionDeclaration {
171    pub anchor: u32,
172    pub urn_anchor: u32,
173    pub name: String,
174}
175
176impl ParsePair for URNExtensionDeclaration {
177    fn rule() -> Rule {
178        Rule::extension_urn_declaration
179    }
180
181    fn message() -> &'static str {
182        "URNExtensionDeclaration"
183    }
184
185    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
186        assert_eq!(pair.as_rule(), Self::rule());
187
188        let mut iter = RuleIter::from(pair.into_inner());
189        let anchor_pair = iter.pop(Rule::urn_anchor);
190        let anchor = unwrap_single_pair(anchor_pair)
191            .as_str()
192            .parse::<u32>()
193            .unwrap();
194        let urn = iter.pop(Rule::urn).as_str().to_string();
195        iter.done();
196
197        URNExtensionDeclaration { anchor, urn }
198    }
199}
200
201impl FromStr for URNExtensionDeclaration {
202    type Err = super::MessageParseError;
203
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        Self::parse_str(s)
206    }
207}
208
209impl ParsePair for SimpleExtensionDeclaration {
210    fn rule() -> Rule {
211        Rule::simple_extension
212    }
213
214    fn message() -> &'static str {
215        "SimpleExtensionDeclaration"
216    }
217
218    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
219        assert_eq!(pair.as_rule(), Self::rule());
220        let mut iter = RuleIter::from(pair.into_inner());
221        let anchor_pair = iter.pop(Rule::anchor);
222        let anchor = unwrap_single_pair(anchor_pair)
223            .as_str()
224            .parse::<u32>()
225            .unwrap();
226        let urn_anchor_pair = iter.pop(Rule::urn_anchor);
227        let urn_anchor = unwrap_single_pair(urn_anchor_pair)
228            .as_str()
229            .parse::<u32>()
230            .unwrap();
231        // compound_name handles both plain names ("add") and compound names with signatures ("equal:any_any").
232        let name_pair = iter.pop(Rule::compound_name);
233        let name = name_pair.as_str().to_string();
234        iter.done();
235
236        SimpleExtensionDeclaration {
237            anchor,
238            urn_anchor,
239            name,
240        }
241    }
242}
243
244impl FromStr for SimpleExtensionDeclaration {
245    type Err = super::MessageParseError;
246
247    fn from_str(s: &str) -> Result<Self, Self::Err> {
248        Self::parse_str(s)
249    }
250}
251
252// Extension relation parsing implementations
253// These were moved from extensions/registry.rs to maintain clean architecture
254
255use crate::extensions::any::Any;
256use crate::parser::expressions::{FieldIndex, Name};
257use crate::textify::expressions::Reference;
258
259impl ScopedParsePair for ExtensionValue {
260    fn rule() -> Rule {
261        Rule::extension_argument
262    }
263
264    fn message() -> &'static str {
265        "ExtensionValue"
266    }
267
268    fn parse_pair(
269        extensions: &SimpleExtensions,
270        pair: pest::iterators::Pair<Rule>,
271    ) -> Result<Self, MessageParseError> {
272        assert_eq!(pair.as_rule(), Self::rule());
273
274        let inner = unwrap_single_pair(pair); // Extract the actual content
275
276        Ok(match inner.as_rule() {
277            Rule::enum_value => {
278                // Strip leading '&' and store the identifier
279                let s = inner.as_str().trim_start_matches('&').to_string();
280                ExtensionValue::Enum(s)
281            }
282            Rule::reference => {
283                // Reuse the existing FieldIndex parser, then extract the i32
284                let field_index = FieldIndex::parse_pair(inner);
285                ExtensionValue::from(Reference(field_index.0))
286            }
287            Rule::untyped_literal => {
288                // Literal can contain integer, float, boolean, or string_literal
289                let value_pair = unwrap_single_pair(inner);
290                match value_pair.as_rule() {
291                    Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
292                    Rule::integer => {
293                        ExtensionValue::Integer(value_pair.as_str().parse::<i64>().unwrap())
294                    }
295                    Rule::float => {
296                        ExtensionValue::Float(value_pair.as_str().parse::<f64>().unwrap())
297                    }
298                    Rule::boolean => ExtensionValue::Boolean(value_pair.as_str() == "true"),
299                    _ => panic!(
300                        "Unexpected extension scalar literal type: {:?}",
301                        value_pair.as_rule()
302                    ),
303                }
304            }
305            Rule::tuple => {
306                let tv = inner
307                    .into_inner()
308                    .map(|pair| ExtensionValue::parse_pair(extensions, pair))
309                    .collect::<Result<TupleValue, MessageParseError>>()?;
310                ExtensionValue::Tuple(tv)
311            }
312            Rule::expression => {
313                let expr = Expression::parse_pair(extensions, inner)?;
314                ExtensionValue::from(expr)
315            }
316            _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
317        })
318    }
319}
320
321impl ScopedParsePair for ExtensionColumn {
322    fn rule() -> Rule {
323        Rule::extension_column
324    }
325
326    fn message() -> &'static str {
327        "ExtensionColumn"
328    }
329
330    fn parse_pair(
331        extensions: &SimpleExtensions,
332        pair: pest::iterators::Pair<Rule>,
333    ) -> Result<Self, MessageParseError> {
334        assert_eq!(pair.as_rule(), Self::rule());
335
336        let inner = unwrap_single_pair(pair); // Extract the actual content
337
338        Ok(match inner.as_rule() {
339            Rule::named_column => {
340                let mut iter = inner.into_inner();
341                let name_pair = iter.next().unwrap(); // Grammar guarantees type exists
342                let type_pair = iter.next().unwrap(); // Grammar guarantees type exists
343
344                let name = Name::parse_pair(name_pair).0.to_string(); // Reuse existing Name parser
345                let ty = Type::parse_pair(extensions, type_pair)?;
346
347                ExtensionColumn::Named { name, r#type: ty }
348            }
349            Rule::reference => {
350                // Reuse the existing FieldIndex parser, then extract the i32
351                let field_index = FieldIndex::parse_pair(inner);
352                ExtensionColumn::Expr(Reference(field_index.0).into())
353            }
354            Rule::expression => {
355                let expr = Expression::parse_pair(extensions, inner)?;
356                ExtensionColumn::Expr(expr.into())
357            }
358            _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
359        })
360    }
361}
362
363/// Relation kind encoded by the text syntax prefix (`ExtensionLeaf`,
364/// `ExtensionSingle`, or `ExtensionMulti`).
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub(crate) enum ExtensionRelationKind {
367    Leaf,
368    Single,
369    Multi,
370}
371
372impl FromStr for ExtensionRelationKind {
373    type Err = String;
374
375    fn from_str(s: &str) -> Result<Self, Self::Err> {
376        match s {
377            "ExtensionLeaf" => Ok(ExtensionRelationKind::Leaf),
378            "ExtensionSingle" => Ok(ExtensionRelationKind::Single),
379            "ExtensionMulti" => Ok(ExtensionRelationKind::Multi),
380            _ => Err(format!("Unknown extension relation type: {s}")),
381        }
382    }
383}
384
385impl ExtensionRelationKind {
386    pub(crate) fn validate_child_count(self, child_count: usize) -> Result<(), String> {
387        match self {
388            ExtensionRelationKind::Leaf => {
389                if child_count == 0 {
390                    Ok(())
391                } else {
392                    Err(format!(
393                        "ExtensionLeaf should have no input children, got {child_count}"
394                    ))
395                }
396            }
397            ExtensionRelationKind::Single => {
398                if child_count == 1 {
399                    Ok(())
400                } else {
401                    Err(format!(
402                        "ExtensionSingle should have exactly 1 input child, got {child_count}"
403                    ))
404                }
405            }
406            ExtensionRelationKind::Multi => Ok(()),
407        }
408    }
409
410    /// Create appropriate relation structure from extension detail and children.
411    pub(crate) fn create_rel(
412        self,
413        detail: Option<Any>,
414        children: Vec<substrait::proto::Rel>,
415    ) -> substrait::proto::Rel {
416        use substrait::proto::rel::RelType;
417        use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
418
419        let rel_type = match self {
420            ExtensionRelationKind::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
421                common: None,
422                detail: detail.map(Into::into),
423            }),
424            ExtensionRelationKind::Single => {
425                let input = children.into_iter().next();
426                RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
427                    common: None,
428                    detail: detail.map(Into::into),
429                    input: input.map(Box::new),
430                }))
431            }
432            ExtensionRelationKind::Multi => RelType::ExtensionMulti(ExtensionMultiRel {
433                common: None,
434                detail: detail.map(Into::into),
435                inputs: children,
436            }),
437        };
438
439        substrait::proto::Rel {
440            rel_type: Some(rel_type),
441        }
442    }
443}
444
445/// Fully parsed extension invocation, including the user-supplied name and the
446/// structured argument payload.
447#[derive(Debug, Clone)]
448pub(crate) struct ExtensionInvocation {
449    pub(crate) relation_kind: ExtensionRelationKind,
450    pub(crate) name: String,
451    pub(crate) args: ExtensionArgs,
452}
453
454impl ScopedParsePair for ExtensionInvocation {
455    fn rule() -> Rule {
456        Rule::extension_relation
457    }
458
459    fn message() -> &'static str {
460        "ExtensionInvocation"
461    }
462
463    fn parse_pair(
464        extensions: &SimpleExtensions,
465        pair: pest::iterators::Pair<Rule>,
466    ) -> Result<Self, MessageParseError> {
467        assert_eq!(pair.as_rule(), Self::rule());
468
469        let mut iter = pair.into_inner();
470
471        // Parse extension name to determine relation type and custom name
472        let extension_name_pair = iter.next().unwrap(); // Grammar guarantees extension_name exists
473        let full_extension_name = extension_name_pair.as_str();
474
475        // Extract the relation type and custom name from the extension name
476        // (e.g., "ExtensionLeaf:ParquetScan" -> "ExtensionLeaf" and "ParquetScan")
477        let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
478            let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
479            (parts[0], parts[1].to_string())
480        } else {
481            (full_extension_name, "UnknownExtension".to_string())
482        };
483
484        let relation_kind = ExtensionRelationKind::from_str(relation_type_str).unwrap();
485        let mut args = ExtensionArgs::default();
486
487        // Parse optional arguments
488        let ext_arguments = iter.next().unwrap();
489        match ext_arguments.as_rule() {
490            Rule::arguments => {
491                arguments_rule_parsing(extensions, ext_arguments, &mut args)?;
492            }
493            r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
494        }
495
496        // parse optional output columns
497        let extension_columns = iter.next();
498        if let Some(value) = extension_columns {
499            match value.as_rule() {
500                Rule::extension_columns => {
501                    for col_pair in value.into_inner() {
502                        if col_pair.as_rule() == Rule::extension_column {
503                            let column = ExtensionColumn::parse_pair(extensions, col_pair)?;
504                            args.output_columns.push(column);
505                        }
506                    }
507                }
508                r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
509            }
510        }
511
512        Ok(ExtensionInvocation {
513            relation_kind,
514            name: custom_name,
515            args,
516        })
517    }
518}
519
520/// A parsed `+` addendum line.
521#[derive(Debug, Clone)]
522pub(crate) struct AddendumInvocation {
523    pub(crate) kind: AddendumKind,
524    pub(crate) name: String,
525    pub(crate) args: ExtensionArgs,
526}
527
528impl ScopedParsePair for AddendumInvocation {
529    fn rule() -> Rule {
530        Rule::addendum
531    }
532
533    fn message() -> &'static str {
534        "AddendumInvocation"
535    }
536
537    fn parse_pair(
538        extensions: &SimpleExtensions,
539        pair: pest::iterators::Pair<Rule>,
540    ) -> Result<Self, MessageParseError> {
541        assert_eq!(pair.as_rule(), Self::rule());
542
543        let mut iter = pair.into_inner();
544
545        // First token: addendum_type - grammar guarantees a known addendum prefix.
546        let type_pair = iter.next().unwrap(); // Grammar guarantees addendum_type exists
547        let kind = match type_pair.as_str() {
548            "Enh" => AddendumKind::Enhancement,
549            "Opt" => AddendumKind::Optimization,
550            "Ext" => AddendumKind::ExtensionTable,
551            other => unreachable!("Unexpected addendum_type: {other}"),
552        };
553
554        // Second token: name
555        let name_pair = iter.next().unwrap();
556        let name = Name::parse_pair(name_pair).0.to_string();
557
558        // Remaining token: arguments — grammar guarantees it is always present.
559        let mut args = ExtensionArgs::default();
560
561        let arguments_pair = iter.next().unwrap();
562        match arguments_pair.as_rule() {
563            Rule::arguments => {
564                arguments_rule_parsing(extensions, arguments_pair, &mut args)?;
565            }
566            r => unreachable!("Unexpected rule in AddendumInvocation args: {r:?}"),
567        }
568
569        Ok(AddendumInvocation { kind, name, args })
570    }
571}
572
573fn arguments_rule_parsing(
574    extensions: &SimpleExtensions,
575    inner_pair: pest::iterators::Pair<'_, Rule>,
576    args: &mut ExtensionArgs,
577) -> Result<(), MessageParseError> {
578    for arg in inner_pair.into_inner() {
579        match arg.as_rule() {
580            Rule::extension_arguments => {
581                for arg_pair in arg.into_inner() {
582                    assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
583                    args.push(ExtensionValue::parse_pair(extensions, arg_pair)?);
584                }
585            }
586            Rule::extension_named_arguments => {
587                for arg_pair in arg.into_inner() {
588                    assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
589                    let mut arg_iter = arg_pair.into_inner();
590                    let name_p = arg_iter.next().unwrap();
591                    let value_p = arg_iter.next().unwrap();
592                    let key = Name::parse_pair(name_p).0.to_string();
593                    let val = ExtensionValue::parse_pair(extensions, value_p)?;
594                    args.insert(key, val);
595                }
596            }
597            Rule::empty => {}
598            r => unreachable!("Unexpected rule in extension args: {r:?}"),
599        }
600    }
601    Ok(())
602}
603
604#[cfg(test)]
605mod tests {
606    use substrait::proto;
607    use substrait::proto::expression::RexType;
608    use substrait::proto::expression::literal::LiteralType;
609
610    use super::*;
611    use crate::OutputOptions;
612    use crate::extensions::{Expr, ExtensionValue};
613    use crate::fixtures::TestContext;
614    use crate::parser::Parser;
615    use crate::parser::common::test_support::ScopedParse;
616
617    fn parse_extension_value(text: &str) -> ExtensionValue {
618        ExtensionValue::parse(&SimpleExtensions::default(), text).unwrap()
619    }
620
621    #[test]
622    fn test_parse_urn_extension_declaration() {
623        let line = "@1: /my/urn1";
624        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
625        assert_eq!(urn.anchor, 1);
626        assert_eq!(urn.urn, "/my/urn1");
627    }
628
629    #[test]
630    fn test_parse_simple_extension_declaration() {
631        // Assumes a format like "@anchor: urn_anchor:name"
632        let line = "#5@2: my_function_name";
633        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
634        assert_eq!(decl.anchor, 5);
635        assert_eq!(decl.urn_anchor, 2);
636        assert_eq!(decl.name, "my_function_name");
637
638        // Test with a different name format, e.g. with underscores and numbers
639        let line2 = "#10  @200: another_ext_123";
640        let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
641        assert_eq!(decl.anchor, 10);
642        assert_eq!(decl.urn_anchor, 200);
643        assert_eq!(decl.name, "another_ext_123");
644    }
645
646    #[test]
647    fn test_parse_urn_extension_declaration_str() {
648        let line = "@1: /my/urn1";
649        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
650        assert_eq!(urn.anchor, 1);
651        assert_eq!(urn.urn, "/my/urn1");
652    }
653
654    #[test]
655    fn test_extensions_round_trip_plan() {
656        let input = r#"
657=== Extensions
658URNs:
659  @  1: /urn/common
660  @  2: /urn/specific_funcs
661Functions:
662  # 10 @  1: func_a
663  # 11 @  2: func_b_special
664Types:
665  # 20 @  1: SomeType
666Type Variations:
667  # 30 @  2: VarX
668"#
669        .trim_start();
670
671        // Parse the input using the structural parser
672        let plan = Parser::parse(input).unwrap();
673
674        // Verify the plan has the expected extensions
675        assert_eq!(plan.extension_urns.len(), 2);
676        assert_eq!(plan.extensions.len(), 4);
677
678        // Convert the plan extensions back to SimpleExtensions
679        let (extensions, errors) =
680            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
681
682        assert!(errors.is_empty());
683        // Convert back to string
684        let output = extensions.to_string("  ");
685
686        // The output should match the input
687        assert_eq!(output, input);
688    }
689
690    #[test]
691    fn test_parse_simple_extension_declaration_compound_name() {
692        // A function name that includes a Substrait signature suffix
693        let line = "#1 @2: equal:any_any";
694        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
695        assert_eq!(decl.anchor, 1);
696        assert_eq!(decl.urn_anchor, 2);
697        assert_eq!(decl.name, "equal:any_any");
698    }
699
700    #[test]
701    fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
702        let line = "#3 @1: regexp_match_substring:str_str_i64";
703        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
704        assert_eq!(decl.anchor, 3);
705        assert_eq!(decl.urn_anchor, 1);
706        assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
707    }
708
709    #[test]
710    fn test_extensions_round_trip_plan_with_compound_names() {
711        let input = r#"=== Extensions
712URNs:
713  @  1: extension:io.substrait:functions_string
714  @  2: extension:io.substrait:functions_comparison
715Functions:
716  #  1 @  2: equal:any_any
717  #  2 @  1: regexp_match_substring:str_str
718  #  3 @  1: regexp_match_substring:str_str_i64
719"#;
720        let plan = Parser::parse(input).unwrap();
721        let (extensions, errors) =
722            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
723        assert!(errors.is_empty());
724        // Compound names must survive the roundtrip
725        assert_eq!(
726            extensions
727                .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
728                .unwrap()
729                .1
730                .full(),
731            "equal:any_any"
732        );
733        assert_eq!(
734            extensions
735                .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
736                .unwrap()
737                .1
738                .full(),
739            "regexp_match_substring:str_str_i64"
740        );
741        // Text output must reproduce the input exactly
742        assert_eq!(extensions.to_string("  "), input);
743    }
744
745    #[test]
746    fn test_tuple_mixed_types_parses() {
747        // tuple has overlapping grammar syntax with expression.
748        let val = parse_extension_value("(&HASH, 8, 'hello')");
749        let ExtensionValue::Tuple(items) = val else {
750            panic!("expected Tuple, got {val:?}");
751        };
752        assert_eq!(items.len(), 3);
753        let items: Vec<&ExtensionValue> = items.iter().collect();
754        assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
755        assert_eq!(i64::try_from(items[1]).unwrap(), 8);
756        assert_eq!(<&str>::try_from(items[2]).unwrap(), "hello");
757    }
758
759    #[test]
760    fn test_empty_tuple_parses() {
761        let val = parse_extension_value("()");
762        let ExtensionValue::Tuple(items) = val else {
763            panic!("expected Tuple, got {val:?}");
764        };
765        assert!(items.is_empty());
766    }
767
768    #[test]
769    fn test_nested_tuple_parses() {
770        let val = parse_extension_value("((&HASH, &RANGE), 8)");
771        let ExtensionValue::Tuple(outer) = val else {
772            panic!("expected Tuple, got {val:?}");
773        };
774        assert_eq!(outer.len(), 2);
775        let ExtensionValue::Tuple(inner) = outer.iter().next().unwrap() else {
776            panic!("expected inner Tuple");
777        };
778        assert_eq!(inner.len(), 2);
779        assert!(matches!(inner.iter().next().unwrap(), ExtensionValue::Enum(s) if s == "HASH"));
780        assert_eq!(i64::try_from(outer.iter().nth(1).unwrap()).unwrap(), 8);
781    }
782
783    #[test]
784    fn test_tuple_in_addendum_parses() {
785        let inv = AddendumInvocation::parse(
786            &SimpleExtensions::default(),
787            "+ Enh:Foo[(&HASH, &RANGE), count=8]",
788        )
789        .unwrap();
790        assert_eq!(inv.kind, AddendumKind::Enhancement);
791        assert_eq!(inv.name, "Foo");
792        assert_eq!(inv.args.positional.len(), 1);
793        let ExtensionValue::Tuple(items) = &inv.args.positional[0] else {
794            panic!("expected Tuple positional arg");
795        };
796        assert_eq!(items.len(), 2);
797        let items: Vec<&ExtensionValue> = items.iter().collect();
798        assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
799        assert!(matches!(items[1], ExtensionValue::Enum(s) if s == "RANGE"));
800        assert_eq!(inv.args.named.len(), 1);
801    }
802
803    #[test]
804    fn extension_relation_kind_parses_text_prefixes() {
805        assert_eq!(
806            ExtensionRelationKind::from_str("ExtensionLeaf").unwrap(),
807            ExtensionRelationKind::Leaf
808        );
809        assert_eq!(
810            ExtensionRelationKind::from_str("ExtensionSingle").unwrap(),
811            ExtensionRelationKind::Single
812        );
813        assert_eq!(
814            ExtensionRelationKind::from_str("ExtensionMulti").unwrap(),
815            ExtensionRelationKind::Multi
816        );
817    }
818
819    #[test]
820    fn extension_multi_allows_any_child_count() {
821        assert!(ExtensionRelationKind::Multi.validate_child_count(0).is_ok());
822        assert!(ExtensionRelationKind::Multi.validate_child_count(1).is_ok());
823        assert!(ExtensionRelationKind::Multi.validate_child_count(3).is_ok());
824    }
825
826    #[test]
827    fn extension_single_rejects_wrong_child_counts() {
828        assert!(
829            ExtensionRelationKind::Single
830                .validate_child_count(0)
831                .is_err()
832        );
833        assert!(
834            ExtensionRelationKind::Single
835                .validate_child_count(2)
836                .is_err()
837        );
838    }
839
840    #[test]
841    fn test_tuple_textify_roundtrip() {
842        let ctx = TestContext::new();
843        for text in &[
844            "(&HASH, &RANGE)",
845            "(&HASH, 8, 'hello')",
846            "()",
847            "(&HASH,)",
848            "((&HASH, &RANGE), 8)",
849        ] {
850            let val = parse_extension_value(text);
851            let rendered = ctx.textify_no_errors(&val);
852            assert_eq!(&rendered, text, "roundtrip failed for {text}");
853        }
854    }
855
856    #[test]
857    fn test_literal_expression_value_textifies_to_canonical_literal() {
858        let expr = proto::Expression {
859            rex_type: Some(RexType::Literal(proto::expression::Literal {
860                literal_type: Some(LiteralType::I64(42)),
861                nullable: false,
862                type_variation_reference: 0,
863            })),
864        };
865        let value = ExtensionValue::from(expr.clone());
866        let ctx = TestContext::new();
867
868        let rendered = ctx.textify_no_errors(&value);
869        assert_eq!(rendered, "42");
870
871        let parsed = parse_extension_value(&rendered);
872        let parsed_expr = Expr::try_from(&parsed).unwrap();
873        assert_eq!(parsed_expr.as_proto(), &expr);
874    }
875
876    #[test]
877    fn test_extension_scalar_literals_stay_scalar_in_verbose_output() {
878        let ctx = TestContext::new().with_options(OutputOptions::verbose());
879
880        let scalar = ExtensionValue::from(42_i64);
881        assert_eq!(ctx.textify_no_errors(&scalar), "42");
882
883        let expression = ExtensionValue::from(Expr::from(42_i64));
884        assert_eq!(ctx.textify_no_errors(&expression), "42:i64");
885    }
886
887    #[test]
888    fn test_typed_extension_literal_parses_as_expression() {
889        let value = parse_extension_value("42:i16");
890        assert!(i64::try_from(&value).is_err());
891
892        let expr = Expr::try_from(&value).unwrap();
893        assert_eq!(ctx_text(&expr), "42:i16");
894    }
895
896    fn ctx_text(value: &Expr) -> String {
897        TestContext::new().textify_no_errors(value)
898    }
899}