Skip to main content

substrait_explain/parser/
extensions.rs

1use std::fmt;
2use std::str::FromStr;
3
4use substrait::proto::{Expression, Type};
5use thiserror::Error;
6
7use super::{
8    MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
9    unwrap_single_pair,
10};
11use crate::extensions::simple::{self, ExtensionKind};
12use crate::extensions::{
13    AddendumKind, ExtensionArgs, ExtensionColumn, ExtensionValue, InsertError, SimpleExtensions,
14    TupleValue,
15};
16use crate::parser::structural::IndentedLine;
17
18#[derive(Debug, Clone, Error)]
19pub enum ExtensionParseError {
20    #[error("Unexpected line, expected {0}")]
21    UnexpectedLine(ExtensionParserState),
22    #[error("Error adding extension: {0}")]
23    ExtensionError(#[from] InsertError),
24    #[error("Error parsing message: {0}")]
25    Message(#[from] super::MessageParseError),
26}
27
28/// The state of the extension parser - tracking what section of extension
29/// parsing we are in.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum ExtensionParserState {
32    // The extensions section, after parsing the 'Extensions:' header, before
33    // parsing any subsection headers.
34    Extensions,
35    // The extension URNs section, after parsing the 'URNs:' subsection header,
36    // and any URNs so far.
37    ExtensionUrns,
38    // In a subsection, after parsing the subsection header, and any
39    // declarations so far.
40    ExtensionDeclarations(ExtensionKind),
41}
42
43impl fmt::Display for ExtensionParserState {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
47            ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
48            ExtensionParserState::ExtensionDeclarations(kind) => {
49                write!(f, "Extension Declaration for {kind}")
50            }
51        }
52    }
53}
54
55/// The parser for the extension section of the Substrait file format.
56///
57/// This is responsible for parsing the extension section of the file, which
58/// contains the extension URNs and declarations. Note that this parser does not
59/// parse the header; otherwise, this is symmetric with the
60/// SimpleExtensions::write method.
61#[derive(Debug)]
62pub struct ExtensionParser {
63    state: ExtensionParserState,
64    extensions: SimpleExtensions,
65}
66
67impl Default for ExtensionParser {
68    fn default() -> Self {
69        Self {
70            state: ExtensionParserState::Extensions,
71            extensions: SimpleExtensions::new(),
72        }
73    }
74}
75
76impl ExtensionParser {
77    pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
78        if line.1.is_empty() {
79            // Blank lines are allowed between subsections, so if we see
80            // one, we revert out of the subsection.
81            self.state = ExtensionParserState::Extensions;
82            return Ok(());
83        }
84
85        match self.state {
86            ExtensionParserState::Extensions => self.parse_subsection(line),
87            ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
88            ExtensionParserState::ExtensionDeclarations(extension_kind) => {
89                self.parse_declarations(line, extension_kind)
90            }
91        }
92    }
93
94    fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
95        match line {
96            IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
97                self.state = ExtensionParserState::ExtensionUrns;
98                Ok(())
99            }
100            IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
101                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
102                Ok(())
103            }
104            IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
105                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
106                Ok(())
107            }
108            IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
109                self.state =
110                    ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
111                Ok(())
112            }
113            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
114        }
115    }
116
117    fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
118        match line {
119            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
120            IndentedLine(1, s) => {
121                let urn =
122                    URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
123                self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
124                Ok(())
125            }
126            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
127        }
128    }
129
130    fn parse_declarations(
131        &mut self,
132        line: IndentedLine,
133        extension_kind: ExtensionKind,
134    ) -> Result<(), ExtensionParseError> {
135        match line {
136            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
137            IndentedLine(1, s) => {
138                let decl = SimpleExtensionDeclaration::from_str(s)?;
139                self.extensions.add_extension(
140                    extension_kind,
141                    decl.urn_anchor,
142                    decl.anchor,
143                    decl.name,
144                )?;
145                Ok(())
146            }
147            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
148        }
149    }
150
151    pub fn extensions(&self) -> &SimpleExtensions {
152        &self.extensions
153    }
154
155    pub fn state(&self) -> ExtensionParserState {
156        self.state
157    }
158}
159
160#[derive(Debug, Clone, PartialEq)]
161pub struct URNExtensionDeclaration {
162    pub anchor: u32,
163    pub urn: String,
164}
165
166#[derive(Debug, Clone, PartialEq)]
167pub struct SimpleExtensionDeclaration {
168    pub anchor: u32,
169    pub urn_anchor: u32,
170    pub name: String,
171}
172
173impl ParsePair for URNExtensionDeclaration {
174    fn rule() -> Rule {
175        Rule::extension_urn_declaration
176    }
177
178    fn message() -> &'static str {
179        "URNExtensionDeclaration"
180    }
181
182    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
183        assert_eq!(pair.as_rule(), Self::rule());
184
185        let mut iter = RuleIter::from(pair.into_inner());
186        let anchor_pair = iter.pop(Rule::urn_anchor);
187        let anchor = unwrap_single_pair(anchor_pair)
188            .as_str()
189            .parse::<u32>()
190            .unwrap();
191        let urn = iter.pop(Rule::urn).as_str().to_string();
192        iter.done();
193
194        URNExtensionDeclaration { anchor, urn }
195    }
196}
197
198impl FromStr for URNExtensionDeclaration {
199    type Err = super::MessageParseError;
200
201    fn from_str(s: &str) -> Result<Self, Self::Err> {
202        Self::parse_str(s)
203    }
204}
205
206impl ParsePair for SimpleExtensionDeclaration {
207    fn rule() -> Rule {
208        Rule::simple_extension
209    }
210
211    fn message() -> &'static str {
212        "SimpleExtensionDeclaration"
213    }
214
215    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
216        assert_eq!(pair.as_rule(), Self::rule());
217        let mut iter = RuleIter::from(pair.into_inner());
218        let anchor_pair = iter.pop(Rule::anchor);
219        let anchor = unwrap_single_pair(anchor_pair)
220            .as_str()
221            .parse::<u32>()
222            .unwrap();
223        let urn_anchor_pair = iter.pop(Rule::urn_anchor);
224        let urn_anchor = unwrap_single_pair(urn_anchor_pair)
225            .as_str()
226            .parse::<u32>()
227            .unwrap();
228        // compound_name handles both plain names ("add") and compound names with signatures ("equal:any_any").
229        let name_pair = iter.pop(Rule::compound_name);
230        let name = name_pair.as_str().to_string();
231        iter.done();
232
233        SimpleExtensionDeclaration {
234            anchor,
235            urn_anchor,
236            name,
237        }
238    }
239}
240
241impl FromStr for SimpleExtensionDeclaration {
242    type Err = super::MessageParseError;
243
244    fn from_str(s: &str) -> Result<Self, Self::Err> {
245        Self::parse_str(s)
246    }
247}
248
249// Extension relation parsing implementations
250// These were moved from extensions/registry.rs to maintain clean architecture
251
252use crate::extensions::any::Any;
253use crate::parser::expressions::{FieldIndex, Name};
254use crate::textify::expressions::Reference;
255
256impl ScopedParsePair for ExtensionValue {
257    fn rule() -> Rule {
258        Rule::extension_argument
259    }
260
261    fn message() -> &'static str {
262        "ExtensionValue"
263    }
264
265    fn parse_pair(
266        extensions: &SimpleExtensions,
267        pair: pest::iterators::Pair<Rule>,
268    ) -> Result<Self, MessageParseError> {
269        assert_eq!(pair.as_rule(), Self::rule());
270
271        let inner = unwrap_single_pair(pair); // Extract the actual content
272
273        Ok(match inner.as_rule() {
274            Rule::enum_value => {
275                // Strip leading '&' and store the identifier
276                let s = inner.as_str().trim_start_matches('&').to_string();
277                ExtensionValue::Enum(s)
278            }
279            Rule::reference => {
280                // Reuse the existing FieldIndex parser, then extract the i32
281                let field_index = FieldIndex::parse_pair(inner);
282                ExtensionValue::from(Reference(field_index.0))
283            }
284            Rule::untyped_literal => {
285                // Literal can contain integer, float, boolean, or string_literal
286                let value_pair = unwrap_single_pair(inner);
287                match value_pair.as_rule() {
288                    Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
289                    Rule::integer => {
290                        ExtensionValue::Integer(value_pair.as_str().parse::<i64>().unwrap())
291                    }
292                    Rule::float => {
293                        ExtensionValue::Float(value_pair.as_str().parse::<f64>().unwrap())
294                    }
295                    Rule::boolean => ExtensionValue::Boolean(value_pair.as_str() == "true"),
296                    _ => panic!(
297                        "Unexpected extension scalar literal type: {:?}",
298                        value_pair.as_rule()
299                    ),
300                }
301            }
302            Rule::tuple => {
303                let tv = inner
304                    .into_inner()
305                    .map(|pair| ExtensionValue::parse_pair(extensions, pair))
306                    .collect::<Result<TupleValue, MessageParseError>>()?;
307                ExtensionValue::Tuple(tv)
308            }
309            Rule::expression => {
310                let expr = Expression::parse_pair(extensions, inner)?;
311                ExtensionValue::from(expr)
312            }
313            _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
314        })
315    }
316}
317
318impl ScopedParsePair for ExtensionColumn {
319    fn rule() -> Rule {
320        Rule::extension_column
321    }
322
323    fn message() -> &'static str {
324        "ExtensionColumn"
325    }
326
327    fn parse_pair(
328        extensions: &SimpleExtensions,
329        pair: pest::iterators::Pair<Rule>,
330    ) -> Result<Self, MessageParseError> {
331        assert_eq!(pair.as_rule(), Self::rule());
332
333        let inner = unwrap_single_pair(pair); // Extract the actual content
334
335        Ok(match inner.as_rule() {
336            Rule::named_column => {
337                let mut iter = inner.into_inner();
338                let name_pair = iter.next().unwrap(); // Grammar guarantees type exists
339                let type_pair = iter.next().unwrap(); // Grammar guarantees type exists
340
341                let name = Name::parse_pair(name_pair).0.to_string(); // Reuse existing Name parser
342                let ty = Type::parse_pair(extensions, type_pair)?;
343
344                ExtensionColumn::Named { name, r#type: ty }
345            }
346            Rule::reference => {
347                // Reuse the existing FieldIndex parser, then extract the i32
348                let field_index = FieldIndex::parse_pair(inner);
349                ExtensionColumn::Expr(Reference(field_index.0).into())
350            }
351            Rule::expression => {
352                let expr = Expression::parse_pair(extensions, inner)?;
353                ExtensionColumn::Expr(expr.into())
354            }
355            _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
356        })
357    }
358}
359
360/// Relation kind encoded by the text syntax prefix (`ExtensionLeaf`,
361/// `ExtensionSingle`, or `ExtensionMulti`).
362#[derive(Debug, Clone, Copy, PartialEq, Eq)]
363pub(crate) enum ExtensionRelationKind {
364    Leaf,
365    Single,
366    Multi,
367}
368
369impl FromStr for ExtensionRelationKind {
370    type Err = String;
371
372    fn from_str(s: &str) -> Result<Self, Self::Err> {
373        match s {
374            "ExtensionLeaf" => Ok(ExtensionRelationKind::Leaf),
375            "ExtensionSingle" => Ok(ExtensionRelationKind::Single),
376            "ExtensionMulti" => Ok(ExtensionRelationKind::Multi),
377            _ => Err(format!("Unknown extension relation type: {s}")),
378        }
379    }
380}
381
382impl ExtensionRelationKind {
383    pub(crate) fn validate_child_count(self, child_count: usize) -> Result<(), String> {
384        match self {
385            ExtensionRelationKind::Leaf => {
386                if child_count == 0 {
387                    Ok(())
388                } else {
389                    Err(format!(
390                        "ExtensionLeaf should have no input children, got {child_count}"
391                    ))
392                }
393            }
394            ExtensionRelationKind::Single => {
395                if child_count == 1 {
396                    Ok(())
397                } else {
398                    Err(format!(
399                        "ExtensionSingle should have exactly 1 input child, got {child_count}"
400                    ))
401                }
402            }
403            ExtensionRelationKind::Multi => Ok(()),
404        }
405    }
406
407    /// Create appropriate relation structure from extension detail and children.
408    pub(crate) fn create_rel(
409        self,
410        detail: Option<Any>,
411        children: Vec<substrait::proto::Rel>,
412    ) -> substrait::proto::Rel {
413        use substrait::proto::rel::RelType;
414        use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
415
416        let rel_type = match self {
417            ExtensionRelationKind::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
418                common: None,
419                detail: detail.map(Into::into),
420            }),
421            ExtensionRelationKind::Single => {
422                let input = children.into_iter().next();
423                RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
424                    common: None,
425                    detail: detail.map(Into::into),
426                    input: input.map(Box::new),
427                }))
428            }
429            ExtensionRelationKind::Multi => RelType::ExtensionMulti(ExtensionMultiRel {
430                common: None,
431                detail: detail.map(Into::into),
432                inputs: children,
433            }),
434        };
435
436        substrait::proto::Rel {
437            rel_type: Some(rel_type),
438        }
439    }
440}
441
442/// Fully parsed extension invocation, including the user-supplied name and the
443/// structured argument payload.
444#[derive(Debug, Clone)]
445pub(crate) struct ExtensionInvocation {
446    pub(crate) relation_kind: ExtensionRelationKind,
447    pub(crate) name: String,
448    pub(crate) args: ExtensionArgs,
449}
450
451impl ScopedParsePair for ExtensionInvocation {
452    fn rule() -> Rule {
453        Rule::extension_relation
454    }
455
456    fn message() -> &'static str {
457        "ExtensionInvocation"
458    }
459
460    fn parse_pair(
461        extensions: &SimpleExtensions,
462        pair: pest::iterators::Pair<Rule>,
463    ) -> Result<Self, MessageParseError> {
464        assert_eq!(pair.as_rule(), Self::rule());
465
466        let mut iter = pair.into_inner();
467
468        // Parse extension name to determine relation type and custom name
469        let extension_name_pair = iter.next().unwrap(); // Grammar guarantees extension_name exists
470        let full_extension_name = extension_name_pair.as_str();
471
472        // Extract the relation type and custom name from the extension name
473        // (e.g., "ExtensionLeaf:ParquetScan" -> "ExtensionLeaf" and "ParquetScan")
474        let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
475            let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
476            (parts[0], parts[1].to_string())
477        } else {
478            (full_extension_name, "UnknownExtension".to_string())
479        };
480
481        let relation_kind = ExtensionRelationKind::from_str(relation_type_str).unwrap();
482        let mut args = ExtensionArgs::default();
483
484        // Parse optional arguments
485        let ext_arguments = iter.next().unwrap();
486        match ext_arguments.as_rule() {
487            Rule::arguments => {
488                arguments_rule_parsing(extensions, ext_arguments, &mut args)?;
489            }
490            r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
491        }
492
493        // parse optional output columns
494        let extension_columns = iter.next();
495        if let Some(value) = extension_columns {
496            match value.as_rule() {
497                Rule::extension_columns => {
498                    for col_pair in value.into_inner() {
499                        if col_pair.as_rule() == Rule::extension_column {
500                            let column = ExtensionColumn::parse_pair(extensions, col_pair)?;
501                            args.output_columns.push(column);
502                        }
503                    }
504                }
505                r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
506            }
507        }
508
509        Ok(ExtensionInvocation {
510            relation_kind,
511            name: custom_name,
512            args,
513        })
514    }
515}
516
517/// A parsed `+` addendum line.
518#[derive(Debug, Clone)]
519pub(crate) struct AddendumInvocation {
520    pub(crate) kind: AddendumKind,
521    pub(crate) name: String,
522    pub(crate) args: ExtensionArgs,
523}
524
525impl ScopedParsePair for AddendumInvocation {
526    fn rule() -> Rule {
527        Rule::addendum
528    }
529
530    fn message() -> &'static str {
531        "AddendumInvocation"
532    }
533
534    fn parse_pair(
535        extensions: &SimpleExtensions,
536        pair: pest::iterators::Pair<Rule>,
537    ) -> Result<Self, MessageParseError> {
538        assert_eq!(pair.as_rule(), Self::rule());
539
540        let mut iter = pair.into_inner();
541
542        // First token: addendum_type - grammar guarantees a known addendum prefix.
543        let type_pair = iter.next().unwrap(); // Grammar guarantees addendum_type exists
544        let kind = match type_pair.as_str() {
545            "Enh" => AddendumKind::Enhancement,
546            "Opt" => AddendumKind::Optimization,
547            "Ext" => AddendumKind::ExtensionTable,
548            other => unreachable!("Unexpected addendum_type: {other}"),
549        };
550
551        // Second token: name
552        let name_pair = iter.next().unwrap();
553        let name = Name::parse_pair(name_pair).0.to_string();
554
555        // Remaining token: arguments — grammar guarantees it is always present.
556        let mut args = ExtensionArgs::default();
557
558        let arguments_pair = iter.next().unwrap();
559        match arguments_pair.as_rule() {
560            Rule::arguments => {
561                arguments_rule_parsing(extensions, arguments_pair, &mut args)?;
562            }
563            r => unreachable!("Unexpected rule in AddendumInvocation args: {r:?}"),
564        }
565
566        Ok(AddendumInvocation { kind, name, args })
567    }
568}
569
570fn arguments_rule_parsing(
571    extensions: &SimpleExtensions,
572    inner_pair: pest::iterators::Pair<'_, Rule>,
573    args: &mut ExtensionArgs,
574) -> Result<(), MessageParseError> {
575    for arg in inner_pair.into_inner() {
576        match arg.as_rule() {
577            Rule::extension_arguments => {
578                for arg_pair in arg.into_inner() {
579                    assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
580                    args.push(ExtensionValue::parse_pair(extensions, arg_pair)?);
581                }
582            }
583            Rule::extension_named_arguments => {
584                for arg_pair in arg.into_inner() {
585                    assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
586                    let mut arg_iter = arg_pair.into_inner();
587                    let name_p = arg_iter.next().unwrap();
588                    let value_p = arg_iter.next().unwrap();
589                    let key = Name::parse_pair(name_p).0.to_string();
590                    let val = ExtensionValue::parse_pair(extensions, value_p)?;
591                    args.insert(key, val);
592                }
593            }
594            Rule::empty => {}
595            r => unreachable!("Unexpected rule in extension args: {r:?}"),
596        }
597    }
598    Ok(())
599}
600
601#[cfg(test)]
602mod tests {
603    use substrait::proto;
604    use substrait::proto::expression::RexType;
605    use substrait::proto::expression::literal::LiteralType;
606
607    use super::*;
608    use crate::OutputOptions;
609    use crate::extensions::{Expr, ExtensionValue};
610    use crate::fixtures::TestContext;
611    use crate::parser::{Parser, ScopedParse};
612
613    fn parse_extension_value(text: &str) -> ExtensionValue {
614        ExtensionValue::parse(&SimpleExtensions::default(), text).unwrap()
615    }
616
617    #[test]
618    fn test_parse_urn_extension_declaration() {
619        let line = "@1: /my/urn1";
620        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
621        assert_eq!(urn.anchor, 1);
622        assert_eq!(urn.urn, "/my/urn1");
623    }
624
625    #[test]
626    fn test_parse_simple_extension_declaration() {
627        // Assumes a format like "@anchor: urn_anchor:name"
628        let line = "#5@2: my_function_name";
629        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
630        assert_eq!(decl.anchor, 5);
631        assert_eq!(decl.urn_anchor, 2);
632        assert_eq!(decl.name, "my_function_name");
633
634        // Test with a different name format, e.g. with underscores and numbers
635        let line2 = "#10  @200: another_ext_123";
636        let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
637        assert_eq!(decl.anchor, 10);
638        assert_eq!(decl.urn_anchor, 200);
639        assert_eq!(decl.name, "another_ext_123");
640    }
641
642    #[test]
643    fn test_parse_urn_extension_declaration_str() {
644        let line = "@1: /my/urn1";
645        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
646        assert_eq!(urn.anchor, 1);
647        assert_eq!(urn.urn, "/my/urn1");
648    }
649
650    #[test]
651    fn test_extensions_round_trip_plan() {
652        let input = r#"
653=== Extensions
654URNs:
655  @  1: /urn/common
656  @  2: /urn/specific_funcs
657Functions:
658  # 10 @  1: func_a
659  # 11 @  2: func_b_special
660Types:
661  # 20 @  1: SomeType
662Type Variations:
663  # 30 @  2: VarX
664"#
665        .trim_start();
666
667        // Parse the input using the structural parser
668        let plan = Parser::parse(input).unwrap();
669
670        // Verify the plan has the expected extensions
671        assert_eq!(plan.extension_urns.len(), 2);
672        assert_eq!(plan.extensions.len(), 4);
673
674        // Convert the plan extensions back to SimpleExtensions
675        let (extensions, errors) =
676            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
677
678        assert!(errors.is_empty());
679        // Convert back to string
680        let output = extensions.to_string("  ");
681
682        // The output should match the input
683        assert_eq!(output, input);
684    }
685
686    #[test]
687    fn test_parse_simple_extension_declaration_compound_name() {
688        // A function name that includes a Substrait signature suffix
689        let line = "#1 @2: equal:any_any";
690        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
691        assert_eq!(decl.anchor, 1);
692        assert_eq!(decl.urn_anchor, 2);
693        assert_eq!(decl.name, "equal:any_any");
694    }
695
696    #[test]
697    fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
698        let line = "#3 @1: regexp_match_substring:str_str_i64";
699        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
700        assert_eq!(decl.anchor, 3);
701        assert_eq!(decl.urn_anchor, 1);
702        assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
703    }
704
705    #[test]
706    fn test_extensions_round_trip_plan_with_compound_names() {
707        let input = r#"=== Extensions
708URNs:
709  @  1: extension:io.substrait:functions_string
710  @  2: extension:io.substrait:functions_comparison
711Functions:
712  #  1 @  2: equal:any_any
713  #  2 @  1: regexp_match_substring:str_str
714  #  3 @  1: regexp_match_substring:str_str_i64
715"#;
716        let plan = Parser::parse(input).unwrap();
717        let (extensions, errors) =
718            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
719        assert!(errors.is_empty());
720        // Compound names must survive the roundtrip
721        assert_eq!(
722            extensions
723                .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
724                .unwrap()
725                .1
726                .full(),
727            "equal:any_any"
728        );
729        assert_eq!(
730            extensions
731                .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
732                .unwrap()
733                .1
734                .full(),
735            "regexp_match_substring:str_str_i64"
736        );
737        // Text output must reproduce the input exactly
738        assert_eq!(extensions.to_string("  "), input);
739    }
740
741    #[test]
742    fn test_tuple_mixed_types_parses() {
743        // tuple has overlapping grammar syntax with expression.
744        let val = parse_extension_value("(&HASH, 8, 'hello')");
745        let ExtensionValue::Tuple(items) = val else {
746            panic!("expected Tuple, got {val:?}");
747        };
748        assert_eq!(items.len(), 3);
749        let items: Vec<&ExtensionValue> = items.iter().collect();
750        assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
751        assert_eq!(i64::try_from(items[1]).unwrap(), 8);
752        assert_eq!(<&str>::try_from(items[2]).unwrap(), "hello");
753    }
754
755    #[test]
756    fn test_empty_tuple_parses() {
757        let val = parse_extension_value("()");
758        let ExtensionValue::Tuple(items) = val else {
759            panic!("expected Tuple, got {val:?}");
760        };
761        assert!(items.is_empty());
762    }
763
764    #[test]
765    fn test_nested_tuple_parses() {
766        let val = parse_extension_value("((&HASH, &RANGE), 8)");
767        let ExtensionValue::Tuple(outer) = val else {
768            panic!("expected Tuple, got {val:?}");
769        };
770        assert_eq!(outer.len(), 2);
771        let ExtensionValue::Tuple(inner) = outer.iter().next().unwrap() else {
772            panic!("expected inner Tuple");
773        };
774        assert_eq!(inner.len(), 2);
775        assert!(matches!(inner.iter().next().unwrap(), ExtensionValue::Enum(s) if s == "HASH"));
776        assert_eq!(i64::try_from(outer.iter().nth(1).unwrap()).unwrap(), 8);
777    }
778
779    #[test]
780    fn test_tuple_in_addendum_parses() {
781        let inv = AddendumInvocation::parse(
782            &SimpleExtensions::default(),
783            "+ Enh:Foo[(&HASH, &RANGE), count=8]",
784        )
785        .unwrap();
786        assert_eq!(inv.kind, AddendumKind::Enhancement);
787        assert_eq!(inv.name, "Foo");
788        assert_eq!(inv.args.positional.len(), 1);
789        let ExtensionValue::Tuple(items) = &inv.args.positional[0] else {
790            panic!("expected Tuple positional arg");
791        };
792        assert_eq!(items.len(), 2);
793        let items: Vec<&ExtensionValue> = items.iter().collect();
794        assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
795        assert!(matches!(items[1], ExtensionValue::Enum(s) if s == "RANGE"));
796        assert_eq!(inv.args.named.len(), 1);
797    }
798
799    #[test]
800    fn extension_relation_kind_parses_text_prefixes() {
801        assert_eq!(
802            ExtensionRelationKind::from_str("ExtensionLeaf").unwrap(),
803            ExtensionRelationKind::Leaf
804        );
805        assert_eq!(
806            ExtensionRelationKind::from_str("ExtensionSingle").unwrap(),
807            ExtensionRelationKind::Single
808        );
809        assert_eq!(
810            ExtensionRelationKind::from_str("ExtensionMulti").unwrap(),
811            ExtensionRelationKind::Multi
812        );
813    }
814
815    #[test]
816    fn extension_multi_allows_any_child_count() {
817        assert!(ExtensionRelationKind::Multi.validate_child_count(0).is_ok());
818        assert!(ExtensionRelationKind::Multi.validate_child_count(1).is_ok());
819        assert!(ExtensionRelationKind::Multi.validate_child_count(3).is_ok());
820    }
821
822    #[test]
823    fn extension_single_rejects_wrong_child_counts() {
824        assert!(
825            ExtensionRelationKind::Single
826                .validate_child_count(0)
827                .is_err()
828        );
829        assert!(
830            ExtensionRelationKind::Single
831                .validate_child_count(2)
832                .is_err()
833        );
834    }
835
836    #[test]
837    fn test_tuple_textify_roundtrip() {
838        let ctx = TestContext::new();
839        for text in &[
840            "(&HASH, &RANGE)",
841            "(&HASH, 8, 'hello')",
842            "()",
843            "(&HASH,)",
844            "((&HASH, &RANGE), 8)",
845        ] {
846            let val = parse_extension_value(text);
847            let rendered = ctx.textify_no_errors(&val);
848            assert_eq!(&rendered, text, "roundtrip failed for {text}");
849        }
850    }
851
852    #[test]
853    fn test_literal_expression_value_textifies_to_canonical_literal() {
854        let expr = proto::Expression {
855            rex_type: Some(RexType::Literal(proto::expression::Literal {
856                literal_type: Some(LiteralType::I64(42)),
857                nullable: false,
858                type_variation_reference: 0,
859            })),
860        };
861        let value = ExtensionValue::from(expr.clone());
862        let ctx = TestContext::new();
863
864        let rendered = ctx.textify_no_errors(&value);
865        assert_eq!(rendered, "42");
866
867        let parsed = parse_extension_value(&rendered);
868        let parsed_expr = Expr::try_from(&parsed).unwrap();
869        assert_eq!(parsed_expr.as_proto(), &expr);
870    }
871
872    #[test]
873    fn test_extension_scalar_literals_stay_scalar_in_verbose_output() {
874        let ctx = TestContext::new().with_options(OutputOptions::verbose());
875
876        let scalar = ExtensionValue::from(42_i64);
877        assert_eq!(ctx.textify_no_errors(&scalar), "42");
878
879        let expression = ExtensionValue::from(Expr::from(42_i64));
880        assert_eq!(ctx.textify_no_errors(&expression), "42:i64");
881    }
882
883    #[test]
884    fn test_typed_extension_literal_parses_as_expression() {
885        let value = parse_extension_value("42:i16");
886        assert!(i64::try_from(&value).is_err());
887
888        let expr = Expr::try_from(&value).unwrap();
889        assert_eq!(ctx_text(&expr), "42:i16");
890    }
891
892    fn ctx_text(value: &Expr) -> String {
893        TestContext::new().textify_no_errors(value)
894    }
895}