substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use prost::Message;
5use substrait::proto::aggregate_rel::Grouping;
6use substrait::proto::expression::literal::LiteralType;
7use substrait::proto::expression::{Literal, RexType};
8use substrait::proto::fetch_rel::{CountMode, OffsetMode};
9use substrait::proto::rel::RelType;
10use substrait::proto::rel_common::{Emit, EmitKind};
11use substrait::proto::sort_field::{SortDirection, SortKind};
12use substrait::proto::{
13    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
14    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
15};
16
17use super::{
18    ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
19};
20use crate::extensions::any::Any;
21use crate::extensions::registry::{ExtensionError, ExtensionType};
22use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions};
23use crate::parser::errors::{ParseContext, ParseError};
24use crate::parser::expressions::{FieldIndex, Name};
25
26/// Parsing context for relations that includes extensions, registry, and optional warning collection
27pub struct RelationParsingContext<'a> {
28    pub extensions: &'a SimpleExtensions,
29    pub registry: &'a ExtensionRegistry,
30    pub line_no: i64,
31    pub line: &'a str,
32}
33
34impl<'a> RelationParsingContext<'a> {
35    /// Resolve extension detail using registry. Any failure is treated as a hard parse error.
36    pub fn resolve_extension_detail(
37        &self,
38        extension_name: &str,
39        extension_args: &ExtensionArgs,
40    ) -> Result<Option<Any>, ParseError> {
41        let detail = self
42            .registry
43            .parse_extension(extension_name, extension_args);
44
45        match detail {
46            Ok(any) => Ok(Some(any)),
47            Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
48                name: extension_name.to_string(),
49                context: ParseContext::new(self.line_no, self.line.to_string()),
50            }),
51            Err(err) => Err(ParseError::ExtensionDetail(
52                ParseContext::new(self.line_no, self.line.to_string()),
53                err,
54            )),
55        }
56    }
57
58    /// Resolve an advanced-extension detail (enhancement or optimization) using the registry.
59    /// Any failure is treated as a hard parse error.
60    ///
61    /// `ext_type` must be `Enhancement` or `Optimization`; `Relation` is not valid here and
62    /// will panic via the unreachable branch in the dispatch below.
63    pub fn resolve_adv_ext_detail(
64        &self,
65        ext_type: ExtensionType,
66        name: &str,
67        args: &ExtensionArgs,
68    ) -> Result<Any, ParseError> {
69        let result = match ext_type {
70            ExtensionType::Enhancement => self.registry.parse_enhancement(name, args),
71            ExtensionType::Optimization => self.registry.parse_optimization(name, args),
72            ExtensionType::Relation => unreachable!("Relation is not an advanced extension type"),
73        };
74        result.map_err(|err| match err {
75            ExtensionError::NotFound { .. } => ParseError::UnregisteredExtension {
76                name: name.to_string(),
77                context: ParseContext::new(self.line_no, self.line.to_string()),
78            },
79            err => ParseError::ExtensionDetail(
80                ParseContext::new(self.line_no, self.line.to_string()),
81                err,
82            ),
83        })
84    }
85}
86
87/// A trait for parsing relations with full context needed for tree building.
88/// This includes extensions, the parsed pair, input children, and output field count.
89pub trait RelationParsePair: Sized {
90    fn rule() -> Rule;
91
92    fn message() -> &'static str;
93
94    /// Parse a relation with full context for tree building.
95    ///
96    /// Args:
97    /// - extensions: The extensions context
98    /// - pair: The parsed pest pair
99    /// - input_children: The input relations (for wiring)
100    /// - input_field_count: Number of output fields from input children (for output mapping)
101    fn parse_pair_with_context(
102        extensions: &SimpleExtensions,
103        pair: Pair<Rule>,
104        input_children: Vec<Box<Rel>>,
105        _input_field_count: usize,
106    ) -> Result<Self, MessageParseError>;
107
108    fn into_rel(self) -> Rel;
109}
110
111pub struct TableName(Vec<String>);
112
113impl ParsePair for TableName {
114    fn rule() -> Rule {
115        Rule::table_name
116    }
117
118    fn message() -> &'static str {
119        "TableName"
120    }
121
122    fn parse_pair(pair: Pair<Rule>) -> Self {
123        assert_eq!(pair.as_rule(), Self::rule());
124        let pairs = pair.into_inner();
125        let mut names = Vec::with_capacity(pairs.len());
126        let mut iter = RuleIter::from(pairs);
127        while let Some(name) = iter.parse_if_next::<Name>() {
128            names.push(name.0);
129        }
130        iter.done();
131        Self(names)
132    }
133}
134
135#[derive(Debug, Clone)]
136pub struct Column {
137    pub name: String,
138    pub typ: Type,
139}
140
141impl ScopedParsePair for Column {
142    fn rule() -> Rule {
143        Rule::named_column
144    }
145
146    fn message() -> &'static str {
147        "Column"
148    }
149
150    fn parse_pair(
151        extensions: &SimpleExtensions,
152        pair: Pair<Rule>,
153    ) -> Result<Self, MessageParseError> {
154        assert_eq!(pair.as_rule(), Self::rule());
155        let mut iter = RuleIter::from(pair.into_inner());
156        let name = iter.parse_next::<Name>().0;
157        let typ = iter.parse_next_scoped(extensions)?;
158        iter.done();
159        Ok(Self { name, typ })
160    }
161}
162
163pub struct NamedColumnList(Vec<Column>);
164
165impl ScopedParsePair for NamedColumnList {
166    fn rule() -> Rule {
167        Rule::named_column_list
168    }
169
170    fn message() -> &'static str {
171        "NamedColumnList"
172    }
173
174    fn parse_pair(
175        extensions: &SimpleExtensions,
176        pair: Pair<Rule>,
177    ) -> Result<Self, MessageParseError> {
178        assert_eq!(pair.as_rule(), Self::rule());
179        let mut columns = Vec::new();
180        for col in pair.into_inner() {
181            columns.push(Column::parse_pair(extensions, col)?);
182        }
183        Ok(Self(columns))
184    }
185}
186
187/// This is a utility function for extracting a single child from the list of
188/// children, to be used in the RelationParsePair trait. The RelationParsePair
189/// trait passes a Vec of children, because some relations have multiple
190/// children - but most accept exactly one child.
191#[allow(clippy::vec_box)]
192pub(crate) fn expect_one_child(
193    message: &'static str,
194    pair: &Pair<Rule>,
195    mut input_children: Vec<Box<Rel>>,
196) -> Result<Box<Rel>, MessageParseError> {
197    match input_children.len() {
198        0 => Err(MessageParseError::invalid(
199            message,
200            pair.as_span(),
201            format!("{message} missing child"),
202        )),
203        1 => Ok(input_children.pop().unwrap()),
204        n => Err(MessageParseError::invalid(
205            message,
206            pair.as_span(),
207            format!("{message} should have 1 input child, got {n}"),
208        )),
209    }
210}
211
212/// Parse a reference list Pair and return an EmitKind::Emit.
213fn parse_reference_emit(pair: Pair<Rule>) -> EmitKind {
214    assert_eq!(pair.as_rule(), Rule::reference_list);
215    let output_mapping = pair
216        .into_inner()
217        .map(|p| FieldIndex::parse_pair(p).0)
218        .collect::<Vec<i32>>();
219    EmitKind::Emit(Emit { output_mapping })
220}
221
222/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
223///
224/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
225///
226/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
227pub struct ParsedNamedArgs<'a> {
228    map: HashMap<&'a str, Pair<'a, Rule>>,
229}
230
231impl<'a> ParsedNamedArgs<'a> {
232    pub fn new(
233        pairs: pest::iterators::Pairs<'a, Rule>,
234        rule: Rule,
235    ) -> Result<Self, MessageParseError> {
236        let mut map = HashMap::new();
237        for pair in pairs {
238            assert_eq!(pair.as_rule(), rule);
239            let mut inner = pair.clone().into_inner();
240            let name_pair = inner.next().unwrap();
241            let value_pair = inner.next().unwrap();
242            assert_eq!(inner.next(), None);
243            let name = name_pair.as_str();
244            if map.contains_key(name) {
245                return Err(MessageParseError::invalid(
246                    "NamedArg",
247                    name_pair.as_span(),
248                    format!("Duplicate argument: {name}"),
249                ));
250            }
251            map.insert(name, value_pair);
252        }
253        Ok(Self { map })
254    }
255
256    // Returns the pair if it exists and matches the rule, otherwise None.
257    // Asserts that the rule must match the rule of the pair (and therefore
258    // panics in non-release-mode if not)
259    pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
260        let pair = self.map.remove(name).inspect(|pair| {
261            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
262        });
263        (self, pair)
264    }
265
266    // Returns an error if there are any unused arguments.
267    pub fn done(self) -> Result<(), MessageParseError> {
268        if let Some((name, pair)) = self.map.iter().next() {
269            return Err(MessageParseError::invalid(
270                "NamedArgExtractor",
271                // No span available for all unused args; use default.
272                pair.as_span(),
273                format!("Unknown argument: {name}"),
274            ));
275        }
276        Ok(())
277    }
278}
279
280impl RelationParsePair for ReadRel {
281    fn rule() -> Rule {
282        Rule::read_relation
283    }
284
285    fn message() -> &'static str {
286        "ReadRel"
287    }
288
289    fn into_rel(self) -> Rel {
290        Rel {
291            rel_type: Some(RelType::Read(Box::new(self))),
292        }
293    }
294
295    fn parse_pair_with_context(
296        extensions: &SimpleExtensions,
297        pair: Pair<Rule>,
298        input_children: Vec<Box<Rel>>,
299        _input_field_count: usize,
300    ) -> Result<Self, MessageParseError> {
301        assert_eq!(pair.as_rule(), Self::rule());
302        // ReadRel is a leaf node - it should have no input children and 0 input fields
303        if !input_children.is_empty() {
304            return Err(MessageParseError::invalid(
305                Self::message(),
306                pair.as_span(),
307                "ReadRel should have no input children",
308            ));
309        }
310        if _input_field_count != 0 {
311            let error = pest::error::Error::new_from_span(
312                pest::error::ErrorVariant::CustomError {
313                    message: "ReadRel should have 0 input fields".to_string(),
314                },
315                pair.as_span(),
316            );
317            return Err(MessageParseError::new(
318                "ReadRel",
319                ErrorKind::InvalidValue,
320                Box::new(error),
321            ));
322        }
323
324        let mut iter = RuleIter::from(pair.into_inner());
325        let table = iter.parse_next::<TableName>().0;
326        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
327        iter.done();
328
329        let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
330        let struct_ = r#type::Struct {
331            types,
332            type_variation_reference: 0,
333            nullability: r#type::Nullability::Required as i32,
334        };
335        let named_struct = NamedStruct {
336            names,
337            r#struct: Some(struct_),
338        };
339
340        let read_rel = ReadRel {
341            base_schema: Some(named_struct),
342            read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
343                names: table,
344                advanced_extension: None,
345            })),
346            ..Default::default()
347        };
348
349        Ok(read_rel)
350    }
351}
352
353impl RelationParsePair for FilterRel {
354    fn rule() -> Rule {
355        Rule::filter_relation
356    }
357
358    fn message() -> &'static str {
359        "FilterRel"
360    }
361
362    fn into_rel(self) -> Rel {
363        Rel {
364            rel_type: Some(RelType::Filter(Box::new(self))),
365        }
366    }
367
368    fn parse_pair_with_context(
369        extensions: &SimpleExtensions,
370        pair: Pair<Rule>,
371        input_children: Vec<Box<Rel>>,
372        _input_field_count: usize,
373    ) -> Result<Self, MessageParseError> {
374        // Form: Filter[condition => references]
375
376        assert_eq!(pair.as_rule(), Self::rule());
377        let input = expect_one_child(Self::message(), &pair, input_children)?;
378        let mut iter = RuleIter::from(pair.into_inner());
379        // condition
380        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
381        // references (which become the emit)
382        let references_pair = iter.pop(Rule::reference_list);
383        iter.done();
384
385        let emit = parse_reference_emit(references_pair);
386        let common = RelCommon {
387            emit_kind: Some(emit),
388            ..Default::default()
389        };
390
391        Ok(FilterRel {
392            input: Some(input),
393            condition: Some(Box::new(condition)),
394            common: Some(common),
395            advanced_extension: None,
396        })
397    }
398}
399
400impl RelationParsePair for ProjectRel {
401    fn rule() -> Rule {
402        Rule::project_relation
403    }
404
405    fn message() -> &'static str {
406        "ProjectRel"
407    }
408
409    fn into_rel(self) -> Rel {
410        Rel {
411            rel_type: Some(RelType::Project(Box::new(self))),
412        }
413    }
414
415    fn parse_pair_with_context(
416        extensions: &SimpleExtensions,
417        pair: Pair<Rule>,
418        input_children: Vec<Box<Rel>>,
419        _input_field_count: usize,
420    ) -> Result<Self, MessageParseError> {
421        assert_eq!(pair.as_rule(), Self::rule());
422        let input = expect_one_child(Self::message(), &pair, input_children)?;
423
424        // Get the argument list (contains references and expressions)
425        let arguments_pair = unwrap_single_pair(pair);
426
427        let mut expressions = Vec::new();
428        let mut output_mapping = Vec::new();
429
430        // Process each argument (can be either a reference or expression)
431        for arg in arguments_pair.into_inner() {
432            let inner_arg = unwrap_single_pair(arg);
433            match inner_arg.as_rule() {
434                Rule::reference => {
435                    // Parse reference like "$0" -> 0
436                    let field_index = FieldIndex::parse_pair(inner_arg);
437                    output_mapping.push(field_index.0);
438                }
439                Rule::expression => {
440                    // Parse as expression (e.g., 42, add($0, $1))
441                    let _expr = Expression::parse_pair(extensions, inner_arg)?;
442                    expressions.push(_expr);
443                    // Expression: index after all input fields
444                    output_mapping.push(_input_field_count as i32 + (expressions.len() as i32 - 1));
445                }
446                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
447            }
448        }
449
450        let emit = EmitKind::Emit(Emit { output_mapping });
451        let common = RelCommon {
452            emit_kind: Some(emit),
453            ..Default::default()
454        };
455
456        Ok(ProjectRel {
457            input: Some(input),
458            expressions,
459            common: Some(common),
460            advanced_extension: None,
461        })
462    }
463}
464
465impl RelationParsePair for AggregateRel {
466    fn rule() -> Rule {
467        Rule::aggregate_relation
468    }
469
470    fn message() -> &'static str {
471        "AggregateRel"
472    }
473
474    fn into_rel(self) -> Rel {
475        Rel {
476            rel_type: Some(RelType::Aggregate(Box::new(self))),
477        }
478    }
479
480    fn parse_pair_with_context(
481        extensions: &SimpleExtensions,
482        pair: Pair<Rule>,
483        input_children: Vec<Box<Rel>>,
484        _input_field_count: usize,
485    ) -> Result<Self, MessageParseError> {
486        assert_eq!(pair.as_rule(), Self::rule());
487        let input = expect_one_child(Self::message(), &pair, input_children)?;
488        let mut iter = RuleIter::from(pair.into_inner());
489        let group_by_pair = iter.pop(Rule::aggregate_group_by);
490        let output_pair = iter.pop(Rule::aggregate_output);
491        iter.done();
492
493        let inner = group_by_pair
494            .into_inner()
495            .next()
496            .expect("aggregate_group_by must have one inner item");
497
498        let grouping_sets = parse_grouping_sets(extensions, inner);
499        let (groupings, grouping_expressions) = build_grouping_fields(&grouping_sets);
500
501        let (measures, output_mapping) =
502            parse_aggregate_measures(extensions, output_pair, &grouping_expressions)?;
503
504        let emit = EmitKind::Emit(Emit { output_mapping });
505        let common = RelCommon {
506            emit_kind: Some(emit),
507            ..Default::default()
508        };
509
510        Ok(AggregateRel {
511            input: Some(input),
512            grouping_expressions,
513            groupings,
514            measures,
515            common: Some(common),
516            advanced_extension: None,
517        })
518    }
519}
520
521/// Parses the output section of an aggregate (everything after `=>`).
522///
523/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
524/// this parses `sum($2), $0, count($2)`.
525fn parse_aggregate_measures(
526    extensions: &SimpleExtensions,
527    output_pair: Pair<'_, Rule>,
528    grouping_expressions: &[Expression],
529) -> Result<(Vec<aggregate_rel::Measure>, Vec<i32>), MessageParseError> {
530    assert_eq!(output_pair.as_rule(), Rule::aggregate_output);
531    let mut measures = Vec::new();
532    let mut output_mapping = Vec::new();
533
534    for aggregate_output_item in output_pair.into_inner() {
535        let inner_item = unwrap_single_pair(aggregate_output_item);
536        match inner_item.as_rule() {
537            Rule::reference => {
538                let field_index = FieldIndex::parse_pair(inner_item);
539                output_mapping.push(field_index.0);
540            }
541            Rule::aggregate_measure => {
542                let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
543                output_mapping.push(grouping_expressions.len() as i32 + measures.len() as i32);
544                measures.push(measure);
545            }
546            _ => panic!(
547                "Unexpected inner output item rule: {:?}",
548                inner_item.as_rule()
549            ),
550        }
551    }
552
553    Ok((measures, output_mapping))
554}
555
556/// Parses the grouping section of an aggregate (everything before `=>`).
557///
558/// For example, in `Aggregate[($0, $1), _ => sum($2), $0, count($2)]`,
559/// this parses `($0, $1), _`.
560///
561/// Each inner Vec is one grouping set; an empty vec represents no grouping (global aggregate).
562///
563/// Grammar: `aggregate_group_by = { grouping_set_list | expression_list }`
564fn parse_grouping_sets(
565    extensions: &SimpleExtensions,
566    inner: Pair<'_, Rule>,
567) -> Vec<Vec<Expression>> {
568    assert!(
569        matches!(
570            inner.as_rule(),
571            Rule::expression_list | Rule::grouping_set_list
572        ),
573        "Expected expression_list or grouping_set_list, got {:?}",
574        inner.as_rule()
575    );
576    match inner.as_rule() {
577        Rule::expression_list => {
578            vec![parse_expression_list(extensions, inner)]
579        }
580        Rule::grouping_set_list => inner
581            .into_inner()
582            .map(|pair| parse_grouping_set(extensions, pair))
583            .collect(),
584        _ => unreachable!(
585            "Unexpected rule in aggregate_group_by: {:?}",
586            inner.as_rule()
587        ),
588    }
589}
590
591/// Parses a single grouping set, e.g. `($0, $1)` or `_`.
592///
593/// Grammar: `grouping_set = { ("(" ~ expression_list ~ ")") | empty }`
594fn parse_grouping_set(extensions: &SimpleExtensions, pair: Pair<'_, Rule>) -> Vec<Expression> {
595    assert_eq!(pair.as_rule(), Rule::grouping_set);
596    let inner = pair
597        .into_inner()
598        .next()
599        .expect("grouping_set must have one inner item");
600    match inner.as_rule() {
601        Rule::empty => vec![],
602        Rule::expression_list => parse_expression_list(extensions, inner),
603        _ => unreachable!("Unexpected item in grouping_set: {:?}", inner.as_rule()),
604    }
605}
606
607/// Grammar: `expression_list = { expression ~ ("," ~ expression)* }`
608fn parse_expression_list(extensions: &SimpleExtensions, pair: Pair<'_, Rule>) -> Vec<Expression> {
609    pair.into_inner()
610        .map(|expr_pair| {
611            Expression::parse_pair(extensions, expr_pair)
612                .expect("By the grammar rule, only expressions should be parsed")
613        })
614        .collect()
615}
616
617/// Deduplicates expressions across all sets and produces the AggregateRel's
618/// protobuf fields: a flat deduplicated expression list and per-set Grouping
619/// messages with index references into that list.
620fn build_grouping_fields(expression_sets: &[Vec<Expression>]) -> (Vec<Grouping>, Vec<Expression>) {
621    let mut expressions: Vec<Expression> = Vec::new();
622    let mut seen: HashMap<Vec<u8>, u32> = HashMap::new();
623
624    let groupings = expression_sets
625        .iter()
626        .map(|set| {
627            let expression_references = set
628                .iter()
629                .map(|exp| {
630                    // TODO: use a better key here than encoding to bytes.
631                    // Ideally, substrait-rs would support `PartialEq` and `Hash`,
632                    // but as there isn't an easy way to do that now, we'll skip.
633                    let key = exp.encode_to_vec();
634                    let next_idx = expressions.len() as u32;
635                    *seen.entry(key).or_insert_with(|| {
636                        expressions.push(exp.clone());
637                        next_idx
638                    })
639                })
640                .collect();
641            Grouping {
642                expression_references,
643                #[allow(deprecated)]
644                grouping_expressions: vec![],
645            }
646        })
647        .collect();
648
649    (groupings, expressions)
650}
651
652impl ScopedParsePair for SortField {
653    fn rule() -> Rule {
654        Rule::sort_field
655    }
656
657    fn message() -> &'static str {
658        "SortField"
659    }
660
661    fn parse_pair(
662        _extensions: &SimpleExtensions,
663        pair: Pair<Rule>,
664    ) -> Result<Self, MessageParseError> {
665        assert_eq!(pair.as_rule(), Self::rule());
666        let mut iter = RuleIter::from(pair.into_inner());
667        let reference_pair = iter.pop(Rule::reference);
668        let field_index = FieldIndex::parse_pair(reference_pair);
669        let direction_pair = iter.pop(Rule::sort_direction);
670        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
671        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
672        // identifiers, but the enum variant names don't include it
673        let direction = match direction_pair.as_str().trim_start_matches('&') {
674            "AscNullsFirst" => SortDirection::AscNullsFirst,
675            "AscNullsLast" => SortDirection::AscNullsLast,
676            "DescNullsFirst" => SortDirection::DescNullsFirst,
677            "DescNullsLast" => SortDirection::DescNullsLast,
678            other => {
679                return Err(MessageParseError::invalid(
680                    "SortDirection",
681                    direction_pair.as_span(),
682                    format!("Unknown sort direction: {other}"),
683                ));
684            }
685        };
686        iter.done();
687        Ok(SortField {
688            expr: Some(Expression {
689                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
690                    field_index.to_field_reference(),
691                ))),
692            }),
693            // TODO: Add support for SortKind::ComparisonFunctionReference
694            sort_kind: Some(SortKind::Direction(direction as i32)),
695        })
696    }
697}
698
699impl RelationParsePair for SortRel {
700    fn rule() -> Rule {
701        Rule::sort_relation
702    }
703
704    fn message() -> &'static str {
705        "SortRel"
706    }
707
708    fn into_rel(self) -> Rel {
709        Rel {
710            rel_type: Some(RelType::Sort(Box::new(self))),
711        }
712    }
713
714    fn parse_pair_with_context(
715        extensions: &SimpleExtensions,
716        pair: Pair<Rule>,
717        input_children: Vec<Box<Rel>>,
718        _input_field_count: usize,
719    ) -> Result<Self, MessageParseError> {
720        assert_eq!(pair.as_rule(), Self::rule());
721        let input = expect_one_child(Self::message(), &pair, input_children)?;
722        let mut iter = RuleIter::from(pair.into_inner());
723        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
724        let reference_list_pair = iter.pop(Rule::reference_list);
725        let mut sorts = Vec::new();
726        for sort_field_pair in sort_field_list_pair.into_inner() {
727            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
728            sorts.push(sort_field);
729        }
730        let emit = parse_reference_emit(reference_list_pair);
731        let common = RelCommon {
732            emit_kind: Some(emit),
733            ..Default::default()
734        };
735        iter.done();
736        Ok(SortRel {
737            input: Some(input),
738            sorts,
739            common: Some(common),
740            advanced_extension: None,
741        })
742    }
743}
744
745impl ScopedParsePair for CountMode {
746    fn rule() -> Rule {
747        Rule::fetch_value
748    }
749    fn message() -> &'static str {
750        "CountMode"
751    }
752    fn parse_pair(
753        extensions: &SimpleExtensions,
754        pair: Pair<Rule>,
755    ) -> Result<Self, MessageParseError> {
756        assert_eq!(pair.as_rule(), Self::rule());
757        let mut arg_inner = RuleIter::from(pair.into_inner());
758        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
759            int_pair
760        } else {
761            arg_inner.pop(Rule::expression)
762        };
763        match value_pair.as_rule() {
764            Rule::integer => {
765                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
766                    MessageParseError::invalid(
767                        Self::message(),
768                        value_pair.as_span(),
769                        format!("Invalid integer: {e}"),
770                    )
771                })?;
772                if value < 0 {
773                    return Err(MessageParseError::invalid(
774                        Self::message(),
775                        value_pair.as_span(),
776                        format!("Fetch limit must be non-negative, got: {value}"),
777                    ));
778                }
779                Ok(CountMode::CountExpr(i64_literal_expr(value)))
780            }
781            Rule::expression => {
782                let expr = Expression::parse_pair(extensions, value_pair)?;
783                Ok(CountMode::CountExpr(Box::new(expr)))
784            }
785            _ => Err(MessageParseError::invalid(
786                Self::message(),
787                value_pair.as_span(),
788                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
789            )),
790        }
791    }
792}
793
794fn i64_literal_expr(value: i64) -> Box<Expression> {
795    Box::new(Expression {
796        rex_type: Some(RexType::Literal(Literal {
797            nullable: false,
798            type_variation_reference: 0,
799            literal_type: Some(LiteralType::I64(value)),
800        })),
801    })
802}
803
804impl ScopedParsePair for OffsetMode {
805    fn rule() -> Rule {
806        Rule::fetch_value
807    }
808    fn message() -> &'static str {
809        "OffsetMode"
810    }
811    fn parse_pair(
812        extensions: &SimpleExtensions,
813        pair: Pair<Rule>,
814    ) -> Result<Self, MessageParseError> {
815        assert_eq!(pair.as_rule(), Self::rule());
816        let mut arg_inner = RuleIter::from(pair.into_inner());
817        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
818            int_pair
819        } else {
820            arg_inner.pop(Rule::expression)
821        };
822        match value_pair.as_rule() {
823            Rule::integer => {
824                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
825                    MessageParseError::invalid(
826                        Self::message(),
827                        value_pair.as_span(),
828                        format!("Invalid integer: {e}"),
829                    )
830                })?;
831                if value < 0 {
832                    return Err(MessageParseError::invalid(
833                        Self::message(),
834                        value_pair.as_span(),
835                        format!("Fetch offset must be non-negative, got: {value}"),
836                    ));
837                }
838                Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
839            }
840            Rule::expression => {
841                let expr = Expression::parse_pair(extensions, value_pair)?;
842                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
843            }
844            _ => Err(MessageParseError::invalid(
845                Self::message(),
846                value_pair.as_span(),
847                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
848            )),
849        }
850    }
851}
852
853impl RelationParsePair for FetchRel {
854    fn rule() -> Rule {
855        Rule::fetch_relation
856    }
857
858    fn message() -> &'static str {
859        "FetchRel"
860    }
861
862    fn into_rel(self) -> Rel {
863        Rel {
864            rel_type: Some(RelType::Fetch(Box::new(self))),
865        }
866    }
867
868    fn parse_pair_with_context(
869        extensions: &SimpleExtensions,
870        pair: Pair<Rule>,
871        input_children: Vec<Box<Rel>>,
872        _input_field_count: usize,
873    ) -> Result<Self, MessageParseError> {
874        assert_eq!(pair.as_rule(), Self::rule());
875        let input = expect_one_child(Self::message(), &pair, input_children)?;
876        let mut iter = RuleIter::from(pair.into_inner());
877
878        // Extract all pairs first, then do validation
879        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
880            None => {
881                // If there are no arguments, it should be empty
882                iter.pop(Rule::empty);
883                (None, None)
884            }
885            Some(fetch_args_pair) => {
886                let extractor =
887                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
888                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
889                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
890                extractor.done()?;
891                (limit_pair, offset_pair)
892            }
893        };
894
895        let reference_list_pair = iter.pop(Rule::reference_list);
896        let emit = parse_reference_emit(reference_list_pair);
897        let common = RelCommon {
898            emit_kind: Some(emit),
899            ..Default::default()
900        };
901        iter.done();
902
903        // Now do validation after iterator is fully consumed
904        let count_mode = limit_pair
905            .map(|pair| CountMode::parse_pair(extensions, pair))
906            .transpose()?;
907        let offset_mode = offset_pair
908            .map(|pair| OffsetMode::parse_pair(extensions, pair))
909            .transpose()?;
910        Ok(FetchRel {
911            input: Some(input),
912            common: Some(common),
913            advanced_extension: None,
914            offset_mode,
915            count_mode,
916        })
917    }
918}
919
920impl ParsePair for join_rel::JoinType {
921    fn rule() -> Rule {
922        Rule::join_type
923    }
924
925    fn message() -> &'static str {
926        "JoinType"
927    }
928
929    fn parse_pair(pair: Pair<Rule>) -> Self {
930        assert_eq!(pair.as_rule(), Self::rule());
931        let join_type_str = pair.as_str().trim_start_matches('&');
932        match join_type_str {
933            "Inner" => join_rel::JoinType::Inner,
934            "Left" => join_rel::JoinType::Left,
935            "Right" => join_rel::JoinType::Right,
936            "Outer" => join_rel::JoinType::Outer,
937            "LeftSemi" => join_rel::JoinType::LeftSemi,
938            "RightSemi" => join_rel::JoinType::RightSemi,
939            "LeftAnti" => join_rel::JoinType::LeftAnti,
940            "RightAnti" => join_rel::JoinType::RightAnti,
941            "LeftSingle" => join_rel::JoinType::LeftSingle,
942            "RightSingle" => join_rel::JoinType::RightSingle,
943            "LeftMark" => join_rel::JoinType::LeftMark,
944            "RightMark" => join_rel::JoinType::RightMark,
945            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
946        }
947    }
948}
949
950impl RelationParsePair for JoinRel {
951    fn rule() -> Rule {
952        Rule::join_relation
953    }
954
955    fn message() -> &'static str {
956        "JoinRel"
957    }
958
959    fn into_rel(self) -> Rel {
960        Rel {
961            rel_type: Some(RelType::Join(Box::new(self))),
962        }
963    }
964
965    fn parse_pair_with_context(
966        extensions: &SimpleExtensions,
967        pair: Pair<Rule>,
968        input_children: Vec<Box<Rel>>,
969        _input_field_count: usize,
970    ) -> Result<Self, MessageParseError> {
971        assert_eq!(pair.as_rule(), Self::rule());
972
973        // Join requires exactly 2 input children
974        if input_children.len() != 2 {
975            return Err(MessageParseError::invalid(
976                Self::message(),
977                pair.as_span(),
978                format!(
979                    "JoinRel should have exactly 2 input children, got {}",
980                    input_children.len()
981                ),
982            ));
983        }
984
985        let mut children_iter = input_children.into_iter();
986        let left = children_iter.next().unwrap();
987        let right = children_iter.next().unwrap();
988
989        let mut iter = RuleIter::from(pair.into_inner());
990
991        // Parse join type
992        let join_type = iter.parse_next::<join_rel::JoinType>();
993
994        // Parse join condition expression
995        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
996
997        // Parse output references (which become the emit)
998        let reference_list_pair = iter.pop(Rule::reference_list);
999        iter.done();
1000
1001        let emit = parse_reference_emit(reference_list_pair);
1002        let common = RelCommon {
1003            emit_kind: Some(emit),
1004            ..Default::default()
1005        };
1006
1007        Ok(JoinRel {
1008            common: Some(common),
1009            left: Some(left),
1010            right: Some(right),
1011            expression: Some(Box::new(condition)),
1012            post_join_filter: None, // Not supported in grammar yet
1013            r#type: join_type as i32,
1014            advanced_extension: None,
1015        })
1016    }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use pest::Parser;
1022
1023    use super::*;
1024    use crate::fixtures::TestContext;
1025    use crate::parser::{ExpressionParser, Rule};
1026
1027    #[test]
1028    fn test_parse_relation() {
1029        // Removed: test_parse_relation for old Relation struct
1030    }
1031
1032    #[test]
1033    fn test_parse_read_relation() {
1034        let extensions = SimpleExtensions::default();
1035        let read = ReadRel::parse_pair_with_context(
1036            &extensions,
1037            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
1038            vec![],
1039            0,
1040        )
1041        .unwrap();
1042        let names = match &read.read_type {
1043            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
1044            _ => panic!("Expected NamedTable"),
1045        };
1046        assert_eq!(names, &["ab", "cd", "ef"]);
1047        let columns = &read
1048            .base_schema
1049            .as_ref()
1050            .unwrap()
1051            .r#struct
1052            .as_ref()
1053            .unwrap()
1054            .types;
1055        assert_eq!(columns.len(), 2);
1056    }
1057
1058    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
1059    fn example_read_relation() -> ReadRel {
1060        let extensions = SimpleExtensions::default();
1061        ReadRel::parse_pair_with_context(
1062            &extensions,
1063            parse_exact(
1064                Rule::read_relation,
1065                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
1066            ),
1067            vec![],
1068            0,
1069        )
1070        .unwrap()
1071    }
1072
1073    #[test]
1074    fn test_parse_filter_relation() {
1075        let extensions = SimpleExtensions::default();
1076        let filter = FilterRel::parse_pair_with_context(
1077            &extensions,
1078            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
1079            vec![Box::new(example_read_relation().into_rel())],
1080            3,
1081        )
1082        .unwrap();
1083        let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1084        let emit = match emit_kind {
1085            EmitKind::Emit(emit) => &emit.output_mapping,
1086            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1087        };
1088        assert_eq!(emit, &[0, 1, 2]);
1089    }
1090
1091    #[test]
1092    fn test_parse_project_relation() {
1093        let extensions = SimpleExtensions::default();
1094        let project = ProjectRel::parse_pair_with_context(
1095            &extensions,
1096            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
1097            vec![Box::new(example_read_relation().into_rel())],
1098            3,
1099        )
1100        .unwrap();
1101
1102        // Should have 1 expression (42) and 2 references ($0, $1)
1103        assert_eq!(project.expressions.len(), 1);
1104
1105        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1106        let emit = match emit_kind {
1107            EmitKind::Emit(emit) => &emit.output_mapping,
1108            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1109        };
1110        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
1111        assert_eq!(emit, &[0, 1, 3]);
1112    }
1113
1114    #[test]
1115    fn test_parse_project_relation_complex() {
1116        let extensions = SimpleExtensions::default();
1117        let project = ProjectRel::parse_pair_with_context(
1118            &extensions,
1119            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
1120            vec![Box::new(example_read_relation().into_rel())],
1121            5, // Assume 5 input fields
1122        )
1123        .unwrap();
1124
1125        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
1126        assert_eq!(project.expressions.len(), 2);
1127
1128        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1129        let emit = match emit_kind {
1130            EmitKind::Emit(emit) => &emit.output_mapping,
1131            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1132        };
1133        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
1134        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
1135        assert_eq!(emit, &[5, 0, 6, 2, 1]);
1136    }
1137
1138    #[test]
1139    fn test_parse_aggregate_relation() {
1140        let extensions = TestContext::new()
1141            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1142            .with_function(1, 10, "sum")
1143            .with_function(1, 11, "count")
1144            .extensions;
1145
1146        let aggregate = AggregateRel::parse_pair_with_context(
1147            &extensions,
1148            parse_exact(
1149                Rule::aggregate_relation,
1150                "Aggregate[($0, $1), _ => sum($2), $0, count($2)]",
1151            ),
1152            vec![Box::new(example_read_relation().into_rel())],
1153            3,
1154        )
1155        .unwrap();
1156
1157        // Should have 2 group-by sets ($0, $1) and an empty group, and emit 2 measures (sum($2), count($2))
1158        assert_eq!(aggregate.grouping_expressions.len(), 2);
1159        assert_eq!(aggregate.groupings[0].expression_references.len(), 2);
1160        assert_eq!(aggregate.groupings.len(), 2);
1161        assert_eq!(aggregate.measures.len(), 2);
1162
1163        let emit_kind = &aggregate
1164            .common
1165            .as_ref()
1166            .unwrap()
1167            .emit_kind
1168            .as_ref()
1169            .unwrap();
1170        let emit = match emit_kind {
1171            EmitKind::Emit(emit) => &emit.output_mapping,
1172            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1173        };
1174        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1175        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1176        assert_eq!(emit, &[2, 0, 3]);
1177    }
1178
1179    #[test]
1180    fn test_parse_aggregate_relation_maintain_column_order() {
1181        let extensions = TestContext::new()
1182            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1183            .with_function(1, 10, "sum")
1184            .with_function(1, 11, "count")
1185            .extensions;
1186
1187        let aggregate = AggregateRel::parse_pair_with_context(
1188            &extensions,
1189            parse_exact(
1190                Rule::aggregate_relation,
1191                "Aggregate[$0 => sum($1), $0, count($1)]",
1192            ),
1193            vec![Box::new(example_read_relation().into_rel())],
1194            3,
1195        )
1196        .unwrap();
1197
1198        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1199        assert_eq!(aggregate.grouping_expressions.len(), 1);
1200        assert_eq!(aggregate.groupings.len(), 1);
1201        assert_eq!(aggregate.measures.len(), 2);
1202
1203        let emit_kind = &aggregate
1204            .common
1205            .as_ref()
1206            .unwrap()
1207            .emit_kind
1208            .as_ref()
1209            .unwrap();
1210        let emit = match emit_kind {
1211            EmitKind::Emit(emit) => &emit.output_mapping,
1212            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1213        };
1214        // Output mapping should be [1, 0, 2] (grouping fields + measures)
1215        assert_eq!(emit, &[1, 0, 2]);
1216    }
1217
1218    #[test]
1219    fn test_parse_aggregate_relation_simple() {
1220        let extensions = TestContext::new()
1221            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1222            .with_function(1, 10, "sum")
1223            .extensions;
1224
1225        let aggregate = AggregateRel::parse_pair_with_context(
1226            &extensions,
1227            parse_exact(Rule::aggregate_relation, "Aggregate[$2, $0 => sum($1)]"),
1228            vec![Box::new(example_read_relation().into_rel())],
1229            3,
1230        )
1231        .unwrap();
1232
1233        assert_eq!(aggregate.grouping_expressions.len(), 2);
1234        assert_eq!(aggregate.groupings.len(), 1);
1235        // expression_references must be positions [0, 1], not raw field indices [2, 0]
1236        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1]);
1237    }
1238
1239    #[test]
1240    fn test_parse_aggregate_relation_global_aggregate() {
1241        let extensions = TestContext::new()
1242            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1243            .with_function(1, 10, "sum")
1244            .with_function(1, 11, "count")
1245            .extensions;
1246
1247        let aggregate = AggregateRel::parse_pair_with_context(
1248            &extensions,
1249            parse_exact(
1250                Rule::aggregate_relation,
1251                "Aggregate[_ => sum($0), count($1)]",
1252            ),
1253            vec![Box::new(example_read_relation().into_rel())],
1254            3,
1255        )
1256        .unwrap();
1257
1258        // Should have 0 group-by fields and 2 measures
1259        assert_eq!(aggregate.grouping_expressions.len(), 0);
1260        assert_eq!(aggregate.groupings.len(), 1);
1261        assert_eq!(aggregate.groupings[0].expression_references.len(), 0);
1262        assert_eq!(aggregate.measures.len(), 2);
1263
1264        let emit_kind = &aggregate
1265            .common
1266            .as_ref()
1267            .unwrap()
1268            .emit_kind
1269            .as_ref()
1270            .unwrap();
1271        let emit = match emit_kind {
1272            EmitKind::Emit(emit) => &emit.output_mapping,
1273            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1274        };
1275        // Output mapping should be [0, 1] (measures only, no group-by fields)
1276        assert_eq!(emit, &[0, 1]);
1277    }
1278
1279    #[test]
1280    fn test_parse_aggregate_relation_grouping_sets() {
1281        let extensions = TestContext::new()
1282            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1283            .with_function(1, 11, "count")
1284            .extensions;
1285
1286        let read_rel = ReadRel::parse_pair_with_context(
1287            &extensions,
1288            parse_exact(
1289                Rule::read_relation,
1290                "Read[ab.cd.ef => a:i32, b:string?, c:i64, d:i64]",
1291            ),
1292            vec![],
1293            0,
1294        )
1295        .unwrap();
1296
1297        let aggregate = AggregateRel::parse_pair_with_context(
1298            &extensions,
1299            parse_exact(
1300                Rule::aggregate_relation,
1301                "Aggregate[($0, $1, $2), ($2, $0), ($1), _ => $0, $1, $2, count($3)]",
1302            ),
1303            vec![Box::new(read_rel.into_rel())],
1304            4,
1305        )
1306        .unwrap();
1307
1308        assert_eq!(aggregate.grouping_expressions.len(), 3);
1309        assert_eq!(aggregate.groupings.len(), 4);
1310        // ($0, $1, $2) -> [0, 1, 2]
1311        assert_eq!(aggregate.groupings[0].expression_references, vec![0, 1, 2]);
1312        // ($2, $0) -> [2, 0] (reuses indices, different order)
1313        assert_eq!(aggregate.groupings[1].expression_references, vec![2, 0]);
1314        // ($1) -> [1]
1315        assert_eq!(aggregate.groupings[2].expression_references, vec![1]);
1316        // _ -> empty
1317        assert!(aggregate.groupings[3].expression_references.is_empty());
1318        assert_eq!(aggregate.measures.len(), 1);
1319    }
1320
1321    #[test]
1322    fn test_fetch_relation_positive_values() {
1323        let extensions = SimpleExtensions::default();
1324
1325        // Test valid positive values should work
1326        let fetch_rel = FetchRel::parse_pair_with_context(
1327            &extensions,
1328            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1329            vec![Box::new(example_read_relation().into_rel())],
1330            3,
1331        )
1332        .unwrap();
1333
1334        // Verify the limit and offset values are correct
1335        assert_eq!(
1336            fetch_rel.count_mode,
1337            Some(CountMode::CountExpr(i64_literal_expr(10)))
1338        );
1339        assert_eq!(
1340            fetch_rel.offset_mode,
1341            Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1342        );
1343    }
1344
1345    #[test]
1346    fn test_fetch_relation_negative_limit_rejected() {
1347        let extensions = SimpleExtensions::default();
1348
1349        // Test that fetch relations with negative limits are properly rejected
1350        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1351        if let Ok(mut pairs) = parsed_result {
1352            let pair = pairs.next().unwrap();
1353            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1354                // Full parse succeeded, now test that validation catches the negative value
1355                let result = FetchRel::parse_pair_with_context(
1356                    &extensions,
1357                    pair,
1358                    vec![Box::new(example_read_relation().into_rel())],
1359                    3,
1360                );
1361                assert!(result.is_err());
1362                let error_msg = result.unwrap_err().to_string();
1363                assert!(error_msg.contains("Fetch limit must be non-negative"));
1364            } else {
1365                // If grammar doesn't fully support negative values, that's also acceptable
1366                // since it would prevent negative values at parse time
1367                println!("Grammar prevents negative limit values at parse time");
1368            }
1369        } else {
1370            // Grammar doesn't support negative values in fetch context
1371            println!("Grammar prevents negative limit values at parse time");
1372        }
1373    }
1374
1375    #[test]
1376    fn test_fetch_relation_negative_offset_rejected() {
1377        let extensions = SimpleExtensions::default();
1378
1379        // Test that fetch relations with negative offsets are properly rejected
1380        let parsed_result =
1381            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1382        if let Ok(mut pairs) = parsed_result {
1383            let pair = pairs.next().unwrap();
1384            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1385                // Full parse succeeded, now test that validation catches the negative value
1386                let result = FetchRel::parse_pair_with_context(
1387                    &extensions,
1388                    pair,
1389                    vec![Box::new(example_read_relation().into_rel())],
1390                    3,
1391                );
1392                assert!(result.is_err());
1393                let error_msg = result.unwrap_err().to_string();
1394                assert!(error_msg.contains("Fetch offset must be non-negative"));
1395            } else {
1396                // If grammar doesn't fully support negative values, that's also acceptable
1397                println!("Grammar prevents negative offset values at parse time");
1398            }
1399        } else {
1400            // Grammar doesn't support negative values in fetch context
1401            println!("Grammar prevents negative offset values at parse time");
1402        }
1403    }
1404
1405    #[test]
1406    fn test_parse_join_relation() {
1407        let extensions = TestContext::new()
1408            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1409            .with_function(1, 10, "eq")
1410            .extensions;
1411
1412        let left_rel = example_read_relation().into_rel();
1413        let right_rel = example_read_relation().into_rel();
1414
1415        let join = JoinRel::parse_pair_with_context(
1416            &extensions,
1417            parse_exact(
1418                Rule::join_relation,
1419                "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1420            ),
1421            vec![Box::new(left_rel), Box::new(right_rel)],
1422            6, // left (3) + right (3) = 6 total input fields
1423        )
1424        .unwrap();
1425
1426        // Should be an Inner join
1427        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1428
1429        // Should have left and right relations
1430        assert!(join.left.is_some());
1431        assert!(join.right.is_some());
1432
1433        // Should have a join condition
1434        assert!(join.expression.is_some());
1435
1436        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1437        let emit = match emit_kind {
1438            EmitKind::Emit(emit) => &emit.output_mapping,
1439            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1440        };
1441        // Output mapping should be [0, 1, 3, 4] (selected columns)
1442        assert_eq!(emit, &[0, 1, 3, 4]);
1443    }
1444
1445    #[test]
1446    fn test_parse_join_relation_left_outer() {
1447        let extensions = TestContext::new()
1448            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1449            .with_function(1, 10, "eq")
1450            .extensions;
1451
1452        let left_rel = example_read_relation().into_rel();
1453        let right_rel = example_read_relation().into_rel();
1454
1455        let join = JoinRel::parse_pair_with_context(
1456            &extensions,
1457            parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1458            vec![Box::new(left_rel), Box::new(right_rel)],
1459            6,
1460        )
1461        .unwrap();
1462
1463        // Should be a Left join
1464        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1465
1466        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1467        let emit = match emit_kind {
1468            EmitKind::Emit(emit) => &emit.output_mapping,
1469            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1470        };
1471        // Output mapping should be [0, 1, 2]
1472        assert_eq!(emit, &[0, 1, 2]);
1473    }
1474
1475    #[test]
1476    fn test_parse_join_relation_left_semi() {
1477        let extensions = TestContext::new()
1478            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1479            .with_function(1, 10, "eq")
1480            .extensions;
1481
1482        let left_rel = example_read_relation().into_rel();
1483        let right_rel = example_read_relation().into_rel();
1484
1485        let join = JoinRel::parse_pair_with_context(
1486            &extensions,
1487            parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1488            vec![Box::new(left_rel), Box::new(right_rel)],
1489            6,
1490        )
1491        .unwrap();
1492
1493        // Should be a LeftSemi join
1494        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1495
1496        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1497        let emit = match emit_kind {
1498            EmitKind::Emit(emit) => &emit.output_mapping,
1499            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1500        };
1501        // Output mapping should be [0, 1] (only left columns for semi join)
1502        assert_eq!(emit, &[0, 1]);
1503    }
1504
1505    #[test]
1506    fn test_parse_join_relation_requires_two_children() {
1507        let extensions = SimpleExtensions::default();
1508
1509        // Test with 0 children
1510        let result = JoinRel::parse_pair_with_context(
1511            &extensions,
1512            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1513            vec![],
1514            0,
1515        );
1516        assert!(result.is_err());
1517
1518        // Test with 1 child
1519        let result = JoinRel::parse_pair_with_context(
1520            &extensions,
1521            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1522            vec![Box::new(example_read_relation().into_rel())],
1523            3,
1524        );
1525        assert!(result.is_err());
1526
1527        // Test with 3 children
1528        let result = JoinRel::parse_pair_with_context(
1529            &extensions,
1530            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1531            vec![
1532                Box::new(example_read_relation().into_rel()),
1533                Box::new(example_read_relation().into_rel()),
1534                Box::new(example_read_relation().into_rel()),
1535            ],
1536            9,
1537        );
1538        assert!(result.is_err());
1539    }
1540
1541    fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1542        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1543        assert_eq!(pairs.as_str(), input);
1544        let pair = pairs.next().unwrap();
1545        assert_eq!(pairs.next(), None);
1546        pair
1547    }
1548}