substrait_explain/parser/
extensions.rs

1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unescape_string, unwrap_single_pair};
7use crate::extensions::simple::{self, ExtensionKind};
8use crate::extensions::{
9    ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, InsertError,
10    RawExpression, SimpleExtensions,
11};
12use crate::parser::structural::IndentedLine;
13
14#[derive(Debug, Clone, Error)]
15pub enum ExtensionParseError {
16    #[error("Unexpected line, expected {0}")]
17    UnexpectedLine(ExtensionParserState),
18    #[error("Error adding extension: {0}")]
19    ExtensionError(#[from] InsertError),
20    #[error("Error parsing message: {0}")]
21    Message(#[from] super::MessageParseError),
22}
23
24/// The state of the extension parser - tracking what section of extension
25/// parsing we are in.
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum ExtensionParserState {
28    // The extensions section, after parsing the 'Extensions:' header, before
29    // parsing any subsection headers.
30    Extensions,
31    // The extension URNs section, after parsing the 'URNs:' subsection header,
32    // and any URNs so far.
33    ExtensionUrns,
34    // In a subsection, after parsing the subsection header, and any
35    // declarations so far.
36    ExtensionDeclarations(ExtensionKind),
37}
38
39impl fmt::Display for ExtensionParserState {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
43            ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
44            ExtensionParserState::ExtensionDeclarations(kind) => {
45                write!(f, "Extension Declaration for {kind}")
46            }
47        }
48    }
49}
50
51/// The parser for the extension section of the Substrait file format.
52///
53/// This is responsible for parsing the extension section of the file, which
54/// contains the extension URNs and declarations. Note that this parser does not
55/// parse the header; otherwise, this is symmetric with the
56/// SimpleExtensions::write method.
57#[derive(Debug)]
58pub struct ExtensionParser {
59    state: ExtensionParserState,
60    extensions: SimpleExtensions,
61}
62
63impl Default for ExtensionParser {
64    fn default() -> Self {
65        Self {
66            state: ExtensionParserState::Extensions,
67            extensions: SimpleExtensions::new(),
68        }
69    }
70}
71
72impl ExtensionParser {
73    pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
74        if line.1.is_empty() {
75            // Blank lines are allowed between subsections, so if we see
76            // one, we revert out of the subsection.
77            self.state = ExtensionParserState::Extensions;
78            return Ok(());
79        }
80
81        match self.state {
82            ExtensionParserState::Extensions => self.parse_subsection(line),
83            ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
84            ExtensionParserState::ExtensionDeclarations(extension_kind) => {
85                self.parse_declarations(line, extension_kind)
86            }
87        }
88    }
89
90    fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
91        match line {
92            IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
93                self.state = ExtensionParserState::ExtensionUrns;
94                Ok(())
95            }
96            IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
97                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
98                Ok(())
99            }
100            IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
101                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
102                Ok(())
103            }
104            IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
105                self.state =
106                    ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
107                Ok(())
108            }
109            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
110        }
111    }
112
113    fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
114        match line {
115            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
116            IndentedLine(1, s) => {
117                let urn =
118                    URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
119                self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
120                Ok(())
121            }
122            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
123        }
124    }
125
126    fn parse_declarations(
127        &mut self,
128        line: IndentedLine,
129        extension_kind: ExtensionKind,
130    ) -> Result<(), ExtensionParseError> {
131        match line {
132            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
133            IndentedLine(1, s) => {
134                let decl = SimpleExtensionDeclaration::from_str(s)?;
135                self.extensions.add_extension(
136                    extension_kind,
137                    decl.urn_anchor,
138                    decl.anchor,
139                    decl.name,
140                )?;
141                Ok(())
142            }
143            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
144        }
145    }
146
147    pub fn extensions(&self) -> &SimpleExtensions {
148        &self.extensions
149    }
150
151    pub fn state(&self) -> ExtensionParserState {
152        self.state
153    }
154}
155
156#[derive(Debug, Clone, PartialEq)]
157pub struct URNExtensionDeclaration {
158    pub anchor: u32,
159    pub urn: String,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct SimpleExtensionDeclaration {
164    pub anchor: u32,
165    pub urn_anchor: u32,
166    pub name: String,
167}
168
169impl ParsePair for URNExtensionDeclaration {
170    fn rule() -> Rule {
171        Rule::extension_urn_declaration
172    }
173
174    fn message() -> &'static str {
175        "URNExtensionDeclaration"
176    }
177
178    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
179        assert_eq!(pair.as_rule(), Self::rule());
180
181        let mut iter = RuleIter::from(pair.into_inner());
182        let anchor_pair = iter.pop(Rule::urn_anchor);
183        let anchor = unwrap_single_pair(anchor_pair)
184            .as_str()
185            .parse::<u32>()
186            .unwrap();
187        let urn = iter.pop(Rule::urn).as_str().to_string();
188        iter.done();
189
190        URNExtensionDeclaration { anchor, urn }
191    }
192}
193
194impl FromStr for URNExtensionDeclaration {
195    type Err = super::MessageParseError;
196
197    fn from_str(s: &str) -> Result<Self, Self::Err> {
198        Self::parse_str(s)
199    }
200}
201
202impl ParsePair for SimpleExtensionDeclaration {
203    fn rule() -> Rule {
204        Rule::simple_extension
205    }
206
207    fn message() -> &'static str {
208        "SimpleExtensionDeclaration"
209    }
210
211    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
212        assert_eq!(pair.as_rule(), Self::rule());
213        let mut iter = RuleIter::from(pair.into_inner());
214        let anchor_pair = iter.pop(Rule::anchor);
215        let anchor = unwrap_single_pair(anchor_pair)
216            .as_str()
217            .parse::<u32>()
218            .unwrap();
219        let urn_anchor_pair = iter.pop(Rule::urn_anchor);
220        let urn_anchor = unwrap_single_pair(urn_anchor_pair)
221            .as_str()
222            .parse::<u32>()
223            .unwrap();
224        let name_pair = iter.pop(Rule::name);
225        let name = unwrap_single_pair(name_pair).as_str().to_string();
226        iter.done();
227
228        SimpleExtensionDeclaration {
229            anchor,
230            urn_anchor,
231            name,
232        }
233    }
234}
235
236impl FromStr for SimpleExtensionDeclaration {
237    type Err = super::MessageParseError;
238
239    fn from_str(s: &str) -> Result<Self, Self::Err> {
240        Self::parse_str(s)
241    }
242}
243
244// Extension relation parsing implementations
245// These were moved from extensions/registry.rs to maintain clean architecture
246
247use crate::extensions::any::Any;
248use crate::parser::expressions::{FieldIndex, Name};
249
250impl ParsePair for ExtensionValue {
251    fn rule() -> Rule {
252        Rule::extension_argument
253    }
254
255    fn message() -> &'static str {
256        "ExtensionValue"
257    }
258
259    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
260        assert_eq!(pair.as_rule(), Self::rule());
261
262        let inner = unwrap_single_pair(pair); // Extract the actual content
263
264        match inner.as_rule() {
265            Rule::reference => {
266                // Reuse the existing FieldIndex parser, then extract the i32
267                let field_index = FieldIndex::parse_pair(inner);
268                ExtensionValue::Reference(field_index.0)
269            }
270            Rule::literal => {
271                // Literal can contain integer, float, boolean, or string_literal
272                let mut literal_inner = inner.into_inner();
273                let value_pair = literal_inner.next().unwrap();
274                match value_pair.as_rule() {
275                    Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
276                    Rule::integer => {
277                        let int_val = value_pair.as_str().parse::<i64>().unwrap();
278                        ExtensionValue::Integer(int_val)
279                    }
280                    Rule::float => {
281                        let float_val = value_pair.as_str().parse::<f64>().unwrap();
282                        ExtensionValue::Float(float_val)
283                    }
284                    Rule::boolean => {
285                        let bool_val = value_pair.as_str() == "true";
286                        ExtensionValue::Boolean(bool_val)
287                    }
288                    _ => panic!("Unexpected literal value type: {:?}", value_pair.as_rule()),
289                }
290            }
291            Rule::string_literal => ExtensionValue::String(unescape_string(inner)),
292            Rule::integer => {
293                // Direct integer (not wrapped in literal rule)
294                let int_val = inner.as_str().parse::<i64>().unwrap();
295                ExtensionValue::Integer(int_val)
296            }
297            Rule::float => {
298                // Direct float (not wrapped in literal rule)
299                let float_val = inner.as_str().parse::<f64>().unwrap();
300                ExtensionValue::Float(float_val)
301            }
302            Rule::boolean => {
303                // Direct boolean (not wrapped in literal rule)
304                let bool_val = inner.as_str() == "true";
305                ExtensionValue::Boolean(bool_val)
306            }
307            Rule::expression => {
308                ExtensionValue::Expression(RawExpression::new(inner.as_str().to_string()))
309            }
310            _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
311        }
312    }
313}
314
315impl ParsePair for ExtensionColumn {
316    fn rule() -> Rule {
317        Rule::extension_column
318    }
319
320    fn message() -> &'static str {
321        "ExtensionColumn"
322    }
323
324    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
325        assert_eq!(pair.as_rule(), Self::rule());
326
327        let inner = unwrap_single_pair(pair); // Extract the actual content
328
329        match inner.as_rule() {
330            Rule::named_column => {
331                let mut iter = inner.into_inner();
332                let name_pair = iter.next().unwrap(); // Grammar guarantees name exists
333                let type_pair = iter.next().unwrap(); // Grammar guarantees type exists
334
335                let name = Name::parse_pair(name_pair).0.to_string(); // Reuse existing Name parser
336                let type_spec = type_pair.as_str().to_string(); // Types are complex, store as string for now
337
338                ExtensionColumn::Named { name, type_spec }
339            }
340            Rule::reference => {
341                // Reuse the existing FieldIndex parser, then extract the i32
342                let field_index = FieldIndex::parse_pair(inner);
343                ExtensionColumn::Reference(field_index.0)
344            }
345            Rule::expression => {
346                ExtensionColumn::Expression(RawExpression::new(inner.as_str().to_string()))
347            }
348            _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
349        }
350    }
351}
352
353/// Fully parsed extension invocation, including the user-supplied name and the
354/// structured argument payload.
355#[derive(Debug, Clone)]
356pub struct ExtensionInvocation {
357    pub name: String,
358    pub args: ExtensionArgs,
359}
360
361impl ParsePair for ExtensionInvocation {
362    fn rule() -> Rule {
363        Rule::extension_relation
364    }
365
366    fn message() -> &'static str {
367        "ExtensionInvocation"
368    }
369
370    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
371        assert_eq!(pair.as_rule(), Self::rule());
372
373        let mut iter = pair.into_inner();
374
375        // Parse extension name to determine relation type and custom name
376        let extension_name_pair = iter.next().unwrap(); // Grammar guarantees extension_name exists
377        let full_extension_name = extension_name_pair.as_str();
378
379        // Extract the relation type and custom name from the extension name
380        // (e.g., "ExtensionLeaf:ParquetScan" -> "ExtensionLeaf" and "ParquetScan")
381        let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
382            let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
383            (parts[0], parts[1].to_string())
384        } else {
385            (full_extension_name, "UnknownExtension".to_string())
386        };
387
388        let relation_type = ExtensionRelationType::from_str(relation_type_str).unwrap();
389        let mut args = ExtensionArgs::new(relation_type);
390
391        // Parse optional arguments and columns
392        for inner_pair in iter {
393            match inner_pair.as_rule() {
394                Rule::extension_arguments => {
395                    // Parse positional arguments
396                    for arg_pair in inner_pair.into_inner() {
397                        if arg_pair.as_rule() == Rule::extension_argument {
398                            let value = ExtensionValue::parse_pair(arg_pair);
399                            args.positional.push(value);
400                        }
401                    }
402                }
403                Rule::extension_named_arguments => {
404                    // Parse named arguments
405                    for arg_pair in inner_pair.into_inner() {
406                        if arg_pair.as_rule() == Rule::extension_named_argument {
407                            let mut arg_iter = arg_pair.into_inner();
408                            let name_pair = arg_iter.next().unwrap();
409                            let value_pair = arg_iter.next().unwrap();
410
411                            let name = Name::parse_pair(name_pair).0.to_string();
412                            let value = ExtensionValue::parse_pair(value_pair);
413                            args.named.insert(name, value);
414                        }
415                    }
416                }
417                Rule::extension_columns => {
418                    // Parse output columns
419                    for col_pair in inner_pair.into_inner() {
420                        if col_pair.as_rule() == Rule::extension_column {
421                            let column = ExtensionColumn::parse_pair(col_pair);
422                            args.output_columns.push(column);
423                        }
424                    }
425                }
426                Rule::empty => {} // "_" — no arguments
427                r => panic!("Unexpected rule in ExtensionArgs: {:?}", r),
428            }
429        }
430
431        ExtensionInvocation {
432            name: custom_name,
433            args,
434        }
435    }
436}
437
438impl ExtensionRelationType {
439    /// Create appropriate relation structure from extension detail and children.
440    /// This method handles the structural logic for creating different extension relation types.
441    pub fn create_rel(
442        self,
443        detail: Option<Any>,
444        children: Vec<Box<substrait::proto::Rel>>,
445    ) -> Result<substrait::proto::Rel, String> {
446        use substrait::proto::rel::RelType;
447        use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
448
449        // Validate child count matches relation type
450        self.validate_child_count(children.len())?;
451
452        let rel_type = match self {
453            ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
454                common: None,
455                detail: detail.map(Into::into),
456            }),
457            ExtensionRelationType::Single => {
458                let input = children.into_iter().next().map(|child| *child);
459                RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
460                    common: None,
461                    detail: detail.map(Into::into),
462                    input: input.map(Box::new),
463                }))
464            }
465            ExtensionRelationType::Multi => {
466                let inputs = children.into_iter().map(|child| *child).collect();
467                RelType::ExtensionMulti(ExtensionMultiRel {
468                    common: None,
469                    detail: detail.map(Into::into),
470                    inputs,
471                })
472            }
473        };
474
475        Ok(substrait::proto::Rel {
476            rel_type: Some(rel_type),
477        })
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::parser::Parser;
485
486    #[test]
487    fn test_parse_urn_extension_declaration() {
488        let line = "@1: /my/urn1";
489        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
490        assert_eq!(urn.anchor, 1);
491        assert_eq!(urn.urn, "/my/urn1");
492    }
493
494    #[test]
495    fn test_parse_simple_extension_declaration() {
496        // Assumes a format like "@anchor: urn_anchor:name"
497        let line = "#5@2: my_function_name";
498        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
499        assert_eq!(decl.anchor, 5);
500        assert_eq!(decl.urn_anchor, 2);
501        assert_eq!(decl.name, "my_function_name");
502
503        // Test with a different name format, e.g. with underscores and numbers
504        let line2 = "#10  @200: another_ext_123";
505        let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
506        assert_eq!(decl.anchor, 10);
507        assert_eq!(decl.urn_anchor, 200);
508        assert_eq!(decl.name, "another_ext_123");
509    }
510
511    #[test]
512    fn test_parse_urn_extension_declaration_str() {
513        let line = "@1: /my/urn1";
514        let urn = URNExtensionDeclaration::parse_str(line).unwrap();
515        assert_eq!(urn.anchor, 1);
516        assert_eq!(urn.urn, "/my/urn1");
517    }
518
519    #[test]
520    fn test_extensions_round_trip_plan() {
521        let input = r#"
522=== Extensions
523URNs:
524  @  1: /urn/common
525  @  2: /urn/specific_funcs
526Functions:
527  # 10 @  1: func_a
528  # 11 @  2: func_b_special
529Types:
530  # 20 @  1: SomeType
531Type Variations:
532  # 30 @  2: VarX
533"#
534        .trim_start();
535
536        // Parse the input using the structural parser
537        let plan = Parser::parse(input).unwrap();
538
539        // Verify the plan has the expected extensions
540        assert_eq!(plan.extension_urns.len(), 2);
541        assert_eq!(plan.extensions.len(), 4);
542
543        // Convert the plan extensions back to SimpleExtensions
544        let (extensions, errors) =
545            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
546
547        assert!(errors.is_empty());
548        // Convert back to string
549        let output = extensions.to_string("  ");
550
551        // The output should match the input
552        assert_eq!(output, input);
553    }
554}