1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType};
8use substrait::proto::fetch_rel::{CountMode, OffsetMode};
9use substrait::proto::rel::RelType;
10use substrait::proto::rel_common::{Emit, EmitKind};
11use substrait::proto::sort_field::{SortDirection, SortKind};
12use substrait::proto::{
13 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
14 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
15};
16
17use super::{
18 ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
19};
20use crate::extensions::any::Any;
21use crate::extensions::registry::{ExtensionError, ExtensionType};
22use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions};
23use crate::parser::errors::{ParseContext, ParseError};
24use crate::parser::expressions::{FieldIndex, Name};
25
26pub struct RelationParsingContext<'a> {
28 pub extensions: &'a SimpleExtensions,
29 pub registry: &'a ExtensionRegistry,
30 pub line_no: i64,
31 pub line: &'a str,
32}
33
34impl<'a> RelationParsingContext<'a> {
35 pub fn resolve_extension_detail(
37 &self,
38 extension_name: &str,
39 extension_args: &ExtensionArgs,
40 ) -> Result<Option<Any>, ParseError> {
41 let detail = self
42 .registry
43 .parse_extension(extension_name, extension_args);
44
45 match detail {
46 Ok(any) => Ok(Some(any)),
47 Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
48 name: extension_name.to_string(),
49 context: ParseContext::new(self.line_no, self.line.to_string()),
50 }),
51 Err(err) => Err(ParseError::ExtensionDetail(
52 ParseContext::new(self.line_no, self.line.to_string()),
53 err,
54 )),
55 }
56 }
57
58 pub fn resolve_adv_ext_detail(
64 &self,
65 ext_type: ExtensionType,
66 name: &str,
67 args: &ExtensionArgs,
68 ) -> Result<Any, ParseError> {
69 let result = match ext_type {
70 ExtensionType::Enhancement => self.registry.parse_enhancement(name, args),
71 ExtensionType::Optimization => self.registry.parse_optimization(name, args),
72 ExtensionType::Relation => unreachable!("Relation is not an advanced extension type"),
73 };
74 result.map_err(|err| match err {
75 ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
76 name: name.to_string(),
77 context: ParseContext::new(self.line_no, self.line.to_string()),
78 },
79 err => ParseError::ExtensionDetail(
80 ParseContext::new(self.line_no, self.line.to_string()),
81 err,
82 ),
83 })
84 }
85}
86
87pub trait RelationParsePair: Sized {
90 fn rule() -> Rule;
91
92 fn message() -> &'static str;
93
94 fn parse_pair_with_context(
102 extensions: &SimpleExtensions,
103 pair: Pair<Rule>,
104 input_children: Vec<Box<Rel>>,
105 _input_field_count: usize,
106 ) -> Result<Self, MessageParseError>;
107
108 fn into_rel(self) -> Rel;
109}
110
111pub struct TableName(Vec<String>);
112
113impl ParsePair for TableName {
114 fn rule() -> Rule {
115 Rule::table_name
116 }
117
118 fn message() -> &'static str {
119 "TableName"
120 }
121
122 fn parse_pair(pair: Pair<Rule>) -> Self {
123 assert_eq!(pair.as_rule(), Self::rule());
124 let pairs = pair.into_inner();
125 let mut names = Vec::with_capacity(pairs.len());
126 let mut iter = RuleIter::from(pairs);
127 while let Some(name) = iter.parse_if_next::<Name>() {
128 names.push(name.0);
129 }
130 iter.done();
131 Self(names)
132 }
133}
134
135#[derive(Debug, Clone)]
136pub struct Column {
137 pub name: String,
138 pub typ: Type,
139}
140
141impl ScopedParsePair for Column {
142 fn rule() -> Rule {
143 Rule::named_column
144 }
145
146 fn message() -> &'static str {
147 "Column"
148 }
149
150 fn parse_pair(
151 extensions: &SimpleExtensions,
152 pair: Pair<Rule>,
153 ) -> Result<Self, MessageParseError> {
154 assert_eq!(pair.as_rule(), Self::rule());
155 let mut iter = RuleIter::from(pair.into_inner());
156 let name = iter.parse_next::<Name>().0;
157 let typ = iter.parse_next_scoped(extensions)?;
158 iter.done();
159 Ok(Self { name, typ })
160 }
161}
162
163pub struct NamedColumnList(Vec<Column>);
164
165impl ScopedParsePair for NamedColumnList {
166 fn rule() -> Rule {
167 Rule::named_column_list
168 }
169
170 fn message() -> &'static str {
171 "NamedColumnList"
172 }
173
174 fn parse_pair(
175 extensions: &SimpleExtensions,
176 pair: Pair<Rule>,
177 ) -> Result<Self, MessageParseError> {
178 assert_eq!(pair.as_rule(), Self::rule());
179 let mut columns = Vec::new();
180 for col in pair.into_inner() {
181 columns.push(Column::parse_pair(extensions, col)?);
182 }
183 Ok(Self(columns))
184 }
185}
186
187#[allow(clippy::vec_box)]
192pub(crate) fn expect_one_child(
193 message: &'static str,
194 pair: &Pair<Rule>,
195 mut input_children: Vec<Box<Rel>>,
196) -> Result<Box<Rel>, MessageParseError> {
197 match input_children.len() {
198 0 => Err(MessageParseError::invalid(
199 message,
200 pair.as_span(),
201 format!("{message} missing child"),
202 )),
203 1 => Ok(input_children.pop().unwrap()),
204 n => Err(MessageParseError::invalid(
205 message,
206 pair.as_span(),
207 format!("{message} should have 1 input child, got {n}"),
208 )),
209 }
210}
211
212fn parse_reference_emit(pair: Pair<Rule>) -> EmitKind {
214 assert_eq!(pair.as_rule(), Rule::reference_list);
215 let output_mapping = pair
216 .into_inner()
217 .map(|p| FieldIndex::parse_pair(p).0)
218 .collect::<Vec<i32>>();
219 EmitKind::Emit(Emit { output_mapping })
220}
221
222pub struct ParsedNamedArgs<'a> {
228 map: HashMap<&'a str, Pair<'a, Rule>>,
229}
230
231impl<'a> ParsedNamedArgs<'a> {
232 pub fn new(
233 pairs: pest::iterators::Pairs<'a, Rule>,
234 rule: Rule,
235 ) -> Result<Self, MessageParseError> {
236 let mut map = HashMap::new();
237 for pair in pairs {
238 assert_eq!(pair.as_rule(), rule);
239 let mut inner = pair.clone().into_inner();
240 let name_pair = inner.next().unwrap();
241 let value_pair = inner.next().unwrap();
242 assert_eq!(inner.next(), None);
243 let name = name_pair.as_str();
244 if map.contains_key(name) {
245 return Err(MessageParseError::invalid(
246 "NamedArg",
247 name_pair.as_span(),
248 format!("Duplicate argument: {name}"),
249 ));
250 }
251 map.insert(name, value_pair);
252 }
253 Ok(Self { map })
254 }
255
256 pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
260 let pair = self.map.remove(name).inspect(|pair| {
261 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
262 });
263 (self, pair)
264 }
265
266 pub fn done(self) -> Result<(), MessageParseError> {
268 if let Some((name, pair)) = self.map.iter().next() {
269 return Err(MessageParseError::invalid(
270 "NamedArgExtractor",
271 pair.as_span(),
273 format!("Unknown argument: {name}"),
274 ));
275 }
276 Ok(())
277 }
278}
279
280impl RelationParsePair for ReadRel {
281 fn rule() -> Rule {
282 Rule::read_relation
283 }
284
285 fn message() -> &'static str {
286 "ReadRel"
287 }
288
289 fn into_rel(self) -> Rel {
290 Rel {
291 rel_type: Some(RelType::Read(Box::new(self))),
292 }
293 }
294
295 fn parse_pair_with_context(
296 extensions: &SimpleExtensions,
297 pair: Pair<Rule>,
298 input_children: Vec<Box<Rel>>,
299 _input_field_count: usize,
300 ) -> Result<Self, MessageParseError> {
301 assert_eq!(pair.as_rule(), Self::rule());
302 if !input_children.is_empty() {
304 return Err(MessageParseError::invalid(
305 Self::message(),
306 pair.as_span(),
307 "ReadRel should have no input children",
308 ));
309 }
310 if _input_field_count != 0 {
311 let error = pest::error::Error::new_from_span(
312 pest::error::ErrorVariant::CustomError {
313 message: "ReadRel should have 0 input fields".to_string(),
314 },
315 pair.as_span(),
316 );
317 return Err(MessageParseError::new(
318 "ReadRel",
319 ErrorKind::InvalidValue,
320 Box::new(error),
321 ));
322 }
323
324 let mut iter = RuleIter::from(pair.into_inner());
325 let table = iter.parse_next::<TableName>().0;
326 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
327 iter.done();
328
329 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
330 let struct_ = r#type::Struct {
331 types,
332 type_variation_reference: 0,
333 nullability: r#type::Nullability::Required as i32,
334 };
335 let named_struct = NamedStruct {
336 names,
337 r#struct: Some(struct_),
338 };
339
340 let read_rel = ReadRel {
341 base_schema: Some(named_struct),
342 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
343 names: table,
344 advanced_extension: None,
345 })),
346 ..Default::default()
347 };
348
349 Ok(read_rel)
350 }
351}
352
353impl RelationParsePair for FilterRel {
354 fn rule() -> Rule {
355 Rule::filter_relation
356 }
357
358 fn message() -> &'static str {
359 "FilterRel"
360 }
361
362 fn into_rel(self) -> Rel {
363 Rel {
364 rel_type: Some(RelType::Filter(Box::new(self))),
365 }
366 }
367
368 fn parse_pair_with_context(
369 extensions: &SimpleExtensions,
370 pair: Pair<Rule>,
371 input_children: Vec<Box<Rel>>,
372 _input_field_count: usize,
373 ) -> Result<Self, MessageParseError> {
374 assert_eq!(pair.as_rule(), Self::rule());
377 let input = expect_one_child(Self::message(), &pair, input_children)?;
378 let mut iter = RuleIter::from(pair.into_inner());
379 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
381 let references_pair = iter.pop(Rule::reference_list);
383 iter.done();
384
385 let emit = parse_reference_emit(references_pair);
386 let common = RelCommon {
387 emit_kind: Some(emit),
388 ..Default::default()
389 };
390
391 Ok(FilterRel {
392 input: Some(input),
393 condition: Some(Box::new(condition)),
394 common: Some(common),
395 advanced_extension: None,
396 })
397 }
398}
399
400impl RelationParsePair for ProjectRel {
401 fn rule() -> Rule {
402 Rule::project_relation
403 }
404
405 fn message() -> &'static str {
406 "ProjectRel"
407 }
408
409 fn into_rel(self) -> Rel {
410 Rel {
411 rel_type: Some(RelType::Project(Box::new(self))),
412 }
413 }
414
415 fn parse_pair_with_context(
416 extensions: &SimpleExtensions,
417 pair: Pair<Rule>,
418 input_children: Vec<Box<Rel>>,
419 _input_field_count: usize,
420 ) -> Result<Self, MessageParseError> {
421 assert_eq!(pair.as_rule(), Self::rule());
422 let input = expect_one_child(Self::message(), &pair, input_children)?;
423
424 let arguments_pair = unwrap_single_pair(pair);
426
427 let mut expressions = Vec::new();
428 let mut output_mapping = Vec::new();
429
430 for arg in arguments_pair.into_inner() {
432 let inner_arg = unwrap_single_pair(arg);
433 match inner_arg.as_rule() {
434 Rule::reference => {
435 let field_index = FieldIndex::parse_pair(inner_arg);
437 output_mapping.push(field_index.0);
438 }
439 Rule::expression => {
440 let _expr = Expression::parse_pair(extensions, inner_arg)?;
442 expressions.push(_expr);
443 output_mapping.push(_input_field_count as i32 + (expressions.len() as i32 - 1));
445 }
446 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
447 }
448 }
449
450 let emit = EmitKind::Emit(Emit { output_mapping });
451 let common = RelCommon {
452 emit_kind: Some(emit),
453 ..Default::default()
454 };
455
456 Ok(ProjectRel {
457 input: Some(input),
458 expressions,
459 common: Some(common),
460 advanced_extension: None,
461 })
462 }
463}
464
465impl RelationParsePair for AggregateRel {
466 fn rule() -> Rule {
467 Rule::aggregate_relation
468 }
469
470 fn message() -> &'static str {
471 "AggregateRel"
472 }
473
474 fn into_rel(self) -> Rel {
475 Rel {
476 rel_type: Some(RelType::Aggregate(Box::new(self))),
477 }
478 }
479
480 fn parse_pair_with_context(
481 extensions: &SimpleExtensions,
482 pair: Pair<Rule>,
483 input_children: Vec<Box<Rel>>,
484 _input_field_count: usize,
485 ) -> Result<Self, MessageParseError> {
486 assert_eq!(pair.as_rule(), Self::rule());
487 let input = expect_one_child(Self::message(), &pair, input_children)?;
488 let mut iter = RuleIter::from(pair.into_inner());
489 let group_by_pair = iter.pop(Rule::aggregate_group_by);
490 let output_pair = iter.pop(Rule::aggregate_output);
491 iter.done();
492
493 let inner = group_by_pair
494 .into_inner()
495 .next()
496 .expect("aggregate_group_by must have one inner item");
497
498 let grouping_sets = parse_grouping_sets(extensions, inner);
499 let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
500
501 let (measures, output_mapping) =
502 parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
503
504 let emit = EmitKind::Emit(Emit { output_mapping });
505 let common = RelCommon {
506 emit_kind: Some(emit),
507 ..Default::default()
508 };
509
510 Ok(AggregateRel {
511 input: Some(input),
512 grouping_expressions,
513 groupings,
514 measures,
515 common: Some(common),
516 advanced_extension: None,
517 })
518 }
519}
520
521fn parse_aggregate_measures(
526 extensions: &SimpleExtensions,
527 output_pair: Pair<'_, Rule>,
528 grouping_expressions: &[Expression],
529) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
530 assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
531 let mut measures = Vec::new();
532 let mut output_mapping = Vec::new();
533
534 for aggregate_output_item in output_pair.into_inner() {
535 let inner_item = unwrap_single_pair(aggregate_output_item);
536 match inner_item.as_rule() {
537 Rule::reference => {
538 let field_index = FieldIndex::parse_pair(inner_item);
539 output_mapping.push(field_index.0);
540 }
541 Rule::aggregate_measure => {
542 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
543 output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
544 measures.push(measure);
545 }
546 _ => panic!(
547 "Unexpected inner output item rule: {:?}",
548 inner_item.as_rule()
549 ),
550 }
551 }
552
553 Ok((measures, output_mapping))
554}
555
556fn parse_grouping_sets(
565 extensions: &SimpleExtensions,
566 inner: Pair<'_, Rule>,
567) -> Vec<Vec<Expression>> {
568 assert!(
569 matches!(
570 inner.as_rule(),
571 Rule::expression_list | Rule::grouping_set_list
572 ),
573 "Expected expression_list or grouping_set_list, got {:?}",
574 inner.as_rule()
575 );
576 match inner.as_rule() {
577 Rule::expression_list => {
578 vec![parse_expression_list(extensions, inner)]
579 }
580 Rule::grouping_set_list => inner
581 .into_inner()
582 .map(|pair| parse_grouping_set(extensions, pair))
583 .collect(),
584 _ => unreachable!(
585 "Unexpected rule in aggregate_group_by: {:?}",
586 inner.as_rule()
587 ),
588 }
589}
590
591fn parse_grouping_set(extensions: &SimpleExtensions, pair: Pair<'_, Rule>) -> Vec<Expression> {
595 assert_eq!(pair.as_rule(), Rule::grouping_set);
596 let inner = pair
597 .into_inner()
598 .next()
599 .expect("grouping_set must have one inner item");
600 match inner.as_rule() {
601 Rule::empty => vec![],
602 Rule::expression_list => parse_expression_list(extensions, inner),
603 _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
604 }
605}
606
607fn parse_expression_list(extensions: &SimpleExtensions, pair: Pair<'_, Rule>) -> Vec<Expression> {
609 pair.into_inner()
610 .map(|expr_pair| {
611 Expression::parse_pair(extensions, expr_pair)
612 .expect("By the grammar rule, only expressions should be parsed")
613 })
614 .collect()
615}
616
617fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
621 let mut expressions: Vec<Expression> = Vec::new();
622 let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
623
624 let groupings = expression_sets
625 .iter()
626 .map(|set| {
627 let expression_references = set
628 .iter()
629 .map(|exp| {
630 let key = exp.encode_to_vec();
634 let next_idx = expressions.len() as u32;
635 *seen.entry(key).or_insert_with(|| {
636 expressions.push(exp.clone());
637 next_idx
638 })
639 })
640 .collect();
641 Grouping {
642 expression_references,
643 #[allow(deprecated)]
644 grouping_expressions: vec![],
645 }
646 })
647 .collect();
648
649 (groupings, expressions)
650}
651
652impl ScopedParsePair for SortField {
653 fn rule() -> Rule {
654 Rule::sort_field
655 }
656
657 fn message() -> &'static str {
658 "SortField"
659 }
660
661 fn parse_pair(
662 _extensions: &SimpleExtensions,
663 pair: Pair<Rule>,
664 ) -> Result<Self, MessageParseError> {
665 assert_eq!(pair.as_rule(), Self::rule());
666 let mut iter = RuleIter::from(pair.into_inner());
667 let reference_pair = iter.pop(Rule::reference);
668 let field_index = FieldIndex::parse_pair(reference_pair);
669 let direction_pair = iter.pop(Rule::sort_direction);
670 let direction = match direction_pair.as_str().trim_start_matches('&') {
674 "AscNullsFirst" => SortDirection::AscNullsFirst,
675 "AscNullsLast" => SortDirection::AscNullsLast,
676 "DescNullsFirst" => SortDirection::DescNullsFirst,
677 "DescNullsLast" => SortDirection::DescNullsLast,
678 other => {
679 return Err(MessageParseError::invalid(
680 "SortDirection",
681 direction_pair.as_span(),
682 format!("Unknown sort direction: {other}"),
683 ));
684 }
685 };
686 iter.done();
687 Ok(SortField {
688 expr: Some(Expression {
689 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
690 field_index.to_field_reference(),
691 ))),
692 }),
693 sort_kind: Some(SortKind::Direction(direction as i32)),
695 })
696 }
697}
698
699impl RelationParsePair for SortRel {
700 fn rule() -> Rule {
701 Rule::sort_relation
702 }
703
704 fn message() -> &'static str {
705 "SortRel"
706 }
707
708 fn into_rel(self) -> Rel {
709 Rel {
710 rel_type: Some(RelType::Sort(Box::new(self))),
711 }
712 }
713
714 fn parse_pair_with_context(
715 extensions: &SimpleExtensions,
716 pair: Pair<Rule>,
717 input_children: Vec<Box<Rel>>,
718 _input_field_count: usize,
719 ) -> Result<Self, MessageParseError> {
720 assert_eq!(pair.as_rule(), Self::rule());
721 let input = expect_one_child(Self::message(), &pair, input_children)?;
722 let mut iter = RuleIter::from(pair.into_inner());
723 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
724 let reference_list_pair = iter.pop(Rule::reference_list);
725 let mut sorts = Vec::new();
726 for sort_field_pair in sort_field_list_pair.into_inner() {
727 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
728 sorts.push(sort_field);
729 }
730 let emit = parse_reference_emit(reference_list_pair);
731 let common = RelCommon {
732 emit_kind: Some(emit),
733 ..Default::default()
734 };
735 iter.done();
736 Ok(SortRel {
737 input: Some(input),
738 sorts,
739 common: Some(common),
740 advanced_extension: None,
741 })
742 }
743}
744
745impl ScopedParsePair for CountMode {
746 fn rule() -> Rule {
747 Rule::fetch_value
748 }
749 fn message() -> &'static str {
750 "CountMode"
751 }
752 fn parse_pair(
753 extensions: &SimpleExtensions,
754 pair: Pair<Rule>,
755 ) -> Result<Self, MessageParseError> {
756 assert_eq!(pair.as_rule(), Self::rule());
757 let mut arg_inner = RuleIter::from(pair.into_inner());
758 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
759 int_pair
760 } else {
761 arg_inner.pop(Rule::expression)
762 };
763 match value_pair.as_rule() {
764 Rule::integer => {
765 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
766 MessageParseError::invalid(
767 Self::message(),
768 value_pair.as_span(),
769 format!("Invalid integer: {e}"),
770 )
771 })?;
772 if value < 0 {
773 return Err(MessageParseError::invalid(
774 Self::message(),
775 value_pair.as_span(),
776 format!("Fetch limit must be non-negative, got: {value}"),
777 ));
778 }
779 Ok(CountMode::CountExpr(i64_literal_expr(value)))
780 }
781 Rule::expression => {
782 let expr = Expression::parse_pair(extensions, value_pair)?;
783 Ok(CountMode::CountExpr(Box::new(expr)))
784 }
785 _ => Err(MessageParseError::invalid(
786 Self::message(),
787 value_pair.as_span(),
788 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
789 )),
790 }
791 }
792}
793
794fn i64_literal_expr(value: i64) -> Box<Expression> {
795 Box::new(Expression {
796 rex_type: Some(RexType::Literal(Literal {
797 nullable: false,
798 type_variation_reference: 0,
799 literal_type: Some(LiteralType::I64(value)),
800 })),
801 })
802}
803
804impl ScopedParsePair for OffsetMode {
805 fn rule() -> Rule {
806 Rule::fetch_value
807 }
808 fn message() -> &'static str {
809 "OffsetMode"
810 }
811 fn parse_pair(
812 extensions: &SimpleExtensions,
813 pair: Pair<Rule>,
814 ) -> Result<Self, MessageParseError> {
815 assert_eq!(pair.as_rule(), Self::rule());
816 let mut arg_inner = RuleIter::from(pair.into_inner());
817 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
818 int_pair
819 } else {
820 arg_inner.pop(Rule::expression)
821 };
822 match value_pair.as_rule() {
823 Rule::integer => {
824 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
825 MessageParseError::invalid(
826 Self::message(),
827 value_pair.as_span(),
828 format!("Invalid integer: {e}"),
829 )
830 })?;
831 if value < 0 {
832 return Err(MessageParseError::invalid(
833 Self::message(),
834 value_pair.as_span(),
835 format!("Fetch offset must be non-negative, got: {value}"),
836 ));
837 }
838 Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
839 }
840 Rule::expression => {
841 let expr = Expression::parse_pair(extensions, value_pair)?;
842 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
843 }
844 _ => Err(MessageParseError::invalid(
845 Self::message(),
846 value_pair.as_span(),
847 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
848 )),
849 }
850 }
851}
852
853impl RelationParsePair for FetchRel {
854 fn rule() -> Rule {
855 Rule::fetch_relation
856 }
857
858 fn message() -> &'static str {
859 "FetchRel"
860 }
861
862 fn into_rel(self) -> Rel {
863 Rel {
864 rel_type: Some(RelType::Fetch(Box::new(self))),
865 }
866 }
867
868 fn parse_pair_with_context(
869 extensions: &SimpleExtensions,
870 pair: Pair<Rule>,
871 input_children: Vec<Box<Rel>>,
872 _input_field_count: usize,
873 ) -> Result<Self, MessageParseError> {
874 assert_eq!(pair.as_rule(), Self::rule());
875 let input = expect_one_child(Self::message(), &pair, input_children)?;
876 let mut iter = RuleIter::from(pair.into_inner());
877
878 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
880 None => {
881 iter.pop(Rule::empty);
883 (None, None)
884 }
885 Some(fetch_args_pair) => {
886 let extractor =
887 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
888 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
889 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
890 extractor.done()?;
891 (limit_pair, offset_pair)
892 }
893 };
894
895 let reference_list_pair = iter.pop(Rule::reference_list);
896 let emit = parse_reference_emit(reference_list_pair);
897 let common = RelCommon {
898 emit_kind: Some(emit),
899 ..Default::default()
900 };
901 iter.done();
902
903 let count_mode = limit_pair
905 .map(|pair| CountMode::parse_pair(extensions, pair))
906 .transpose()?;
907 let offset_mode = offset_pair
908 .map(|pair| OffsetMode::parse_pair(extensions, pair))
909 .transpose()?;
910 Ok(FetchRel {
911 input: Some(input),
912 common: Some(common),
913 advanced_extension: None,
914 offset_mode,
915 count_mode,
916 })
917 }
918}
919
920impl ParsePair for join_rel::JoinType {
921 fn rule() -> Rule {
922 Rule::join_type
923 }
924
925 fn message() -> &'static str {
926 "JoinType"
927 }
928
929 fn parse_pair(pair: Pair<Rule>) -> Self {
930 assert_eq!(pair.as_rule(), Self::rule());
931 let join_type_str = pair.as_str().trim_start_matches('&');
932 match join_type_str {
933 "Inner" => join_rel::JoinType::Inner,
934 "Left" => join_rel::JoinType::Left,
935 "Right" => join_rel::JoinType::Right,
936 "Outer" => join_rel::JoinType::Outer,
937 "LeftSemi" => join_rel::JoinType::LeftSemi,
938 "RightSemi" => join_rel::JoinType::RightSemi,
939 "LeftAnti" => join_rel::JoinType::LeftAnti,
940 "RightAnti" => join_rel::JoinType::RightAnti,
941 "LeftSingle" => join_rel::JoinType::LeftSingle,
942 "RightSingle" => join_rel::JoinType::RightSingle,
943 "LeftMark" => join_rel::JoinType::LeftMark,
944 "RightMark" => join_rel::JoinType::RightMark,
945 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
946 }
947 }
948}
949
950impl RelationParsePair for JoinRel {
951 fn rule() -> Rule {
952 Rule::join_relation
953 }
954
955 fn message() -> &'static str {
956 "JoinRel"
957 }
958
959 fn into_rel(self) -> Rel {
960 Rel {
961 rel_type: Some(RelType::Join(Box::new(self))),
962 }
963 }
964
965 fn parse_pair_with_context(
966 extensions: &SimpleExtensions,
967 pair: Pair<Rule>,
968 input_children: Vec<Box<Rel>>,
969 _input_field_count: usize,
970 ) -> Result<Self, MessageParseError> {
971 assert_eq!(pair.as_rule(), Self::rule());
972
973 if input_children.len() != 2 {
975 return Err(MessageParseError::invalid(
976 Self::message(),
977 pair.as_span(),
978 format!(
979 "JoinRel should have exactly 2 input children, got {}",
980 input_children.len()
981 ),
982 ));
983 }
984
985 let mut children_iter = input_children.into_iter();
986 let left = children_iter.next().unwrap();
987 let right = children_iter.next().unwrap();
988
989 let mut iter = RuleIter::from(pair.into_inner());
990
991 let join_type = iter.parse_next::<join_rel::JoinType>();
993
994 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
996
997 let reference_list_pair = iter.pop(Rule::reference_list);
999 iter.done();
1000
1001 let emit = parse_reference_emit(reference_list_pair);
1002 let common = RelCommon {
1003 emit_kind: Some(emit),
1004 ..Default::default()
1005 };
1006
1007 Ok(JoinRel {
1008 common: Some(common),
1009 left: Some(left),
1010 right: Some(right),
1011 expression: Some(Box::new(condition)),
1012 post_join_filter: None, r#type: join_type as i32,
1014 advanced_extension: None,
1015 })
1016 }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use pest::Parser;
1022
1023 use super::*;
1024 use crate::fixtures::TestContext;
1025 use crate::parser::{ExpressionParser, Rule};
1026
1027 #[test]
1028 fn test_parse_relation() {
1029 }
1031
1032 #[test]
1033 fn test_parse_read_relation() {
1034 let extensions = SimpleExtensions::default();
1035 let read = ReadRel::parse_pair_with_context(
1036 &extensions,
1037 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1038 vec![],
1039 0,
1040 )
1041 .unwrap();
1042 let names = match &read.read_type {
1043 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1044 _ => panic!("Expected NamedTable"),
1045 };
1046 assert_eq!(names, &["ab", "cd", "ef"]);
1047 let columns = &read
1048 .base_schema
1049 .as_ref()
1050 .unwrap()
1051 .r#struct
1052 .as_ref()
1053 .unwrap()
1054 .types;
1055 assert_eq!(columns.len(), 2);
1056 }
1057
1058 fn example_read_relation() -> ReadRel {
1060 let extensions = SimpleExtensions::default();
1061 ReadRel::parse_pair_with_context(
1062 &extensions,
1063 parse_exact(
1064 Rule::read_relation,
1065 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1066 ),
1067 vec![],
1068 0,
1069 )
1070 .unwrap()
1071 }
1072
1073 #[test]
1074 fn test_parse_filter_relation() {
1075 let extensions = SimpleExtensions::default();
1076 let filter = FilterRel::parse_pair_with_context(
1077 &extensions,
1078 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1079 vec![Box::new(example_read_relation().into_rel())],
1080 3,
1081 )
1082 .unwrap();
1083 let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1084 let emit = match emit_kind {
1085 EmitKind::Emit(emit) => &emit.output_mapping,
1086 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1087 };
1088 assert_eq!(emit, &[0, 1, 2]);
1089 }
1090
1091 #[test]
1092 fn test_parse_project_relation() {
1093 let extensions = SimpleExtensions::default();
1094 let project = ProjectRel::parse_pair_with_context(
1095 &extensions,
1096 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1097 vec![Box::new(example_read_relation().into_rel())],
1098 3,
1099 )
1100 .unwrap();
1101
1102 assert_eq!(project.expressions.len(), 1);
1104
1105 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1106 let emit = match emit_kind {
1107 EmitKind::Emit(emit) => &emit.output_mapping,
1108 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1109 };
1110 assert_eq!(emit, &[0, 1, 3]);
1112 }
1113
1114 #[test]
1115 fn test_parse_project_relation_complex() {
1116 let extensions = SimpleExtensions::default();
1117 let project = ProjectRel::parse_pair_with_context(
1118 &extensions,
1119 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1120 vec![Box::new(example_read_relation().into_rel())],
1121 5, )
1123 .unwrap();
1124
1125 assert_eq!(project.expressions.len(), 2);
1127
1128 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1129 let emit = match emit_kind {
1130 EmitKind::Emit(emit) => &emit.output_mapping,
1131 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1132 };
1133 assert_eq!(emit, &[5, 0, 6, 2, 1]);
1136 }
1137
1138 #[test]
1139 fn test_parse_aggregate_relation() {
1140 let extensions = TestContext::new()
1141 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1142 .with_function(1, 10, "sum")
1143 .with_function(1, 11, "count")
1144 .extensions;
1145
1146 let aggregate = AggregateRel::parse_pair_with_context(
1147 &extensions,
1148 parse_exact(
1149 Rule::aggregate_relation,
1150 "Aggregate[($0, $1), _ => sum($2), $0, count($2)]",
1151 ),
1152 vec![Box::new(example_read_relation().into_rel())],
1153 3,
1154 )
1155 .unwrap();
1156
1157 assert_eq!(aggregate.grouping_expressions.len(), 2);
1159 assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1160 assert_eq!(aggregate.groupings.len(), 2);
1161 assert_eq!(aggregate.measures.len(), 2);
1162
1163 let emit_kind = &aggregate
1164 .common
1165 .as_ref()
1166 .unwrap()
1167 .emit_kind
1168 .as_ref()
1169 .unwrap();
1170 let emit = match emit_kind {
1171 EmitKind::Emit(emit) => &emit.output_mapping,
1172 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1173 };
1174 assert_eq!(emit, &[2, 0, 3]);
1177 }
1178
1179 #[test]
1180 fn test_parse_aggregate_relation_maintain_column_order() {
1181 let extensions = TestContext::new()
1182 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1183 .with_function(1, 10, "sum")
1184 .with_function(1, 11, "count")
1185 .extensions;
1186
1187 let aggregate = AggregateRel::parse_pair_with_context(
1188 &extensions,
1189 parse_exact(
1190 Rule::aggregate_relation,
1191 "Aggregate[$0 => sum($1), $0, count($1)]",
1192 ),
1193 vec![Box::new(example_read_relation().into_rel())],
1194 3,
1195 )
1196 .unwrap();
1197
1198 assert_eq!(aggregate.grouping_expressions.len(), 1);
1200 assert_eq!(aggregate.groupings.len(), 1);
1201 assert_eq!(aggregate.measures.len(), 2);
1202
1203 let emit_kind = &aggregate
1204 .common
1205 .as_ref()
1206 .unwrap()
1207 .emit_kind
1208 .as_ref()
1209 .unwrap();
1210 let emit = match emit_kind {
1211 EmitKind::Emit(emit) => &emit.output_mapping,
1212 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1213 };
1214 assert_eq!(emit, &[1, 0, 2]);
1216 }
1217
1218 #[test]
1219 fn test_parse_aggregate_relation_simple() {
1220 let extensions = TestContext::new()
1221 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1222 .with_function(1, 10, "sum")
1223 .extensions;
1224
1225 let aggregate = AggregateRel::parse_pair_with_context(
1226 &extensions,
1227 parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"),
1228 vec![Box::new(example_read_relation().into_rel())],
1229 3,
1230 )
1231 .unwrap();
1232
1233 assert_eq!(aggregate.grouping_expressions.len(), 2);
1234 assert_eq!(aggregate.groupings.len(), 1);
1235 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1237 }
1238
1239 #[test]
1240 fn test_parse_aggregate_relation_global_aggregate() {
1241 let extensions = TestContext::new()
1242 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1243 .with_function(1, 10, "sum")
1244 .with_function(1, 11, "count")
1245 .extensions;
1246
1247 let aggregate = AggregateRel::parse_pair_with_context(
1248 &extensions,
1249 parse_exact(
1250 Rule::aggregate_relation,
1251 "Aggregate[_ => sum($0), count($1)]",
1252 ),
1253 vec![Box::new(example_read_relation().into_rel())],
1254 3,
1255 )
1256 .unwrap();
1257
1258 assert_eq!(aggregate.grouping_expressions.len(), 0);
1260 assert_eq!(aggregate.groupings.len(), 1);
1261 assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1262 assert_eq!(aggregate.measures.len(), 2);
1263
1264 let emit_kind = &aggregate
1265 .common
1266 .as_ref()
1267 .unwrap()
1268 .emit_kind
1269 .as_ref()
1270 .unwrap();
1271 let emit = match emit_kind {
1272 EmitKind::Emit(emit) => &emit.output_mapping,
1273 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1274 };
1275 assert_eq!(emit, &[0, 1]);
1277 }
1278
1279 #[test]
1280 fn test_parse_aggregate_relation_grouping_sets() {
1281 let extensions = TestContext::new()
1282 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1283 .with_function(1, 11, "count")
1284 .extensions;
1285
1286 let read_rel = ReadRel::parse_pair_with_context(
1287 &extensions,
1288 parse_exact(
1289 Rule::read_relation,
1290 "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1291 ),
1292 vec![],
1293 0,
1294 )
1295 .unwrap();
1296
1297 let aggregate = AggregateRel::parse_pair_with_context(
1298 &extensions,
1299 parse_exact(
1300 Rule::aggregate_relation,
1301 "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]",
1302 ),
1303 vec![Box::new(read_rel.into_rel())],
1304 4,
1305 )
1306 .unwrap();
1307
1308 assert_eq!(aggregate.grouping_expressions.len(), 3);
1309 assert_eq!(aggregate.groupings.len(), 4);
1310 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1312 assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1314 assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1316 assert!(aggregate.groupings[3].expression_references.is_empty());
1318 assert_eq!(aggregate.measures.len(), 1);
1319 }
1320
1321 #[test]
1322 fn test_fetch_relation_positive_values() {
1323 let extensions = SimpleExtensions::default();
1324
1325 let fetch_rel = FetchRel::parse_pair_with_context(
1327 &extensions,
1328 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1329 vec![Box::new(example_read_relation().into_rel())],
1330 3,
1331 )
1332 .unwrap();
1333
1334 assert_eq!(
1336 fetch_rel.count_mode,
1337 Some(CountMode::CountExpr(i64_literal_expr(10)))
1338 );
1339 assert_eq!(
1340 fetch_rel.offset_mode,
1341 Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1342 );
1343 }
1344
1345 #[test]
1346 fn test_fetch_relation_negative_limit_rejected() {
1347 let extensions = SimpleExtensions::default();
1348
1349 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1351 if let Ok(mut pairs) = parsed_result {
1352 let pair = pairs.next().unwrap();
1353 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1354 let result = FetchRel::parse_pair_with_context(
1356 &extensions,
1357 pair,
1358 vec![Box::new(example_read_relation().into_rel())],
1359 3,
1360 );
1361 assert!(result.is_err());
1362 let error_msg = result.unwrap_err().to_string();
1363 assert!(error_msg.contains("Fetch limit must be non-negative"));
1364 } else {
1365 println!("Grammar prevents negative limit values at parse time");
1368 }
1369 } else {
1370 println!("Grammar prevents negative limit values at parse time");
1372 }
1373 }
1374
1375 #[test]
1376 fn test_fetch_relation_negative_offset_rejected() {
1377 let extensions = SimpleExtensions::default();
1378
1379 let parsed_result =
1381 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1382 if let Ok(mut pairs) = parsed_result {
1383 let pair = pairs.next().unwrap();
1384 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1385 let result = FetchRel::parse_pair_with_context(
1387 &extensions,
1388 pair,
1389 vec![Box::new(example_read_relation().into_rel())],
1390 3,
1391 );
1392 assert!(result.is_err());
1393 let error_msg = result.unwrap_err().to_string();
1394 assert!(error_msg.contains("Fetch offset must be non-negative"));
1395 } else {
1396 println!("Grammar prevents negative offset values at parse time");
1398 }
1399 } else {
1400 println!("Grammar prevents negative offset values at parse time");
1402 }
1403 }
1404
1405 #[test]
1406 fn test_parse_join_relation() {
1407 let extensions = TestContext::new()
1408 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1409 .with_function(1, 10, "eq")
1410 .extensions;
1411
1412 let left_rel = example_read_relation().into_rel();
1413 let right_rel = example_read_relation().into_rel();
1414
1415 let join = JoinRel::parse_pair_with_context(
1416 &extensions,
1417 parse_exact(
1418 Rule::join_relation,
1419 "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1420 ),
1421 vec![Box::new(left_rel), Box::new(right_rel)],
1422 6, )
1424 .unwrap();
1425
1426 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1428
1429 assert!(join.left.is_some());
1431 assert!(join.right.is_some());
1432
1433 assert!(join.expression.is_some());
1435
1436 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1437 let emit = match emit_kind {
1438 EmitKind::Emit(emit) => &emit.output_mapping,
1439 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1440 };
1441 assert_eq!(emit, &[0, 1, 3, 4]);
1443 }
1444
1445 #[test]
1446 fn test_parse_join_relation_left_outer() {
1447 let extensions = TestContext::new()
1448 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1449 .with_function(1, 10, "eq")
1450 .extensions;
1451
1452 let left_rel = example_read_relation().into_rel();
1453 let right_rel = example_read_relation().into_rel();
1454
1455 let join = JoinRel::parse_pair_with_context(
1456 &extensions,
1457 parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1458 vec![Box::new(left_rel), Box::new(right_rel)],
1459 6,
1460 )
1461 .unwrap();
1462
1463 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1465
1466 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1467 let emit = match emit_kind {
1468 EmitKind::Emit(emit) => &emit.output_mapping,
1469 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1470 };
1471 assert_eq!(emit, &[0, 1, 2]);
1473 }
1474
1475 #[test]
1476 fn test_parse_join_relation_left_semi() {
1477 let extensions = TestContext::new()
1478 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1479 .with_function(1, 10, "eq")
1480 .extensions;
1481
1482 let left_rel = example_read_relation().into_rel();
1483 let right_rel = example_read_relation().into_rel();
1484
1485 let join = JoinRel::parse_pair_with_context(
1486 &extensions,
1487 parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1488 vec![Box::new(left_rel), Box::new(right_rel)],
1489 6,
1490 )
1491 .unwrap();
1492
1493 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1495
1496 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1497 let emit = match emit_kind {
1498 EmitKind::Emit(emit) => &emit.output_mapping,
1499 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1500 };
1501 assert_eq!(emit, &[0, 1]);
1503 }
1504
1505 #[test]
1506 fn test_parse_join_relation_requires_two_children() {
1507 let extensions = SimpleExtensions::default();
1508
1509 let result = JoinRel::parse_pair_with_context(
1511 &extensions,
1512 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1513 vec![],
1514 0,
1515 );
1516 assert!(result.is_err());
1517
1518 let result = JoinRel::parse_pair_with_context(
1520 &extensions,
1521 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1522 vec![Box::new(example_read_relation().into_rel())],
1523 3,
1524 );
1525 assert!(result.is_err());
1526
1527 let result = JoinRel::parse_pair_with_context(
1529 &extensions,
1530 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1531 vec![
1532 Box::new(example_read_relation().into_rel()),
1533 Box::new(example_read_relation().into_rel()),
1534 Box::new(example_read_relation().into_rel()),
1535 ],
1536 9,
1537 );
1538 assert!(result.is_err());
1539 }
1540
1541 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1542 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1543 assert_eq!(pairs.as_str(), input);
1544 let pair = pairs.next().unwrap();
1545 assert_eq!(pairs.next(), None);
1546 pair
1547 }
1548}