substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{Literal, RexType};
5use substrait::proto::fetch_rel::{CountMode, OffsetMode};
6use substrait::proto::rel::RelType;
7use substrait::proto::rel_common::{Emit, EmitKind};
8use substrait::proto::sort_field::{SortDirection, SortKind};
9use substrait::proto::{
10    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
11    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
12};
13
14use super::{
15    ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
16};
17use crate::extensions::SimpleExtensions;
18use crate::parser::expressions::{FieldIndex, Name};
19
20/// A trait for parsing relations with full context needed for tree building.
21/// This includes extensions, the parsed pair, input children, and output field count.
22pub trait RelationParsePair: Sized {
23    fn rule() -> Rule;
24
25    fn message() -> &'static str;
26
27    /// Parse a relation with full context for tree building.
28    ///
29    /// Args:
30    /// - extensions: The extensions context
31    /// - pair: The parsed pest pair
32    /// - input_children: The input relations (for wiring)
33    /// - input_field_count: Number of output fields from input children (for output mapping)
34    fn parse_pair_with_context(
35        extensions: &SimpleExtensions,
36        pair: pest::iterators::Pair<Rule>,
37        input_children: Vec<Box<Rel>>,
38        input_field_count: usize,
39    ) -> Result<Self, MessageParseError>;
40
41    fn into_rel(self) -> Rel;
42}
43
44pub struct TableName(Vec<String>);
45
46impl ParsePair for TableName {
47    fn rule() -> Rule {
48        Rule::table_name
49    }
50
51    fn message() -> &'static str {
52        "TableName"
53    }
54
55    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
56        assert_eq!(pair.as_rule(), Self::rule());
57        let pairs = pair.into_inner();
58        let mut names = Vec::with_capacity(pairs.len());
59        let mut iter = RuleIter::from(pairs);
60        while let Some(name) = iter.parse_if_next::<Name>() {
61            names.push(name.0);
62        }
63        iter.done();
64        Self(names)
65    }
66}
67
68#[derive(Debug, Clone)]
69pub struct Column {
70    pub name: String,
71    pub typ: Type,
72}
73
74impl ScopedParsePair for Column {
75    fn rule() -> Rule {
76        Rule::named_column
77    }
78
79    fn message() -> &'static str {
80        "Column"
81    }
82
83    fn parse_pair(
84        extensions: &SimpleExtensions,
85        pair: pest::iterators::Pair<Rule>,
86    ) -> Result<Self, MessageParseError> {
87        assert_eq!(pair.as_rule(), Self::rule());
88        let mut iter = RuleIter::from(pair.into_inner());
89        let name = iter.parse_next::<Name>().0;
90        let typ = iter.parse_next_scoped(extensions)?;
91        iter.done();
92        Ok(Self { name, typ })
93    }
94}
95
96pub struct NamedColumnList(Vec<Column>);
97
98impl ScopedParsePair for NamedColumnList {
99    fn rule() -> Rule {
100        Rule::named_column_list
101    }
102
103    fn message() -> &'static str {
104        "NamedColumnList"
105    }
106
107    fn parse_pair(
108        extensions: &SimpleExtensions,
109        pair: pest::iterators::Pair<Rule>,
110    ) -> Result<Self, MessageParseError> {
111        assert_eq!(pair.as_rule(), Self::rule());
112        let mut columns = Vec::new();
113        for col in pair.into_inner() {
114            columns.push(Column::parse_pair(extensions, col)?);
115        }
116        Ok(Self(columns))
117    }
118}
119
120/// This is a utility function for extracting a single child from the list of
121/// children, to be used in the RelationParsePair trait. The RelationParsePair
122/// trait passes a Vec of children, because some relations have multiple
123/// children - but most accept exactly one child.
124#[allow(clippy::vec_box)]
125pub(crate) fn expect_one_child(
126    message: &'static str,
127    pair: &pest::iterators::Pair<Rule>,
128    mut input_children: Vec<Box<Rel>>,
129) -> Result<Box<Rel>, MessageParseError> {
130    match input_children.len() {
131        0 => Err(MessageParseError::invalid(
132            message,
133            pair.as_span(),
134            format!("{message} missing child"),
135        )),
136        1 => Ok(input_children.pop().unwrap()),
137        n => Err(MessageParseError::invalid(
138            message,
139            pair.as_span(),
140            format!("{message} should have 1 input child, got {n}"),
141        )),
142    }
143}
144
145/// Parse a reference list Pair and return an EmitKind::Emit.
146fn parse_reference_emit(pair: pest::iterators::Pair<Rule>) -> EmitKind {
147    assert_eq!(pair.as_rule(), Rule::reference_list);
148    let output_mapping = pair
149        .into_inner()
150        .map(|p| FieldIndex::parse_pair(p).0)
151        .collect::<Vec<i32>>();
152    EmitKind::Emit(Emit { output_mapping })
153}
154
155/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
156///
157/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
158///
159/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
160pub struct ParsedNamedArgs<'a> {
161    map: HashMap<&'a str, pest::iterators::Pair<'a, Rule>>,
162}
163
164impl<'a> ParsedNamedArgs<'a> {
165    pub fn new(
166        pairs: pest::iterators::Pairs<'a, Rule>,
167        rule: Rule,
168    ) -> Result<Self, MessageParseError> {
169        let mut map = HashMap::new();
170        for pair in pairs {
171            assert_eq!(pair.as_rule(), rule);
172            let mut inner = pair.clone().into_inner();
173            let name_pair = inner.next().unwrap();
174            let value_pair = inner.next().unwrap();
175            assert_eq!(inner.next(), None);
176            let name = name_pair.as_str();
177            if map.contains_key(name) {
178                return Err(MessageParseError::invalid(
179                    "NamedArg",
180                    name_pair.as_span(),
181                    format!("Duplicate argument: {name}"),
182                ));
183            }
184            map.insert(name, value_pair);
185        }
186        Ok(Self { map })
187    }
188
189    // Returns the pair if it exists and matches the rule, otherwise None.
190    // Asserts that the rule must match the rule of the pair (and therefore
191    // panics in non-release-mode if not)
192    pub fn pop(
193        mut self,
194        name: &str,
195        rule: Rule,
196    ) -> (Self, Option<pest::iterators::Pair<'a, Rule>>) {
197        let pair = self.map.remove(name).inspect(|pair| {
198            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
199        });
200        (self, pair)
201    }
202
203    // Returns an error if there are any unused arguments.
204    pub fn done(self) -> Result<(), MessageParseError> {
205        if let Some((name, pair)) = self.map.iter().next() {
206            return Err(MessageParseError::invalid(
207                "NamedArgExtractor",
208                // No span available for all unused args; use default.
209                pair.as_span(),
210                format!("Unknown argument: {name}"),
211            ));
212        }
213        Ok(())
214    }
215}
216
217impl RelationParsePair for ReadRel {
218    fn rule() -> Rule {
219        Rule::read_relation
220    }
221
222    fn message() -> &'static str {
223        "ReadRel"
224    }
225
226    fn into_rel(self) -> Rel {
227        Rel {
228            rel_type: Some(RelType::Read(Box::new(self))),
229        }
230    }
231
232    fn parse_pair_with_context(
233        extensions: &SimpleExtensions,
234        pair: pest::iterators::Pair<Rule>,
235        input_children: Vec<Box<Rel>>,
236        input_field_count: usize,
237    ) -> Result<Self, MessageParseError> {
238        assert_eq!(pair.as_rule(), Self::rule());
239        // ReadRel is a leaf node - it should have no input children and 0 input fields
240        if !input_children.is_empty() {
241            return Err(MessageParseError::invalid(
242                Self::message(),
243                pair.as_span(),
244                "ReadRel should have no input children",
245            ));
246        }
247        if input_field_count != 0 {
248            let error = pest::error::Error::new_from_span(
249                pest::error::ErrorVariant::CustomError {
250                    message: "ReadRel should have 0 input fields".to_string(),
251                },
252                pair.as_span(),
253            );
254            return Err(MessageParseError::new(
255                "ReadRel",
256                ErrorKind::InvalidValue,
257                Box::new(error),
258            ));
259        }
260
261        let mut iter = RuleIter::from(pair.into_inner());
262        let table = iter.parse_next::<TableName>().0;
263        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
264        iter.done();
265
266        let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
267        let struct_ = r#type::Struct {
268            types,
269            type_variation_reference: 0,
270            nullability: r#type::Nullability::Required as i32,
271        };
272        let named_struct = NamedStruct {
273            names,
274            r#struct: Some(struct_),
275        };
276
277        let read_rel = ReadRel {
278            base_schema: Some(named_struct),
279            read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
280                names: table,
281                advanced_extension: None,
282            })),
283            ..Default::default()
284        };
285
286        Ok(read_rel)
287    }
288}
289
290impl RelationParsePair for FilterRel {
291    fn rule() -> Rule {
292        Rule::filter_relation
293    }
294
295    fn message() -> &'static str {
296        "FilterRel"
297    }
298
299    fn into_rel(self) -> Rel {
300        Rel {
301            rel_type: Some(RelType::Filter(Box::new(self))),
302        }
303    }
304
305    fn parse_pair_with_context(
306        extensions: &SimpleExtensions,
307        pair: pest::iterators::Pair<Rule>,
308        input_children: Vec<Box<Rel>>,
309        _input_field_count: usize,
310    ) -> Result<Self, MessageParseError> {
311        // Form: Filter[condition => references]
312
313        assert_eq!(pair.as_rule(), Self::rule());
314        let input = expect_one_child(Self::message(), &pair, input_children)?;
315        let mut iter = RuleIter::from(pair.into_inner());
316        // condition
317        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
318        // references (which become the emit)
319        let references_pair = iter.pop(Rule::reference_list);
320        iter.done();
321
322        let emit = parse_reference_emit(references_pair);
323        let common = RelCommon {
324            emit_kind: Some(emit),
325            ..Default::default()
326        };
327
328        Ok(FilterRel {
329            input: Some(input),
330            condition: Some(Box::new(condition)),
331            common: Some(common),
332            advanced_extension: None,
333        })
334    }
335}
336
337impl RelationParsePair for ProjectRel {
338    fn rule() -> Rule {
339        Rule::project_relation
340    }
341
342    fn message() -> &'static str {
343        "ProjectRel"
344    }
345
346    fn into_rel(self) -> Rel {
347        Rel {
348            rel_type: Some(RelType::Project(Box::new(self))),
349        }
350    }
351
352    fn parse_pair_with_context(
353        extensions: &SimpleExtensions,
354        pair: pest::iterators::Pair<Rule>,
355        input_children: Vec<Box<Rel>>,
356        input_field_count: usize,
357    ) -> Result<Self, MessageParseError> {
358        assert_eq!(pair.as_rule(), Self::rule());
359        let input = expect_one_child(Self::message(), &pair, input_children)?;
360
361        // Get the argument list (contains references and expressions)
362        let arguments_pair = unwrap_single_pair(pair);
363
364        let mut expressions = Vec::new();
365        let mut output_mapping = Vec::new();
366
367        // Process each argument (can be either a reference or expression)
368        for arg in arguments_pair.into_inner() {
369            let inner_arg = crate::parser::unwrap_single_pair(arg);
370            match inner_arg.as_rule() {
371                Rule::reference => {
372                    // Parse reference like "$0" -> 0
373                    let field_index = FieldIndex::parse_pair(inner_arg);
374                    output_mapping.push(field_index.0);
375                }
376                Rule::expression => {
377                    // Parse as expression (e.g., 42, add($0, $1))
378                    let _expr = Expression::parse_pair(extensions, inner_arg)?;
379                    expressions.push(_expr);
380                    // Expression: index after all input fields
381                    output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
382                }
383                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
384            }
385        }
386
387        let emit = EmitKind::Emit(Emit { output_mapping });
388        let common = RelCommon {
389            emit_kind: Some(emit),
390            ..Default::default()
391        };
392
393        Ok(ProjectRel {
394            input: Some(input),
395            expressions,
396            common: Some(common),
397            advanced_extension: None,
398        })
399    }
400}
401
402impl RelationParsePair for AggregateRel {
403    fn rule() -> Rule {
404        Rule::aggregate_relation
405    }
406
407    fn message() -> &'static str {
408        "AggregateRel"
409    }
410
411    fn into_rel(self) -> Rel {
412        Rel {
413            rel_type: Some(RelType::Aggregate(Box::new(self))),
414        }
415    }
416
417    fn parse_pair_with_context(
418        extensions: &SimpleExtensions,
419        pair: pest::iterators::Pair<Rule>,
420        input_children: Vec<Box<Rel>>,
421        _input_field_count: usize,
422    ) -> Result<Self, MessageParseError> {
423        assert_eq!(pair.as_rule(), Self::rule());
424        let input = expect_one_child(Self::message(), &pair, input_children)?;
425        let mut iter = RuleIter::from(pair.into_inner());
426        let group_by_pair = iter.pop(Rule::aggregate_group_by);
427        let output_pair = iter.pop(Rule::aggregate_output);
428        iter.done();
429        let mut grouping_expressions = Vec::new();
430        for group_by_item in group_by_pair.into_inner() {
431            match group_by_item.as_rule() {
432                Rule::reference => {
433                    let field_index = FieldIndex::parse_pair(group_by_item);
434                    grouping_expressions.push(Expression {
435                        rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
436                            field_index.to_field_reference(),
437                        ))),
438                    });
439                }
440                Rule::empty => {
441                    // No grouping expressions to add
442                }
443                _ => panic!(
444                    "Unexpected group-by item rule: {:?}",
445                    group_by_item.as_rule()
446                ),
447            }
448        }
449
450        // Parse output items (can be references or aggregate measures)
451        let mut measures = Vec::new();
452        let mut output_mapping = Vec::new();
453        let group_by_count = grouping_expressions.len();
454        let mut measure_count = 0;
455
456        for output_item in output_pair.into_inner() {
457            let inner_item = unwrap_single_pair(output_item);
458            match inner_item.as_rule() {
459                Rule::reference => {
460                    let field_index = FieldIndex::parse_pair(inner_item);
461                    output_mapping.push(field_index.0);
462                }
463                Rule::aggregate_measure => {
464                    let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
465                    measures.push(measure);
466                    output_mapping.push(group_by_count as i32 + measure_count);
467                    measure_count += 1;
468                }
469                _ => panic!(
470                    "Unexpected inner output item rule: {:?}",
471                    inner_item.as_rule()
472                ),
473            }
474        }
475
476        let emit = EmitKind::Emit(Emit { output_mapping });
477        let common = RelCommon {
478            emit_kind: Some(emit),
479            ..Default::default()
480        };
481
482        Ok(AggregateRel {
483            input: Some(input),
484            grouping_expressions,
485            groupings: vec![], // TODO: Create groupings from grouping_expressions for complex grouping scenarios
486            measures,
487            common: Some(common),
488            advanced_extension: None,
489        })
490    }
491}
492
493impl ScopedParsePair for SortField {
494    fn rule() -> Rule {
495        Rule::sort_field
496    }
497
498    fn message() -> &'static str {
499        "SortField"
500    }
501
502    fn parse_pair(
503        _extensions: &SimpleExtensions,
504        pair: pest::iterators::Pair<Rule>,
505    ) -> Result<Self, MessageParseError> {
506        assert_eq!(pair.as_rule(), Self::rule());
507        let mut iter = RuleIter::from(pair.into_inner());
508        let reference_pair = iter.pop(Rule::reference);
509        let field_index = FieldIndex::parse_pair(reference_pair);
510        let direction_pair = iter.pop(Rule::sort_direction);
511        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
512        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
513        // identifiers, but the enum variant names don't include it
514        let direction = match direction_pair.as_str().trim_start_matches('&') {
515            "AscNullsFirst" => SortDirection::AscNullsFirst,
516            "AscNullsLast" => SortDirection::AscNullsLast,
517            "DescNullsFirst" => SortDirection::DescNullsFirst,
518            "DescNullsLast" => SortDirection::DescNullsLast,
519            other => {
520                return Err(MessageParseError::invalid(
521                    "SortDirection",
522                    direction_pair.as_span(),
523                    format!("Unknown sort direction: {other}"),
524                ));
525            }
526        };
527        iter.done();
528        Ok(SortField {
529            expr: Some(Expression {
530                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
531                    field_index.to_field_reference(),
532                ))),
533            }),
534            // TODO: Add support for SortKind::ComparisonFunctionReference
535            sort_kind: Some(SortKind::Direction(direction as i32)),
536        })
537    }
538}
539
540impl RelationParsePair for SortRel {
541    fn rule() -> Rule {
542        Rule::sort_relation
543    }
544
545    fn message() -> &'static str {
546        "SortRel"
547    }
548
549    fn into_rel(self) -> Rel {
550        Rel {
551            rel_type: Some(RelType::Sort(Box::new(self))),
552        }
553    }
554
555    fn parse_pair_with_context(
556        extensions: &SimpleExtensions,
557        pair: pest::iterators::Pair<Rule>,
558        input_children: Vec<Box<Rel>>,
559        _input_field_count: usize,
560    ) -> Result<Self, MessageParseError> {
561        assert_eq!(pair.as_rule(), Self::rule());
562        let input = expect_one_child(Self::message(), &pair, input_children)?;
563        let mut iter = RuleIter::from(pair.into_inner());
564        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
565        let reference_list_pair = iter.pop(Rule::reference_list);
566        let mut sorts = Vec::new();
567        for sort_field_pair in sort_field_list_pair.into_inner() {
568            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
569            sorts.push(sort_field);
570        }
571        let emit = parse_reference_emit(reference_list_pair);
572        let common = RelCommon {
573            emit_kind: Some(emit),
574            ..Default::default()
575        };
576        iter.done();
577        Ok(SortRel {
578            input: Some(input),
579            sorts,
580            common: Some(common),
581            advanced_extension: None,
582        })
583    }
584}
585
586impl ScopedParsePair for CountMode {
587    fn rule() -> Rule {
588        Rule::fetch_value
589    }
590    fn message() -> &'static str {
591        "CountMode"
592    }
593    fn parse_pair(
594        extensions: &SimpleExtensions,
595        pair: pest::iterators::Pair<Rule>,
596    ) -> Result<Self, MessageParseError> {
597        assert_eq!(pair.as_rule(), Self::rule());
598        let mut arg_inner = RuleIter::from(pair.into_inner());
599        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
600            int_pair
601        } else {
602            arg_inner.pop(Rule::expression)
603        };
604        match value_pair.as_rule() {
605            Rule::integer => {
606                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
607                    MessageParseError::invalid(
608                        Self::message(),
609                        value_pair.as_span(),
610                        format!("Invalid integer: {e}"),
611                    )
612                })?;
613                if value < 0 {
614                    return Err(MessageParseError::invalid(
615                        Self::message(),
616                        value_pair.as_span(),
617                        format!("Fetch limit must be non-negative, got: {value}"),
618                    ));
619                }
620                Ok(CountMode::CountExpr(i64_literal_expr(value)))
621            }
622            Rule::expression => {
623                let expr = Expression::parse_pair(extensions, value_pair)?;
624                Ok(CountMode::CountExpr(Box::new(expr)))
625            }
626            _ => Err(MessageParseError::invalid(
627                Self::message(),
628                value_pair.as_span(),
629                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
630            )),
631        }
632    }
633}
634
635fn i64_literal_expr(value: i64) -> Box<Expression> {
636    Box::new(Expression {
637        rex_type: Some(RexType::Literal(Literal {
638            nullable: false,
639            type_variation_reference: 0,
640            literal_type: Some(LiteralType::I64(value)),
641        })),
642    })
643}
644
645impl ScopedParsePair for OffsetMode {
646    fn rule() -> Rule {
647        Rule::fetch_value
648    }
649    fn message() -> &'static str {
650        "OffsetMode"
651    }
652    fn parse_pair(
653        extensions: &SimpleExtensions,
654        pair: pest::iterators::Pair<Rule>,
655    ) -> Result<Self, MessageParseError> {
656        assert_eq!(pair.as_rule(), Self::rule());
657        let mut arg_inner = RuleIter::from(pair.into_inner());
658        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
659            int_pair
660        } else {
661            arg_inner.pop(Rule::expression)
662        };
663        match value_pair.as_rule() {
664            Rule::integer => {
665                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
666                    MessageParseError::invalid(
667                        Self::message(),
668                        value_pair.as_span(),
669                        format!("Invalid integer: {e}"),
670                    )
671                })?;
672                if value < 0 {
673                    return Err(MessageParseError::invalid(
674                        Self::message(),
675                        value_pair.as_span(),
676                        format!("Fetch offset must be non-negative, got: {value}"),
677                    ));
678                }
679                Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
680            }
681            Rule::expression => {
682                let expr = Expression::parse_pair(extensions, value_pair)?;
683                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
684            }
685            _ => Err(MessageParseError::invalid(
686                Self::message(),
687                value_pair.as_span(),
688                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
689            )),
690        }
691    }
692}
693
694impl RelationParsePair for FetchRel {
695    fn rule() -> Rule {
696        Rule::fetch_relation
697    }
698
699    fn message() -> &'static str {
700        "FetchRel"
701    }
702
703    fn into_rel(self) -> Rel {
704        Rel {
705            rel_type: Some(RelType::Fetch(Box::new(self))),
706        }
707    }
708
709    fn parse_pair_with_context(
710        extensions: &SimpleExtensions,
711        pair: pest::iterators::Pair<Rule>,
712        input_children: Vec<Box<Rel>>,
713        _input_field_count: usize,
714    ) -> Result<Self, MessageParseError> {
715        assert_eq!(pair.as_rule(), Self::rule());
716        let input = expect_one_child(Self::message(), &pair, input_children)?;
717        let mut iter = RuleIter::from(pair.into_inner());
718
719        // Extract all pairs first, then do validation
720        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
721            None => {
722                // If there are no arguments, it should be empty
723                iter.pop(Rule::empty);
724                (None, None)
725            }
726            Some(fetch_args_pair) => {
727                let extractor =
728                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
729                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
730                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
731                extractor.done()?;
732                (limit_pair, offset_pair)
733            }
734        };
735
736        let reference_list_pair = iter.pop(Rule::reference_list);
737        let emit = parse_reference_emit(reference_list_pair);
738        let common = RelCommon {
739            emit_kind: Some(emit),
740            ..Default::default()
741        };
742        iter.done();
743
744        // Now do validation after iterator is fully consumed
745        let count_mode = limit_pair
746            .map(|pair| CountMode::parse_pair(extensions, pair))
747            .transpose()?;
748        let offset_mode = offset_pair
749            .map(|pair| OffsetMode::parse_pair(extensions, pair))
750            .transpose()?;
751        Ok(FetchRel {
752            input: Some(input),
753            common: Some(common),
754            advanced_extension: None,
755            offset_mode,
756            count_mode,
757        })
758    }
759}
760
761impl ParsePair for join_rel::JoinType {
762    fn rule() -> Rule {
763        Rule::join_type
764    }
765
766    fn message() -> &'static str {
767        "JoinType"
768    }
769
770    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
771        assert_eq!(pair.as_rule(), Self::rule());
772        let join_type_str = pair.as_str().trim_start_matches('&');
773        match join_type_str {
774            "Inner" => join_rel::JoinType::Inner,
775            "Left" => join_rel::JoinType::Left,
776            "Right" => join_rel::JoinType::Right,
777            "Outer" => join_rel::JoinType::Outer,
778            "LeftSemi" => join_rel::JoinType::LeftSemi,
779            "RightSemi" => join_rel::JoinType::RightSemi,
780            "LeftAnti" => join_rel::JoinType::LeftAnti,
781            "RightAnti" => join_rel::JoinType::RightAnti,
782            "LeftSingle" => join_rel::JoinType::LeftSingle,
783            "RightSingle" => join_rel::JoinType::RightSingle,
784            "LeftMark" => join_rel::JoinType::LeftMark,
785            "RightMark" => join_rel::JoinType::RightMark,
786            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
787        }
788    }
789}
790
791impl RelationParsePair for JoinRel {
792    fn rule() -> Rule {
793        Rule::join_relation
794    }
795
796    fn message() -> &'static str {
797        "JoinRel"
798    }
799
800    fn into_rel(self) -> Rel {
801        Rel {
802            rel_type: Some(RelType::Join(Box::new(self))),
803        }
804    }
805
806    fn parse_pair_with_context(
807        extensions: &SimpleExtensions,
808        pair: pest::iterators::Pair<Rule>,
809        input_children: Vec<Box<Rel>>,
810        _input_field_count: usize,
811    ) -> Result<Self, MessageParseError> {
812        assert_eq!(pair.as_rule(), Self::rule());
813
814        // Join requires exactly 2 input children
815        if input_children.len() != 2 {
816            return Err(MessageParseError::invalid(
817                Self::message(),
818                pair.as_span(),
819                format!(
820                    "JoinRel should have exactly 2 input children, got {}",
821                    input_children.len()
822                ),
823            ));
824        }
825
826        let mut children_iter = input_children.into_iter();
827        let left = children_iter.next().unwrap();
828        let right = children_iter.next().unwrap();
829
830        let mut iter = RuleIter::from(pair.into_inner());
831
832        // Parse join type
833        let join_type = iter.parse_next::<join_rel::JoinType>();
834
835        // Parse join condition expression
836        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
837
838        // Parse output references (which become the emit)
839        let reference_list_pair = iter.pop(Rule::reference_list);
840        iter.done();
841
842        let emit = parse_reference_emit(reference_list_pair);
843        let common = RelCommon {
844            emit_kind: Some(emit),
845            ..Default::default()
846        };
847
848        Ok(JoinRel {
849            common: Some(common),
850            left: Some(left),
851            right: Some(right),
852            expression: Some(Box::new(condition)),
853            post_join_filter: None, // Not supported in grammar yet
854            r#type: join_type as i32,
855            advanced_extension: None,
856        })
857    }
858}
859
860#[cfg(test)]
861mod tests {
862    use pest::Parser;
863
864    use super::*;
865    use crate::fixtures::TestContext;
866    use crate::parser::{ExpressionParser, Rule};
867
868    #[test]
869    fn test_parse_relation() {
870        // Removed: test_parse_relation for old Relation struct
871    }
872
873    #[test]
874    fn test_parse_read_relation() {
875        let extensions = SimpleExtensions::default();
876        let read = ReadRel::parse_pair_with_context(
877            &extensions,
878            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
879            vec![],
880            0,
881        )
882        .unwrap();
883        let names = match &read.read_type {
884            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
885            _ => panic!("Expected NamedTable"),
886        };
887        assert_eq!(names, &["ab", "cd", "ef"]);
888        let columns = &read
889            .base_schema
890            .as_ref()
891            .unwrap()
892            .r#struct
893            .as_ref()
894            .unwrap()
895            .types;
896        assert_eq!(columns.len(), 2);
897    }
898
899    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
900    fn example_read_relation() -> ReadRel {
901        let extensions = SimpleExtensions::default();
902        ReadRel::parse_pair_with_context(
903            &extensions,
904            parse_exact(
905                Rule::read_relation,
906                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
907            ),
908            vec![],
909            0,
910        )
911        .unwrap()
912    }
913
914    #[test]
915    fn test_parse_filter_relation() {
916        let extensions = SimpleExtensions::default();
917        let filter = FilterRel::parse_pair_with_context(
918            &extensions,
919            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
920            vec![Box::new(example_read_relation().into_rel())],
921            3,
922        )
923        .unwrap();
924        let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
925        let emit = match emit_kind {
926            EmitKind::Emit(emit) => &emit.output_mapping,
927            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
928        };
929        assert_eq!(emit, &[0, 1, 2]);
930    }
931
932    #[test]
933    fn test_parse_project_relation() {
934        let extensions = SimpleExtensions::default();
935        let project = ProjectRel::parse_pair_with_context(
936            &extensions,
937            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
938            vec![Box::new(example_read_relation().into_rel())],
939            3,
940        )
941        .unwrap();
942
943        // Should have 1 expression (42) and 2 references ($0, $1)
944        assert_eq!(project.expressions.len(), 1);
945
946        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
947        let emit = match emit_kind {
948            EmitKind::Emit(emit) => &emit.output_mapping,
949            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
950        };
951        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
952        assert_eq!(emit, &[0, 1, 3]);
953    }
954
955    #[test]
956    fn test_parse_project_relation_complex() {
957        let extensions = SimpleExtensions::default();
958        let project = ProjectRel::parse_pair_with_context(
959            &extensions,
960            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
961            vec![Box::new(example_read_relation().into_rel())],
962            5, // Assume 5 input fields
963        )
964        .unwrap();
965
966        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
967        assert_eq!(project.expressions.len(), 2);
968
969        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
970        let emit = match emit_kind {
971            EmitKind::Emit(emit) => &emit.output_mapping,
972            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
973        };
974        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
975        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
976        assert_eq!(emit, &[5, 0, 6, 2, 1]);
977    }
978
979    #[test]
980    fn test_parse_aggregate_relation() {
981        let extensions = TestContext::new()
982            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
983            .with_function(1, 10, "sum")
984            .with_function(1, 11, "count")
985            .extensions;
986
987        let aggregate = AggregateRel::parse_pair_with_context(
988            &extensions,
989            parse_exact(
990                Rule::aggregate_relation,
991                "Aggregate[$0, $1 => sum($2), $0, count($2)]",
992            ),
993            vec![Box::new(example_read_relation().into_rel())],
994            3,
995        )
996        .unwrap();
997
998        // Should have 2 group-by fields ($0, $1) and 2 measures (sum($2), count($2))
999        assert_eq!(aggregate.grouping_expressions.len(), 2);
1000        assert_eq!(aggregate.measures.len(), 2);
1001
1002        let emit_kind = &aggregate
1003            .common
1004            .as_ref()
1005            .unwrap()
1006            .emit_kind
1007            .as_ref()
1008            .unwrap();
1009        let emit = match emit_kind {
1010            EmitKind::Emit(emit) => &emit.output_mapping,
1011            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1012        };
1013        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1014        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1015        assert_eq!(emit, &[2, 0, 3]);
1016    }
1017
1018    #[test]
1019    fn test_parse_aggregate_relation_simple() {
1020        let extensions = TestContext::new()
1021            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1022            .with_function(1, 10, "sum")
1023            .with_function(1, 11, "count")
1024            .extensions;
1025
1026        let aggregate = AggregateRel::parse_pair_with_context(
1027            &extensions,
1028            parse_exact(
1029                Rule::aggregate_relation,
1030                "Aggregate[$0 => sum($1), count($1)]",
1031            ),
1032            vec![Box::new(example_read_relation().into_rel())],
1033            3,
1034        )
1035        .unwrap();
1036
1037        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1038        assert_eq!(aggregate.grouping_expressions.len(), 1);
1039        assert_eq!(aggregate.measures.len(), 2);
1040
1041        let emit_kind = &aggregate
1042            .common
1043            .as_ref()
1044            .unwrap()
1045            .emit_kind
1046            .as_ref()
1047            .unwrap();
1048        let emit = match emit_kind {
1049            EmitKind::Emit(emit) => &emit.output_mapping,
1050            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1051        };
1052        // Output mapping should be [1, 2] (measures only)
1053        assert_eq!(emit, &[1, 2]);
1054    }
1055
1056    #[test]
1057    fn test_parse_aggregate_relation_no_group_by() {
1058        let extensions = TestContext::new()
1059            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1060            .with_function(1, 10, "sum")
1061            .with_function(1, 11, "count")
1062            .extensions;
1063
1064        let aggregate = AggregateRel::parse_pair_with_context(
1065            &extensions,
1066            parse_exact(
1067                Rule::aggregate_relation,
1068                "Aggregate[_ => sum($0), count($1)]",
1069            ),
1070            vec![Box::new(example_read_relation().into_rel())],
1071            3,
1072        )
1073        .unwrap();
1074
1075        // Should have 0 group-by fields and 2 measures
1076        assert_eq!(aggregate.grouping_expressions.len(), 0);
1077        assert_eq!(aggregate.measures.len(), 2);
1078
1079        let emit_kind = &aggregate
1080            .common
1081            .as_ref()
1082            .unwrap()
1083            .emit_kind
1084            .as_ref()
1085            .unwrap();
1086        let emit = match emit_kind {
1087            EmitKind::Emit(emit) => &emit.output_mapping,
1088            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1089        };
1090        // Output mapping should be [0, 1] (measures only, no group-by fields)
1091        assert_eq!(emit, &[0, 1]);
1092    }
1093
1094    #[test]
1095    fn test_parse_aggregate_relation_empty_group_by() {
1096        let extensions = TestContext::new()
1097            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1098            .with_function(1, 10, "sum")
1099            .with_function(1, 11, "count")
1100            .extensions;
1101
1102        let aggregate = AggregateRel::parse_pair_with_context(
1103            &extensions,
1104            parse_exact(
1105                Rule::aggregate_relation,
1106                "Aggregate[_ => sum($0), count($1)]",
1107            ),
1108            vec![Box::new(example_read_relation().into_rel())],
1109            3,
1110        )
1111        .unwrap();
1112
1113        // Should have 0 group-by fields and 2 measures
1114        assert_eq!(aggregate.grouping_expressions.len(), 0);
1115        assert_eq!(aggregate.measures.len(), 2);
1116
1117        let emit_kind = &aggregate
1118            .common
1119            .as_ref()
1120            .unwrap()
1121            .emit_kind
1122            .as_ref()
1123            .unwrap();
1124        let emit = match emit_kind {
1125            EmitKind::Emit(emit) => &emit.output_mapping,
1126            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1127        };
1128        // Output mapping should be [0, 1] (measures only, no group-by fields)
1129        assert_eq!(emit, &[0, 1]);
1130    }
1131
1132    #[test]
1133    fn test_fetch_relation_positive_values() {
1134        let extensions = SimpleExtensions::default();
1135
1136        // Test valid positive values should work
1137        let fetch_rel = FetchRel::parse_pair_with_context(
1138            &extensions,
1139            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1140            vec![Box::new(example_read_relation().into_rel())],
1141            3,
1142        )
1143        .unwrap();
1144
1145        // Verify the limit and offset values are correct
1146        assert_eq!(
1147            fetch_rel.count_mode,
1148            Some(CountMode::CountExpr(i64_literal_expr(10)))
1149        );
1150        assert_eq!(
1151            fetch_rel.offset_mode,
1152            Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1153        );
1154    }
1155
1156    #[test]
1157    fn test_fetch_relation_negative_limit_rejected() {
1158        let extensions = SimpleExtensions::default();
1159
1160        // Test that fetch relations with negative limits are properly rejected
1161        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1162        if let Ok(mut pairs) = parsed_result {
1163            let pair = pairs.next().unwrap();
1164            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1165                // Full parse succeeded, now test that validation catches the negative value
1166                let result = FetchRel::parse_pair_with_context(
1167                    &extensions,
1168                    pair,
1169                    vec![Box::new(example_read_relation().into_rel())],
1170                    3,
1171                );
1172                assert!(result.is_err());
1173                let error_msg = result.unwrap_err().to_string();
1174                assert!(error_msg.contains("Fetch limit must be non-negative"));
1175            } else {
1176                // If grammar doesn't fully support negative values, that's also acceptable
1177                // since it would prevent negative values at parse time
1178                println!("Grammar prevents negative limit values at parse time");
1179            }
1180        } else {
1181            // Grammar doesn't support negative values in fetch context
1182            println!("Grammar prevents negative limit values at parse time");
1183        }
1184    }
1185
1186    #[test]
1187    fn test_fetch_relation_negative_offset_rejected() {
1188        let extensions = SimpleExtensions::default();
1189
1190        // Test that fetch relations with negative offsets are properly rejected
1191        let parsed_result =
1192            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1193        if let Ok(mut pairs) = parsed_result {
1194            let pair = pairs.next().unwrap();
1195            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1196                // Full parse succeeded, now test that validation catches the negative value
1197                let result = FetchRel::parse_pair_with_context(
1198                    &extensions,
1199                    pair,
1200                    vec![Box::new(example_read_relation().into_rel())],
1201                    3,
1202                );
1203                assert!(result.is_err());
1204                let error_msg = result.unwrap_err().to_string();
1205                assert!(error_msg.contains("Fetch offset must be non-negative"));
1206            } else {
1207                // If grammar doesn't fully support negative values, that's also acceptable
1208                println!("Grammar prevents negative offset values at parse time");
1209            }
1210        } else {
1211            // Grammar doesn't support negative values in fetch context
1212            println!("Grammar prevents negative offset values at parse time");
1213        }
1214    }
1215
1216    #[test]
1217    fn test_parse_join_relation() {
1218        let extensions = TestContext::new()
1219            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1220            .with_function(1, 10, "eq")
1221            .extensions;
1222
1223        let left_rel = example_read_relation().into_rel();
1224        let right_rel = example_read_relation().into_rel();
1225
1226        let join = JoinRel::parse_pair_with_context(
1227            &extensions,
1228            parse_exact(
1229                Rule::join_relation,
1230                "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1231            ),
1232            vec![Box::new(left_rel), Box::new(right_rel)],
1233            6, // left (3) + right (3) = 6 total input fields
1234        )
1235        .unwrap();
1236
1237        // Should be an Inner join
1238        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1239
1240        // Should have left and right relations
1241        assert!(join.left.is_some());
1242        assert!(join.right.is_some());
1243
1244        // Should have a join condition
1245        assert!(join.expression.is_some());
1246
1247        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1248        let emit = match emit_kind {
1249            EmitKind::Emit(emit) => &emit.output_mapping,
1250            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1251        };
1252        // Output mapping should be [0, 1, 3, 4] (selected columns)
1253        assert_eq!(emit, &[0, 1, 3, 4]);
1254    }
1255
1256    #[test]
1257    fn test_parse_join_relation_left_outer() {
1258        let extensions = TestContext::new()
1259            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1260            .with_function(1, 10, "eq")
1261            .extensions;
1262
1263        let left_rel = example_read_relation().into_rel();
1264        let right_rel = example_read_relation().into_rel();
1265
1266        let join = JoinRel::parse_pair_with_context(
1267            &extensions,
1268            parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1269            vec![Box::new(left_rel), Box::new(right_rel)],
1270            6,
1271        )
1272        .unwrap();
1273
1274        // Should be a Left join
1275        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1276
1277        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1278        let emit = match emit_kind {
1279            EmitKind::Emit(emit) => &emit.output_mapping,
1280            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1281        };
1282        // Output mapping should be [0, 1, 2]
1283        assert_eq!(emit, &[0, 1, 2]);
1284    }
1285
1286    #[test]
1287    fn test_parse_join_relation_left_semi() {
1288        let extensions = TestContext::new()
1289            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1290            .with_function(1, 10, "eq")
1291            .extensions;
1292
1293        let left_rel = example_read_relation().into_rel();
1294        let right_rel = example_read_relation().into_rel();
1295
1296        let join = JoinRel::parse_pair_with_context(
1297            &extensions,
1298            parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1299            vec![Box::new(left_rel), Box::new(right_rel)],
1300            6,
1301        )
1302        .unwrap();
1303
1304        // Should be a LeftSemi join
1305        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1306
1307        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1308        let emit = match emit_kind {
1309            EmitKind::Emit(emit) => &emit.output_mapping,
1310            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1311        };
1312        // Output mapping should be [0, 1] (only left columns for semi join)
1313        assert_eq!(emit, &[0, 1]);
1314    }
1315
1316    #[test]
1317    fn test_parse_join_relation_requires_two_children() {
1318        let extensions = SimpleExtensions::default();
1319
1320        // Test with 0 children
1321        let result = JoinRel::parse_pair_with_context(
1322            &extensions,
1323            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1324            vec![],
1325            0,
1326        );
1327        assert!(result.is_err());
1328
1329        // Test with 1 child
1330        let result = JoinRel::parse_pair_with_context(
1331            &extensions,
1332            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1333            vec![Box::new(example_read_relation().into_rel())],
1334            3,
1335        );
1336        assert!(result.is_err());
1337
1338        // Test with 3 children
1339        let result = JoinRel::parse_pair_with_context(
1340            &extensions,
1341            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1342            vec![
1343                Box::new(example_read_relation().into_rel()),
1344                Box::new(example_read_relation().into_rel()),
1345                Box::new(example_read_relation().into_rel()),
1346            ],
1347            9,
1348        );
1349        assert!(result.is_err());
1350    }
1351
1352    fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
1353        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1354        assert_eq!(pairs.as_str(), input);
1355        let pair = pairs.next().unwrap();
1356        assert_eq!(pairs.next(), None);
1357        pair
1358    }
1359}