substrait_explain/parser/
extensions.rs

1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unescape_string, unwrap_single_pair};
7use crate::extensions::registry::ExtensionType;
8use crate::extensions::simple::{self, ExtensionKind};
9use crate::extensions::{
10    ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, InsertError,
11    RawExpression, SimpleExtensions,
12};
13use crate::parser::structural::IndentedLine;
14
15#[derive(Debug, Clone, Error)]
16pub enum ExtensionParseError {
17    #[error("Unexpected line, expected {0}")]
18    UnexpectedLine(ExtensionParserState),
19    #[error("Error adding extension: {0}")]
20    ExtensionError(#[from] InsertError),
21    #[error("Error parsing message: {0}")]
22    Message(#[from] super::MessageParseError),
23}
24
25/// The state of the extension parser - tracking what section of extension
26/// parsing we are in.
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub enum ExtensionParserState {
29    // The extensions section, after parsing the 'Extensions:' header, before
30    // parsing any subsection headers.
31    Extensions,
32    // The extension URNs section, after parsing the 'URNs:' subsection header,
33    // and any URNs so far.
34    ExtensionUrns,
35    // In a subsection, after parsing the subsection header, and any
36    // declarations so far.
37    ExtensionDeclarations(ExtensionKind),
38}
39
40impl fmt::Display for ExtensionParserState {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
44            ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
45            ExtensionParserState::ExtensionDeclarations(kind) => {
46                write!(f, "Extension Declaration for {kind}")
47            }
48        }
49    }
50}
51
52/// The parser for the extension section of the Substrait file format.
53///
54/// This is responsible for parsing the extension section of the file, which
55/// contains the extension URNs and declarations. Note that this parser does not
56/// parse the header; otherwise, this is symmetric with the
57/// SimpleExtensions::write method.
58#[derive(Debug)]
59pub struct ExtensionParser {
60    state: ExtensionParserState,
61    extensions: SimpleExtensions,
62}
63
64impl Default for ExtensionParser {
65    fn default() -> Self {
66        Self {
67            state: ExtensionParserState::Extensions,
68            extensions: SimpleExtensions::new(),
69        }
70    }
71}
72
73impl ExtensionParser {
74    pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
75        if line.1.is_empty() {
76            // Blank lines are allowed between subsections, so if we see
77            // one, we revert out of the subsection.
78            self.state = ExtensionParserState::Extensions;
79            return Ok(());
80        }
81
82        match self.state {
83            ExtensionParserState::Extensions => self.parse_subsection(line),
84            ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
85            ExtensionParserState::ExtensionDeclarations(extension_kind) => {
86                self.parse_declarations(line, extension_kind)
87            }
88        }
89    }
90
91    fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
92        match line {
93            IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
94                self.state = ExtensionParserState::ExtensionUrns;
95                Ok(())
96            }
97            IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
98                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
99                Ok(())
100            }
101            IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
102                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
103                Ok(())
104            }
105            IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
106                self.state =
107                    ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
108                Ok(())
109            }
110            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
111        }
112    }
113
114    fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
115        match line {
116            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
117            IndentedLine(1, s) => {
118                let urn =
119                    URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
120                self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
121                Ok(())
122            }
123            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
124        }
125    }
126
127    fn parse_declarations(
128        &mut self,
129        line: IndentedLine,
130        extension_kind: ExtensionKind,
131    ) -> Result<(), ExtensionParseError> {
132        match line {
133            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
134            IndentedLine(1, s) => {
135                let decl = SimpleExtensionDeclaration::from_str(s)?;
136                self.extensions.add_extension(
137                    extension_kind,
138                    decl.urn_anchor,
139                    decl.anchor,
140                    decl.name,
141                )?;
142                Ok(())
143            }
144            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
145        }
146    }
147
148    pub fn extensions(&self) -> &SimpleExtensions {
149        &self.extensions
150    }
151
152    pub fn state(&self) -> ExtensionParserState {
153        self.state
154    }
155}
156
157#[derive(Debug, Clone, PartialEq)]
158pub struct URNExtensionDeclaration {
159    pub anchor: u32,
160    pub urn: String,
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub struct SimpleExtensionDeclaration {
165    pub anchor: u32,
166    pub urn_anchor: u32,
167    pub name: String,
168}
169
170impl ParsePair for URNExtensionDeclaration {
171    fn rule() -> Rule {
172        Rule::extension_urn_declaration
173    }
174
175    fn message() -> &'static str {
176        "URNExtensionDeclaration"
177    }
178
179    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
180        assert_eq!(pair.as_rule(), Self::rule());
181
182        let mut iter = RuleIter::from(pair.into_inner());
183        let anchor_pair = iter.pop(Rule::urn_anchor);
184        let anchor = unwrap_single_pair(anchor_pair)
185            .as_str()
186            .parse::<u32>()
187            .unwrap();
188        let urn = iter.pop(Rule::urn).as_str().to_string();
189        iter.done();
190
191        URNExtensionDeclaration { anchor, urn }
192    }
193}
194
195impl FromStr for URNExtensionDeclaration {
196    type Err = super::MessageParseError;
197
198    fn from_str(s: &str) -> Result<Self, Self::Err> {
199        Self::parse_str(s)
200    }
201}
202
203impl ParsePair for SimpleExtensionDeclaration {
204    fn rule() -> Rule {
205        Rule::simple_extension
206    }
207
208    fn message() -> &'static str {
209        "SimpleExtensionDeclaration"
210    }
211
212    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
213        assert_eq!(pair.as_rule(), Self::rule());
214        let mut iter = RuleIter::from(pair.into_inner());
215        let anchor_pair = iter.pop(Rule::anchor);
216        let anchor = unwrap_single_pair(anchor_pair)
217            .as_str()
218            .parse::<u32>()
219            .unwrap();
220        let urn_anchor_pair = iter.pop(Rule::urn_anchor);
221        let urn_anchor = unwrap_single_pair(urn_anchor_pair)
222            .as_str()
223            .parse::<u32>()
224            .unwrap();
225        // compound_name handles both plain names ("add") and compound names with signatures ("equal:any_any").
226        let name_pair = iter.pop(Rule::compound_name);
227        let name = name_pair.as_str().to_string();
228        iter.done();
229
230        SimpleExtensionDeclaration {
231            anchor,
232            urn_anchor,
233            name,
234        }
235    }
236}
237
238impl FromStr for SimpleExtensionDeclaration {
239    type Err = super::MessageParseError;
240
241    fn from_str(s: &str) -> Result<Self, Self::Err> {
242        Self::parse_str(s)
243    }
244}
245
246// Extension relation parsing implementations
247// These were moved from extensions/registry.rs to maintain clean architecture
248
249use crate::extensions::any::Any;
250use crate::parser::expressions::{FieldIndex, Name};
251
252impl ParsePair for ExtensionValue {
253    fn rule() -> Rule {
254        Rule::extension_argument
255    }
256
257    fn message() -> &'static str {
258        "ExtensionValue"
259    }
260
261    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
262        assert_eq!(pair.as_rule(), Self::rule());
263
264        let inner = unwrap_single_pair(pair); // Extract the actual content
265
266        match inner.as_rule() {
267            Rule::enum_value => {
268                // Strip leading '&' and store the identifier
269                let s = inner.as_str().trim_start_matches('&').to_string();
270                ExtensionValue::Enum(s)
271            }
272            Rule::reference => {
273                // Reuse the existing FieldIndex parser, then extract the i32
274                let field_index = FieldIndex::parse_pair(inner);
275                ExtensionValue::Reference(field_index.0)
276            }
277            Rule::literal => {
278                // Literal can contain integer, float, boolean, or string_literal
279                let mut literal_inner = inner.into_inner();
280                let value_pair = literal_inner.next().unwrap();
281                match value_pair.as_rule() {
282                    Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
283                    Rule::integer => {
284                        let int_val = value_pair.as_str().parse::<i64>().unwrap();
285                        ExtensionValue::Integer(int_val)
286                    }
287                    Rule::float => {
288                        let float_val = value_pair.as_str().parse::<f64>().unwrap();
289                        ExtensionValue::Float(float_val)
290                    }
291                    Rule::boolean => {
292                        let bool_val = value_pair.as_str() == "true";
293                        ExtensionValue::Boolean(bool_val)
294                    }
295                    _ => panic!("Unexpected literal value type: {:?}", value_pair.as_rule()),
296                }
297            }
298            Rule::string_literal => ExtensionValue::String(unescape_string(inner)),
299            Rule::integer => {
300                // Direct integer (not wrapped in literal rule)
301                let int_val = inner.as_str().parse::<i64>().unwrap();
302                ExtensionValue::Integer(int_val)
303            }
304            Rule::float => {
305                // Direct float (not wrapped in literal rule)
306                let float_val = inner.as_str().parse::<f64>().unwrap();
307                ExtensionValue::Float(float_val)
308            }
309            Rule::boolean => {
310                // Direct boolean (not wrapped in literal rule)
311                let bool_val = inner.as_str() == "true";
312                ExtensionValue::Boolean(bool_val)
313            }
314            Rule::expression => {
315                ExtensionValue::Expression(RawExpression::new(inner.as_str().to_string()))
316            }
317            _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
318        }
319    }
320}
321
322impl ParsePair for ExtensionColumn {
323    fn rule() -> Rule {
324        Rule::extension_column
325    }
326
327    fn message() -> &'static str {
328        "ExtensionColumn"
329    }
330
331    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
332        assert_eq!(pair.as_rule(), Self::rule());
333
334        let inner = unwrap_single_pair(pair); // Extract the actual content
335
336        match inner.as_rule() {
337            Rule::named_column => {
338                let mut iter = inner.into_inner();
339                let name_pair = iter.next().unwrap(); // Grammar guarantees name exists
340                let type_pair = iter.next().unwrap(); // Grammar guarantees type exists
341
342                let name = Name::parse_pair(name_pair).0.to_string(); // Reuse existing Name parser
343                let type_spec = type_pair.as_str().to_string(); // Types are complex, store as string for now
344
345                ExtensionColumn::Named { name, type_spec }
346            }
347            Rule::reference => {
348                // Reuse the existing FieldIndex parser, then extract the i32
349                let field_index = FieldIndex::parse_pair(inner);
350                ExtensionColumn::Reference(field_index.0)
351            }
352            Rule::expression => {
353                ExtensionColumn::Expression(RawExpression::new(inner.as_str().to_string()))
354            }
355            _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
356        }
357    }
358}
359
360/// Fully parsed extension invocation, including the user-supplied name and the
361/// structured argument payload.
362#[derive(Debug, Clone)]
363pub struct ExtensionInvocation {
364    pub name: String,
365    pub args: ExtensionArgs,
366}
367
368impl ParsePair for ExtensionInvocation {
369    fn rule() -> Rule {
370        Rule::extension_relation
371    }
372
373    fn message() -> &'static str {
374        "ExtensionInvocation"
375    }
376
377    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
378        assert_eq!(pair.as_rule(), Self::rule());
379
380        let mut iter = pair.into_inner();
381
382        // Parse extension name to determine relation type and custom name
383        let extension_name_pair = iter.next().unwrap(); // Grammar guarantees extension_name exists
384        let full_extension_name = extension_name_pair.as_str();
385
386        // Extract the relation type and custom name from the extension name
387        // (e.g., "ExtensionLeaf:ParquetScan" -> "ExtensionLeaf" and "ParquetScan")
388        let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
389            let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
390            (parts[0], parts[1].to_string())
391        } else {
392            (full_extension_name, "UnknownExtension".to_string())
393        };
394
395        let relation_type = ExtensionRelationType::from_str(relation_type_str).unwrap();
396        let mut args = ExtensionArgs::new(relation_type);
397
398        // Parse optional arguments
399        let ext_arguments = iter.next().unwrap();
400        match ext_arguments.as_rule() {
401            Rule::arguments => {
402                arguments_rule_parsing(ext_arguments, &mut args);
403            }
404            r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
405        }
406
407        // parse optional output columns
408        let extension_columns = iter.next();
409        if let Some(value) = extension_columns {
410            match value.as_rule() {
411                Rule::extension_columns => {
412                    for col_pair in value.into_inner() {
413                        if col_pair.as_rule() == Rule::extension_column {
414                            let column = ExtensionColumn::parse_pair(col_pair);
415                            args.output_columns.push(column);
416                        }
417                    }
418                }
419                r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
420            }
421        }
422
423        ExtensionInvocation {
424            name: custom_name,
425            args,
426        }
427    }
428}
429
430/// A parsed `+ Enh:Name[args]` or `+ Opt:Name[args]` line.
431#[derive(Debug, Clone)]
432pub struct AdvExtInvocation {
433    /// Whether this is an enhancement (`ExtensionType::Enhancement`) or
434    /// optimization (`ExtensionType::Optimization`).  The grammar restricts
435    /// the value to those two variants; `ExtensionType::Relation` will never
436    /// appear here.
437    pub ext_type: ExtensionType,
438    pub name: String,
439    pub args: ExtensionArgs,
440}
441
442impl ParsePair for AdvExtInvocation {
443    fn rule() -> Rule {
444        Rule::adv_extension
445    }
446
447    fn message() -> &'static str {
448        "AdvExtInvocation"
449    }
450
451    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
452        assert_eq!(pair.as_rule(), Self::rule());
453
454        let mut iter = pair.into_inner();
455
456        // First token: adv_ext_type — grammar guarantees "Enh" or "Opt"
457        let type_pair = iter.next().unwrap(); // Grammar guarantees adv_ext_type exists
458        let ext_type = match type_pair.as_str() {
459            "Enh" => ExtensionType::Enhancement,
460            "Opt" => ExtensionType::Optimization,
461            other => unreachable!("Unexpected adv_ext_type: {other}"),
462        };
463
464        // Second token: name
465        let name_pair = iter.next().unwrap();
466        let name = Name::parse_pair(name_pair).0.to_string();
467
468        // Remaining token: arguments — grammar guarantees it is always present
469        // Use Leaf as the relation_type placeholder — adv_extensions don't have children
470        let mut args = ExtensionArgs::new(crate::extensions::ExtensionRelationType::Leaf);
471
472        let arguments_pair = iter.next().unwrap();
473        match arguments_pair.as_rule() {
474            Rule::arguments => {
475                arguments_rule_parsing(arguments_pair, &mut args);
476            }
477            r => unreachable!("Unexpected rule in AdvExtInvocation args: {r:?}"),
478        }
479
480        AdvExtInvocation {
481            ext_type,
482            name,
483            args,
484        }
485    }
486}
487
488fn arguments_rule_parsing(inner_pair: pest::iterators::Pair<'_, Rule>, args: &mut ExtensionArgs) {
489    for arg in inner_pair.into_inner() {
490        match arg.as_rule() {
491            Rule::extension_arguments => {
492                // Parse positional arguments
493                for arg_pair in arg.into_inner() {
494                    assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
495                    args.positional.push(ExtensionValue::parse_pair(arg_pair));
496                }
497            }
498            Rule::extension_named_arguments => {
499                for arg_pair in arg.into_inner() {
500                    assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
501                    let mut arg_iter = arg_pair.into_inner();
502                    let name_p = arg_iter.next().unwrap();
503                    let value_p = arg_iter.next().unwrap();
504                    let key = Name::parse_pair(name_p).0.to_string();
505                    let val = ExtensionValue::parse_pair(value_p);
506                    args.named.insert(key, val);
507                }
508            }
509            Rule::empty => {}
510            r => unreachable!("Unexpected rule in extension args: {r:?}"),
511        }
512    }
513}
514
515impl ExtensionRelationType {
516    /// Create appropriate relation structure from extension detail and children.
517    /// This method handles the structural logic for creating different extension relation types.
518    pub fn create_rel(
519        self,
520        detail: Option<Any>,
521        children: Vec<Box<substrait::proto::Rel>>,
522    ) -> Result<substrait::proto::Rel, String> {
523        use substrait::proto::rel::RelType;
524        use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
525
526        // Validate child count matches relation type
527        self.validate_child_count(children.len())?;
528
529        // The output column count is returned alongside the Rel by parse_extension_relation
530        // and flows up the parse tree through Rust return values
531        let rel_type = match self {
532            ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
533                common: None,
534                detail: detail.map(Into::into),
535            }),
536            ExtensionRelationType::Single => {
537                let input = children.into_iter().next().map(|child| *child);
538                RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
539                    common: None,
540                    detail: detail.map(Into::into),
541                    input: input.map(Box::new),
542                }))
543            }
544            ExtensionRelationType::Multi => {
545                let inputs = children.into_iter().map(|child| *child).collect();
546                RelType::ExtensionMulti(ExtensionMultiRel {
547                    common: None,
548                    detail: detail.map(Into::into),
549                    inputs,
550                })
551            }
552        };
553
554        Ok(substrait::proto::Rel {
555            rel_type: Some(rel_type),
556        })
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::parser::Parser;
564
565    #[test]
566    fn test_parse_urn_extension_declaration() {
567        let line = "@1: /my/urn1";
568        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
569        assert_eq!(urn.anchor, 1);
570        assert_eq!(urn.urn, "/my/urn1");
571    }
572
573    #[test]
574    fn test_parse_simple_extension_declaration() {
575        // Assumes a format like "@anchor: urn_anchor:name"
576        let line = "#5@2: my_function_name";
577        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
578        assert_eq!(decl.anchor, 5);
579        assert_eq!(decl.urn_anchor, 2);
580        assert_eq!(decl.name, "my_function_name");
581
582        // Test with a different name format, e.g. with underscores and numbers
583        let line2 = "#10  @200: another_ext_123";
584        let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
585        assert_eq!(decl.anchor, 10);
586        assert_eq!(decl.urn_anchor, 200);
587        assert_eq!(decl.name, "another_ext_123");
588    }
589
590    #[test]
591    fn test_parse_urn_extension_declaration_str() {
592        let line = "@1: /my/urn1";
593        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
594        assert_eq!(urn.anchor, 1);
595        assert_eq!(urn.urn, "/my/urn1");
596    }
597
598    #[test]
599    fn test_extensions_round_trip_plan() {
600        let input = r#"
601=== Extensions
602URNs:
603  @  1: /urn/common
604  @  2: /urn/specific_funcs
605Functions:
606  # 10 @  1: func_a
607  # 11 @  2: func_b_special
608Types:
609  # 20 @  1: SomeType
610Type Variations:
611  # 30 @  2: VarX
612"#
613        .trim_start();
614
615        // Parse the input using the structural parser
616        let plan = Parser::parse(input).unwrap();
617
618        // Verify the plan has the expected extensions
619        assert_eq!(plan.extension_urns.len(), 2);
620        assert_eq!(plan.extensions.len(), 4);
621
622        // Convert the plan extensions back to SimpleExtensions
623        let (extensions, errors) =
624            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
625
626        assert!(errors.is_empty());
627        // Convert back to string
628        let output = extensions.to_string("  ");
629
630        // The output should match the input
631        assert_eq!(output, input);
632    }
633
634    #[test]
635    fn test_parse_simple_extension_declaration_compound_name() {
636        // A function name that includes a Substrait signature suffix
637        let line = "#1 @2: equal:any_any";
638        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
639        assert_eq!(decl.anchor, 1);
640        assert_eq!(decl.urn_anchor, 2);
641        assert_eq!(decl.name, "equal:any_any");
642    }
643
644    #[test]
645    fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
646        let line = "#3 @1: regexp_match_substring:str_str_i64";
647        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
648        assert_eq!(decl.anchor, 3);
649        assert_eq!(decl.urn_anchor, 1);
650        assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
651    }
652
653    #[test]
654    fn test_extensions_round_trip_plan_with_compound_names() {
655        let input = r#"=== Extensions
656URNs:
657  @  1: extension:io.substrait:functions_string
658  @  2: extension:io.substrait:functions_comparison
659Functions:
660  #  1 @  2: equal:any_any
661  #  2 @  1: regexp_match_substring:str_str
662  #  3 @  1: regexp_match_substring:str_str_i64
663"#;
664        let plan = Parser::parse(input).unwrap();
665        let (extensions, errors) =
666            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
667        assert!(errors.is_empty());
668        // Compound names must survive the roundtrip
669        assert_eq!(
670            extensions
671                .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
672                .unwrap()
673                .1
674                .full(),
675            "equal:any_any"
676        );
677        assert_eq!(
678            extensions
679                .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
680                .unwrap()
681                .1
682                .full(),
683            "regexp_match_substring:str_str_i64"
684        );
685        // Text output must reproduce the input exactly
686        assert_eq!(extensions.to_string("  "), input);
687    }
688}