1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType, nested};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::{CountMode, OffsetMode};
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::{Direct, Emit, EmitKind};
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14 AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
15 RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
16};
17
18use super::{MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair};
19use crate::extensions::any::Any;
20use crate::extensions::registry::ExtensionError;
21use crate::extensions::{AddendumKind, ExtensionArgs, ExtensionRegistry, SimpleExtensions};
22use crate::parser::errors::{ParseContext, ParseError};
23use crate::parser::expressions::{FieldIndex, Name};
24
25pub struct RelationParsingContext<'a> {
27 pub registry: &'a ExtensionRegistry,
28 pub line_no: i64,
29 pub line: &'a str,
30}
31
32impl<'a> RelationParsingContext<'a> {
33 pub fn resolve_extension_detail(
35 &self,
36 extension_name: &str,
37 extension_args: &ExtensionArgs,
38 ) -> Result<Option<Any>, ParseError> {
39 let detail = self
40 .registry
41 .parse_extension(extension_name, extension_args);
42
43 match detail {
44 Ok(any) => Ok(Some(any)),
45 Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
46 name: extension_name.to_string(),
47 context: ParseContext::new(self.line_no, self.line.to_string()),
48 }),
49 Err(err) => Err(ParseError::ExtensionDetail(
50 ParseContext::new(self.line_no, self.line.to_string()),
51 err,
52 )),
53 }
54 }
55
56 pub(crate) fn resolve_addendum_detail(
59 &self,
60 kind: AddendumKind,
61 name: &str,
62 args: &ExtensionArgs,
63 ) -> Result<Any, ParseError> {
64 let result = match kind {
65 AddendumKind::Enhancement => self.registry.parse_enhancement(name, args),
66 AddendumKind::Optimization => self.registry.parse_optimization(name, args),
67 AddendumKind::ExtensionTable => self.registry.parse_extension_table(name, args),
68 };
69 result.map_err(|err| match err {
70 ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
71 name: name.to_string(),
72 context: ParseContext::new(self.line_no, self.line.to_string()),
73 },
74 err => ParseError::ExtensionDetail(
75 ParseContext::new(self.line_no, self.line.to_string()),
76 err,
77 ),
78 })
79 }
80}
81
82pub trait RelationParsePair: Sized {
84 fn rule() -> Rule;
85 fn message() -> &'static str;
86
87 fn parse_pair_with_context(
93 extensions: &SimpleExtensions,
94 pair: Pair<Rule>,
95 input_children: Vec<Rel>,
96 input_field_count: usize,
97 ) -> Result<(Self, usize), MessageParseError>;
98
99 fn into_rel(self, adv_ext: Option<AdvancedExtension>) -> Rel;
102}
103
104pub struct TableName(Vec<String>);
105
106impl ParsePair for TableName {
107 fn rule() -> Rule {
108 Rule::table_name
109 }
110
111 fn message() -> &'static str {
112 "TableName"
113 }
114
115 fn parse_pair(pair: Pair<Rule>) -> Self {
116 assert_eq!(pair.as_rule(), Self::rule());
117 let pairs = pair.into_inner();
118 let mut names = Vec::with_capacity(pairs.len());
119 let mut iter = RuleIter::from(pairs);
120 while let Some(name) = iter.parse_if_next::<Name>() {
121 names.push(name.0);
122 }
123 iter.done();
124 Self(names)
125 }
126}
127
128#[derive(Debug, Clone)]
129pub struct Column {
130 pub name: String,
131 pub typ: Type,
132}
133
134impl ScopedParsePair for Column {
135 fn rule() -> Rule {
136 Rule::named_column
137 }
138
139 fn message() -> &'static str {
140 "Column"
141 }
142
143 fn parse_pair(
144 extensions: &SimpleExtensions,
145 pair: Pair<Rule>,
146 ) -> Result<Self, MessageParseError> {
147 assert_eq!(pair.as_rule(), Self::rule());
148 let mut iter = RuleIter::from(pair.into_inner());
149 let name = iter.parse_next::<Name>().0;
150 let typ = iter.parse_next_scoped(extensions)?;
151 iter.done();
152 Ok(Self { name, typ })
153 }
154}
155
156pub(crate) struct NamedColumnList(pub(crate) Vec<Column>);
157
158impl ScopedParsePair for NamedColumnList {
159 fn rule() -> Rule {
160 Rule::named_column_list
161 }
162
163 fn message() -> &'static str {
164 "NamedColumnList"
165 }
166
167 fn parse_pair(
168 extensions: &SimpleExtensions,
169 pair: Pair<Rule>,
170 ) -> Result<Self, MessageParseError> {
171 assert_eq!(pair.as_rule(), Self::rule());
172 let mut columns = Vec::new();
173 for col in pair.into_inner() {
174 columns.push(Column::parse_pair(extensions, col)?);
175 }
176 Ok(Self(columns))
177 }
178}
179
180pub(crate) fn expect_one_child(
185 message: &'static str,
186 pair: &Pair<Rule>,
187 mut input_children: Vec<Rel>,
188) -> Result<Box<Rel>, MessageParseError> {
189 match input_children.len() {
190 0 => Err(MessageParseError::invalid(
191 message,
192 pair.as_span(),
193 format!("{message} missing child"),
194 )),
195 1 => Ok(Box::new(input_children.pop().unwrap())),
196 n => Err(MessageParseError::invalid(
197 message,
198 pair.as_span(),
199 format!("{message} should have 1 input child, got {n}"),
200 )),
201 }
202}
203
204fn parse_output_mapping(pair: Pair<Rule>) -> Vec<i32> {
207 assert_eq!(pair.as_rule(), Rule::reference_list);
208 pair.into_inner()
209 .map(|p| FieldIndex::parse_pair(p).0)
210 .collect()
211}
212
213fn make_emit(output_mapping: Vec<i32>, direct_output_count: usize) -> EmitKind {
216 let is_identity = output_mapping.len() == direct_output_count
217 && output_mapping
218 .iter()
219 .enumerate()
220 .all(|(i, &v)| v == i as i32);
221 if is_identity {
222 EmitKind::Direct(Direct {})
223 } else {
224 EmitKind::Emit(Emit { output_mapping })
225 }
226}
227
228fn parse_emit(reference_list: Pair<Rule>, direct_output_count: usize) -> (EmitKind, usize) {
230 let output_mapping = parse_output_mapping(reference_list);
231 let output_count = output_mapping.len();
232 let emit = make_emit(output_mapping, direct_output_count);
233 (emit, output_count)
234}
235
236pub struct ParsedNamedArgs<'a> {
242 map: HashMap<&'a str, Pair<'a, Rule>>,
243}
244
245impl<'a> ParsedNamedArgs<'a> {
246 pub fn new(
247 pairs: pest::iterators::Pairs<'a, Rule>,
248 rule: Rule,
249 ) -> Result<Self, MessageParseError> {
250 let mut map = HashMap::new();
251 for pair in pairs {
252 assert_eq!(pair.as_rule(), rule);
253 let mut inner = pair.clone().into_inner();
254 let name_pair = inner.next().unwrap();
255 let value_pair = inner.next().unwrap();
256 assert_eq!(inner.next(), None);
257 let name = name_pair.as_str();
258 if map.contains_key(name) {
259 return Err(MessageParseError::invalid(
260 "NamedArg",
261 name_pair.as_span(),
262 format!("Duplicate argument: {name}"),
263 ));
264 }
265 map.insert(name, value_pair);
266 }
267 Ok(Self { map })
268 }
269
270 pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
274 let pair = self.map.remove(name).inspect(|pair| {
275 assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
276 });
277 (self, pair)
278 }
279
280 pub fn done(self) -> Result<(), MessageParseError> {
282 if let Some((name, pair)) = self.map.iter().next() {
283 return Err(MessageParseError::invalid(
284 "NamedArgExtractor",
285 pair.as_span(),
287 format!("Unknown argument: {name}"),
288 ));
289 }
290 Ok(())
291 }
292}
293
294impl RelationParsePair for ReadRel {
295 fn rule() -> Rule {
296 Rule::read_relation
297 }
298
299 fn message() -> &'static str {
300 "ReadRel"
301 }
302
303 fn parse_pair_with_context(
304 extensions: &SimpleExtensions,
305 pair: Pair<Rule>,
306 input_children: Vec<Rel>,
307 input_field_count: usize,
308 ) -> Result<(Self, usize), MessageParseError> {
309 assert_eq!(pair.as_rule(), Self::rule());
310 if !input_children.is_empty() {
312 return Err(MessageParseError::invalid(
313 Self::message(),
314 pair.as_span(),
315 "ReadRel should have no input children",
316 ));
317 }
318 if input_field_count != 0 {
319 return Err(MessageParseError::invalid(
320 "ReadRel",
321 pair.as_span(),
322 "ReadRel should have 0 input fields",
323 ));
324 }
325
326 let mut iter = RuleIter::from(pair.into_inner());
327 let table = iter.parse_next::<TableName>().0;
328 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
329 iter.done();
330
331 let output_count = columns.len();
332 Ok((
333 ReadRel {
334 base_schema: Some(build_named_struct(columns)),
335 read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
336 names: table,
337 advanced_extension: None,
338 })),
339 ..Default::default()
340 },
341 output_count,
342 ))
343 }
344
345 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
346 self.advanced_extension = adv_ext;
347 Rel {
348 rel_type: Some(RelType::Read(Box::new(self))),
349 }
350 }
351}
352
353pub(crate) struct VirtualReadRel(ReadRel);
357
358impl RelationParsePair for VirtualReadRel {
359 fn rule() -> Rule {
360 Rule::virtual_read_relation
361 }
362
363 fn message() -> &'static str {
364 "VirtualReadRel"
365 }
366
367 fn parse_pair_with_context(
368 extensions: &SimpleExtensions,
369 pair: Pair<Rule>,
370 input_children: Vec<Rel>,
371 _input_field_count: usize,
372 ) -> Result<(Self, usize), MessageParseError> {
373 assert_eq!(pair.as_rule(), Self::rule());
374 if !input_children.is_empty() {
375 return Err(MessageParseError::invalid(
376 Self::message(),
377 pair.as_span(),
378 "Read:Virtual should have no input children",
379 ));
380 }
381
382 let mut iter = RuleIter::from(pair.into_inner());
383 let args_pair = iter.pop(Rule::virtual_read_args);
384 let columns_pair = iter.pop(Rule::named_column_list);
385 iter.done();
386
387 let rows = parse_virtual_read_args(extensions, args_pair)?;
388 let columns = NamedColumnList::parse_pair(extensions, columns_pair)?.0;
389
390 let output_count = columns.len();
395 Ok((
396 VirtualReadRel(ReadRel {
397 base_schema: Some(build_named_struct(columns)),
398 read_type: Some(read_rel::ReadType::VirtualTable(read_rel::VirtualTable {
399 expressions: rows,
400 ..Default::default()
401 })),
402 ..Default::default()
403 }),
404 output_count,
405 ))
406 }
407
408 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
409 self.0.advanced_extension = adv_ext;
410 Rel {
411 rel_type: Some(RelType::Read(Box::new(self.0))),
412 }
413 }
414}
415
416pub(crate) struct ExtensionReadRel(ReadRel);
420
421impl ExtensionReadRel {
422 pub(crate) fn parse_pair_with_detail(
423 extensions: &SimpleExtensions,
424 pair: Pair<Rule>,
425 input_children: Vec<Rel>,
426 input_field_count: usize,
427 detail: Any,
428 advanced_extension: Option<AdvancedExtension>,
429 ) -> Result<(Rel, usize), MessageParseError> {
430 assert_eq!(pair.as_rule(), Rule::extension_read_relation);
431 if !input_children.is_empty() {
432 return Err(MessageParseError::invalid(
433 "ExtensionReadRel",
434 pair.as_span(),
435 "Read:Extension should have no input children",
436 ));
437 }
438 if input_field_count != 0 {
439 return Err(MessageParseError::invalid(
440 "ExtensionReadRel",
441 pair.as_span(),
442 "Read:Extension should have 0 input fields",
443 ));
444 }
445
446 let mut iter = RuleIter::from(pair.into_inner());
447 let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
448 iter.done();
449
450 let output_count = columns.len();
451 let rel = ExtensionReadRel(ReadRel {
452 base_schema: Some(build_named_struct(columns)),
453 read_type: Some(read_rel::ReadType::ExtensionTable(
454 read_rel::ExtensionTable {
455 detail: Some(detail.into()),
456 },
457 )),
458 advanced_extension,
459 ..Default::default()
460 })
461 .into_rel();
462
463 Ok((rel, output_count))
464 }
465
466 fn into_rel(self) -> Rel {
467 Rel {
468 rel_type: Some(RelType::Read(Box::new(self.0))),
469 }
470 }
471}
472
473pub(crate) fn build_named_struct(columns: Vec<Column>) -> NamedStruct {
475 let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
476 NamedStruct {
477 names,
478 r#struct: Some(r#type::Struct {
479 types,
480 type_variation_reference: 0,
481 nullability: r#type::Nullability::Required as i32,
482 }),
483 }
484}
485
486fn parse_virtual_read_args(
488 extensions: &SimpleExtensions,
489 pair: Pair<Rule>,
490) -> Result<Vec<nested::Struct>, MessageParseError> {
491 assert_eq!(pair.as_rule(), Rule::virtual_read_args);
492 let inner = unwrap_single_pair(pair);
493 match inner.as_rule() {
494 Rule::empty => Ok(vec![]),
495 Rule::virtual_row_list => inner
496 .into_inner()
497 .map(|row| parse_virtual_row(extensions, row))
498 .collect(),
499 _ => unreachable!(
500 "Unexpected rule in virtual_read_args: {:?}",
501 inner.as_rule()
502 ),
503 }
504}
505
506fn parse_virtual_row(
508 extensions: &SimpleExtensions,
509 pair: Pair<Rule>,
510) -> Result<nested::Struct, MessageParseError> {
511 assert_eq!(pair.as_rule(), Rule::virtual_row);
512 let fields = match pair.into_inner().next() {
513 Some(expression_list) => {
514 assert_eq!(expression_list.as_rule(), Rule::expression_list);
515 parse_expression_list(extensions, expression_list)?
516 }
517 None => vec![],
519 };
520 Ok(nested::Struct { fields })
521}
522
523impl RelationParsePair for FilterRel {
524 fn rule() -> Rule {
525 Rule::filter_relation
526 }
527
528 fn message() -> &'static str {
529 "FilterRel"
530 }
531
532 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
533 self.advanced_extension = adv_ext;
534 Rel {
535 rel_type: Some(RelType::Filter(Box::new(self))),
536 }
537 }
538
539 fn parse_pair_with_context(
540 extensions: &SimpleExtensions,
541 pair: Pair<Rule>,
542 input_children: Vec<Rel>,
543 input_field_count: usize,
544 ) -> Result<(Self, usize), MessageParseError> {
545 assert_eq!(pair.as_rule(), Self::rule());
546 let input = expect_one_child(Self::message(), &pair, input_children)?;
547 let mut iter = RuleIter::from(pair.into_inner());
548 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
549 let references_pair = iter.pop(Rule::reference_list);
551 iter.done();
552
553 let (emit, output_count) = parse_emit(references_pair, input_field_count);
554 let common = RelCommon {
555 emit_kind: Some(emit),
556 ..Default::default()
557 };
558
559 Ok((
560 FilterRel {
561 input: Some(input),
562 condition: Some(Box::new(condition)),
563 common: Some(common),
564 advanced_extension: None,
565 },
566 output_count,
567 ))
568 }
569}
570
571impl RelationParsePair for ProjectRel {
572 fn rule() -> Rule {
573 Rule::project_relation
574 }
575
576 fn message() -> &'static str {
577 "ProjectRel"
578 }
579
580 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
581 self.advanced_extension = adv_ext;
582 Rel {
583 rel_type: Some(RelType::Project(Box::new(self))),
584 }
585 }
586
587 fn parse_pair_with_context(
588 extensions: &SimpleExtensions,
589 pair: Pair<Rule>,
590 input_children: Vec<Rel>,
591 input_field_count: usize,
592 ) -> Result<(Self, usize), MessageParseError> {
593 assert_eq!(pair.as_rule(), Self::rule());
594 let input = expect_one_child(Self::message(), &pair, input_children)?;
595
596 let arguments_pair = unwrap_single_pair(pair);
597
598 let mut expressions = Vec::new();
599 let mut output_mapping = Vec::new();
600
601 for arg in arguments_pair.into_inner() {
602 let inner_arg = unwrap_single_pair(arg);
603 match inner_arg.as_rule() {
604 Rule::reference => {
605 let field_index = FieldIndex::parse_pair(inner_arg);
606 output_mapping.push(field_index.0);
607 }
608 Rule::expression => {
609 let expr = Expression::parse_pair(extensions, inner_arg)?;
610 expressions.push(expr);
611 output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
613 }
614 _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
615 }
616 }
617
618 let output_count = output_mapping.len();
619 let direct_count = input_field_count + expressions.len();
620 let emit = make_emit(output_mapping, direct_count);
621 let common = RelCommon {
622 emit_kind: Some(emit),
623 ..Default::default()
624 };
625
626 Ok((
627 ProjectRel {
628 input: Some(input),
629 expressions,
630 common: Some(common),
631 advanced_extension: None,
632 },
633 output_count,
634 ))
635 }
636}
637
638impl RelationParsePair for AggregateRel {
639 fn rule() -> Rule {
640 Rule::aggregate_relation
641 }
642
643 fn message() -> &'static str {
644 "AggregateRel"
645 }
646
647 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
648 self.advanced_extension = adv_ext;
649 Rel {
650 rel_type: Some(RelType::Aggregate(Box::new(self))),
651 }
652 }
653
654 fn parse_pair_with_context(
655 extensions: &SimpleExtensions,
656 pair: Pair<Rule>,
657 input_children: Vec<Rel>,
658 _input_field_count: usize,
661 ) -> Result<(Self, usize), MessageParseError> {
662 assert_eq!(pair.as_rule(), Self::rule());
663 let input = expect_one_child(Self::message(), &pair, input_children)?;
664 let mut iter = RuleIter::from(pair.into_inner());
665 let group_by_pair = iter.pop(Rule::aggregate_group_by);
666 let output_pair = iter.pop(Rule::aggregate_output);
667 iter.done();
668
669 let inner = group_by_pair
670 .into_inner()
671 .next()
672 .expect("aggregate_group_by must have one inner item");
673
674 let grouping_sets = parse_grouping_sets(extensions, inner)?;
675 let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
676
677 let (measures, output_mapping) =
678 parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
679
680 let output_count = output_mapping.len();
681 let direct_count = grouping_expressions.len() + measures.len();
682 let emit = make_emit(output_mapping, direct_count);
683 let common = RelCommon {
684 emit_kind: Some(emit),
685 ..Default::default()
686 };
687
688 Ok((
689 AggregateRel {
690 input: Some(input),
691 grouping_expressions,
692 groupings,
693 measures,
694 common: Some(common),
695 advanced_extension: None,
696 },
697 output_count,
698 ))
699 }
700}
701
702fn parse_aggregate_measures(
707 extensions: &SimpleExtensions,
708 output_pair: Pair<'_, Rule>,
709 grouping_expressions: &[Expression],
710) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
711 assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
712 let mut measures = Vec::new();
713 let mut output_mapping = Vec::new();
714
715 for aggregate_output_item in output_pair.into_inner() {
716 let inner_item = unwrap_single_pair(aggregate_output_item);
717 match inner_item.as_rule() {
718 Rule::reference => {
719 let field_index = FieldIndex::parse_pair(inner_item);
720 output_mapping.push(field_index.0);
721 }
722 Rule::aggregate_measure => {
723 let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
724 output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
725 measures.push(measure);
726 }
727 _ => panic!(
728 "Unexpected inner output item rule: {:?}",
729 inner_item.as_rule()
730 ),
731 }
732 }
733
734 Ok((measures, output_mapping))
735}
736
737fn parse_grouping_sets(
746 extensions: &SimpleExtensions,
747 inner: Pair<'_, Rule>,
748) -> Result<Vec<Vec<Expression>>, MessageParseError> {
749 assert!(
750 matches!(
751 inner.as_rule(),
752 Rule::expression_list | Rule::grouping_set_list
753 ),
754 "Expected expression_list or grouping_set_list, got {:?}",
755 inner.as_rule()
756 );
757 match inner.as_rule() {
758 Rule::expression_list => Ok(vec![parse_expression_list(extensions, inner)?]),
759 Rule::grouping_set_list => inner
760 .into_inner()
761 .map(|pair| parse_grouping_set(extensions, pair))
762 .collect(),
763 _ => unreachable!(
764 "Unexpected rule in aggregate_group_by: {:?}",
765 inner.as_rule()
766 ),
767 }
768}
769
770fn parse_grouping_set(
774 extensions: &SimpleExtensions,
775 pair: Pair<'_, Rule>,
776) -> Result<Vec<Expression>, MessageParseError> {
777 assert_eq!(pair.as_rule(), Rule::grouping_set);
778 let inner = pair
779 .into_inner()
780 .next()
781 .expect("grouping_set must have one inner item");
782 match inner.as_rule() {
783 Rule::empty => Ok(vec![]),
784 Rule::expression_list => parse_expression_list(extensions, inner),
785 _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
786 }
787}
788
789pub(crate) fn parse_expression_list(
791 extensions: &SimpleExtensions,
792 pair: Pair<'_, Rule>,
793) -> Result<Vec<Expression>, MessageParseError> {
794 pair.into_inner()
795 .map(|expr_pair| Expression::parse_pair(extensions, expr_pair))
796 .collect()
797}
798
799fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
803 let mut expressions: Vec<Expression> = Vec::new();
804 let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
805
806 let groupings = expression_sets
807 .iter()
808 .map(|set| {
809 let expression_references = set
810 .iter()
811 .map(|exp| {
812 let key = exp.encode_to_vec();
816 let next_idx = expressions.len() as u32;
817 *seen.entry(key).or_insert_with(|| {
818 expressions.push(exp.clone());
819 next_idx
820 })
821 })
822 .collect();
823 Grouping {
824 expression_references,
825 #[allow(deprecated)]
826 grouping_expressions: vec![],
827 }
828 })
829 .collect();
830
831 (groupings, expressions)
832}
833
834impl ScopedParsePair for SortField {
835 fn rule() -> Rule {
836 Rule::sort_field
837 }
838
839 fn message() -> &'static str {
840 "SortField"
841 }
842
843 fn parse_pair(
844 _extensions: &SimpleExtensions,
845 pair: Pair<Rule>,
846 ) -> Result<Self, MessageParseError> {
847 assert_eq!(pair.as_rule(), Self::rule());
848 let mut iter = RuleIter::from(pair.into_inner());
849 let reference_pair = iter.pop(Rule::reference);
850 let field_index = FieldIndex::parse_pair(reference_pair);
851 let direction_pair = iter.pop(Rule::sort_direction);
852 let direction = match direction_pair.as_str().trim_start_matches('&') {
856 "AscNullsFirst" => SortDirection::AscNullsFirst,
857 "AscNullsLast" => SortDirection::AscNullsLast,
858 "DescNullsFirst" => SortDirection::DescNullsFirst,
859 "DescNullsLast" => SortDirection::DescNullsLast,
860 other => {
861 return Err(MessageParseError::invalid(
862 "SortDirection",
863 direction_pair.as_span(),
864 format!("Unknown sort direction: {other}"),
865 ));
866 }
867 };
868 iter.done();
869 Ok(SortField {
870 expr: Some(Expression {
871 rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
872 field_index.to_field_reference(),
873 ))),
874 }),
875 sort_kind: Some(SortKind::Direction(direction as i32)),
877 })
878 }
879}
880
881impl RelationParsePair for SortRel {
882 fn rule() -> Rule {
883 Rule::sort_relation
884 }
885
886 fn message() -> &'static str {
887 "SortRel"
888 }
889
890 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
891 self.advanced_extension = adv_ext;
892 Rel {
893 rel_type: Some(RelType::Sort(Box::new(self))),
894 }
895 }
896
897 fn parse_pair_with_context(
898 extensions: &SimpleExtensions,
899 pair: Pair<Rule>,
900 input_children: Vec<Rel>,
901 input_field_count: usize,
902 ) -> Result<(Self, usize), MessageParseError> {
903 assert_eq!(pair.as_rule(), Self::rule());
904 let input = expect_one_child(Self::message(), &pair, input_children)?;
905 let mut iter = RuleIter::from(pair.into_inner());
906 let sort_field_list_pair = iter.pop(Rule::sort_field_list);
907 let reference_list_pair = iter.pop(Rule::reference_list);
908 let mut sorts = Vec::new();
909 for sort_field_pair in sort_field_list_pair.into_inner() {
910 let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
911 sorts.push(sort_field);
912 }
913 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
914 let common = RelCommon {
915 emit_kind: Some(emit),
916 ..Default::default()
917 };
918 iter.done();
919 Ok((
920 SortRel {
921 input: Some(input),
922 sorts,
923 common: Some(common),
924 advanced_extension: None,
925 },
926 output_count,
927 ))
928 }
929}
930
931impl ScopedParsePair for CountMode {
932 fn rule() -> Rule {
933 Rule::fetch_value
934 }
935 fn message() -> &'static str {
936 "CountMode"
937 }
938 fn parse_pair(
939 extensions: &SimpleExtensions,
940 pair: Pair<Rule>,
941 ) -> Result<Self, MessageParseError> {
942 assert_eq!(pair.as_rule(), Self::rule());
943 let mut arg_inner = RuleIter::from(pair.into_inner());
944 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
945 int_pair
946 } else {
947 arg_inner.pop(Rule::expression)
948 };
949 match value_pair.as_rule() {
950 Rule::integer => {
951 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
952 MessageParseError::invalid(
953 Self::message(),
954 value_pair.as_span(),
955 format!("Invalid integer: {e}"),
956 )
957 })?;
958 if value < 0 {
959 return Err(MessageParseError::invalid(
960 Self::message(),
961 value_pair.as_span(),
962 format!("Fetch limit must be non-negative, got: {value}"),
963 ));
964 }
965 Ok(CountMode::CountExpr(i64_literal_expr(value)))
966 }
967 Rule::expression => {
968 let expr = Expression::parse_pair(extensions, value_pair)?;
969 Ok(CountMode::CountExpr(Box::new(expr)))
970 }
971 _ => Err(MessageParseError::invalid(
972 Self::message(),
973 value_pair.as_span(),
974 format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
975 )),
976 }
977 }
978}
979
980fn i64_literal_expr(value: i64) -> Box<Expression> {
981 Box::new(Expression {
982 rex_type: Some(RexType::Literal(Literal {
983 nullable: false,
984 type_variation_reference: 0,
985 literal_type: Some(LiteralType::I64(value)),
986 })),
987 })
988}
989
990impl ScopedParsePair for OffsetMode {
991 fn rule() -> Rule {
992 Rule::fetch_value
993 }
994 fn message() -> &'static str {
995 "OffsetMode"
996 }
997 fn parse_pair(
998 extensions: &SimpleExtensions,
999 pair: Pair<Rule>,
1000 ) -> Result<Self, MessageParseError> {
1001 assert_eq!(pair.as_rule(), Self::rule());
1002 let mut arg_inner = RuleIter::from(pair.into_inner());
1003 let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
1004 int_pair
1005 } else {
1006 arg_inner.pop(Rule::expression)
1007 };
1008 match value_pair.as_rule() {
1009 Rule::integer => {
1010 let value = value_pair.as_str().parse::<i64>().map_err(|e| {
1011 MessageParseError::invalid(
1012 Self::message(),
1013 value_pair.as_span(),
1014 format!("Invalid integer: {e}"),
1015 )
1016 })?;
1017 if value < 0 {
1018 return Err(MessageParseError::invalid(
1019 Self::message(),
1020 value_pair.as_span(),
1021 format!("Fetch offset must be non-negative, got: {value}"),
1022 ));
1023 }
1024 Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
1025 }
1026 Rule::expression => {
1027 let expr = Expression::parse_pair(extensions, value_pair)?;
1028 Ok(OffsetMode::OffsetExpr(Box::new(expr)))
1029 }
1030 _ => Err(MessageParseError::invalid(
1031 Self::message(),
1032 value_pair.as_span(),
1033 format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
1034 )),
1035 }
1036 }
1037}
1038
1039impl RelationParsePair for FetchRel {
1040 fn rule() -> Rule {
1041 Rule::fetch_relation
1042 }
1043
1044 fn message() -> &'static str {
1045 "FetchRel"
1046 }
1047
1048 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1049 self.advanced_extension = adv_ext;
1050 Rel {
1051 rel_type: Some(RelType::Fetch(Box::new(self))),
1052 }
1053 }
1054
1055 fn parse_pair_with_context(
1056 extensions: &SimpleExtensions,
1057 pair: Pair<Rule>,
1058 input_children: Vec<Rel>,
1059 input_field_count: usize,
1060 ) -> Result<(Self, usize), MessageParseError> {
1061 assert_eq!(pair.as_rule(), Self::rule());
1062 let input = expect_one_child(Self::message(), &pair, input_children)?;
1063 let mut iter = RuleIter::from(pair.into_inner());
1064
1065 let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
1069 None => {
1070 iter.pop(Rule::empty);
1071 (None, None)
1072 }
1073 Some(fetch_args_pair) => {
1074 let extractor =
1075 ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
1076 let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
1077 let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
1078 extractor.done()?;
1079 (limit_pair, offset_pair)
1080 }
1081 };
1082
1083 let reference_list_pair = iter.pop(Rule::reference_list);
1084 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1085 let common = RelCommon {
1086 emit_kind: Some(emit),
1087 ..Default::default()
1088 };
1089 iter.done();
1090
1091 let count_mode = limit_pair
1092 .map(|pair| CountMode::parse_pair(extensions, pair))
1093 .transpose()?;
1094 let offset_mode = offset_pair
1095 .map(|pair| OffsetMode::parse_pair(extensions, pair))
1096 .transpose()?;
1097 Ok((
1098 FetchRel {
1099 input: Some(input),
1100 common: Some(common),
1101 advanced_extension: None,
1102 offset_mode,
1103 count_mode,
1104 },
1105 output_count,
1106 ))
1107 }
1108}
1109
1110impl ParsePair for join_rel::JoinType {
1111 fn rule() -> Rule {
1112 Rule::join_type
1113 }
1114
1115 fn message() -> &'static str {
1116 "JoinType"
1117 }
1118
1119 fn parse_pair(pair: Pair<Rule>) -> Self {
1120 assert_eq!(pair.as_rule(), Self::rule());
1121 let join_type_str = pair.as_str().trim_start_matches('&');
1122 match join_type_str {
1123 "Inner" => join_rel::JoinType::Inner,
1124 "Left" => join_rel::JoinType::Left,
1125 "Right" => join_rel::JoinType::Right,
1126 "Outer" => join_rel::JoinType::Outer,
1127 "LeftSemi" => join_rel::JoinType::LeftSemi,
1128 "RightSemi" => join_rel::JoinType::RightSemi,
1129 "LeftAnti" => join_rel::JoinType::LeftAnti,
1130 "RightAnti" => join_rel::JoinType::RightAnti,
1131 "LeftSingle" => join_rel::JoinType::LeftSingle,
1132 "RightSingle" => join_rel::JoinType::RightSingle,
1133 "LeftMark" => join_rel::JoinType::LeftMark,
1134 "RightMark" => join_rel::JoinType::RightMark,
1135 _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
1136 }
1137 }
1138}
1139
1140impl RelationParsePair for JoinRel {
1141 fn rule() -> Rule {
1142 Rule::join_relation
1143 }
1144
1145 fn message() -> &'static str {
1146 "JoinRel"
1147 }
1148
1149 fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1150 self.advanced_extension = adv_ext;
1151 Rel {
1152 rel_type: Some(RelType::Join(Box::new(self))),
1153 }
1154 }
1155
1156 fn parse_pair_with_context(
1157 extensions: &SimpleExtensions,
1158 pair: Pair<Rule>,
1159 input_children: Vec<Rel>,
1160 input_field_count: usize,
1161 ) -> Result<(Self, usize), MessageParseError> {
1162 assert_eq!(pair.as_rule(), Self::rule());
1163
1164 if input_children.len() != 2 {
1165 return Err(MessageParseError::invalid(
1166 Self::message(),
1167 pair.as_span(),
1168 format!(
1169 "JoinRel should have exactly 2 input children, got {}",
1170 input_children.len()
1171 ),
1172 ));
1173 }
1174
1175 let mut children_iter = input_children.into_iter();
1176 let left = Box::new(children_iter.next().unwrap());
1177 let right = Box::new(children_iter.next().unwrap());
1178
1179 let mut iter = RuleIter::from(pair.into_inner());
1180 let join_type = iter.parse_next::<join_rel::JoinType>();
1181 let condition = iter.parse_next_scoped::<Expression>(extensions)?;
1182 let reference_list_pair = iter.pop(Rule::reference_list);
1183 iter.done();
1184
1185 let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1189 let common = RelCommon {
1190 emit_kind: Some(emit),
1191 ..Default::default()
1192 };
1193
1194 Ok((
1195 JoinRel {
1196 common: Some(common),
1197 left: Some(left),
1198 right: Some(right),
1199 expression: Some(Box::new(condition)),
1200 post_join_filter: None, r#type: join_type as i32,
1202 advanced_extension: None,
1203 },
1204 output_count,
1205 ))
1206 }
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211 use pest::Parser;
1212
1213 use super::*;
1214 use crate::fixtures::TestContext;
1215 use crate::parser::{ExpressionParser, Rule};
1216
1217 #[test]
1218 fn test_parse_relation() {
1219 }
1221
1222 #[test]
1223 fn test_parse_read_relation() {
1224 let extensions = SimpleExtensions::default();
1225 let read = ReadRel::parse_pair_with_context(
1226 &extensions,
1227 parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1228 vec![],
1229 0,
1230 )
1231 .unwrap()
1232 .0;
1233 let names = match &read.read_type {
1234 Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1235 _ => panic!("Expected NamedTable"),
1236 };
1237 assert_eq!(names, &["ab", "cd", "ef"]);
1238 let columns = &read
1239 .base_schema
1240 .as_ref()
1241 .unwrap()
1242 .r#struct
1243 .as_ref()
1244 .unwrap()
1245 .types;
1246 assert_eq!(columns.len(), 2);
1247 }
1248
1249 fn example_read_relation() -> ReadRel {
1251 let extensions = SimpleExtensions::default();
1252 ReadRel::parse_pair_with_context(
1253 &extensions,
1254 parse_exact(
1255 Rule::read_relation,
1256 "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1257 ),
1258 vec![],
1259 0,
1260 )
1261 .unwrap()
1262 .0
1263 }
1264
1265 #[test]
1266 fn test_parse_filter_relation() {
1267 let extensions = SimpleExtensions::default();
1268 let filter = FilterRel::parse_pair_with_context(
1269 &extensions,
1270 parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1271 vec![example_read_relation().into_rel(None)],
1272 3,
1273 )
1274 .unwrap()
1275 .0;
1276 let emit_kind = filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1278 assert!(
1279 matches!(emit_kind, EmitKind::Direct(_)),
1280 "Expected Direct for identity emit, got {emit_kind:?}"
1281 );
1282 }
1283
1284 #[test]
1285 fn test_parse_project_relation() {
1286 let extensions = SimpleExtensions::default();
1287 let project = ProjectRel::parse_pair_with_context(
1288 &extensions,
1289 parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1290 vec![example_read_relation().into_rel(None)],
1291 3,
1292 )
1293 .unwrap()
1294 .0;
1295
1296 assert_eq!(project.expressions.len(), 1);
1298
1299 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1300 let emit = match emit_kind {
1301 EmitKind::Emit(emit) => &emit.output_mapping,
1302 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1303 };
1304 assert_eq!(emit, &[0, 1, 3]);
1306 }
1307
1308 #[test]
1309 fn test_parse_project_relation_complex() {
1310 let extensions = SimpleExtensions::default();
1311 let project = ProjectRel::parse_pair_with_context(
1312 &extensions,
1313 parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1314 vec![example_read_relation().into_rel(None)],
1315 5, )
1317 .unwrap()
1318 .0;
1319
1320 assert_eq!(project.expressions.len(), 2);
1322
1323 let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1324 let emit = match emit_kind {
1325 EmitKind::Emit(emit) => &emit.output_mapping,
1326 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1327 };
1328 assert_eq!(emit, &[5, 0, 6, 2, 1]);
1331 }
1332
1333 #[test]
1334 fn test_parse_aggregate_relation() {
1335 let extensions = TestContext::new()
1336 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1337 .with_function(1, 10, "sum")
1338 .with_function(1, 11, "count")
1339 .extensions;
1340
1341 let aggregate = AggregateRel::parse_pair_with_context(
1342 &extensions,
1343 parse_exact(
1344 Rule::aggregate_relation,
1345 "Aggregate[($0, $1), _ => sum($2):i64, $0, count($2):i64]",
1346 ),
1347 vec![example_read_relation().into_rel(None)],
1348 3,
1349 )
1350 .unwrap()
1351 .0;
1352
1353 assert_eq!(aggregate.grouping_expressions.len(), 2);
1355 assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1356 assert_eq!(aggregate.groupings.len(), 2);
1357 assert_eq!(aggregate.measures.len(), 2);
1358
1359 let emit_kind = &aggregate
1360 .common
1361 .as_ref()
1362 .unwrap()
1363 .emit_kind
1364 .as_ref()
1365 .unwrap();
1366 let emit = match emit_kind {
1367 EmitKind::Emit(emit) => &emit.output_mapping,
1368 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1369 };
1370 assert_eq!(emit, &[2, 0, 3]);
1373 }
1374
1375 #[test]
1376 fn test_parse_aggregate_relation_maintain_column_order() {
1377 let extensions = TestContext::new()
1378 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1379 .with_function(1, 10, "sum")
1380 .with_function(1, 11, "count")
1381 .extensions;
1382
1383 let aggregate = AggregateRel::parse_pair_with_context(
1384 &extensions,
1385 parse_exact(
1386 Rule::aggregate_relation,
1387 "Aggregate[$0 => sum($1):i64, $0, count($1):i64]",
1388 ),
1389 vec![example_read_relation().into_rel(None)],
1390 3,
1391 )
1392 .unwrap()
1393 .0;
1394
1395 assert_eq!(aggregate.grouping_expressions.len(), 1);
1397 assert_eq!(aggregate.groupings.len(), 1);
1398 assert_eq!(aggregate.measures.len(), 2);
1399
1400 let emit_kind = &aggregate
1401 .common
1402 .as_ref()
1403 .unwrap()
1404 .emit_kind
1405 .as_ref()
1406 .unwrap();
1407 let emit = match emit_kind {
1408 EmitKind::Emit(emit) => &emit.output_mapping,
1409 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1410 };
1411 assert_eq!(emit, &[1, 0, 2]);
1413 }
1414
1415 #[test]
1416 fn test_parse_aggregate_relation_simple() {
1417 let extensions = TestContext::new()
1418 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1419 .with_function(1, 10, "sum")
1420 .extensions;
1421
1422 let aggregate = AggregateRel::parse_pair_with_context(
1423 &extensions,
1424 parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1):i64]"),
1425 vec![example_read_relation().into_rel(None)],
1426 3,
1427 )
1428 .unwrap()
1429 .0;
1430
1431 assert_eq!(aggregate.grouping_expressions.len(), 2);
1432 assert_eq!(aggregate.groupings.len(), 1);
1433 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1435 }
1436
1437 #[test]
1438 fn test_parse_aggregate_relation_global_aggregate() {
1439 let extensions = TestContext::new()
1440 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1441 .with_function(1, 10, "sum")
1442 .with_function(1, 11, "count")
1443 .extensions;
1444
1445 let aggregate = AggregateRel::parse_pair_with_context(
1446 &extensions,
1447 parse_exact(
1448 Rule::aggregate_relation,
1449 "Aggregate[_ => sum($0):i64, count($1):i64]",
1450 ),
1451 vec![example_read_relation().into_rel(None)],
1452 3,
1453 )
1454 .unwrap()
1455 .0;
1456
1457 assert_eq!(aggregate.grouping_expressions.len(), 0);
1459 assert_eq!(aggregate.groupings.len(), 1);
1460 assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1461 assert_eq!(aggregate.measures.len(), 2);
1462
1463 let emit_kind = aggregate
1465 .common
1466 .as_ref()
1467 .unwrap()
1468 .emit_kind
1469 .as_ref()
1470 .unwrap();
1471 assert!(
1472 matches!(emit_kind, EmitKind::Direct(_)),
1473 "Expected Direct for identity emit, got {emit_kind:?}"
1474 );
1475 }
1476
1477 #[test]
1478 fn test_parse_aggregate_relation_grouping_sets() {
1479 let extensions = TestContext::new()
1480 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1481 .with_function(1, 11, "count")
1482 .extensions;
1483
1484 let read_rel = ReadRel::parse_pair_with_context(
1485 &extensions,
1486 parse_exact(
1487 Rule::read_relation,
1488 "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1489 ),
1490 vec![],
1491 0,
1492 )
1493 .unwrap()
1494 .0;
1495
1496 let aggregate = AggregateRel::parse_pair_with_context(
1497 &extensions,
1498 parse_exact(
1499 Rule::aggregate_relation,
1500 "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3):i64]",
1501 ),
1502 vec![read_rel.into_rel(None)],
1503 4,
1504 )
1505 .unwrap()
1506 .0;
1507
1508 assert_eq!(aggregate.grouping_expressions.len(), 3);
1509 assert_eq!(aggregate.groupings.len(), 4);
1510 assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1512 assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1514 assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1516 assert!(aggregate.groupings[3].expression_references.is_empty());
1518 assert_eq!(aggregate.measures.len(), 1);
1519 }
1520
1521 #[test]
1522 fn test_fetch_relation_positive_values() {
1523 let extensions = SimpleExtensions::default();
1524
1525 let fetch_rel = FetchRel::parse_pair_with_context(
1527 &extensions,
1528 parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1529 vec![example_read_relation().into_rel(None)],
1530 3,
1531 )
1532 .unwrap()
1533 .0;
1534
1535 assert_eq!(
1537 fetch_rel.count_mode,
1538 Some(CountMode::CountExpr(i64_literal_expr(10)))
1539 );
1540 assert_eq!(
1541 fetch_rel.offset_mode,
1542 Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1543 );
1544 }
1545
1546 #[test]
1547 fn test_fetch_relation_negative_limit_rejected() {
1548 let extensions = SimpleExtensions::default();
1549
1550 let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1552 if let Ok(mut pairs) = parsed_result {
1553 let pair = pairs.next().unwrap();
1554 if pair.as_str() == "Fetch[limit=-5 => $0]" {
1555 let result = FetchRel::parse_pair_with_context(
1557 &extensions,
1558 pair,
1559 vec![example_read_relation().into_rel(None)],
1560 3,
1561 );
1562 assert!(result.is_err());
1563 let error_msg = result.unwrap_err().to_string();
1564 assert!(error_msg.contains("Fetch limit must be non-negative"));
1565 } else {
1566 println!("Grammar prevents negative limit values at parse time");
1569 }
1570 } else {
1571 println!("Grammar prevents negative limit values at parse time");
1573 }
1574 }
1575
1576 #[test]
1577 fn test_fetch_relation_negative_offset_rejected() {
1578 let extensions = SimpleExtensions::default();
1579
1580 let parsed_result =
1582 ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1583 if let Ok(mut pairs) = parsed_result {
1584 let pair = pairs.next().unwrap();
1585 if pair.as_str() == "Fetch[offset=-10 => $0]" {
1586 let result = FetchRel::parse_pair_with_context(
1588 &extensions,
1589 pair,
1590 vec![example_read_relation().into_rel(None)],
1591 3,
1592 );
1593 assert!(result.is_err());
1594 let error_msg = result.unwrap_err().to_string();
1595 assert!(error_msg.contains("Fetch offset must be non-negative"));
1596 } else {
1597 println!("Grammar prevents negative offset values at parse time");
1599 }
1600 } else {
1601 println!("Grammar prevents negative offset values at parse time");
1603 }
1604 }
1605
1606 #[test]
1607 fn test_parse_join_relation() {
1608 let extensions = TestContext::new()
1609 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1610 .with_function(1, 10, "eq")
1611 .extensions;
1612
1613 let left_rel = example_read_relation().into_rel(None);
1614 let right_rel = example_read_relation().into_rel(None);
1615
1616 let join = JoinRel::parse_pair_with_context(
1617 &extensions,
1618 parse_exact(
1619 Rule::join_relation,
1620 "Join[&Inner, eq($0, $3):boolean => $0, $1, $3, $4]",
1621 ),
1622 vec![left_rel, right_rel],
1623 6, )
1625 .unwrap()
1626 .0;
1627
1628 assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1630
1631 assert!(join.left.is_some());
1633 assert!(join.right.is_some());
1634
1635 assert!(join.expression.is_some());
1637
1638 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1639 let emit = match emit_kind {
1640 EmitKind::Emit(emit) => &emit.output_mapping,
1641 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1642 };
1643 assert_eq!(emit, &[0, 1, 3, 4]);
1645 }
1646
1647 #[test]
1648 fn test_parse_join_relation_left_outer() {
1649 let extensions = TestContext::new()
1650 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1651 .with_function(1, 10, "eq")
1652 .extensions;
1653
1654 let left_rel = example_read_relation().into_rel(None);
1655 let right_rel = example_read_relation().into_rel(None);
1656
1657 let join = JoinRel::parse_pair_with_context(
1658 &extensions,
1659 parse_exact(
1660 Rule::join_relation,
1661 "Join[&Left, eq($0, $3):boolean => $0, $1, $2]",
1662 ),
1663 vec![left_rel, right_rel],
1664 6,
1665 )
1666 .unwrap()
1667 .0;
1668
1669 assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1671
1672 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1673 let emit = match emit_kind {
1674 EmitKind::Emit(emit) => &emit.output_mapping,
1675 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1676 };
1677 assert_eq!(emit, &[0, 1, 2]);
1679 }
1680
1681 #[test]
1682 fn test_parse_join_relation_left_semi() {
1683 let extensions = TestContext::new()
1684 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1685 .with_function(1, 10, "eq")
1686 .extensions;
1687
1688 let left_rel = example_read_relation().into_rel(None);
1689 let right_rel = example_read_relation().into_rel(None);
1690
1691 let join = JoinRel::parse_pair_with_context(
1692 &extensions,
1693 parse_exact(
1694 Rule::join_relation,
1695 "Join[&LeftSemi, eq($0, $3):boolean => $0, $1]",
1696 ),
1697 vec![left_rel, right_rel],
1698 6,
1699 )
1700 .unwrap()
1701 .0;
1702
1703 assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1705
1706 let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1707 let emit = match emit_kind {
1708 EmitKind::Emit(emit) => &emit.output_mapping,
1709 _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1710 };
1711 assert_eq!(emit, &[0, 1]);
1713 }
1714
1715 #[test]
1716 fn test_parse_join_relation_requires_two_children() {
1717 let extensions = SimpleExtensions::default();
1718
1719 let result = JoinRel::parse_pair_with_context(
1721 &extensions,
1722 parse_exact(
1723 Rule::join_relation,
1724 "Join[&Inner, eq($0, $1):boolean => $0, $1]",
1725 ),
1726 vec![],
1727 0,
1728 );
1729 assert!(result.is_err());
1730
1731 let result = JoinRel::parse_pair_with_context(
1733 &extensions,
1734 parse_exact(
1735 Rule::join_relation,
1736 "Join[&Inner, eq($0, $1):boolean => $0, $1]",
1737 ),
1738 vec![example_read_relation().into_rel(None)],
1739 3,
1740 );
1741 assert!(result.is_err());
1742
1743 let result = JoinRel::parse_pair_with_context(
1745 &extensions,
1746 parse_exact(
1747 Rule::join_relation,
1748 "Join[&Inner, eq($0, $1):boolean => $0, $1]",
1749 ),
1750 vec![
1751 example_read_relation().into_rel(None),
1752 example_read_relation().into_rel(None),
1753 example_read_relation().into_rel(None),
1754 ],
1755 9,
1756 );
1757 assert!(result.is_err());
1758 }
1759
1760 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1761 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1762 assert_eq!(pairs.as_str(), input);
1763 let pair = pairs.next().unwrap();
1764 assert_eq!(pairs.next(), None);
1765 pair
1766 }
1767}