substrait_explain/parser/
extensions.rs1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unescape_string, unwrap_single_pair};
7use crate::extensions::simple::{self, ExtensionKind};
8use crate::extensions::{
9 ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, InsertError,
10 RawExpression, SimpleExtensions,
11};
12use crate::parser::structural::IndentedLine;
13
14#[derive(Debug, Clone, Error)]
15pub enum ExtensionParseError {
16 #[error("Unexpected line, expected {0}")]
17 UnexpectedLine(ExtensionParserState),
18 #[error("Error adding extension: {0}")]
19 ExtensionError(#[from] InsertError),
20 #[error("Error parsing message: {0}")]
21 Message(#[from] super::MessageParseError),
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub enum ExtensionParserState {
28 Extensions,
31 ExtensionUrns,
34 ExtensionDeclarations(ExtensionKind),
37}
38
39impl fmt::Display for ExtensionParserState {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
43 ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
44 ExtensionParserState::ExtensionDeclarations(kind) => {
45 write!(f, "Extension Declaration for {kind}")
46 }
47 }
48 }
49}
50
51#[derive(Debug)]
58pub struct ExtensionParser {
59 state: ExtensionParserState,
60 extensions: SimpleExtensions,
61}
62
63impl Default for ExtensionParser {
64 fn default() -> Self {
65 Self {
66 state: ExtensionParserState::Extensions,
67 extensions: SimpleExtensions::new(),
68 }
69 }
70}
71
72impl ExtensionParser {
73 pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
74 if line.1.is_empty() {
75 self.state = ExtensionParserState::Extensions;
78 return Ok(());
79 }
80
81 match self.state {
82 ExtensionParserState::Extensions => self.parse_subsection(line),
83 ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
84 ExtensionParserState::ExtensionDeclarations(extension_kind) => {
85 self.parse_declarations(line, extension_kind)
86 }
87 }
88 }
89
90 fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
91 match line {
92 IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
93 self.state = ExtensionParserState::ExtensionUrns;
94 Ok(())
95 }
96 IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
97 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
98 Ok(())
99 }
100 IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
101 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
102 Ok(())
103 }
104 IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
105 self.state =
106 ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
107 Ok(())
108 }
109 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
110 }
111 }
112
113 fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
114 match line {
115 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
117 let urn =
118 URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
119 self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
120 Ok(())
121 }
122 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
123 }
124 }
125
126 fn parse_declarations(
127 &mut self,
128 line: IndentedLine,
129 extension_kind: ExtensionKind,
130 ) -> Result<(), ExtensionParseError> {
131 match line {
132 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
134 let decl = SimpleExtensionDeclaration::from_str(s)?;
135 self.extensions.add_extension(
136 extension_kind,
137 decl.urn_anchor,
138 decl.anchor,
139 decl.name,
140 )?;
141 Ok(())
142 }
143 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
144 }
145 }
146
147 pub fn extensions(&self) -> &SimpleExtensions {
148 &self.extensions
149 }
150
151 pub fn state(&self) -> ExtensionParserState {
152 self.state
153 }
154}
155
156#[derive(Debug, Clone, PartialEq)]
157pub struct URNExtensionDeclaration {
158 pub anchor: u32,
159 pub urn: String,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct SimpleExtensionDeclaration {
164 pub anchor: u32,
165 pub urn_anchor: u32,
166 pub name: String,
167}
168
169impl ParsePair for URNExtensionDeclaration {
170 fn rule() -> Rule {
171 Rule::extension_urn_declaration
172 }
173
174 fn message() -> &'static str {
175 "URNExtensionDeclaration"
176 }
177
178 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
179 assert_eq!(pair.as_rule(), Self::rule());
180
181 let mut iter = RuleIter::from(pair.into_inner());
182 let anchor_pair = iter.pop(Rule::urn_anchor);
183 let anchor = unwrap_single_pair(anchor_pair)
184 .as_str()
185 .parse::<u32>()
186 .unwrap();
187 let urn = iter.pop(Rule::urn).as_str().to_string();
188 iter.done();
189
190 URNExtensionDeclaration { anchor, urn }
191 }
192}
193
194impl FromStr for URNExtensionDeclaration {
195 type Err = super::MessageParseError;
196
197 fn from_str(s: &str) -> Result<Self, Self::Err> {
198 Self::parse_str(s)
199 }
200}
201
202impl ParsePair for SimpleExtensionDeclaration {
203 fn rule() -> Rule {
204 Rule::simple_extension
205 }
206
207 fn message() -> &'static str {
208 "SimpleExtensionDeclaration"
209 }
210
211 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
212 assert_eq!(pair.as_rule(), Self::rule());
213 let mut iter = RuleIter::from(pair.into_inner());
214 let anchor_pair = iter.pop(Rule::anchor);
215 let anchor = unwrap_single_pair(anchor_pair)
216 .as_str()
217 .parse::<u32>()
218 .unwrap();
219 let urn_anchor_pair = iter.pop(Rule::urn_anchor);
220 let urn_anchor = unwrap_single_pair(urn_anchor_pair)
221 .as_str()
222 .parse::<u32>()
223 .unwrap();
224 let name_pair = iter.pop(Rule::name);
225 let name = unwrap_single_pair(name_pair).as_str().to_string();
226 iter.done();
227
228 SimpleExtensionDeclaration {
229 anchor,
230 urn_anchor,
231 name,
232 }
233 }
234}
235
236impl FromStr for SimpleExtensionDeclaration {
237 type Err = super::MessageParseError;
238
239 fn from_str(s: &str) -> Result<Self, Self::Err> {
240 Self::parse_str(s)
241 }
242}
243
244use crate::extensions::any::Any;
248use crate::parser::expressions::{FieldIndex, Name};
249
250impl ParsePair for ExtensionValue {
251 fn rule() -> Rule {
252 Rule::extension_argument
253 }
254
255 fn message() -> &'static str {
256 "ExtensionValue"
257 }
258
259 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
260 assert_eq!(pair.as_rule(), Self::rule());
261
262 let inner = unwrap_single_pair(pair); match inner.as_rule() {
265 Rule::reference => {
266 let field_index = FieldIndex::parse_pair(inner);
268 ExtensionValue::Reference(field_index.0)
269 }
270 Rule::literal => {
271 let mut literal_inner = inner.into_inner();
273 let value_pair = literal_inner.next().unwrap();
274 match value_pair.as_rule() {
275 Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
276 Rule::integer => {
277 let int_val = value_pair.as_str().parse::<i64>().unwrap();
278 ExtensionValue::Integer(int_val)
279 }
280 Rule::float => {
281 let float_val = value_pair.as_str().parse::<f64>().unwrap();
282 ExtensionValue::Float(float_val)
283 }
284 Rule::boolean => {
285 let bool_val = value_pair.as_str() == "true";
286 ExtensionValue::Boolean(bool_val)
287 }
288 _ => panic!("Unexpected literal value type: {:?}", value_pair.as_rule()),
289 }
290 }
291 Rule::string_literal => ExtensionValue::String(unescape_string(inner)),
292 Rule::integer => {
293 let int_val = inner.as_str().parse::<i64>().unwrap();
295 ExtensionValue::Integer(int_val)
296 }
297 Rule::float => {
298 let float_val = inner.as_str().parse::<f64>().unwrap();
300 ExtensionValue::Float(float_val)
301 }
302 Rule::boolean => {
303 let bool_val = inner.as_str() == "true";
305 ExtensionValue::Boolean(bool_val)
306 }
307 Rule::expression => {
308 ExtensionValue::Expression(RawExpression::new(inner.as_str().to_string()))
309 }
310 _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
311 }
312 }
313}
314
315impl ParsePair for ExtensionColumn {
316 fn rule() -> Rule {
317 Rule::extension_column
318 }
319
320 fn message() -> &'static str {
321 "ExtensionColumn"
322 }
323
324 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
325 assert_eq!(pair.as_rule(), Self::rule());
326
327 let inner = unwrap_single_pair(pair); match inner.as_rule() {
330 Rule::named_column => {
331 let mut iter = inner.into_inner();
332 let name_pair = iter.next().unwrap(); let type_pair = iter.next().unwrap(); let name = Name::parse_pair(name_pair).0.to_string(); let type_spec = type_pair.as_str().to_string(); ExtensionColumn::Named { name, type_spec }
339 }
340 Rule::reference => {
341 let field_index = FieldIndex::parse_pair(inner);
343 ExtensionColumn::Reference(field_index.0)
344 }
345 Rule::expression => {
346 ExtensionColumn::Expression(RawExpression::new(inner.as_str().to_string()))
347 }
348 _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
349 }
350 }
351}
352
353#[derive(Debug, Clone)]
356pub struct ExtensionInvocation {
357 pub name: String,
358 pub args: ExtensionArgs,
359}
360
361impl ParsePair for ExtensionInvocation {
362 fn rule() -> Rule {
363 Rule::extension_relation
364 }
365
366 fn message() -> &'static str {
367 "ExtensionInvocation"
368 }
369
370 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
371 assert_eq!(pair.as_rule(), Self::rule());
372
373 let mut iter = pair.into_inner();
374
375 let extension_name_pair = iter.next().unwrap(); let full_extension_name = extension_name_pair.as_str();
378
379 let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
382 let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
383 (parts[0], parts[1].to_string())
384 } else {
385 (full_extension_name, "UnknownExtension".to_string())
386 };
387
388 let relation_type = ExtensionRelationType::from_str(relation_type_str).unwrap();
389 let mut args = ExtensionArgs::new(relation_type);
390
391 for inner_pair in iter {
393 match inner_pair.as_rule() {
394 Rule::extension_arguments => {
395 for arg_pair in inner_pair.into_inner() {
397 if arg_pair.as_rule() == Rule::extension_argument {
398 let value = ExtensionValue::parse_pair(arg_pair);
399 args.positional.push(value);
400 }
401 }
402 }
403 Rule::extension_named_arguments => {
404 for arg_pair in inner_pair.into_inner() {
406 if arg_pair.as_rule() == Rule::extension_named_argument {
407 let mut arg_iter = arg_pair.into_inner();
408 let name_pair = arg_iter.next().unwrap();
409 let value_pair = arg_iter.next().unwrap();
410
411 let name = Name::parse_pair(name_pair).0.to_string();
412 let value = ExtensionValue::parse_pair(value_pair);
413 args.named.insert(name, value);
414 }
415 }
416 }
417 Rule::extension_columns => {
418 for col_pair in inner_pair.into_inner() {
420 if col_pair.as_rule() == Rule::extension_column {
421 let column = ExtensionColumn::parse_pair(col_pair);
422 args.output_columns.push(column);
423 }
424 }
425 }
426 Rule::empty => {} r => panic!("Unexpected rule in ExtensionArgs: {:?}", r),
428 }
429 }
430
431 ExtensionInvocation {
432 name: custom_name,
433 args,
434 }
435 }
436}
437
438impl ExtensionRelationType {
439 pub fn create_rel(
442 self,
443 detail: Option<Any>,
444 children: Vec<Box<substrait::proto::Rel>>,
445 ) -> Result<substrait::proto::Rel, String> {
446 use substrait::proto::rel::RelType;
447 use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
448
449 self.validate_child_count(children.len())?;
451
452 let rel_type = match self {
453 ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
454 common: None,
455 detail: detail.map(Into::into),
456 }),
457 ExtensionRelationType::Single => {
458 let input = children.into_iter().next().map(|child| *child);
459 RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
460 common: None,
461 detail: detail.map(Into::into),
462 input: input.map(Box::new),
463 }))
464 }
465 ExtensionRelationType::Multi => {
466 let inputs = children.into_iter().map(|child| *child).collect();
467 RelType::ExtensionMulti(ExtensionMultiRel {
468 common: None,
469 detail: detail.map(Into::into),
470 inputs,
471 })
472 }
473 };
474
475 Ok(substrait::proto::Rel {
476 rel_type: Some(rel_type),
477 })
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::parser::Parser;
485
486 #[test]
487 fn test_parse_urn_extension_declaration() {
488 let line = "@1: /my/urn1";
489 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
490 assert_eq!(urn.anchor, 1);
491 assert_eq!(urn.urn, "/my/urn1");
492 }
493
494 #[test]
495 fn test_parse_simple_extension_declaration() {
496 let line = "#5@2: my_function_name";
498 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
499 assert_eq!(decl.anchor, 5);
500 assert_eq!(decl.urn_anchor, 2);
501 assert_eq!(decl.name, "my_function_name");
502
503 let line2 = "#10 @200: another_ext_123";
505 let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
506 assert_eq!(decl.anchor, 10);
507 assert_eq!(decl.urn_anchor, 200);
508 assert_eq!(decl.name, "another_ext_123");
509 }
510
511 #[test]
512 fn test_parse_urn_extension_declaration_str() {
513 let line = "@1: /my/urn1";
514 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
515 assert_eq!(urn.anchor, 1);
516 assert_eq!(urn.urn, "/my/urn1");
517 }
518
519 #[test]
520 fn test_extensions_round_trip_plan() {
521 let input = r#"
522=== Extensions
523URNs:
524 @ 1: /urn/common
525 @ 2: /urn/specific_funcs
526Functions:
527 # 10 @ 1: func_a
528 # 11 @ 2: func_b_special
529Types:
530 # 20 @ 1: SomeType
531Type Variations:
532 # 30 @ 2: VarX
533"#
534 .trim_start();
535
536 let plan = Parser::parse(input).unwrap();
538
539 assert_eq!(plan.extension_urns.len(), 2);
541 assert_eq!(plan.extensions.len(), 4);
542
543 let (extensions, errors) =
545 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
546
547 assert!(errors.is_empty());
548 let output = extensions.to_string(" ");
550
551 assert_eq!(output, input);
553 }
554}