1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unescape_string, unwrap_single_pair};
7use crate::extensions::registry::ExtensionType;
8use crate::extensions::simple::{self, ExtensionKind};
9use crate::extensions::{
10 ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, InsertError,
11 RawExpression, SimpleExtensions, TupleValue,
12};
13use crate::parser::structural::IndentedLine;
14
15#[derive(Debug, Clone, Error)]
16pub enum ExtensionParseError {
17 #[error("Unexpected line, expected {0}")]
18 UnexpectedLine(ExtensionParserState),
19 #[error("Error adding extension: {0}")]
20 ExtensionError(#[from] InsertError),
21 #[error("Error parsing message: {0}")]
22 Message(#[from] super::MessageParseError),
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub enum ExtensionParserState {
29 Extensions,
32 ExtensionUrns,
35 ExtensionDeclarations(ExtensionKind),
38}
39
40impl fmt::Display for ExtensionParserState {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
44 ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
45 ExtensionParserState::ExtensionDeclarations(kind) => {
46 write!(f, "Extension Declaration for {kind}")
47 }
48 }
49 }
50}
51
52#[derive(Debug)]
59pub struct ExtensionParser {
60 state: ExtensionParserState,
61 extensions: SimpleExtensions,
62}
63
64impl Default for ExtensionParser {
65 fn default() -> Self {
66 Self {
67 state: ExtensionParserState::Extensions,
68 extensions: SimpleExtensions::new(),
69 }
70 }
71}
72
73impl ExtensionParser {
74 pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
75 if line.1.is_empty() {
76 self.state = ExtensionParserState::Extensions;
79 return Ok(());
80 }
81
82 match self.state {
83 ExtensionParserState::Extensions => self.parse_subsection(line),
84 ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
85 ExtensionParserState::ExtensionDeclarations(extension_kind) => {
86 self.parse_declarations(line, extension_kind)
87 }
88 }
89 }
90
91 fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
92 match line {
93 IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
94 self.state = ExtensionParserState::ExtensionUrns;
95 Ok(())
96 }
97 IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
98 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
99 Ok(())
100 }
101 IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
102 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
103 Ok(())
104 }
105 IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
106 self.state =
107 ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
108 Ok(())
109 }
110 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
111 }
112 }
113
114 fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
115 match line {
116 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
118 let urn =
119 URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
120 self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
121 Ok(())
122 }
123 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
124 }
125 }
126
127 fn parse_declarations(
128 &mut self,
129 line: IndentedLine,
130 extension_kind: ExtensionKind,
131 ) -> Result<(), ExtensionParseError> {
132 match line {
133 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
135 let decl = SimpleExtensionDeclaration::from_str(s)?;
136 self.extensions.add_extension(
137 extension_kind,
138 decl.urn_anchor,
139 decl.anchor,
140 decl.name,
141 )?;
142 Ok(())
143 }
144 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
145 }
146 }
147
148 pub fn extensions(&self) -> &SimpleExtensions {
149 &self.extensions
150 }
151
152 pub fn state(&self) -> ExtensionParserState {
153 self.state
154 }
155}
156
157#[derive(Debug, Clone, PartialEq)]
158pub struct URNExtensionDeclaration {
159 pub anchor: u32,
160 pub urn: String,
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub struct SimpleExtensionDeclaration {
165 pub anchor: u32,
166 pub urn_anchor: u32,
167 pub name: String,
168}
169
170impl ParsePair for URNExtensionDeclaration {
171 fn rule() -> Rule {
172 Rule::extension_urn_declaration
173 }
174
175 fn message() -> &'static str {
176 "URNExtensionDeclaration"
177 }
178
179 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
180 assert_eq!(pair.as_rule(), Self::rule());
181
182 let mut iter = RuleIter::from(pair.into_inner());
183 let anchor_pair = iter.pop(Rule::urn_anchor);
184 let anchor = unwrap_single_pair(anchor_pair)
185 .as_str()
186 .parse::<u32>()
187 .unwrap();
188 let urn = iter.pop(Rule::urn).as_str().to_string();
189 iter.done();
190
191 URNExtensionDeclaration { anchor, urn }
192 }
193}
194
195impl FromStr for URNExtensionDeclaration {
196 type Err = super::MessageParseError;
197
198 fn from_str(s: &str) -> Result<Self, Self::Err> {
199 Self::parse_str(s)
200 }
201}
202
203impl ParsePair for SimpleExtensionDeclaration {
204 fn rule() -> Rule {
205 Rule::simple_extension
206 }
207
208 fn message() -> &'static str {
209 "SimpleExtensionDeclaration"
210 }
211
212 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
213 assert_eq!(pair.as_rule(), Self::rule());
214 let mut iter = RuleIter::from(pair.into_inner());
215 let anchor_pair = iter.pop(Rule::anchor);
216 let anchor = unwrap_single_pair(anchor_pair)
217 .as_str()
218 .parse::<u32>()
219 .unwrap();
220 let urn_anchor_pair = iter.pop(Rule::urn_anchor);
221 let urn_anchor = unwrap_single_pair(urn_anchor_pair)
222 .as_str()
223 .parse::<u32>()
224 .unwrap();
225 let name_pair = iter.pop(Rule::compound_name);
227 let name = name_pair.as_str().to_string();
228 iter.done();
229
230 SimpleExtensionDeclaration {
231 anchor,
232 urn_anchor,
233 name,
234 }
235 }
236}
237
238impl FromStr for SimpleExtensionDeclaration {
239 type Err = super::MessageParseError;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Self::parse_str(s)
243 }
244}
245
246use crate::extensions::any::Any;
250use crate::parser::expressions::{FieldIndex, Name};
251
252impl ParsePair for ExtensionValue {
253 fn rule() -> Rule {
254 Rule::extension_argument
255 }
256
257 fn message() -> &'static str {
258 "ExtensionValue"
259 }
260
261 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
262 assert_eq!(pair.as_rule(), Self::rule());
263
264 let inner = unwrap_single_pair(pair); match inner.as_rule() {
267 Rule::enum_value => {
268 let s = inner.as_str().trim_start_matches('&').to_string();
270 ExtensionValue::Enum(s)
271 }
272 Rule::reference => {
273 let field_index = FieldIndex::parse_pair(inner);
275 ExtensionValue::Reference(field_index.0)
276 }
277 Rule::literal => {
278 let mut literal_inner = inner.into_inner();
280 let value_pair = literal_inner.next().unwrap();
281 match value_pair.as_rule() {
282 Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
283 Rule::integer => {
284 let int_val = value_pair.as_str().parse::<i64>().unwrap();
285 ExtensionValue::Integer(int_val)
286 }
287 Rule::float => {
288 let float_val = value_pair.as_str().parse::<f64>().unwrap();
289 ExtensionValue::Float(float_val)
290 }
291 Rule::boolean => {
292 let bool_val = value_pair.as_str() == "true";
293 ExtensionValue::Boolean(bool_val)
294 }
295 _ => panic!("Unexpected literal value type: {:?}", value_pair.as_rule()),
296 }
297 }
298 Rule::string_literal => ExtensionValue::String(unescape_string(inner)),
299 Rule::integer => {
300 let int_val = inner.as_str().parse::<i64>().unwrap();
302 ExtensionValue::Integer(int_val)
303 }
304 Rule::float => {
305 let float_val = inner.as_str().parse::<f64>().unwrap();
307 ExtensionValue::Float(float_val)
308 }
309 Rule::boolean => {
310 let bool_val = inner.as_str() == "true";
312 ExtensionValue::Boolean(bool_val)
313 }
314 Rule::tuple => {
315 let tv = inner
316 .into_inner()
317 .map(ExtensionValue::parse_pair)
318 .collect::<TupleValue>();
319 ExtensionValue::Tuple(tv)
320 }
321 Rule::expression => {
322 ExtensionValue::Expression(RawExpression::new(inner.as_str().to_string()))
323 }
324 _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
325 }
326 }
327}
328
329impl ParsePair for ExtensionColumn {
330 fn rule() -> Rule {
331 Rule::extension_column
332 }
333
334 fn message() -> &'static str {
335 "ExtensionColumn"
336 }
337
338 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
339 assert_eq!(pair.as_rule(), Self::rule());
340
341 let inner = unwrap_single_pair(pair); match inner.as_rule() {
344 Rule::named_column => {
345 let mut iter = inner.into_inner();
346 let name_pair = iter.next().unwrap(); let type_pair = iter.next().unwrap(); let name = Name::parse_pair(name_pair).0.to_string(); let type_spec = type_pair.as_str().to_string(); ExtensionColumn::Named { name, type_spec }
353 }
354 Rule::reference => {
355 let field_index = FieldIndex::parse_pair(inner);
357 ExtensionColumn::Reference(field_index.0)
358 }
359 Rule::expression => {
360 ExtensionColumn::Expression(RawExpression::new(inner.as_str().to_string()))
361 }
362 _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
363 }
364 }
365}
366
367#[derive(Debug, Clone)]
370pub struct ExtensionInvocation {
371 pub name: String,
372 pub args: ExtensionArgs,
373}
374
375impl ParsePair for ExtensionInvocation {
376 fn rule() -> Rule {
377 Rule::extension_relation
378 }
379
380 fn message() -> &'static str {
381 "ExtensionInvocation"
382 }
383
384 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
385 assert_eq!(pair.as_rule(), Self::rule());
386
387 let mut iter = pair.into_inner();
388
389 let extension_name_pair = iter.next().unwrap(); let full_extension_name = extension_name_pair.as_str();
392
393 let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
396 let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
397 (parts[0], parts[1].to_string())
398 } else {
399 (full_extension_name, "UnknownExtension".to_string())
400 };
401
402 let relation_type = ExtensionRelationType::from_str(relation_type_str).unwrap();
403 let mut args = ExtensionArgs::new(relation_type);
404
405 let ext_arguments = iter.next().unwrap();
407 match ext_arguments.as_rule() {
408 Rule::arguments => {
409 arguments_rule_parsing(ext_arguments, &mut args);
410 }
411 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
412 }
413
414 let extension_columns = iter.next();
416 if let Some(value) = extension_columns {
417 match value.as_rule() {
418 Rule::extension_columns => {
419 for col_pair in value.into_inner() {
420 if col_pair.as_rule() == Rule::extension_column {
421 let column = ExtensionColumn::parse_pair(col_pair);
422 args.output_columns.push(column);
423 }
424 }
425 }
426 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
427 }
428 }
429
430 ExtensionInvocation {
431 name: custom_name,
432 args,
433 }
434 }
435}
436
437#[derive(Debug, Clone)]
439pub struct AdvExtInvocation {
440 pub ext_type: ExtensionType,
445 pub name: String,
446 pub args: ExtensionArgs,
447}
448
449impl ParsePair for AdvExtInvocation {
450 fn rule() -> Rule {
451 Rule::adv_extension
452 }
453
454 fn message() -> &'static str {
455 "AdvExtInvocation"
456 }
457
458 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
459 assert_eq!(pair.as_rule(), Self::rule());
460
461 let mut iter = pair.into_inner();
462
463 let type_pair = iter.next().unwrap(); let ext_type = match type_pair.as_str() {
466 "Enh" => ExtensionType::Enhancement,
467 "Opt" => ExtensionType::Optimization,
468 other => unreachable!("Unexpected adv_ext_type: {other}"),
469 };
470
471 let name_pair = iter.next().unwrap();
473 let name = Name::parse_pair(name_pair).0.to_string();
474
475 let mut args = ExtensionArgs::new(crate::extensions::ExtensionRelationType::Leaf);
478
479 let arguments_pair = iter.next().unwrap();
480 match arguments_pair.as_rule() {
481 Rule::arguments => {
482 arguments_rule_parsing(arguments_pair, &mut args);
483 }
484 r => unreachable!("Unexpected rule in AdvExtInvocation args: {r:?}"),
485 }
486
487 AdvExtInvocation {
488 ext_type,
489 name,
490 args,
491 }
492 }
493}
494
495fn arguments_rule_parsing(inner_pair: pest::iterators::Pair<'_, Rule>, args: &mut ExtensionArgs) {
496 for arg in inner_pair.into_inner() {
497 match arg.as_rule() {
498 Rule::extension_arguments => {
499 for arg_pair in arg.into_inner() {
501 assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
502 args.positional.push(ExtensionValue::parse_pair(arg_pair));
503 }
504 }
505 Rule::extension_named_arguments => {
506 for arg_pair in arg.into_inner() {
507 assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
508 let mut arg_iter = arg_pair.into_inner();
509 let name_p = arg_iter.next().unwrap();
510 let value_p = arg_iter.next().unwrap();
511 let key = Name::parse_pair(name_p).0.to_string();
512 let val = ExtensionValue::parse_pair(value_p);
513 args.named.insert(key, val);
514 }
515 }
516 Rule::empty => {}
517 r => unreachable!("Unexpected rule in extension args: {r:?}"),
518 }
519 }
520}
521
522impl ExtensionRelationType {
523 pub fn create_rel(
526 self,
527 detail: Option<Any>,
528 children: Vec<Box<substrait::proto::Rel>>,
529 ) -> Result<substrait::proto::Rel, String> {
530 use substrait::proto::rel::RelType;
531 use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
532
533 self.validate_child_count(children.len())?;
535
536 let rel_type = match self {
539 ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
540 common: None,
541 detail: detail.map(Into::into),
542 }),
543 ExtensionRelationType::Single => {
544 let input = children.into_iter().next().map(|child| *child);
545 RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
546 common: None,
547 detail: detail.map(Into::into),
548 input: input.map(Box::new),
549 }))
550 }
551 ExtensionRelationType::Multi => {
552 let inputs = children.into_iter().map(|child| *child).collect();
553 RelType::ExtensionMulti(ExtensionMultiRel {
554 common: None,
555 detail: detail.map(Into::into),
556 inputs,
557 })
558 }
559 };
560
561 Ok(substrait::proto::Rel {
562 rel_type: Some(rel_type),
563 })
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use crate::extensions::ExtensionValue;
571 use crate::fixtures::TestContext;
572 use crate::parser::Parser;
573
574 #[test]
575 fn test_parse_urn_extension_declaration() {
576 let line = "@1: /my/urn1";
577 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
578 assert_eq!(urn.anchor, 1);
579 assert_eq!(urn.urn, "/my/urn1");
580 }
581
582 #[test]
583 fn test_parse_simple_extension_declaration() {
584 let line = "#5@2: my_function_name";
586 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
587 assert_eq!(decl.anchor, 5);
588 assert_eq!(decl.urn_anchor, 2);
589 assert_eq!(decl.name, "my_function_name");
590
591 let line2 = "#10 @200: another_ext_123";
593 let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
594 assert_eq!(decl.anchor, 10);
595 assert_eq!(decl.urn_anchor, 200);
596 assert_eq!(decl.name, "another_ext_123");
597 }
598
599 #[test]
600 fn test_parse_urn_extension_declaration_str() {
601 let line = "@1: /my/urn1";
602 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
603 assert_eq!(urn.anchor, 1);
604 assert_eq!(urn.urn, "/my/urn1");
605 }
606
607 #[test]
608 fn test_extensions_round_trip_plan() {
609 let input = r#"
610=== Extensions
611URNs:
612 @ 1: /urn/common
613 @ 2: /urn/specific_funcs
614Functions:
615 # 10 @ 1: func_a
616 # 11 @ 2: func_b_special
617Types:
618 # 20 @ 1: SomeType
619Type Variations:
620 # 30 @ 2: VarX
621"#
622 .trim_start();
623
624 let plan = Parser::parse(input).unwrap();
626
627 assert_eq!(plan.extension_urns.len(), 2);
629 assert_eq!(plan.extensions.len(), 4);
630
631 let (extensions, errors) =
633 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
634
635 assert!(errors.is_empty());
636 let output = extensions.to_string(" ");
638
639 assert_eq!(output, input);
641 }
642
643 #[test]
644 fn test_parse_simple_extension_declaration_compound_name() {
645 let line = "#1 @2: equal:any_any";
647 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
648 assert_eq!(decl.anchor, 1);
649 assert_eq!(decl.urn_anchor, 2);
650 assert_eq!(decl.name, "equal:any_any");
651 }
652
653 #[test]
654 fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
655 let line = "#3 @1: regexp_match_substring:str_str_i64";
656 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
657 assert_eq!(decl.anchor, 3);
658 assert_eq!(decl.urn_anchor, 1);
659 assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
660 }
661
662 #[test]
663 fn test_extensions_round_trip_plan_with_compound_names() {
664 let input = r#"=== Extensions
665URNs:
666 @ 1: extension:io.substrait:functions_string
667 @ 2: extension:io.substrait:functions_comparison
668Functions:
669 # 1 @ 2: equal:any_any
670 # 2 @ 1: regexp_match_substring:str_str
671 # 3 @ 1: regexp_match_substring:str_str_i64
672"#;
673 let plan = Parser::parse(input).unwrap();
674 let (extensions, errors) =
675 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
676 assert!(errors.is_empty());
677 assert_eq!(
679 extensions
680 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
681 .unwrap()
682 .1
683 .full(),
684 "equal:any_any"
685 );
686 assert_eq!(
687 extensions
688 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
689 .unwrap()
690 .1
691 .full(),
692 "regexp_match_substring:str_str_i64"
693 );
694 assert_eq!(extensions.to_string(" "), input);
696 }
697
698 #[test]
699 fn test_tuple_mixed_types_parses() {
700 let val = ExtensionValue::parse_str("(&HASH, 8, 'hello')").unwrap();
702 let ExtensionValue::Tuple(items) = val else {
703 panic!("expected Tuple, got {val:?}");
704 };
705 assert_eq!(items.len(), 3);
706 let items: Vec<&ExtensionValue> = items.iter().collect();
707 assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
708 assert!(matches!(items[1], ExtensionValue::Integer(8)));
709 assert!(matches!(items[2], ExtensionValue::String(s) if s == "hello"));
710 }
711
712 #[test]
713 fn test_empty_tuple_parses() {
714 let val = ExtensionValue::parse_str("()").unwrap();
715 let ExtensionValue::Tuple(items) = val else {
716 panic!("expected Tuple, got {val:?}");
717 };
718 assert!(items.is_empty());
719 }
720
721 #[test]
722 fn test_nested_tuple_parses() {
723 let val = ExtensionValue::parse_str("((&HASH, &RANGE), 8)").unwrap();
724 let ExtensionValue::Tuple(outer) = val else {
725 panic!("expected Tuple, got {val:?}");
726 };
727 assert_eq!(outer.len(), 2);
728 let ExtensionValue::Tuple(inner) = outer.iter().next().unwrap() else {
729 panic!("expected inner Tuple");
730 };
731 assert_eq!(inner.len(), 2);
732 assert!(matches!(inner.iter().next().unwrap(), ExtensionValue::Enum(s) if s == "HASH"));
733 assert!(matches!(
734 outer.iter().nth(1).unwrap(),
735 ExtensionValue::Integer(8)
736 ));
737 }
738
739 #[test]
740 fn test_tuple_in_adv_extension_parses() {
741 let inv = AdvExtInvocation::parse_str("+ Enh:Foo[(&HASH, &RANGE), count=8]").unwrap();
742 assert_eq!(inv.name, "Foo");
743 assert_eq!(inv.args.positional.len(), 1);
744 let ExtensionValue::Tuple(items) = &inv.args.positional[0] else {
745 panic!("expected Tuple positional arg");
746 };
747 assert_eq!(items.len(), 2);
748 let items: Vec<&ExtensionValue> = items.iter().collect();
749 assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
750 assert!(matches!(items[1], ExtensionValue::Enum(s) if s == "RANGE"));
751 assert_eq!(inv.args.named.len(), 1);
752 }
753
754 #[test]
755 fn test_tuple_textify_roundtrip() {
756 let ctx = TestContext::new();
757 for text in &[
758 "(&HASH, &RANGE)",
759 "(&HASH, 8, 'hello')",
760 "()",
761 "(&HASH,)",
762 "((&HASH, &RANGE), 8)",
763 ] {
764 let val = ExtensionValue::parse_str(text).unwrap();
765 let rendered = ctx.textify_no_errors(&val);
766 assert_eq!(&rendered, text, "roundtrip failed for {text}");
767 }
768 }
769}