1use std::collections::HashMap;
2
3use substrait::proto::fetch_rel::{CountMode, OffsetMode};
4use substrait::proto::rel::RelType;
5use substrait::proto::rel_common::{Emit, EmitKind};
6use substrait::proto::sort_field::{SortDirection, SortKind};
7use substrait::proto::{
8 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
9 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
10};
11
12use super::{
13 ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
14};
15use crate::extensions::SimpleExtensions;
16use crate::parser::expressions::{FieldIndex, Name};
17
18pub trait RelationParsePair: Sized {
21 fn rule() -> Rule;
22
23 fn message() -> &'static str;
24
25 fn parse_pair_with_context(
33 extensions: &SimpleExtensions,
34 pair: pest::iterators::Pair<Rule>,
35 input_children: Vec<Box<Rel>>,
36 input_field_count: usize,
37 ) -> Result<Self, MessageParseError>;
38
39 fn into_rel(self) -> Rel;
40}
41
42pub struct TableName(Vec<String>);
43
44impl ParsePair for TableName {
45 fn rule() -> Rule {
46 Rule::table_name
47 }
48
49 fn message() -> &'static str {
50 "TableName"
51 }
52
53 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
54 assert_eq!(pair.as_rule(), Self::rule());
55 let pairs = pair.into_inner();
56 let mut names = Vec::with_capacity(pairs.len());
57 let mut iter = RuleIter::from(pairs);
58 while let Some(name) = iter.parse_if_next::<Name>() {
59 names.push(name.0);
60 }
61 iter.done();
62 Self(names)
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct Column {
68 pub name: String,
69 pub typ: Type,
70}
71
72impl ScopedParsePair for Column {
73 fn rule() -> Rule {
74 Rule::named_column
75 }
76
77 fn message() -> &'static str {
78 "Column"
79 }
80
81 fn parse_pair(
82 extensions: &SimpleExtensions,
83 pair: pest::iterators::Pair<Rule>,
84 ) -> Result<Self, MessageParseError> {
85 assert_eq!(pair.as_rule(), Self::rule());
86 let mut iter = RuleIter::from(pair.into_inner());
87 let name = iter.parse_next::<Name>().0;
88 let typ = iter.parse_next_scoped(extensions)?;
89 iter.done();
90 Ok(Self { name, typ })
91 }
92}
93
94pub struct NamedColumnList(Vec<Column>);
95
96impl ScopedParsePair for NamedColumnList {
97 fn rule() -> Rule {
98 Rule::named_column_list
99 }
100
101 fn message() -> &'static str {
102 "NamedColumnList"
103 }
104
105 fn parse_pair(
106 extensions: &SimpleExtensions,
107 pair: pest::iterators::Pair<Rule>,
108 ) -> Result<Self, MessageParseError> {
109 assert_eq!(pair.as_rule(), Self::rule());
110 let mut columns = Vec::new();
111 for col in pair.into_inner() {
112 columns.push(Column::parse_pair(extensions, col)?);
113 }
114 Ok(Self(columns))
115 }
116}
117
118#[allow(clippy::vec_box)]
123pub(crate) fn expect_one_child(
124 message: &'static str,
125 pair: &pest::iterators::Pair<Rule>,
126 mut input_children: Vec<Box<Rel>>,
127) -> Result<Box<Rel>, MessageParseError> {
128 match input_children.len() {
129 0 => Err(MessageParseError::invalid(
130 message,
131 pair.as_span(),
132 format!("{message} missing child"),
133 )),
134 1 => Ok(input_children.pop().unwrap()),
135 n => Err(MessageParseError::invalid(
136 message,
137 pair.as_span(),
138 format!("{message} should have 1 input child, got {n}"),
139 )),
140 }
141}
142
143fn parse_reference_emit(pair: pest::iterators::Pair<Rule>) -> EmitKind {
145 assert_eq!(pair.as_rule(), Rule::reference_list);
146 let output_mapping = pair
147 .into_inner()
148 .map(|p| FieldIndex::parse_pair(p).0)
149 .collect::<Vec<i32>>();
150 EmitKind::Emit(Emit { output_mapping })
151}
152
153pub struct ParsedNamedArgs<'a> {
159 map: HashMap<&'a str, pest::iterators::Pair<'a, Rule>>,
160}
161
162impl<'a> ParsedNamedArgs<'a> {
163 pub fn new(
164 pairs: pest::iterators::Pairs<'a, Rule>,
165 rule: Rule,
166 ) -> Result<Self, MessageParseError> {
167 let mut map = HashMap::new();
168 for pair in pairs {
169 assert_eq!(pair.as_rule(), rule);
170 let mut inner = pair.clone().into_inner();
171 let name_pair = inner.next().unwrap();
172 let value_pair = inner.next().unwrap();
173 assert_eq!(inner.next(), None);
174 let name = name_pair.as_str();
175 if map.contains_key(name) {
176 return Err(MessageParseError::invalid(
177 "NamedArg",
178 name_pair.as_span(),
179 format!("Duplicate argument: {name}"),
180 ));
181 }
182 map.insert(name, value_pair);
183 }
184 Ok(Self { map })
185 }
186
187 pub fn pop(
191 mut self,
192 name: &str,
193 rule: Rule,
194 ) -> (Self, Option<pest::iterators::Pair<'a, Rule>>) {
195 let pair = self.map.remove(name).inspect(|pair| {
196 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
197 });
198 (self, pair)
199 }
200
201 pub fn done(self) -> Result<(), MessageParseError> {
203 if let Some((name, pair)) = self.map.iter().next() {
204 return Err(MessageParseError::invalid(
205 "NamedArgExtractor",
206 pair.as_span(),
208 format!("Unknown argument: {name}"),
209 ));
210 }
211 Ok(())
212 }
213}
214
215impl RelationParsePair for ReadRel {
216 fn rule() -> Rule {
217 Rule::read_relation
218 }
219
220 fn message() -> &'static str {
221 "ReadRel"
222 }
223
224 fn into_rel(self) -> Rel {
225 Rel {
226 rel_type: Some(RelType::Read(Box::new(self))),
227 }
228 }
229
230 fn parse_pair_with_context(
231 extensions: &SimpleExtensions,
232 pair: pest::iterators::Pair<Rule>,
233 input_children: Vec<Box<Rel>>,
234 input_field_count: usize,
235 ) -> Result<Self, MessageParseError> {
236 assert_eq!(pair.as_rule(), Self::rule());
237 if !input_children.is_empty() {
239 return Err(MessageParseError::invalid(
240 Self::message(),
241 pair.as_span(),
242 "ReadRel should have no input children",
243 ));
244 }
245 if input_field_count != 0 {
246 let error = pest::error::Error::new_from_span(
247 pest::error::ErrorVariant::CustomError {
248 message: "ReadRel should have 0 input fields".to_string(),
249 },
250 pair.as_span(),
251 );
252 return Err(MessageParseError::new(
253 "ReadRel",
254 ErrorKind::InvalidValue,
255 Box::new(error),
256 ));
257 }
258
259 let mut iter = RuleIter::from(pair.into_inner());
260 let table = iter.parse_next::<TableName>().0;
261 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
262 iter.done();
263
264 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
265 let struct_ = r#type::Struct {
266 types,
267 type_variation_reference: 0,
268 nullability: r#type::Nullability::Required as i32,
269 };
270 let named_struct = NamedStruct {
271 names,
272 r#struct: Some(struct_),
273 };
274
275 let read_rel = ReadRel {
276 base_schema: Some(named_struct),
277 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
278 names: table,
279 advanced_extension: None,
280 })),
281 ..Default::default()
282 };
283
284 Ok(read_rel)
285 }
286}
287
288impl RelationParsePair for FilterRel {
289 fn rule() -> Rule {
290 Rule::filter_relation
291 }
292
293 fn message() -> &'static str {
294 "FilterRel"
295 }
296
297 fn into_rel(self) -> Rel {
298 Rel {
299 rel_type: Some(RelType::Filter(Box::new(self))),
300 }
301 }
302
303 fn parse_pair_with_context(
304 extensions: &SimpleExtensions,
305 pair: pest::iterators::Pair<Rule>,
306 input_children: Vec<Box<Rel>>,
307 _input_field_count: usize,
308 ) -> Result<Self, MessageParseError> {
309 assert_eq!(pair.as_rule(), Self::rule());
312 let input = expect_one_child(Self::message(), &pair, input_children)?;
313 let mut iter = RuleIter::from(pair.into_inner());
314 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
316 let references_pair = iter.pop(Rule::reference_list);
318 iter.done();
319
320 let emit = parse_reference_emit(references_pair);
321 let common = RelCommon {
322 emit_kind: Some(emit),
323 ..Default::default()
324 };
325
326 Ok(FilterRel {
327 input: Some(input),
328 condition: Some(Box::new(condition)),
329 common: Some(common),
330 advanced_extension: None,
331 })
332 }
333}
334
335impl RelationParsePair for ProjectRel {
336 fn rule() -> Rule {
337 Rule::project_relation
338 }
339
340 fn message() -> &'static str {
341 "ProjectRel"
342 }
343
344 fn into_rel(self) -> Rel {
345 Rel {
346 rel_type: Some(RelType::Project(Box::new(self))),
347 }
348 }
349
350 fn parse_pair_with_context(
351 extensions: &SimpleExtensions,
352 pair: pest::iterators::Pair<Rule>,
353 input_children: Vec<Box<Rel>>,
354 input_field_count: usize,
355 ) -> Result<Self, MessageParseError> {
356 assert_eq!(pair.as_rule(), Self::rule());
357 let input = expect_one_child(Self::message(), &pair, input_children)?;
358
359 let arguments_pair = unwrap_single_pair(pair);
361
362 let mut expressions = Vec::new();
363 let mut output_mapping = Vec::new();
364
365 for arg in arguments_pair.into_inner() {
367 let inner_arg = crate::parser::unwrap_single_pair(arg);
368 match inner_arg.as_rule() {
369 Rule::reference => {
370 let field_index = FieldIndex::parse_pair(inner_arg);
372 output_mapping.push(field_index.0);
373 }
374 Rule::expression => {
375 let _expr = Expression::parse_pair(extensions, inner_arg)?;
377 expressions.push(_expr);
378 output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
380 }
381 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
382 }
383 }
384
385 let emit = EmitKind::Emit(Emit { output_mapping });
386 let common = RelCommon {
387 emit_kind: Some(emit),
388 ..Default::default()
389 };
390
391 Ok(ProjectRel {
392 input: Some(input),
393 expressions,
394 common: Some(common),
395 advanced_extension: None,
396 })
397 }
398}
399
400impl RelationParsePair for AggregateRel {
401 fn rule() -> Rule {
402 Rule::aggregate_relation
403 }
404
405 fn message() -> &'static str {
406 "AggregateRel"
407 }
408
409 fn into_rel(self) -> Rel {
410 Rel {
411 rel_type: Some(RelType::Aggregate(Box::new(self))),
412 }
413 }
414
415 fn parse_pair_with_context(
416 extensions: &SimpleExtensions,
417 pair: pest::iterators::Pair<Rule>,
418 input_children: Vec<Box<Rel>>,
419 _input_field_count: usize,
420 ) -> Result<Self, MessageParseError> {
421 assert_eq!(pair.as_rule(), Self::rule());
422 let input = expect_one_child(Self::message(), &pair, input_children)?;
423 let mut iter = RuleIter::from(pair.into_inner());
424 let group_by_pair = iter.pop(Rule::aggregate_group_by);
425 let output_pair = iter.pop(Rule::aggregate_output);
426 iter.done();
427 let mut grouping_expressions = Vec::new();
428 for group_by_item in group_by_pair.into_inner() {
429 match group_by_item.as_rule() {
430 Rule::reference => {
431 let field_index = FieldIndex::parse_pair(group_by_item);
432 grouping_expressions.push(Expression {
433 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
434 field_index.to_field_reference(),
435 ))),
436 });
437 }
438 Rule::empty => {
439 }
441 _ => panic!(
442 "Unexpected group-by item rule: {:?}",
443 group_by_item.as_rule()
444 ),
445 }
446 }
447
448 let mut measures = Vec::new();
450 let mut output_mapping = Vec::new();
451 let group_by_count = grouping_expressions.len();
452 let mut measure_count = 0;
453
454 for output_item in output_pair.into_inner() {
455 let inner_item = unwrap_single_pair(output_item);
456 match inner_item.as_rule() {
457 Rule::reference => {
458 let field_index = FieldIndex::parse_pair(inner_item);
459 output_mapping.push(field_index.0);
460 }
461 Rule::aggregate_measure => {
462 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
463 measures.push(measure);
464 output_mapping.push(group_by_count as i32 + measure_count);
465 measure_count += 1;
466 }
467 _ => panic!(
468 "Unexpected inner output item rule: {:?}",
469 inner_item.as_rule()
470 ),
471 }
472 }
473
474 let emit = EmitKind::Emit(Emit { output_mapping });
475 let common = RelCommon {
476 emit_kind: Some(emit),
477 ..Default::default()
478 };
479
480 Ok(AggregateRel {
481 input: Some(input),
482 grouping_expressions,
483 groupings: vec![], measures,
485 common: Some(common),
486 advanced_extension: None,
487 })
488 }
489}
490
491impl ScopedParsePair for SortField {
492 fn rule() -> Rule {
493 Rule::sort_field
494 }
495
496 fn message() -> &'static str {
497 "SortField"
498 }
499
500 fn parse_pair(
501 _extensions: &SimpleExtensions,
502 pair: pest::iterators::Pair<Rule>,
503 ) -> Result<Self, MessageParseError> {
504 assert_eq!(pair.as_rule(), Self::rule());
505 let mut iter = RuleIter::from(pair.into_inner());
506 let reference_pair = iter.pop(Rule::reference);
507 let field_index = FieldIndex::parse_pair(reference_pair);
508 let direction_pair = iter.pop(Rule::sort_direction);
509 let direction = match direction_pair.as_str().trim_start_matches('&') {
513 "AscNullsFirst" => SortDirection::AscNullsFirst,
514 "AscNullsLast" => SortDirection::AscNullsLast,
515 "DescNullsFirst" => SortDirection::DescNullsFirst,
516 "DescNullsLast" => SortDirection::DescNullsLast,
517 other => {
518 return Err(MessageParseError::invalid(
519 "SortDirection",
520 direction_pair.as_span(),
521 format!("Unknown sort direction: {other}"),
522 ));
523 }
524 };
525 iter.done();
526 Ok(SortField {
527 expr: Some(Expression {
528 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
529 field_index.to_field_reference(),
530 ))),
531 }),
532 sort_kind: Some(SortKind::Direction(direction as i32)),
534 })
535 }
536}
537
538impl RelationParsePair for SortRel {
539 fn rule() -> Rule {
540 Rule::sort_relation
541 }
542
543 fn message() -> &'static str {
544 "SortRel"
545 }
546
547 fn into_rel(self) -> Rel {
548 Rel {
549 rel_type: Some(RelType::Sort(Box::new(self))),
550 }
551 }
552
553 fn parse_pair_with_context(
554 extensions: &SimpleExtensions,
555 pair: pest::iterators::Pair<Rule>,
556 input_children: Vec<Box<Rel>>,
557 _input_field_count: usize,
558 ) -> Result<Self, MessageParseError> {
559 assert_eq!(pair.as_rule(), Self::rule());
560 let input = expect_one_child(Self::message(), &pair, input_children)?;
561 let mut iter = RuleIter::from(pair.into_inner());
562 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
563 let reference_list_pair = iter.pop(Rule::reference_list);
564 let mut sorts = Vec::new();
565 for sort_field_pair in sort_field_list_pair.into_inner() {
566 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
567 sorts.push(sort_field);
568 }
569 let emit = parse_reference_emit(reference_list_pair);
570 let common = RelCommon {
571 emit_kind: Some(emit),
572 ..Default::default()
573 };
574 iter.done();
575 Ok(SortRel {
576 input: Some(input),
577 sorts,
578 common: Some(common),
579 advanced_extension: None,
580 })
581 }
582}
583
584impl ScopedParsePair for CountMode {
585 fn rule() -> Rule {
586 Rule::fetch_value
587 }
588 fn message() -> &'static str {
589 "CountMode"
590 }
591 fn parse_pair(
592 extensions: &SimpleExtensions,
593 pair: pest::iterators::Pair<Rule>,
594 ) -> Result<Self, MessageParseError> {
595 assert_eq!(pair.as_rule(), Self::rule());
596 let mut arg_inner = RuleIter::from(pair.into_inner());
597 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
598 int_pair
599 } else {
600 arg_inner.pop(Rule::expression)
601 };
602 match value_pair.as_rule() {
603 Rule::integer => {
604 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
605 MessageParseError::invalid(
606 Self::message(),
607 value_pair.as_span(),
608 format!("Invalid integer: {e}"),
609 )
610 })?;
611 if value < 0 {
612 return Err(MessageParseError::invalid(
613 Self::message(),
614 value_pair.as_span(),
615 format!("Fetch limit must be non-negative, got: {value}"),
616 ));
617 }
618 Ok(CountMode::Count(value))
619 }
620 Rule::expression => {
621 let expr = Expression::parse_pair(extensions, value_pair)?;
622 Ok(CountMode::CountExpr(Box::new(expr)))
623 }
624 _ => Err(MessageParseError::invalid(
625 Self::message(),
626 value_pair.as_span(),
627 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
628 )),
629 }
630 }
631}
632
633impl ScopedParsePair for OffsetMode {
634 fn rule() -> Rule {
635 Rule::fetch_value
636 }
637 fn message() -> &'static str {
638 "OffsetMode"
639 }
640 fn parse_pair(
641 extensions: &SimpleExtensions,
642 pair: pest::iterators::Pair<Rule>,
643 ) -> Result<Self, MessageParseError> {
644 assert_eq!(pair.as_rule(), Self::rule());
645 let mut arg_inner = RuleIter::from(pair.into_inner());
646 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
647 int_pair
648 } else {
649 arg_inner.pop(Rule::expression)
650 };
651 match value_pair.as_rule() {
652 Rule::integer => {
653 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
654 MessageParseError::invalid(
655 Self::message(),
656 value_pair.as_span(),
657 format!("Invalid integer: {e}"),
658 )
659 })?;
660 if value < 0 {
661 return Err(MessageParseError::invalid(
662 Self::message(),
663 value_pair.as_span(),
664 format!("Fetch offset must be non-negative, got: {value}"),
665 ));
666 }
667 Ok(OffsetMode::Offset(value))
668 }
669 Rule::expression => {
670 let expr = Expression::parse_pair(extensions, value_pair)?;
671 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
672 }
673 _ => Err(MessageParseError::invalid(
674 Self::message(),
675 value_pair.as_span(),
676 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
677 )),
678 }
679 }
680}
681
682impl RelationParsePair for FetchRel {
683 fn rule() -> Rule {
684 Rule::fetch_relation
685 }
686
687 fn message() -> &'static str {
688 "FetchRel"
689 }
690
691 fn into_rel(self) -> Rel {
692 Rel {
693 rel_type: Some(RelType::Fetch(Box::new(self))),
694 }
695 }
696
697 fn parse_pair_with_context(
698 extensions: &SimpleExtensions,
699 pair: pest::iterators::Pair<Rule>,
700 input_children: Vec<Box<Rel>>,
701 _input_field_count: usize,
702 ) -> Result<Self, MessageParseError> {
703 assert_eq!(pair.as_rule(), Self::rule());
704 let input = expect_one_child(Self::message(), &pair, input_children)?;
705 let mut iter = RuleIter::from(pair.into_inner());
706
707 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
709 None => {
710 iter.pop(Rule::empty);
712 (None, None)
713 }
714 Some(fetch_args_pair) => {
715 let extractor =
716 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
717 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
718 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
719 extractor.done()?;
720 (limit_pair, offset_pair)
721 }
722 };
723
724 let reference_list_pair = iter.pop(Rule::reference_list);
725 let emit = parse_reference_emit(reference_list_pair);
726 let common = RelCommon {
727 emit_kind: Some(emit),
728 ..Default::default()
729 };
730 iter.done();
731
732 let count_mode = limit_pair
734 .map(|pair| CountMode::parse_pair(extensions, pair))
735 .transpose()?;
736 let offset_mode = offset_pair
737 .map(|pair| OffsetMode::parse_pair(extensions, pair))
738 .transpose()?;
739 Ok(FetchRel {
740 input: Some(input),
741 common: Some(common),
742 advanced_extension: None,
743 offset_mode,
744 count_mode,
745 })
746 }
747}
748
749impl ParsePair for join_rel::JoinType {
750 fn rule() -> Rule {
751 Rule::join_type
752 }
753
754 fn message() -> &'static str {
755 "JoinType"
756 }
757
758 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
759 assert_eq!(pair.as_rule(), Self::rule());
760 let join_type_str = pair.as_str().trim_start_matches('&');
761 match join_type_str {
762 "Inner" => join_rel::JoinType::Inner,
763 "Left" => join_rel::JoinType::Left,
764 "Right" => join_rel::JoinType::Right,
765 "Outer" => join_rel::JoinType::Outer,
766 "LeftSemi" => join_rel::JoinType::LeftSemi,
767 "RightSemi" => join_rel::JoinType::RightSemi,
768 "LeftAnti" => join_rel::JoinType::LeftAnti,
769 "RightAnti" => join_rel::JoinType::RightAnti,
770 "LeftSingle" => join_rel::JoinType::LeftSingle,
771 "RightSingle" => join_rel::JoinType::RightSingle,
772 "LeftMark" => join_rel::JoinType::LeftMark,
773 "RightMark" => join_rel::JoinType::RightMark,
774 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
775 }
776 }
777}
778
779impl RelationParsePair for JoinRel {
780 fn rule() -> Rule {
781 Rule::join_relation
782 }
783
784 fn message() -> &'static str {
785 "JoinRel"
786 }
787
788 fn into_rel(self) -> Rel {
789 Rel {
790 rel_type: Some(RelType::Join(Box::new(self))),
791 }
792 }
793
794 fn parse_pair_with_context(
795 extensions: &SimpleExtensions,
796 pair: pest::iterators::Pair<Rule>,
797 input_children: Vec<Box<Rel>>,
798 _input_field_count: usize,
799 ) -> Result<Self, MessageParseError> {
800 assert_eq!(pair.as_rule(), Self::rule());
801
802 if input_children.len() != 2 {
804 return Err(MessageParseError::invalid(
805 Self::message(),
806 pair.as_span(),
807 format!(
808 "JoinRel should have exactly 2 input children, got {}",
809 input_children.len()
810 ),
811 ));
812 }
813
814 let mut children_iter = input_children.into_iter();
815 let left = children_iter.next().unwrap();
816 let right = children_iter.next().unwrap();
817
818 let mut iter = RuleIter::from(pair.into_inner());
819
820 let join_type = iter.parse_next::<join_rel::JoinType>();
822
823 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
825
826 let reference_list_pair = iter.pop(Rule::reference_list);
828 iter.done();
829
830 let emit = parse_reference_emit(reference_list_pair);
831 let common = RelCommon {
832 emit_kind: Some(emit),
833 ..Default::default()
834 };
835
836 Ok(JoinRel {
837 common: Some(common),
838 left: Some(left),
839 right: Some(right),
840 expression: Some(Box::new(condition)),
841 post_join_filter: None, r#type: join_type as i32,
843 advanced_extension: None,
844 })
845 }
846}
847
848#[cfg(test)]
849mod tests {
850 use pest::Parser;
851
852 use super::*;
853 use crate::fixtures::TestContext;
854 use crate::parser::{ExpressionParser, Rule};
855
856 #[test]
857 fn test_parse_relation() {
858 }
860
861 #[test]
862 fn test_parse_read_relation() {
863 let extensions = SimpleExtensions::default();
864 let read = ReadRel::parse_pair_with_context(
865 &extensions,
866 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
867 vec![],
868 0,
869 )
870 .unwrap();
871 let names = match &read.read_type {
872 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
873 _ => panic!("Expected NamedTable"),
874 };
875 assert_eq!(names, &["ab", "cd", "ef"]);
876 let columns = &read
877 .base_schema
878 .as_ref()
879 .unwrap()
880 .r#struct
881 .as_ref()
882 .unwrap()
883 .types;
884 assert_eq!(columns.len(), 2);
885 }
886
887 fn example_read_relation() -> ReadRel {
889 let extensions = SimpleExtensions::default();
890 ReadRel::parse_pair_with_context(
891 &extensions,
892 parse_exact(
893 Rule::read_relation,
894 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
895 ),
896 vec![],
897 0,
898 )
899 .unwrap()
900 }
901
902 #[test]
903 fn test_parse_filter_relation() {
904 let extensions = SimpleExtensions::default();
905 let filter = FilterRel::parse_pair_with_context(
906 &extensions,
907 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
908 vec![Box::new(example_read_relation().into_rel())],
909 3,
910 )
911 .unwrap();
912 let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
913 let emit = match emit_kind {
914 EmitKind::Emit(emit) => &emit.output_mapping,
915 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
916 };
917 assert_eq!(emit, &[0, 1, 2]);
918 }
919
920 #[test]
921 fn test_parse_project_relation() {
922 let extensions = SimpleExtensions::default();
923 let project = ProjectRel::parse_pair_with_context(
924 &extensions,
925 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
926 vec![Box::new(example_read_relation().into_rel())],
927 3,
928 )
929 .unwrap();
930
931 assert_eq!(project.expressions.len(), 1);
933
934 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
935 let emit = match emit_kind {
936 EmitKind::Emit(emit) => &emit.output_mapping,
937 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
938 };
939 assert_eq!(emit, &[0, 1, 3]);
941 }
942
943 #[test]
944 fn test_parse_project_relation_complex() {
945 let extensions = SimpleExtensions::default();
946 let project = ProjectRel::parse_pair_with_context(
947 &extensions,
948 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
949 vec![Box::new(example_read_relation().into_rel())],
950 5, )
952 .unwrap();
953
954 assert_eq!(project.expressions.len(), 2);
956
957 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
958 let emit = match emit_kind {
959 EmitKind::Emit(emit) => &emit.output_mapping,
960 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
961 };
962 assert_eq!(emit, &[5, 0, 6, 2, 1]);
965 }
966
967 #[test]
968 fn test_parse_aggregate_relation() {
969 let extensions = TestContext::new()
970 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
971 .with_function(1, 10, "sum")
972 .with_function(1, 11, "count")
973 .extensions;
974
975 let aggregate = AggregateRel::parse_pair_with_context(
976 &extensions,
977 parse_exact(
978 Rule::aggregate_relation,
979 "Aggregate[$0, $1 => sum($2), $0, count($2)]",
980 ),
981 vec![Box::new(example_read_relation().into_rel())],
982 3,
983 )
984 .unwrap();
985
986 assert_eq!(aggregate.grouping_expressions.len(), 2);
988 assert_eq!(aggregate.measures.len(), 2);
989
990 let emit_kind = &aggregate
991 .common
992 .as_ref()
993 .unwrap()
994 .emit_kind
995 .as_ref()
996 .unwrap();
997 let emit = match emit_kind {
998 EmitKind::Emit(emit) => &emit.output_mapping,
999 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1000 };
1001 assert_eq!(emit, &[2, 0, 3]);
1004 }
1005
1006 #[test]
1007 fn test_parse_aggregate_relation_simple() {
1008 let extensions = TestContext::new()
1009 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1010 .with_function(1, 10, "sum")
1011 .with_function(1, 11, "count")
1012 .extensions;
1013
1014 let aggregate = AggregateRel::parse_pair_with_context(
1015 &extensions,
1016 parse_exact(
1017 Rule::aggregate_relation,
1018 "Aggregate[$0 => sum($1), count($1)]",
1019 ),
1020 vec![Box::new(example_read_relation().into_rel())],
1021 3,
1022 )
1023 .unwrap();
1024
1025 assert_eq!(aggregate.grouping_expressions.len(), 1);
1027 assert_eq!(aggregate.measures.len(), 2);
1028
1029 let emit_kind = &aggregate
1030 .common
1031 .as_ref()
1032 .unwrap()
1033 .emit_kind
1034 .as_ref()
1035 .unwrap();
1036 let emit = match emit_kind {
1037 EmitKind::Emit(emit) => &emit.output_mapping,
1038 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1039 };
1040 assert_eq!(emit, &[1, 2]);
1042 }
1043
1044 #[test]
1045 fn test_parse_aggregate_relation_no_group_by() {
1046 let extensions = TestContext::new()
1047 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1048 .with_function(1, 10, "sum")
1049 .with_function(1, 11, "count")
1050 .extensions;
1051
1052 let aggregate = AggregateRel::parse_pair_with_context(
1053 &extensions,
1054 parse_exact(
1055 Rule::aggregate_relation,
1056 "Aggregate[_ => sum($0), count($1)]",
1057 ),
1058 vec![Box::new(example_read_relation().into_rel())],
1059 3,
1060 )
1061 .unwrap();
1062
1063 assert_eq!(aggregate.grouping_expressions.len(), 0);
1065 assert_eq!(aggregate.measures.len(), 2);
1066
1067 let emit_kind = &aggregate
1068 .common
1069 .as_ref()
1070 .unwrap()
1071 .emit_kind
1072 .as_ref()
1073 .unwrap();
1074 let emit = match emit_kind {
1075 EmitKind::Emit(emit) => &emit.output_mapping,
1076 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1077 };
1078 assert_eq!(emit, &[0, 1]);
1080 }
1081
1082 #[test]
1083 fn test_parse_aggregate_relation_empty_group_by() {
1084 let extensions = TestContext::new()
1085 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1086 .with_function(1, 10, "sum")
1087 .with_function(1, 11, "count")
1088 .extensions;
1089
1090 let aggregate = AggregateRel::parse_pair_with_context(
1091 &extensions,
1092 parse_exact(
1093 Rule::aggregate_relation,
1094 "Aggregate[_ => sum($0), count($1)]",
1095 ),
1096 vec![Box::new(example_read_relation().into_rel())],
1097 3,
1098 )
1099 .unwrap();
1100
1101 assert_eq!(aggregate.grouping_expressions.len(), 0);
1103 assert_eq!(aggregate.measures.len(), 2);
1104
1105 let emit_kind = &aggregate
1106 .common
1107 .as_ref()
1108 .unwrap()
1109 .emit_kind
1110 .as_ref()
1111 .unwrap();
1112 let emit = match emit_kind {
1113 EmitKind::Emit(emit) => &emit.output_mapping,
1114 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1115 };
1116 assert_eq!(emit, &[0, 1]);
1118 }
1119
1120 #[test]
1121 fn test_fetch_relation_positive_values() {
1122 let extensions = SimpleExtensions::default();
1123
1124 let fetch_rel = FetchRel::parse_pair_with_context(
1126 &extensions,
1127 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1128 vec![Box::new(example_read_relation().into_rel())],
1129 3,
1130 )
1131 .unwrap();
1132
1133 assert_eq!(fetch_rel.count_mode, Some(CountMode::Count(10)));
1135 assert_eq!(fetch_rel.offset_mode, Some(OffsetMode::Offset(5)));
1136 }
1137
1138 #[test]
1139 fn test_fetch_relation_negative_limit_rejected() {
1140 let extensions = SimpleExtensions::default();
1141
1142 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1144 if let Ok(mut pairs) = parsed_result {
1145 let pair = pairs.next().unwrap();
1146 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1147 let result = FetchRel::parse_pair_with_context(
1149 &extensions,
1150 pair,
1151 vec![Box::new(example_read_relation().into_rel())],
1152 3,
1153 );
1154 assert!(result.is_err());
1155 let error_msg = result.unwrap_err().to_string();
1156 assert!(error_msg.contains("Fetch limit must be non-negative"));
1157 } else {
1158 println!("Grammar prevents negative limit values at parse time");
1161 }
1162 } else {
1163 println!("Grammar prevents negative limit values at parse time");
1165 }
1166 }
1167
1168 #[test]
1169 fn test_fetch_relation_negative_offset_rejected() {
1170 let extensions = SimpleExtensions::default();
1171
1172 let parsed_result =
1174 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1175 if let Ok(mut pairs) = parsed_result {
1176 let pair = pairs.next().unwrap();
1177 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1178 let result = FetchRel::parse_pair_with_context(
1180 &extensions,
1181 pair,
1182 vec![Box::new(example_read_relation().into_rel())],
1183 3,
1184 );
1185 assert!(result.is_err());
1186 let error_msg = result.unwrap_err().to_string();
1187 assert!(error_msg.contains("Fetch offset must be non-negative"));
1188 } else {
1189 println!("Grammar prevents negative offset values at parse time");
1191 }
1192 } else {
1193 println!("Grammar prevents negative offset values at parse time");
1195 }
1196 }
1197
1198 #[test]
1199 fn test_parse_join_relation() {
1200 let extensions = TestContext::new()
1201 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1202 .with_function(1, 10, "eq")
1203 .extensions;
1204
1205 let left_rel = example_read_relation().into_rel();
1206 let right_rel = example_read_relation().into_rel();
1207
1208 let join = JoinRel::parse_pair_with_context(
1209 &extensions,
1210 parse_exact(
1211 Rule::join_relation,
1212 "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1213 ),
1214 vec![Box::new(left_rel), Box::new(right_rel)],
1215 6, )
1217 .unwrap();
1218
1219 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1221
1222 assert!(join.left.is_some());
1224 assert!(join.right.is_some());
1225
1226 assert!(join.expression.is_some());
1228
1229 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1230 let emit = match emit_kind {
1231 EmitKind::Emit(emit) => &emit.output_mapping,
1232 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1233 };
1234 assert_eq!(emit, &[0, 1, 3, 4]);
1236 }
1237
1238 #[test]
1239 fn test_parse_join_relation_left_outer() {
1240 let extensions = TestContext::new()
1241 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1242 .with_function(1, 10, "eq")
1243 .extensions;
1244
1245 let left_rel = example_read_relation().into_rel();
1246 let right_rel = example_read_relation().into_rel();
1247
1248 let join = JoinRel::parse_pair_with_context(
1249 &extensions,
1250 parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1251 vec![Box::new(left_rel), Box::new(right_rel)],
1252 6,
1253 )
1254 .unwrap();
1255
1256 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1258
1259 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1260 let emit = match emit_kind {
1261 EmitKind::Emit(emit) => &emit.output_mapping,
1262 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1263 };
1264 assert_eq!(emit, &[0, 1, 2]);
1266 }
1267
1268 #[test]
1269 fn test_parse_join_relation_left_semi() {
1270 let extensions = TestContext::new()
1271 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1272 .with_function(1, 10, "eq")
1273 .extensions;
1274
1275 let left_rel = example_read_relation().into_rel();
1276 let right_rel = example_read_relation().into_rel();
1277
1278 let join = JoinRel::parse_pair_with_context(
1279 &extensions,
1280 parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1281 vec![Box::new(left_rel), Box::new(right_rel)],
1282 6,
1283 )
1284 .unwrap();
1285
1286 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1288
1289 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1290 let emit = match emit_kind {
1291 EmitKind::Emit(emit) => &emit.output_mapping,
1292 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1293 };
1294 assert_eq!(emit, &[0, 1]);
1296 }
1297
1298 #[test]
1299 fn test_parse_join_relation_requires_two_children() {
1300 let extensions = SimpleExtensions::default();
1301
1302 let result = JoinRel::parse_pair_with_context(
1304 &extensions,
1305 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1306 vec![],
1307 0,
1308 );
1309 assert!(result.is_err());
1310
1311 let result = JoinRel::parse_pair_with_context(
1313 &extensions,
1314 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1315 vec![Box::new(example_read_relation().into_rel())],
1316 3,
1317 );
1318 assert!(result.is_err());
1319
1320 let result = JoinRel::parse_pair_with_context(
1322 &extensions,
1323 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1324 vec![
1325 Box::new(example_read_relation().into_rel()),
1326 Box::new(example_read_relation().into_rel()),
1327 Box::new(example_read_relation().into_rel()),
1328 ],
1329 9,
1330 );
1331 assert!(result.is_err());
1332 }
1333
1334 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
1335 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1336 assert_eq!(pairs.as_str(), input);
1337 let pair = pairs.next().unwrap();
1338 assert_eq!(pairs.next(), None);
1339 pair
1340 }
1341}