1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType, nested};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::{CountMode, OffsetMode};
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::{Direct, Emit, EmitKind};
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
15 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
16};
17
18use super::{MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair};
19use crate::extensions::any::Any;
20use crate::extensions::registry::{ExtensionError, ExtensionType};
21use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions};
22use crate::parser::errors::{ParseContext, ParseError};
23use crate::parser::expressions::{FieldIndex, Name};
24
25pub struct RelationParsingContext<'a> {
27 pub extensions: &'a SimpleExtensions,
28 pub registry: &'a ExtensionRegistry,
29 pub line_no: i64,
30 pub line: &'a str,
31}
32
33impl<'a> RelationParsingContext<'a> {
34 pub fn resolve_extension_detail(
36 &self,
37 extension_name: &str,
38 extension_args: &ExtensionArgs,
39 ) -> Result<Option<Any>, ParseError> {
40 let detail = self
41 .registry
42 .parse_extension(extension_name, extension_args);
43
44 match detail {
45 Ok(any) => Ok(Some(any)),
46 Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
47 name: extension_name.to_string(),
48 context: ParseContext::new(self.line_no, self.line.to_string()),
49 }),
50 Err(err) => Err(ParseError::ExtensionDetail(
51 ParseContext::new(self.line_no, self.line.to_string()),
52 err,
53 )),
54 }
55 }
56
57 pub fn resolve_adv_ext_detail(
63 &self,
64 ext_type: ExtensionType,
65 name: &str,
66 args: &ExtensionArgs,
67 ) -> Result<Any, ParseError> {
68 let result = match ext_type {
69 ExtensionType::Enhancement => self.registry.parse_enhancement(name, args),
70 ExtensionType::Optimization => self.registry.parse_optimization(name, args),
71 ExtensionType::Relation => unreachable!("Relation is not an advanced extension type"),
72 };
73 result.map_err(|err| match err {
74 ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
75 name: name.to_string(),
76 context: ParseContext::new(self.line_no, self.line.to_string()),
77 },
78 err => ParseError::ExtensionDetail(
79 ParseContext::new(self.line_no, self.line.to_string()),
80 err,
81 ),
82 })
83 }
84}
85
86pub trait RelationParsePair: Sized {
88 fn rule() -> Rule;
89 fn message() -> &'static str;
90
91 fn parse_pair_with_context(
97 extensions: &SimpleExtensions,
98 pair: Pair<Rule>,
99 input_children: Vec<Box<Rel>>,
100 input_field_count: usize,
101 ) -> Result<(Self, usize), MessageParseError>;
102
103 fn into_rel(self, adv_ext: Option<AdvancedExtension>) -> Rel;
106}
107
108pub struct TableName(Vec<String>);
109
110impl ParsePair for TableName {
111 fn rule() -> Rule {
112 Rule::table_name
113 }
114
115 fn message() -> &'static str {
116 "TableName"
117 }
118
119 fn parse_pair(pair: Pair<Rule>) -> Self {
120 assert_eq!(pair.as_rule(), Self::rule());
121 let pairs = pair.into_inner();
122 let mut names = Vec::with_capacity(pairs.len());
123 let mut iter = RuleIter::from(pairs);
124 while let Some(name) = iter.parse_if_next::<Name>() {
125 names.push(name.0);
126 }
127 iter.done();
128 Self(names)
129 }
130}
131
132#[derive(Debug, Clone)]
133pub struct Column {
134 pub name: String,
135 pub typ: Type,
136}
137
138impl ScopedParsePair for Column {
139 fn rule() -> Rule {
140 Rule::named_column
141 }
142
143 fn message() -> &'static str {
144 "Column"
145 }
146
147 fn parse_pair(
148 extensions: &SimpleExtensions,
149 pair: Pair<Rule>,
150 ) -> Result<Self, MessageParseError> {
151 assert_eq!(pair.as_rule(), Self::rule());
152 let mut iter = RuleIter::from(pair.into_inner());
153 let name = iter.parse_next::<Name>().0;
154 let typ = iter.parse_next_scoped(extensions)?;
155 iter.done();
156 Ok(Self { name, typ })
157 }
158}
159
160pub struct NamedColumnList(Vec<Column>);
161
162impl ScopedParsePair for NamedColumnList {
163 fn rule() -> Rule {
164 Rule::named_column_list
165 }
166
167 fn message() -> &'static str {
168 "NamedColumnList"
169 }
170
171 fn parse_pair(
172 extensions: &SimpleExtensions,
173 pair: Pair<Rule>,
174 ) -> Result<Self, MessageParseError> {
175 assert_eq!(pair.as_rule(), Self::rule());
176 let mut columns = Vec::new();
177 for col in pair.into_inner() {
178 columns.push(Column::parse_pair(extensions, col)?);
179 }
180 Ok(Self(columns))
181 }
182}
183
184#[allow(clippy::vec_box)]
189pub(crate) fn expect_one_child(
190 message: &'static str,
191 pair: &Pair<Rule>,
192 mut input_children: Vec<Box<Rel>>,
193) -> Result<Box<Rel>, MessageParseError> {
194 match input_children.len() {
195 0 => Err(MessageParseError::invalid(
196 message,
197 pair.as_span(),
198 format!("{message} missing child"),
199 )),
200 1 => Ok(input_children.pop().unwrap()),
201 n => Err(MessageParseError::invalid(
202 message,
203 pair.as_span(),
204 format!("{message} should have 1 input child, got {n}"),
205 )),
206 }
207}
208
209fn parse_output_mapping(pair: Pair<Rule>) -> Vec<i32> {
212 assert_eq!(pair.as_rule(), Rule::reference_list);
213 pair.into_inner()
214 .map(|p| FieldIndex::parse_pair(p).0)
215 .collect()
216}
217
218fn make_emit(output_mapping: Vec<i32>, direct_output_count: usize) -> EmitKind {
221 let is_identity = output_mapping.len() == direct_output_count
222 && output_mapping
223 .iter()
224 .enumerate()
225 .all(|(i, &v)| v == i as i32);
226 if is_identity {
227 EmitKind::Direct(Direct {})
228 } else {
229 EmitKind::Emit(Emit { output_mapping })
230 }
231}
232
233fn parse_emit(reference_list: Pair<Rule>, direct_output_count: usize) -> (EmitKind, usize) {
235 let output_mapping = parse_output_mapping(reference_list);
236 let output_count = output_mapping.len();
237 let emit = make_emit(output_mapping, direct_output_count);
238 (emit, output_count)
239}
240
241pub struct ParsedNamedArgs<'a> {
247 map: HashMap<&'a str, Pair<'a, Rule>>,
248}
249
250impl<'a> ParsedNamedArgs<'a> {
251 pub fn new(
252 pairs: pest::iterators::Pairs<'a, Rule>,
253 rule: Rule,
254 ) -> Result<Self, MessageParseError> {
255 let mut map = HashMap::new();
256 for pair in pairs {
257 assert_eq!(pair.as_rule(), rule);
258 let mut inner = pair.clone().into_inner();
259 let name_pair = inner.next().unwrap();
260 let value_pair = inner.next().unwrap();
261 assert_eq!(inner.next(), None);
262 let name = name_pair.as_str();
263 if map.contains_key(name) {
264 return Err(MessageParseError::invalid(
265 "NamedArg",
266 name_pair.as_span(),
267 format!("Duplicate argument: {name}"),
268 ));
269 }
270 map.insert(name, value_pair);
271 }
272 Ok(Self { map })
273 }
274
275 pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
279 let pair = self.map.remove(name).inspect(|pair| {
280 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
281 });
282 (self, pair)
283 }
284
285 pub fn done(self) -> Result<(), MessageParseError> {
287 if let Some((name, pair)) = self.map.iter().next() {
288 return Err(MessageParseError::invalid(
289 "NamedArgExtractor",
290 pair.as_span(),
292 format!("Unknown argument: {name}"),
293 ));
294 }
295 Ok(())
296 }
297}
298
299impl RelationParsePair for ReadRel {
300 fn rule() -> Rule {
301 Rule::read_relation
302 }
303
304 fn message() -> &'static str {
305 "ReadRel"
306 }
307
308 fn parse_pair_with_context(
309 extensions: &SimpleExtensions,
310 pair: Pair<Rule>,
311 input_children: Vec<Box<Rel>>,
312 input_field_count: usize,
313 ) -> Result<(Self, usize), MessageParseError> {
314 assert_eq!(pair.as_rule(), Self::rule());
315 if !input_children.is_empty() {
317 return Err(MessageParseError::invalid(
318 Self::message(),
319 pair.as_span(),
320 "ReadRel should have no input children",
321 ));
322 }
323 if input_field_count != 0 {
324 return Err(MessageParseError::invalid(
325 "ReadRel",
326 pair.as_span(),
327 "ReadRel should have 0 input fields",
328 ));
329 }
330
331 let mut iter = RuleIter::from(pair.into_inner());
332 let table = iter.parse_next::<TableName>().0;
333 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
334 iter.done();
335
336 let output_count = columns.len();
337 Ok((
338 ReadRel {
339 base_schema: Some(build_named_struct(columns)),
340 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
341 names: table,
342 advanced_extension: None,
343 })),
344 ..Default::default()
345 },
346 output_count,
347 ))
348 }
349
350 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
351 self.advanced_extension = adv_ext;
352 Rel {
353 rel_type: Some(RelType::Read(Box::new(self))),
354 }
355 }
356}
357
358pub(crate) struct VirtualReadRel(ReadRel);
362
363impl RelationParsePair for VirtualReadRel {
364 fn rule() -> Rule {
365 Rule::virtual_read_relation
366 }
367
368 fn message() -> &'static str {
369 "VirtualReadRel"
370 }
371
372 fn parse_pair_with_context(
373 extensions: &SimpleExtensions,
374 pair: Pair<Rule>,
375 input_children: Vec<Box<Rel>>,
376 _input_field_count: usize,
377 ) -> Result<(Self, usize), MessageParseError> {
378 assert_eq!(pair.as_rule(), Self::rule());
379 if !input_children.is_empty() {
380 return Err(MessageParseError::invalid(
381 Self::message(),
382 pair.as_span(),
383 "Read:Virtual should have no input children",
384 ));
385 }
386
387 let mut iter = RuleIter::from(pair.into_inner());
388 let args_pair = iter.pop(Rule::virtual_read_args);
389 let columns_pair = iter.pop(Rule::named_column_list);
390 iter.done();
391
392 let rows = parse_virtual_read_args(extensions, args_pair)?;
393 let columns = NamedColumnList::parse_pair(extensions, columns_pair)?.0;
394
395 let output_count = columns.len();
400 Ok((
401 VirtualReadRel(ReadRel {
402 base_schema: Some(build_named_struct(columns)),
403 read_type: Some(read_rel::ReadType::VirtualTable(read_rel::VirtualTable {
404 expressions: rows,
405 ..Default::default()
406 })),
407 ..Default::default()
408 }),
409 output_count,
410 ))
411 }
412
413 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
414 self.0.advanced_extension = adv_ext;
415 Rel {
416 rel_type: Some(RelType::Read(Box::new(self.0))),
417 }
418 }
419}
420
421fn build_named_struct(columns: Vec<Column>) -> NamedStruct {
423 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
424 NamedStruct {
425 names,
426 r#struct: Some(r#type::Struct {
427 types,
428 type_variation_reference: 0,
429 nullability: r#type::Nullability::Required as i32,
430 }),
431 }
432}
433
434fn parse_virtual_read_args(
436 extensions: &SimpleExtensions,
437 pair: Pair<Rule>,
438) -> Result<Vec<nested::Struct>, MessageParseError> {
439 assert_eq!(pair.as_rule(), Rule::virtual_read_args);
440 let inner = unwrap_single_pair(pair);
441 match inner.as_rule() {
442 Rule::empty => Ok(vec![]),
443 Rule::virtual_row_list => inner
444 .into_inner()
445 .map(|row| parse_virtual_row(extensions, row))
446 .collect(),
447 _ => unreachable!(
448 "Unexpected rule in virtual_read_args: {:?}",
449 inner.as_rule()
450 ),
451 }
452}
453
454fn parse_virtual_row(
456 extensions: &SimpleExtensions,
457 pair: Pair<Rule>,
458) -> Result<nested::Struct, MessageParseError> {
459 assert_eq!(pair.as_rule(), Rule::virtual_row);
460 let fields = match pair.into_inner().next() {
461 Some(expression_list) => {
462 assert_eq!(expression_list.as_rule(), Rule::expression_list);
463 parse_expression_list(extensions, expression_list)?
464 }
465 None => vec![],
467 };
468 Ok(nested::Struct { fields })
469}
470
471impl RelationParsePair for FilterRel {
472 fn rule() -> Rule {
473 Rule::filter_relation
474 }
475
476 fn message() -> &'static str {
477 "FilterRel"
478 }
479
480 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
481 self.advanced_extension = adv_ext;
482 Rel {
483 rel_type: Some(RelType::Filter(Box::new(self))),
484 }
485 }
486
487 fn parse_pair_with_context(
488 extensions: &SimpleExtensions,
489 pair: Pair<Rule>,
490 input_children: Vec<Box<Rel>>,
491 input_field_count: usize,
492 ) -> Result<(Self, usize), MessageParseError> {
493 assert_eq!(pair.as_rule(), Self::rule());
494 let input = expect_one_child(Self::message(), &pair, input_children)?;
495 let mut iter = RuleIter::from(pair.into_inner());
496 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
497 let references_pair = iter.pop(Rule::reference_list);
499 iter.done();
500
501 let (emit, output_count) = parse_emit(references_pair, input_field_count);
502 let common = RelCommon {
503 emit_kind: Some(emit),
504 ..Default::default()
505 };
506
507 Ok((
508 FilterRel {
509 input: Some(input),
510 condition: Some(Box::new(condition)),
511 common: Some(common),
512 advanced_extension: None,
513 },
514 output_count,
515 ))
516 }
517}
518
519impl RelationParsePair for ProjectRel {
520 fn rule() -> Rule {
521 Rule::project_relation
522 }
523
524 fn message() -> &'static str {
525 "ProjectRel"
526 }
527
528 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
529 self.advanced_extension = adv_ext;
530 Rel {
531 rel_type: Some(RelType::Project(Box::new(self))),
532 }
533 }
534
535 fn parse_pair_with_context(
536 extensions: &SimpleExtensions,
537 pair: Pair<Rule>,
538 input_children: Vec<Box<Rel>>,
539 input_field_count: usize,
540 ) -> Result<(Self, usize), MessageParseError> {
541 assert_eq!(pair.as_rule(), Self::rule());
542 let input = expect_one_child(Self::message(), &pair, input_children)?;
543
544 let arguments_pair = unwrap_single_pair(pair);
545
546 let mut expressions = Vec::new();
547 let mut output_mapping = Vec::new();
548
549 for arg in arguments_pair.into_inner() {
550 let inner_arg = unwrap_single_pair(arg);
551 match inner_arg.as_rule() {
552 Rule::reference => {
553 let field_index = FieldIndex::parse_pair(inner_arg);
554 output_mapping.push(field_index.0);
555 }
556 Rule::expression => {
557 let expr = Expression::parse_pair(extensions, inner_arg)?;
558 expressions.push(expr);
559 output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
561 }
562 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
563 }
564 }
565
566 let output_count = output_mapping.len();
567 let direct_count = input_field_count + expressions.len();
568 let emit = make_emit(output_mapping, direct_count);
569 let common = RelCommon {
570 emit_kind: Some(emit),
571 ..Default::default()
572 };
573
574 Ok((
575 ProjectRel {
576 input: Some(input),
577 expressions,
578 common: Some(common),
579 advanced_extension: None,
580 },
581 output_count,
582 ))
583 }
584}
585
586impl RelationParsePair for AggregateRel {
587 fn rule() -> Rule {
588 Rule::aggregate_relation
589 }
590
591 fn message() -> &'static str {
592 "AggregateRel"
593 }
594
595 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
596 self.advanced_extension = adv_ext;
597 Rel {
598 rel_type: Some(RelType::Aggregate(Box::new(self))),
599 }
600 }
601
602 fn parse_pair_with_context(
603 extensions: &SimpleExtensions,
604 pair: Pair<Rule>,
605 input_children: Vec<Box<Rel>>,
606 _input_field_count: usize,
609 ) -> Result<(Self, usize), MessageParseError> {
610 assert_eq!(pair.as_rule(), Self::rule());
611 let input = expect_one_child(Self::message(), &pair, input_children)?;
612 let mut iter = RuleIter::from(pair.into_inner());
613 let group_by_pair = iter.pop(Rule::aggregate_group_by);
614 let output_pair = iter.pop(Rule::aggregate_output);
615 iter.done();
616
617 let inner = group_by_pair
618 .into_inner()
619 .next()
620 .expect("aggregate_group_by must have one inner item");
621
622 let grouping_sets = parse_grouping_sets(extensions, inner)?;
623 let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
624
625 let (measures, output_mapping) =
626 parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
627
628 let output_count = output_mapping.len();
629 let direct_count = grouping_expressions.len() + measures.len();
630 let emit = make_emit(output_mapping, direct_count);
631 let common = RelCommon {
632 emit_kind: Some(emit),
633 ..Default::default()
634 };
635
636 Ok((
637 AggregateRel {
638 input: Some(input),
639 grouping_expressions,
640 groupings,
641 measures,
642 common: Some(common),
643 advanced_extension: None,
644 },
645 output_count,
646 ))
647 }
648}
649
650fn parse_aggregate_measures(
655 extensions: &SimpleExtensions,
656 output_pair: Pair<'_, Rule>,
657 grouping_expressions: &[Expression],
658) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
659 assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
660 let mut measures = Vec::new();
661 let mut output_mapping = Vec::new();
662
663 for aggregate_output_item in output_pair.into_inner() {
664 let inner_item = unwrap_single_pair(aggregate_output_item);
665 match inner_item.as_rule() {
666 Rule::reference => {
667 let field_index = FieldIndex::parse_pair(inner_item);
668 output_mapping.push(field_index.0);
669 }
670 Rule::aggregate_measure => {
671 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
672 output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
673 measures.push(measure);
674 }
675 _ => panic!(
676 "Unexpected inner output item rule: {:?}",
677 inner_item.as_rule()
678 ),
679 }
680 }
681
682 Ok((measures, output_mapping))
683}
684
685fn parse_grouping_sets(
694 extensions: &SimpleExtensions,
695 inner: Pair<'_, Rule>,
696) -> Result<Vec<Vec<Expression>>, MessageParseError> {
697 assert!(
698 matches!(
699 inner.as_rule(),
700 Rule::expression_list | Rule::grouping_set_list
701 ),
702 "Expected expression_list or grouping_set_list, got {:?}",
703 inner.as_rule()
704 );
705 match inner.as_rule() {
706 Rule::expression_list => Ok(vec![parse_expression_list(extensions, inner)?]),
707 Rule::grouping_set_list => inner
708 .into_inner()
709 .map(|pair| parse_grouping_set(extensions, pair))
710 .collect(),
711 _ => unreachable!(
712 "Unexpected rule in aggregate_group_by: {:?}",
713 inner.as_rule()
714 ),
715 }
716}
717
718fn parse_grouping_set(
722 extensions: &SimpleExtensions,
723 pair: Pair<'_, Rule>,
724) -> Result<Vec<Expression>, MessageParseError> {
725 assert_eq!(pair.as_rule(), Rule::grouping_set);
726 let inner = pair
727 .into_inner()
728 .next()
729 .expect("grouping_set must have one inner item");
730 match inner.as_rule() {
731 Rule::empty => Ok(vec![]),
732 Rule::expression_list => parse_expression_list(extensions, inner),
733 _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
734 }
735}
736
737pub(crate) fn parse_expression_list(
739 extensions: &SimpleExtensions,
740 pair: Pair<'_, Rule>,
741) -> Result<Vec<Expression>, MessageParseError> {
742 pair.into_inner()
743 .map(|expr_pair| Expression::parse_pair(extensions, expr_pair))
744 .collect()
745}
746
747fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
751 let mut expressions: Vec<Expression> = Vec::new();
752 let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
753
754 let groupings = expression_sets
755 .iter()
756 .map(|set| {
757 let expression_references = set
758 .iter()
759 .map(|exp| {
760 let key = exp.encode_to_vec();
764 let next_idx = expressions.len() as u32;
765 *seen.entry(key).or_insert_with(|| {
766 expressions.push(exp.clone());
767 next_idx
768 })
769 })
770 .collect();
771 Grouping {
772 expression_references,
773 #[allow(deprecated)]
774 grouping_expressions: vec![],
775 }
776 })
777 .collect();
778
779 (groupings, expressions)
780}
781
782impl ScopedParsePair for SortField {
783 fn rule() -> Rule {
784 Rule::sort_field
785 }
786
787 fn message() -> &'static str {
788 "SortField"
789 }
790
791 fn parse_pair(
792 _extensions: &SimpleExtensions,
793 pair: Pair<Rule>,
794 ) -> Result<Self, MessageParseError> {
795 assert_eq!(pair.as_rule(), Self::rule());
796 let mut iter = RuleIter::from(pair.into_inner());
797 let reference_pair = iter.pop(Rule::reference);
798 let field_index = FieldIndex::parse_pair(reference_pair);
799 let direction_pair = iter.pop(Rule::sort_direction);
800 let direction = match direction_pair.as_str().trim_start_matches('&') {
804 "AscNullsFirst" => SortDirection::AscNullsFirst,
805 "AscNullsLast" => SortDirection::AscNullsLast,
806 "DescNullsFirst" => SortDirection::DescNullsFirst,
807 "DescNullsLast" => SortDirection::DescNullsLast,
808 other => {
809 return Err(MessageParseError::invalid(
810 "SortDirection",
811 direction_pair.as_span(),
812 format!("Unknown sort direction: {other}"),
813 ));
814 }
815 };
816 iter.done();
817 Ok(SortField {
818 expr: Some(Expression {
819 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
820 field_index.to_field_reference(),
821 ))),
822 }),
823 sort_kind: Some(SortKind::Direction(direction as i32)),
825 })
826 }
827}
828
829impl RelationParsePair for SortRel {
830 fn rule() -> Rule {
831 Rule::sort_relation
832 }
833
834 fn message() -> &'static str {
835 "SortRel"
836 }
837
838 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
839 self.advanced_extension = adv_ext;
840 Rel {
841 rel_type: Some(RelType::Sort(Box::new(self))),
842 }
843 }
844
845 fn parse_pair_with_context(
846 extensions: &SimpleExtensions,
847 pair: Pair<Rule>,
848 input_children: Vec<Box<Rel>>,
849 input_field_count: usize,
850 ) -> Result<(Self, usize), MessageParseError> {
851 assert_eq!(pair.as_rule(), Self::rule());
852 let input = expect_one_child(Self::message(), &pair, input_children)?;
853 let mut iter = RuleIter::from(pair.into_inner());
854 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
855 let reference_list_pair = iter.pop(Rule::reference_list);
856 let mut sorts = Vec::new();
857 for sort_field_pair in sort_field_list_pair.into_inner() {
858 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
859 sorts.push(sort_field);
860 }
861 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
862 let common = RelCommon {
863 emit_kind: Some(emit),
864 ..Default::default()
865 };
866 iter.done();
867 Ok((
868 SortRel {
869 input: Some(input),
870 sorts,
871 common: Some(common),
872 advanced_extension: None,
873 },
874 output_count,
875 ))
876 }
877}
878
879impl ScopedParsePair for CountMode {
880 fn rule() -> Rule {
881 Rule::fetch_value
882 }
883 fn message() -> &'static str {
884 "CountMode"
885 }
886 fn parse_pair(
887 extensions: &SimpleExtensions,
888 pair: Pair<Rule>,
889 ) -> Result<Self, MessageParseError> {
890 assert_eq!(pair.as_rule(), Self::rule());
891 let mut arg_inner = RuleIter::from(pair.into_inner());
892 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
893 int_pair
894 } else {
895 arg_inner.pop(Rule::expression)
896 };
897 match value_pair.as_rule() {
898 Rule::integer => {
899 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
900 MessageParseError::invalid(
901 Self::message(),
902 value_pair.as_span(),
903 format!("Invalid integer: {e}"),
904 )
905 })?;
906 if value < 0 {
907 return Err(MessageParseError::invalid(
908 Self::message(),
909 value_pair.as_span(),
910 format!("Fetch limit must be non-negative, got: {value}"),
911 ));
912 }
913 Ok(CountMode::CountExpr(i64_literal_expr(value)))
914 }
915 Rule::expression => {
916 let expr = Expression::parse_pair(extensions, value_pair)?;
917 Ok(CountMode::CountExpr(Box::new(expr)))
918 }
919 _ => Err(MessageParseError::invalid(
920 Self::message(),
921 value_pair.as_span(),
922 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
923 )),
924 }
925 }
926}
927
928fn i64_literal_expr(value: i64) -> Box<Expression> {
929 Box::new(Expression {
930 rex_type: Some(RexType::Literal(Literal {
931 nullable: false,
932 type_variation_reference: 0,
933 literal_type: Some(LiteralType::I64(value)),
934 })),
935 })
936}
937
938impl ScopedParsePair for OffsetMode {
939 fn rule() -> Rule {
940 Rule::fetch_value
941 }
942 fn message() -> &'static str {
943 "OffsetMode"
944 }
945 fn parse_pair(
946 extensions: &SimpleExtensions,
947 pair: Pair<Rule>,
948 ) -> Result<Self, MessageParseError> {
949 assert_eq!(pair.as_rule(), Self::rule());
950 let mut arg_inner = RuleIter::from(pair.into_inner());
951 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
952 int_pair
953 } else {
954 arg_inner.pop(Rule::expression)
955 };
956 match value_pair.as_rule() {
957 Rule::integer => {
958 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
959 MessageParseError::invalid(
960 Self::message(),
961 value_pair.as_span(),
962 format!("Invalid integer: {e}"),
963 )
964 })?;
965 if value < 0 {
966 return Err(MessageParseError::invalid(
967 Self::message(),
968 value_pair.as_span(),
969 format!("Fetch offset must be non-negative, got: {value}"),
970 ));
971 }
972 Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
973 }
974 Rule::expression => {
975 let expr = Expression::parse_pair(extensions, value_pair)?;
976 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
977 }
978 _ => Err(MessageParseError::invalid(
979 Self::message(),
980 value_pair.as_span(),
981 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
982 )),
983 }
984 }
985}
986
987impl RelationParsePair for FetchRel {
988 fn rule() -> Rule {
989 Rule::fetch_relation
990 }
991
992 fn message() -> &'static str {
993 "FetchRel"
994 }
995
996 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
997 self.advanced_extension = adv_ext;
998 Rel {
999 rel_type: Some(RelType::Fetch(Box::new(self))),
1000 }
1001 }
1002
1003 fn parse_pair_with_context(
1004 extensions: &SimpleExtensions,
1005 pair: Pair<Rule>,
1006 input_children: Vec<Box<Rel>>,
1007 input_field_count: usize,
1008 ) -> Result<(Self, usize), MessageParseError> {
1009 assert_eq!(pair.as_rule(), Self::rule());
1010 let input = expect_one_child(Self::message(), &pair, input_children)?;
1011 let mut iter = RuleIter::from(pair.into_inner());
1012
1013 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
1017 None => {
1018 iter.pop(Rule::empty);
1019 (None, None)
1020 }
1021 Some(fetch_args_pair) => {
1022 let extractor =
1023 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
1024 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
1025 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
1026 extractor.done()?;
1027 (limit_pair, offset_pair)
1028 }
1029 };
1030
1031 let reference_list_pair = iter.pop(Rule::reference_list);
1032 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1033 let common = RelCommon {
1034 emit_kind: Some(emit),
1035 ..Default::default()
1036 };
1037 iter.done();
1038
1039 let count_mode = limit_pair
1040 .map(|pair| CountMode::parse_pair(extensions, pair))
1041 .transpose()?;
1042 let offset_mode = offset_pair
1043 .map(|pair| OffsetMode::parse_pair(extensions, pair))
1044 .transpose()?;
1045 Ok((
1046 FetchRel {
1047 input: Some(input),
1048 common: Some(common),
1049 advanced_extension: None,
1050 offset_mode,
1051 count_mode,
1052 },
1053 output_count,
1054 ))
1055 }
1056}
1057
1058impl ParsePair for join_rel::JoinType {
1059 fn rule() -> Rule {
1060 Rule::join_type
1061 }
1062
1063 fn message() -> &'static str {
1064 "JoinType"
1065 }
1066
1067 fn parse_pair(pair: Pair<Rule>) -> Self {
1068 assert_eq!(pair.as_rule(), Self::rule());
1069 let join_type_str = pair.as_str().trim_start_matches('&');
1070 match join_type_str {
1071 "Inner" => join_rel::JoinType::Inner,
1072 "Left" => join_rel::JoinType::Left,
1073 "Right" => join_rel::JoinType::Right,
1074 "Outer" => join_rel::JoinType::Outer,
1075 "LeftSemi" => join_rel::JoinType::LeftSemi,
1076 "RightSemi" => join_rel::JoinType::RightSemi,
1077 "LeftAnti" => join_rel::JoinType::LeftAnti,
1078 "RightAnti" => join_rel::JoinType::RightAnti,
1079 "LeftSingle" => join_rel::JoinType::LeftSingle,
1080 "RightSingle" => join_rel::JoinType::RightSingle,
1081 "LeftMark" => join_rel::JoinType::LeftMark,
1082 "RightMark" => join_rel::JoinType::RightMark,
1083 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
1084 }
1085 }
1086}
1087
1088impl RelationParsePair for JoinRel {
1089 fn rule() -> Rule {
1090 Rule::join_relation
1091 }
1092
1093 fn message() -> &'static str {
1094 "JoinRel"
1095 }
1096
1097 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1098 self.advanced_extension = adv_ext;
1099 Rel {
1100 rel_type: Some(RelType::Join(Box::new(self))),
1101 }
1102 }
1103
1104 fn parse_pair_with_context(
1105 extensions: &SimpleExtensions,
1106 pair: Pair<Rule>,
1107 input_children: Vec<Box<Rel>>,
1108 input_field_count: usize,
1109 ) -> Result<(Self, usize), MessageParseError> {
1110 assert_eq!(pair.as_rule(), Self::rule());
1111
1112 if input_children.len() != 2 {
1113 return Err(MessageParseError::invalid(
1114 Self::message(),
1115 pair.as_span(),
1116 format!(
1117 "JoinRel should have exactly 2 input children, got {}",
1118 input_children.len()
1119 ),
1120 ));
1121 }
1122
1123 let mut children_iter = input_children.into_iter();
1124 let left = children_iter.next().unwrap();
1125 let right = children_iter.next().unwrap();
1126
1127 let mut iter = RuleIter::from(pair.into_inner());
1128 let join_type = iter.parse_next::<join_rel::JoinType>();
1129 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
1130 let reference_list_pair = iter.pop(Rule::reference_list);
1131 iter.done();
1132
1133 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1137 let common = RelCommon {
1138 emit_kind: Some(emit),
1139 ..Default::default()
1140 };
1141
1142 Ok((
1143 JoinRel {
1144 common: Some(common),
1145 left: Some(left),
1146 right: Some(right),
1147 expression: Some(Box::new(condition)),
1148 post_join_filter: None, r#type: join_type as i32,
1150 advanced_extension: None,
1151 },
1152 output_count,
1153 ))
1154 }
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159 use pest::Parser;
1160
1161 use super::*;
1162 use crate::fixtures::TestContext;
1163 use crate::parser::{ExpressionParser, Rule};
1164
1165 #[test]
1166 fn test_parse_relation() {
1167 }
1169
1170 #[test]
1171 fn test_parse_read_relation() {
1172 let extensions = SimpleExtensions::default();
1173 let read = ReadRel::parse_pair_with_context(
1174 &extensions,
1175 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1176 vec![],
1177 0,
1178 )
1179 .unwrap()
1180 .0;
1181 let names = match &read.read_type {
1182 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1183 _ => panic!("Expected NamedTable"),
1184 };
1185 assert_eq!(names, &["ab", "cd", "ef"]);
1186 let columns = &read
1187 .base_schema
1188 .as_ref()
1189 .unwrap()
1190 .r#struct
1191 .as_ref()
1192 .unwrap()
1193 .types;
1194 assert_eq!(columns.len(), 2);
1195 }
1196
1197 fn example_read_relation() -> ReadRel {
1199 let extensions = SimpleExtensions::default();
1200 ReadRel::parse_pair_with_context(
1201 &extensions,
1202 parse_exact(
1203 Rule::read_relation,
1204 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1205 ),
1206 vec![],
1207 0,
1208 )
1209 .unwrap()
1210 .0
1211 }
1212
1213 #[test]
1214 fn test_parse_filter_relation() {
1215 let extensions = SimpleExtensions::default();
1216 let filter = FilterRel::parse_pair_with_context(
1217 &extensions,
1218 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1219 vec![Box::new(example_read_relation().into_rel(None))],
1220 3,
1221 )
1222 .unwrap()
1223 .0;
1224 let emit_kind = filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1226 assert!(
1227 matches!(emit_kind, EmitKind::Direct(_)),
1228 "Expected Direct for identity emit, got {emit_kind:?}"
1229 );
1230 }
1231
1232 #[test]
1233 fn test_parse_project_relation() {
1234 let extensions = SimpleExtensions::default();
1235 let project = ProjectRel::parse_pair_with_context(
1236 &extensions,
1237 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1238 vec![Box::new(example_read_relation().into_rel(None))],
1239 3,
1240 )
1241 .unwrap()
1242 .0;
1243
1244 assert_eq!(project.expressions.len(), 1);
1246
1247 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1248 let emit = match emit_kind {
1249 EmitKind::Emit(emit) => &emit.output_mapping,
1250 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1251 };
1252 assert_eq!(emit, &[0, 1, 3]);
1254 }
1255
1256 #[test]
1257 fn test_parse_project_relation_complex() {
1258 let extensions = SimpleExtensions::default();
1259 let project = ProjectRel::parse_pair_with_context(
1260 &extensions,
1261 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1262 vec![Box::new(example_read_relation().into_rel(None))],
1263 5, )
1265 .unwrap()
1266 .0;
1267
1268 assert_eq!(project.expressions.len(), 2);
1270
1271 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1272 let emit = match emit_kind {
1273 EmitKind::Emit(emit) => &emit.output_mapping,
1274 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1275 };
1276 assert_eq!(emit, &[5, 0, 6, 2, 1]);
1279 }
1280
1281 #[test]
1282 fn test_parse_aggregate_relation() {
1283 let extensions = TestContext::new()
1284 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1285 .with_function(1, 10, "sum")
1286 .with_function(1, 11, "count")
1287 .extensions;
1288
1289 let aggregate = AggregateRel::parse_pair_with_context(
1290 &extensions,
1291 parse_exact(
1292 Rule::aggregate_relation,
1293 "Aggregate[($0, $1), _ => sum($2), $0, count($2)]",
1294 ),
1295 vec![Box::new(example_read_relation().into_rel(None))],
1296 3,
1297 )
1298 .unwrap()
1299 .0;
1300
1301 assert_eq!(aggregate.grouping_expressions.len(), 2);
1303 assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1304 assert_eq!(aggregate.groupings.len(), 2);
1305 assert_eq!(aggregate.measures.len(), 2);
1306
1307 let emit_kind = &aggregate
1308 .common
1309 .as_ref()
1310 .unwrap()
1311 .emit_kind
1312 .as_ref()
1313 .unwrap();
1314 let emit = match emit_kind {
1315 EmitKind::Emit(emit) => &emit.output_mapping,
1316 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1317 };
1318 assert_eq!(emit, &[2, 0, 3]);
1321 }
1322
1323 #[test]
1324 fn test_parse_aggregate_relation_maintain_column_order() {
1325 let extensions = TestContext::new()
1326 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1327 .with_function(1, 10, "sum")
1328 .with_function(1, 11, "count")
1329 .extensions;
1330
1331 let aggregate = AggregateRel::parse_pair_with_context(
1332 &extensions,
1333 parse_exact(
1334 Rule::aggregate_relation,
1335 "Aggregate[$0 => sum($1), $0, count($1)]",
1336 ),
1337 vec![Box::new(example_read_relation().into_rel(None))],
1338 3,
1339 )
1340 .unwrap()
1341 .0;
1342
1343 assert_eq!(aggregate.grouping_expressions.len(), 1);
1345 assert_eq!(aggregate.groupings.len(), 1);
1346 assert_eq!(aggregate.measures.len(), 2);
1347
1348 let emit_kind = &aggregate
1349 .common
1350 .as_ref()
1351 .unwrap()
1352 .emit_kind
1353 .as_ref()
1354 .unwrap();
1355 let emit = match emit_kind {
1356 EmitKind::Emit(emit) => &emit.output_mapping,
1357 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1358 };
1359 assert_eq!(emit, &[1, 0, 2]);
1361 }
1362
1363 #[test]
1364 fn test_parse_aggregate_relation_simple() {
1365 let extensions = TestContext::new()
1366 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1367 .with_function(1, 10, "sum")
1368 .extensions;
1369
1370 let aggregate = AggregateRel::parse_pair_with_context(
1371 &extensions,
1372 parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"),
1373 vec![Box::new(example_read_relation().into_rel(None))],
1374 3,
1375 )
1376 .unwrap()
1377 .0;
1378
1379 assert_eq!(aggregate.grouping_expressions.len(), 2);
1380 assert_eq!(aggregate.groupings.len(), 1);
1381 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1383 }
1384
1385 #[test]
1386 fn test_parse_aggregate_relation_global_aggregate() {
1387 let extensions = TestContext::new()
1388 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1389 .with_function(1, 10, "sum")
1390 .with_function(1, 11, "count")
1391 .extensions;
1392
1393 let aggregate = AggregateRel::parse_pair_with_context(
1394 &extensions,
1395 parse_exact(
1396 Rule::aggregate_relation,
1397 "Aggregate[_ => sum($0), count($1)]",
1398 ),
1399 vec![Box::new(example_read_relation().into_rel(None))],
1400 3,
1401 )
1402 .unwrap()
1403 .0;
1404
1405 assert_eq!(aggregate.grouping_expressions.len(), 0);
1407 assert_eq!(aggregate.groupings.len(), 1);
1408 assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1409 assert_eq!(aggregate.measures.len(), 2);
1410
1411 let emit_kind = aggregate
1413 .common
1414 .as_ref()
1415 .unwrap()
1416 .emit_kind
1417 .as_ref()
1418 .unwrap();
1419 assert!(
1420 matches!(emit_kind, EmitKind::Direct(_)),
1421 "Expected Direct for identity emit, got {emit_kind:?}"
1422 );
1423 }
1424
1425 #[test]
1426 fn test_parse_aggregate_relation_grouping_sets() {
1427 let extensions = TestContext::new()
1428 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1429 .with_function(1, 11, "count")
1430 .extensions;
1431
1432 let read_rel = ReadRel::parse_pair_with_context(
1433 &extensions,
1434 parse_exact(
1435 Rule::read_relation,
1436 "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1437 ),
1438 vec![],
1439 0,
1440 )
1441 .unwrap()
1442 .0;
1443
1444 let aggregate = AggregateRel::parse_pair_with_context(
1445 &extensions,
1446 parse_exact(
1447 Rule::aggregate_relation,
1448 "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]",
1449 ),
1450 vec![Box::new(read_rel.into_rel(None))],
1451 4,
1452 )
1453 .unwrap()
1454 .0;
1455
1456 assert_eq!(aggregate.grouping_expressions.len(), 3);
1457 assert_eq!(aggregate.groupings.len(), 4);
1458 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1460 assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1462 assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1464 assert!(aggregate.groupings[3].expression_references.is_empty());
1466 assert_eq!(aggregate.measures.len(), 1);
1467 }
1468
1469 #[test]
1470 fn test_fetch_relation_positive_values() {
1471 let extensions = SimpleExtensions::default();
1472
1473 let fetch_rel = FetchRel::parse_pair_with_context(
1475 &extensions,
1476 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1477 vec![Box::new(example_read_relation().into_rel(None))],
1478 3,
1479 )
1480 .unwrap()
1481 .0;
1482
1483 assert_eq!(
1485 fetch_rel.count_mode,
1486 Some(CountMode::CountExpr(i64_literal_expr(10)))
1487 );
1488 assert_eq!(
1489 fetch_rel.offset_mode,
1490 Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1491 );
1492 }
1493
1494 #[test]
1495 fn test_fetch_relation_negative_limit_rejected() {
1496 let extensions = SimpleExtensions::default();
1497
1498 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1500 if let Ok(mut pairs) = parsed_result {
1501 let pair = pairs.next().unwrap();
1502 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1503 let result = FetchRel::parse_pair_with_context(
1505 &extensions,
1506 pair,
1507 vec![Box::new(example_read_relation().into_rel(None))],
1508 3,
1509 );
1510 assert!(result.is_err());
1511 let error_msg = result.unwrap_err().to_string();
1512 assert!(error_msg.contains("Fetch limit must be non-negative"));
1513 } else {
1514 println!("Grammar prevents negative limit values at parse time");
1517 }
1518 } else {
1519 println!("Grammar prevents negative limit values at parse time");
1521 }
1522 }
1523
1524 #[test]
1525 fn test_fetch_relation_negative_offset_rejected() {
1526 let extensions = SimpleExtensions::default();
1527
1528 let parsed_result =
1530 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1531 if let Ok(mut pairs) = parsed_result {
1532 let pair = pairs.next().unwrap();
1533 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1534 let result = FetchRel::parse_pair_with_context(
1536 &extensions,
1537 pair,
1538 vec![Box::new(example_read_relation().into_rel(None))],
1539 3,
1540 );
1541 assert!(result.is_err());
1542 let error_msg = result.unwrap_err().to_string();
1543 assert!(error_msg.contains("Fetch offset must be non-negative"));
1544 } else {
1545 println!("Grammar prevents negative offset values at parse time");
1547 }
1548 } else {
1549 println!("Grammar prevents negative offset values at parse time");
1551 }
1552 }
1553
1554 #[test]
1555 fn test_parse_join_relation() {
1556 let extensions = TestContext::new()
1557 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1558 .with_function(1, 10, "eq")
1559 .extensions;
1560
1561 let left_rel = example_read_relation().into_rel(None);
1562 let right_rel = example_read_relation().into_rel(None);
1563
1564 let join = JoinRel::parse_pair_with_context(
1565 &extensions,
1566 parse_exact(
1567 Rule::join_relation,
1568 "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1569 ),
1570 vec![Box::new(left_rel), Box::new(right_rel)],
1571 6, )
1573 .unwrap()
1574 .0;
1575
1576 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1578
1579 assert!(join.left.is_some());
1581 assert!(join.right.is_some());
1582
1583 assert!(join.expression.is_some());
1585
1586 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1587 let emit = match emit_kind {
1588 EmitKind::Emit(emit) => &emit.output_mapping,
1589 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1590 };
1591 assert_eq!(emit, &[0, 1, 3, 4]);
1593 }
1594
1595 #[test]
1596 fn test_parse_join_relation_left_outer() {
1597 let extensions = TestContext::new()
1598 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1599 .with_function(1, 10, "eq")
1600 .extensions;
1601
1602 let left_rel = example_read_relation().into_rel(None);
1603 let right_rel = example_read_relation().into_rel(None);
1604
1605 let join = JoinRel::parse_pair_with_context(
1606 &extensions,
1607 parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1608 vec![Box::new(left_rel), Box::new(right_rel)],
1609 6,
1610 )
1611 .unwrap()
1612 .0;
1613
1614 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1616
1617 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1618 let emit = match emit_kind {
1619 EmitKind::Emit(emit) => &emit.output_mapping,
1620 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1621 };
1622 assert_eq!(emit, &[0, 1, 2]);
1624 }
1625
1626 #[test]
1627 fn test_parse_join_relation_left_semi() {
1628 let extensions = TestContext::new()
1629 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1630 .with_function(1, 10, "eq")
1631 .extensions;
1632
1633 let left_rel = example_read_relation().into_rel(None);
1634 let right_rel = example_read_relation().into_rel(None);
1635
1636 let join = JoinRel::parse_pair_with_context(
1637 &extensions,
1638 parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1639 vec![Box::new(left_rel), Box::new(right_rel)],
1640 6,
1641 )
1642 .unwrap()
1643 .0;
1644
1645 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1647
1648 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1649 let emit = match emit_kind {
1650 EmitKind::Emit(emit) => &emit.output_mapping,
1651 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1652 };
1653 assert_eq!(emit, &[0, 1]);
1655 }
1656
1657 #[test]
1658 fn test_parse_join_relation_requires_two_children() {
1659 let extensions = SimpleExtensions::default();
1660
1661 let result = JoinRel::parse_pair_with_context(
1663 &extensions,
1664 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1665 vec![],
1666 0,
1667 );
1668 assert!(result.is_err());
1669
1670 let result = JoinRel::parse_pair_with_context(
1672 &extensions,
1673 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1674 vec![Box::new(example_read_relation().into_rel(None))],
1675 3,
1676 );
1677 assert!(result.is_err());
1678
1679 let result = JoinRel::parse_pair_with_context(
1681 &extensions,
1682 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1683 vec![
1684 Box::new(example_read_relation().into_rel(None)),
1685 Box::new(example_read_relation().into_rel(None)),
1686 Box::new(example_read_relation().into_rel(None)),
1687 ],
1688 9,
1689 );
1690 assert!(result.is_err());
1691 }
1692
1693 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1694 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1695 assert_eq!(pairs.as_str(), input);
1696 let pair = pairs.next().unwrap();
1697 assert_eq!(pairs.next(), None);
1698 pair
1699 }
1700}