substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType, nested};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::{CountMode, OffsetMode};
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::{Direct, Emit, EmitKind};
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
15    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
16};
17
18use super::{MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair};
19use crate::extensions::any::Any;
20use crate::extensions::registry::{ExtensionError, ExtensionType};
21use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions};
22use crate::parser::errors::{ParseContext, ParseError};
23use crate::parser::expressions::{FieldIndex, Name};
24
25/// Parsing context for relations that includes extensions, registry, and optional warning collection
26pub struct RelationParsingContext<'a> {
27    pub extensions: &'a SimpleExtensions,
28    pub registry: &'a ExtensionRegistry,
29    pub line_no: i64,
30    pub line: &'a str,
31}
32
33impl<'a> RelationParsingContext<'a> {
34    /// Resolve extension detail using registry. Any failure is treated as a hard parse error.
35    pub fn resolve_extension_detail(
36        &self,
37        extension_name: &str,
38        extension_args: &ExtensionArgs,
39    ) -> Result<Option<Any>, ParseError> {
40        let detail = self
41            .registry
42            .parse_extension(extension_name, extension_args);
43
44        match detail {
45            Ok(any) => Ok(Some(any)),
46            Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
47                name: extension_name.to_string(),
48                context: ParseContext::new(self.line_no, self.line.to_string()),
49            }),
50            Err(err) => Err(ParseError::ExtensionDetail(
51                ParseContext::new(self.line_no, self.line.to_string()),
52                err,
53            )),
54        }
55    }
56
57    /// Resolve an advanced-extension detail (enhancement or optimization) using the registry.
58    /// Any failure is treated as a hard parse error.
59    ///
60    /// `ext_type` must be `Enhancement` or `Optimization`; `Relation` is not valid here and
61    /// will panic via the unreachable branch in the dispatch below.
62    pub fn resolve_adv_ext_detail(
63        &self,
64        ext_type: ExtensionType,
65        name: &str,
66        args: &ExtensionArgs,
67    ) -> Result<Any, ParseError> {
68        let result = match ext_type {
69            ExtensionType::Enhancement => self.registry.parse_enhancement(name, args),
70            ExtensionType::Optimization => self.registry.parse_optimization(name, args),
71            ExtensionType::Relation => unreachable!("Relation is not an advanced extension type"),
72        };
73        result.map_err(|err| match err {
74            ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
75                name: name.to_string(),
76                context: ParseContext::new(self.line_no, self.line.to_string()),
77            },
78            err => ParseError::ExtensionDetail(
79                ParseContext::new(self.line_no, self.line.to_string()),
80                err,
81            ),
82        })
83    }
84}
85
86/// A trait for parsing relations with full context for tree building.
87pub trait RelationParsePair: Sized {
88    fn rule() -> Rule;
89    fn message() -> &'static str;
90
91    /// Parse the grammar pair into this relation type and its output field
92    /// count.
93    ///
94    /// Returns `(Self, usize)` where `usize` is the output field count —
95    /// computed during parsing when `input_field_count` is available.
96    fn parse_pair_with_context(
97        extensions: &SimpleExtensions,
98        pair: Pair<Rule>,
99        input_children: Vec<Box<Rel>>,
100        input_field_count: usize,
101    ) -> Result<(Self, usize), MessageParseError>;
102
103    /// Consume this parsed relation, apply the advanced extension, and produce
104    /// the final `Rel`.
105    fn into_rel(self, adv_ext: Option<AdvancedExtension>) -> Rel;
106}
107
108pub struct TableName(Vec<String>);
109
110impl ParsePair for TableName {
111    fn rule() -> Rule {
112        Rule::table_name
113    }
114
115    fn message() -> &'static str {
116        "TableName"
117    }
118
119    fn parse_pair(pair: Pair<Rule>) -> Self {
120        assert_eq!(pair.as_rule(), Self::rule());
121        let pairs = pair.into_inner();
122        let mut names = Vec::with_capacity(pairs.len());
123        let mut iter = RuleIter::from(pairs);
124        while let Some(name) = iter.parse_if_next::<Name>() {
125            names.push(name.0);
126        }
127        iter.done();
128        Self(names)
129    }
130}
131
132#[derive(Debug, Clone)]
133pub struct Column {
134    pub name: String,
135    pub typ: Type,
136}
137
138impl ScopedParsePair for Column {
139    fn rule() -> Rule {
140        Rule::named_column
141    }
142
143    fn message() -> &'static str {
144        "Column"
145    }
146
147    fn parse_pair(
148        extensions: &SimpleExtensions,
149        pair: Pair<Rule>,
150    ) -> Result<Self, MessageParseError> {
151        assert_eq!(pair.as_rule(), Self::rule());
152        let mut iter = RuleIter::from(pair.into_inner());
153        let name = iter.parse_next::<Name>().0;
154        let typ = iter.parse_next_scoped(extensions)?;
155        iter.done();
156        Ok(Self { name, typ })
157    }
158}
159
160pub struct NamedColumnList(Vec<Column>);
161
162impl ScopedParsePair for NamedColumnList {
163    fn rule() -> Rule {
164        Rule::named_column_list
165    }
166
167    fn message() -> &'static str {
168        "NamedColumnList"
169    }
170
171    fn parse_pair(
172        extensions: &SimpleExtensions,
173        pair: Pair<Rule>,
174    ) -> Result<Self, MessageParseError> {
175        assert_eq!(pair.as_rule(), Self::rule());
176        let mut columns = Vec::new();
177        for col in pair.into_inner() {
178            columns.push(Column::parse_pair(extensions, col)?);
179        }
180        Ok(Self(columns))
181    }
182}
183
184/// This is a utility function for extracting a single child from the list of
185/// children, to be used in the RelationParsePair trait. The RelationParsePair
186/// trait passes a Vec of children, because some relations have multiple
187/// children - but most accept exactly one child.
188#[allow(clippy::vec_box)]
189pub(crate) fn expect_one_child(
190    message: &'static str,
191    pair: &Pair<Rule>,
192    mut input_children: Vec<Box<Rel>>,
193) -> Result<Box<Rel>, MessageParseError> {
194    match input_children.len() {
195        0 => Err(MessageParseError::invalid(
196            message,
197            pair.as_span(),
198            format!("{message} missing child"),
199        )),
200        1 => Ok(input_children.pop().unwrap()),
201        n => Err(MessageParseError::invalid(
202            message,
203            pair.as_span(),
204            format!("{message} should have 1 input child, got {n}"),
205        )),
206    }
207}
208
209/// Parse a reference list Pair and return an EmitKind::Emit.
210/// Parse a reference list into field indices for emit mapping.
211fn parse_output_mapping(pair: Pair<Rule>) -> Vec<i32> {
212    assert_eq!(pair.as_rule(), Rule::reference_list);
213    pair.into_inner()
214        .map(|p| FieldIndex::parse_pair(p).0)
215        .collect()
216}
217
218/// Build an emit: `Direct` if the mapping is the identity `[0, 1, ..., N-1]`
219/// (where N = `direct_output_count`), otherwise `Emit` with the explicit mapping.
220fn make_emit(output_mapping: Vec<i32>, direct_output_count: usize) -> EmitKind {
221    let is_identity = output_mapping.len() == direct_output_count
222        && output_mapping
223            .iter()
224            .enumerate()
225            .all(|(i, &v)| v == i as i32);
226    if is_identity {
227        EmitKind::Direct(Direct {})
228    } else {
229        EmitKind::Emit(Emit { output_mapping })
230    }
231}
232
233/// Parse a reference list into an emit and output field count.
234fn parse_emit(reference_list: Pair<Rule>, direct_output_count: usize) -> (EmitKind, usize) {
235    let output_mapping = parse_output_mapping(reference_list);
236    let output_count = output_mapping.len();
237    let emit = make_emit(output_mapping, direct_output_count);
238    (emit, output_count)
239}
240
241/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
242///
243/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
244///
245/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
246pub struct ParsedNamedArgs<'a> {
247    map: HashMap<&'a str, Pair<'a, Rule>>,
248}
249
250impl<'a> ParsedNamedArgs<'a> {
251    pub fn new(
252        pairs: pest::iterators::Pairs<'a, Rule>,
253        rule: Rule,
254    ) -> Result<Self, MessageParseError> {
255        let mut map = HashMap::new();
256        for pair in pairs {
257            assert_eq!(pair.as_rule(), rule);
258            let mut inner = pair.clone().into_inner();
259            let name_pair = inner.next().unwrap();
260            let value_pair = inner.next().unwrap();
261            assert_eq!(inner.next(), None);
262            let name = name_pair.as_str();
263            if map.contains_key(name) {
264                return Err(MessageParseError::invalid(
265                    "NamedArg",
266                    name_pair.as_span(),
267                    format!("Duplicate argument: {name}"),
268                ));
269            }
270            map.insert(name, value_pair);
271        }
272        Ok(Self { map })
273    }
274
275    // Returns the pair if it exists and matches the rule, otherwise None.
276    // Asserts that the rule must match the rule of the pair (and therefore
277    // panics in non-release-mode if not)
278    pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
279        let pair = self.map.remove(name).inspect(|pair| {
280            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
281        });
282        (self, pair)
283    }
284
285    // Returns an error if there are any unused arguments.
286    pub fn done(self) -> Result<(), MessageParseError> {
287        if let Some((name, pair)) = self.map.iter().next() {
288            return Err(MessageParseError::invalid(
289                "NamedArgExtractor",
290                // No span available for all unused args; use default.
291                pair.as_span(),
292                format!("Unknown argument: {name}"),
293            ));
294        }
295        Ok(())
296    }
297}
298
299impl RelationParsePair for ReadRel {
300    fn rule() -> Rule {
301        Rule::read_relation
302    }
303
304    fn message() -> &'static str {
305        "ReadRel"
306    }
307
308    fn parse_pair_with_context(
309        extensions: &SimpleExtensions,
310        pair: Pair<Rule>,
311        input_children: Vec<Box<Rel>>,
312        input_field_count: usize,
313    ) -> Result<(Self, usize), MessageParseError> {
314        assert_eq!(pair.as_rule(), Self::rule());
315        // ReadRel is a leaf node - it should have no input children and 0 input fields
316        if !input_children.is_empty() {
317            return Err(MessageParseError::invalid(
318                Self::message(),
319                pair.as_span(),
320                "ReadRel should have no input children",
321            ));
322        }
323        if input_field_count != 0 {
324            return Err(MessageParseError::invalid(
325                "ReadRel",
326                pair.as_span(),
327                "ReadRel should have 0 input fields",
328            ));
329        }
330
331        let mut iter = RuleIter::from(pair.into_inner());
332        let table = iter.parse_next::<TableName>().0;
333        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
334        iter.done();
335
336        let output_count = columns.len();
337        Ok((
338            ReadRel {
339                base_schema: Some(build_named_struct(columns)),
340                read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
341                    names: table,
342                    advanced_extension: None,
343                })),
344                ..Default::default()
345            },
346            output_count,
347        ))
348    }
349
350    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
351        self.advanced_extension = adv_ext;
352        Rel {
353            rel_type: Some(RelType::Read(Box::new(self))),
354        }
355    }
356}
357
358/// Parsed `Read:Virtual[rows => columns]` relation. Needs a newtype because the
359/// proto type is `ReadRel` (same as `NamedTable`), but the grammar and handling
360/// are different.
361pub(crate) struct VirtualReadRel(ReadRel);
362
363impl RelationParsePair for VirtualReadRel {
364    fn rule() -> Rule {
365        Rule::virtual_read_relation
366    }
367
368    fn message() -> &'static str {
369        "VirtualReadRel"
370    }
371
372    fn parse_pair_with_context(
373        extensions: &SimpleExtensions,
374        pair: Pair<Rule>,
375        input_children: Vec<Box<Rel>>,
376        _input_field_count: usize,
377    ) -> Result<(Self, usize), MessageParseError> {
378        assert_eq!(pair.as_rule(), Self::rule());
379        if !input_children.is_empty() {
380            return Err(MessageParseError::invalid(
381                Self::message(),
382                pair.as_span(),
383                "Read:Virtual should have no input children",
384            ));
385        }
386
387        let mut iter = RuleIter::from(pair.into_inner());
388        let args_pair = iter.pop(Rule::virtual_read_args);
389        let columns_pair = iter.pop(Rule::named_column_list);
390        iter.done();
391
392        let rows = parse_virtual_read_args(extensions, args_pair)?;
393        let columns = NamedColumnList::parse_pair(extensions, columns_pair)?.0;
394
395        // TODO: Validate that each row has the same number of expressions as
396        // columns. Currently no parser-side warning mechanism exists, and while
397        // this is an invalid plan, it is constructible as Substrait. Consider
398        // adding once a warning collection path is available.
399        let output_count = columns.len();
400        Ok((
401            VirtualReadRel(ReadRel {
402                base_schema: Some(build_named_struct(columns)),
403                read_type: Some(read_rel::ReadType::VirtualTable(read_rel::VirtualTable {
404                    expressions: rows,
405                    ..Default::default()
406                })),
407                ..Default::default()
408            }),
409            output_count,
410        ))
411    }
412
413    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
414        self.0.advanced_extension = adv_ext;
415        Rel {
416            rel_type: Some(RelType::Read(Box::new(self.0))),
417        }
418    }
419}
420
421/// Build a `NamedStruct` from parsed columns.
422fn build_named_struct(columns: Vec<Column>) -> NamedStruct {
423    let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
424    NamedStruct {
425        names,
426        r#struct: Some(r#type::Struct {
427            types,
428            type_variation_reference: 0,
429            nullability: r#type::Nullability::Required as i32,
430        }),
431    }
432}
433
434/// `Read:Virtual` positional args: either `empty` or a list of row tuples.
435fn parse_virtual_read_args(
436    extensions: &SimpleExtensions,
437    pair: Pair<Rule>,
438) -> Result<Vec<nested::Struct>, MessageParseError> {
439    assert_eq!(pair.as_rule(), Rule::virtual_read_args);
440    let inner = unwrap_single_pair(pair);
441    match inner.as_rule() {
442        Rule::empty => Ok(vec![]),
443        Rule::virtual_row_list => inner
444            .into_inner()
445            .map(|row| parse_virtual_row(extensions, row))
446            .collect(),
447        _ => unreachable!(
448            "Unexpected rule in virtual_read_args: {:?}",
449            inner.as_rule()
450        ),
451    }
452}
453
454/// Parse a single `virtual_row` (`(expr, expr, ...)` or `()`) into a `nested::Struct`.
455fn parse_virtual_row(
456    extensions: &SimpleExtensions,
457    pair: Pair<Rule>,
458) -> Result<nested::Struct, MessageParseError> {
459    assert_eq!(pair.as_rule(), Rule::virtual_row);
460    let fields = match pair.into_inner().next() {
461        Some(expression_list) => {
462            assert_eq!(expression_list.as_rule(), Rule::expression_list);
463            parse_expression_list(extensions, expression_list)?
464        }
465        // Empty virtual row. An unusual but valid case.
466        None => vec![],
467    };
468    Ok(nested::Struct { fields })
469}
470
471impl RelationParsePair for FilterRel {
472    fn rule() -> Rule {
473        Rule::filter_relation
474    }
475
476    fn message() -> &'static str {
477        "FilterRel"
478    }
479
480    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
481        self.advanced_extension = adv_ext;
482        Rel {
483            rel_type: Some(RelType::Filter(Box::new(self))),
484        }
485    }
486
487    fn parse_pair_with_context(
488        extensions: &SimpleExtensions,
489        pair: Pair<Rule>,
490        input_children: Vec<Box<Rel>>,
491        input_field_count: usize,
492    ) -> Result<(Self, usize), MessageParseError> {
493        assert_eq!(pair.as_rule(), Self::rule());
494        let input = expect_one_child(Self::message(), &pair, input_children)?;
495        let mut iter = RuleIter::from(pair.into_inner());
496        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
497        // references (which become the emit)
498        let references_pair = iter.pop(Rule::reference_list);
499        iter.done();
500
501        let (emit, output_count) = parse_emit(references_pair, input_field_count);
502        let common = RelCommon {
503            emit_kind: Some(emit),
504            ..Default::default()
505        };
506
507        Ok((
508            FilterRel {
509                input: Some(input),
510                condition: Some(Box::new(condition)),
511                common: Some(common),
512                advanced_extension: None,
513            },
514            output_count,
515        ))
516    }
517}
518
519impl RelationParsePair for ProjectRel {
520    fn rule() -> Rule {
521        Rule::project_relation
522    }
523
524    fn message() -> &'static str {
525        "ProjectRel"
526    }
527
528    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
529        self.advanced_extension = adv_ext;
530        Rel {
531            rel_type: Some(RelType::Project(Box::new(self))),
532        }
533    }
534
535    fn parse_pair_with_context(
536        extensions: &SimpleExtensions,
537        pair: Pair<Rule>,
538        input_children: Vec<Box<Rel>>,
539        input_field_count: usize,
540    ) -> Result<(Self, usize), MessageParseError> {
541        assert_eq!(pair.as_rule(), Self::rule());
542        let input = expect_one_child(Self::message(), &pair, input_children)?;
543
544        let arguments_pair = unwrap_single_pair(pair);
545
546        let mut expressions = Vec::new();
547        let mut output_mapping = Vec::new();
548
549        for arg in arguments_pair.into_inner() {
550            let inner_arg = unwrap_single_pair(arg);
551            match inner_arg.as_rule() {
552                Rule::reference => {
553                    let field_index = FieldIndex::parse_pair(inner_arg);
554                    output_mapping.push(field_index.0);
555                }
556                Rule::expression => {
557                    let expr = Expression::parse_pair(extensions, inner_arg)?;
558                    expressions.push(expr);
559                    // Index into the combined schema: [input fields][computed expressions].
560                    output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
561                }
562                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
563            }
564        }
565
566        let output_count = output_mapping.len();
567        let direct_count = input_field_count + expressions.len();
568        let emit = make_emit(output_mapping, direct_count);
569        let common = RelCommon {
570            emit_kind: Some(emit),
571            ..Default::default()
572        };
573
574        Ok((
575            ProjectRel {
576                input: Some(input),
577                expressions,
578                common: Some(common),
579                advanced_extension: None,
580            },
581            output_count,
582        ))
583    }
584}
585
586impl RelationParsePair for AggregateRel {
587    fn rule() -> Rule {
588        Rule::aggregate_relation
589    }
590
591    fn message() -> &'static str {
592        "AggregateRel"
593    }
594
595    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
596        self.advanced_extension = adv_ext;
597        Rel {
598            rel_type: Some(RelType::Aggregate(Box::new(self))),
599        }
600    }
601
602    fn parse_pair_with_context(
603        extensions: &SimpleExtensions,
604        pair: Pair<Rule>,
605        input_children: Vec<Box<Rel>>,
606        // Aggregate defines its own output schema (grouping keys + measures),
607        // so the input field count isn't needed for emit construction.
608        _input_field_count: usize,
609    ) -> Result<(Self, usize), MessageParseError> {
610        assert_eq!(pair.as_rule(), Self::rule());
611        let input = expect_one_child(Self::message(), &pair, input_children)?;
612        let mut iter = RuleIter::from(pair.into_inner());
613        let group_by_pair = iter.pop(Rule::aggregate_group_by);
614        let output_pair = iter.pop(Rule::aggregate_output);
615        iter.done();
616
617        let inner = group_by_pair
618            .into_inner()
619            .next()
620            .expect("aggregate_group_by must have one inner item");
621
622        let grouping_sets = parse_grouping_sets(extensions, inner)?;
623        let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
624
625        let (measures, output_mapping) =
626            parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
627
628        let output_count = output_mapping.len();
629        let direct_count = grouping_expressions.len() + measures.len();
630        let emit = make_emit(output_mapping, direct_count);
631        let common = RelCommon {
632            emit_kind: Some(emit),
633            ..Default::default()
634        };
635
636        Ok((
637            AggregateRel {
638                input: Some(input),
639                grouping_expressions,
640                groupings,
641                measures,
642                common: Some(common),
643                advanced_extension: None,
644            },
645            output_count,
646        ))
647    }
648}
649
650/// Parses the output section of an aggregate (everything after `=>`).
651///
652/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
653/// this parses `sum($2), $0, count($2)`.
654fn parse_aggregate_measures(
655    extensions: &SimpleExtensions,
656    output_pair: Pair<'_, Rule>,
657    grouping_expressions: &[Expression],
658) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
659    assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
660    let mut measures = Vec::new();
661    let mut output_mapping = Vec::new();
662
663    for aggregate_output_item in output_pair.into_inner() {
664        let inner_item = unwrap_single_pair(aggregate_output_item);
665        match inner_item.as_rule() {
666            Rule::reference => {
667                let field_index = FieldIndex::parse_pair(inner_item);
668                output_mapping.push(field_index.0);
669            }
670            Rule::aggregate_measure => {
671                let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
672                output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
673                measures.push(measure);
674            }
675            _ => panic!(
676                "Unexpected inner output item rule: {:?}",
677                inner_item.as_rule()
678            ),
679        }
680    }
681
682    Ok((measures, output_mapping))
683}
684
685/// Parses the grouping section of an aggregate (everything before `=>`).
686///
687/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
688/// this parses `($0, $1), _`.
689///
690/// Each inner Vec is one grouping set; an empty vec represents no grouping (global aggregate).
691///
692/// Grammar: `aggregate_group_by = { grouping_set_list | expression_list }`
693fn parse_grouping_sets(
694    extensions: &SimpleExtensions,
695    inner: Pair<'_, Rule>,
696) -> Result<Vec<Vec<Expression>>, MessageParseError> {
697    assert!(
698        matches!(
699            inner.as_rule(),
700            Rule::expression_list | Rule::grouping_set_list
701        ),
702        "Expected expression_list or grouping_set_list, got {:?}",
703        inner.as_rule()
704    );
705    match inner.as_rule() {
706        Rule::expression_list => Ok(vec![parse_expression_list(extensions, inner)?]),
707        Rule::grouping_set_list => inner
708            .into_inner()
709            .map(|pair| parse_grouping_set(extensions, pair))
710            .collect(),
711        _ => unreachable!(
712            "Unexpected rule in aggregate_group_by: {:?}",
713            inner.as_rule()
714        ),
715    }
716}
717
718/// Parses a single grouping set, e.g. `($0, $1)` or `_`.
719///
720/// Grammar: `grouping_set = { ("(" ~ expression_list ~ ")") | empty }`
721fn parse_grouping_set(
722    extensions: &SimpleExtensions,
723    pair: Pair<'_, Rule>,
724) -> Result<Vec<Expression>, MessageParseError> {
725    assert_eq!(pair.as_rule(), Rule::grouping_set);
726    let inner = pair
727        .into_inner()
728        .next()
729        .expect("grouping_set must have one inner item");
730    match inner.as_rule() {
731        Rule::empty => Ok(vec![]),
732        Rule::expression_list => parse_expression_list(extensions, inner),
733        _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
734    }
735}
736
737/// Grammar: `expression_list = { expression ~ ("," ~ expression)* }`
738pub(crate) fn parse_expression_list(
739    extensions: &SimpleExtensions,
740    pair: Pair<'_, Rule>,
741) -> Result<Vec<Expression>, MessageParseError> {
742    pair.into_inner()
743        .map(|expr_pair| Expression::parse_pair(extensions, expr_pair))
744        .collect()
745}
746
747/// Deduplicates expressions across all sets and produces the AggregateRel's
748/// protobuf fields: a flat deduplicated expression list and per-set Grouping
749/// messages with index references into that list.
750fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
751    let mut expressions: Vec<Expression> = Vec::new();
752    let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
753
754    let groupings = expression_sets
755        .iter()
756        .map(|set| {
757            let expression_references = set
758                .iter()
759                .map(|exp| {
760                    // TODO: use a better key here than encoding to bytes.
761                    // Ideally, substrait-rs would support `PartialEq` and `Hash`,
762                    // but as there isn't an easy way to do that now, we'll skip.
763                    let key = exp.encode_to_vec();
764                    let next_idx = expressions.len() as u32;
765                    *seen.entry(key).or_insert_with(|| {
766                        expressions.push(exp.clone());
767                        next_idx
768                    })
769                })
770                .collect();
771            Grouping {
772                expression_references,
773                #[allow(deprecated)]
774                grouping_expressions: vec![],
775            }
776        })
777        .collect();
778
779    (groupings, expressions)
780}
781
782impl ScopedParsePair for SortField {
783    fn rule() -> Rule {
784        Rule::sort_field
785    }
786
787    fn message() -> &'static str {
788        "SortField"
789    }
790
791    fn parse_pair(
792        _extensions: &SimpleExtensions,
793        pair: Pair<Rule>,
794    ) -> Result<Self, MessageParseError> {
795        assert_eq!(pair.as_rule(), Self::rule());
796        let mut iter = RuleIter::from(pair.into_inner());
797        let reference_pair = iter.pop(Rule::reference);
798        let field_index = FieldIndex::parse_pair(reference_pair);
799        let direction_pair = iter.pop(Rule::sort_direction);
800        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
801        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
802        // identifiers, but the enum variant names don't include it
803        let direction = match direction_pair.as_str().trim_start_matches('&') {
804            "AscNullsFirst" => SortDirection::AscNullsFirst,
805            "AscNullsLast" => SortDirection::AscNullsLast,
806            "DescNullsFirst" => SortDirection::DescNullsFirst,
807            "DescNullsLast" => SortDirection::DescNullsLast,
808            other => {
809                return Err(MessageParseError::invalid(
810                    "SortDirection",
811                    direction_pair.as_span(),
812                    format!("Unknown sort direction: {other}"),
813                ));
814            }
815        };
816        iter.done();
817        Ok(SortField {
818            expr: Some(Expression {
819                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
820                    field_index.to_field_reference(),
821                ))),
822            }),
823            // TODO: Add support for SortKind::ComparisonFunctionReference
824            sort_kind: Some(SortKind::Direction(direction as i32)),
825        })
826    }
827}
828
829impl RelationParsePair for SortRel {
830    fn rule() -> Rule {
831        Rule::sort_relation
832    }
833
834    fn message() -> &'static str {
835        "SortRel"
836    }
837
838    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
839        self.advanced_extension = adv_ext;
840        Rel {
841            rel_type: Some(RelType::Sort(Box::new(self))),
842        }
843    }
844
845    fn parse_pair_with_context(
846        extensions: &SimpleExtensions,
847        pair: Pair<Rule>,
848        input_children: Vec<Box<Rel>>,
849        input_field_count: usize,
850    ) -> Result<(Self, usize), MessageParseError> {
851        assert_eq!(pair.as_rule(), Self::rule());
852        let input = expect_one_child(Self::message(), &pair, input_children)?;
853        let mut iter = RuleIter::from(pair.into_inner());
854        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
855        let reference_list_pair = iter.pop(Rule::reference_list);
856        let mut sorts = Vec::new();
857        for sort_field_pair in sort_field_list_pair.into_inner() {
858            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
859            sorts.push(sort_field);
860        }
861        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
862        let common = RelCommon {
863            emit_kind: Some(emit),
864            ..Default::default()
865        };
866        iter.done();
867        Ok((
868            SortRel {
869                input: Some(input),
870                sorts,
871                common: Some(common),
872                advanced_extension: None,
873            },
874            output_count,
875        ))
876    }
877}
878
879impl ScopedParsePair for CountMode {
880    fn rule() -> Rule {
881        Rule::fetch_value
882    }
883    fn message() -> &'static str {
884        "CountMode"
885    }
886    fn parse_pair(
887        extensions: &SimpleExtensions,
888        pair: Pair<Rule>,
889    ) -> Result<Self, MessageParseError> {
890        assert_eq!(pair.as_rule(), Self::rule());
891        let mut arg_inner = RuleIter::from(pair.into_inner());
892        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
893            int_pair
894        } else {
895            arg_inner.pop(Rule::expression)
896        };
897        match value_pair.as_rule() {
898            Rule::integer => {
899                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
900                    MessageParseError::invalid(
901                        Self::message(),
902                        value_pair.as_span(),
903                        format!("Invalid integer: {e}"),
904                    )
905                })?;
906                if value < 0 {
907                    return Err(MessageParseError::invalid(
908                        Self::message(),
909                        value_pair.as_span(),
910                        format!("Fetch limit must be non-negative, got: {value}"),
911                    ));
912                }
913                Ok(CountMode::CountExpr(i64_literal_expr(value)))
914            }
915            Rule::expression => {
916                let expr = Expression::parse_pair(extensions, value_pair)?;
917                Ok(CountMode::CountExpr(Box::new(expr)))
918            }
919            _ => Err(MessageParseError::invalid(
920                Self::message(),
921                value_pair.as_span(),
922                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
923            )),
924        }
925    }
926}
927
928fn i64_literal_expr(value: i64) -> Box<Expression> {
929    Box::new(Expression {
930        rex_type: Some(RexType::Literal(Literal {
931            nullable: false,
932            type_variation_reference: 0,
933            literal_type: Some(LiteralType::I64(value)),
934        })),
935    })
936}
937
938impl ScopedParsePair for OffsetMode {
939    fn rule() -> Rule {
940        Rule::fetch_value
941    }
942    fn message() -> &'static str {
943        "OffsetMode"
944    }
945    fn parse_pair(
946        extensions: &SimpleExtensions,
947        pair: Pair<Rule>,
948    ) -> Result<Self, MessageParseError> {
949        assert_eq!(pair.as_rule(), Self::rule());
950        let mut arg_inner = RuleIter::from(pair.into_inner());
951        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
952            int_pair
953        } else {
954            arg_inner.pop(Rule::expression)
955        };
956        match value_pair.as_rule() {
957            Rule::integer => {
958                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
959                    MessageParseError::invalid(
960                        Self::message(),
961                        value_pair.as_span(),
962                        format!("Invalid integer: {e}"),
963                    )
964                })?;
965                if value < 0 {
966                    return Err(MessageParseError::invalid(
967                        Self::message(),
968                        value_pair.as_span(),
969                        format!("Fetch offset must be non-negative, got: {value}"),
970                    ));
971                }
972                Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
973            }
974            Rule::expression => {
975                let expr = Expression::parse_pair(extensions, value_pair)?;
976                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
977            }
978            _ => Err(MessageParseError::invalid(
979                Self::message(),
980                value_pair.as_span(),
981                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
982            )),
983        }
984    }
985}
986
987impl RelationParsePair for FetchRel {
988    fn rule() -> Rule {
989        Rule::fetch_relation
990    }
991
992    fn message() -> &'static str {
993        "FetchRel"
994    }
995
996    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
997        self.advanced_extension = adv_ext;
998        Rel {
999            rel_type: Some(RelType::Fetch(Box::new(self))),
1000        }
1001    }
1002
1003    fn parse_pair_with_context(
1004        extensions: &SimpleExtensions,
1005        pair: Pair<Rule>,
1006        input_children: Vec<Box<Rel>>,
1007        input_field_count: usize,
1008    ) -> Result<(Self, usize), MessageParseError> {
1009        assert_eq!(pair.as_rule(), Self::rule());
1010        let input = expect_one_child(Self::message(), &pair, input_children)?;
1011        let mut iter = RuleIter::from(pair.into_inner());
1012
1013        // Extract all pairs before any validation: RuleIter's Drop panics on
1014        // incomplete consumption, so we must exhaust the iterator before any
1015        // early return. Validation runs after iter.done() below.
1016        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
1017            None => {
1018                iter.pop(Rule::empty);
1019                (None, None)
1020            }
1021            Some(fetch_args_pair) => {
1022                let extractor =
1023                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
1024                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
1025                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
1026                extractor.done()?;
1027                (limit_pair, offset_pair)
1028            }
1029        };
1030
1031        let reference_list_pair = iter.pop(Rule::reference_list);
1032        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1033        let common = RelCommon {
1034            emit_kind: Some(emit),
1035            ..Default::default()
1036        };
1037        iter.done();
1038
1039        let count_mode = limit_pair
1040            .map(|pair| CountMode::parse_pair(extensions, pair))
1041            .transpose()?;
1042        let offset_mode = offset_pair
1043            .map(|pair| OffsetMode::parse_pair(extensions, pair))
1044            .transpose()?;
1045        Ok((
1046            FetchRel {
1047                input: Some(input),
1048                common: Some(common),
1049                advanced_extension: None,
1050                offset_mode,
1051                count_mode,
1052            },
1053            output_count,
1054        ))
1055    }
1056}
1057
1058impl ParsePair for join_rel::JoinType {
1059    fn rule() -> Rule {
1060        Rule::join_type
1061    }
1062
1063    fn message() -> &'static str {
1064        "JoinType"
1065    }
1066
1067    fn parse_pair(pair: Pair<Rule>) -> Self {
1068        assert_eq!(pair.as_rule(), Self::rule());
1069        let join_type_str = pair.as_str().trim_start_matches('&');
1070        match join_type_str {
1071            "Inner" => join_rel::JoinType::Inner,
1072            "Left" => join_rel::JoinType::Left,
1073            "Right" => join_rel::JoinType::Right,
1074            "Outer" => join_rel::JoinType::Outer,
1075            "LeftSemi" => join_rel::JoinType::LeftSemi,
1076            "RightSemi" => join_rel::JoinType::RightSemi,
1077            "LeftAnti" => join_rel::JoinType::LeftAnti,
1078            "RightAnti" => join_rel::JoinType::RightAnti,
1079            "LeftSingle" => join_rel::JoinType::LeftSingle,
1080            "RightSingle" => join_rel::JoinType::RightSingle,
1081            "LeftMark" => join_rel::JoinType::LeftMark,
1082            "RightMark" => join_rel::JoinType::RightMark,
1083            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
1084        }
1085    }
1086}
1087
1088impl RelationParsePair for JoinRel {
1089    fn rule() -> Rule {
1090        Rule::join_relation
1091    }
1092
1093    fn message() -> &'static str {
1094        "JoinRel"
1095    }
1096
1097    fn into_rel(mut self, adv_ext: Option<AdvancedExtension>) -> Rel {
1098        self.advanced_extension = adv_ext;
1099        Rel {
1100            rel_type: Some(RelType::Join(Box::new(self))),
1101        }
1102    }
1103
1104    fn parse_pair_with_context(
1105        extensions: &SimpleExtensions,
1106        pair: Pair<Rule>,
1107        input_children: Vec<Box<Rel>>,
1108        input_field_count: usize,
1109    ) -> Result<(Self, usize), MessageParseError> {
1110        assert_eq!(pair.as_rule(), Self::rule());
1111
1112        if input_children.len() != 2 {
1113            return Err(MessageParseError::invalid(
1114                Self::message(),
1115                pair.as_span(),
1116                format!(
1117                    "JoinRel should have exactly 2 input children, got {}",
1118                    input_children.len()
1119                ),
1120            ));
1121        }
1122
1123        let mut children_iter = input_children.into_iter();
1124        let left = children_iter.next().unwrap();
1125        let right = children_iter.next().unwrap();
1126
1127        let mut iter = RuleIter::from(pair.into_inner());
1128        let join_type = iter.parse_next::<join_rel::JoinType>();
1129        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
1130        let reference_list_pair = iter.pop(Rule::reference_list);
1131        iter.done();
1132
1133        // TODO: For semi/anti joins, the direct output width differs from
1134        // left+right — `input_field_count` would misclassify the emit as Direct.
1135        // Revisit when those join types are supported.
1136        let (emit, output_count) = parse_emit(reference_list_pair, input_field_count);
1137        let common = RelCommon {
1138            emit_kind: Some(emit),
1139            ..Default::default()
1140        };
1141
1142        Ok((
1143            JoinRel {
1144                common: Some(common),
1145                left: Some(left),
1146                right: Some(right),
1147                expression: Some(Box::new(condition)),
1148                post_join_filter: None, // not yet represented in the grammar
1149                r#type: join_type as i32,
1150                advanced_extension: None,
1151            },
1152            output_count,
1153        ))
1154    }
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159    use pest::Parser;
1160
1161    use super::*;
1162    use crate::fixtures::TestContext;
1163    use crate::parser::{ExpressionParser, Rule};
1164
1165    #[test]
1166    fn test_parse_relation() {
1167        // Removed: test_parse_relation for old Relation struct
1168    }
1169
1170    #[test]
1171    fn test_parse_read_relation() {
1172        let extensions = SimpleExtensions::default();
1173        let read = ReadRel::parse_pair_with_context(
1174            &extensions,
1175            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1176            vec![],
1177            0,
1178        )
1179        .unwrap()
1180        .0;
1181        let names = match &read.read_type {
1182            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1183            _ => panic!("Expected NamedTable"),
1184        };
1185        assert_eq!(names, &["ab", "cd", "ef"]);
1186        let columns = &read
1187            .base_schema
1188            .as_ref()
1189            .unwrap()
1190            .r#struct
1191            .as_ref()
1192            .unwrap()
1193            .types;
1194        assert_eq!(columns.len(), 2);
1195    }
1196
1197    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
1198    fn example_read_relation() -> ReadRel {
1199        let extensions = SimpleExtensions::default();
1200        ReadRel::parse_pair_with_context(
1201            &extensions,
1202            parse_exact(
1203                Rule::read_relation,
1204                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1205            ),
1206            vec![],
1207            0,
1208        )
1209        .unwrap()
1210        .0
1211    }
1212
1213    #[test]
1214    fn test_parse_filter_relation() {
1215        let extensions = SimpleExtensions::default();
1216        let filter = FilterRel::parse_pair_with_context(
1217            &extensions,
1218            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1219            vec![Box::new(example_read_relation().into_rel(None))],
1220            3,
1221        )
1222        .unwrap()
1223        .0;
1224        // Identity mapping [0, 1, 2] over 3 inputs → Direct
1225        let emit_kind = filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1226        assert!(
1227            matches!(emit_kind, EmitKind::Direct(_)),
1228            "Expected Direct for identity emit, got {emit_kind:?}"
1229        );
1230    }
1231
1232    #[test]
1233    fn test_parse_project_relation() {
1234        let extensions = SimpleExtensions::default();
1235        let project = ProjectRel::parse_pair_with_context(
1236            &extensions,
1237            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1238            vec![Box::new(example_read_relation().into_rel(None))],
1239            3,
1240        )
1241        .unwrap()
1242        .0;
1243
1244        // Should have 1 expression (42) and 2 references ($0, $1)
1245        assert_eq!(project.expressions.len(), 1);
1246
1247        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1248        let emit = match emit_kind {
1249            EmitKind::Emit(emit) => &emit.output_mapping,
1250            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1251        };
1252        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
1253        assert_eq!(emit, &[0, 1, 3]);
1254    }
1255
1256    #[test]
1257    fn test_parse_project_relation_complex() {
1258        let extensions = SimpleExtensions::default();
1259        let project = ProjectRel::parse_pair_with_context(
1260            &extensions,
1261            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1262            vec![Box::new(example_read_relation().into_rel(None))],
1263            5, // Assume 5 input fields
1264        )
1265        .unwrap()
1266        .0;
1267
1268        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
1269        assert_eq!(project.expressions.len(), 2);
1270
1271        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1272        let emit = match emit_kind {
1273            EmitKind::Emit(emit) => &emit.output_mapping,
1274            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1275        };
1276        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
1277        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
1278        assert_eq!(emit, &[5, 0, 6, 2, 1]);
1279    }
1280
1281    #[test]
1282    fn test_parse_aggregate_relation() {
1283        let extensions = TestContext::new()
1284            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1285            .with_function(1, 10, "sum")
1286            .with_function(1, 11, "count")
1287            .extensions;
1288
1289        let aggregate = AggregateRel::parse_pair_with_context(
1290            &extensions,
1291            parse_exact(
1292                Rule::aggregate_relation,
1293                "Aggregate[($0, $1), _ => sum($2), $0, count($2)]",
1294            ),
1295            vec![Box::new(example_read_relation().into_rel(None))],
1296            3,
1297        )
1298        .unwrap()
1299        .0;
1300
1301        // Should have 2 group-by sets ($0, $1) and an empty group, and emit 2 measures (sum($2), count($2))
1302        assert_eq!(aggregate.grouping_expressions.len(), 2);
1303        assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1304        assert_eq!(aggregate.groupings.len(), 2);
1305        assert_eq!(aggregate.measures.len(), 2);
1306
1307        let emit_kind = &aggregate
1308            .common
1309            .as_ref()
1310            .unwrap()
1311            .emit_kind
1312            .as_ref()
1313            .unwrap();
1314        let emit = match emit_kind {
1315            EmitKind::Emit(emit) => &emit.output_mapping,
1316            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1317        };
1318        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1319        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1320        assert_eq!(emit, &[2, 0, 3]);
1321    }
1322
1323    #[test]
1324    fn test_parse_aggregate_relation_maintain_column_order() {
1325        let extensions = TestContext::new()
1326            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1327            .with_function(1, 10, "sum")
1328            .with_function(1, 11, "count")
1329            .extensions;
1330
1331        let aggregate = AggregateRel::parse_pair_with_context(
1332            &extensions,
1333            parse_exact(
1334                Rule::aggregate_relation,
1335                "Aggregate[$0 => sum($1), $0, count($1)]",
1336            ),
1337            vec![Box::new(example_read_relation().into_rel(None))],
1338            3,
1339        )
1340        .unwrap()
1341        .0;
1342
1343        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1344        assert_eq!(aggregate.grouping_expressions.len(), 1);
1345        assert_eq!(aggregate.groupings.len(), 1);
1346        assert_eq!(aggregate.measures.len(), 2);
1347
1348        let emit_kind = &aggregate
1349            .common
1350            .as_ref()
1351            .unwrap()
1352            .emit_kind
1353            .as_ref()
1354            .unwrap();
1355        let emit = match emit_kind {
1356            EmitKind::Emit(emit) => &emit.output_mapping,
1357            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1358        };
1359        // Output mapping should be [1, 0, 2] (grouping fields + measures)
1360        assert_eq!(emit, &[1, 0, 2]);
1361    }
1362
1363    #[test]
1364    fn test_parse_aggregate_relation_simple() {
1365        let extensions = TestContext::new()
1366            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1367            .with_function(1, 10, "sum")
1368            .extensions;
1369
1370        let aggregate = AggregateRel::parse_pair_with_context(
1371            &extensions,
1372            parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"),
1373            vec![Box::new(example_read_relation().into_rel(None))],
1374            3,
1375        )
1376        .unwrap()
1377        .0;
1378
1379        assert_eq!(aggregate.grouping_expressions.len(), 2);
1380        assert_eq!(aggregate.groupings.len(), 1);
1381        // expression_references must be positions [0, 1], not raw field indices [2, 0]
1382        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1383    }
1384
1385    #[test]
1386    fn test_parse_aggregate_relation_global_aggregate() {
1387        let extensions = TestContext::new()
1388            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1389            .with_function(1, 10, "sum")
1390            .with_function(1, 11, "count")
1391            .extensions;
1392
1393        let aggregate = AggregateRel::parse_pair_with_context(
1394            &extensions,
1395            parse_exact(
1396                Rule::aggregate_relation,
1397                "Aggregate[_ => sum($0), count($1)]",
1398            ),
1399            vec![Box::new(example_read_relation().into_rel(None))],
1400            3,
1401        )
1402        .unwrap()
1403        .0;
1404
1405        // Should have 0 group-by fields and 2 measures
1406        assert_eq!(aggregate.grouping_expressions.len(), 0);
1407        assert_eq!(aggregate.groupings.len(), 1);
1408        assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1409        assert_eq!(aggregate.measures.len(), 2);
1410
1411        // Identity mapping [0, 1] over 2 outputs (0 grouping + 2 measures) → Direct
1412        let emit_kind = aggregate
1413            .common
1414            .as_ref()
1415            .unwrap()
1416            .emit_kind
1417            .as_ref()
1418            .unwrap();
1419        assert!(
1420            matches!(emit_kind, EmitKind::Direct(_)),
1421            "Expected Direct for identity emit, got {emit_kind:?}"
1422        );
1423    }
1424
1425    #[test]
1426    fn test_parse_aggregate_relation_grouping_sets() {
1427        let extensions = TestContext::new()
1428            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1429            .with_function(1, 11, "count")
1430            .extensions;
1431
1432        let read_rel = ReadRel::parse_pair_with_context(
1433            &extensions,
1434            parse_exact(
1435                Rule::read_relation,
1436                "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1437            ),
1438            vec![],
1439            0,
1440        )
1441        .unwrap()
1442        .0;
1443
1444        let aggregate = AggregateRel::parse_pair_with_context(
1445            &extensions,
1446            parse_exact(
1447                Rule::aggregate_relation,
1448                "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]",
1449            ),
1450            vec![Box::new(read_rel.into_rel(None))],
1451            4,
1452        )
1453        .unwrap()
1454        .0;
1455
1456        assert_eq!(aggregate.grouping_expressions.len(), 3);
1457        assert_eq!(aggregate.groupings.len(), 4);
1458        // ($0, $1, $2) -> [0, 1, 2]
1459        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1460        // ($2, $0) -> [2, 0] (reuses indices, different order)
1461        assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1462        // ($1) -> [1]
1463        assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1464        // _ -> empty
1465        assert!(aggregate.groupings[3].expression_references.is_empty());
1466        assert_eq!(aggregate.measures.len(), 1);
1467    }
1468
1469    #[test]
1470    fn test_fetch_relation_positive_values() {
1471        let extensions = SimpleExtensions::default();
1472
1473        // Test valid positive values should work
1474        let fetch_rel = FetchRel::parse_pair_with_context(
1475            &extensions,
1476            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1477            vec![Box::new(example_read_relation().into_rel(None))],
1478            3,
1479        )
1480        .unwrap()
1481        .0;
1482
1483        // Verify the limit and offset values are correct
1484        assert_eq!(
1485            fetch_rel.count_mode,
1486            Some(CountMode::CountExpr(i64_literal_expr(10)))
1487        );
1488        assert_eq!(
1489            fetch_rel.offset_mode,
1490            Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1491        );
1492    }
1493
1494    #[test]
1495    fn test_fetch_relation_negative_limit_rejected() {
1496        let extensions = SimpleExtensions::default();
1497
1498        // Test that fetch relations with negative limits are properly rejected
1499        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1500        if let Ok(mut pairs) = parsed_result {
1501            let pair = pairs.next().unwrap();
1502            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1503                // Full parse succeeded, now test that validation catches the negative value
1504                let result = FetchRel::parse_pair_with_context(
1505                    &extensions,
1506                    pair,
1507                    vec![Box::new(example_read_relation().into_rel(None))],
1508                    3,
1509                );
1510                assert!(result.is_err());
1511                let error_msg = result.unwrap_err().to_string();
1512                assert!(error_msg.contains("Fetch limit must be non-negative"));
1513            } else {
1514                // If grammar doesn't fully support negative values, that's also acceptable
1515                // since it would prevent negative values at parse time
1516                println!("Grammar prevents negative limit values at parse time");
1517            }
1518        } else {
1519            // Grammar doesn't support negative values in fetch context
1520            println!("Grammar prevents negative limit values at parse time");
1521        }
1522    }
1523
1524    #[test]
1525    fn test_fetch_relation_negative_offset_rejected() {
1526        let extensions = SimpleExtensions::default();
1527
1528        // Test that fetch relations with negative offsets are properly rejected
1529        let parsed_result =
1530            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1531        if let Ok(mut pairs) = parsed_result {
1532            let pair = pairs.next().unwrap();
1533            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1534                // Full parse succeeded, now test that validation catches the negative value
1535                let result = FetchRel::parse_pair_with_context(
1536                    &extensions,
1537                    pair,
1538                    vec![Box::new(example_read_relation().into_rel(None))],
1539                    3,
1540                );
1541                assert!(result.is_err());
1542                let error_msg = result.unwrap_err().to_string();
1543                assert!(error_msg.contains("Fetch offset must be non-negative"));
1544            } else {
1545                // If grammar doesn't fully support negative values, that's also acceptable
1546                println!("Grammar prevents negative offset values at parse time");
1547            }
1548        } else {
1549            // Grammar doesn't support negative values in fetch context
1550            println!("Grammar prevents negative offset values at parse time");
1551        }
1552    }
1553
1554    #[test]
1555    fn test_parse_join_relation() {
1556        let extensions = TestContext::new()
1557            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1558            .with_function(1, 10, "eq")
1559            .extensions;
1560
1561        let left_rel = example_read_relation().into_rel(None);
1562        let right_rel = example_read_relation().into_rel(None);
1563
1564        let join = JoinRel::parse_pair_with_context(
1565            &extensions,
1566            parse_exact(
1567                Rule::join_relation,
1568                "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1569            ),
1570            vec![Box::new(left_rel), Box::new(right_rel)],
1571            6, // left (3) + right (3) = 6 total input fields
1572        )
1573        .unwrap()
1574        .0;
1575
1576        // Should be an Inner join
1577        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1578
1579        // Should have left and right relations
1580        assert!(join.left.is_some());
1581        assert!(join.right.is_some());
1582
1583        // Should have a join condition
1584        assert!(join.expression.is_some());
1585
1586        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1587        let emit = match emit_kind {
1588            EmitKind::Emit(emit) => &emit.output_mapping,
1589            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1590        };
1591        // Output mapping should be [0, 1, 3, 4] (selected columns)
1592        assert_eq!(emit, &[0, 1, 3, 4]);
1593    }
1594
1595    #[test]
1596    fn test_parse_join_relation_left_outer() {
1597        let extensions = TestContext::new()
1598            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1599            .with_function(1, 10, "eq")
1600            .extensions;
1601
1602        let left_rel = example_read_relation().into_rel(None);
1603        let right_rel = example_read_relation().into_rel(None);
1604
1605        let join = JoinRel::parse_pair_with_context(
1606            &extensions,
1607            parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1608            vec![Box::new(left_rel), Box::new(right_rel)],
1609            6,
1610        )
1611        .unwrap()
1612        .0;
1613
1614        // Should be a Left join
1615        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1616
1617        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1618        let emit = match emit_kind {
1619            EmitKind::Emit(emit) => &emit.output_mapping,
1620            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1621        };
1622        // Output mapping should be [0, 1, 2]
1623        assert_eq!(emit, &[0, 1, 2]);
1624    }
1625
1626    #[test]
1627    fn test_parse_join_relation_left_semi() {
1628        let extensions = TestContext::new()
1629            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1630            .with_function(1, 10, "eq")
1631            .extensions;
1632
1633        let left_rel = example_read_relation().into_rel(None);
1634        let right_rel = example_read_relation().into_rel(None);
1635
1636        let join = JoinRel::parse_pair_with_context(
1637            &extensions,
1638            parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1639            vec![Box::new(left_rel), Box::new(right_rel)],
1640            6,
1641        )
1642        .unwrap()
1643        .0;
1644
1645        // Should be a LeftSemi join
1646        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1647
1648        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1649        let emit = match emit_kind {
1650            EmitKind::Emit(emit) => &emit.output_mapping,
1651            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1652        };
1653        // Output mapping should be [0, 1] (only left columns for semi join)
1654        assert_eq!(emit, &[0, 1]);
1655    }
1656
1657    #[test]
1658    fn test_parse_join_relation_requires_two_children() {
1659        let extensions = SimpleExtensions::default();
1660
1661        // Test with 0 children
1662        let result = JoinRel::parse_pair_with_context(
1663            &extensions,
1664            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1665            vec![],
1666            0,
1667        );
1668        assert!(result.is_err());
1669
1670        // Test with 1 child
1671        let result = JoinRel::parse_pair_with_context(
1672            &extensions,
1673            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1674            vec![Box::new(example_read_relation().into_rel(None))],
1675            3,
1676        );
1677        assert!(result.is_err());
1678
1679        // Test with 3 children
1680        let result = JoinRel::parse_pair_with_context(
1681            &extensions,
1682            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1683            vec![
1684                Box::new(example_read_relation().into_rel(None)),
1685                Box::new(example_read_relation().into_rel(None)),
1686                Box::new(example_read_relation().into_rel(None)),
1687            ],
1688            9,
1689        );
1690        assert!(result.is_err());
1691    }
1692
1693    fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1694        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1695        assert_eq!(pairs.as_str(), input);
1696        let pair = pairs.next().unwrap();
1697        assert_eq!(pairs.next(), None);
1698        pair
1699    }
1700}