substrait_explain/parser/
extensions.rs1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unwrap_single_pair};
7use crate::extensions::simple::{self, ExtensionKind};
8use crate::extensions::{InsertError, SimpleExtensions};
9use crate::parser::structural::IndentedLine;
10
11#[derive(Debug, Clone, Error)]
12pub enum ExtensionParseError {
13 #[error("Unexpected line, expected {0}")]
14 UnexpectedLine(ExtensionParserState),
15 #[error("Error adding extension: {0}")]
16 ExtensionError(#[from] InsertError),
17 #[error("Error parsing message: {0}")]
18 Message(#[from] super::MessageParseError),
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ExtensionParserState {
25 Extensions,
28 ExtensionUris,
31 ExtensionDeclarations(ExtensionKind),
34}
35
36impl fmt::Display for ExtensionParserState {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URIs:'"),
40 ExtensionParserState::ExtensionUris => write!(f, "Extension URIs"),
41 ExtensionParserState::ExtensionDeclarations(kind) => {
42 write!(f, "Extension Declaration for {kind}")
43 }
44 }
45 }
46}
47
48#[derive(Debug)]
55pub struct ExtensionParser {
56 state: ExtensionParserState,
57 extensions: SimpleExtensions,
58}
59
60impl Default for ExtensionParser {
61 fn default() -> Self {
62 Self {
63 state: ExtensionParserState::Extensions,
64 extensions: SimpleExtensions::new(),
65 }
66 }
67}
68
69impl ExtensionParser {
70 pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
71 if line.1.is_empty() {
72 self.state = ExtensionParserState::Extensions;
75 return Ok(());
76 }
77
78 match self.state {
79 ExtensionParserState::Extensions => self.parse_subsection(line),
80 ExtensionParserState::ExtensionUris => self.parse_extension_uris(line),
81 ExtensionParserState::ExtensionDeclarations(extension_kind) => {
82 self.parse_declarations(line, extension_kind)
83 }
84 }
85 }
86
87 fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
88 match line {
89 IndentedLine(0, simple::EXTENSION_URIS_HEADER) => {
90 self.state = ExtensionParserState::ExtensionUris;
91 Ok(())
92 }
93 IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
94 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
95 Ok(())
96 }
97 IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
98 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
99 Ok(())
100 }
101 IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
102 self.state =
103 ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
104 Ok(())
105 }
106 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
107 }
108 }
109
110 fn parse_extension_uris(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
111 match line {
112 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
114 let uri =
115 URIExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
116 self.extensions.add_extension_uri(uri.uri, uri.anchor)?;
117 Ok(())
118 }
119 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
120 }
121 }
122
123 fn parse_declarations(
124 &mut self,
125 line: IndentedLine,
126 extension_kind: ExtensionKind,
127 ) -> Result<(), ExtensionParseError> {
128 match line {
129 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
131 let decl = SimpleExtensionDeclaration::from_str(s)?;
132 self.extensions.add_extension(
133 extension_kind,
134 decl.uri_anchor,
135 decl.anchor,
136 decl.name,
137 )?;
138 Ok(())
139 }
140 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
141 }
142 }
143
144 pub fn extensions(&self) -> &SimpleExtensions {
145 &self.extensions
146 }
147
148 pub fn state(&self) -> ExtensionParserState {
149 self.state
150 }
151}
152
153#[derive(Debug, Clone, PartialEq)]
154pub struct URIExtensionDeclaration {
155 pub anchor: u32,
156 pub uri: String,
157}
158
159#[derive(Debug, Clone, PartialEq)]
160pub struct SimpleExtensionDeclaration {
161 pub anchor: u32,
162 pub uri_anchor: u32,
163 pub name: String,
164}
165
166impl ParsePair for URIExtensionDeclaration {
167 fn rule() -> Rule {
168 Rule::extension_uri_declaration
169 }
170
171 fn message() -> &'static str {
172 "URIExtensionDeclaration"
173 }
174
175 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
176 assert_eq!(pair.as_rule(), Self::rule());
177
178 let mut iter = RuleIter::from(pair.into_inner());
179 let anchor_pair = iter.pop(Rule::uri_anchor);
180 let anchor = unwrap_single_pair(anchor_pair)
181 .as_str()
182 .parse::<u32>()
183 .unwrap();
184 let uri = iter.pop(Rule::uri).as_str().to_string();
185 iter.done();
186
187 URIExtensionDeclaration { anchor, uri }
188 }
189}
190
191impl FromStr for URIExtensionDeclaration {
192 type Err = super::MessageParseError;
193
194 fn from_str(s: &str) -> Result<Self, Self::Err> {
195 Self::parse_str(s)
196 }
197}
198
199impl ParsePair for SimpleExtensionDeclaration {
200 fn rule() -> Rule {
201 Rule::simple_extension
202 }
203
204 fn message() -> &'static str {
205 "SimpleExtensionDeclaration"
206 }
207
208 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
209 assert_eq!(pair.as_rule(), Self::rule());
210 let mut iter = RuleIter::from(pair.into_inner());
211 let anchor_pair = iter.pop(Rule::anchor);
212 let anchor = unwrap_single_pair(anchor_pair)
213 .as_str()
214 .parse::<u32>()
215 .unwrap();
216 let uri_anchor_pair = iter.pop(Rule::uri_anchor);
217 let uri_anchor = unwrap_single_pair(uri_anchor_pair)
218 .as_str()
219 .parse::<u32>()
220 .unwrap();
221 let name_pair = iter.pop(Rule::name);
222 let name = unwrap_single_pair(name_pair).as_str().to_string();
223 iter.done();
224
225 SimpleExtensionDeclaration {
226 anchor,
227 uri_anchor,
228 name,
229 }
230 }
231}
232
233impl FromStr for SimpleExtensionDeclaration {
234 type Err = super::MessageParseError;
235
236 fn from_str(s: &str) -> Result<Self, Self::Err> {
237 Self::parse_str(s)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::parser::Parser;
245
246 #[test]
247 fn test_parse_uri_extension_declaration() {
248 let line = "@1: /my/uri1";
249 let uri = URIExtensionDeclaration::parse_str(line).unwrap();
250 assert_eq!(uri.anchor, 1);
251 assert_eq!(uri.uri, "/my/uri1");
252 }
253
254 #[test]
255 fn test_parse_simple_extension_declaration() {
256 let line = "#5@2: my_function_name";
258 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
259 assert_eq!(decl.anchor, 5);
260 assert_eq!(decl.uri_anchor, 2);
261 assert_eq!(decl.name, "my_function_name");
262
263 let line2 = "#10 @200: another_ext_123";
265 let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
266 assert_eq!(decl.anchor, 10);
267 assert_eq!(decl.uri_anchor, 200);
268 assert_eq!(decl.name, "another_ext_123");
269 }
270
271 #[test]
272 fn test_parse_uri_extension_declaration_str() {
273 let line = "@1: /my/uri1";
274 let uri = URIExtensionDeclaration::parse_str(line).unwrap();
275 assert_eq!(uri.anchor, 1);
276 assert_eq!(uri.uri, "/my/uri1");
277 }
278
279 #[test]
280 fn test_extensions_round_trip_plan() {
281 let input = r#"
282=== Extensions
283URIs:
284 @ 1: /uri/common
285 @ 2: /uri/specific_funcs
286Functions:
287 # 10 @ 1: func_a
288 # 11 @ 2: func_b_special
289Types:
290 # 20 @ 1: SomeType
291Type Variations:
292 # 30 @ 2: VarX
293"#
294 .trim_start();
295
296 let plan = Parser::parse(input).unwrap();
298
299 assert_eq!(plan.extension_uris.len(), 2);
301 assert_eq!(plan.extensions.len(), 4);
302
303 let (extensions, _errors) =
305 SimpleExtensions::from_extensions(&plan.extension_uris, &plan.extensions);
306
307 let output = extensions.to_string(" ");
309
310 assert_eq!(output, input);
312 }
313}