Skip to main content

substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType, nested};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::{CountMode, OffsetMode};
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::{Direct, Emit, EmitKind};
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
15    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
16};
17
18use super::{MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair};
19use crate::extensions::any::Any;
20use crate::extensions::registry::ExtensionError;
21use crate::extensions::{AddendumKind, ExtensionArgs, ExtensionRegistry, SimpleExtensions};
22use crate::parser::errors::{ParseContext, ParseError};
23use crate::parser::expressions::{FieldIndex, Name};
24
25/// Parsing context for relations that includes extensions, registry, and optional warning collection
26pub struct RelationParsingContext<'a> {
27    pub registry: &'a ExtensionRegistry,
28    pub line_no: i64,
29    pub line: &'a str,
30}
31
32impl<'a> RelationParsingContext<'a> {
33    /// Resolve extension detail using registry. Any failure is treated as a hard parse error.
34    pub fn resolve_extension_detail(
35        &self,
36        extension_name: &str,
37        extension_args: &ExtensionArgs,
38    ) -> Result<Option<Any>, ParseError> {
39        let detail = self
40            .registry
41            .parse_extension(extension_name, extension_args);
42
43        match detail {
44            Ok(any) => Ok(Some(any)),
45            Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
46                name: extension_name.to_string(),
47                context: ParseContext::new(self.line_no, self.line.to_string()),
48            }),
49            Err(err) => Err(ParseError::ExtensionDetail(
50                ParseContext::new(self.line_no, self.line.to_string()),
51                err,
52            )),
53        }
54    }
55
56    /// Resolve an addendum detail using the registry entry for its kind.
57    /// Any failure is treated as a hard parse error.
58    pub(crate) fn resolve_addendum_detail(
59        &self,
60        kind: AddendumKind,
61        name: &str,
62        args: &ExtensionArgs,
63    ) -> Result<Any, ParseError> {
64        let result = match kind {
65            AddendumKind::Enhancement => self.registry.parse_enhancement(name, args),
66            AddendumKind::Optimization => self.registry.parse_optimization(name, args),
67            AddendumKind::ExtensionTable => self.registry.parse_extension_table(name, args),
68        };
69        result.map_err(|err| match err {
70            ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
71                name: name.to_string(),
72                context: ParseContext::new(self.line_no, self.line.to_string()),
73            },
74            err => ParseError::ExtensionDetail(
75                ParseContext::new(self.line_no, self.line.to_string()),
76                err,
77            ),
78        })
79    }
80}
81
82/// A trait for parsing relations with full context for tree building.
83pub trait RelationParsePair: Sized {
84    fn rule() -> Rule;
85    fn message() -> &'static str;
86
87    /// Parse the grammar pair into this relation type and its output field
88    /// count.
89    ///
90    /// Returns `(Self, usize)` where `usize` is the output field count —
91    /// computed during parsing when `input_field_count` is available.
92    fn parse_pair_with_context(
93        extensions: &SimpleExtensions,
94        pair: Pair<Rule>,
95        input_children: Vec<Rel>,
96        input_field_count: usize,
97    ) -> Result<(Self, usize), MessageParseError>;
98
99    /// Consume this parsed relation, apply the advanced extension, and produce
100    /// the final `Rel`.
101    fn into_rel(self, adv_ext: Option<AdvancedExtension>) -> Rel;
102}
103
104pub struct TableName(Vec<String>);
105
106impl ParsePair for TableName {
107    fn rule() -> Rule {
108        Rule::table_name
109    }
110
111    fn message() -> &'static str {
112        "TableName"
113    }
114
115    fn parse_pair(pair: Pair<Rule>) -> Self {
116        assert_eq!(pair.as_rule(), Self::rule());
117        let pairs = pair.into_inner();
118        let mut names = Vec::with_capacity(pairs.len());
119        let mut iter = RuleIter::from(pairs);
120        while let Some(name) = iter.parse_if_next::<Name>() {
121            names.push(name.0);
122        }
123        iter.done();
124        Self(names)
125    }
126}
127
128#[derive(Debug, Clone)]
129pub struct Column {
130    pub name: String,
131    pub typ: Type,
132}
133
134impl ScopedParsePair for Column {
135    fn rule() -> Rule {
136        Rule::named_column
137    }
138
139    fn message() -> &'static str {
140        "Column"
141    }
142
143    fn parse_pair(
144        extensions: &SimpleExtensions,
145        pair: Pair<Rule>,
146    ) -> Result<Self, MessageParseError> {
147        assert_eq!(pair.as_rule(), Self::rule());
148        let mut iter = RuleIter::from(pair.into_inner());
149        let name = iter.parse_next::<Name>().0;
150        let typ = iter.parse_next_scoped(extensions)?;
151        iter.done();
152        Ok(Self { name, typ })
153    }
154}
155
156pub(crate) struct NamedColumnList(pub(crate) Vec<Column>);
157
158impl ScopedParsePair for NamedColumnList {
159    fn rule() -> Rule {
160        Rule::named_column_list
161    }
162
163    fn message() -> &'static str {
164        "NamedColumnList"
165    }
166
167    fn parse_pair(
168        extensions: &SimpleExtensions,
169        pair: Pair<Rule>,
170    ) -> Result<Self, MessageParseError> {
171        assert_eq!(pair.as_rule(), Self::rule());
172        let mut columns = Vec::new();
173        for col in pair.into_inner() {
174            columns.push(Column::parse_pair(extensions, col)?);
175        }
176        Ok(Self(columns))
177    }
178}
179
180/// This is a utility function for extracting a single child from the list of
181/// children, to be used in the RelationParsePair trait. The RelationParsePair
182/// trait passes a Vec of children, because some relations have multiple
183/// children - but most accept exactly one child.
184pub(crate) fn expect_one_child(
185    message: &'static str,
186    pair: &Pair<Rule>,
187    mut input_children: Vec<Rel>,
188) -> Result<Box<Rel>, MessageParseError> {
189    match input_children.len() {
190        0 => Err(MessageParseError::invalid(
191            message,
192            pair.as_span(),
193            format!("{message} missing child"),
194        )),
195        1 => Ok(Box::new(input_children.pop().unwrap())),
196        n => Err(MessageParseError::invalid(
197            message,
198            pair.as_span(),
199            format!("{message} should have 1 input child, got {n}"),
200        )),
201    }
202}
203
204/// Parse a reference list Pair and return an EmitKind::Emit.
205/// Parse a reference list into field indices for emit mapping.
206fn parse_output_mapping(pair: Pair<Rule>) -> Vec<i32> {
207    assert_eq!(pair.as_rule(), Rule::reference_list);
208    pair.into_inner()
209        .map(|p| FieldIndex::parse_pair(p).0)
210        .collect()
211}
212
213/// Build an emit: `Direct` if the mapping is the identity `[0, 1, ..., N-1]`
214/// (where N = `direct_output_count`), otherwise `Emit` with the explicit mapping.
215fn make_emit(output_mapping: Vec<i32>, direct_output_count: usize) -> EmitKind {
216    let is_identity = output_mapping.len() == direct_output_count
217        && output_mapping
218            .iter()
219            .enumerate()
220            .all(|(i, &v)| v == i as i32);
221    if is_identity {
222        EmitKind::Direct(Direct {})
223    } else {
224        EmitKind::Emit(Emit { output_mapping })
225    }
226}
227
228/// Parse a reference list into an emit and output field count.
229fn parse_emit(reference_list: Pair<Rule>, direct_output_count: usize) -> (EmitKind, usize) {
230    let output_mapping = parse_output_mapping(reference_list);
231    let output_count = output_mapping.len();
232    let emit = make_emit(output_mapping, direct_output_count);
233    (emit, output_count)
234}
235
236/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
237///
238/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
239///
240/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
241pub struct ParsedNamedArgs<'a> {
242    map: HashMap<&'a str, Pair<'a, Rule>>,
243}
244
245impl<'a> ParsedNamedArgs<'a> {
246    pub fn new(
247        pairs: pest::iterators::Pairs<'a, Rule>,
248        rule: Rule,
249    ) -> Result<Self, MessageParseError> {
250        let mut map = HashMap::new();
251        for pair in pairs {
252            assert_eq!(pair.as_rule(), rule);
253            let mut inner = pair.clone().into_inner();
254            let name_pair = inner.next().unwrap();
255            let value_pair = inner.next().unwrap();
256            assert_eq!(inner.next(), None);
257            let name = name_pair.as_str();
258            if map.contains_key(name) {
259                return Err(MessageParseError::invalid(
260                    "NamedArg",
261                    name_pair.as_span(),
262                    format!("Duplicate argument: {name}"),
263                ));
264            }
265            map.insert(name, value_pair);
266        }
267        Ok(Self { map })
268    }
269
270    // Returns the pair if it exists and matches the rule, otherwise None.
271    // Asserts that the rule must match the rule of the pair (and therefore
272    // panics in non-release-mode if not)
273    pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
274        let pair = self.map.remove(name).inspect(|pair| {
275            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
276        });
277        (self, pair)
278    }
279
280    // Returns an error if there are any unused arguments.
281    pub fn done(self) -> Result<(), MessageParseError> {
282        if let Some((name, pair)) = self.map.iter().next() {
283            return Err(MessageParseError::invalid(
284                "NamedArgExtractor",
285                // No span available for all unused args; use default.
286                pair.as_span(),
287                format!("Unknown argument: {name}"),
288            ));
289        }
290        Ok(())
291    }
292}
293
294impl RelationParsePair for ReadRel {
295    fn rule() -> Rule {
296        Rule::read_relation
297    }
298
299    fn message() -> &'static str {
300        "ReadRel"
301    }
302
303    fn parse_pair_with_context(
304        extensions: &SimpleExtensions,
305        pair: Pair<Rule>,
306        input_children: Vec<Rel>,
307        input_field_count: usize,
308    ) -> Result<(Self, usize), MessageParseError> {
309        assert_eq!(pair.as_rule(), Self::rule());
310        // ReadRel is a leaf node - it should have no input children and 0 input fields
311        if !input_children.is_empty() {
312            return Err(MessageParseError::invalid(
313                Self::message(),
314                pair.as_span(),
315                "ReadRel should have no input children",
316            ));
317        }
318        if input_field_count != 0 {
319            return Err(MessageParseError::invalid(
320                "ReadRel",
321                pair.as_span(),
322                "ReadRel should have 0 input fields",
323            ));
324        }
325
326        let mut iter = RuleIter::from(pair.into_inner());
327        let table = iter.parse_next::<TableName>().0;
328        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
329        iter.done();
330
331        let output_count = columns.len();
332        Ok((
333            ReadRel {
334                base_schema: Some(build_named_struct(columns)),
335                read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
336                    names: table,
337                    advanced_extension: None,
338                })),
339                ..Default::default()
340            },
341            output_count,
342        ))
343    }
344
345    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
346        self.advanced_extension = adv_ext;
347        Rel {
348            rel_type: Some(RelType::Read(Box::new(self))),
349        }
350    }
351}
352
353/// Parsed `Read:Virtual[rows => columns]` relation. Needs a newtype because the
354/// proto type is `ReadRel` (same as `NamedTable`), but the grammar and handling
355/// are different.
356pub(crate) struct VirtualReadRel(ReadRel);
357
358impl RelationParsePair for VirtualReadRel {
359    fn rule() -> Rule {
360        Rule::virtual_read_relation
361    }
362
363    fn message() -> &'static str {
364        "VirtualReadRel"
365    }
366
367    fn parse_pair_with_context(
368        extensions: &SimpleExtensions,
369        pair: Pair<Rule>,
370        input_children: Vec<Rel>,
371        _input_field_count: usize,
372    ) -> Result<(Self, usize), MessageParseError> {
373        assert_eq!(pair.as_rule(), Self::rule());
374        if !input_children.is_empty() {
375            return Err(MessageParseError::invalid(
376                Self::message(),
377                pair.as_span(),
378                "Read:Virtual should have no input children",
379            ));
380        }
381
382        let mut iter = RuleIter::from(pair.into_inner());
383        let args_pair = iter.pop(Rule::virtual_read_args);
384        let columns_pair = iter.pop(Rule::named_column_list);
385        iter.done();
386
387        let rows = parse_virtual_read_args(extensions, args_pair)?;
388        let columns = NamedColumnList::parse_pair(extensions, columns_pair)?.0;
389
390        // TODO: Validate that each row has the same number of expressions as
391        // columns. Currently no parser-side warning mechanism exists, and while
392        // this is an invalid plan, it is constructible as Substrait. Consider
393        // adding once a warning collection path is available.
394        let output_count = columns.len();
395        Ok((
396            VirtualReadRel(ReadRel {
397                base_schema: Some(build_named_struct(columns)),
398                read_type: Some(read_rel::ReadType::VirtualTable(read_rel::VirtualTable {
399                    expressions: rows,
400                    ..Default::default()
401                })),
402                ..Default::default()
403            }),
404            output_count,
405        ))
406    }
407
408    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
409        self.0.advanced_extension = adv_ext;
410        Rel {
411            rel_type: Some(RelType::Read(Box::new(self.0))),
412        }
413    }
414}
415
416/// Parsed `Read:Extension[columns]` relation. Needs a newtype because the
417/// proto type is `ReadRel` (same as `NamedTable` and `VirtualTable`), but the
418/// read detail is supplied by a required `+ Ext:` addendum.
419pub(crate) struct ExtensionReadRel(ReadRel);
420
421impl ExtensionReadRel {
422    pub(crate) fn parse_pair_with_detail(
423        extensions: &SimpleExtensions,
424        pair: Pair<Rule>,
425        input_children: Vec<Rel>,
426        input_field_count: usize,
427        detail: Any,
428        advanced_extension: Option<AdvancedExtension>,
429    ) -> Result<(Rel, usize), MessageParseError> {
430        assert_eq!(pair.as_rule(), Rule::extension_read_relation);
431        if !input_children.is_empty() {
432            return Err(MessageParseError::invalid(
433                "ExtensionReadRel",
434                pair.as_span(),
435                "Read:Extension should have no input children",
436            ));
437        }
438        if input_field_count != 0 {
439            return Err(MessageParseError::invalid(
440                "ExtensionReadRel",
441                pair.as_span(),
442                "Read:Extension should have 0 input fields",
443            ));
444        }
445
446        let mut iter = RuleIter::from(pair.into_inner());
447        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
448        iter.done();
449
450        let output_count = columns.len();
451        let rel = ExtensionReadRel(ReadRel {
452            base_schema: Some(build_named_struct(columns)),
453            read_type: Some(read_rel::ReadType::ExtensionTable(
454                read_rel::ExtensionTable {
455                    detail: Some(detail.into()),
456                },
457            )),
458            advanced_extension,
459            ..Default::default()
460        })
461        .into_rel();
462
463        Ok((rel, output_count))
464    }
465
466    fn into_rel(self) -> Rel {
467        Rel {
468            rel_type: Some(RelType::Read(Box::new(self.0))),
469        }
470    }
471}
472
473/// Build a `NamedStruct` from parsed columns.
474pub(crate) fn build_named_struct(columns: Vec<Column>) -> NamedStruct {
475    let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
476    NamedStruct {
477        names,
478        r#struct: Some(r#type::Struct {
479            types,
480            type_variation_reference: 0,
481            nullability: r#type::Nullability::Required as i32,
482        }),
483    }
484}
485
486/// `Read:Virtual` positional args: either `empty` or a list of row tuples.
487fn parse_virtual_read_args(
488    extensions: &SimpleExtensions,
489    pair: Pair<Rule>,
490) -> Result<Vec<nested::Struct>, MessageParseError> {
491    assert_eq!(pair.as_rule(), Rule::virtual_read_args);
492    let inner = unwrap_single_pair(pair);
493    match inner.as_rule() {
494        Rule::empty => Ok(vec![]),
495        Rule::virtual_row_list => inner
496            .into_inner()
497            .map(|row| parse_virtual_row(extensions, row))
498            .collect(),
499        _ => unreachable!(
500            "Unexpected rule in virtual_read_args: {:?}",
501            inner.as_rule()
502        ),
503    }
504}
505
506/// Parse a single `virtual_row` (`(expr, expr, ...)` or `()`) into a `nested::Struct`.
507fn parse_virtual_row(
508    extensions: &SimpleExtensions,
509    pair: Pair<Rule>,
510) -> Result<nested::Struct, MessageParseError> {
511    assert_eq!(pair.as_rule(), Rule::virtual_row);
512    let fields = match pair.into_inner().next() {
513        Some(expression_list) => {
514            assert_eq!(expression_list.as_rule(), Rule::expression_list);
515            parse_expression_list(extensions, expression_list)?
516        }
517        // Empty virtual row. An unusual but valid case.
518        None => vec![],
519    };
520    Ok(nested::Struct { fields })
521}
522
523impl RelationParsePair for FilterRel {
524    fn rule() -> Rule {
525        Rule::filter_relation
526    }
527
528    fn message() -> &'static str {
529        "FilterRel"
530    }
531
532    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
533        self.advanced_extension = adv_ext;
534        Rel {
535            rel_type: Some(RelType::Filter(Box::new(self))),
536        }
537    }
538
539    fn parse_pair_with_context(
540        extensions: &SimpleExtensions,
541        pair: Pair<Rule>,
542        input_children: Vec<Rel>,
543        input_field_count: usize,
544    ) -> Result<(Self, usize), MessageParseError> {
545        assert_eq!(pair.as_rule(), Self::rule());
546        let input = expect_one_child(Self::message(), &pair, input_children)?;
547        let mut iter = RuleIter::from(pair.into_inner());
548        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
549        // references (which become the emit)
550        let references_pair = iter.pop(Rule::reference_list);
551        iter.done();
552
553        let (emit, output_count) = parse_emit(references_pair, input_field_count);
554        let common = RelCommon {
555            emit_kind: Some(emit),
556            ..Default::default()
557        };
558
559        Ok((
560            FilterRel {
561                input: Some(input),
562                condition: Some(Box::new(condition)),
563                common: Some(common),
564                advanced_extension: None,
565            },
566            output_count,
567        ))
568    }
569}
570
571impl RelationParsePair for ProjectRel {
572    fn rule() -> Rule {
573        Rule::project_relation
574    }
575
576    fn message() -> &'static str {
577        "ProjectRel"
578    }
579
580    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
581        self.advanced_extension = adv_ext;
582        Rel {
583            rel_type: Some(RelType::Project(Box::new(self))),
584        }
585    }
586
587    fn parse_pair_with_context(
588        extensions: &SimpleExtensions,
589        pair: Pair<Rule>,
590        input_children: Vec<Rel>,
591        input_field_count: usize,
592    ) -> Result<(Self, usize), MessageParseError> {
593        assert_eq!(pair.as_rule(), Self::rule());
594        let input = expect_one_child(Self::message(), &pair, input_children)?;
595
596        let arguments_pair = unwrap_single_pair(pair);
597
598        let mut expressions = Vec::new();
599        let mut output_mapping = Vec::new();
600
601        for arg in arguments_pair.into_inner() {
602            let inner_arg = unwrap_single_pair(arg);
603            match inner_arg.as_rule() {
604                Rule::reference => {
605                    let field_index = FieldIndex::parse_pair(inner_arg);
606                    output_mapping.push(field_index.0);
607                }
608                Rule::expression => {
609                    let expr = Expression::parse_pair(extensions, inner_arg)?;
610                    expressions.push(expr);
611                    // Index into the combined schema: [input fields][computed expressions].
612                    output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
613                }
614                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
615            }
616        }
617
618        let output_count = output_mapping.len();
619        let direct_count = input_field_count + expressions.len();
620        let emit = make_emit(output_mapping, direct_count);
621        let common = RelCommon {
622            emit_kind: Some(emit),
623            ..Default::default()
624        };
625
626        Ok((
627            ProjectRel {
628                input: Some(input),
629                expressions,
630                common: Some(common),
631                advanced_extension: None,
632            },
633            output_count,
634        ))
635    }
636}
637
638impl RelationParsePair for AggregateRel {
639    fn rule() -> Rule {
640        Rule::aggregate_relation
641    }
642
643    fn message() -> &'static str {
644        "AggregateRel"
645    }
646
647    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
648        self.advanced_extension = adv_ext;
649        Rel {
650            rel_type: Some(RelType::Aggregate(Box::new(self))),
651        }
652    }
653
654    fn parse_pair_with_context(
655        extensions: &SimpleExtensions,
656        pair: Pair<Rule>,
657        input_children: Vec<Rel>,
658        // Aggregate defines its own output schema (grouping keys + measures),
659        // so the input field count isn't needed for emit construction.
660        _input_field_count: usize,
661    ) -> Result<(Self, usize), MessageParseError> {
662        assert_eq!(pair.as_rule(), Self::rule());
663        let input = expect_one_child(Self::message(), &pair, input_children)?;
664        let mut iter = RuleIter::from(pair.into_inner());
665        let group_by_pair = iter.pop(Rule::aggregate_group_by);
666        let output_pair = iter.pop(Rule::aggregate_output);
667        iter.done();
668
669        let inner = group_by_pair
670            .into_inner()
671            .next()
672            .expect("aggregate_group_by must have one inner item");
673
674        let grouping_sets = parse_grouping_sets(extensions, inner)?;
675        let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
676
677        let (measures, output_mapping) =
678            parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
679
680        let output_count = output_mapping.len();
681        let direct_count = grouping_expressions.len() + measures.len();
682        let emit = make_emit(output_mapping, direct_count);
683        let common = RelCommon {
684            emit_kind: Some(emit),
685            ..Default::default()
686        };
687
688        Ok((
689            AggregateRel {
690                input: Some(input),
691                grouping_expressions,
692                groupings,
693                measures,
694                common: Some(common),
695                advanced_extension: None,
696            },
697            output_count,
698        ))
699    }
700}
701
702/// Parses the output section of an aggregate (everything after `=>`).
703///
704/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
705/// this parses `sum($2), $0, count($2)`.
706fn parse_aggregate_measures(
707    extensions: &SimpleExtensions,
708    output_pair: Pair<'_, Rule>,
709    grouping_expressions: &[Expression],
710) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
711    assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
712    let mut measures = Vec::new();
713    let mut output_mapping = Vec::new();
714
715    for aggregate_output_item in output_pair.into_inner() {
716        let inner_item = unwrap_single_pair(aggregate_output_item);
717        match inner_item.as_rule() {
718            Rule::reference => {
719                let field_index = FieldIndex::parse_pair(inner_item);
720                output_mapping.push(field_index.0);
721            }
722            Rule::aggregate_measure => {
723                let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
724                output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
725                measures.push(measure);
726            }
727            _ => panic!(
728                "Unexpected inner output item rule: {:?}",
729                inner_item.as_rule()
730            ),
731        }
732    }
733
734    Ok((measures, output_mapping))
735}
736
737/// Parses the grouping section of an aggregate (everything before `=>`).
738///
739/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
740/// this parses `($0, $1), _`.
741///
742/// Each inner Vec is one grouping set; an empty vec represents no grouping (global aggregate).
743///
744/// Grammar: `aggregate_group_by = { grouping_set_list | expression_list }`
745fn parse_grouping_sets(
746    extensions: &SimpleExtensions,
747    inner: Pair<'_, Rule>,
748) -> Result<Vec<Vec<Expression>>, MessageParseError> {
749    assert!(
750        matches!(
751            inner.as_rule(),
752            Rule::expression_list | Rule::grouping_set_list
753        ),
754        "Expected expression_list or grouping_set_list, got {:?}",
755        inner.as_rule()
756    );
757    match inner.as_rule() {
758        Rule::expression_list => Ok(vec![parse_expression_list(extensions, inner)?]),
759        Rule::grouping_set_list => inner
760            .into_inner()
761            .map(|pair| parse_grouping_set(extensions, pair))
762            .collect(),
763        _ => unreachable!(
764            "Unexpected rule in aggregate_group_by: {:?}",
765            inner.as_rule()
766        ),
767    }
768}
769
770/// Parses a single grouping set, e.g. `($0, $1)` or `_`.
771///
772/// Grammar: `grouping_set = { ("(" ~ expression_list ~ ")") | empty }`
773fn parse_grouping_set(
774    extensions: &SimpleExtensions,
775    pair: Pair<'_, Rule>,
776) -> Result<Vec<Expression>, MessageParseError> {
777    assert_eq!(pair.as_rule(), Rule::grouping_set);
778    let inner = pair
779        .into_inner()
780        .next()
781        .expect("grouping_set must have one inner item");
782    match inner.as_rule() {
783        Rule::empty => Ok(vec![]),
784        Rule::expression_list => parse_expression_list(extensions, inner),
785        _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
786    }
787}
788
789/// Grammar: `expression_list = { expression ~ ("," ~ expression)* }`
790pub(crate) fn parse_expression_list(
791    extensions: &SimpleExtensions,
792    pair: Pair<'_, Rule>,
793) -> Result<Vec<Expression>, MessageParseError> {
794    pair.into_inner()
795        .map(|expr_pair| Expression::parse_pair(extensions, expr_pair))
796        .collect()
797}
798
799/// Deduplicates expressions across all sets and produces the AggregateRel's
800/// protobuf fields: a flat deduplicated expression list and per-set Grouping
801/// messages with index references into that list.
802fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
803    let mut expressions: Vec<Expression> = Vec::new();
804    let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
805
806    let groupings = expression_sets
807        .iter()
808        .map(|set| {
809            let expression_references = set
810                .iter()
811                .map(|exp| {
812                    // TODO: use a better key here than encoding to bytes.
813                    // Ideally, substrait-rs would support `PartialEq` and `Hash`,
814                    // but as there isn't an easy way to do that now, we'll skip.
815                    let key = exp.encode_to_vec();
816                    let next_idx = expressions.len() as u32;
817                    *seen.entry(key).or_insert_with(|| {
818                        expressions.push(exp.clone());
819                        next_idx
820                    })
821                })
822                .collect();
823            Grouping {
824                expression_references,
825                #[allow(deprecated)]
826                grouping_expressions: vec![],
827            }
828        })
829        .collect();
830
831    (groupings, expressions)
832}
833
834impl ScopedParsePair for SortField {
835    fn rule() -> Rule {
836        Rule::sort_field
837    }
838
839    fn message() -> &'static str {
840        "SortField"
841    }
842
843    fn parse_pair(
844        _extensions: &SimpleExtensions,
845        pair: Pair<Rule>,
846    ) -> Result<Self, MessageParseError> {
847        assert_eq!(pair.as_rule(), Self::rule());
848        let mut iter = RuleIter::from(pair.into_inner());
849        let reference_pair = iter.pop(Rule::reference);
850        let field_index = FieldIndex::parse_pair(reference_pair);
851        let direction_pair = iter.pop(Rule::sort_direction);
852        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
853        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
854        // identifiers, but the enum variant names don't include it
855        let direction = match direction_pair.as_str().trim_start_matches('&') {
856            "AscNullsFirst" => SortDirection::AscNullsFirst,
857            "AscNullsLast" => SortDirection::AscNullsLast,
858            "DescNullsFirst" => SortDirection::DescNullsFirst,
859            "DescNullsLast" => SortDirection::DescNullsLast,
860            other => {
861                return Err(MessageParseError::invalid(
862                    "SortDirection",
863                    direction_pair.as_span(),
864                    format!("Unknown sort direction: {other}"),
865                ));
866            }
867        };
868        iter.done();
869        Ok(SortField {
870            expr: Some(Expression {
871                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
872                    field_index.to_field_reference(),
873                ))),
874            }),
875            // TODO: Add support for SortKind::ComparisonFunctionReference
876            sort_kind: Some(SortKind::Direction(direction as i32)),
877        })
878    }
879}
880
881impl RelationParsePair for SortRel {
882    fn rule() -> Rule {
883        Rule::sort_relation
884    }
885
886    fn message() -> &'static str {
887        "SortRel"
888    }
889
890    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
891        self.advanced_extension = adv_ext;
892        Rel {
893            rel_type: Some(RelType::Sort(Box::new(self))),
894        }
895    }
896
897    fn parse_pair_with_context(
898        extensions: &SimpleExtensions,
899        pair: Pair<Rule>,
900        input_children: Vec<Rel>,
901        input_field_count: usize,
902    ) -> Result<(Self, usize), MessageParseError> {
903        assert_eq!(pair.as_rule(), Self::rule());
904        let input = expect_one_child(Self::message(), &pair, input_children)?;
905        let mut iter = RuleIter::from(pair.into_inner());
906        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
907        let reference_list_pair = iter.pop(Rule::reference_list);
908        let mut sorts = Vec::new();
909        for sort_field_pair in sort_field_list_pair.into_inner() {
910            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
911            sorts.push(sort_field);
912        }
913        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
914        let common = RelCommon {
915            emit_kind: Some(emit),
916            ..Default::default()
917        };
918        iter.done();
919        Ok((
920            SortRel {
921                input: Some(input),
922                sorts,
923                common: Some(common),
924                advanced_extension: None,
925            },
926            output_count,
927        ))
928    }
929}
930
931impl ScopedParsePair for CountMode {
932    fn rule() -> Rule {
933        Rule::fetch_value
934    }
935    fn message() -> &'static str {
936        "CountMode"
937    }
938    fn parse_pair(
939        extensions: &SimpleExtensions,
940        pair: Pair<Rule>,
941    ) -> Result<Self, MessageParseError> {
942        assert_eq!(pair.as_rule(), Self::rule());
943        let mut arg_inner = RuleIter::from(pair.into_inner());
944        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
945            int_pair
946        } else {
947            arg_inner.pop(Rule::expression)
948        };
949        match value_pair.as_rule() {
950            Rule::integer => {
951                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
952                    MessageParseError::invalid(
953                        Self::message(),
954                        value_pair.as_span(),
955                        format!("Invalid integer: {e}"),
956                    )
957                })?;
958                if value < 0 {
959                    return Err(MessageParseError::invalid(
960                        Self::message(),
961                        value_pair.as_span(),
962                        format!("Fetch limit must be non-negative, got: {value}"),
963                    ));
964                }
965                Ok(CountMode::CountExpr(i64_literal_expr(value)))
966            }
967            Rule::expression => {
968                let expr = Expression::parse_pair(extensions, value_pair)?;
969                Ok(CountMode::CountExpr(Box::new(expr)))
970            }
971            _ => Err(MessageParseError::invalid(
972                Self::message(),
973                value_pair.as_span(),
974                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
975            )),
976        }
977    }
978}
979
980fn i64_literal_expr(value: i64) -> Box<Expression> {
981    Box::new(Expression {
982        rex_type: Some(RexType::Literal(Literal {
983            nullable: false,
984            type_variation_reference: 0,
985            literal_type: Some(LiteralType::I64(value)),
986        })),
987    })
988}
989
990impl ScopedParsePair for OffsetMode {
991    fn rule() -> Rule {
992        Rule::fetch_value
993    }
994    fn message() -> &'static str {
995        "OffsetMode"
996    }
997    fn parse_pair(
998        extensions: &SimpleExtensions,
999        pair: Pair<Rule>,
1000    ) -> Result<Self, MessageParseError> {
1001        assert_eq!(pair.as_rule(), Self::rule());
1002        let mut arg_inner = RuleIter::from(pair.into_inner());
1003        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
1004            int_pair
1005        } else {
1006            arg_inner.pop(Rule::expression)
1007        };
1008        match value_pair.as_rule() {
1009            Rule::integer => {
1010                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
1011                    MessageParseError::invalid(
1012                        Self::message(),
1013                        value_pair.as_span(),
1014                        format!("Invalid integer: {e}"),
1015                    )
1016                })?;
1017                if value < 0 {
1018                    return Err(MessageParseError::invalid(
1019                        Self::message(),
1020                        value_pair.as_span(),
1021                        format!("Fetch offset must be non-negative, got: {value}"),
1022                    ));
1023                }
1024                Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
1025            }
1026            Rule::expression => {
1027                let expr = Expression::parse_pair(extensions, value_pair)?;
1028                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
1029            }
1030            _ => Err(MessageParseError::invalid(
1031                Self::message(),
1032                value_pair.as_span(),
1033                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
1034            )),
1035        }
1036    }
1037}
1038
1039impl RelationParsePair for FetchRel {
1040    fn rule() -> Rule {
1041        Rule::fetch_relation
1042    }
1043
1044    fn message() -> &'static str {
1045        "FetchRel"
1046    }
1047
1048    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1049        self.advanced_extension = adv_ext;
1050        Rel {
1051            rel_type: Some(RelType::Fetch(Box::new(self))),
1052        }
1053    }
1054
1055    fn parse_pair_with_context(
1056        extensions: &SimpleExtensions,
1057        pair: Pair<Rule>,
1058        input_children: Vec<Rel>,
1059        input_field_count: usize,
1060    ) -> Result<(Self, usize), MessageParseError> {
1061        assert_eq!(pair.as_rule(), Self::rule());
1062        let input = expect_one_child(Self::message(), &pair, input_children)?;
1063        let mut iter = RuleIter::from(pair.into_inner());
1064
1065        // Extract all pairs before any validation: RuleIter's Drop panics on
1066        // incomplete consumption, so we must exhaust the iterator before any
1067        // early return. Validation runs after iter.done() below.
1068        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
1069            None => {
1070                iter.pop(Rule::empty);
1071                (None, None)
1072            }
1073            Some(fetch_args_pair) => {
1074                let extractor =
1075                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
1076                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
1077                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
1078                extractor.done()?;
1079                (limit_pair, offset_pair)
1080            }
1081        };
1082
1083        let reference_list_pair = iter.pop(Rule::reference_list);
1084        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1085        let common = RelCommon {
1086            emit_kind: Some(emit),
1087            ..Default::default()
1088        };
1089        iter.done();
1090
1091        let count_mode = limit_pair
1092            .map(|pair| CountMode::parse_pair(extensions, pair))
1093            .transpose()?;
1094        let offset_mode = offset_pair
1095            .map(|pair| OffsetMode::parse_pair(extensions, pair))
1096            .transpose()?;
1097        Ok((
1098            FetchRel {
1099                input: Some(input),
1100                common: Some(common),
1101                advanced_extension: None,
1102                offset_mode,
1103                count_mode,
1104            },
1105            output_count,
1106        ))
1107    }
1108}
1109
1110impl ParsePair for join_rel::JoinType {
1111    fn rule() -> Rule {
1112        Rule::join_type
1113    }
1114
1115    fn message() -> &'static str {
1116        "JoinType"
1117    }
1118
1119    fn parse_pair(pair: Pair<Rule>) -> Self {
1120        assert_eq!(pair.as_rule(), Self::rule());
1121        let join_type_str = pair.as_str().trim_start_matches('&');
1122        match join_type_str {
1123            "Inner" => join_rel::JoinType::Inner,
1124            "Left" => join_rel::JoinType::Left,
1125            "Right" => join_rel::JoinType::Right,
1126            "Outer" => join_rel::JoinType::Outer,
1127            "LeftSemi" => join_rel::JoinType::LeftSemi,
1128            "RightSemi" => join_rel::JoinType::RightSemi,
1129            "LeftAnti" => join_rel::JoinType::LeftAnti,
1130            "RightAnti" => join_rel::JoinType::RightAnti,
1131            "LeftSingle" => join_rel::JoinType::LeftSingle,
1132            "RightSingle" => join_rel::JoinType::RightSingle,
1133            "LeftMark" => join_rel::JoinType::LeftMark,
1134            "RightMark" => join_rel::JoinType::RightMark,
1135            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
1136        }
1137    }
1138}
1139
1140impl RelationParsePair for JoinRel {
1141    fn rule() -> Rule {
1142        Rule::join_relation
1143    }
1144
1145    fn message() -> &'static str {
1146        "JoinRel"
1147    }
1148
1149    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1150        self.advanced_extension = adv_ext;
1151        Rel {
1152            rel_type: Some(RelType::Join(Box::new(self))),
1153        }
1154    }
1155
1156    fn parse_pair_with_context(
1157        extensions: &SimpleExtensions,
1158        pair: Pair<Rule>,
1159        input_children: Vec<Rel>,
1160        input_field_count: usize,
1161    ) -> Result<(Self, usize), MessageParseError> {
1162        assert_eq!(pair.as_rule(), Self::rule());
1163
1164        if input_children.len() != 2 {
1165            return Err(MessageParseError::invalid(
1166                Self::message(),
1167                pair.as_span(),
1168                format!(
1169                    "JoinRel should have exactly 2 input children, got {}",
1170                    input_children.len()
1171                ),
1172            ));
1173        }
1174
1175        let mut children_iter = input_children.into_iter();
1176        let left = Box::new(children_iter.next().unwrap());
1177        let right = Box::new(children_iter.next().unwrap());
1178
1179        let mut iter = RuleIter::from(pair.into_inner());
1180        let join_type = iter.parse_next::<join_rel::JoinType>();
1181        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
1182        let reference_list_pair = iter.pop(Rule::reference_list);
1183        iter.done();
1184
1185        // TODO: For semi/anti joins, the direct output width differs from
1186        // left+right — `input_field_count` would misclassify the emit as Direct.
1187        // Revisit when those join types are supported.
1188        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1189        let common = RelCommon {
1190            emit_kind: Some(emit),
1191            ..Default::default()
1192        };
1193
1194        Ok((
1195            JoinRel {
1196                common: Some(common),
1197                left: Some(left),
1198                right: Some(right),
1199                expression: Some(Box::new(condition)),
1200                post_join_filter: None, // not yet represented in the grammar
1201                r#type: join_type as i32,
1202                advanced_extension: None,
1203            },
1204            output_count,
1205        ))
1206    }
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211    use pest::Parser;
1212
1213    use super::*;
1214    use crate::fixtures::TestContext;
1215    use crate::parser::{ExpressionParser, Rule};
1216
1217    #[test]
1218    fn test_parse_relation() {
1219        // Removed: test_parse_relation for old Relation struct
1220    }
1221
1222    #[test]
1223    fn test_parse_read_relation() {
1224        let extensions = SimpleExtensions::default();
1225        let read = ReadRel::parse_pair_with_context(
1226            &extensions,
1227            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1228            vec![],
1229            0,
1230        )
1231        .unwrap()
1232        .0;
1233        let names = match &read.read_type {
1234            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1235            _ => panic!("Expected NamedTable"),
1236        };
1237        assert_eq!(names, &["ab", "cd", "ef"]);
1238        let columns = &read
1239            .base_schema
1240            .as_ref()
1241            .unwrap()
1242            .r#struct
1243            .as_ref()
1244            .unwrap()
1245            .types;
1246        assert_eq!(columns.len(), 2);
1247    }
1248
1249    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
1250    fn example_read_relation() -> ReadRel {
1251        let extensions = SimpleExtensions::default();
1252        ReadRel::parse_pair_with_context(
1253            &extensions,
1254            parse_exact(
1255                Rule::read_relation,
1256                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1257            ),
1258            vec![],
1259            0,
1260        )
1261        .unwrap()
1262        .0
1263    }
1264
1265    #[test]
1266    fn test_parse_filter_relation() {
1267        let extensions = SimpleExtensions::default();
1268        let filter = FilterRel::parse_pair_with_context(
1269            &extensions,
1270            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1271            vec![example_read_relation().into_rel(None)],
1272            3,
1273        )
1274        .unwrap()
1275        .0;
1276        // Identity mapping [0, 1, 2] over 3 inputs → Direct
1277        let emit_kind = filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1278        assert!(
1279            matches!(emit_kind, EmitKind::Direct(_)),
1280            "Expected Direct for identity emit, got {emit_kind:?}"
1281        );
1282    }
1283
1284    #[test]
1285    fn test_parse_project_relation() {
1286        let extensions = SimpleExtensions::default();
1287        let project = ProjectRel::parse_pair_with_context(
1288            &extensions,
1289            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1290            vec![example_read_relation().into_rel(None)],
1291            3,
1292        )
1293        .unwrap()
1294        .0;
1295
1296        // Should have 1 expression (42) and 2 references ($0, $1)
1297        assert_eq!(project.expressions.len(), 1);
1298
1299        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1300        let emit = match emit_kind {
1301            EmitKind::Emit(emit) => &emit.output_mapping,
1302            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1303        };
1304        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
1305        assert_eq!(emit, &[0, 1, 3]);
1306    }
1307
1308    #[test]
1309    fn test_parse_project_relation_complex() {
1310        let extensions = SimpleExtensions::default();
1311        let project = ProjectRel::parse_pair_with_context(
1312            &extensions,
1313            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1314            vec![example_read_relation().into_rel(None)],
1315            5, // Assume 5 input fields
1316        )
1317        .unwrap()
1318        .0;
1319
1320        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
1321        assert_eq!(project.expressions.len(), 2);
1322
1323        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1324        let emit = match emit_kind {
1325            EmitKind::Emit(emit) => &emit.output_mapping,
1326            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1327        };
1328        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
1329        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
1330        assert_eq!(emit, &[5, 0, 6, 2, 1]);
1331    }
1332
1333    #[test]
1334    fn test_parse_aggregate_relation() {
1335        let extensions = TestContext::new()
1336            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1337            .with_function(1, 10, "sum")
1338            .with_function(1, 11, "count")
1339            .extensions;
1340
1341        let aggregate = AggregateRel::parse_pair_with_context(
1342            &extensions,
1343            parse_exact(
1344                Rule::aggregate_relation,
1345                "Aggregate[($0, $1), _ => sum($2):i64, $0, count($2):i64]",
1346            ),
1347            vec![example_read_relation().into_rel(None)],
1348            3,
1349        )
1350        .unwrap()
1351        .0;
1352
1353        // Should have 2 group-by sets ($0, $1) and an empty group, and emit 2 measures (sum($2), count($2))
1354        assert_eq!(aggregate.grouping_expressions.len(), 2);
1355        assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1356        assert_eq!(aggregate.groupings.len(), 2);
1357        assert_eq!(aggregate.measures.len(), 2);
1358
1359        let emit_kind = &aggregate
1360            .common
1361            .as_ref()
1362            .unwrap()
1363            .emit_kind
1364            .as_ref()
1365            .unwrap();
1366        let emit = match emit_kind {
1367            EmitKind::Emit(emit) => &emit.output_mapping,
1368            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1369        };
1370        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1371        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1372        assert_eq!(emit, &[2, 0, 3]);
1373    }
1374
1375    #[test]
1376    fn test_parse_aggregate_relation_maintain_column_order() {
1377        let extensions = TestContext::new()
1378            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1379            .with_function(1, 10, "sum")
1380            .with_function(1, 11, "count")
1381            .extensions;
1382
1383        let aggregate = AggregateRel::parse_pair_with_context(
1384            &extensions,
1385            parse_exact(
1386                Rule::aggregate_relation,
1387                "Aggregate[$0 => sum($1):i64, $0, count($1):i64]",
1388            ),
1389            vec![example_read_relation().into_rel(None)],
1390            3,
1391        )
1392        .unwrap()
1393        .0;
1394
1395        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1396        assert_eq!(aggregate.grouping_expressions.len(), 1);
1397        assert_eq!(aggregate.groupings.len(), 1);
1398        assert_eq!(aggregate.measures.len(), 2);
1399
1400        let emit_kind = &aggregate
1401            .common
1402            .as_ref()
1403            .unwrap()
1404            .emit_kind
1405            .as_ref()
1406            .unwrap();
1407        let emit = match emit_kind {
1408            EmitKind::Emit(emit) => &emit.output_mapping,
1409            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1410        };
1411        // Output mapping should be [1, 0, 2] (grouping fields + measures)
1412        assert_eq!(emit, &[1, 0, 2]);
1413    }
1414
1415    #[test]
1416    fn test_parse_aggregate_relation_simple() {
1417        let extensions = TestContext::new()
1418            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1419            .with_function(1, 10, "sum")
1420            .extensions;
1421
1422        let aggregate = AggregateRel::parse_pair_with_context(
1423            &extensions,
1424            parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1):i64]"),
1425            vec![example_read_relation().into_rel(None)],
1426            3,
1427        )
1428        .unwrap()
1429        .0;
1430
1431        assert_eq!(aggregate.grouping_expressions.len(), 2);
1432        assert_eq!(aggregate.groupings.len(), 1);
1433        // expression_references must be positions [0, 1], not raw field indices [2, 0]
1434        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1435    }
1436
1437    #[test]
1438    fn test_parse_aggregate_relation_global_aggregate() {
1439        let extensions = TestContext::new()
1440            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1441            .with_function(1, 10, "sum")
1442            .with_function(1, 11, "count")
1443            .extensions;
1444
1445        let aggregate = AggregateRel::parse_pair_with_context(
1446            &extensions,
1447            parse_exact(
1448                Rule::aggregate_relation,
1449                "Aggregate[_ => sum($0):i64, count($1):i64]",
1450            ),
1451            vec![example_read_relation().into_rel(None)],
1452            3,
1453        )
1454        .unwrap()
1455        .0;
1456
1457        // Should have 0 group-by fields and 2 measures
1458        assert_eq!(aggregate.grouping_expressions.len(), 0);
1459        assert_eq!(aggregate.groupings.len(), 1);
1460        assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1461        assert_eq!(aggregate.measures.len(), 2);
1462
1463        // Identity mapping [0, 1] over 2 outputs (0 grouping + 2 measures) → Direct
1464        let emit_kind = aggregate
1465            .common
1466            .as_ref()
1467            .unwrap()
1468            .emit_kind
1469            .as_ref()
1470            .unwrap();
1471        assert!(
1472            matches!(emit_kind, EmitKind::Direct(_)),
1473            "Expected Direct for identity emit, got {emit_kind:?}"
1474        );
1475    }
1476
1477    #[test]
1478    fn test_parse_aggregate_relation_grouping_sets() {
1479        let extensions = TestContext::new()
1480            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1481            .with_function(1, 11, "count")
1482            .extensions;
1483
1484        let read_rel = ReadRel::parse_pair_with_context(
1485            &extensions,
1486            parse_exact(
1487                Rule::read_relation,
1488                "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1489            ),
1490            vec![],
1491            0,
1492        )
1493        .unwrap()
1494        .0;
1495
1496        let aggregate = AggregateRel::parse_pair_with_context(
1497            &extensions,
1498            parse_exact(
1499                Rule::aggregate_relation,
1500                "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3):i64]",
1501            ),
1502            vec![read_rel.into_rel(None)],
1503            4,
1504        )
1505        .unwrap()
1506        .0;
1507
1508        assert_eq!(aggregate.grouping_expressions.len(), 3);
1509        assert_eq!(aggregate.groupings.len(), 4);
1510        // ($0, $1, $2) -> [0, 1, 2]
1511        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1512        // ($2, $0) -> [2, 0] (reuses indices, different order)
1513        assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1514        // ($1) -> [1]
1515        assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1516        // _ -> empty
1517        assert!(aggregate.groupings[3].expression_references.is_empty());
1518        assert_eq!(aggregate.measures.len(), 1);
1519    }
1520
1521    #[test]
1522    fn test_fetch_relation_positive_values() {
1523        let extensions = SimpleExtensions::default();
1524
1525        // Test valid positive values should work
1526        let fetch_rel = FetchRel::parse_pair_with_context(
1527            &extensions,
1528            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1529            vec![example_read_relation().into_rel(None)],
1530            3,
1531        )
1532        .unwrap()
1533        .0;
1534
1535        // Verify the limit and offset values are correct
1536        assert_eq!(
1537            fetch_rel.count_mode,
1538            Some(CountMode::CountExpr(i64_literal_expr(10)))
1539        );
1540        assert_eq!(
1541            fetch_rel.offset_mode,
1542            Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1543        );
1544    }
1545
1546    #[test]
1547    fn test_fetch_relation_negative_limit_rejected() {
1548        let extensions = SimpleExtensions::default();
1549
1550        // Test that fetch relations with negative limits are properly rejected
1551        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1552        if let Ok(mut pairs) = parsed_result {
1553            let pair = pairs.next().unwrap();
1554            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1555                // Full parse succeeded, now test that validation catches the negative value
1556                let result = FetchRel::parse_pair_with_context(
1557                    &extensions,
1558                    pair,
1559                    vec![example_read_relation().into_rel(None)],
1560                    3,
1561                );
1562                assert!(result.is_err());
1563                let error_msg = result.unwrap_err().to_string();
1564                assert!(error_msg.contains("Fetch limit must be non-negative"));
1565            } else {
1566                // If grammar doesn't fully support negative values, that's also acceptable
1567                // since it would prevent negative values at parse time
1568                println!("Grammar prevents negative limit values at parse time");
1569            }
1570        } else {
1571            // Grammar doesn't support negative values in fetch context
1572            println!("Grammar prevents negative limit values at parse time");
1573        }
1574    }
1575
1576    #[test]
1577    fn test_fetch_relation_negative_offset_rejected() {
1578        let extensions = SimpleExtensions::default();
1579
1580        // Test that fetch relations with negative offsets are properly rejected
1581        let parsed_result =
1582            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1583        if let Ok(mut pairs) = parsed_result {
1584            let pair = pairs.next().unwrap();
1585            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1586                // Full parse succeeded, now test that validation catches the negative value
1587                let result = FetchRel::parse_pair_with_context(
1588                    &extensions,
1589                    pair,
1590                    vec![example_read_relation().into_rel(None)],
1591                    3,
1592                );
1593                assert!(result.is_err());
1594                let error_msg = result.unwrap_err().to_string();
1595                assert!(error_msg.contains("Fetch offset must be non-negative"));
1596            } else {
1597                // If grammar doesn't fully support negative values, that's also acceptable
1598                println!("Grammar prevents negative offset values at parse time");
1599            }
1600        } else {
1601            // Grammar doesn't support negative values in fetch context
1602            println!("Grammar prevents negative offset values at parse time");
1603        }
1604    }
1605
1606    #[test]
1607    fn test_parse_join_relation() {
1608        let extensions = TestContext::new()
1609            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1610            .with_function(1, 10, "eq")
1611            .extensions;
1612
1613        let left_rel = example_read_relation().into_rel(None);
1614        let right_rel = example_read_relation().into_rel(None);
1615
1616        let join = JoinRel::parse_pair_with_context(
1617            &extensions,
1618            parse_exact(
1619                Rule::join_relation,
1620                "Join[&Inner, eq($0, $3):boolean => $0, $1, $3, $4]",
1621            ),
1622            vec![left_rel, right_rel],
1623            6, // left (3) + right (3) = 6 total input fields
1624        )
1625        .unwrap()
1626        .0;
1627
1628        // Should be an Inner join
1629        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1630
1631        // Should have left and right relations
1632        assert!(join.left.is_some());
1633        assert!(join.right.is_some());
1634
1635        // Should have a join condition
1636        assert!(join.expression.is_some());
1637
1638        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1639        let emit = match emit_kind {
1640            EmitKind::Emit(emit) => &emit.output_mapping,
1641            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1642        };
1643        // Output mapping should be [0, 1, 3, 4] (selected columns)
1644        assert_eq!(emit, &[0, 1, 3, 4]);
1645    }
1646
1647    #[test]
1648    fn test_parse_join_relation_left_outer() {
1649        let extensions = TestContext::new()
1650            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1651            .with_function(1, 10, "eq")
1652            .extensions;
1653
1654        let left_rel = example_read_relation().into_rel(None);
1655        let right_rel = example_read_relation().into_rel(None);
1656
1657        let join = JoinRel::parse_pair_with_context(
1658            &extensions,
1659            parse_exact(
1660                Rule::join_relation,
1661                "Join[&Left, eq($0, $3):boolean => $0, $1, $2]",
1662            ),
1663            vec![left_rel, right_rel],
1664            6,
1665        )
1666        .unwrap()
1667        .0;
1668
1669        // Should be a Left join
1670        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1671
1672        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1673        let emit = match emit_kind {
1674            EmitKind::Emit(emit) => &emit.output_mapping,
1675            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1676        };
1677        // Output mapping should be [0, 1, 2]
1678        assert_eq!(emit, &[0, 1, 2]);
1679    }
1680
1681    #[test]
1682    fn test_parse_join_relation_left_semi() {
1683        let extensions = TestContext::new()
1684            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1685            .with_function(1, 10, "eq")
1686            .extensions;
1687
1688        let left_rel = example_read_relation().into_rel(None);
1689        let right_rel = example_read_relation().into_rel(None);
1690
1691        let join = JoinRel::parse_pair_with_context(
1692            &extensions,
1693            parse_exact(
1694                Rule::join_relation,
1695                "Join[&LeftSemi, eq($0, $3):boolean => $0, $1]",
1696            ),
1697            vec![left_rel, right_rel],
1698            6,
1699        )
1700        .unwrap()
1701        .0;
1702
1703        // Should be a LeftSemi join
1704        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1705
1706        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1707        let emit = match emit_kind {
1708            EmitKind::Emit(emit) => &emit.output_mapping,
1709            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1710        };
1711        // Output mapping should be [0, 1] (only left columns for semi join)
1712        assert_eq!(emit, &[0, 1]);
1713    }
1714
1715    #[test]
1716    fn test_parse_join_relation_requires_two_children() {
1717        let extensions = SimpleExtensions::default();
1718
1719        // Test with 0 children
1720        let result = JoinRel::parse_pair_with_context(
1721            &extensions,
1722            parse_exact(
1723                Rule::join_relation,
1724                "Join[&Inner, eq($0, $1):boolean => $0, $1]",
1725            ),
1726            vec![],
1727            0,
1728        );
1729        assert!(result.is_err());
1730
1731        // Test with 1 child
1732        let result = JoinRel::parse_pair_with_context(
1733            &extensions,
1734            parse_exact(
1735                Rule::join_relation,
1736                "Join[&Inner, eq($0, $1):boolean => $0, $1]",
1737            ),
1738            vec![example_read_relation().into_rel(None)],
1739            3,
1740        );
1741        assert!(result.is_err());
1742
1743        // Test with 3 children
1744        let result = JoinRel::parse_pair_with_context(
1745            &extensions,
1746            parse_exact(
1747                Rule::join_relation,
1748                "Join[&Inner, eq($0, $1):boolean => $0, $1]",
1749            ),
1750            vec![
1751                example_read_relation().into_rel(None),
1752                example_read_relation().into_rel(None),
1753                example_read_relation().into_rel(None),
1754            ],
1755            9,
1756        );
1757        assert!(result.is_err());
1758    }
1759
1760    fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1761        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1762        assert_eq!(pairs.as_str(), input);
1763        let pair = pairs.next().unwrap();
1764        assert_eq!(pairs.next(), None);
1765        pair
1766    }
1767}