1use std::collections::HashMap;
2
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{Literal, RexType};
5use substrait::proto::fetch_rel::{CountMode, OffsetMode};
6use substrait::proto::rel::RelType;
7use substrait::proto::rel_common::{Emit, EmitKind};
8use substrait::proto::sort_field::{SortDirection, SortKind};
9use substrait::proto::{
10 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
11 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
12};
13
14use super::{
15 ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
16};
17use crate::extensions::SimpleExtensions;
18use crate::parser::expressions::{FieldIndex, Name};
19
20pub trait RelationParsePair: Sized {
23 fn rule() -> Rule;
24
25 fn message() -> &'static str;
26
27 fn parse_pair_with_context(
35 extensions: &SimpleExtensions,
36 pair: pest::iterators::Pair<Rule>,
37 input_children: Vec<Box<Rel>>,
38 input_field_count: usize,
39 ) -> Result<Self, MessageParseError>;
40
41 fn into_rel(self) -> Rel;
42}
43
44pub struct TableName(Vec<String>);
45
46impl ParsePair for TableName {
47 fn rule() -> Rule {
48 Rule::table_name
49 }
50
51 fn message() -> &'static str {
52 "TableName"
53 }
54
55 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
56 assert_eq!(pair.as_rule(), Self::rule());
57 let pairs = pair.into_inner();
58 let mut names = Vec::with_capacity(pairs.len());
59 let mut iter = RuleIter::from(pairs);
60 while let Some(name) = iter.parse_if_next::<Name>() {
61 names.push(name.0);
62 }
63 iter.done();
64 Self(names)
65 }
66}
67
68#[derive(Debug, Clone)]
69pub struct Column {
70 pub name: String,
71 pub typ: Type,
72}
73
74impl ScopedParsePair for Column {
75 fn rule() -> Rule {
76 Rule::named_column
77 }
78
79 fn message() -> &'static str {
80 "Column"
81 }
82
83 fn parse_pair(
84 extensions: &SimpleExtensions,
85 pair: pest::iterators::Pair<Rule>,
86 ) -> Result<Self, MessageParseError> {
87 assert_eq!(pair.as_rule(), Self::rule());
88 let mut iter = RuleIter::from(pair.into_inner());
89 let name = iter.parse_next::<Name>().0;
90 let typ = iter.parse_next_scoped(extensions)?;
91 iter.done();
92 Ok(Self { name, typ })
93 }
94}
95
96pub struct NamedColumnList(Vec<Column>);
97
98impl ScopedParsePair for NamedColumnList {
99 fn rule() -> Rule {
100 Rule::named_column_list
101 }
102
103 fn message() -> &'static str {
104 "NamedColumnList"
105 }
106
107 fn parse_pair(
108 extensions: &SimpleExtensions,
109 pair: pest::iterators::Pair<Rule>,
110 ) -> Result<Self, MessageParseError> {
111 assert_eq!(pair.as_rule(), Self::rule());
112 let mut columns = Vec::new();
113 for col in pair.into_inner() {
114 columns.push(Column::parse_pair(extensions, col)?);
115 }
116 Ok(Self(columns))
117 }
118}
119
120#[allow(clippy::vec_box)]
125pub(crate) fn expect_one_child(
126 message: &'static str,
127 pair: &pest::iterators::Pair<Rule>,
128 mut input_children: Vec<Box<Rel>>,
129) -> Result<Box<Rel>, MessageParseError> {
130 match input_children.len() {
131 0 => Err(MessageParseError::invalid(
132 message,
133 pair.as_span(),
134 format!("{message} missing child"),
135 )),
136 1 => Ok(input_children.pop().unwrap()),
137 n => Err(MessageParseError::invalid(
138 message,
139 pair.as_span(),
140 format!("{message} should have 1 input child, got {n}"),
141 )),
142 }
143}
144
145fn parse_reference_emit(pair: pest::iterators::Pair<Rule>) -> EmitKind {
147 assert_eq!(pair.as_rule(), Rule::reference_list);
148 let output_mapping = pair
149 .into_inner()
150 .map(|p| FieldIndex::parse_pair(p).0)
151 .collect::<Vec<i32>>();
152 EmitKind::Emit(Emit { output_mapping })
153}
154
155pub struct ParsedNamedArgs<'a> {
161 map: HashMap<&'a str, pest::iterators::Pair<'a, Rule>>,
162}
163
164impl<'a> ParsedNamedArgs<'a> {
165 pub fn new(
166 pairs: pest::iterators::Pairs<'a, Rule>,
167 rule: Rule,
168 ) -> Result<Self, MessageParseError> {
169 let mut map = HashMap::new();
170 for pair in pairs {
171 assert_eq!(pair.as_rule(), rule);
172 let mut inner = pair.clone().into_inner();
173 let name_pair = inner.next().unwrap();
174 let value_pair = inner.next().unwrap();
175 assert_eq!(inner.next(), None);
176 let name = name_pair.as_str();
177 if map.contains_key(name) {
178 return Err(MessageParseError::invalid(
179 "NamedArg",
180 name_pair.as_span(),
181 format!("Duplicate argument: {name}"),
182 ));
183 }
184 map.insert(name, value_pair);
185 }
186 Ok(Self { map })
187 }
188
189 pub fn pop(
193 mut self,
194 name: &str,
195 rule: Rule,
196 ) -> (Self, Option<pest::iterators::Pair<'a, Rule>>) {
197 let pair = self.map.remove(name).inspect(|pair| {
198 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
199 });
200 (self, pair)
201 }
202
203 pub fn done(self) -> Result<(), MessageParseError> {
205 if let Some((name, pair)) = self.map.iter().next() {
206 return Err(MessageParseError::invalid(
207 "NamedArgExtractor",
208 pair.as_span(),
210 format!("Unknown argument: {name}"),
211 ));
212 }
213 Ok(())
214 }
215}
216
217impl RelationParsePair for ReadRel {
218 fn rule() -> Rule {
219 Rule::read_relation
220 }
221
222 fn message() -> &'static str {
223 "ReadRel"
224 }
225
226 fn into_rel(self) -> Rel {
227 Rel {
228 rel_type: Some(RelType::Read(Box::new(self))),
229 }
230 }
231
232 fn parse_pair_with_context(
233 extensions: &SimpleExtensions,
234 pair: pest::iterators::Pair<Rule>,
235 input_children: Vec<Box<Rel>>,
236 input_field_count: usize,
237 ) -> Result<Self, MessageParseError> {
238 assert_eq!(pair.as_rule(), Self::rule());
239 if !input_children.is_empty() {
241 return Err(MessageParseError::invalid(
242 Self::message(),
243 pair.as_span(),
244 "ReadRel should have no input children",
245 ));
246 }
247 if input_field_count != 0 {
248 let error = pest::error::Error::new_from_span(
249 pest::error::ErrorVariant::CustomError {
250 message: "ReadRel should have 0 input fields".to_string(),
251 },
252 pair.as_span(),
253 );
254 return Err(MessageParseError::new(
255 "ReadRel",
256 ErrorKind::InvalidValue,
257 Box::new(error),
258 ));
259 }
260
261 let mut iter = RuleIter::from(pair.into_inner());
262 let table = iter.parse_next::<TableName>().0;
263 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
264 iter.done();
265
266 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
267 let struct_ = r#type::Struct {
268 types,
269 type_variation_reference: 0,
270 nullability: r#type::Nullability::Required as i32,
271 };
272 let named_struct = NamedStruct {
273 names,
274 r#struct: Some(struct_),
275 };
276
277 let read_rel = ReadRel {
278 base_schema: Some(named_struct),
279 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
280 names: table,
281 advanced_extension: None,
282 })),
283 ..Default::default()
284 };
285
286 Ok(read_rel)
287 }
288}
289
290impl RelationParsePair for FilterRel {
291 fn rule() -> Rule {
292 Rule::filter_relation
293 }
294
295 fn message() -> &'static str {
296 "FilterRel"
297 }
298
299 fn into_rel(self) -> Rel {
300 Rel {
301 rel_type: Some(RelType::Filter(Box::new(self))),
302 }
303 }
304
305 fn parse_pair_with_context(
306 extensions: &SimpleExtensions,
307 pair: pest::iterators::Pair<Rule>,
308 input_children: Vec<Box<Rel>>,
309 _input_field_count: usize,
310 ) -> Result<Self, MessageParseError> {
311 assert_eq!(pair.as_rule(), Self::rule());
314 let input = expect_one_child(Self::message(), &pair, input_children)?;
315 let mut iter = RuleIter::from(pair.into_inner());
316 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
318 let references_pair = iter.pop(Rule::reference_list);
320 iter.done();
321
322 let emit = parse_reference_emit(references_pair);
323 let common = RelCommon {
324 emit_kind: Some(emit),
325 ..Default::default()
326 };
327
328 Ok(FilterRel {
329 input: Some(input),
330 condition: Some(Box::new(condition)),
331 common: Some(common),
332 advanced_extension: None,
333 })
334 }
335}
336
337impl RelationParsePair for ProjectRel {
338 fn rule() -> Rule {
339 Rule::project_relation
340 }
341
342 fn message() -> &'static str {
343 "ProjectRel"
344 }
345
346 fn into_rel(self) -> Rel {
347 Rel {
348 rel_type: Some(RelType::Project(Box::new(self))),
349 }
350 }
351
352 fn parse_pair_with_context(
353 extensions: &SimpleExtensions,
354 pair: pest::iterators::Pair<Rule>,
355 input_children: Vec<Box<Rel>>,
356 input_field_count: usize,
357 ) -> Result<Self, MessageParseError> {
358 assert_eq!(pair.as_rule(), Self::rule());
359 let input = expect_one_child(Self::message(), &pair, input_children)?;
360
361 let arguments_pair = unwrap_single_pair(pair);
363
364 let mut expressions = Vec::new();
365 let mut output_mapping = Vec::new();
366
367 for arg in arguments_pair.into_inner() {
369 let inner_arg = crate::parser::unwrap_single_pair(arg);
370 match inner_arg.as_rule() {
371 Rule::reference => {
372 let field_index = FieldIndex::parse_pair(inner_arg);
374 output_mapping.push(field_index.0);
375 }
376 Rule::expression => {
377 let _expr = Expression::parse_pair(extensions, inner_arg)?;
379 expressions.push(_expr);
380 output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
382 }
383 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
384 }
385 }
386
387 let emit = EmitKind::Emit(Emit { output_mapping });
388 let common = RelCommon {
389 emit_kind: Some(emit),
390 ..Default::default()
391 };
392
393 Ok(ProjectRel {
394 input: Some(input),
395 expressions,
396 common: Some(common),
397 advanced_extension: None,
398 })
399 }
400}
401
402impl RelationParsePair for AggregateRel {
403 fn rule() -> Rule {
404 Rule::aggregate_relation
405 }
406
407 fn message() -> &'static str {
408 "AggregateRel"
409 }
410
411 fn into_rel(self) -> Rel {
412 Rel {
413 rel_type: Some(RelType::Aggregate(Box::new(self))),
414 }
415 }
416
417 fn parse_pair_with_context(
418 extensions: &SimpleExtensions,
419 pair: pest::iterators::Pair<Rule>,
420 input_children: Vec<Box<Rel>>,
421 _input_field_count: usize,
422 ) -> Result<Self, MessageParseError> {
423 assert_eq!(pair.as_rule(), Self::rule());
424 let input = expect_one_child(Self::message(), &pair, input_children)?;
425 let mut iter = RuleIter::from(pair.into_inner());
426 let group_by_pair = iter.pop(Rule::aggregate_group_by);
427 let output_pair = iter.pop(Rule::aggregate_output);
428 iter.done();
429 let mut grouping_expressions = Vec::new();
430 for group_by_item in group_by_pair.into_inner() {
431 match group_by_item.as_rule() {
432 Rule::reference => {
433 let field_index = FieldIndex::parse_pair(group_by_item);
434 grouping_expressions.push(Expression {
435 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
436 field_index.to_field_reference(),
437 ))),
438 });
439 }
440 Rule::empty => {
441 }
443 _ => panic!(
444 "Unexpected group-by item rule: {:?}",
445 group_by_item.as_rule()
446 ),
447 }
448 }
449
450 let mut measures = Vec::new();
452 let mut output_mapping = Vec::new();
453 let group_by_count = grouping_expressions.len();
454 let mut measure_count = 0;
455
456 for output_item in output_pair.into_inner() {
457 let inner_item = unwrap_single_pair(output_item);
458 match inner_item.as_rule() {
459 Rule::reference => {
460 let field_index = FieldIndex::parse_pair(inner_item);
461 output_mapping.push(field_index.0);
462 }
463 Rule::aggregate_measure => {
464 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
465 measures.push(measure);
466 output_mapping.push(group_by_count as i32 + measure_count);
467 measure_count += 1;
468 }
469 _ => panic!(
470 "Unexpected inner output item rule: {:?}",
471 inner_item.as_rule()
472 ),
473 }
474 }
475
476 let emit = EmitKind::Emit(Emit { output_mapping });
477 let common = RelCommon {
478 emit_kind: Some(emit),
479 ..Default::default()
480 };
481
482 Ok(AggregateRel {
483 input: Some(input),
484 grouping_expressions,
485 groupings: vec![], measures,
487 common: Some(common),
488 advanced_extension: None,
489 })
490 }
491}
492
493impl ScopedParsePair for SortField {
494 fn rule() -> Rule {
495 Rule::sort_field
496 }
497
498 fn message() -> &'static str {
499 "SortField"
500 }
501
502 fn parse_pair(
503 _extensions: &SimpleExtensions,
504 pair: pest::iterators::Pair<Rule>,
505 ) -> Result<Self, MessageParseError> {
506 assert_eq!(pair.as_rule(), Self::rule());
507 let mut iter = RuleIter::from(pair.into_inner());
508 let reference_pair = iter.pop(Rule::reference);
509 let field_index = FieldIndex::parse_pair(reference_pair);
510 let direction_pair = iter.pop(Rule::sort_direction);
511 let direction = match direction_pair.as_str().trim_start_matches('&') {
515 "AscNullsFirst" => SortDirection::AscNullsFirst,
516 "AscNullsLast" => SortDirection::AscNullsLast,
517 "DescNullsFirst" => SortDirection::DescNullsFirst,
518 "DescNullsLast" => SortDirection::DescNullsLast,
519 other => {
520 return Err(MessageParseError::invalid(
521 "SortDirection",
522 direction_pair.as_span(),
523 format!("Unknown sort direction: {other}"),
524 ));
525 }
526 };
527 iter.done();
528 Ok(SortField {
529 expr: Some(Expression {
530 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
531 field_index.to_field_reference(),
532 ))),
533 }),
534 sort_kind: Some(SortKind::Direction(direction as i32)),
536 })
537 }
538}
539
540impl RelationParsePair for SortRel {
541 fn rule() -> Rule {
542 Rule::sort_relation
543 }
544
545 fn message() -> &'static str {
546 "SortRel"
547 }
548
549 fn into_rel(self) -> Rel {
550 Rel {
551 rel_type: Some(RelType::Sort(Box::new(self))),
552 }
553 }
554
555 fn parse_pair_with_context(
556 extensions: &SimpleExtensions,
557 pair: pest::iterators::Pair<Rule>,
558 input_children: Vec<Box<Rel>>,
559 _input_field_count: usize,
560 ) -> Result<Self, MessageParseError> {
561 assert_eq!(pair.as_rule(), Self::rule());
562 let input = expect_one_child(Self::message(), &pair, input_children)?;
563 let mut iter = RuleIter::from(pair.into_inner());
564 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
565 let reference_list_pair = iter.pop(Rule::reference_list);
566 let mut sorts = Vec::new();
567 for sort_field_pair in sort_field_list_pair.into_inner() {
568 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
569 sorts.push(sort_field);
570 }
571 let emit = parse_reference_emit(reference_list_pair);
572 let common = RelCommon {
573 emit_kind: Some(emit),
574 ..Default::default()
575 };
576 iter.done();
577 Ok(SortRel {
578 input: Some(input),
579 sorts,
580 common: Some(common),
581 advanced_extension: None,
582 })
583 }
584}
585
586impl ScopedParsePair for CountMode {
587 fn rule() -> Rule {
588 Rule::fetch_value
589 }
590 fn message() -> &'static str {
591 "CountMode"
592 }
593 fn parse_pair(
594 extensions: &SimpleExtensions,
595 pair: pest::iterators::Pair<Rule>,
596 ) -> Result<Self, MessageParseError> {
597 assert_eq!(pair.as_rule(), Self::rule());
598 let mut arg_inner = RuleIter::from(pair.into_inner());
599 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
600 int_pair
601 } else {
602 arg_inner.pop(Rule::expression)
603 };
604 match value_pair.as_rule() {
605 Rule::integer => {
606 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
607 MessageParseError::invalid(
608 Self::message(),
609 value_pair.as_span(),
610 format!("Invalid integer: {e}"),
611 )
612 })?;
613 if value < 0 {
614 return Err(MessageParseError::invalid(
615 Self::message(),
616 value_pair.as_span(),
617 format!("Fetch limit must be non-negative, got: {value}"),
618 ));
619 }
620 Ok(CountMode::CountExpr(i64_literal_expr(value)))
621 }
622 Rule::expression => {
623 let expr = Expression::parse_pair(extensions, value_pair)?;
624 Ok(CountMode::CountExpr(Box::new(expr)))
625 }
626 _ => Err(MessageParseError::invalid(
627 Self::message(),
628 value_pair.as_span(),
629 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
630 )),
631 }
632 }
633}
634
635fn i64_literal_expr(value: i64) -> Box<Expression> {
636 Box::new(Expression {
637 rex_type: Some(RexType::Literal(Literal {
638 nullable: false,
639 type_variation_reference: 0,
640 literal_type: Some(LiteralType::I64(value)),
641 })),
642 })
643}
644
645impl ScopedParsePair for OffsetMode {
646 fn rule() -> Rule {
647 Rule::fetch_value
648 }
649 fn message() -> &'static str {
650 "OffsetMode"
651 }
652 fn parse_pair(
653 extensions: &SimpleExtensions,
654 pair: pest::iterators::Pair<Rule>,
655 ) -> Result<Self, MessageParseError> {
656 assert_eq!(pair.as_rule(), Self::rule());
657 let mut arg_inner = RuleIter::from(pair.into_inner());
658 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
659 int_pair
660 } else {
661 arg_inner.pop(Rule::expression)
662 };
663 match value_pair.as_rule() {
664 Rule::integer => {
665 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
666 MessageParseError::invalid(
667 Self::message(),
668 value_pair.as_span(),
669 format!("Invalid integer: {e}"),
670 )
671 })?;
672 if value < 0 {
673 return Err(MessageParseError::invalid(
674 Self::message(),
675 value_pair.as_span(),
676 format!("Fetch offset must be non-negative, got: {value}"),
677 ));
678 }
679 Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
680 }
681 Rule::expression => {
682 let expr = Expression::parse_pair(extensions, value_pair)?;
683 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
684 }
685 _ => Err(MessageParseError::invalid(
686 Self::message(),
687 value_pair.as_span(),
688 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
689 )),
690 }
691 }
692}
693
694impl RelationParsePair for FetchRel {
695 fn rule() -> Rule {
696 Rule::fetch_relation
697 }
698
699 fn message() -> &'static str {
700 "FetchRel"
701 }
702
703 fn into_rel(self) -> Rel {
704 Rel {
705 rel_type: Some(RelType::Fetch(Box::new(self))),
706 }
707 }
708
709 fn parse_pair_with_context(
710 extensions: &SimpleExtensions,
711 pair: pest::iterators::Pair<Rule>,
712 input_children: Vec<Box<Rel>>,
713 _input_field_count: usize,
714 ) -> Result<Self, MessageParseError> {
715 assert_eq!(pair.as_rule(), Self::rule());
716 let input = expect_one_child(Self::message(), &pair, input_children)?;
717 let mut iter = RuleIter::from(pair.into_inner());
718
719 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
721 None => {
722 iter.pop(Rule::empty);
724 (None, None)
725 }
726 Some(fetch_args_pair) => {
727 let extractor =
728 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
729 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
730 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
731 extractor.done()?;
732 (limit_pair, offset_pair)
733 }
734 };
735
736 let reference_list_pair = iter.pop(Rule::reference_list);
737 let emit = parse_reference_emit(reference_list_pair);
738 let common = RelCommon {
739 emit_kind: Some(emit),
740 ..Default::default()
741 };
742 iter.done();
743
744 let count_mode = limit_pair
746 .map(|pair| CountMode::parse_pair(extensions, pair))
747 .transpose()?;
748 let offset_mode = offset_pair
749 .map(|pair| OffsetMode::parse_pair(extensions, pair))
750 .transpose()?;
751 Ok(FetchRel {
752 input: Some(input),
753 common: Some(common),
754 advanced_extension: None,
755 offset_mode,
756 count_mode,
757 })
758 }
759}
760
761impl ParsePair for join_rel::JoinType {
762 fn rule() -> Rule {
763 Rule::join_type
764 }
765
766 fn message() -> &'static str {
767 "JoinType"
768 }
769
770 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
771 assert_eq!(pair.as_rule(), Self::rule());
772 let join_type_str = pair.as_str().trim_start_matches('&');
773 match join_type_str {
774 "Inner" => join_rel::JoinType::Inner,
775 "Left" => join_rel::JoinType::Left,
776 "Right" => join_rel::JoinType::Right,
777 "Outer" => join_rel::JoinType::Outer,
778 "LeftSemi" => join_rel::JoinType::LeftSemi,
779 "RightSemi" => join_rel::JoinType::RightSemi,
780 "LeftAnti" => join_rel::JoinType::LeftAnti,
781 "RightAnti" => join_rel::JoinType::RightAnti,
782 "LeftSingle" => join_rel::JoinType::LeftSingle,
783 "RightSingle" => join_rel::JoinType::RightSingle,
784 "LeftMark" => join_rel::JoinType::LeftMark,
785 "RightMark" => join_rel::JoinType::RightMark,
786 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
787 }
788 }
789}
790
791impl RelationParsePair for JoinRel {
792 fn rule() -> Rule {
793 Rule::join_relation
794 }
795
796 fn message() -> &'static str {
797 "JoinRel"
798 }
799
800 fn into_rel(self) -> Rel {
801 Rel {
802 rel_type: Some(RelType::Join(Box::new(self))),
803 }
804 }
805
806 fn parse_pair_with_context(
807 extensions: &SimpleExtensions,
808 pair: pest::iterators::Pair<Rule>,
809 input_children: Vec<Box<Rel>>,
810 _input_field_count: usize,
811 ) -> Result<Self, MessageParseError> {
812 assert_eq!(pair.as_rule(), Self::rule());
813
814 if input_children.len() != 2 {
816 return Err(MessageParseError::invalid(
817 Self::message(),
818 pair.as_span(),
819 format!(
820 "JoinRel should have exactly 2 input children, got {}",
821 input_children.len()
822 ),
823 ));
824 }
825
826 let mut children_iter = input_children.into_iter();
827 let left = children_iter.next().unwrap();
828 let right = children_iter.next().unwrap();
829
830 let mut iter = RuleIter::from(pair.into_inner());
831
832 let join_type = iter.parse_next::<join_rel::JoinType>();
834
835 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
837
838 let reference_list_pair = iter.pop(Rule::reference_list);
840 iter.done();
841
842 let emit = parse_reference_emit(reference_list_pair);
843 let common = RelCommon {
844 emit_kind: Some(emit),
845 ..Default::default()
846 };
847
848 Ok(JoinRel {
849 common: Some(common),
850 left: Some(left),
851 right: Some(right),
852 expression: Some(Box::new(condition)),
853 post_join_filter: None, r#type: join_type as i32,
855 advanced_extension: None,
856 })
857 }
858}
859
860#[cfg(test)]
861mod tests {
862 use pest::Parser;
863
864 use super::*;
865 use crate::fixtures::TestContext;
866 use crate::parser::{ExpressionParser, Rule};
867
868 #[test]
869 fn test_parse_relation() {
870 }
872
873 #[test]
874 fn test_parse_read_relation() {
875 let extensions = SimpleExtensions::default();
876 let read = ReadRel::parse_pair_with_context(
877 &extensions,
878 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
879 vec![],
880 0,
881 )
882 .unwrap();
883 let names = match &read.read_type {
884 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
885 _ => panic!("Expected NamedTable"),
886 };
887 assert_eq!(names, &["ab", "cd", "ef"]);
888 let columns = &read
889 .base_schema
890 .as_ref()
891 .unwrap()
892 .r#struct
893 .as_ref()
894 .unwrap()
895 .types;
896 assert_eq!(columns.len(), 2);
897 }
898
899 fn example_read_relation() -> ReadRel {
901 let extensions = SimpleExtensions::default();
902 ReadRel::parse_pair_with_context(
903 &extensions,
904 parse_exact(
905 Rule::read_relation,
906 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
907 ),
908 vec![],
909 0,
910 )
911 .unwrap()
912 }
913
914 #[test]
915 fn test_parse_filter_relation() {
916 let extensions = SimpleExtensions::default();
917 let filter = FilterRel::parse_pair_with_context(
918 &extensions,
919 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
920 vec![Box::new(example_read_relation().into_rel())],
921 3,
922 )
923 .unwrap();
924 let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
925 let emit = match emit_kind {
926 EmitKind::Emit(emit) => &emit.output_mapping,
927 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
928 };
929 assert_eq!(emit, &[0, 1, 2]);
930 }
931
932 #[test]
933 fn test_parse_project_relation() {
934 let extensions = SimpleExtensions::default();
935 let project = ProjectRel::parse_pair_with_context(
936 &extensions,
937 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
938 vec![Box::new(example_read_relation().into_rel())],
939 3,
940 )
941 .unwrap();
942
943 assert_eq!(project.expressions.len(), 1);
945
946 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
947 let emit = match emit_kind {
948 EmitKind::Emit(emit) => &emit.output_mapping,
949 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
950 };
951 assert_eq!(emit, &[0, 1, 3]);
953 }
954
955 #[test]
956 fn test_parse_project_relation_complex() {
957 let extensions = SimpleExtensions::default();
958 let project = ProjectRel::parse_pair_with_context(
959 &extensions,
960 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
961 vec![Box::new(example_read_relation().into_rel())],
962 5, )
964 .unwrap();
965
966 assert_eq!(project.expressions.len(), 2);
968
969 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
970 let emit = match emit_kind {
971 EmitKind::Emit(emit) => &emit.output_mapping,
972 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
973 };
974 assert_eq!(emit, &[5, 0, 6, 2, 1]);
977 }
978
979 #[test]
980 fn test_parse_aggregate_relation() {
981 let extensions = TestContext::new()
982 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
983 .with_function(1, 10, "sum")
984 .with_function(1, 11, "count")
985 .extensions;
986
987 let aggregate = AggregateRel::parse_pair_with_context(
988 &extensions,
989 parse_exact(
990 Rule::aggregate_relation,
991 "Aggregate[$0, $1 => sum($2), $0, count($2)]",
992 ),
993 vec![Box::new(example_read_relation().into_rel())],
994 3,
995 )
996 .unwrap();
997
998 assert_eq!(aggregate.grouping_expressions.len(), 2);
1000 assert_eq!(aggregate.measures.len(), 2);
1001
1002 let emit_kind = &aggregate
1003 .common
1004 .as_ref()
1005 .unwrap()
1006 .emit_kind
1007 .as_ref()
1008 .unwrap();
1009 let emit = match emit_kind {
1010 EmitKind::Emit(emit) => &emit.output_mapping,
1011 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1012 };
1013 assert_eq!(emit, &[2, 0, 3]);
1016 }
1017
1018 #[test]
1019 fn test_parse_aggregate_relation_simple() {
1020 let extensions = TestContext::new()
1021 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1022 .with_function(1, 10, "sum")
1023 .with_function(1, 11, "count")
1024 .extensions;
1025
1026 let aggregate = AggregateRel::parse_pair_with_context(
1027 &extensions,
1028 parse_exact(
1029 Rule::aggregate_relation,
1030 "Aggregate[$0 => sum($1), count($1)]",
1031 ),
1032 vec![Box::new(example_read_relation().into_rel())],
1033 3,
1034 )
1035 .unwrap();
1036
1037 assert_eq!(aggregate.grouping_expressions.len(), 1);
1039 assert_eq!(aggregate.measures.len(), 2);
1040
1041 let emit_kind = &aggregate
1042 .common
1043 .as_ref()
1044 .unwrap()
1045 .emit_kind
1046 .as_ref()
1047 .unwrap();
1048 let emit = match emit_kind {
1049 EmitKind::Emit(emit) => &emit.output_mapping,
1050 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1051 };
1052 assert_eq!(emit, &[1, 2]);
1054 }
1055
1056 #[test]
1057 fn test_parse_aggregate_relation_no_group_by() {
1058 let extensions = TestContext::new()
1059 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1060 .with_function(1, 10, "sum")
1061 .with_function(1, 11, "count")
1062 .extensions;
1063
1064 let aggregate = AggregateRel::parse_pair_with_context(
1065 &extensions,
1066 parse_exact(
1067 Rule::aggregate_relation,
1068 "Aggregate[_ => sum($0), count($1)]",
1069 ),
1070 vec![Box::new(example_read_relation().into_rel())],
1071 3,
1072 )
1073 .unwrap();
1074
1075 assert_eq!(aggregate.grouping_expressions.len(), 0);
1077 assert_eq!(aggregate.measures.len(), 2);
1078
1079 let emit_kind = &aggregate
1080 .common
1081 .as_ref()
1082 .unwrap()
1083 .emit_kind
1084 .as_ref()
1085 .unwrap();
1086 let emit = match emit_kind {
1087 EmitKind::Emit(emit) => &emit.output_mapping,
1088 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1089 };
1090 assert_eq!(emit, &[0, 1]);
1092 }
1093
1094 #[test]
1095 fn test_parse_aggregate_relation_empty_group_by() {
1096 let extensions = TestContext::new()
1097 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1098 .with_function(1, 10, "sum")
1099 .with_function(1, 11, "count")
1100 .extensions;
1101
1102 let aggregate = AggregateRel::parse_pair_with_context(
1103 &extensions,
1104 parse_exact(
1105 Rule::aggregate_relation,
1106 "Aggregate[_ => sum($0), count($1)]",
1107 ),
1108 vec![Box::new(example_read_relation().into_rel())],
1109 3,
1110 )
1111 .unwrap();
1112
1113 assert_eq!(aggregate.grouping_expressions.len(), 0);
1115 assert_eq!(aggregate.measures.len(), 2);
1116
1117 let emit_kind = &aggregate
1118 .common
1119 .as_ref()
1120 .unwrap()
1121 .emit_kind
1122 .as_ref()
1123 .unwrap();
1124 let emit = match emit_kind {
1125 EmitKind::Emit(emit) => &emit.output_mapping,
1126 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1127 };
1128 assert_eq!(emit, &[0, 1]);
1130 }
1131
1132 #[test]
1133 fn test_fetch_relation_positive_values() {
1134 let extensions = SimpleExtensions::default();
1135
1136 let fetch_rel = FetchRel::parse_pair_with_context(
1138 &extensions,
1139 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1140 vec![Box::new(example_read_relation().into_rel())],
1141 3,
1142 )
1143 .unwrap();
1144
1145 assert_eq!(
1147 fetch_rel.count_mode,
1148 Some(CountMode::CountExpr(i64_literal_expr(10)))
1149 );
1150 assert_eq!(
1151 fetch_rel.offset_mode,
1152 Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1153 );
1154 }
1155
1156 #[test]
1157 fn test_fetch_relation_negative_limit_rejected() {
1158 let extensions = SimpleExtensions::default();
1159
1160 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1162 if let Ok(mut pairs) = parsed_result {
1163 let pair = pairs.next().unwrap();
1164 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1165 let result = FetchRel::parse_pair_with_context(
1167 &extensions,
1168 pair,
1169 vec![Box::new(example_read_relation().into_rel())],
1170 3,
1171 );
1172 assert!(result.is_err());
1173 let error_msg = result.unwrap_err().to_string();
1174 assert!(error_msg.contains("Fetch limit must be non-negative"));
1175 } else {
1176 println!("Grammar prevents negative limit values at parse time");
1179 }
1180 } else {
1181 println!("Grammar prevents negative limit values at parse time");
1183 }
1184 }
1185
1186 #[test]
1187 fn test_fetch_relation_negative_offset_rejected() {
1188 let extensions = SimpleExtensions::default();
1189
1190 let parsed_result =
1192 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1193 if let Ok(mut pairs) = parsed_result {
1194 let pair = pairs.next().unwrap();
1195 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1196 let result = FetchRel::parse_pair_with_context(
1198 &extensions,
1199 pair,
1200 vec![Box::new(example_read_relation().into_rel())],
1201 3,
1202 );
1203 assert!(result.is_err());
1204 let error_msg = result.unwrap_err().to_string();
1205 assert!(error_msg.contains("Fetch offset must be non-negative"));
1206 } else {
1207 println!("Grammar prevents negative offset values at parse time");
1209 }
1210 } else {
1211 println!("Grammar prevents negative offset values at parse time");
1213 }
1214 }
1215
1216 #[test]
1217 fn test_parse_join_relation() {
1218 let extensions = TestContext::new()
1219 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1220 .with_function(1, 10, "eq")
1221 .extensions;
1222
1223 let left_rel = example_read_relation().into_rel();
1224 let right_rel = example_read_relation().into_rel();
1225
1226 let join = JoinRel::parse_pair_with_context(
1227 &extensions,
1228 parse_exact(
1229 Rule::join_relation,
1230 "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1231 ),
1232 vec![Box::new(left_rel), Box::new(right_rel)],
1233 6, )
1235 .unwrap();
1236
1237 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1239
1240 assert!(join.left.is_some());
1242 assert!(join.right.is_some());
1243
1244 assert!(join.expression.is_some());
1246
1247 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1248 let emit = match emit_kind {
1249 EmitKind::Emit(emit) => &emit.output_mapping,
1250 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1251 };
1252 assert_eq!(emit, &[0, 1, 3, 4]);
1254 }
1255
1256 #[test]
1257 fn test_parse_join_relation_left_outer() {
1258 let extensions = TestContext::new()
1259 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1260 .with_function(1, 10, "eq")
1261 .extensions;
1262
1263 let left_rel = example_read_relation().into_rel();
1264 let right_rel = example_read_relation().into_rel();
1265
1266 let join = JoinRel::parse_pair_with_context(
1267 &extensions,
1268 parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1269 vec![Box::new(left_rel), Box::new(right_rel)],
1270 6,
1271 )
1272 .unwrap();
1273
1274 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1276
1277 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1278 let emit = match emit_kind {
1279 EmitKind::Emit(emit) => &emit.output_mapping,
1280 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1281 };
1282 assert_eq!(emit, &[0, 1, 2]);
1284 }
1285
1286 #[test]
1287 fn test_parse_join_relation_left_semi() {
1288 let extensions = TestContext::new()
1289 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1290 .with_function(1, 10, "eq")
1291 .extensions;
1292
1293 let left_rel = example_read_relation().into_rel();
1294 let right_rel = example_read_relation().into_rel();
1295
1296 let join = JoinRel::parse_pair_with_context(
1297 &extensions,
1298 parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1299 vec![Box::new(left_rel), Box::new(right_rel)],
1300 6,
1301 )
1302 .unwrap();
1303
1304 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1306
1307 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1308 let emit = match emit_kind {
1309 EmitKind::Emit(emit) => &emit.output_mapping,
1310 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1311 };
1312 assert_eq!(emit, &[0, 1]);
1314 }
1315
1316 #[test]
1317 fn test_parse_join_relation_requires_two_children() {
1318 let extensions = SimpleExtensions::default();
1319
1320 let result = JoinRel::parse_pair_with_context(
1322 &extensions,
1323 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1324 vec![],
1325 0,
1326 );
1327 assert!(result.is_err());
1328
1329 let result = JoinRel::parse_pair_with_context(
1331 &extensions,
1332 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1333 vec![Box::new(example_read_relation().into_rel())],
1334 3,
1335 );
1336 assert!(result.is_err());
1337
1338 let result = JoinRel::parse_pair_with_context(
1340 &extensions,
1341 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1342 vec![
1343 Box::new(example_read_relation().into_rel()),
1344 Box::new(example_read_relation().into_rel()),
1345 Box::new(example_read_relation().into_rel()),
1346 ],
1347 9,
1348 );
1349 assert!(result.is_err());
1350 }
1351
1352 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
1353 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1354 assert_eq!(pairs.as_str(), input);
1355 let pair = pairs.next().unwrap();
1356 assert_eq!(pairs.next(), None);
1357 pair
1358 }
1359}