1use std::fmt;
2use std::str::FromStr;
3
4use thiserror::Error;
5
6use super::{ParsePair, Rule, RuleIter, unescape_string, unwrap_single_pair};
7use crate::extensions::registry::ExtensionType;
8use crate::extensions::simple::{self, ExtensionKind};
9use crate::extensions::{
10 ExtensionArgs, ExtensionColumn, ExtensionRelationType, ExtensionValue, InsertError,
11 RawExpression, SimpleExtensions,
12};
13use crate::parser::structural::IndentedLine;
14
15#[derive(Debug, Clone, Error)]
16pub enum ExtensionParseError {
17 #[error("Unexpected line, expected {0}")]
18 UnexpectedLine(ExtensionParserState),
19 #[error("Error adding extension: {0}")]
20 ExtensionError(#[from] InsertError),
21 #[error("Error parsing message: {0}")]
22 Message(#[from] super::MessageParseError),
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub enum ExtensionParserState {
29 Extensions,
32 ExtensionUrns,
35 ExtensionDeclarations(ExtensionKind),
38}
39
40impl fmt::Display for ExtensionParserState {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
44 ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
45 ExtensionParserState::ExtensionDeclarations(kind) => {
46 write!(f, "Extension Declaration for {kind}")
47 }
48 }
49 }
50}
51
52#[derive(Debug)]
59pub struct ExtensionParser {
60 state: ExtensionParserState,
61 extensions: SimpleExtensions,
62}
63
64impl Default for ExtensionParser {
65 fn default() -> Self {
66 Self {
67 state: ExtensionParserState::Extensions,
68 extensions: SimpleExtensions::new(),
69 }
70 }
71}
72
73impl ExtensionParser {
74 pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
75 if line.1.is_empty() {
76 self.state = ExtensionParserState::Extensions;
79 return Ok(());
80 }
81
82 match self.state {
83 ExtensionParserState::Extensions => self.parse_subsection(line),
84 ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
85 ExtensionParserState::ExtensionDeclarations(extension_kind) => {
86 self.parse_declarations(line, extension_kind)
87 }
88 }
89 }
90
91 fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
92 match line {
93 IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
94 self.state = ExtensionParserState::ExtensionUrns;
95 Ok(())
96 }
97 IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
98 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
99 Ok(())
100 }
101 IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
102 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
103 Ok(())
104 }
105 IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
106 self.state =
107 ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
108 Ok(())
109 }
110 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
111 }
112 }
113
114 fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
115 match line {
116 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
118 let urn =
119 URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
120 self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
121 Ok(())
122 }
123 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
124 }
125 }
126
127 fn parse_declarations(
128 &mut self,
129 line: IndentedLine,
130 extension_kind: ExtensionKind,
131 ) -> Result<(), ExtensionParseError> {
132 match line {
133 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
135 let decl = SimpleExtensionDeclaration::from_str(s)?;
136 self.extensions.add_extension(
137 extension_kind,
138 decl.urn_anchor,
139 decl.anchor,
140 decl.name,
141 )?;
142 Ok(())
143 }
144 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
145 }
146 }
147
148 pub fn extensions(&self) -> &SimpleExtensions {
149 &self.extensions
150 }
151
152 pub fn state(&self) -> ExtensionParserState {
153 self.state
154 }
155}
156
157#[derive(Debug, Clone, PartialEq)]
158pub struct URNExtensionDeclaration {
159 pub anchor: u32,
160 pub urn: String,
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub struct SimpleExtensionDeclaration {
165 pub anchor: u32,
166 pub urn_anchor: u32,
167 pub name: String,
168}
169
170impl ParsePair for URNExtensionDeclaration {
171 fn rule() -> Rule {
172 Rule::extension_urn_declaration
173 }
174
175 fn message() -> &'static str {
176 "URNExtensionDeclaration"
177 }
178
179 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
180 assert_eq!(pair.as_rule(), Self::rule());
181
182 let mut iter = RuleIter::from(pair.into_inner());
183 let anchor_pair = iter.pop(Rule::urn_anchor);
184 let anchor = unwrap_single_pair(anchor_pair)
185 .as_str()
186 .parse::<u32>()
187 .unwrap();
188 let urn = iter.pop(Rule::urn).as_str().to_string();
189 iter.done();
190
191 URNExtensionDeclaration { anchor, urn }
192 }
193}
194
195impl FromStr for URNExtensionDeclaration {
196 type Err = super::MessageParseError;
197
198 fn from_str(s: &str) -> Result<Self, Self::Err> {
199 Self::parse_str(s)
200 }
201}
202
203impl ParsePair for SimpleExtensionDeclaration {
204 fn rule() -> Rule {
205 Rule::simple_extension
206 }
207
208 fn message() -> &'static str {
209 "SimpleExtensionDeclaration"
210 }
211
212 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
213 assert_eq!(pair.as_rule(), Self::rule());
214 let mut iter = RuleIter::from(pair.into_inner());
215 let anchor_pair = iter.pop(Rule::anchor);
216 let anchor = unwrap_single_pair(anchor_pair)
217 .as_str()
218 .parse::<u32>()
219 .unwrap();
220 let urn_anchor_pair = iter.pop(Rule::urn_anchor);
221 let urn_anchor = unwrap_single_pair(urn_anchor_pair)
222 .as_str()
223 .parse::<u32>()
224 .unwrap();
225 let name_pair = iter.pop(Rule::compound_name);
227 let name = name_pair.as_str().to_string();
228 iter.done();
229
230 SimpleExtensionDeclaration {
231 anchor,
232 urn_anchor,
233 name,
234 }
235 }
236}
237
238impl FromStr for SimpleExtensionDeclaration {
239 type Err = super::MessageParseError;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Self::parse_str(s)
243 }
244}
245
246use crate::extensions::any::Any;
250use crate::parser::expressions::{FieldIndex, Name};
251
252impl ParsePair for ExtensionValue {
253 fn rule() -> Rule {
254 Rule::extension_argument
255 }
256
257 fn message() -> &'static str {
258 "ExtensionValue"
259 }
260
261 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
262 assert_eq!(pair.as_rule(), Self::rule());
263
264 let inner = unwrap_single_pair(pair); match inner.as_rule() {
267 Rule::enum_value => {
268 let s = inner.as_str().trim_start_matches('&').to_string();
270 ExtensionValue::Enum(s)
271 }
272 Rule::reference => {
273 let field_index = FieldIndex::parse_pair(inner);
275 ExtensionValue::Reference(field_index.0)
276 }
277 Rule::literal => {
278 let mut literal_inner = inner.into_inner();
280 let value_pair = literal_inner.next().unwrap();
281 match value_pair.as_rule() {
282 Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
283 Rule::integer => {
284 let int_val = value_pair.as_str().parse::<i64>().unwrap();
285 ExtensionValue::Integer(int_val)
286 }
287 Rule::float => {
288 let float_val = value_pair.as_str().parse::<f64>().unwrap();
289 ExtensionValue::Float(float_val)
290 }
291 Rule::boolean => {
292 let bool_val = value_pair.as_str() == "true";
293 ExtensionValue::Boolean(bool_val)
294 }
295 _ => panic!("Unexpected literal value type: {:?}", value_pair.as_rule()),
296 }
297 }
298 Rule::string_literal => ExtensionValue::String(unescape_string(inner)),
299 Rule::integer => {
300 let int_val = inner.as_str().parse::<i64>().unwrap();
302 ExtensionValue::Integer(int_val)
303 }
304 Rule::float => {
305 let float_val = inner.as_str().parse::<f64>().unwrap();
307 ExtensionValue::Float(float_val)
308 }
309 Rule::boolean => {
310 let bool_val = inner.as_str() == "true";
312 ExtensionValue::Boolean(bool_val)
313 }
314 Rule::expression => {
315 ExtensionValue::Expression(RawExpression::new(inner.as_str().to_string()))
316 }
317 _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
318 }
319 }
320}
321
322impl ParsePair for ExtensionColumn {
323 fn rule() -> Rule {
324 Rule::extension_column
325 }
326
327 fn message() -> &'static str {
328 "ExtensionColumn"
329 }
330
331 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
332 assert_eq!(pair.as_rule(), Self::rule());
333
334 let inner = unwrap_single_pair(pair); match inner.as_rule() {
337 Rule::named_column => {
338 let mut iter = inner.into_inner();
339 let name_pair = iter.next().unwrap(); let type_pair = iter.next().unwrap(); let name = Name::parse_pair(name_pair).0.to_string(); let type_spec = type_pair.as_str().to_string(); ExtensionColumn::Named { name, type_spec }
346 }
347 Rule::reference => {
348 let field_index = FieldIndex::parse_pair(inner);
350 ExtensionColumn::Reference(field_index.0)
351 }
352 Rule::expression => {
353 ExtensionColumn::Expression(RawExpression::new(inner.as_str().to_string()))
354 }
355 _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
356 }
357 }
358}
359
360#[derive(Debug, Clone)]
363pub struct ExtensionInvocation {
364 pub name: String,
365 pub args: ExtensionArgs,
366}
367
368impl ParsePair for ExtensionInvocation {
369 fn rule() -> Rule {
370 Rule::extension_relation
371 }
372
373 fn message() -> &'static str {
374 "ExtensionInvocation"
375 }
376
377 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
378 assert_eq!(pair.as_rule(), Self::rule());
379
380 let mut iter = pair.into_inner();
381
382 let extension_name_pair = iter.next().unwrap(); let full_extension_name = extension_name_pair.as_str();
385
386 let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
389 let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
390 (parts[0], parts[1].to_string())
391 } else {
392 (full_extension_name, "UnknownExtension".to_string())
393 };
394
395 let relation_type = ExtensionRelationType::from_str(relation_type_str).unwrap();
396 let mut args = ExtensionArgs::new(relation_type);
397
398 let ext_arguments = iter.next().unwrap();
400 match ext_arguments.as_rule() {
401 Rule::arguments => {
402 arguments_rule_parsing(ext_arguments, &mut args);
403 }
404 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
405 }
406
407 let extension_columns = iter.next();
409 if let Some(value) = extension_columns {
410 match value.as_rule() {
411 Rule::extension_columns => {
412 for col_pair in value.into_inner() {
413 if col_pair.as_rule() == Rule::extension_column {
414 let column = ExtensionColumn::parse_pair(col_pair);
415 args.output_columns.push(column);
416 }
417 }
418 }
419 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
420 }
421 }
422
423 ExtensionInvocation {
424 name: custom_name,
425 args,
426 }
427 }
428}
429
430#[derive(Debug, Clone)]
432pub struct AdvExtInvocation {
433 pub ext_type: ExtensionType,
438 pub name: String,
439 pub args: ExtensionArgs,
440}
441
442impl ParsePair for AdvExtInvocation {
443 fn rule() -> Rule {
444 Rule::adv_extension
445 }
446
447 fn message() -> &'static str {
448 "AdvExtInvocation"
449 }
450
451 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
452 assert_eq!(pair.as_rule(), Self::rule());
453
454 let mut iter = pair.into_inner();
455
456 let type_pair = iter.next().unwrap(); let ext_type = match type_pair.as_str() {
459 "Enh" => ExtensionType::Enhancement,
460 "Opt" => ExtensionType::Optimization,
461 other => unreachable!("Unexpected adv_ext_type: {other}"),
462 };
463
464 let name_pair = iter.next().unwrap();
466 let name = Name::parse_pair(name_pair).0.to_string();
467
468 let mut args = ExtensionArgs::new(crate::extensions::ExtensionRelationType::Leaf);
471
472 let arguments_pair = iter.next().unwrap();
473 match arguments_pair.as_rule() {
474 Rule::arguments => {
475 arguments_rule_parsing(arguments_pair, &mut args);
476 }
477 r => unreachable!("Unexpected rule in AdvExtInvocation args: {r:?}"),
478 }
479
480 AdvExtInvocation {
481 ext_type,
482 name,
483 args,
484 }
485 }
486}
487
488fn arguments_rule_parsing(inner_pair: pest::iterators::Pair<'_, Rule>, args: &mut ExtensionArgs) {
489 for arg in inner_pair.into_inner() {
490 match arg.as_rule() {
491 Rule::extension_arguments => {
492 for arg_pair in arg.into_inner() {
494 assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
495 args.positional.push(ExtensionValue::parse_pair(arg_pair));
496 }
497 }
498 Rule::extension_named_arguments => {
499 for arg_pair in arg.into_inner() {
500 assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
501 let mut arg_iter = arg_pair.into_inner();
502 let name_p = arg_iter.next().unwrap();
503 let value_p = arg_iter.next().unwrap();
504 let key = Name::parse_pair(name_p).0.to_string();
505 let val = ExtensionValue::parse_pair(value_p);
506 args.named.insert(key, val);
507 }
508 }
509 Rule::empty => {}
510 r => unreachable!("Unexpected rule in extension args: {r:?}"),
511 }
512 }
513}
514
515impl ExtensionRelationType {
516 pub fn create_rel(
519 self,
520 detail: Option<Any>,
521 children: Vec<Box<substrait::proto::Rel>>,
522 ) -> Result<substrait::proto::Rel, String> {
523 use substrait::proto::rel::RelType;
524 use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
525
526 self.validate_child_count(children.len())?;
528
529 let rel_type = match self {
532 ExtensionRelationType::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
533 common: None,
534 detail: detail.map(Into::into),
535 }),
536 ExtensionRelationType::Single => {
537 let input = children.into_iter().next().map(|child| *child);
538 RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
539 common: None,
540 detail: detail.map(Into::into),
541 input: input.map(Box::new),
542 }))
543 }
544 ExtensionRelationType::Multi => {
545 let inputs = children.into_iter().map(|child| *child).collect();
546 RelType::ExtensionMulti(ExtensionMultiRel {
547 common: None,
548 detail: detail.map(Into::into),
549 inputs,
550 })
551 }
552 };
553
554 Ok(substrait::proto::Rel {
555 rel_type: Some(rel_type),
556 })
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use crate::parser::Parser;
564
565 #[test]
566 fn test_parse_urn_extension_declaration() {
567 let line = "@1: /my/urn1";
568 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
569 assert_eq!(urn.anchor, 1);
570 assert_eq!(urn.urn, "/my/urn1");
571 }
572
573 #[test]
574 fn test_parse_simple_extension_declaration() {
575 let line = "#5@2: my_function_name";
577 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
578 assert_eq!(decl.anchor, 5);
579 assert_eq!(decl.urn_anchor, 2);
580 assert_eq!(decl.name, "my_function_name");
581
582 let line2 = "#10 @200: another_ext_123";
584 let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
585 assert_eq!(decl.anchor, 10);
586 assert_eq!(decl.urn_anchor, 200);
587 assert_eq!(decl.name, "another_ext_123");
588 }
589
590 #[test]
591 fn test_parse_urn_extension_declaration_str() {
592 let line = "@1: /my/urn1";
593 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
594 assert_eq!(urn.anchor, 1);
595 assert_eq!(urn.urn, "/my/urn1");
596 }
597
598 #[test]
599 fn test_extensions_round_trip_plan() {
600 let input = r#"
601=== Extensions
602URNs:
603 @ 1: /urn/common
604 @ 2: /urn/specific_funcs
605Functions:
606 # 10 @ 1: func_a
607 # 11 @ 2: func_b_special
608Types:
609 # 20 @ 1: SomeType
610Type Variations:
611 # 30 @ 2: VarX
612"#
613 .trim_start();
614
615 let plan = Parser::parse(input).unwrap();
617
618 assert_eq!(plan.extension_urns.len(), 2);
620 assert_eq!(plan.extensions.len(), 4);
621
622 let (extensions, errors) =
624 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
625
626 assert!(errors.is_empty());
627 let output = extensions.to_string(" ");
629
630 assert_eq!(output, input);
632 }
633
634 #[test]
635 fn test_parse_simple_extension_declaration_compound_name() {
636 let line = "#1 @2: equal:any_any";
638 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
639 assert_eq!(decl.anchor, 1);
640 assert_eq!(decl.urn_anchor, 2);
641 assert_eq!(decl.name, "equal:any_any");
642 }
643
644 #[test]
645 fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
646 let line = "#3 @1: regexp_match_substring:str_str_i64";
647 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
648 assert_eq!(decl.anchor, 3);
649 assert_eq!(decl.urn_anchor, 1);
650 assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
651 }
652
653 #[test]
654 fn test_extensions_round_trip_plan_with_compound_names() {
655 let input = r#"=== Extensions
656URNs:
657 @ 1: extension:io.substrait:functions_string
658 @ 2: extension:io.substrait:functions_comparison
659Functions:
660 # 1 @ 2: equal:any_any
661 # 2 @ 1: regexp_match_substring:str_str
662 # 3 @ 1: regexp_match_substring:str_str_i64
663"#;
664 let plan = Parser::parse(input).unwrap();
665 let (extensions, errors) =
666 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
667 assert!(errors.is_empty());
668 assert_eq!(
670 extensions
671 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
672 .unwrap()
673 .1
674 .full(),
675 "equal:any_any"
676 );
677 assert_eq!(
678 extensions
679 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
680 .unwrap()
681 .1
682 .full(),
683 "regexp_match_substring:str_str_i64"
684 );
685 assert_eq!(extensions.to_string(" "), input);
687 }
688}