substrait_explain/parser/
extensions.rs

1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unwrap_single_pair};
7use crate::extensions::simple::{self, ExtensionKind};
8use crate::extensions::{InsertError, SimpleExtensions};
9use crate::parser::structural::IndentedLine;
10
11#[derive(Debug, Clone, Error)]
12pub enum ExtensionParseError {
13    #[error("Unexpected line, expected {0}")]
14    UnexpectedLine(ExtensionParserState),
15    #[error("Error adding extension: {0}")]
16    ExtensionError(#[from] InsertError),
17    #[error("Error parsing message: {0}")]
18    Message(#[from] super::MessageParseError),
19}
20
21/// The state of the extension parser - tracking what section of extension
22/// parsing we are in.
23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ExtensionParserState {
25    // The extensions section, after parsing the 'Extensions:' header, before
26    // parsing any subsection headers.
27    Extensions,
28    // The extension URIs section, after parsing the 'URIs:' subsection header,
29    // and any URIs so far.
30    ExtensionUris,
31    // In a subsection, after parsing the subsection header, and any
32    // declarations so far.
33    ExtensionDeclarations(ExtensionKind),
34}
35
36impl fmt::Display for ExtensionParserState {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URIs:'"),
40            ExtensionParserState::ExtensionUris => write!(f, "Extension URIs"),
41            ExtensionParserState::ExtensionDeclarations(kind) => {
42                write!(f, "Extension Declaration for {kind}")
43            }
44        }
45    }
46}
47
48/// The parser for the extension section of the Substrait file format.
49///
50/// This is responsible for parsing the extension section of the file, which
51/// contains the extension URIs and declarations. Note that this parser does not
52/// parse the header; otherwise, this is symmetric with the
53/// SimpleExtensions::write method.
54#[derive(Debug)]
55pub struct ExtensionParser {
56    state: ExtensionParserState,
57    extensions: SimpleExtensions,
58}
59
60impl Default for ExtensionParser {
61    fn default() -> Self {
62        Self {
63            state: ExtensionParserState::Extensions,
64            extensions: SimpleExtensions::new(),
65        }
66    }
67}
68
69impl ExtensionParser {
70    pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
71        if line.1.is_empty() {
72            // Blank lines are allowed between subsections, so if we see
73            // one, we revert out of the subsection.
74            self.state = ExtensionParserState::Extensions;
75            return Ok(());
76        }
77
78        match self.state {
79            ExtensionParserState::Extensions => self.parse_subsection(line),
80            ExtensionParserState::ExtensionUris => self.parse_extension_uris(line),
81            ExtensionParserState::ExtensionDeclarations(extension_kind) => {
82                self.parse_declarations(line, extension_kind)
83            }
84        }
85    }
86
87    fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
88        match line {
89            IndentedLine(0, simple::EXTENSION_URIS_HEADER) => {
90                self.state = ExtensionParserState::ExtensionUris;
91                Ok(())
92            }
93            IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
94                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
95                Ok(())
96            }
97            IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
98                self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
99                Ok(())
100            }
101            IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
102                self.state =
103                    ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
104                Ok(())
105            }
106            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
107        }
108    }
109
110    fn parse_extension_uris(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
111        match line {
112            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
113            IndentedLine(1, s) => {
114                let uri =
115                    URIExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
116                self.extensions.add_extension_uri(uri.uri, uri.anchor)?;
117                Ok(())
118            }
119            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
120        }
121    }
122
123    fn parse_declarations(
124        &mut self,
125        line: IndentedLine,
126        extension_kind: ExtensionKind,
127    ) -> Result<(), ExtensionParseError> {
128        match line {
129            IndentedLine(0, _s) => self.parse_subsection(line), // Pass the original line with 0 indent
130            IndentedLine(1, s) => {
131                let decl = SimpleExtensionDeclaration::from_str(s)?;
132                self.extensions.add_extension(
133                    extension_kind,
134                    decl.uri_anchor,
135                    decl.anchor,
136                    decl.name,
137                )?;
138                Ok(())
139            }
140            _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
141        }
142    }
143
144    pub fn extensions(&self) -> &SimpleExtensions {
145        &self.extensions
146    }
147
148    pub fn state(&self) -> ExtensionParserState {
149        self.state
150    }
151}
152
153#[derive(Debug, Clone, PartialEq)]
154pub struct URIExtensionDeclaration {
155    pub anchor: u32,
156    pub uri: String,
157}
158
159#[derive(Debug, Clone, PartialEq)]
160pub struct SimpleExtensionDeclaration {
161    pub anchor: u32,
162    pub uri_anchor: u32,
163    pub name: String,
164}
165
166impl ParsePair for URIExtensionDeclaration {
167    fn rule() -> Rule {
168        Rule::extension_uri_declaration
169    }
170
171    fn message() -> &'static str {
172        "URIExtensionDeclaration"
173    }
174
175    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
176        assert_eq!(pair.as_rule(), Self::rule());
177
178        let mut iter = RuleIter::from(pair.into_inner());
179        let anchor_pair = iter.pop(Rule::uri_anchor);
180        let anchor = unwrap_single_pair(anchor_pair)
181            .as_str()
182            .parse::<u32>()
183            .unwrap();
184        let uri = iter.pop(Rule::uri).as_str().to_string();
185        iter.done();
186
187        URIExtensionDeclaration { anchor, uri }
188    }
189}
190
191impl FromStr for URIExtensionDeclaration {
192    type Err = super::MessageParseError;
193
194    fn from_str(s: &str) -> Result<Self, Self::Err> {
195        Self::parse_str(s)
196    }
197}
198
199impl ParsePair for SimpleExtensionDeclaration {
200    fn rule() -> Rule {
201        Rule::simple_extension
202    }
203
204    fn message() -> &'static str {
205        "SimpleExtensionDeclaration"
206    }
207
208    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
209        assert_eq!(pair.as_rule(), Self::rule());
210        let mut iter = RuleIter::from(pair.into_inner());
211        let anchor_pair = iter.pop(Rule::anchor);
212        let anchor = unwrap_single_pair(anchor_pair)
213            .as_str()
214            .parse::<u32>()
215            .unwrap();
216        let uri_anchor_pair = iter.pop(Rule::uri_anchor);
217        let uri_anchor = unwrap_single_pair(uri_anchor_pair)
218            .as_str()
219            .parse::<u32>()
220            .unwrap();
221        let name_pair = iter.pop(Rule::name);
222        let name = unwrap_single_pair(name_pair).as_str().to_string();
223        iter.done();
224
225        SimpleExtensionDeclaration {
226            anchor,
227            uri_anchor,
228            name,
229        }
230    }
231}
232
233impl FromStr for SimpleExtensionDeclaration {
234    type Err = super::MessageParseError;
235
236    fn from_str(s: &str) -> Result<Self, Self::Err> {
237        Self::parse_str(s)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::parser::Parser;
245
246    #[test]
247    fn test_parse_uri_extension_declaration() {
248        let line = "@1: /my/uri1";
249        let uri = URIExtensionDeclaration::parse_str(line).unwrap();
250        assert_eq!(uri.anchor, 1);
251        assert_eq!(uri.uri, "/my/uri1");
252    }
253
254    #[test]
255    fn test_parse_simple_extension_declaration() {
256        // Assumes a format like "@anchor: uri_anchor:name"
257        let line = "#5@2: my_function_name";
258        let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
259        assert_eq!(decl.anchor, 5);
260        assert_eq!(decl.uri_anchor, 2);
261        assert_eq!(decl.name, "my_function_name");
262
263        // Test with a different name format, e.g. with underscores and numbers
264        let line2 = "#10  @200: another_ext_123";
265        let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
266        assert_eq!(decl.anchor, 10);
267        assert_eq!(decl.uri_anchor, 200);
268        assert_eq!(decl.name, "another_ext_123");
269    }
270
271    #[test]
272    fn test_parse_uri_extension_declaration_str() {
273        let line = "@1: /my/uri1";
274        let uri = URIExtensionDeclaration::parse_str(line).unwrap();
275        assert_eq!(uri.anchor, 1);
276        assert_eq!(uri.uri, "/my/uri1");
277    }
278
279    #[test]
280    fn test_extensions_round_trip_plan() {
281        let input = r#"
282=== Extensions
283URIs:
284  @  1: /uri/common
285  @  2: /uri/specific_funcs
286Functions:
287  # 10 @  1: func_a
288  # 11 @  2: func_b_special
289Types:
290  # 20 @  1: SomeType
291Type Variations:
292  # 30 @  2: VarX
293"#
294        .trim_start();
295
296        // Parse the input using the structural parser
297        let plan = Parser::parse(input).unwrap();
298
299        // Verify the plan has the expected extensions
300        assert_eq!(plan.extension_uris.len(), 2);
301        assert_eq!(plan.extensions.len(), 4);
302
303        // Convert the plan extensions back to SimpleExtensions
304        let (extensions, _errors) =
305            SimpleExtensions::from_extensions(&plan.extension_uris, &plan.extensions);
306
307        // Convert back to string
308        let output = extensions.to_string("  ");
309
310        // The output should match the input
311        assert_eq!(output, input);
312    }
313}