1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use substrait::proto::expression::literal::LiteralType;
5use substrait::proto::expression::{Literal, RexType};
6use substrait::proto::fetch_rel::{CountMode, OffsetMode};
7use substrait::proto::rel::RelType;
8use substrait::proto::rel_common::{Emit, EmitKind};
9use substrait::proto::sort_field::{SortDirection, SortKind};
10use substrait::proto::{
11 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
12 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
13};
14
15use super::{
16 ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
17};
18use crate::extensions::any::Any;
19use crate::extensions::registry::ExtensionError;
20use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions};
21use crate::parser::errors::{ParseContext, ParseError};
22use crate::parser::expressions::{FieldIndex, Name};
23
24pub struct RelationParsingContext<'a> {
26 pub extensions: &'a SimpleExtensions,
27 pub registry: &'a ExtensionRegistry,
28 pub line_no: i64,
29 pub line: &'a str,
30}
31
32impl<'a> RelationParsingContext<'a> {
33 pub fn resolve_extension_detail(
35 &self,
36 extension_name: &str,
37 extension_args: &ExtensionArgs,
38 ) -> Result<Option<Any>, ParseError> {
39 let detail = self
40 .registry
41 .parse_extension(extension_name, extension_args);
42
43 match detail {
44 Ok(any) => Ok(Some(any)),
45 Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
46 name: extension_name.to_string(),
47 context: ParseContext::new(self.line_no, self.line.to_string()),
48 }),
49 Err(err) => Err(ParseError::ExtensionDetail(
50 ParseContext::new(self.line_no, self.line.to_string()),
51 err,
52 )),
53 }
54 }
55}
56
57pub trait RelationParsePair: Sized {
60 fn rule() -> Rule;
61
62 fn message() -> &'static str;
63
64 fn parse_pair_with_context(
72 extensions: &SimpleExtensions,
73 pair: Pair<Rule>,
74 input_children: Vec<Box<Rel>>,
75 _input_field_count: usize,
76 ) -> Result<Self, MessageParseError>;
77
78 fn into_rel(self) -> Rel;
79}
80
81pub struct TableName(Vec<String>);
82
83impl ParsePair for TableName {
84 fn rule() -> Rule {
85 Rule::table_name
86 }
87
88 fn message() -> &'static str {
89 "TableName"
90 }
91
92 fn parse_pair(pair: Pair<Rule>) -> Self {
93 assert_eq!(pair.as_rule(), Self::rule());
94 let pairs = pair.into_inner();
95 let mut names = Vec::with_capacity(pairs.len());
96 let mut iter = RuleIter::from(pairs);
97 while let Some(name) = iter.parse_if_next::<Name>() {
98 names.push(name.0);
99 }
100 iter.done();
101 Self(names)
102 }
103}
104
105#[derive(Debug, Clone)]
106pub struct Column {
107 pub name: String,
108 pub typ: Type,
109}
110
111impl ScopedParsePair for Column {
112 fn rule() -> Rule {
113 Rule::named_column
114 }
115
116 fn message() -> &'static str {
117 "Column"
118 }
119
120 fn parse_pair(
121 extensions: &SimpleExtensions,
122 pair: Pair<Rule>,
123 ) -> Result<Self, MessageParseError> {
124 assert_eq!(pair.as_rule(), Self::rule());
125 let mut iter = RuleIter::from(pair.into_inner());
126 let name = iter.parse_next::<Name>().0;
127 let typ = iter.parse_next_scoped(extensions)?;
128 iter.done();
129 Ok(Self { name, typ })
130 }
131}
132
133pub struct NamedColumnList(Vec<Column>);
134
135impl ScopedParsePair for NamedColumnList {
136 fn rule() -> Rule {
137 Rule::named_column_list
138 }
139
140 fn message() -> &'static str {
141 "NamedColumnList"
142 }
143
144 fn parse_pair(
145 extensions: &SimpleExtensions,
146 pair: Pair<Rule>,
147 ) -> Result<Self, MessageParseError> {
148 assert_eq!(pair.as_rule(), Self::rule());
149 let mut columns = Vec::new();
150 for col in pair.into_inner() {
151 columns.push(Column::parse_pair(extensions, col)?);
152 }
153 Ok(Self(columns))
154 }
155}
156
157#[allow(clippy::vec_box)]
162pub(crate) fn expect_one_child(
163 message: &'static str,
164 pair: &Pair<Rule>,
165 mut input_children: Vec<Box<Rel>>,
166) -> Result<Box<Rel>, MessageParseError> {
167 match input_children.len() {
168 0 => Err(MessageParseError::invalid(
169 message,
170 pair.as_span(),
171 format!("{message} missing child"),
172 )),
173 1 => Ok(input_children.pop().unwrap()),
174 n => Err(MessageParseError::invalid(
175 message,
176 pair.as_span(),
177 format!("{message} should have 1 input child, got {n}"),
178 )),
179 }
180}
181
182fn parse_reference_emit(pair: Pair<Rule>) -> EmitKind {
184 assert_eq!(pair.as_rule(), Rule::reference_list);
185 let output_mapping = pair
186 .into_inner()
187 .map(|p| FieldIndex::parse_pair(p).0)
188 .collect::<Vec<i32>>();
189 EmitKind::Emit(Emit { output_mapping })
190}
191
192pub struct ParsedNamedArgs<'a> {
198 map: HashMap<&'a str, Pair<'a, Rule>>,
199}
200
201impl<'a> ParsedNamedArgs<'a> {
202 pub fn new(
203 pairs: pest::iterators::Pairs<'a, Rule>,
204 rule: Rule,
205 ) -> Result<Self, MessageParseError> {
206 let mut map = HashMap::new();
207 for pair in pairs {
208 assert_eq!(pair.as_rule(), rule);
209 let mut inner = pair.clone().into_inner();
210 let name_pair = inner.next().unwrap();
211 let value_pair = inner.next().unwrap();
212 assert_eq!(inner.next(), None);
213 let name = name_pair.as_str();
214 if map.contains_key(name) {
215 return Err(MessageParseError::invalid(
216 "NamedArg",
217 name_pair.as_span(),
218 format!("Duplicate argument: {name}"),
219 ));
220 }
221 map.insert(name, value_pair);
222 }
223 Ok(Self { map })
224 }
225
226 pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
230 let pair = self.map.remove(name).inspect(|pair| {
231 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
232 });
233 (self, pair)
234 }
235
236 pub fn done(self) -> Result<(), MessageParseError> {
238 if let Some((name, pair)) = self.map.iter().next() {
239 return Err(MessageParseError::invalid(
240 "NamedArgExtractor",
241 pair.as_span(),
243 format!("Unknown argument: {name}"),
244 ));
245 }
246 Ok(())
247 }
248}
249
250impl RelationParsePair for ReadRel {
251 fn rule() -> Rule {
252 Rule::read_relation
253 }
254
255 fn message() -> &'static str {
256 "ReadRel"
257 }
258
259 fn into_rel(self) -> Rel {
260 Rel {
261 rel_type: Some(RelType::Read(Box::new(self))),
262 }
263 }
264
265 fn parse_pair_with_context(
266 extensions: &SimpleExtensions,
267 pair: Pair<Rule>,
268 input_children: Vec<Box<Rel>>,
269 _input_field_count: usize,
270 ) -> Result<Self, MessageParseError> {
271 assert_eq!(pair.as_rule(), Self::rule());
272 if !input_children.is_empty() {
274 return Err(MessageParseError::invalid(
275 Self::message(),
276 pair.as_span(),
277 "ReadRel should have no input children",
278 ));
279 }
280 if _input_field_count != 0 {
281 let error = pest::error::Error::new_from_span(
282 pest::error::ErrorVariant::CustomError {
283 message: "ReadRel should have 0 input fields".to_string(),
284 },
285 pair.as_span(),
286 );
287 return Err(MessageParseError::new(
288 "ReadRel",
289 ErrorKind::InvalidValue,
290 Box::new(error),
291 ));
292 }
293
294 let mut iter = RuleIter::from(pair.into_inner());
295 let table = iter.parse_next::<TableName>().0;
296 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
297 iter.done();
298
299 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
300 let struct_ = r#type::Struct {
301 types,
302 type_variation_reference: 0,
303 nullability: r#type::Nullability::Required as i32,
304 };
305 let named_struct = NamedStruct {
306 names,
307 r#struct: Some(struct_),
308 };
309
310 let read_rel = ReadRel {
311 base_schema: Some(named_struct),
312 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
313 names: table,
314 advanced_extension: None,
315 })),
316 ..Default::default()
317 };
318
319 Ok(read_rel)
320 }
321}
322
323impl RelationParsePair for FilterRel {
324 fn rule() -> Rule {
325 Rule::filter_relation
326 }
327
328 fn message() -> &'static str {
329 "FilterRel"
330 }
331
332 fn into_rel(self) -> Rel {
333 Rel {
334 rel_type: Some(RelType::Filter(Box::new(self))),
335 }
336 }
337
338 fn parse_pair_with_context(
339 extensions: &SimpleExtensions,
340 pair: Pair<Rule>,
341 input_children: Vec<Box<Rel>>,
342 _input_field_count: usize,
343 ) -> Result<Self, MessageParseError> {
344 assert_eq!(pair.as_rule(), Self::rule());
347 let input = expect_one_child(Self::message(), &pair, input_children)?;
348 let mut iter = RuleIter::from(pair.into_inner());
349 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
351 let references_pair = iter.pop(Rule::reference_list);
353 iter.done();
354
355 let emit = parse_reference_emit(references_pair);
356 let common = RelCommon {
357 emit_kind: Some(emit),
358 ..Default::default()
359 };
360
361 Ok(FilterRel {
362 input: Some(input),
363 condition: Some(Box::new(condition)),
364 common: Some(common),
365 advanced_extension: None,
366 })
367 }
368}
369
370impl RelationParsePair for ProjectRel {
371 fn rule() -> Rule {
372 Rule::project_relation
373 }
374
375 fn message() -> &'static str {
376 "ProjectRel"
377 }
378
379 fn into_rel(self) -> Rel {
380 Rel {
381 rel_type: Some(RelType::Project(Box::new(self))),
382 }
383 }
384
385 fn parse_pair_with_context(
386 extensions: &SimpleExtensions,
387 pair: Pair<Rule>,
388 input_children: Vec<Box<Rel>>,
389 _input_field_count: usize,
390 ) -> Result<Self, MessageParseError> {
391 assert_eq!(pair.as_rule(), Self::rule());
392 let input = expect_one_child(Self::message(), &pair, input_children)?;
393
394 let arguments_pair = unwrap_single_pair(pair);
396
397 let mut expressions = Vec::new();
398 let mut output_mapping = Vec::new();
399
400 for arg in arguments_pair.into_inner() {
402 let inner_arg = unwrap_single_pair(arg);
403 match inner_arg.as_rule() {
404 Rule::reference => {
405 let field_index = FieldIndex::parse_pair(inner_arg);
407 output_mapping.push(field_index.0);
408 }
409 Rule::expression => {
410 let _expr = Expression::parse_pair(extensions, inner_arg)?;
412 expressions.push(_expr);
413 output_mapping.push(_input_field_count as i32 + (expressions.len() as i32 - 1));
415 }
416 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
417 }
418 }
419
420 let emit = EmitKind::Emit(Emit { output_mapping });
421 let common = RelCommon {
422 emit_kind: Some(emit),
423 ..Default::default()
424 };
425
426 Ok(ProjectRel {
427 input: Some(input),
428 expressions,
429 common: Some(common),
430 advanced_extension: None,
431 })
432 }
433}
434
435impl RelationParsePair for AggregateRel {
436 fn rule() -> Rule {
437 Rule::aggregate_relation
438 }
439
440 fn message() -> &'static str {
441 "AggregateRel"
442 }
443
444 fn into_rel(self) -> Rel {
445 Rel {
446 rel_type: Some(RelType::Aggregate(Box::new(self))),
447 }
448 }
449
450 fn parse_pair_with_context(
451 extensions: &SimpleExtensions,
452 pair: Pair<Rule>,
453 input_children: Vec<Box<Rel>>,
454 _input_field_count: usize,
455 ) -> Result<Self, MessageParseError> {
456 assert_eq!(pair.as_rule(), Self::rule());
457 let input = expect_one_child(Self::message(), &pair, input_children)?;
458 let mut iter = RuleIter::from(pair.into_inner());
459 let group_by_pair = iter.pop(Rule::aggregate_group_by);
460 let output_pair = iter.pop(Rule::aggregate_output);
461 iter.done();
462 let mut grouping_expressions = Vec::new();
463 for group_by_item in group_by_pair.into_inner() {
464 match group_by_item.as_rule() {
465 Rule::reference => {
466 let field_index = FieldIndex::parse_pair(group_by_item);
467 grouping_expressions.push(Expression {
468 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
469 field_index.to_field_reference(),
470 ))),
471 });
472 }
473 Rule::empty => {
474 }
476 _ => panic!(
477 "Unexpected group-by item rule: {:?}",
478 group_by_item.as_rule()
479 ),
480 }
481 }
482
483 let mut measures = Vec::new();
485 let mut output_mapping = Vec::new();
486 let group_by_count = grouping_expressions.len();
487 let mut measure_count = 0;
488
489 for output_item in output_pair.into_inner() {
490 let inner_item = unwrap_single_pair(output_item);
491 match inner_item.as_rule() {
492 Rule::reference => {
493 let field_index = FieldIndex::parse_pair(inner_item);
494 output_mapping.push(field_index.0);
495 }
496 Rule::aggregate_measure => {
497 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
498 measures.push(measure);
499 output_mapping.push(group_by_count as i32 + measure_count);
500 measure_count += 1;
501 }
502 _ => panic!(
503 "Unexpected inner output item rule: {:?}",
504 inner_item.as_rule()
505 ),
506 }
507 }
508
509 let emit = EmitKind::Emit(Emit { output_mapping });
510 let common = RelCommon {
511 emit_kind: Some(emit),
512 ..Default::default()
513 };
514
515 Ok(AggregateRel {
516 input: Some(input),
517 grouping_expressions,
518 groupings: vec![], measures,
520 common: Some(common),
521 advanced_extension: None,
522 })
523 }
524}
525
526impl ScopedParsePair for SortField {
527 fn rule() -> Rule {
528 Rule::sort_field
529 }
530
531 fn message() -> &'static str {
532 "SortField"
533 }
534
535 fn parse_pair(
536 _extensions: &SimpleExtensions,
537 pair: Pair<Rule>,
538 ) -> Result<Self, MessageParseError> {
539 assert_eq!(pair.as_rule(), Self::rule());
540 let mut iter = RuleIter::from(pair.into_inner());
541 let reference_pair = iter.pop(Rule::reference);
542 let field_index = FieldIndex::parse_pair(reference_pair);
543 let direction_pair = iter.pop(Rule::sort_direction);
544 let direction = match direction_pair.as_str().trim_start_matches('&') {
548 "AscNullsFirst" => SortDirection::AscNullsFirst,
549 "AscNullsLast" => SortDirection::AscNullsLast,
550 "DescNullsFirst" => SortDirection::DescNullsFirst,
551 "DescNullsLast" => SortDirection::DescNullsLast,
552 other => {
553 return Err(MessageParseError::invalid(
554 "SortDirection",
555 direction_pair.as_span(),
556 format!("Unknown sort direction: {other}"),
557 ));
558 }
559 };
560 iter.done();
561 Ok(SortField {
562 expr: Some(Expression {
563 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
564 field_index.to_field_reference(),
565 ))),
566 }),
567 sort_kind: Some(SortKind::Direction(direction as i32)),
569 })
570 }
571}
572
573impl RelationParsePair for SortRel {
574 fn rule() -> Rule {
575 Rule::sort_relation
576 }
577
578 fn message() -> &'static str {
579 "SortRel"
580 }
581
582 fn into_rel(self) -> Rel {
583 Rel {
584 rel_type: Some(RelType::Sort(Box::new(self))),
585 }
586 }
587
588 fn parse_pair_with_context(
589 extensions: &SimpleExtensions,
590 pair: Pair<Rule>,
591 input_children: Vec<Box<Rel>>,
592 _input_field_count: usize,
593 ) -> Result<Self, MessageParseError> {
594 assert_eq!(pair.as_rule(), Self::rule());
595 let input = expect_one_child(Self::message(), &pair, input_children)?;
596 let mut iter = RuleIter::from(pair.into_inner());
597 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
598 let reference_list_pair = iter.pop(Rule::reference_list);
599 let mut sorts = Vec::new();
600 for sort_field_pair in sort_field_list_pair.into_inner() {
601 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
602 sorts.push(sort_field);
603 }
604 let emit = parse_reference_emit(reference_list_pair);
605 let common = RelCommon {
606 emit_kind: Some(emit),
607 ..Default::default()
608 };
609 iter.done();
610 Ok(SortRel {
611 input: Some(input),
612 sorts,
613 common: Some(common),
614 advanced_extension: None,
615 })
616 }
617}
618
619impl ScopedParsePair for CountMode {
620 fn rule() -> Rule {
621 Rule::fetch_value
622 }
623 fn message() -> &'static str {
624 "CountMode"
625 }
626 fn parse_pair(
627 extensions: &SimpleExtensions,
628 pair: Pair<Rule>,
629 ) -> Result<Self, MessageParseError> {
630 assert_eq!(pair.as_rule(), Self::rule());
631 let mut arg_inner = RuleIter::from(pair.into_inner());
632 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
633 int_pair
634 } else {
635 arg_inner.pop(Rule::expression)
636 };
637 match value_pair.as_rule() {
638 Rule::integer => {
639 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
640 MessageParseError::invalid(
641 Self::message(),
642 value_pair.as_span(),
643 format!("Invalid integer: {e}"),
644 )
645 })?;
646 if value < 0 {
647 return Err(MessageParseError::invalid(
648 Self::message(),
649 value_pair.as_span(),
650 format!("Fetch limit must be non-negative, got: {value}"),
651 ));
652 }
653 Ok(CountMode::CountExpr(i64_literal_expr(value)))
654 }
655 Rule::expression => {
656 let expr = Expression::parse_pair(extensions, value_pair)?;
657 Ok(CountMode::CountExpr(Box::new(expr)))
658 }
659 _ => Err(MessageParseError::invalid(
660 Self::message(),
661 value_pair.as_span(),
662 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
663 )),
664 }
665 }
666}
667
668fn i64_literal_expr(value: i64) -> Box<Expression> {
669 Box::new(Expression {
670 rex_type: Some(RexType::Literal(Literal {
671 nullable: false,
672 type_variation_reference: 0,
673 literal_type: Some(LiteralType::I64(value)),
674 })),
675 })
676}
677
678impl ScopedParsePair for OffsetMode {
679 fn rule() -> Rule {
680 Rule::fetch_value
681 }
682 fn message() -> &'static str {
683 "OffsetMode"
684 }
685 fn parse_pair(
686 extensions: &SimpleExtensions,
687 pair: Pair<Rule>,
688 ) -> Result<Self, MessageParseError> {
689 assert_eq!(pair.as_rule(), Self::rule());
690 let mut arg_inner = RuleIter::from(pair.into_inner());
691 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
692 int_pair
693 } else {
694 arg_inner.pop(Rule::expression)
695 };
696 match value_pair.as_rule() {
697 Rule::integer => {
698 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
699 MessageParseError::invalid(
700 Self::message(),
701 value_pair.as_span(),
702 format!("Invalid integer: {e}"),
703 )
704 })?;
705 if value < 0 {
706 return Err(MessageParseError::invalid(
707 Self::message(),
708 value_pair.as_span(),
709 format!("Fetch offset must be non-negative, got: {value}"),
710 ));
711 }
712 Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
713 }
714 Rule::expression => {
715 let expr = Expression::parse_pair(extensions, value_pair)?;
716 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
717 }
718 _ => Err(MessageParseError::invalid(
719 Self::message(),
720 value_pair.as_span(),
721 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
722 )),
723 }
724 }
725}
726
727impl RelationParsePair for FetchRel {
728 fn rule() -> Rule {
729 Rule::fetch_relation
730 }
731
732 fn message() -> &'static str {
733 "FetchRel"
734 }
735
736 fn into_rel(self) -> Rel {
737 Rel {
738 rel_type: Some(RelType::Fetch(Box::new(self))),
739 }
740 }
741
742 fn parse_pair_with_context(
743 extensions: &SimpleExtensions,
744 pair: Pair<Rule>,
745 input_children: Vec<Box<Rel>>,
746 _input_field_count: usize,
747 ) -> Result<Self, MessageParseError> {
748 assert_eq!(pair.as_rule(), Self::rule());
749 let input = expect_one_child(Self::message(), &pair, input_children)?;
750 let mut iter = RuleIter::from(pair.into_inner());
751
752 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
754 None => {
755 iter.pop(Rule::empty);
757 (None, None)
758 }
759 Some(fetch_args_pair) => {
760 let extractor =
761 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
762 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
763 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
764 extractor.done()?;
765 (limit_pair, offset_pair)
766 }
767 };
768
769 let reference_list_pair = iter.pop(Rule::reference_list);
770 let emit = parse_reference_emit(reference_list_pair);
771 let common = RelCommon {
772 emit_kind: Some(emit),
773 ..Default::default()
774 };
775 iter.done();
776
777 let count_mode = limit_pair
779 .map(|pair| CountMode::parse_pair(extensions, pair))
780 .transpose()?;
781 let offset_mode = offset_pair
782 .map(|pair| OffsetMode::parse_pair(extensions, pair))
783 .transpose()?;
784 Ok(FetchRel {
785 input: Some(input),
786 common: Some(common),
787 advanced_extension: None,
788 offset_mode,
789 count_mode,
790 })
791 }
792}
793
794impl ParsePair for join_rel::JoinType {
795 fn rule() -> Rule {
796 Rule::join_type
797 }
798
799 fn message() -> &'static str {
800 "JoinType"
801 }
802
803 fn parse_pair(pair: Pair<Rule>) -> Self {
804 assert_eq!(pair.as_rule(), Self::rule());
805 let join_type_str = pair.as_str().trim_start_matches('&');
806 match join_type_str {
807 "Inner" => join_rel::JoinType::Inner,
808 "Left" => join_rel::JoinType::Left,
809 "Right" => join_rel::JoinType::Right,
810 "Outer" => join_rel::JoinType::Outer,
811 "LeftSemi" => join_rel::JoinType::LeftSemi,
812 "RightSemi" => join_rel::JoinType::RightSemi,
813 "LeftAnti" => join_rel::JoinType::LeftAnti,
814 "RightAnti" => join_rel::JoinType::RightAnti,
815 "LeftSingle" => join_rel::JoinType::LeftSingle,
816 "RightSingle" => join_rel::JoinType::RightSingle,
817 "LeftMark" => join_rel::JoinType::LeftMark,
818 "RightMark" => join_rel::JoinType::RightMark,
819 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
820 }
821 }
822}
823
824impl RelationParsePair for JoinRel {
825 fn rule() -> Rule {
826 Rule::join_relation
827 }
828
829 fn message() -> &'static str {
830 "JoinRel"
831 }
832
833 fn into_rel(self) -> Rel {
834 Rel {
835 rel_type: Some(RelType::Join(Box::new(self))),
836 }
837 }
838
839 fn parse_pair_with_context(
840 extensions: &SimpleExtensions,
841 pair: Pair<Rule>,
842 input_children: Vec<Box<Rel>>,
843 _input_field_count: usize,
844 ) -> Result<Self, MessageParseError> {
845 assert_eq!(pair.as_rule(), Self::rule());
846
847 if input_children.len() != 2 {
849 return Err(MessageParseError::invalid(
850 Self::message(),
851 pair.as_span(),
852 format!(
853 "JoinRel should have exactly 2 input children, got {}",
854 input_children.len()
855 ),
856 ));
857 }
858
859 let mut children_iter = input_children.into_iter();
860 let left = children_iter.next().unwrap();
861 let right = children_iter.next().unwrap();
862
863 let mut iter = RuleIter::from(pair.into_inner());
864
865 let join_type = iter.parse_next::<join_rel::JoinType>();
867
868 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
870
871 let reference_list_pair = iter.pop(Rule::reference_list);
873 iter.done();
874
875 let emit = parse_reference_emit(reference_list_pair);
876 let common = RelCommon {
877 emit_kind: Some(emit),
878 ..Default::default()
879 };
880
881 Ok(JoinRel {
882 common: Some(common),
883 left: Some(left),
884 right: Some(right),
885 expression: Some(Box::new(condition)),
886 post_join_filter: None, r#type: join_type as i32,
888 advanced_extension: None,
889 })
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use pest::Parser;
896
897 use super::*;
898 use crate::fixtures::TestContext;
899 use crate::parser::{ExpressionParser, Rule};
900
901 #[test]
902 fn test_parse_relation() {
903 }
905
906 #[test]
907 fn test_parse_read_relation() {
908 let extensions = SimpleExtensions::default();
909 let read = ReadRel::parse_pair_with_context(
910 &extensions,
911 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
912 vec![],
913 0,
914 )
915 .unwrap();
916 let names = match &read.read_type {
917 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
918 _ => panic!("Expected NamedTable"),
919 };
920 assert_eq!(names, &["ab", "cd", "ef"]);
921 let columns = &read
922 .base_schema
923 .as_ref()
924 .unwrap()
925 .r#struct
926 .as_ref()
927 .unwrap()
928 .types;
929 assert_eq!(columns.len(), 2);
930 }
931
932 fn example_read_relation() -> ReadRel {
934 let extensions = SimpleExtensions::default();
935 ReadRel::parse_pair_with_context(
936 &extensions,
937 parse_exact(
938 Rule::read_relation,
939 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
940 ),
941 vec![],
942 0,
943 )
944 .unwrap()
945 }
946
947 #[test]
948 fn test_parse_filter_relation() {
949 let extensions = SimpleExtensions::default();
950 let filter = FilterRel::parse_pair_with_context(
951 &extensions,
952 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
953 vec![Box::new(example_read_relation().into_rel())],
954 3,
955 )
956 .unwrap();
957 let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
958 let emit = match emit_kind {
959 EmitKind::Emit(emit) => &emit.output_mapping,
960 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
961 };
962 assert_eq!(emit, &[0, 1, 2]);
963 }
964
965 #[test]
966 fn test_parse_project_relation() {
967 let extensions = SimpleExtensions::default();
968 let project = ProjectRel::parse_pair_with_context(
969 &extensions,
970 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
971 vec![Box::new(example_read_relation().into_rel())],
972 3,
973 )
974 .unwrap();
975
976 assert_eq!(project.expressions.len(), 1);
978
979 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
980 let emit = match emit_kind {
981 EmitKind::Emit(emit) => &emit.output_mapping,
982 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
983 };
984 assert_eq!(emit, &[0, 1, 3]);
986 }
987
988 #[test]
989 fn test_parse_project_relation_complex() {
990 let extensions = SimpleExtensions::default();
991 let project = ProjectRel::parse_pair_with_context(
992 &extensions,
993 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
994 vec![Box::new(example_read_relation().into_rel())],
995 5, )
997 .unwrap();
998
999 assert_eq!(project.expressions.len(), 2);
1001
1002 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1003 let emit = match emit_kind {
1004 EmitKind::Emit(emit) => &emit.output_mapping,
1005 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1006 };
1007 assert_eq!(emit, &[5, 0, 6, 2, 1]);
1010 }
1011
1012 #[test]
1013 fn test_parse_aggregate_relation() {
1014 let extensions = TestContext::new()
1015 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1016 .with_function(1, 10, "sum")
1017 .with_function(1, 11, "count")
1018 .extensions;
1019
1020 let aggregate = AggregateRel::parse_pair_with_context(
1021 &extensions,
1022 parse_exact(
1023 Rule::aggregate_relation,
1024 "Aggregate[$0, $1 => sum($2), $0, count($2)]",
1025 ),
1026 vec![Box::new(example_read_relation().into_rel())],
1027 3,
1028 )
1029 .unwrap();
1030
1031 assert_eq!(aggregate.grouping_expressions.len(), 2);
1033 assert_eq!(aggregate.measures.len(), 2);
1034
1035 let emit_kind = &aggregate
1036 .common
1037 .as_ref()
1038 .unwrap()
1039 .emit_kind
1040 .as_ref()
1041 .unwrap();
1042 let emit = match emit_kind {
1043 EmitKind::Emit(emit) => &emit.output_mapping,
1044 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1045 };
1046 assert_eq!(emit, &[2, 0, 3]);
1049 }
1050
1051 #[test]
1052 fn test_parse_aggregate_relation_simple() {
1053 let extensions = TestContext::new()
1054 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1055 .with_function(1, 10, "sum")
1056 .with_function(1, 11, "count")
1057 .extensions;
1058
1059 let aggregate = AggregateRel::parse_pair_with_context(
1060 &extensions,
1061 parse_exact(
1062 Rule::aggregate_relation,
1063 "Aggregate[$0 => sum($1), count($1)]",
1064 ),
1065 vec![Box::new(example_read_relation().into_rel())],
1066 3,
1067 )
1068 .unwrap();
1069
1070 assert_eq!(aggregate.grouping_expressions.len(), 1);
1072 assert_eq!(aggregate.measures.len(), 2);
1073
1074 let emit_kind = &aggregate
1075 .common
1076 .as_ref()
1077 .unwrap()
1078 .emit_kind
1079 .as_ref()
1080 .unwrap();
1081 let emit = match emit_kind {
1082 EmitKind::Emit(emit) => &emit.output_mapping,
1083 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1084 };
1085 assert_eq!(emit, &[1, 2]);
1087 }
1088
1089 #[test]
1090 fn test_parse_aggregate_relation_no_group_by() {
1091 let extensions = TestContext::new()
1092 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1093 .with_function(1, 10, "sum")
1094 .with_function(1, 11, "count")
1095 .extensions;
1096
1097 let aggregate = AggregateRel::parse_pair_with_context(
1098 &extensions,
1099 parse_exact(
1100 Rule::aggregate_relation,
1101 "Aggregate[_ => sum($0), count($1)]",
1102 ),
1103 vec![Box::new(example_read_relation().into_rel())],
1104 3,
1105 )
1106 .unwrap();
1107
1108 assert_eq!(aggregate.grouping_expressions.len(), 0);
1110 assert_eq!(aggregate.measures.len(), 2);
1111
1112 let emit_kind = &aggregate
1113 .common
1114 .as_ref()
1115 .unwrap()
1116 .emit_kind
1117 .as_ref()
1118 .unwrap();
1119 let emit = match emit_kind {
1120 EmitKind::Emit(emit) => &emit.output_mapping,
1121 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1122 };
1123 assert_eq!(emit, &[0, 1]);
1125 }
1126
1127 #[test]
1128 fn test_parse_aggregate_relation_empty_group_by() {
1129 let extensions = TestContext::new()
1130 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1131 .with_function(1, 10, "sum")
1132 .with_function(1, 11, "count")
1133 .extensions;
1134
1135 let aggregate = AggregateRel::parse_pair_with_context(
1136 &extensions,
1137 parse_exact(
1138 Rule::aggregate_relation,
1139 "Aggregate[_ => sum($0), count($1)]",
1140 ),
1141 vec![Box::new(example_read_relation().into_rel())],
1142 3,
1143 )
1144 .unwrap();
1145
1146 assert_eq!(aggregate.grouping_expressions.len(), 0);
1148 assert_eq!(aggregate.measures.len(), 2);
1149
1150 let emit_kind = &aggregate
1151 .common
1152 .as_ref()
1153 .unwrap()
1154 .emit_kind
1155 .as_ref()
1156 .unwrap();
1157 let emit = match emit_kind {
1158 EmitKind::Emit(emit) => &emit.output_mapping,
1159 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1160 };
1161 assert_eq!(emit, &[0, 1]);
1163 }
1164
1165 #[test]
1166 fn test_fetch_relation_positive_values() {
1167 let extensions = SimpleExtensions::default();
1168
1169 let fetch_rel = FetchRel::parse_pair_with_context(
1171 &extensions,
1172 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1173 vec![Box::new(example_read_relation().into_rel())],
1174 3,
1175 )
1176 .unwrap();
1177
1178 assert_eq!(
1180 fetch_rel.count_mode,
1181 Some(CountMode::CountExpr(i64_literal_expr(10)))
1182 );
1183 assert_eq!(
1184 fetch_rel.offset_mode,
1185 Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1186 );
1187 }
1188
1189 #[test]
1190 fn test_fetch_relation_negative_limit_rejected() {
1191 let extensions = SimpleExtensions::default();
1192
1193 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1195 if let Ok(mut pairs) = parsed_result {
1196 let pair = pairs.next().unwrap();
1197 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1198 let result = FetchRel::parse_pair_with_context(
1200 &extensions,
1201 pair,
1202 vec![Box::new(example_read_relation().into_rel())],
1203 3,
1204 );
1205 assert!(result.is_err());
1206 let error_msg = result.unwrap_err().to_string();
1207 assert!(error_msg.contains("Fetch limit must be non-negative"));
1208 } else {
1209 println!("Grammar prevents negative limit values at parse time");
1212 }
1213 } else {
1214 println!("Grammar prevents negative limit values at parse time");
1216 }
1217 }
1218
1219 #[test]
1220 fn test_fetch_relation_negative_offset_rejected() {
1221 let extensions = SimpleExtensions::default();
1222
1223 let parsed_result =
1225 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1226 if let Ok(mut pairs) = parsed_result {
1227 let pair = pairs.next().unwrap();
1228 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1229 let result = FetchRel::parse_pair_with_context(
1231 &extensions,
1232 pair,
1233 vec![Box::new(example_read_relation().into_rel())],
1234 3,
1235 );
1236 assert!(result.is_err());
1237 let error_msg = result.unwrap_err().to_string();
1238 assert!(error_msg.contains("Fetch offset must be non-negative"));
1239 } else {
1240 println!("Grammar prevents negative offset values at parse time");
1242 }
1243 } else {
1244 println!("Grammar prevents negative offset values at parse time");
1246 }
1247 }
1248
1249 #[test]
1250 fn test_parse_join_relation() {
1251 let extensions = TestContext::new()
1252 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1253 .with_function(1, 10, "eq")
1254 .extensions;
1255
1256 let left_rel = example_read_relation().into_rel();
1257 let right_rel = example_read_relation().into_rel();
1258
1259 let join = JoinRel::parse_pair_with_context(
1260 &extensions,
1261 parse_exact(
1262 Rule::join_relation,
1263 "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1264 ),
1265 vec![Box::new(left_rel), Box::new(right_rel)],
1266 6, )
1268 .unwrap();
1269
1270 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1272
1273 assert!(join.left.is_some());
1275 assert!(join.right.is_some());
1276
1277 assert!(join.expression.is_some());
1279
1280 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1281 let emit = match emit_kind {
1282 EmitKind::Emit(emit) => &emit.output_mapping,
1283 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1284 };
1285 assert_eq!(emit, &[0, 1, 3, 4]);
1287 }
1288
1289 #[test]
1290 fn test_parse_join_relation_left_outer() {
1291 let extensions = TestContext::new()
1292 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1293 .with_function(1, 10, "eq")
1294 .extensions;
1295
1296 let left_rel = example_read_relation().into_rel();
1297 let right_rel = example_read_relation().into_rel();
1298
1299 let join = JoinRel::parse_pair_with_context(
1300 &extensions,
1301 parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1302 vec![Box::new(left_rel), Box::new(right_rel)],
1303 6,
1304 )
1305 .unwrap();
1306
1307 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1309
1310 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1311 let emit = match emit_kind {
1312 EmitKind::Emit(emit) => &emit.output_mapping,
1313 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1314 };
1315 assert_eq!(emit, &[0, 1, 2]);
1317 }
1318
1319 #[test]
1320 fn test_parse_join_relation_left_semi() {
1321 let extensions = TestContext::new()
1322 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1323 .with_function(1, 10, "eq")
1324 .extensions;
1325
1326 let left_rel = example_read_relation().into_rel();
1327 let right_rel = example_read_relation().into_rel();
1328
1329 let join = JoinRel::parse_pair_with_context(
1330 &extensions,
1331 parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1332 vec![Box::new(left_rel), Box::new(right_rel)],
1333 6,
1334 )
1335 .unwrap();
1336
1337 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1339
1340 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1341 let emit = match emit_kind {
1342 EmitKind::Emit(emit) => &emit.output_mapping,
1343 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1344 };
1345 assert_eq!(emit, &[0, 1]);
1347 }
1348
1349 #[test]
1350 fn test_parse_join_relation_requires_two_children() {
1351 let extensions = SimpleExtensions::default();
1352
1353 let result = JoinRel::parse_pair_with_context(
1355 &extensions,
1356 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1357 vec![],
1358 0,
1359 );
1360 assert!(result.is_err());
1361
1362 let result = JoinRel::parse_pair_with_context(
1364 &extensions,
1365 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1366 vec![Box::new(example_read_relation().into_rel())],
1367 3,
1368 );
1369 assert!(result.is_err());
1370
1371 let result = JoinRel::parse_pair_with_context(
1373 &extensions,
1374 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1375 vec![
1376 Box::new(example_read_relation().into_rel()),
1377 Box::new(example_read_relation().into_rel()),
1378 Box::new(example_read_relation().into_rel()),
1379 ],
1380 9,
1381 );
1382 assert!(result.is_err());
1383 }
1384
1385 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1386 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1387 assert_eq!(pairs.as_str(), input);
1388 let pair = pairs.next().unwrap();
1389 assert_eq!(pairs.next(), None);
1390 pair
1391 }
1392}