1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType, nested};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::{CountMode, OffsetMode};
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::{Direct, Emit, EmitKind};
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
15 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
16};
17
18use super::{MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair};
19use crate::extensions::any::Any;
20use crate::extensions::registry::ExtensionError;
21use crate::extensions::{AddendumKind, ExtensionArgs, ExtensionRegistry, SimpleExtensions};
22use crate::parser::errors::{ParseContext, ParseError};
23use crate::parser::expressions::{FieldIndex, Name};
24
25pub struct RelationParsingContext<'a> {
27 pub extensions: &'a SimpleExtensions,
28 pub registry: &'a ExtensionRegistry,
29 pub line_no: i64,
30 pub line: &'a str,
31}
32
33impl<'a> RelationParsingContext<'a> {
34 pub fn resolve_extension_detail(
36 &self,
37 extension_name: &str,
38 extension_args: &ExtensionArgs,
39 ) -> Result<Option<Any>, ParseError> {
40 let detail = self
41 .registry
42 .parse_extension(extension_name, extension_args);
43
44 match detail {
45 Ok(any) => Ok(Some(any)),
46 Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
47 name: extension_name.to_string(),
48 context: ParseContext::new(self.line_no, self.line.to_string()),
49 }),
50 Err(err) => Err(ParseError::ExtensionDetail(
51 ParseContext::new(self.line_no, self.line.to_string()),
52 err,
53 )),
54 }
55 }
56
57 pub(crate) fn resolve_addendum_detail(
60 &self,
61 kind: AddendumKind,
62 name: &str,
63 args: &ExtensionArgs,
64 ) -> Result<Any, ParseError> {
65 let result = match kind {
66 AddendumKind::Enhancement => self.registry.parse_enhancement(name, args),
67 AddendumKind::Optimization => self.registry.parse_optimization(name, args),
68 AddendumKind::ExtensionTable => self.registry.parse_extension_table(name, args),
69 };
70 result.map_err(|err| match err {
71 ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
72 name: name.to_string(),
73 context: ParseContext::new(self.line_no, self.line.to_string()),
74 },
75 err => ParseError::ExtensionDetail(
76 ParseContext::new(self.line_no, self.line.to_string()),
77 err,
78 ),
79 })
80 }
81}
82
83pub trait RelationParsePair: Sized {
85 fn rule() -> Rule;
86 fn message() -> &'static str;
87
88 fn parse_pair_with_context(
94 extensions: &SimpleExtensions,
95 pair: Pair<Rule>,
96 input_children: Vec<Box<Rel>>,
97 input_field_count: usize,
98 ) -> Result<(Self, usize), MessageParseError>;
99
100 fn into_rel(self, adv_ext: Option<AdvancedExtension>) -> Rel;
103}
104
105pub struct TableName(Vec<String>);
106
107impl ParsePair for TableName {
108 fn rule() -> Rule {
109 Rule::table_name
110 }
111
112 fn message() -> &'static str {
113 "TableName"
114 }
115
116 fn parse_pair(pair: Pair<Rule>) -> Self {
117 assert_eq!(pair.as_rule(), Self::rule());
118 let pairs = pair.into_inner();
119 let mut names = Vec::with_capacity(pairs.len());
120 let mut iter = RuleIter::from(pairs);
121 while let Some(name) = iter.parse_if_next::<Name>() {
122 names.push(name.0);
123 }
124 iter.done();
125 Self(names)
126 }
127}
128
129#[derive(Debug, Clone)]
130pub struct Column {
131 pub name: String,
132 pub typ: Type,
133}
134
135impl ScopedParsePair for Column {
136 fn rule() -> Rule {
137 Rule::named_column
138 }
139
140 fn message() -> &'static str {
141 "Column"
142 }
143
144 fn parse_pair(
145 extensions: &SimpleExtensions,
146 pair: Pair<Rule>,
147 ) -> Result<Self, MessageParseError> {
148 assert_eq!(pair.as_rule(), Self::rule());
149 let mut iter = RuleIter::from(pair.into_inner());
150 let name = iter.parse_next::<Name>().0;
151 let typ = iter.parse_next_scoped(extensions)?;
152 iter.done();
153 Ok(Self { name, typ })
154 }
155}
156
157pub(crate) struct NamedColumnList(pub(crate) Vec<Column>);
158
159impl ScopedParsePair for NamedColumnList {
160 fn rule() -> Rule {
161 Rule::named_column_list
162 }
163
164 fn message() -> &'static str {
165 "NamedColumnList"
166 }
167
168 fn parse_pair(
169 extensions: &SimpleExtensions,
170 pair: Pair<Rule>,
171 ) -> Result<Self, MessageParseError> {
172 assert_eq!(pair.as_rule(), Self::rule());
173 let mut columns = Vec::new();
174 for col in pair.into_inner() {
175 columns.push(Column::parse_pair(extensions, col)?);
176 }
177 Ok(Self(columns))
178 }
179}
180
181#[allow(clippy::vec_box)]
186pub(crate) fn expect_one_child(
187 message: &'static str,
188 pair: &Pair<Rule>,
189 mut input_children: Vec<Box<Rel>>,
190) -> Result<Box<Rel>, MessageParseError> {
191 match input_children.len() {
192 0 => Err(MessageParseError::invalid(
193 message,
194 pair.as_span(),
195 format!("{message} missing child"),
196 )),
197 1 => Ok(input_children.pop().unwrap()),
198 n => Err(MessageParseError::invalid(
199 message,
200 pair.as_span(),
201 format!("{message} should have 1 input child, got {n}"),
202 )),
203 }
204}
205
206fn parse_output_mapping(pair: Pair<Rule>) -> Vec<i32> {
209 assert_eq!(pair.as_rule(), Rule::reference_list);
210 pair.into_inner()
211 .map(|p| FieldIndex::parse_pair(p).0)
212 .collect()
213}
214
215fn make_emit(output_mapping: Vec<i32>, direct_output_count: usize) -> EmitKind {
218 let is_identity = output_mapping.len() == direct_output_count
219 && output_mapping
220 .iter()
221 .enumerate()
222 .all(|(i, &v)| v == i as i32);
223 if is_identity {
224 EmitKind::Direct(Direct {})
225 } else {
226 EmitKind::Emit(Emit { output_mapping })
227 }
228}
229
230fn parse_emit(reference_list: Pair<Rule>, direct_output_count: usize) -> (EmitKind, usize) {
232 let output_mapping = parse_output_mapping(reference_list);
233 let output_count = output_mapping.len();
234 let emit = make_emit(output_mapping, direct_output_count);
235 (emit, output_count)
236}
237
238pub struct ParsedNamedArgs<'a> {
244 map: HashMap<&'a str, Pair<'a, Rule>>,
245}
246
247impl<'a> ParsedNamedArgs<'a> {
248 pub fn new(
249 pairs: pest::iterators::Pairs<'a, Rule>,
250 rule: Rule,
251 ) -> Result<Self, MessageParseError> {
252 let mut map = HashMap::new();
253 for pair in pairs {
254 assert_eq!(pair.as_rule(), rule);
255 let mut inner = pair.clone().into_inner();
256 let name_pair = inner.next().unwrap();
257 let value_pair = inner.next().unwrap();
258 assert_eq!(inner.next(), None);
259 let name = name_pair.as_str();
260 if map.contains_key(name) {
261 return Err(MessageParseError::invalid(
262 "NamedArg",
263 name_pair.as_span(),
264 format!("Duplicate argument: {name}"),
265 ));
266 }
267 map.insert(name, value_pair);
268 }
269 Ok(Self { map })
270 }
271
272 pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
276 let pair = self.map.remove(name).inspect(|pair| {
277 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
278 });
279 (self, pair)
280 }
281
282 pub fn done(self) -> Result<(), MessageParseError> {
284 if let Some((name, pair)) = self.map.iter().next() {
285 return Err(MessageParseError::invalid(
286 "NamedArgExtractor",
287 pair.as_span(),
289 format!("Unknown argument: {name}"),
290 ));
291 }
292 Ok(())
293 }
294}
295
296impl RelationParsePair for ReadRel {
297 fn rule() -> Rule {
298 Rule::read_relation
299 }
300
301 fn message() -> &'static str {
302 "ReadRel"
303 }
304
305 fn parse_pair_with_context(
306 extensions: &SimpleExtensions,
307 pair: Pair<Rule>,
308 input_children: Vec<Box<Rel>>,
309 input_field_count: usize,
310 ) -> Result<(Self, usize), MessageParseError> {
311 assert_eq!(pair.as_rule(), Self::rule());
312 if !input_children.is_empty() {
314 return Err(MessageParseError::invalid(
315 Self::message(),
316 pair.as_span(),
317 "ReadRel should have no input children",
318 ));
319 }
320 if input_field_count != 0 {
321 return Err(MessageParseError::invalid(
322 "ReadRel",
323 pair.as_span(),
324 "ReadRel should have 0 input fields",
325 ));
326 }
327
328 let mut iter = RuleIter::from(pair.into_inner());
329 let table = iter.parse_next::<TableName>().0;
330 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
331 iter.done();
332
333 let output_count = columns.len();
334 Ok((
335 ReadRel {
336 base_schema: Some(build_named_struct(columns)),
337 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
338 names: table,
339 advanced_extension: None,
340 })),
341 ..Default::default()
342 },
343 output_count,
344 ))
345 }
346
347 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
348 self.advanced_extension = adv_ext;
349 Rel {
350 rel_type: Some(RelType::Read(Box::new(self))),
351 }
352 }
353}
354
355pub(crate) struct VirtualReadRel(ReadRel);
359
360impl RelationParsePair for VirtualReadRel {
361 fn rule() -> Rule {
362 Rule::virtual_read_relation
363 }
364
365 fn message() -> &'static str {
366 "VirtualReadRel"
367 }
368
369 fn parse_pair_with_context(
370 extensions: &SimpleExtensions,
371 pair: Pair<Rule>,
372 input_children: Vec<Box<Rel>>,
373 _input_field_count: usize,
374 ) -> Result<(Self, usize), MessageParseError> {
375 assert_eq!(pair.as_rule(), Self::rule());
376 if !input_children.is_empty() {
377 return Err(MessageParseError::invalid(
378 Self::message(),
379 pair.as_span(),
380 "Read:Virtual should have no input children",
381 ));
382 }
383
384 let mut iter = RuleIter::from(pair.into_inner());
385 let args_pair = iter.pop(Rule::virtual_read_args);
386 let columns_pair = iter.pop(Rule::named_column_list);
387 iter.done();
388
389 let rows = parse_virtual_read_args(extensions, args_pair)?;
390 let columns = NamedColumnList::parse_pair(extensions, columns_pair)?.0;
391
392 let output_count = columns.len();
397 Ok((
398 VirtualReadRel(ReadRel {
399 base_schema: Some(build_named_struct(columns)),
400 read_type: Some(read_rel::ReadType::VirtualTable(read_rel::VirtualTable {
401 expressions: rows,
402 ..Default::default()
403 })),
404 ..Default::default()
405 }),
406 output_count,
407 ))
408 }
409
410 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
411 self.0.advanced_extension = adv_ext;
412 Rel {
413 rel_type: Some(RelType::Read(Box::new(self.0))),
414 }
415 }
416}
417
418pub(crate) struct ExtensionReadRel(ReadRel);
422
423impl ExtensionReadRel {
424 #[allow(clippy::vec_box)]
430 pub(crate) fn parse_pair_with_detail(
431 extensions: &SimpleExtensions,
432 pair: Pair<Rule>,
433 input_children: Vec<Box<Rel>>,
434 input_field_count: usize,
435 detail: Any,
436 advanced_extension: Option<AdvancedExtension>,
437 ) -> Result<(Rel, usize), MessageParseError> {
438 assert_eq!(pair.as_rule(), Rule::extension_read_relation);
439 if !input_children.is_empty() {
440 return Err(MessageParseError::invalid(
441 "ExtensionReadRel",
442 pair.as_span(),
443 "Read:Extension should have no input children",
444 ));
445 }
446 if input_field_count != 0 {
447 return Err(MessageParseError::invalid(
448 "ExtensionReadRel",
449 pair.as_span(),
450 "Read:Extension should have 0 input fields",
451 ));
452 }
453
454 let mut iter = RuleIter::from(pair.into_inner());
455 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
456 iter.done();
457
458 let output_count = columns.len();
459 let rel = ExtensionReadRel(ReadRel {
460 base_schema: Some(build_named_struct(columns)),
461 read_type: Some(read_rel::ReadType::ExtensionTable(
462 read_rel::ExtensionTable {
463 detail: Some(detail.into()),
464 },
465 )),
466 advanced_extension,
467 ..Default::default()
468 })
469 .into_rel();
470
471 Ok((rel, output_count))
472 }
473
474 fn into_rel(self) -> Rel {
475 Rel {
476 rel_type: Some(RelType::Read(Box::new(self.0))),
477 }
478 }
479}
480
481pub(crate) fn build_named_struct(columns: Vec<Column>) -> NamedStruct {
483 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
484 NamedStruct {
485 names,
486 r#struct: Some(r#type::Struct {
487 types,
488 type_variation_reference: 0,
489 nullability: r#type::Nullability::Required as i32,
490 }),
491 }
492}
493
494fn parse_virtual_read_args(
496 extensions: &SimpleExtensions,
497 pair: Pair<Rule>,
498) -> Result<Vec<nested::Struct>, MessageParseError> {
499 assert_eq!(pair.as_rule(), Rule::virtual_read_args);
500 let inner = unwrap_single_pair(pair);
501 match inner.as_rule() {
502 Rule::empty => Ok(vec![]),
503 Rule::virtual_row_list => inner
504 .into_inner()
505 .map(|row| parse_virtual_row(extensions, row))
506 .collect(),
507 _ => unreachable!(
508 "Unexpected rule in virtual_read_args: {:?}",
509 inner.as_rule()
510 ),
511 }
512}
513
514fn parse_virtual_row(
516 extensions: &SimpleExtensions,
517 pair: Pair<Rule>,
518) -> Result<nested::Struct, MessageParseError> {
519 assert_eq!(pair.as_rule(), Rule::virtual_row);
520 let fields = match pair.into_inner().next() {
521 Some(expression_list) => {
522 assert_eq!(expression_list.as_rule(), Rule::expression_list);
523 parse_expression_list(extensions, expression_list)?
524 }
525 None => vec![],
527 };
528 Ok(nested::Struct { fields })
529}
530
531impl RelationParsePair for FilterRel {
532 fn rule() -> Rule {
533 Rule::filter_relation
534 }
535
536 fn message() -> &'static str {
537 "FilterRel"
538 }
539
540 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
541 self.advanced_extension = adv_ext;
542 Rel {
543 rel_type: Some(RelType::Filter(Box::new(self))),
544 }
545 }
546
547 fn parse_pair_with_context(
548 extensions: &SimpleExtensions,
549 pair: Pair<Rule>,
550 input_children: Vec<Box<Rel>>,
551 input_field_count: usize,
552 ) -> Result<(Self, usize), MessageParseError> {
553 assert_eq!(pair.as_rule(), Self::rule());
554 let input = expect_one_child(Self::message(), &pair, input_children)?;
555 let mut iter = RuleIter::from(pair.into_inner());
556 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
557 let references_pair = iter.pop(Rule::reference_list);
559 iter.done();
560
561 let (emit, output_count) = parse_emit(references_pair, input_field_count);
562 let common = RelCommon {
563 emit_kind: Some(emit),
564 ..Default::default()
565 };
566
567 Ok((
568 FilterRel {
569 input: Some(input),
570 condition: Some(Box::new(condition)),
571 common: Some(common),
572 advanced_extension: None,
573 },
574 output_count,
575 ))
576 }
577}
578
579impl RelationParsePair for ProjectRel {
580 fn rule() -> Rule {
581 Rule::project_relation
582 }
583
584 fn message() -> &'static str {
585 "ProjectRel"
586 }
587
588 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
589 self.advanced_extension = adv_ext;
590 Rel {
591 rel_type: Some(RelType::Project(Box::new(self))),
592 }
593 }
594
595 fn parse_pair_with_context(
596 extensions: &SimpleExtensions,
597 pair: Pair<Rule>,
598 input_children: Vec<Box<Rel>>,
599 input_field_count: usize,
600 ) -> Result<(Self, usize), MessageParseError> {
601 assert_eq!(pair.as_rule(), Self::rule());
602 let input = expect_one_child(Self::message(), &pair, input_children)?;
603
604 let arguments_pair = unwrap_single_pair(pair);
605
606 let mut expressions = Vec::new();
607 let mut output_mapping = Vec::new();
608
609 for arg in arguments_pair.into_inner() {
610 let inner_arg = unwrap_single_pair(arg);
611 match inner_arg.as_rule() {
612 Rule::reference => {
613 let field_index = FieldIndex::parse_pair(inner_arg);
614 output_mapping.push(field_index.0);
615 }
616 Rule::expression => {
617 let expr = Expression::parse_pair(extensions, inner_arg)?;
618 expressions.push(expr);
619 output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
621 }
622 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
623 }
624 }
625
626 let output_count = output_mapping.len();
627 let direct_count = input_field_count + expressions.len();
628 let emit = make_emit(output_mapping, direct_count);
629 let common = RelCommon {
630 emit_kind: Some(emit),
631 ..Default::default()
632 };
633
634 Ok((
635 ProjectRel {
636 input: Some(input),
637 expressions,
638 common: Some(common),
639 advanced_extension: None,
640 },
641 output_count,
642 ))
643 }
644}
645
646impl RelationParsePair for AggregateRel {
647 fn rule() -> Rule {
648 Rule::aggregate_relation
649 }
650
651 fn message() -> &'static str {
652 "AggregateRel"
653 }
654
655 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
656 self.advanced_extension = adv_ext;
657 Rel {
658 rel_type: Some(RelType::Aggregate(Box::new(self))),
659 }
660 }
661
662 fn parse_pair_with_context(
663 extensions: &SimpleExtensions,
664 pair: Pair<Rule>,
665 input_children: Vec<Box<Rel>>,
666 _input_field_count: usize,
669 ) -> Result<(Self, usize), MessageParseError> {
670 assert_eq!(pair.as_rule(), Self::rule());
671 let input = expect_one_child(Self::message(), &pair, input_children)?;
672 let mut iter = RuleIter::from(pair.into_inner());
673 let group_by_pair = iter.pop(Rule::aggregate_group_by);
674 let output_pair = iter.pop(Rule::aggregate_output);
675 iter.done();
676
677 let inner = group_by_pair
678 .into_inner()
679 .next()
680 .expect("aggregate_group_by must have one inner item");
681
682 let grouping_sets = parse_grouping_sets(extensions, inner)?;
683 let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
684
685 let (measures, output_mapping) =
686 parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
687
688 let output_count = output_mapping.len();
689 let direct_count = grouping_expressions.len() + measures.len();
690 let emit = make_emit(output_mapping, direct_count);
691 let common = RelCommon {
692 emit_kind: Some(emit),
693 ..Default::default()
694 };
695
696 Ok((
697 AggregateRel {
698 input: Some(input),
699 grouping_expressions,
700 groupings,
701 measures,
702 common: Some(common),
703 advanced_extension: None,
704 },
705 output_count,
706 ))
707 }
708}
709
710fn parse_aggregate_measures(
715 extensions: &SimpleExtensions,
716 output_pair: Pair<'_, Rule>,
717 grouping_expressions: &[Expression],
718) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
719 assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
720 let mut measures = Vec::new();
721 let mut output_mapping = Vec::new();
722
723 for aggregate_output_item in output_pair.into_inner() {
724 let inner_item = unwrap_single_pair(aggregate_output_item);
725 match inner_item.as_rule() {
726 Rule::reference => {
727 let field_index = FieldIndex::parse_pair(inner_item);
728 output_mapping.push(field_index.0);
729 }
730 Rule::aggregate_measure => {
731 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
732 output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
733 measures.push(measure);
734 }
735 _ => panic!(
736 "Unexpected inner output item rule: {:?}",
737 inner_item.as_rule()
738 ),
739 }
740 }
741
742 Ok((measures, output_mapping))
743}
744
745fn parse_grouping_sets(
754 extensions: &SimpleExtensions,
755 inner: Pair<'_, Rule>,
756) -> Result<Vec<Vec<Expression>>, MessageParseError> {
757 assert!(
758 matches!(
759 inner.as_rule(),
760 Rule::expression_list | Rule::grouping_set_list
761 ),
762 "Expected expression_list or grouping_set_list, got {:?}",
763 inner.as_rule()
764 );
765 match inner.as_rule() {
766 Rule::expression_list => Ok(vec![parse_expression_list(extensions, inner)?]),
767 Rule::grouping_set_list => inner
768 .into_inner()
769 .map(|pair| parse_grouping_set(extensions, pair))
770 .collect(),
771 _ => unreachable!(
772 "Unexpected rule in aggregate_group_by: {:?}",
773 inner.as_rule()
774 ),
775 }
776}
777
778fn parse_grouping_set(
782 extensions: &SimpleExtensions,
783 pair: Pair<'_, Rule>,
784) -> Result<Vec<Expression>, MessageParseError> {
785 assert_eq!(pair.as_rule(), Rule::grouping_set);
786 let inner = pair
787 .into_inner()
788 .next()
789 .expect("grouping_set must have one inner item");
790 match inner.as_rule() {
791 Rule::empty => Ok(vec![]),
792 Rule::expression_list => parse_expression_list(extensions, inner),
793 _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
794 }
795}
796
797pub(crate) fn parse_expression_list(
799 extensions: &SimpleExtensions,
800 pair: Pair<'_, Rule>,
801) -> Result<Vec<Expression>, MessageParseError> {
802 pair.into_inner()
803 .map(|expr_pair| Expression::parse_pair(extensions, expr_pair))
804 .collect()
805}
806
807fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
811 let mut expressions: Vec<Expression> = Vec::new();
812 let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
813
814 let groupings = expression_sets
815 .iter()
816 .map(|set| {
817 let expression_references = set
818 .iter()
819 .map(|exp| {
820 let key = exp.encode_to_vec();
824 let next_idx = expressions.len() as u32;
825 *seen.entry(key).or_insert_with(|| {
826 expressions.push(exp.clone());
827 next_idx
828 })
829 })
830 .collect();
831 Grouping {
832 expression_references,
833 #[allow(deprecated)]
834 grouping_expressions: vec![],
835 }
836 })
837 .collect();
838
839 (groupings, expressions)
840}
841
842impl ScopedParsePair for SortField {
843 fn rule() -> Rule {
844 Rule::sort_field
845 }
846
847 fn message() -> &'static str {
848 "SortField"
849 }
850
851 fn parse_pair(
852 _extensions: &SimpleExtensions,
853 pair: Pair<Rule>,
854 ) -> Result<Self, MessageParseError> {
855 assert_eq!(pair.as_rule(), Self::rule());
856 let mut iter = RuleIter::from(pair.into_inner());
857 let reference_pair = iter.pop(Rule::reference);
858 let field_index = FieldIndex::parse_pair(reference_pair);
859 let direction_pair = iter.pop(Rule::sort_direction);
860 let direction = match direction_pair.as_str().trim_start_matches('&') {
864 "AscNullsFirst" => SortDirection::AscNullsFirst,
865 "AscNullsLast" => SortDirection::AscNullsLast,
866 "DescNullsFirst" => SortDirection::DescNullsFirst,
867 "DescNullsLast" => SortDirection::DescNullsLast,
868 other => {
869 return Err(MessageParseError::invalid(
870 "SortDirection",
871 direction_pair.as_span(),
872 format!("Unknown sort direction: {other}"),
873 ));
874 }
875 };
876 iter.done();
877 Ok(SortField {
878 expr: Some(Expression {
879 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
880 field_index.to_field_reference(),
881 ))),
882 }),
883 sort_kind: Some(SortKind::Direction(direction as i32)),
885 })
886 }
887}
888
889impl RelationParsePair for SortRel {
890 fn rule() -> Rule {
891 Rule::sort_relation
892 }
893
894 fn message() -> &'static str {
895 "SortRel"
896 }
897
898 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
899 self.advanced_extension = adv_ext;
900 Rel {
901 rel_type: Some(RelType::Sort(Box::new(self))),
902 }
903 }
904
905 fn parse_pair_with_context(
906 extensions: &SimpleExtensions,
907 pair: Pair<Rule>,
908 input_children: Vec<Box<Rel>>,
909 input_field_count: usize,
910 ) -> Result<(Self, usize), MessageParseError> {
911 assert_eq!(pair.as_rule(), Self::rule());
912 let input = expect_one_child(Self::message(), &pair, input_children)?;
913 let mut iter = RuleIter::from(pair.into_inner());
914 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
915 let reference_list_pair = iter.pop(Rule::reference_list);
916 let mut sorts = Vec::new();
917 for sort_field_pair in sort_field_list_pair.into_inner() {
918 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
919 sorts.push(sort_field);
920 }
921 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
922 let common = RelCommon {
923 emit_kind: Some(emit),
924 ..Default::default()
925 };
926 iter.done();
927 Ok((
928 SortRel {
929 input: Some(input),
930 sorts,
931 common: Some(common),
932 advanced_extension: None,
933 },
934 output_count,
935 ))
936 }
937}
938
939impl ScopedParsePair for CountMode {
940 fn rule() -> Rule {
941 Rule::fetch_value
942 }
943 fn message() -> &'static str {
944 "CountMode"
945 }
946 fn parse_pair(
947 extensions: &SimpleExtensions,
948 pair: Pair<Rule>,
949 ) -> Result<Self, MessageParseError> {
950 assert_eq!(pair.as_rule(), Self::rule());
951 let mut arg_inner = RuleIter::from(pair.into_inner());
952 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
953 int_pair
954 } else {
955 arg_inner.pop(Rule::expression)
956 };
957 match value_pair.as_rule() {
958 Rule::integer => {
959 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
960 MessageParseError::invalid(
961 Self::message(),
962 value_pair.as_span(),
963 format!("Invalid integer: {e}"),
964 )
965 })?;
966 if value < 0 {
967 return Err(MessageParseError::invalid(
968 Self::message(),
969 value_pair.as_span(),
970 format!("Fetch limit must be non-negative, got: {value}"),
971 ));
972 }
973 Ok(CountMode::CountExpr(i64_literal_expr(value)))
974 }
975 Rule::expression => {
976 let expr = Expression::parse_pair(extensions, value_pair)?;
977 Ok(CountMode::CountExpr(Box::new(expr)))
978 }
979 _ => Err(MessageParseError::invalid(
980 Self::message(),
981 value_pair.as_span(),
982 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
983 )),
984 }
985 }
986}
987
988fn i64_literal_expr(value: i64) -> Box<Expression> {
989 Box::new(Expression {
990 rex_type: Some(RexType::Literal(Literal {
991 nullable: false,
992 type_variation_reference: 0,
993 literal_type: Some(LiteralType::I64(value)),
994 })),
995 })
996}
997
998impl ScopedParsePair for OffsetMode {
999 fn rule() -> Rule {
1000 Rule::fetch_value
1001 }
1002 fn message() -> &'static str {
1003 "OffsetMode"
1004 }
1005 fn parse_pair(
1006 extensions: &SimpleExtensions,
1007 pair: Pair<Rule>,
1008 ) -> Result<Self, MessageParseError> {
1009 assert_eq!(pair.as_rule(), Self::rule());
1010 let mut arg_inner = RuleIter::from(pair.into_inner());
1011 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
1012 int_pair
1013 } else {
1014 arg_inner.pop(Rule::expression)
1015 };
1016 match value_pair.as_rule() {
1017 Rule::integer => {
1018 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
1019 MessageParseError::invalid(
1020 Self::message(),
1021 value_pair.as_span(),
1022 format!("Invalid integer: {e}"),
1023 )
1024 })?;
1025 if value < 0 {
1026 return Err(MessageParseError::invalid(
1027 Self::message(),
1028 value_pair.as_span(),
1029 format!("Fetch offset must be non-negative, got: {value}"),
1030 ));
1031 }
1032 Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
1033 }
1034 Rule::expression => {
1035 let expr = Expression::parse_pair(extensions, value_pair)?;
1036 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
1037 }
1038 _ => Err(MessageParseError::invalid(
1039 Self::message(),
1040 value_pair.as_span(),
1041 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
1042 )),
1043 }
1044 }
1045}
1046
1047impl RelationParsePair for FetchRel {
1048 fn rule() -> Rule {
1049 Rule::fetch_relation
1050 }
1051
1052 fn message() -> &'static str {
1053 "FetchRel"
1054 }
1055
1056 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1057 self.advanced_extension = adv_ext;
1058 Rel {
1059 rel_type: Some(RelType::Fetch(Box::new(self))),
1060 }
1061 }
1062
1063 fn parse_pair_with_context(
1064 extensions: &SimpleExtensions,
1065 pair: Pair<Rule>,
1066 input_children: Vec<Box<Rel>>,
1067 input_field_count: usize,
1068 ) -> Result<(Self, usize), MessageParseError> {
1069 assert_eq!(pair.as_rule(), Self::rule());
1070 let input = expect_one_child(Self::message(), &pair, input_children)?;
1071 let mut iter = RuleIter::from(pair.into_inner());
1072
1073 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
1077 None => {
1078 iter.pop(Rule::empty);
1079 (None, None)
1080 }
1081 Some(fetch_args_pair) => {
1082 let extractor =
1083 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
1084 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
1085 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
1086 extractor.done()?;
1087 (limit_pair, offset_pair)
1088 }
1089 };
1090
1091 let reference_list_pair = iter.pop(Rule::reference_list);
1092 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1093 let common = RelCommon {
1094 emit_kind: Some(emit),
1095 ..Default::default()
1096 };
1097 iter.done();
1098
1099 let count_mode = limit_pair
1100 .map(|pair| CountMode::parse_pair(extensions, pair))
1101 .transpose()?;
1102 let offset_mode = offset_pair
1103 .map(|pair| OffsetMode::parse_pair(extensions, pair))
1104 .transpose()?;
1105 Ok((
1106 FetchRel {
1107 input: Some(input),
1108 common: Some(common),
1109 advanced_extension: None,
1110 offset_mode,
1111 count_mode,
1112 },
1113 output_count,
1114 ))
1115 }
1116}
1117
1118impl ParsePair for join_rel::JoinType {
1119 fn rule() -> Rule {
1120 Rule::join_type
1121 }
1122
1123 fn message() -> &'static str {
1124 "JoinType"
1125 }
1126
1127 fn parse_pair(pair: Pair<Rule>) -> Self {
1128 assert_eq!(pair.as_rule(), Self::rule());
1129 let join_type_str = pair.as_str().trim_start_matches('&');
1130 match join_type_str {
1131 "Inner" => join_rel::JoinType::Inner,
1132 "Left" => join_rel::JoinType::Left,
1133 "Right" => join_rel::JoinType::Right,
1134 "Outer" => join_rel::JoinType::Outer,
1135 "LeftSemi" => join_rel::JoinType::LeftSemi,
1136 "RightSemi" => join_rel::JoinType::RightSemi,
1137 "LeftAnti" => join_rel::JoinType::LeftAnti,
1138 "RightAnti" => join_rel::JoinType::RightAnti,
1139 "LeftSingle" => join_rel::JoinType::LeftSingle,
1140 "RightSingle" => join_rel::JoinType::RightSingle,
1141 "LeftMark" => join_rel::JoinType::LeftMark,
1142 "RightMark" => join_rel::JoinType::RightMark,
1143 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
1144 }
1145 }
1146}
1147
1148impl RelationParsePair for JoinRel {
1149 fn rule() -> Rule {
1150 Rule::join_relation
1151 }
1152
1153 fn message() -> &'static str {
1154 "JoinRel"
1155 }
1156
1157 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1158 self.advanced_extension = adv_ext;
1159 Rel {
1160 rel_type: Some(RelType::Join(Box::new(self))),
1161 }
1162 }
1163
1164 fn parse_pair_with_context(
1165 extensions: &SimpleExtensions,
1166 pair: Pair<Rule>,
1167 input_children: Vec<Box<Rel>>,
1168 input_field_count: usize,
1169 ) -> Result<(Self, usize), MessageParseError> {
1170 assert_eq!(pair.as_rule(), Self::rule());
1171
1172 if input_children.len() != 2 {
1173 return Err(MessageParseError::invalid(
1174 Self::message(),
1175 pair.as_span(),
1176 format!(
1177 "JoinRel should have exactly 2 input children, got {}",
1178 input_children.len()
1179 ),
1180 ));
1181 }
1182
1183 let mut children_iter = input_children.into_iter();
1184 let left = children_iter.next().unwrap();
1185 let right = children_iter.next().unwrap();
1186
1187 let mut iter = RuleIter::from(pair.into_inner());
1188 let join_type = iter.parse_next::<join_rel::JoinType>();
1189 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
1190 let reference_list_pair = iter.pop(Rule::reference_list);
1191 iter.done();
1192
1193 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1197 let common = RelCommon {
1198 emit_kind: Some(emit),
1199 ..Default::default()
1200 };
1201
1202 Ok((
1203 JoinRel {
1204 common: Some(common),
1205 left: Some(left),
1206 right: Some(right),
1207 expression: Some(Box::new(condition)),
1208 post_join_filter: None, r#type: join_type as i32,
1210 advanced_extension: None,
1211 },
1212 output_count,
1213 ))
1214 }
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219 use pest::Parser;
1220
1221 use super::*;
1222 use crate::fixtures::TestContext;
1223 use crate::parser::{ExpressionParser, Rule};
1224
1225 #[test]
1226 fn test_parse_relation() {
1227 }
1229
1230 #[test]
1231 fn test_parse_read_relation() {
1232 let extensions = SimpleExtensions::default();
1233 let read = ReadRel::parse_pair_with_context(
1234 &extensions,
1235 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1236 vec![],
1237 0,
1238 )
1239 .unwrap()
1240 .0;
1241 let names = match &read.read_type {
1242 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1243 _ => panic!("Expected NamedTable"),
1244 };
1245 assert_eq!(names, &["ab", "cd", "ef"]);
1246 let columns = &read
1247 .base_schema
1248 .as_ref()
1249 .unwrap()
1250 .r#struct
1251 .as_ref()
1252 .unwrap()
1253 .types;
1254 assert_eq!(columns.len(), 2);
1255 }
1256
1257 fn example_read_relation() -> ReadRel {
1259 let extensions = SimpleExtensions::default();
1260 ReadRel::parse_pair_with_context(
1261 &extensions,
1262 parse_exact(
1263 Rule::read_relation,
1264 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1265 ),
1266 vec![],
1267 0,
1268 )
1269 .unwrap()
1270 .0
1271 }
1272
1273 #[test]
1274 fn test_parse_filter_relation() {
1275 let extensions = SimpleExtensions::default();
1276 let filter = FilterRel::parse_pair_with_context(
1277 &extensions,
1278 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1279 vec![Box::new(example_read_relation().into_rel(None))],
1280 3,
1281 )
1282 .unwrap()
1283 .0;
1284 let emit_kind = filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1286 assert!(
1287 matches!(emit_kind, EmitKind::Direct(_)),
1288 "Expected Direct for identity emit, got {emit_kind:?}"
1289 );
1290 }
1291
1292 #[test]
1293 fn test_parse_project_relation() {
1294 let extensions = SimpleExtensions::default();
1295 let project = ProjectRel::parse_pair_with_context(
1296 &extensions,
1297 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1298 vec![Box::new(example_read_relation().into_rel(None))],
1299 3,
1300 )
1301 .unwrap()
1302 .0;
1303
1304 assert_eq!(project.expressions.len(), 1);
1306
1307 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1308 let emit = match emit_kind {
1309 EmitKind::Emit(emit) => &emit.output_mapping,
1310 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1311 };
1312 assert_eq!(emit, &[0, 1, 3]);
1314 }
1315
1316 #[test]
1317 fn test_parse_project_relation_complex() {
1318 let extensions = SimpleExtensions::default();
1319 let project = ProjectRel::parse_pair_with_context(
1320 &extensions,
1321 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1322 vec![Box::new(example_read_relation().into_rel(None))],
1323 5, )
1325 .unwrap()
1326 .0;
1327
1328 assert_eq!(project.expressions.len(), 2);
1330
1331 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1332 let emit = match emit_kind {
1333 EmitKind::Emit(emit) => &emit.output_mapping,
1334 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1335 };
1336 assert_eq!(emit, &[5, 0, 6, 2, 1]);
1339 }
1340
1341 #[test]
1342 fn test_parse_aggregate_relation() {
1343 let extensions = TestContext::new()
1344 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1345 .with_function(1, 10, "sum")
1346 .with_function(1, 11, "count")
1347 .extensions;
1348
1349 let aggregate = AggregateRel::parse_pair_with_context(
1350 &extensions,
1351 parse_exact(
1352 Rule::aggregate_relation,
1353 "Aggregate[($0, $1), _ => sum($2), $0, count($2)]",
1354 ),
1355 vec![Box::new(example_read_relation().into_rel(None))],
1356 3,
1357 )
1358 .unwrap()
1359 .0;
1360
1361 assert_eq!(aggregate.grouping_expressions.len(), 2);
1363 assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1364 assert_eq!(aggregate.groupings.len(), 2);
1365 assert_eq!(aggregate.measures.len(), 2);
1366
1367 let emit_kind = &aggregate
1368 .common
1369 .as_ref()
1370 .unwrap()
1371 .emit_kind
1372 .as_ref()
1373 .unwrap();
1374 let emit = match emit_kind {
1375 EmitKind::Emit(emit) => &emit.output_mapping,
1376 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1377 };
1378 assert_eq!(emit, &[2, 0, 3]);
1381 }
1382
1383 #[test]
1384 fn test_parse_aggregate_relation_maintain_column_order() {
1385 let extensions = TestContext::new()
1386 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1387 .with_function(1, 10, "sum")
1388 .with_function(1, 11, "count")
1389 .extensions;
1390
1391 let aggregate = AggregateRel::parse_pair_with_context(
1392 &extensions,
1393 parse_exact(
1394 Rule::aggregate_relation,
1395 "Aggregate[$0 => sum($1), $0, count($1)]",
1396 ),
1397 vec![Box::new(example_read_relation().into_rel(None))],
1398 3,
1399 )
1400 .unwrap()
1401 .0;
1402
1403 assert_eq!(aggregate.grouping_expressions.len(), 1);
1405 assert_eq!(aggregate.groupings.len(), 1);
1406 assert_eq!(aggregate.measures.len(), 2);
1407
1408 let emit_kind = &aggregate
1409 .common
1410 .as_ref()
1411 .unwrap()
1412 .emit_kind
1413 .as_ref()
1414 .unwrap();
1415 let emit = match emit_kind {
1416 EmitKind::Emit(emit) => &emit.output_mapping,
1417 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1418 };
1419 assert_eq!(emit, &[1, 0, 2]);
1421 }
1422
1423 #[test]
1424 fn test_parse_aggregate_relation_simple() {
1425 let extensions = TestContext::new()
1426 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1427 .with_function(1, 10, "sum")
1428 .extensions;
1429
1430 let aggregate = AggregateRel::parse_pair_with_context(
1431 &extensions,
1432 parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"),
1433 vec![Box::new(example_read_relation().into_rel(None))],
1434 3,
1435 )
1436 .unwrap()
1437 .0;
1438
1439 assert_eq!(aggregate.grouping_expressions.len(), 2);
1440 assert_eq!(aggregate.groupings.len(), 1);
1441 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1443 }
1444
1445 #[test]
1446 fn test_parse_aggregate_relation_global_aggregate() {
1447 let extensions = TestContext::new()
1448 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1449 .with_function(1, 10, "sum")
1450 .with_function(1, 11, "count")
1451 .extensions;
1452
1453 let aggregate = AggregateRel::parse_pair_with_context(
1454 &extensions,
1455 parse_exact(
1456 Rule::aggregate_relation,
1457 "Aggregate[_ => sum($0), count($1)]",
1458 ),
1459 vec![Box::new(example_read_relation().into_rel(None))],
1460 3,
1461 )
1462 .unwrap()
1463 .0;
1464
1465 assert_eq!(aggregate.grouping_expressions.len(), 0);
1467 assert_eq!(aggregate.groupings.len(), 1);
1468 assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1469 assert_eq!(aggregate.measures.len(), 2);
1470
1471 let emit_kind = aggregate
1473 .common
1474 .as_ref()
1475 .unwrap()
1476 .emit_kind
1477 .as_ref()
1478 .unwrap();
1479 assert!(
1480 matches!(emit_kind, EmitKind::Direct(_)),
1481 "Expected Direct for identity emit, got {emit_kind:?}"
1482 );
1483 }
1484
1485 #[test]
1486 fn test_parse_aggregate_relation_grouping_sets() {
1487 let extensions = TestContext::new()
1488 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1489 .with_function(1, 11, "count")
1490 .extensions;
1491
1492 let read_rel = ReadRel::parse_pair_with_context(
1493 &extensions,
1494 parse_exact(
1495 Rule::read_relation,
1496 "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1497 ),
1498 vec![],
1499 0,
1500 )
1501 .unwrap()
1502 .0;
1503
1504 let aggregate = AggregateRel::parse_pair_with_context(
1505 &extensions,
1506 parse_exact(
1507 Rule::aggregate_relation,
1508 "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]",
1509 ),
1510 vec![Box::new(read_rel.into_rel(None))],
1511 4,
1512 )
1513 .unwrap()
1514 .0;
1515
1516 assert_eq!(aggregate.grouping_expressions.len(), 3);
1517 assert_eq!(aggregate.groupings.len(), 4);
1518 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1520 assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1522 assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1524 assert!(aggregate.groupings[3].expression_references.is_empty());
1526 assert_eq!(aggregate.measures.len(), 1);
1527 }
1528
1529 #[test]
1530 fn test_fetch_relation_positive_values() {
1531 let extensions = SimpleExtensions::default();
1532
1533 let fetch_rel = FetchRel::parse_pair_with_context(
1535 &extensions,
1536 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1537 vec![Box::new(example_read_relation().into_rel(None))],
1538 3,
1539 )
1540 .unwrap()
1541 .0;
1542
1543 assert_eq!(
1545 fetch_rel.count_mode,
1546 Some(CountMode::CountExpr(i64_literal_expr(10)))
1547 );
1548 assert_eq!(
1549 fetch_rel.offset_mode,
1550 Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1551 );
1552 }
1553
1554 #[test]
1555 fn test_fetch_relation_negative_limit_rejected() {
1556 let extensions = SimpleExtensions::default();
1557
1558 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1560 if let Ok(mut pairs) = parsed_result {
1561 let pair = pairs.next().unwrap();
1562 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1563 let result = FetchRel::parse_pair_with_context(
1565 &extensions,
1566 pair,
1567 vec![Box::new(example_read_relation().into_rel(None))],
1568 3,
1569 );
1570 assert!(result.is_err());
1571 let error_msg = result.unwrap_err().to_string();
1572 assert!(error_msg.contains("Fetch limit must be non-negative"));
1573 } else {
1574 println!("Grammar prevents negative limit values at parse time");
1577 }
1578 } else {
1579 println!("Grammar prevents negative limit values at parse time");
1581 }
1582 }
1583
1584 #[test]
1585 fn test_fetch_relation_negative_offset_rejected() {
1586 let extensions = SimpleExtensions::default();
1587
1588 let parsed_result =
1590 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1591 if let Ok(mut pairs) = parsed_result {
1592 let pair = pairs.next().unwrap();
1593 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1594 let result = FetchRel::parse_pair_with_context(
1596 &extensions,
1597 pair,
1598 vec![Box::new(example_read_relation().into_rel(None))],
1599 3,
1600 );
1601 assert!(result.is_err());
1602 let error_msg = result.unwrap_err().to_string();
1603 assert!(error_msg.contains("Fetch offset must be non-negative"));
1604 } else {
1605 println!("Grammar prevents negative offset values at parse time");
1607 }
1608 } else {
1609 println!("Grammar prevents negative offset values at parse time");
1611 }
1612 }
1613
1614 #[test]
1615 fn test_parse_join_relation() {
1616 let extensions = TestContext::new()
1617 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1618 .with_function(1, 10, "eq")
1619 .extensions;
1620
1621 let left_rel = example_read_relation().into_rel(None);
1622 let right_rel = example_read_relation().into_rel(None);
1623
1624 let join = JoinRel::parse_pair_with_context(
1625 &extensions,
1626 parse_exact(
1627 Rule::join_relation,
1628 "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1629 ),
1630 vec![Box::new(left_rel), Box::new(right_rel)],
1631 6, )
1633 .unwrap()
1634 .0;
1635
1636 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1638
1639 assert!(join.left.is_some());
1641 assert!(join.right.is_some());
1642
1643 assert!(join.expression.is_some());
1645
1646 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1647 let emit = match emit_kind {
1648 EmitKind::Emit(emit) => &emit.output_mapping,
1649 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1650 };
1651 assert_eq!(emit, &[0, 1, 3, 4]);
1653 }
1654
1655 #[test]
1656 fn test_parse_join_relation_left_outer() {
1657 let extensions = TestContext::new()
1658 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1659 .with_function(1, 10, "eq")
1660 .extensions;
1661
1662 let left_rel = example_read_relation().into_rel(None);
1663 let right_rel = example_read_relation().into_rel(None);
1664
1665 let join = JoinRel::parse_pair_with_context(
1666 &extensions,
1667 parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1668 vec![Box::new(left_rel), Box::new(right_rel)],
1669 6,
1670 )
1671 .unwrap()
1672 .0;
1673
1674 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1676
1677 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1678 let emit = match emit_kind {
1679 EmitKind::Emit(emit) => &emit.output_mapping,
1680 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1681 };
1682 assert_eq!(emit, &[0, 1, 2]);
1684 }
1685
1686 #[test]
1687 fn test_parse_join_relation_left_semi() {
1688 let extensions = TestContext::new()
1689 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1690 .with_function(1, 10, "eq")
1691 .extensions;
1692
1693 let left_rel = example_read_relation().into_rel(None);
1694 let right_rel = example_read_relation().into_rel(None);
1695
1696 let join = JoinRel::parse_pair_with_context(
1697 &extensions,
1698 parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1699 vec![Box::new(left_rel), Box::new(right_rel)],
1700 6,
1701 )
1702 .unwrap()
1703 .0;
1704
1705 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1707
1708 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1709 let emit = match emit_kind {
1710 EmitKind::Emit(emit) => &emit.output_mapping,
1711 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1712 };
1713 assert_eq!(emit, &[0, 1]);
1715 }
1716
1717 #[test]
1718 fn test_parse_join_relation_requires_two_children() {
1719 let extensions = SimpleExtensions::default();
1720
1721 let result = JoinRel::parse_pair_with_context(
1723 &extensions,
1724 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1725 vec![],
1726 0,
1727 );
1728 assert!(result.is_err());
1729
1730 let result = JoinRel::parse_pair_with_context(
1732 &extensions,
1733 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1734 vec![Box::new(example_read_relation().into_rel(None))],
1735 3,
1736 );
1737 assert!(result.is_err());
1738
1739 let result = JoinRel::parse_pair_with_context(
1741 &extensions,
1742 parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1743 vec![
1744 Box::new(example_read_relation().into_rel(None)),
1745 Box::new(example_read_relation().into_rel(None)),
1746 Box::new(example_read_relation().into_rel(None)),
1747 ],
1748 9,
1749 );
1750 assert!(result.is_err());
1751 }
1752
1753 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1754 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1755 assert_eq!(pairs.as_str(), input);
1756 let pair = pairs.next().unwrap();
1757 assert_eq!(pairs.next(), None);
1758 pair
1759 }
1760}