Skip to main content

substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType, nested};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::{CountMode, OffsetMode};
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::{Direct, Emit, EmitKind};
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
15    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
16};
17
18use super::{MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair};
19use crate::extensions::any::Any;
20use crate::extensions::registry::ExtensionError;
21use crate::extensions::{AddendumKind, ExtensionArgs, ExtensionRegistry, SimpleExtensions};
22use crate::parser::errors::{ParseContext, ParseError};
23use crate::parser::expressions::{FieldIndex, Name};
24
25/// Parsing context for relations that includes extensions, registry, and optional warning collection
26pub struct RelationParsingContext<'a> {
27    pub extensions: &'a SimpleExtensions,
28    pub registry: &'a ExtensionRegistry,
29    pub line_no: i64,
30    pub line: &'a str,
31}
32
33impl<'a> RelationParsingContext<'a> {
34    /// Resolve extension detail using registry. Any failure is treated as a hard parse error.
35    pub fn resolve_extension_detail(
36        &self,
37        extension_name: &str,
38        extension_args: &ExtensionArgs,
39    ) -> Result<Option<Any>, ParseError> {
40        let detail = self
41            .registry
42            .parse_extension(extension_name, extension_args);
43
44        match detail {
45            Ok(any) => Ok(Some(any)),
46            Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
47                name: extension_name.to_string(),
48                context: ParseContext::new(self.line_no, self.line.to_string()),
49            }),
50            Err(err) => Err(ParseError::ExtensionDetail(
51                ParseContext::new(self.line_no, self.line.to_string()),
52                err,
53            )),
54        }
55    }
56
57    /// Resolve an addendum detail using the registry entry for its kind.
58    /// Any failure is treated as a hard parse error.
59    pub(crate) fn resolve_addendum_detail(
60        &self,
61        kind: AddendumKind,
62        name: &str,
63        args: &ExtensionArgs,
64    ) -> Result<Any, ParseError> {
65        let result = match kind {
66            AddendumKind::Enhancement => self.registry.parse_enhancement(name, args),
67            AddendumKind::Optimization => self.registry.parse_optimization(name, args),
68            AddendumKind::ExtensionTable => self.registry.parse_extension_table(name, args),
69        };
70        result.map_err(|err| match err {
71            ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
72                name: name.to_string(),
73                context: ParseContext::new(self.line_no, self.line.to_string()),
74            },
75            err => ParseError::ExtensionDetail(
76                ParseContext::new(self.line_no, self.line.to_string()),
77                err,
78            ),
79        })
80    }
81}
82
83/// A trait for parsing relations with full context for tree building.
84pub trait RelationParsePair: Sized {
85    fn rule() -> Rule;
86    fn message() -> &'static str;
87
88    /// Parse the grammar pair into this relation type and its output field
89    /// count.
90    ///
91    /// Returns `(Self, usize)` where `usize` is the output field count —
92    /// computed during parsing when `input_field_count` is available.
93    fn parse_pair_with_context(
94        extensions: &SimpleExtensions,
95        pair: Pair<Rule>,
96        input_children: Vec<Box<Rel>>,
97        input_field_count: usize,
98    ) -> Result<(Self, usize), MessageParseError>;
99
100    /// Consume this parsed relation, apply the advanced extension, and produce
101    /// the final `Rel`.
102    fn into_rel(self, adv_ext: Option<AdvancedExtension>) -> Rel;
103}
104
105pub struct TableName(Vec<String>);
106
107impl ParsePair for TableName {
108    fn rule() -> Rule {
109        Rule::table_name
110    }
111
112    fn message() -> &'static str {
113        "TableName"
114    }
115
116    fn parse_pair(pair: Pair<Rule>) -> Self {
117        assert_eq!(pair.as_rule(), Self::rule());
118        let pairs = pair.into_inner();
119        let mut names = Vec::with_capacity(pairs.len());
120        let mut iter = RuleIter::from(pairs);
121        while let Some(name) = iter.parse_if_next::<Name>() {
122            names.push(name.0);
123        }
124        iter.done();
125        Self(names)
126    }
127}
128
129#[derive(Debug, Clone)]
130pub struct Column {
131    pub name: String,
132    pub typ: Type,
133}
134
135impl ScopedParsePair for Column {
136    fn rule() -> Rule {
137        Rule::named_column
138    }
139
140    fn message() -> &'static str {
141        "Column"
142    }
143
144    fn parse_pair(
145        extensions: &SimpleExtensions,
146        pair: Pair<Rule>,
147    ) -> Result<Self, MessageParseError> {
148        assert_eq!(pair.as_rule(), Self::rule());
149        let mut iter = RuleIter::from(pair.into_inner());
150        let name = iter.parse_next::<Name>().0;
151        let typ = iter.parse_next_scoped(extensions)?;
152        iter.done();
153        Ok(Self { name, typ })
154    }
155}
156
157pub(crate) struct NamedColumnList(pub(crate) Vec<Column>);
158
159impl ScopedParsePair for NamedColumnList {
160    fn rule() -> Rule {
161        Rule::named_column_list
162    }
163
164    fn message() -> &'static str {
165        "NamedColumnList"
166    }
167
168    fn parse_pair(
169        extensions: &SimpleExtensions,
170        pair: Pair<Rule>,
171    ) -> Result<Self, MessageParseError> {
172        assert_eq!(pair.as_rule(), Self::rule());
173        let mut columns = Vec::new();
174        for col in pair.into_inner() {
175            columns.push(Column::parse_pair(extensions, col)?);
176        }
177        Ok(Self(columns))
178    }
179}
180
181/// This is a utility function for extracting a single child from the list of
182/// children, to be used in the RelationParsePair trait. The RelationParsePair
183/// trait passes a Vec of children, because some relations have multiple
184/// children - but most accept exactly one child.
185#[allow(clippy::vec_box)]
186pub(crate) fn expect_one_child(
187    message: &'static str,
188    pair: &Pair<Rule>,
189    mut input_children: Vec<Box<Rel>>,
190) -> Result<Box<Rel>, MessageParseError> {
191    match input_children.len() {
192        0 => Err(MessageParseError::invalid(
193            message,
194            pair.as_span(),
195            format!("{message} missing child"),
196        )),
197        1 => Ok(input_children.pop().unwrap()),
198        n => Err(MessageParseError::invalid(
199            message,
200            pair.as_span(),
201            format!("{message} should have 1 input child, got {n}"),
202        )),
203    }
204}
205
206/// Parse a reference list Pair and return an EmitKind::Emit.
207/// Parse a reference list into field indices for emit mapping.
208fn parse_output_mapping(pair: Pair<Rule>) -> Vec<i32> {
209    assert_eq!(pair.as_rule(), Rule::reference_list);
210    pair.into_inner()
211        .map(|p| FieldIndex::parse_pair(p).0)
212        .collect()
213}
214
215/// Build an emit: `Direct` if the mapping is the identity `[0, 1, ..., N-1]`
216/// (where N = `direct_output_count`), otherwise `Emit` with the explicit mapping.
217fn make_emit(output_mapping: Vec<i32>, direct_output_count: usize) -> EmitKind {
218    let is_identity = output_mapping.len() == direct_output_count
219        && output_mapping
220            .iter()
221            .enumerate()
222            .all(|(i, &v)| v == i as i32);
223    if is_identity {
224        EmitKind::Direct(Direct {})
225    } else {
226        EmitKind::Emit(Emit { output_mapping })
227    }
228}
229
230/// Parse a reference list into an emit and output field count.
231fn parse_emit(reference_list: Pair<Rule>, direct_output_count: usize) -> (EmitKind, usize) {
232    let output_mapping = parse_output_mapping(reference_list);
233    let output_count = output_mapping.len();
234    let emit = make_emit(output_mapping, direct_output_count);
235    (emit, output_count)
236}
237
238/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
239///
240/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
241///
242/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
243pub struct ParsedNamedArgs<'a> {
244    map: HashMap<&'a str, Pair<'a, Rule>>,
245}
246
247impl<'a> ParsedNamedArgs<'a> {
248    pub fn new(
249        pairs: pest::iterators::Pairs<'a, Rule>,
250        rule: Rule,
251    ) -> Result<Self, MessageParseError> {
252        let mut map = HashMap::new();
253        for pair in pairs {
254            assert_eq!(pair.as_rule(), rule);
255            let mut inner = pair.clone().into_inner();
256            let name_pair = inner.next().unwrap();
257            let value_pair = inner.next().unwrap();
258            assert_eq!(inner.next(), None);
259            let name = name_pair.as_str();
260            if map.contains_key(name) {
261                return Err(MessageParseError::invalid(
262                    "NamedArg",
263                    name_pair.as_span(),
264                    format!("Duplicate argument: {name}"),
265                ));
266            }
267            map.insert(name, value_pair);
268        }
269        Ok(Self { map })
270    }
271
272    // Returns the pair if it exists and matches the rule, otherwise None.
273    // Asserts that the rule must match the rule of the pair (and therefore
274    // panics in non-release-mode if not)
275    pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
276        let pair = self.map.remove(name).inspect(|pair| {
277            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
278        });
279        (self, pair)
280    }
281
282    // Returns an error if there are any unused arguments.
283    pub fn done(self) -> Result<(), MessageParseError> {
284        if let Some((name, pair)) = self.map.iter().next() {
285            return Err(MessageParseError::invalid(
286                "NamedArgExtractor",
287                // No span available for all unused args; use default.
288                pair.as_span(),
289                format!("Unknown argument: {name}"),
290            ));
291        }
292        Ok(())
293    }
294}
295
296impl RelationParsePair for ReadRel {
297    fn rule() -> Rule {
298        Rule::read_relation
299    }
300
301    fn message() -> &'static str {
302        "ReadRel"
303    }
304
305    fn parse_pair_with_context(
306        extensions: &SimpleExtensions,
307        pair: Pair<Rule>,
308        input_children: Vec<Box<Rel>>,
309        input_field_count: usize,
310    ) -> Result<(Self, usize), MessageParseError> {
311        assert_eq!(pair.as_rule(), Self::rule());
312        // ReadRel is a leaf node - it should have no input children and 0 input fields
313        if !input_children.is_empty() {
314            return Err(MessageParseError::invalid(
315                Self::message(),
316                pair.as_span(),
317                "ReadRel should have no input children",
318            ));
319        }
320        if input_field_count != 0 {
321            return Err(MessageParseError::invalid(
322                "ReadRel",
323                pair.as_span(),
324                "ReadRel should have 0 input fields",
325            ));
326        }
327
328        let mut iter = RuleIter::from(pair.into_inner());
329        let table = iter.parse_next::<TableName>().0;
330        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
331        iter.done();
332
333        let output_count = columns.len();
334        Ok((
335            ReadRel {
336                base_schema: Some(build_named_struct(columns)),
337                read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
338                    names: table,
339                    advanced_extension: None,
340                })),
341                ..Default::default()
342            },
343            output_count,
344        ))
345    }
346
347    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
348        self.advanced_extension = adv_ext;
349        Rel {
350            rel_type: Some(RelType::Read(Box::new(self))),
351        }
352    }
353}
354
355/// Parsed `Read:Virtual[rows => columns]` relation. Needs a newtype because the
356/// proto type is `ReadRel` (same as `NamedTable`), but the grammar and handling
357/// are different.
358pub(crate) struct VirtualReadRel(ReadRel);
359
360impl RelationParsePair for VirtualReadRel {
361    fn rule() -> Rule {
362        Rule::virtual_read_relation
363    }
364
365    fn message() -> &'static str {
366        "VirtualReadRel"
367    }
368
369    fn parse_pair_with_context(
370        extensions: &SimpleExtensions,
371        pair: Pair<Rule>,
372        input_children: Vec<Box<Rel>>,
373        _input_field_count: usize,
374    ) -> Result<(Self, usize), MessageParseError> {
375        assert_eq!(pair.as_rule(), Self::rule());
376        if !input_children.is_empty() {
377            return Err(MessageParseError::invalid(
378                Self::message(),
379                pair.as_span(),
380                "Read:Virtual should have no input children",
381            ));
382        }
383
384        let mut iter = RuleIter::from(pair.into_inner());
385        let args_pair = iter.pop(Rule::virtual_read_args);
386        let columns_pair = iter.pop(Rule::named_column_list);
387        iter.done();
388
389        let rows = parse_virtual_read_args(extensions, args_pair)?;
390        let columns = NamedColumnList::parse_pair(extensions, columns_pair)?.0;
391
392        // TODO: Validate that each row has the same number of expressions as
393        // columns. Currently no parser-side warning mechanism exists, and while
394        // this is an invalid plan, it is constructible as Substrait. Consider
395        // adding once a warning collection path is available.
396        let output_count = columns.len();
397        Ok((
398            VirtualReadRel(ReadRel {
399                base_schema: Some(build_named_struct(columns)),
400                read_type: Some(read_rel::ReadType::VirtualTable(read_rel::VirtualTable {
401                    expressions: rows,
402                    ..Default::default()
403                })),
404                ..Default::default()
405            }),
406            output_count,
407        ))
408    }
409
410    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
411        self.0.advanced_extension = adv_ext;
412        Rel {
413            rel_type: Some(RelType::Read(Box::new(self.0))),
414        }
415    }
416}
417
418/// Parsed `Read:Extension[columns]` relation. Needs a newtype because the
419/// proto type is `ReadRel` (same as `NamedTable` and `VirtualTable`), but the
420/// read detail is supplied by a required `+ Ext:` addendum.
421pub(crate) struct ExtensionReadRel(ReadRel);
422
423impl ExtensionReadRel {
424    // Clippy allow: normally Vec<Box<…>> is a warning, because Vec is already
425    // on the heap, so Vec<Box<…>> just adds a layer of indirection. Here, the
426    // protobuf generated code already comes with boxes, so not using boxes
427    // would mean copying things around, which would be more wasteful. So we
428    // allow(…) it.
429    #[allow(clippy::vec_box)]
430    pub(crate) fn parse_pair_with_detail(
431        extensions: &SimpleExtensions,
432        pair: Pair<Rule>,
433        input_children: Vec<Box<Rel>>,
434        input_field_count: usize,
435        detail: Any,
436        advanced_extension: Option<AdvancedExtension>,
437    ) -> Result<(Rel, usize), MessageParseError> {
438        assert_eq!(pair.as_rule(), Rule::extension_read_relation);
439        if !input_children.is_empty() {
440            return Err(MessageParseError::invalid(
441                "ExtensionReadRel",
442                pair.as_span(),
443                "Read:Extension should have no input children",
444            ));
445        }
446        if input_field_count != 0 {
447            return Err(MessageParseError::invalid(
448                "ExtensionReadRel",
449                pair.as_span(),
450                "Read:Extension should have 0 input fields",
451            ));
452        }
453
454        let mut iter = RuleIter::from(pair.into_inner());
455        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
456        iter.done();
457
458        let output_count = columns.len();
459        let rel = ExtensionReadRel(ReadRel {
460            base_schema: Some(build_named_struct(columns)),
461            read_type: Some(read_rel::ReadType::ExtensionTable(
462                read_rel::ExtensionTable {
463                    detail: Some(detail.into()),
464                },
465            )),
466            advanced_extension,
467            ..Default::default()
468        })
469        .into_rel();
470
471        Ok((rel, output_count))
472    }
473
474    fn into_rel(self) -> Rel {
475        Rel {
476            rel_type: Some(RelType::Read(Box::new(self.0))),
477        }
478    }
479}
480
481/// Build a `NamedStruct` from parsed columns.
482pub(crate) fn build_named_struct(columns: Vec<Column>) -> NamedStruct {
483    let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
484    NamedStruct {
485        names,
486        r#struct: Some(r#type::Struct {
487            types,
488            type_variation_reference: 0,
489            nullability: r#type::Nullability::Required as i32,
490        }),
491    }
492}
493
494/// `Read:Virtual` positional args: either `empty` or a list of row tuples.
495fn parse_virtual_read_args(
496    extensions: &SimpleExtensions,
497    pair: Pair<Rule>,
498) -> Result<Vec<nested::Struct>, MessageParseError> {
499    assert_eq!(pair.as_rule(), Rule::virtual_read_args);
500    let inner = unwrap_single_pair(pair);
501    match inner.as_rule() {
502        Rule::empty => Ok(vec![]),
503        Rule::virtual_row_list => inner
504            .into_inner()
505            .map(|row| parse_virtual_row(extensions, row))
506            .collect(),
507        _ => unreachable!(
508            "Unexpected rule in virtual_read_args: {:?}",
509            inner.as_rule()
510        ),
511    }
512}
513
514/// Parse a single `virtual_row` (`(expr, expr, ...)` or `()`) into a `nested::Struct`.
515fn parse_virtual_row(
516    extensions: &SimpleExtensions,
517    pair: Pair<Rule>,
518) -> Result<nested::Struct, MessageParseError> {
519    assert_eq!(pair.as_rule(), Rule::virtual_row);
520    let fields = match pair.into_inner().next() {
521        Some(expression_list) => {
522            assert_eq!(expression_list.as_rule(), Rule::expression_list);
523            parse_expression_list(extensions, expression_list)?
524        }
525        // Empty virtual row. An unusual but valid case.
526        None => vec![],
527    };
528    Ok(nested::Struct { fields })
529}
530
531impl RelationParsePair for FilterRel {
532    fn rule() -> Rule {
533        Rule::filter_relation
534    }
535
536    fn message() -> &'static str {
537        "FilterRel"
538    }
539
540    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
541        self.advanced_extension = adv_ext;
542        Rel {
543            rel_type: Some(RelType::Filter(Box::new(self))),
544        }
545    }
546
547    fn parse_pair_with_context(
548        extensions: &SimpleExtensions,
549        pair: Pair<Rule>,
550        input_children: Vec<Box<Rel>>,
551        input_field_count: usize,
552    ) -> Result<(Self, usize), MessageParseError> {
553        assert_eq!(pair.as_rule(), Self::rule());
554        let input = expect_one_child(Self::message(), &pair, input_children)?;
555        let mut iter = RuleIter::from(pair.into_inner());
556        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
557        // references (which become the emit)
558        let references_pair = iter.pop(Rule::reference_list);
559        iter.done();
560
561        let (emit, output_count) = parse_emit(references_pair, input_field_count);
562        let common = RelCommon {
563            emit_kind: Some(emit),
564            ..Default::default()
565        };
566
567        Ok((
568            FilterRel {
569                input: Some(input),
570                condition: Some(Box::new(condition)),
571                common: Some(common),
572                advanced_extension: None,
573            },
574            output_count,
575        ))
576    }
577}
578
579impl RelationParsePair for ProjectRel {
580    fn rule() -> Rule {
581        Rule::project_relation
582    }
583
584    fn message() -> &'static str {
585        "ProjectRel"
586    }
587
588    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
589        self.advanced_extension = adv_ext;
590        Rel {
591            rel_type: Some(RelType::Project(Box::new(self))),
592        }
593    }
594
595    fn parse_pair_with_context(
596        extensions: &SimpleExtensions,
597        pair: Pair<Rule>,
598        input_children: Vec<Box<Rel>>,
599        input_field_count: usize,
600    ) -> Result<(Self, usize), MessageParseError> {
601        assert_eq!(pair.as_rule(), Self::rule());
602        let input = expect_one_child(Self::message(), &pair, input_children)?;
603
604        let arguments_pair = unwrap_single_pair(pair);
605
606        let mut expressions = Vec::new();
607        let mut output_mapping = Vec::new();
608
609        for arg in arguments_pair.into_inner() {
610            let inner_arg = unwrap_single_pair(arg);
611            match inner_arg.as_rule() {
612                Rule::reference => {
613                    let field_index = FieldIndex::parse_pair(inner_arg);
614                    output_mapping.push(field_index.0);
615                }
616                Rule::expression => {
617                    let expr = Expression::parse_pair(extensions, inner_arg)?;
618                    expressions.push(expr);
619                    // Index into the combined schema: [input fields][computed expressions].
620                    output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
621                }
622                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
623            }
624        }
625
626        let output_count = output_mapping.len();
627        let direct_count = input_field_count + expressions.len();
628        let emit = make_emit(output_mapping, direct_count);
629        let common = RelCommon {
630            emit_kind: Some(emit),
631            ..Default::default()
632        };
633
634        Ok((
635            ProjectRel {
636                input: Some(input),
637                expressions,
638                common: Some(common),
639                advanced_extension: None,
640            },
641            output_count,
642        ))
643    }
644}
645
646impl RelationParsePair for AggregateRel {
647    fn rule() -> Rule {
648        Rule::aggregate_relation
649    }
650
651    fn message() -> &'static str {
652        "AggregateRel"
653    }
654
655    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
656        self.advanced_extension = adv_ext;
657        Rel {
658            rel_type: Some(RelType::Aggregate(Box::new(self))),
659        }
660    }
661
662    fn parse_pair_with_context(
663        extensions: &SimpleExtensions,
664        pair: Pair<Rule>,
665        input_children: Vec<Box<Rel>>,
666        // Aggregate defines its own output schema (grouping keys + measures),
667        // so the input field count isn't needed for emit construction.
668        _input_field_count: usize,
669    ) -> Result<(Self, usize), MessageParseError> {
670        assert_eq!(pair.as_rule(), Self::rule());
671        let input = expect_one_child(Self::message(), &pair, input_children)?;
672        let mut iter = RuleIter::from(pair.into_inner());
673        let group_by_pair = iter.pop(Rule::aggregate_group_by);
674        let output_pair = iter.pop(Rule::aggregate_output);
675        iter.done();
676
677        let inner = group_by_pair
678            .into_inner()
679            .next()
680            .expect("aggregate_group_by must have one inner item");
681
682        let grouping_sets = parse_grouping_sets(extensions, inner)?;
683        let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
684
685        let (measures, output_mapping) =
686            parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
687
688        let output_count = output_mapping.len();
689        let direct_count = grouping_expressions.len() + measures.len();
690        let emit = make_emit(output_mapping, direct_count);
691        let common = RelCommon {
692            emit_kind: Some(emit),
693            ..Default::default()
694        };
695
696        Ok((
697            AggregateRel {
698                input: Some(input),
699                grouping_expressions,
700                groupings,
701                measures,
702                common: Some(common),
703                advanced_extension: None,
704            },
705            output_count,
706        ))
707    }
708}
709
710/// Parses the output section of an aggregate (everything after `=>`).
711///
712/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
713/// this parses `sum($2), $0, count($2)`.
714fn parse_aggregate_measures(
715    extensions: &SimpleExtensions,
716    output_pair: Pair<'_, Rule>,
717    grouping_expressions: &[Expression],
718) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
719    assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
720    let mut measures = Vec::new();
721    let mut output_mapping = Vec::new();
722
723    for aggregate_output_item in output_pair.into_inner() {
724        let inner_item = unwrap_single_pair(aggregate_output_item);
725        match inner_item.as_rule() {
726            Rule::reference => {
727                let field_index = FieldIndex::parse_pair(inner_item);
728                output_mapping.push(field_index.0);
729            }
730            Rule::aggregate_measure => {
731                let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
732                output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
733                measures.push(measure);
734            }
735            _ => panic!(
736                "Unexpected inner output item rule: {:?}",
737                inner_item.as_rule()
738            ),
739        }
740    }
741
742    Ok((measures, output_mapping))
743}
744
745/// Parses the grouping section of an aggregate (everything before `=>`).
746///
747/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
748/// this parses `($0, $1), _`.
749///
750/// Each inner Vec is one grouping set; an empty vec represents no grouping (global aggregate).
751///
752/// Grammar: `aggregate_group_by = { grouping_set_list | expression_list }`
753fn parse_grouping_sets(
754    extensions: &SimpleExtensions,
755    inner: Pair<'_, Rule>,
756) -> Result<Vec<Vec<Expression>>, MessageParseError> {
757    assert!(
758        matches!(
759            inner.as_rule(),
760            Rule::expression_list | Rule::grouping_set_list
761        ),
762        "Expected expression_list or grouping_set_list, got {:?}",
763        inner.as_rule()
764    );
765    match inner.as_rule() {
766        Rule::expression_list => Ok(vec![parse_expression_list(extensions, inner)?]),
767        Rule::grouping_set_list => inner
768            .into_inner()
769            .map(|pair| parse_grouping_set(extensions, pair))
770            .collect(),
771        _ => unreachable!(
772            "Unexpected rule in aggregate_group_by: {:?}",
773            inner.as_rule()
774        ),
775    }
776}
777
778/// Parses a single grouping set, e.g. `($0, $1)` or `_`.
779///
780/// Grammar: `grouping_set = { ("(" ~ expression_list ~ ")") | empty }`
781fn parse_grouping_set(
782    extensions: &SimpleExtensions,
783    pair: Pair<'_, Rule>,
784) -> Result<Vec<Expression>, MessageParseError> {
785    assert_eq!(pair.as_rule(), Rule::grouping_set);
786    let inner = pair
787        .into_inner()
788        .next()
789        .expect("grouping_set must have one inner item");
790    match inner.as_rule() {
791        Rule::empty => Ok(vec![]),
792        Rule::expression_list => parse_expression_list(extensions, inner),
793        _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
794    }
795}
796
797/// Grammar: `expression_list = { expression ~ ("," ~ expression)* }`
798pub(crate) fn parse_expression_list(
799    extensions: &SimpleExtensions,
800    pair: Pair<'_, Rule>,
801) -> Result<Vec<Expression>, MessageParseError> {
802    pair.into_inner()
803        .map(|expr_pair| Expression::parse_pair(extensions, expr_pair))
804        .collect()
805}
806
807/// Deduplicates expressions across all sets and produces the AggregateRel's
808/// protobuf fields: a flat deduplicated expression list and per-set Grouping
809/// messages with index references into that list.
810fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
811    let mut expressions: Vec<Expression> = Vec::new();
812    let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
813
814    let groupings = expression_sets
815        .iter()
816        .map(|set| {
817            let expression_references = set
818                .iter()
819                .map(|exp| {
820                    // TODO: use a better key here than encoding to bytes.
821                    // Ideally, substrait-rs would support `PartialEq` and `Hash`,
822                    // but as there isn't an easy way to do that now, we'll skip.
823                    let key = exp.encode_to_vec();
824                    let next_idx = expressions.len() as u32;
825                    *seen.entry(key).or_insert_with(|| {
826                        expressions.push(exp.clone());
827                        next_idx
828                    })
829                })
830                .collect();
831            Grouping {
832                expression_references,
833                #[allow(deprecated)]
834                grouping_expressions: vec![],
835            }
836        })
837        .collect();
838
839    (groupings, expressions)
840}
841
842impl ScopedParsePair for SortField {
843    fn rule() -> Rule {
844        Rule::sort_field
845    }
846
847    fn message() -> &'static str {
848        "SortField"
849    }
850
851    fn parse_pair(
852        _extensions: &SimpleExtensions,
853        pair: Pair<Rule>,
854    ) -> Result<Self, MessageParseError> {
855        assert_eq!(pair.as_rule(), Self::rule());
856        let mut iter = RuleIter::from(pair.into_inner());
857        let reference_pair = iter.pop(Rule::reference);
858        let field_index = FieldIndex::parse_pair(reference_pair);
859        let direction_pair = iter.pop(Rule::sort_direction);
860        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
861        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
862        // identifiers, but the enum variant names don't include it
863        let direction = match direction_pair.as_str().trim_start_matches('&') {
864            "AscNullsFirst" => SortDirection::AscNullsFirst,
865            "AscNullsLast" => SortDirection::AscNullsLast,
866            "DescNullsFirst" => SortDirection::DescNullsFirst,
867            "DescNullsLast" => SortDirection::DescNullsLast,
868            other => {
869                return Err(MessageParseError::invalid(
870                    "SortDirection",
871                    direction_pair.as_span(),
872                    format!("Unknown sort direction: {other}"),
873                ));
874            }
875        };
876        iter.done();
877        Ok(SortField {
878            expr: Some(Expression {
879                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
880                    field_index.to_field_reference(),
881                ))),
882            }),
883            // TODO: Add support for SortKind::ComparisonFunctionReference
884            sort_kind: Some(SortKind::Direction(direction as i32)),
885        })
886    }
887}
888
889impl RelationParsePair for SortRel {
890    fn rule() -> Rule {
891        Rule::sort_relation
892    }
893
894    fn message() -> &'static str {
895        "SortRel"
896    }
897
898    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
899        self.advanced_extension = adv_ext;
900        Rel {
901            rel_type: Some(RelType::Sort(Box::new(self))),
902        }
903    }
904
905    fn parse_pair_with_context(
906        extensions: &SimpleExtensions,
907        pair: Pair<Rule>,
908        input_children: Vec<Box<Rel>>,
909        input_field_count: usize,
910    ) -> Result<(Self, usize), MessageParseError> {
911        assert_eq!(pair.as_rule(), Self::rule());
912        let input = expect_one_child(Self::message(), &pair, input_children)?;
913        let mut iter = RuleIter::from(pair.into_inner());
914        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
915        let reference_list_pair = iter.pop(Rule::reference_list);
916        let mut sorts = Vec::new();
917        for sort_field_pair in sort_field_list_pair.into_inner() {
918            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
919            sorts.push(sort_field);
920        }
921        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
922        let common = RelCommon {
923            emit_kind: Some(emit),
924            ..Default::default()
925        };
926        iter.done();
927        Ok((
928            SortRel {
929                input: Some(input),
930                sorts,
931                common: Some(common),
932                advanced_extension: None,
933            },
934            output_count,
935        ))
936    }
937}
938
939impl ScopedParsePair for CountMode {
940    fn rule() -> Rule {
941        Rule::fetch_value
942    }
943    fn message() -> &'static str {
944        "CountMode"
945    }
946    fn parse_pair(
947        extensions: &SimpleExtensions,
948        pair: Pair<Rule>,
949    ) -> Result<Self, MessageParseError> {
950        assert_eq!(pair.as_rule(), Self::rule());
951        let mut arg_inner = RuleIter::from(pair.into_inner());
952        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
953            int_pair
954        } else {
955            arg_inner.pop(Rule::expression)
956        };
957        match value_pair.as_rule() {
958            Rule::integer => {
959                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
960                    MessageParseError::invalid(
961                        Self::message(),
962                        value_pair.as_span(),
963                        format!("Invalid integer: {e}"),
964                    )
965                })?;
966                if value < 0 {
967                    return Err(MessageParseError::invalid(
968                        Self::message(),
969                        value_pair.as_span(),
970                        format!("Fetch limit must be non-negative, got: {value}"),
971                    ));
972                }
973                Ok(CountMode::CountExpr(i64_literal_expr(value)))
974            }
975            Rule::expression => {
976                let expr = Expression::parse_pair(extensions, value_pair)?;
977                Ok(CountMode::CountExpr(Box::new(expr)))
978            }
979            _ => Err(MessageParseError::invalid(
980                Self::message(),
981                value_pair.as_span(),
982                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
983            )),
984        }
985    }
986}
987
988fn i64_literal_expr(value: i64) -> Box<Expression> {
989    Box::new(Expression {
990        rex_type: Some(RexType::Literal(Literal {
991            nullable: false,
992            type_variation_reference: 0,
993            literal_type: Some(LiteralType::I64(value)),
994        })),
995    })
996}
997
998impl ScopedParsePair for OffsetMode {
999    fn rule() -> Rule {
1000        Rule::fetch_value
1001    }
1002    fn message() -> &'static str {
1003        "OffsetMode"
1004    }
1005    fn parse_pair(
1006        extensions: &SimpleExtensions,
1007        pair: Pair<Rule>,
1008    ) -> Result<Self, MessageParseError> {
1009        assert_eq!(pair.as_rule(), Self::rule());
1010        let mut arg_inner = RuleIter::from(pair.into_inner());
1011        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
1012            int_pair
1013        } else {
1014            arg_inner.pop(Rule::expression)
1015        };
1016        match value_pair.as_rule() {
1017            Rule::integer => {
1018                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
1019                    MessageParseError::invalid(
1020                        Self::message(),
1021                        value_pair.as_span(),
1022                        format!("Invalid integer: {e}"),
1023                    )
1024                })?;
1025                if value < 0 {
1026                    return Err(MessageParseError::invalid(
1027                        Self::message(),
1028                        value_pair.as_span(),
1029                        format!("Fetch offset must be non-negative, got: {value}"),
1030                    ));
1031                }
1032                Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
1033            }
1034            Rule::expression => {
1035                let expr = Expression::parse_pair(extensions, value_pair)?;
1036                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
1037            }
1038            _ => Err(MessageParseError::invalid(
1039                Self::message(),
1040                value_pair.as_span(),
1041                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
1042            )),
1043        }
1044    }
1045}
1046
1047impl RelationParsePair for FetchRel {
1048    fn rule() -> Rule {
1049        Rule::fetch_relation
1050    }
1051
1052    fn message() -> &'static str {
1053        "FetchRel"
1054    }
1055
1056    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1057        self.advanced_extension = adv_ext;
1058        Rel {
1059            rel_type: Some(RelType::Fetch(Box::new(self))),
1060        }
1061    }
1062
1063    fn parse_pair_with_context(
1064        extensions: &SimpleExtensions,
1065        pair: Pair<Rule>,
1066        input_children: Vec<Box<Rel>>,
1067        input_field_count: usize,
1068    ) -> Result<(Self, usize), MessageParseError> {
1069        assert_eq!(pair.as_rule(), Self::rule());
1070        let input = expect_one_child(Self::message(), &pair, input_children)?;
1071        let mut iter = RuleIter::from(pair.into_inner());
1072
1073        // Extract all pairs before any validation: RuleIter's Drop panics on
1074        // incomplete consumption, so we must exhaust the iterator before any
1075        // early return. Validation runs after iter.done() below.
1076        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
1077            None => {
1078                iter.pop(Rule::empty);
1079                (None, None)
1080            }
1081            Some(fetch_args_pair) => {
1082                let extractor =
1083                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
1084                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
1085                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
1086                extractor.done()?;
1087                (limit_pair, offset_pair)
1088            }
1089        };
1090
1091        let reference_list_pair = iter.pop(Rule::reference_list);
1092        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1093        let common = RelCommon {
1094            emit_kind: Some(emit),
1095            ..Default::default()
1096        };
1097        iter.done();
1098
1099        let count_mode = limit_pair
1100            .map(|pair| CountMode::parse_pair(extensions, pair))
1101            .transpose()?;
1102        let offset_mode = offset_pair
1103            .map(|pair| OffsetMode::parse_pair(extensions, pair))
1104            .transpose()?;
1105        Ok((
1106            FetchRel {
1107                input: Some(input),
1108                common: Some(common),
1109                advanced_extension: None,
1110                offset_mode,
1111                count_mode,
1112            },
1113            output_count,
1114        ))
1115    }
1116}
1117
1118impl ParsePair for join_rel::JoinType {
1119    fn rule() -> Rule {
1120        Rule::join_type
1121    }
1122
1123    fn message() -> &'static str {
1124        "JoinType"
1125    }
1126
1127    fn parse_pair(pair: Pair<Rule>) -> Self {
1128        assert_eq!(pair.as_rule(), Self::rule());
1129        let join_type_str = pair.as_str().trim_start_matches('&');
1130        match join_type_str {
1131            "Inner" => join_rel::JoinType::Inner,
1132            "Left" => join_rel::JoinType::Left,
1133            "Right" => join_rel::JoinType::Right,
1134            "Outer" => join_rel::JoinType::Outer,
1135            "LeftSemi" => join_rel::JoinType::LeftSemi,
1136            "RightSemi" => join_rel::JoinType::RightSemi,
1137            "LeftAnti" => join_rel::JoinType::LeftAnti,
1138            "RightAnti" => join_rel::JoinType::RightAnti,
1139            "LeftSingle" => join_rel::JoinType::LeftSingle,
1140            "RightSingle" => join_rel::JoinType::RightSingle,
1141            "LeftMark" => join_rel::JoinType::LeftMark,
1142            "RightMark" => join_rel::JoinType::RightMark,
1143            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
1144        }
1145    }
1146}
1147
1148impl RelationParsePair for JoinRel {
1149    fn rule() -> Rule {
1150        Rule::join_relation
1151    }
1152
1153    fn message() -> &'static str {
1154        "JoinRel"
1155    }
1156
1157    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1158        self.advanced_extension = adv_ext;
1159        Rel {
1160            rel_type: Some(RelType::Join(Box::new(self))),
1161        }
1162    }
1163
1164    fn parse_pair_with_context(
1165        extensions: &SimpleExtensions,
1166        pair: Pair<Rule>,
1167        input_children: Vec<Box<Rel>>,
1168        input_field_count: usize,
1169    ) -> Result<(Self, usize), MessageParseError> {
1170        assert_eq!(pair.as_rule(), Self::rule());
1171
1172        if input_children.len() != 2 {
1173            return Err(MessageParseError::invalid(
1174                Self::message(),
1175                pair.as_span(),
1176                format!(
1177                    "JoinRel should have exactly 2 input children, got {}",
1178                    input_children.len()
1179                ),
1180            ));
1181        }
1182
1183        let mut children_iter = input_children.into_iter();
1184        let left = children_iter.next().unwrap();
1185        let right = children_iter.next().unwrap();
1186
1187        let mut iter = RuleIter::from(pair.into_inner());
1188        let join_type = iter.parse_next::<join_rel::JoinType>();
1189        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
1190        let reference_list_pair = iter.pop(Rule::reference_list);
1191        iter.done();
1192
1193        // TODO: For semi/anti joins, the direct output width differs from
1194        // left+right — `input_field_count` would misclassify the emit as Direct.
1195        // Revisit when those join types are supported.
1196        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1197        let common = RelCommon {
1198            emit_kind: Some(emit),
1199            ..Default::default()
1200        };
1201
1202        Ok((
1203            JoinRel {
1204                common: Some(common),
1205                left: Some(left),
1206                right: Some(right),
1207                expression: Some(Box::new(condition)),
1208                post_join_filter: None, // not yet represented in the grammar
1209                r#type: join_type as i32,
1210                advanced_extension: None,
1211            },
1212            output_count,
1213        ))
1214    }
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219    use pest::Parser;
1220
1221    use super::*;
1222    use crate::fixtures::TestContext;
1223    use crate::parser::{ExpressionParser, Rule};
1224
1225    #[test]
1226    fn test_parse_relation() {
1227        // Removed: test_parse_relation for old Relation struct
1228    }
1229
1230    #[test]
1231    fn test_parse_read_relation() {
1232        let extensions = SimpleExtensions::default();
1233        let read = ReadRel::parse_pair_with_context(
1234            &extensions,
1235            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1236            vec![],
1237            0,
1238        )
1239        .unwrap()
1240        .0;
1241        let names = match &read.read_type {
1242            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1243            _ => panic!("Expected NamedTable"),
1244        };
1245        assert_eq!(names, &["ab", "cd", "ef"]);
1246        let columns = &read
1247            .base_schema
1248            .as_ref()
1249            .unwrap()
1250            .r#struct
1251            .as_ref()
1252            .unwrap()
1253            .types;
1254        assert_eq!(columns.len(), 2);
1255    }
1256
1257    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
1258    fn example_read_relation() -> ReadRel {
1259        let extensions = SimpleExtensions::default();
1260        ReadRel::parse_pair_with_context(
1261            &extensions,
1262            parse_exact(
1263                Rule::read_relation,
1264                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1265            ),
1266            vec![],
1267            0,
1268        )
1269        .unwrap()
1270        .0
1271    }
1272
1273    #[test]
1274    fn test_parse_filter_relation() {
1275        let extensions = SimpleExtensions::default();
1276        let filter = FilterRel::parse_pair_with_context(
1277            &extensions,
1278            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1279            vec![Box::new(example_read_relation().into_rel(None))],
1280            3,
1281        )
1282        .unwrap()
1283        .0;
1284        // Identity mapping [0, 1, 2] over 3 inputs → Direct
1285        let emit_kind = filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1286        assert!(
1287            matches!(emit_kind, EmitKind::Direct(_)),
1288            "Expected Direct for identity emit, got {emit_kind:?}"
1289        );
1290    }
1291
1292    #[test]
1293    fn test_parse_project_relation() {
1294        let extensions = SimpleExtensions::default();
1295        let project = ProjectRel::parse_pair_with_context(
1296            &extensions,
1297            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1298            vec![Box::new(example_read_relation().into_rel(None))],
1299            3,
1300        )
1301        .unwrap()
1302        .0;
1303
1304        // Should have 1 expression (42) and 2 references ($0, $1)
1305        assert_eq!(project.expressions.len(), 1);
1306
1307        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1308        let emit = match emit_kind {
1309            EmitKind::Emit(emit) => &emit.output_mapping,
1310            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1311        };
1312        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
1313        assert_eq!(emit, &[0, 1, 3]);
1314    }
1315
1316    #[test]
1317    fn test_parse_project_relation_complex() {
1318        let extensions = SimpleExtensions::default();
1319        let project = ProjectRel::parse_pair_with_context(
1320            &extensions,
1321            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1322            vec![Box::new(example_read_relation().into_rel(None))],
1323            5, // Assume 5 input fields
1324        )
1325        .unwrap()
1326        .0;
1327
1328        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
1329        assert_eq!(project.expressions.len(), 2);
1330
1331        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1332        let emit = match emit_kind {
1333            EmitKind::Emit(emit) => &emit.output_mapping,
1334            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1335        };
1336        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
1337        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
1338        assert_eq!(emit, &[5, 0, 6, 2, 1]);
1339    }
1340
1341    #[test]
1342    fn test_parse_aggregate_relation() {
1343        let extensions = TestContext::new()
1344            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1345            .with_function(1, 10, "sum")
1346            .with_function(1, 11, "count")
1347            .extensions;
1348
1349        let aggregate = AggregateRel::parse_pair_with_context(
1350            &extensions,
1351            parse_exact(
1352                Rule::aggregate_relation,
1353                "Aggregate[($0, $1), _ => sum($2), $0, count($2)]",
1354            ),
1355            vec![Box::new(example_read_relation().into_rel(None))],
1356            3,
1357        )
1358        .unwrap()
1359        .0;
1360
1361        // Should have 2 group-by sets ($0, $1) and an empty group, and emit 2 measures (sum($2), count($2))
1362        assert_eq!(aggregate.grouping_expressions.len(), 2);
1363        assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1364        assert_eq!(aggregate.groupings.len(), 2);
1365        assert_eq!(aggregate.measures.len(), 2);
1366
1367        let emit_kind = &aggregate
1368            .common
1369            .as_ref()
1370            .unwrap()
1371            .emit_kind
1372            .as_ref()
1373            .unwrap();
1374        let emit = match emit_kind {
1375            EmitKind::Emit(emit) => &emit.output_mapping,
1376            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1377        };
1378        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1379        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1380        assert_eq!(emit, &[2, 0, 3]);
1381    }
1382
1383    #[test]
1384    fn test_parse_aggregate_relation_maintain_column_order() {
1385        let extensions = TestContext::new()
1386            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1387            .with_function(1, 10, "sum")
1388            .with_function(1, 11, "count")
1389            .extensions;
1390
1391        let aggregate = AggregateRel::parse_pair_with_context(
1392            &extensions,
1393            parse_exact(
1394                Rule::aggregate_relation,
1395                "Aggregate[$0 => sum($1), $0, count($1)]",
1396            ),
1397            vec![Box::new(example_read_relation().into_rel(None))],
1398            3,
1399        )
1400        .unwrap()
1401        .0;
1402
1403        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1404        assert_eq!(aggregate.grouping_expressions.len(), 1);
1405        assert_eq!(aggregate.groupings.len(), 1);
1406        assert_eq!(aggregate.measures.len(), 2);
1407
1408        let emit_kind = &aggregate
1409            .common
1410            .as_ref()
1411            .unwrap()
1412            .emit_kind
1413            .as_ref()
1414            .unwrap();
1415        let emit = match emit_kind {
1416            EmitKind::Emit(emit) => &emit.output_mapping,
1417            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1418        };
1419        // Output mapping should be [1, 0, 2] (grouping fields + measures)
1420        assert_eq!(emit, &[1, 0, 2]);
1421    }
1422
1423    #[test]
1424    fn test_parse_aggregate_relation_simple() {
1425        let extensions = TestContext::new()
1426            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1427            .with_function(1, 10, "sum")
1428            .extensions;
1429
1430        let aggregate = AggregateRel::parse_pair_with_context(
1431            &extensions,
1432            parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"),
1433            vec![Box::new(example_read_relation().into_rel(None))],
1434            3,
1435        )
1436        .unwrap()
1437        .0;
1438
1439        assert_eq!(aggregate.grouping_expressions.len(), 2);
1440        assert_eq!(aggregate.groupings.len(), 1);
1441        // expression_references must be positions [0, 1], not raw field indices [2, 0]
1442        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1443    }
1444
1445    #[test]
1446    fn test_parse_aggregate_relation_global_aggregate() {
1447        let extensions = TestContext::new()
1448            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1449            .with_function(1, 10, "sum")
1450            .with_function(1, 11, "count")
1451            .extensions;
1452
1453        let aggregate = AggregateRel::parse_pair_with_context(
1454            &extensions,
1455            parse_exact(
1456                Rule::aggregate_relation,
1457                "Aggregate[_ => sum($0), count($1)]",
1458            ),
1459            vec![Box::new(example_read_relation().into_rel(None))],
1460            3,
1461        )
1462        .unwrap()
1463        .0;
1464
1465        // Should have 0 group-by fields and 2 measures
1466        assert_eq!(aggregate.grouping_expressions.len(), 0);
1467        assert_eq!(aggregate.groupings.len(), 1);
1468        assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1469        assert_eq!(aggregate.measures.len(), 2);
1470
1471        // Identity mapping [0, 1] over 2 outputs (0 grouping + 2 measures) → Direct
1472        let emit_kind = aggregate
1473            .common
1474            .as_ref()
1475            .unwrap()
1476            .emit_kind
1477            .as_ref()
1478            .unwrap();
1479        assert!(
1480            matches!(emit_kind, EmitKind::Direct(_)),
1481            "Expected Direct for identity emit, got {emit_kind:?}"
1482        );
1483    }
1484
1485    #[test]
1486    fn test_parse_aggregate_relation_grouping_sets() {
1487        let extensions = TestContext::new()
1488            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1489            .with_function(1, 11, "count")
1490            .extensions;
1491
1492        let read_rel = ReadRel::parse_pair_with_context(
1493            &extensions,
1494            parse_exact(
1495                Rule::read_relation,
1496                "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1497            ),
1498            vec![],
1499            0,
1500        )
1501        .unwrap()
1502        .0;
1503
1504        let aggregate = AggregateRel::parse_pair_with_context(
1505            &extensions,
1506            parse_exact(
1507                Rule::aggregate_relation,
1508                "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]",
1509            ),
1510            vec![Box::new(read_rel.into_rel(None))],
1511            4,
1512        )
1513        .unwrap()
1514        .0;
1515
1516        assert_eq!(aggregate.grouping_expressions.len(), 3);
1517        assert_eq!(aggregate.groupings.len(), 4);
1518        // ($0, $1, $2) -> [0, 1, 2]
1519        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1520        // ($2, $0) -> [2, 0] (reuses indices, different order)
1521        assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1522        // ($1) -> [1]
1523        assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1524        // _ -> empty
1525        assert!(aggregate.groupings[3].expression_references.is_empty());
1526        assert_eq!(aggregate.measures.len(), 1);
1527    }
1528
1529    #[test]
1530    fn test_fetch_relation_positive_values() {
1531        let extensions = SimpleExtensions::default();
1532
1533        // Test valid positive values should work
1534        let fetch_rel = FetchRel::parse_pair_with_context(
1535            &extensions,
1536            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1537            vec![Box::new(example_read_relation().into_rel(None))],
1538            3,
1539        )
1540        .unwrap()
1541        .0;
1542
1543        // Verify the limit and offset values are correct
1544        assert_eq!(
1545            fetch_rel.count_mode,
1546            Some(CountMode::CountExpr(i64_literal_expr(10)))
1547        );
1548        assert_eq!(
1549            fetch_rel.offset_mode,
1550            Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1551        );
1552    }
1553
1554    #[test]
1555    fn test_fetch_relation_negative_limit_rejected() {
1556        let extensions = SimpleExtensions::default();
1557
1558        // Test that fetch relations with negative limits are properly rejected
1559        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1560        if let Ok(mut pairs) = parsed_result {
1561            let pair = pairs.next().unwrap();
1562            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1563                // Full parse succeeded, now test that validation catches the negative value
1564                let result = FetchRel::parse_pair_with_context(
1565                    &extensions,
1566                    pair,
1567                    vec![Box::new(example_read_relation().into_rel(None))],
1568                    3,
1569                );
1570                assert!(result.is_err());
1571                let error_msg = result.unwrap_err().to_string();
1572                assert!(error_msg.contains("Fetch limit must be non-negative"));
1573            } else {
1574                // If grammar doesn't fully support negative values, that's also acceptable
1575                // since it would prevent negative values at parse time
1576                println!("Grammar prevents negative limit values at parse time");
1577            }
1578        } else {
1579            // Grammar doesn't support negative values in fetch context
1580            println!("Grammar prevents negative limit values at parse time");
1581        }
1582    }
1583
1584    #[test]
1585    fn test_fetch_relation_negative_offset_rejected() {
1586        let extensions = SimpleExtensions::default();
1587
1588        // Test that fetch relations with negative offsets are properly rejected
1589        let parsed_result =
1590            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1591        if let Ok(mut pairs) = parsed_result {
1592            let pair = pairs.next().unwrap();
1593            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1594                // Full parse succeeded, now test that validation catches the negative value
1595                let result = FetchRel::parse_pair_with_context(
1596                    &extensions,
1597                    pair,
1598                    vec![Box::new(example_read_relation().into_rel(None))],
1599                    3,
1600                );
1601                assert!(result.is_err());
1602                let error_msg = result.unwrap_err().to_string();
1603                assert!(error_msg.contains("Fetch offset must be non-negative"));
1604            } else {
1605                // If grammar doesn't fully support negative values, that's also acceptable
1606                println!("Grammar prevents negative offset values at parse time");
1607            }
1608        } else {
1609            // Grammar doesn't support negative values in fetch context
1610            println!("Grammar prevents negative offset values at parse time");
1611        }
1612    }
1613
1614    #[test]
1615    fn test_parse_join_relation() {
1616        let extensions = TestContext::new()
1617            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1618            .with_function(1, 10, "eq")
1619            .extensions;
1620
1621        let left_rel = example_read_relation().into_rel(None);
1622        let right_rel = example_read_relation().into_rel(None);
1623
1624        let join = JoinRel::parse_pair_with_context(
1625            &extensions,
1626            parse_exact(
1627                Rule::join_relation,
1628                "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1629            ),
1630            vec![Box::new(left_rel), Box::new(right_rel)],
1631            6, // left (3) + right (3) = 6 total input fields
1632        )
1633        .unwrap()
1634        .0;
1635
1636        // Should be an Inner join
1637        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1638
1639        // Should have left and right relations
1640        assert!(join.left.is_some());
1641        assert!(join.right.is_some());
1642
1643        // Should have a join condition
1644        assert!(join.expression.is_some());
1645
1646        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1647        let emit = match emit_kind {
1648            EmitKind::Emit(emit) => &emit.output_mapping,
1649            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1650        };
1651        // Output mapping should be [0, 1, 3, 4] (selected columns)
1652        assert_eq!(emit, &[0, 1, 3, 4]);
1653    }
1654
1655    #[test]
1656    fn test_parse_join_relation_left_outer() {
1657        let extensions = TestContext::new()
1658            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1659            .with_function(1, 10, "eq")
1660            .extensions;
1661
1662        let left_rel = example_read_relation().into_rel(None);
1663        let right_rel = example_read_relation().into_rel(None);
1664
1665        let join = JoinRel::parse_pair_with_context(
1666            &extensions,
1667            parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1668            vec![Box::new(left_rel), Box::new(right_rel)],
1669            6,
1670        )
1671        .unwrap()
1672        .0;
1673
1674        // Should be a Left join
1675        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1676
1677        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1678        let emit = match emit_kind {
1679            EmitKind::Emit(emit) => &emit.output_mapping,
1680            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1681        };
1682        // Output mapping should be [0, 1, 2]
1683        assert_eq!(emit, &[0, 1, 2]);
1684    }
1685
1686    #[test]
1687    fn test_parse_join_relation_left_semi() {
1688        let extensions = TestContext::new()
1689            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1690            .with_function(1, 10, "eq")
1691            .extensions;
1692
1693        let left_rel = example_read_relation().into_rel(None);
1694        let right_rel = example_read_relation().into_rel(None);
1695
1696        let join = JoinRel::parse_pair_with_context(
1697            &extensions,
1698            parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1699            vec![Box::new(left_rel), Box::new(right_rel)],
1700            6,
1701        )
1702        .unwrap()
1703        .0;
1704
1705        // Should be a LeftSemi join
1706        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1707
1708        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1709        let emit = match emit_kind {
1710            EmitKind::Emit(emit) => &emit.output_mapping,
1711            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1712        };
1713        // Output mapping should be [0, 1] (only left columns for semi join)
1714        assert_eq!(emit, &[0, 1]);
1715    }
1716
1717    #[test]
1718    fn test_parse_join_relation_requires_two_children() {
1719        let extensions = SimpleExtensions::default();
1720
1721        // Test with 0 children
1722        let result = JoinRel::parse_pair_with_context(
1723            &extensions,
1724            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1725            vec![],
1726            0,
1727        );
1728        assert!(result.is_err());
1729
1730        // Test with 1 child
1731        let result = JoinRel::parse_pair_with_context(
1732            &extensions,
1733            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1734            vec![Box::new(example_read_relation().into_rel(None))],
1735            3,
1736        );
1737        assert!(result.is_err());
1738
1739        // Test with 3 children
1740        let result = JoinRel::parse_pair_with_context(
1741            &extensions,
1742            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1743            vec![
1744                Box::new(example_read_relation().into_rel(None)),
1745                Box::new(example_read_relation().into_rel(None)),
1746                Box::new(example_read_relation().into_rel(None)),
1747            ],
1748            9,
1749        );
1750        assert!(result.is_err());
1751    }
1752
1753    fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1754        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1755        assert_eq!(pairs.as_str(), input);
1756        let pair = pairs.next().unwrap();
1757        assert_eq!(pairs.next(), None);
1758        pair
1759    }
1760}