substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use pest::iterators::Pair;
4use substrait::proto::expression::literal::LiteralType;
5use substrait::proto::expression::{Literal, RexType};
6use substrait::proto::fetch_rel::{CountMode, OffsetMode};
7use substrait::proto::rel::RelType;
8use substrait::proto::rel_common::{Emit, EmitKind};
9use substrait::proto::sort_field::{SortDirection, SortKind};
10use substrait::proto::{
11    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
12    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
13};
14
15use super::{
16    ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
17};
18use crate::extensions::any::Any;
19use crate::extensions::registry::ExtensionError;
20use crate::extensions::{ExtensionArgs, ExtensionRegistry, SimpleExtensions};
21use crate::parser::errors::{ParseContext, ParseError};
22use crate::parser::expressions::{FieldIndex, Name};
23
24/// Parsing context for relations that includes extensions, registry, and optional warning collection
25pub struct RelationParsingContext<'a> {
26    pub extensions: &'a SimpleExtensions,
27    pub registry: &'a ExtensionRegistry,
28    pub line_no: i64,
29    pub line: &'a str,
30}
31
32impl<'a> RelationParsingContext<'a> {
33    /// Resolve extension detail using registry. Any failure is treated as a hard parse error.
34    pub fn resolve_extension_detail(
35        &self,
36        extension_name: &str,
37        extension_args: &ExtensionArgs,
38    ) -> Result<Option<Any>, ParseError> {
39        let detail = self
40            .registry
41            .parse_extension(extension_name, extension_args);
42
43        match detail {
44            Ok(any) => Ok(Some(any)),
45            Err(ExtensionError::NotFound { .. }) => Err(ParseError::UnregisteredExtension {
46                name: extension_name.to_string(),
47                context: ParseContext::new(self.line_no, self.line.to_string()),
48            }),
49            Err(err) => Err(ParseError::ExtensionDetail(
50                ParseContext::new(self.line_no, self.line.to_string()),
51                err,
52            )),
53        }
54    }
55}
56
57/// A trait for parsing relations with full context needed for tree building.
58/// This includes extensions, the parsed pair, input children, and output field count.
59pub trait RelationParsePair: Sized {
60    fn rule() -> Rule;
61
62    fn message() -> &'static str;
63
64    /// Parse a relation with full context for tree building.
65    ///
66    /// Args:
67    /// - extensions: The extensions context
68    /// - pair: The parsed pest pair
69    /// - input_children: The input relations (for wiring)
70    /// - input_field_count: Number of output fields from input children (for output mapping)
71    fn parse_pair_with_context(
72        extensions: &SimpleExtensions,
73        pair: Pair<Rule>,
74        input_children: Vec<Box<Rel>>,
75        _input_field_count: usize,
76    ) -> Result<Self, MessageParseError>;
77
78    fn into_rel(self) -> Rel;
79}
80
81pub struct TableName(Vec<String>);
82
83impl ParsePair for TableName {
84    fn rule() -> Rule {
85        Rule::table_name
86    }
87
88    fn message() -> &'static str {
89        "TableName"
90    }
91
92    fn parse_pair(pair: Pair<Rule>) -> Self {
93        assert_eq!(pair.as_rule(), Self::rule());
94        let pairs = pair.into_inner();
95        let mut names = Vec::with_capacity(pairs.len());
96        let mut iter = RuleIter::from(pairs);
97        while let Some(name) = iter.parse_if_next::<Name>() {
98            names.push(name.0);
99        }
100        iter.done();
101        Self(names)
102    }
103}
104
105#[derive(Debug, Clone)]
106pub struct Column {
107    pub name: String,
108    pub typ: Type,
109}
110
111impl ScopedParsePair for Column {
112    fn rule() -> Rule {
113        Rule::named_column
114    }
115
116    fn message() -> &'static str {
117        "Column"
118    }
119
120    fn parse_pair(
121        extensions: &SimpleExtensions,
122        pair: Pair<Rule>,
123    ) -> Result<Self, MessageParseError> {
124        assert_eq!(pair.as_rule(), Self::rule());
125        let mut iter = RuleIter::from(pair.into_inner());
126        let name = iter.parse_next::<Name>().0;
127        let typ = iter.parse_next_scoped(extensions)?;
128        iter.done();
129        Ok(Self { name, typ })
130    }
131}
132
133pub struct NamedColumnList(Vec<Column>);
134
135impl ScopedParsePair for NamedColumnList {
136    fn rule() -> Rule {
137        Rule::named_column_list
138    }
139
140    fn message() -> &'static str {
141        "NamedColumnList"
142    }
143
144    fn parse_pair(
145        extensions: &SimpleExtensions,
146        pair: Pair<Rule>,
147    ) -> Result<Self, MessageParseError> {
148        assert_eq!(pair.as_rule(), Self::rule());
149        let mut columns = Vec::new();
150        for col in pair.into_inner() {
151            columns.push(Column::parse_pair(extensions, col)?);
152        }
153        Ok(Self(columns))
154    }
155}
156
157/// This is a utility function for extracting a single child from the list of
158/// children, to be used in the RelationParsePair trait. The RelationParsePair
159/// trait passes a Vec of children, because some relations have multiple
160/// children - but most accept exactly one child.
161#[allow(clippy::vec_box)]
162pub(crate) fn expect_one_child(
163    message: &'static str,
164    pair: &Pair<Rule>,
165    mut input_children: Vec<Box<Rel>>,
166) -> Result<Box<Rel>, MessageParseError> {
167    match input_children.len() {
168        0 => Err(MessageParseError::invalid(
169            message,
170            pair.as_span(),
171            format!("{message} missing child"),
172        )),
173        1 => Ok(input_children.pop().unwrap()),
174        n => Err(MessageParseError::invalid(
175            message,
176            pair.as_span(),
177            format!("{message} should have 1 input child, got {n}"),
178        )),
179    }
180}
181
182/// Parse a reference list Pair and return an EmitKind::Emit.
183fn parse_reference_emit(pair: Pair<Rule>) -> EmitKind {
184    assert_eq!(pair.as_rule(), Rule::reference_list);
185    let output_mapping = pair
186        .into_inner()
187        .map(|p| FieldIndex::parse_pair(p).0)
188        .collect::<Vec<i32>>();
189    EmitKind::Emit(Emit { output_mapping })
190}
191
192/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
193///
194/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
195///
196/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
197pub struct ParsedNamedArgs<'a> {
198    map: HashMap<&'a str, Pair<'a, Rule>>,
199}
200
201impl<'a> ParsedNamedArgs<'a> {
202    pub fn new(
203        pairs: pest::iterators::Pairs<'a, Rule>,
204        rule: Rule,
205    ) -> Result<Self, MessageParseError> {
206        let mut map = HashMap::new();
207        for pair in pairs {
208            assert_eq!(pair.as_rule(), rule);
209            let mut inner = pair.clone().into_inner();
210            let name_pair = inner.next().unwrap();
211            let value_pair = inner.next().unwrap();
212            assert_eq!(inner.next(), None);
213            let name = name_pair.as_str();
214            if map.contains_key(name) {
215                return Err(MessageParseError::invalid(
216                    "NamedArg",
217                    name_pair.as_span(),
218                    format!("Duplicate argument: {name}"),
219                ));
220            }
221            map.insert(name, value_pair);
222        }
223        Ok(Self { map })
224    }
225
226    // Returns the pair if it exists and matches the rule, otherwise None.
227    // Asserts that the rule must match the rule of the pair (and therefore
228    // panics in non-release-mode if not)
229    pub fn pop(mut self, name: &str, rule: Rule) -> (Self, Option<Pair<'a, Rule>>) {
230        let pair = self.map.remove(name).inspect(|pair| {
231            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
232        });
233        (self, pair)
234    }
235
236    // Returns an error if there are any unused arguments.
237    pub fn done(self) -> Result<(), MessageParseError> {
238        if let Some((name, pair)) = self.map.iter().next() {
239            return Err(MessageParseError::invalid(
240                "NamedArgExtractor",
241                // No span available for all unused args; use default.
242                pair.as_span(),
243                format!("Unknown argument: {name}"),
244            ));
245        }
246        Ok(())
247    }
248}
249
250impl RelationParsePair for ReadRel {
251    fn rule() -> Rule {
252        Rule::read_relation
253    }
254
255    fn message() -> &'static str {
256        "ReadRel"
257    }
258
259    fn into_rel(self) -> Rel {
260        Rel {
261            rel_type: Some(RelType::Read(Box::new(self))),
262        }
263    }
264
265    fn parse_pair_with_context(
266        extensions: &SimpleExtensions,
267        pair: Pair<Rule>,
268        input_children: Vec<Box<Rel>>,
269        _input_field_count: usize,
270    ) -> Result<Self, MessageParseError> {
271        assert_eq!(pair.as_rule(), Self::rule());
272        // ReadRel is a leaf node - it should have no input children and 0 input fields
273        if !input_children.is_empty() {
274            return Err(MessageParseError::invalid(
275                Self::message(),
276                pair.as_span(),
277                "ReadRel should have no input children",
278            ));
279        }
280        if _input_field_count != 0 {
281            let error = pest::error::Error::new_from_span(
282                pest::error::ErrorVariant::CustomError {
283                    message: "ReadRel should have 0 input fields".to_string(),
284                },
285                pair.as_span(),
286            );
287            return Err(MessageParseError::new(
288                "ReadRel",
289                ErrorKind::InvalidValue,
290                Box::new(error),
291            ));
292        }
293
294        let mut iter = RuleIter::from(pair.into_inner());
295        let table = iter.parse_next::<TableName>().0;
296        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
297        iter.done();
298
299        let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
300        let struct_ = r#type::Struct {
301            types,
302            type_variation_reference: 0,
303            nullability: r#type::Nullability::Required as i32,
304        };
305        let named_struct = NamedStruct {
306            names,
307            r#struct: Some(struct_),
308        };
309
310        let read_rel = ReadRel {
311            base_schema: Some(named_struct),
312            read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
313                names: table,
314                advanced_extension: None,
315            })),
316            ..Default::default()
317        };
318
319        Ok(read_rel)
320    }
321}
322
323impl RelationParsePair for FilterRel {
324    fn rule() -> Rule {
325        Rule::filter_relation
326    }
327
328    fn message() -> &'static str {
329        "FilterRel"
330    }
331
332    fn into_rel(self) -> Rel {
333        Rel {
334            rel_type: Some(RelType::Filter(Box::new(self))),
335        }
336    }
337
338    fn parse_pair_with_context(
339        extensions: &SimpleExtensions,
340        pair: Pair<Rule>,
341        input_children: Vec<Box<Rel>>,
342        _input_field_count: usize,
343    ) -> Result<Self, MessageParseError> {
344        // Form: Filter[condition => references]
345
346        assert_eq!(pair.as_rule(), Self::rule());
347        let input = expect_one_child(Self::message(), &pair, input_children)?;
348        let mut iter = RuleIter::from(pair.into_inner());
349        // condition
350        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
351        // references (which become the emit)
352        let references_pair = iter.pop(Rule::reference_list);
353        iter.done();
354
355        let emit = parse_reference_emit(references_pair);
356        let common = RelCommon {
357            emit_kind: Some(emit),
358            ..Default::default()
359        };
360
361        Ok(FilterRel {
362            input: Some(input),
363            condition: Some(Box::new(condition)),
364            common: Some(common),
365            advanced_extension: None,
366        })
367    }
368}
369
370impl RelationParsePair for ProjectRel {
371    fn rule() -> Rule {
372        Rule::project_relation
373    }
374
375    fn message() -> &'static str {
376        "ProjectRel"
377    }
378
379    fn into_rel(self) -> Rel {
380        Rel {
381            rel_type: Some(RelType::Project(Box::new(self))),
382        }
383    }
384
385    fn parse_pair_with_context(
386        extensions: &SimpleExtensions,
387        pair: Pair<Rule>,
388        input_children: Vec<Box<Rel>>,
389        _input_field_count: usize,
390    ) -> Result<Self, MessageParseError> {
391        assert_eq!(pair.as_rule(), Self::rule());
392        let input = expect_one_child(Self::message(), &pair, input_children)?;
393
394        // Get the argument list (contains references and expressions)
395        let arguments_pair = unwrap_single_pair(pair);
396
397        let mut expressions = Vec::new();
398        let mut output_mapping = Vec::new();
399
400        // Process each argument (can be either a reference or expression)
401        for arg in arguments_pair.into_inner() {
402            let inner_arg = unwrap_single_pair(arg);
403            match inner_arg.as_rule() {
404                Rule::reference => {
405                    // Parse reference like "$0" -> 0
406                    let field_index = FieldIndex::parse_pair(inner_arg);
407                    output_mapping.push(field_index.0);
408                }
409                Rule::expression => {
410                    // Parse as expression (e.g., 42, add($0, $1))
411                    let _expr = Expression::parse_pair(extensions, inner_arg)?;
412                    expressions.push(_expr);
413                    // Expression: index after all input fields
414                    output_mapping.push(_input_field_count as i32 + (expressions.len() as i32 - 1));
415                }
416                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
417            }
418        }
419
420        let emit = EmitKind::Emit(Emit { output_mapping });
421        let common = RelCommon {
422            emit_kind: Some(emit),
423            ..Default::default()
424        };
425
426        Ok(ProjectRel {
427            input: Some(input),
428            expressions,
429            common: Some(common),
430            advanced_extension: None,
431        })
432    }
433}
434
435impl RelationParsePair for AggregateRel {
436    fn rule() -> Rule {
437        Rule::aggregate_relation
438    }
439
440    fn message() -> &'static str {
441        "AggregateRel"
442    }
443
444    fn into_rel(self) -> Rel {
445        Rel {
446            rel_type: Some(RelType::Aggregate(Box::new(self))),
447        }
448    }
449
450    fn parse_pair_with_context(
451        extensions: &SimpleExtensions,
452        pair: Pair<Rule>,
453        input_children: Vec<Box<Rel>>,
454        _input_field_count: usize,
455    ) -> Result<Self, MessageParseError> {
456        assert_eq!(pair.as_rule(), Self::rule());
457        let input = expect_one_child(Self::message(), &pair, input_children)?;
458        let mut iter = RuleIter::from(pair.into_inner());
459        let group_by_pair = iter.pop(Rule::aggregate_group_by);
460        let output_pair = iter.pop(Rule::aggregate_output);
461        iter.done();
462        let mut grouping_expressions = Vec::new();
463        for group_by_item in group_by_pair.into_inner() {
464            match group_by_item.as_rule() {
465                Rule::reference => {
466                    let field_index = FieldIndex::parse_pair(group_by_item);
467                    grouping_expressions.push(Expression {
468                        rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
469                            field_index.to_field_reference(),
470                        ))),
471                    });
472                }
473                Rule::empty => {
474                    // No grouping expressions to add
475                }
476                _ => panic!(
477                    "Unexpected group-by item rule: {:?}",
478                    group_by_item.as_rule()
479                ),
480            }
481        }
482
483        // Parse output items (can be references or aggregate measures)
484        let mut measures = Vec::new();
485        let mut output_mapping = Vec::new();
486        let group_by_count = grouping_expressions.len();
487        let mut measure_count = 0;
488
489        for output_item in output_pair.into_inner() {
490            let inner_item = unwrap_single_pair(output_item);
491            match inner_item.as_rule() {
492                Rule::reference => {
493                    let field_index = FieldIndex::parse_pair(inner_item);
494                    output_mapping.push(field_index.0);
495                }
496                Rule::aggregate_measure => {
497                    let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
498                    measures.push(measure);
499                    output_mapping.push(group_by_count as i32 + measure_count);
500                    measure_count += 1;
501                }
502                _ => panic!(
503                    "Unexpected inner output item rule: {:?}",
504                    inner_item.as_rule()
505                ),
506            }
507        }
508
509        let emit = EmitKind::Emit(Emit { output_mapping });
510        let common = RelCommon {
511            emit_kind: Some(emit),
512            ..Default::default()
513        };
514
515        Ok(AggregateRel {
516            input: Some(input),
517            grouping_expressions,
518            groupings: vec![], // TODO: Create groupings from grouping_expressions for complex grouping scenarios
519            measures,
520            common: Some(common),
521            advanced_extension: None,
522        })
523    }
524}
525
526impl ScopedParsePair for SortField {
527    fn rule() -> Rule {
528        Rule::sort_field
529    }
530
531    fn message() -> &'static str {
532        "SortField"
533    }
534
535    fn parse_pair(
536        _extensions: &SimpleExtensions,
537        pair: Pair<Rule>,
538    ) -> Result<Self, MessageParseError> {
539        assert_eq!(pair.as_rule(), Self::rule());
540        let mut iter = RuleIter::from(pair.into_inner());
541        let reference_pair = iter.pop(Rule::reference);
542        let field_index = FieldIndex::parse_pair(reference_pair);
543        let direction_pair = iter.pop(Rule::sort_direction);
544        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
545        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
546        // identifiers, but the enum variant names don't include it
547        let direction = match direction_pair.as_str().trim_start_matches('&') {
548            "AscNullsFirst" => SortDirection::AscNullsFirst,
549            "AscNullsLast" => SortDirection::AscNullsLast,
550            "DescNullsFirst" => SortDirection::DescNullsFirst,
551            "DescNullsLast" => SortDirection::DescNullsLast,
552            other => {
553                return Err(MessageParseError::invalid(
554                    "SortDirection",
555                    direction_pair.as_span(),
556                    format!("Unknown sort direction: {other}"),
557                ));
558            }
559        };
560        iter.done();
561        Ok(SortField {
562            expr: Some(Expression {
563                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
564                    field_index.to_field_reference(),
565                ))),
566            }),
567            // TODO: Add support for SortKind::ComparisonFunctionReference
568            sort_kind: Some(SortKind::Direction(direction as i32)),
569        })
570    }
571}
572
573impl RelationParsePair for SortRel {
574    fn rule() -> Rule {
575        Rule::sort_relation
576    }
577
578    fn message() -> &'static str {
579        "SortRel"
580    }
581
582    fn into_rel(self) -> Rel {
583        Rel {
584            rel_type: Some(RelType::Sort(Box::new(self))),
585        }
586    }
587
588    fn parse_pair_with_context(
589        extensions: &SimpleExtensions,
590        pair: Pair<Rule>,
591        input_children: Vec<Box<Rel>>,
592        _input_field_count: usize,
593    ) -> Result<Self, MessageParseError> {
594        assert_eq!(pair.as_rule(), Self::rule());
595        let input = expect_one_child(Self::message(), &pair, input_children)?;
596        let mut iter = RuleIter::from(pair.into_inner());
597        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
598        let reference_list_pair = iter.pop(Rule::reference_list);
599        let mut sorts = Vec::new();
600        for sort_field_pair in sort_field_list_pair.into_inner() {
601            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
602            sorts.push(sort_field);
603        }
604        let emit = parse_reference_emit(reference_list_pair);
605        let common = RelCommon {
606            emit_kind: Some(emit),
607            ..Default::default()
608        };
609        iter.done();
610        Ok(SortRel {
611            input: Some(input),
612            sorts,
613            common: Some(common),
614            advanced_extension: None,
615        })
616    }
617}
618
619impl ScopedParsePair for CountMode {
620    fn rule() -> Rule {
621        Rule::fetch_value
622    }
623    fn message() -> &'static str {
624        "CountMode"
625    }
626    fn parse_pair(
627        extensions: &SimpleExtensions,
628        pair: Pair<Rule>,
629    ) -> Result<Self, MessageParseError> {
630        assert_eq!(pair.as_rule(), Self::rule());
631        let mut arg_inner = RuleIter::from(pair.into_inner());
632        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
633            int_pair
634        } else {
635            arg_inner.pop(Rule::expression)
636        };
637        match value_pair.as_rule() {
638            Rule::integer => {
639                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
640                    MessageParseError::invalid(
641                        Self::message(),
642                        value_pair.as_span(),
643                        format!("Invalid integer: {e}"),
644                    )
645                })?;
646                if value < 0 {
647                    return Err(MessageParseError::invalid(
648                        Self::message(),
649                        value_pair.as_span(),
650                        format!("Fetch limit must be non-negative, got: {value}"),
651                    ));
652                }
653                Ok(CountMode::CountExpr(i64_literal_expr(value)))
654            }
655            Rule::expression => {
656                let expr = Expression::parse_pair(extensions, value_pair)?;
657                Ok(CountMode::CountExpr(Box::new(expr)))
658            }
659            _ => Err(MessageParseError::invalid(
660                Self::message(),
661                value_pair.as_span(),
662                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
663            )),
664        }
665    }
666}
667
668fn i64_literal_expr(value: i64) -> Box<Expression> {
669    Box::new(Expression {
670        rex_type: Some(RexType::Literal(Literal {
671            nullable: false,
672            type_variation_reference: 0,
673            literal_type: Some(LiteralType::I64(value)),
674        })),
675    })
676}
677
678impl ScopedParsePair for OffsetMode {
679    fn rule() -> Rule {
680        Rule::fetch_value
681    }
682    fn message() -> &'static str {
683        "OffsetMode"
684    }
685    fn parse_pair(
686        extensions: &SimpleExtensions,
687        pair: Pair<Rule>,
688    ) -> Result<Self, MessageParseError> {
689        assert_eq!(pair.as_rule(), Self::rule());
690        let mut arg_inner = RuleIter::from(pair.into_inner());
691        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
692            int_pair
693        } else {
694            arg_inner.pop(Rule::expression)
695        };
696        match value_pair.as_rule() {
697            Rule::integer => {
698                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
699                    MessageParseError::invalid(
700                        Self::message(),
701                        value_pair.as_span(),
702                        format!("Invalid integer: {e}"),
703                    )
704                })?;
705                if value < 0 {
706                    return Err(MessageParseError::invalid(
707                        Self::message(),
708                        value_pair.as_span(),
709                        format!("Fetch offset must be non-negative, got: {value}"),
710                    ));
711                }
712                Ok(OffsetMode::OffsetExpr(i64_literal_expr(value)))
713            }
714            Rule::expression => {
715                let expr = Expression::parse_pair(extensions, value_pair)?;
716                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
717            }
718            _ => Err(MessageParseError::invalid(
719                Self::message(),
720                value_pair.as_span(),
721                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
722            )),
723        }
724    }
725}
726
727impl RelationParsePair for FetchRel {
728    fn rule() -> Rule {
729        Rule::fetch_relation
730    }
731
732    fn message() -> &'static str {
733        "FetchRel"
734    }
735
736    fn into_rel(self) -> Rel {
737        Rel {
738            rel_type: Some(RelType::Fetch(Box::new(self))),
739        }
740    }
741
742    fn parse_pair_with_context(
743        extensions: &SimpleExtensions,
744        pair: Pair<Rule>,
745        input_children: Vec<Box<Rel>>,
746        _input_field_count: usize,
747    ) -> Result<Self, MessageParseError> {
748        assert_eq!(pair.as_rule(), Self::rule());
749        let input = expect_one_child(Self::message(), &pair, input_children)?;
750        let mut iter = RuleIter::from(pair.into_inner());
751
752        // Extract all pairs first, then do validation
753        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
754            None => {
755                // If there are no arguments, it should be empty
756                iter.pop(Rule::empty);
757                (None, None)
758            }
759            Some(fetch_args_pair) => {
760                let extractor =
761                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
762                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
763                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
764                extractor.done()?;
765                (limit_pair, offset_pair)
766            }
767        };
768
769        let reference_list_pair = iter.pop(Rule::reference_list);
770        let emit = parse_reference_emit(reference_list_pair);
771        let common = RelCommon {
772            emit_kind: Some(emit),
773            ..Default::default()
774        };
775        iter.done();
776
777        // Now do validation after iterator is fully consumed
778        let count_mode = limit_pair
779            .map(|pair| CountMode::parse_pair(extensions, pair))
780            .transpose()?;
781        let offset_mode = offset_pair
782            .map(|pair| OffsetMode::parse_pair(extensions, pair))
783            .transpose()?;
784        Ok(FetchRel {
785            input: Some(input),
786            common: Some(common),
787            advanced_extension: None,
788            offset_mode,
789            count_mode,
790        })
791    }
792}
793
794impl ParsePair for join_rel::JoinType {
795    fn rule() -> Rule {
796        Rule::join_type
797    }
798
799    fn message() -> &'static str {
800        "JoinType"
801    }
802
803    fn parse_pair(pair: Pair<Rule>) -> Self {
804        assert_eq!(pair.as_rule(), Self::rule());
805        let join_type_str = pair.as_str().trim_start_matches('&');
806        match join_type_str {
807            "Inner" => join_rel::JoinType::Inner,
808            "Left" => join_rel::JoinType::Left,
809            "Right" => join_rel::JoinType::Right,
810            "Outer" => join_rel::JoinType::Outer,
811            "LeftSemi" => join_rel::JoinType::LeftSemi,
812            "RightSemi" => join_rel::JoinType::RightSemi,
813            "LeftAnti" => join_rel::JoinType::LeftAnti,
814            "RightAnti" => join_rel::JoinType::RightAnti,
815            "LeftSingle" => join_rel::JoinType::LeftSingle,
816            "RightSingle" => join_rel::JoinType::RightSingle,
817            "LeftMark" => join_rel::JoinType::LeftMark,
818            "RightMark" => join_rel::JoinType::RightMark,
819            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
820        }
821    }
822}
823
824impl RelationParsePair for JoinRel {
825    fn rule() -> Rule {
826        Rule::join_relation
827    }
828
829    fn message() -> &'static str {
830        "JoinRel"
831    }
832
833    fn into_rel(self) -> Rel {
834        Rel {
835            rel_type: Some(RelType::Join(Box::new(self))),
836        }
837    }
838
839    fn parse_pair_with_context(
840        extensions: &SimpleExtensions,
841        pair: Pair<Rule>,
842        input_children: Vec<Box<Rel>>,
843        _input_field_count: usize,
844    ) -> Result<Self, MessageParseError> {
845        assert_eq!(pair.as_rule(), Self::rule());
846
847        // Join requires exactly 2 input children
848        if input_children.len() != 2 {
849            return Err(MessageParseError::invalid(
850                Self::message(),
851                pair.as_span(),
852                format!(
853                    "JoinRel should have exactly 2 input children, got {}",
854                    input_children.len()
855                ),
856            ));
857        }
858
859        let mut children_iter = input_children.into_iter();
860        let left = children_iter.next().unwrap();
861        let right = children_iter.next().unwrap();
862
863        let mut iter = RuleIter::from(pair.into_inner());
864
865        // Parse join type
866        let join_type = iter.parse_next::<join_rel::JoinType>();
867
868        // Parse join condition expression
869        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
870
871        // Parse output references (which become the emit)
872        let reference_list_pair = iter.pop(Rule::reference_list);
873        iter.done();
874
875        let emit = parse_reference_emit(reference_list_pair);
876        let common = RelCommon {
877            emit_kind: Some(emit),
878            ..Default::default()
879        };
880
881        Ok(JoinRel {
882            common: Some(common),
883            left: Some(left),
884            right: Some(right),
885            expression: Some(Box::new(condition)),
886            post_join_filter: None, // Not supported in grammar yet
887            r#type: join_type as i32,
888            advanced_extension: None,
889        })
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use pest::Parser;
896
897    use super::*;
898    use crate::fixtures::TestContext;
899    use crate::parser::{ExpressionParser, Rule};
900
901    #[test]
902    fn test_parse_relation() {
903        // Removed: test_parse_relation for old Relation struct
904    }
905
906    #[test]
907    fn test_parse_read_relation() {
908        let extensions = SimpleExtensions::default();
909        let read = ReadRel::parse_pair_with_context(
910            &extensions,
911            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
912            vec![],
913            0,
914        )
915        .unwrap();
916        let names = match &read.read_type {
917            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
918            _ => panic!("Expected NamedTable"),
919        };
920        assert_eq!(names, &["ab", "cd", "ef"]);
921        let columns = &read
922            .base_schema
923            .as_ref()
924            .unwrap()
925            .r#struct
926            .as_ref()
927            .unwrap()
928            .types;
929        assert_eq!(columns.len(), 2);
930    }
931
932    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
933    fn example_read_relation() -> ReadRel {
934        let extensions = SimpleExtensions::default();
935        ReadRel::parse_pair_with_context(
936            &extensions,
937            parse_exact(
938                Rule::read_relation,
939                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
940            ),
941            vec![],
942            0,
943        )
944        .unwrap()
945    }
946
947    #[test]
948    fn test_parse_filter_relation() {
949        let extensions = SimpleExtensions::default();
950        let filter = FilterRel::parse_pair_with_context(
951            &extensions,
952            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
953            vec![Box::new(example_read_relation().into_rel())],
954            3,
955        )
956        .unwrap();
957        let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
958        let emit = match emit_kind {
959            EmitKind::Emit(emit) => &emit.output_mapping,
960            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
961        };
962        assert_eq!(emit, &[0, 1, 2]);
963    }
964
965    #[test]
966    fn test_parse_project_relation() {
967        let extensions = SimpleExtensions::default();
968        let project = ProjectRel::parse_pair_with_context(
969            &extensions,
970            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
971            vec![Box::new(example_read_relation().into_rel())],
972            3,
973        )
974        .unwrap();
975
976        // Should have 1 expression (42) and 2 references ($0, $1)
977        assert_eq!(project.expressions.len(), 1);
978
979        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
980        let emit = match emit_kind {
981            EmitKind::Emit(emit) => &emit.output_mapping,
982            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
983        };
984        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
985        assert_eq!(emit, &[0, 1, 3]);
986    }
987
988    #[test]
989    fn test_parse_project_relation_complex() {
990        let extensions = SimpleExtensions::default();
991        let project = ProjectRel::parse_pair_with_context(
992            &extensions,
993            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
994            vec![Box::new(example_read_relation().into_rel())],
995            5, // Assume 5 input fields
996        )
997        .unwrap();
998
999        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
1000        assert_eq!(project.expressions.len(), 2);
1001
1002        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1003        let emit = match emit_kind {
1004            EmitKind::Emit(emit) => &emit.output_mapping,
1005            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1006        };
1007        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
1008        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
1009        assert_eq!(emit, &[5, 0, 6, 2, 1]);
1010    }
1011
1012    #[test]
1013    fn test_parse_aggregate_relation() {
1014        let extensions = TestContext::new()
1015            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1016            .with_function(1, 10, "sum")
1017            .with_function(1, 11, "count")
1018            .extensions;
1019
1020        let aggregate = AggregateRel::parse_pair_with_context(
1021            &extensions,
1022            parse_exact(
1023                Rule::aggregate_relation,
1024                "Aggregate[$0, $1 => sum($2), $0, count($2)]",
1025            ),
1026            vec![Box::new(example_read_relation().into_rel())],
1027            3,
1028        )
1029        .unwrap();
1030
1031        // Should have 2 group-by fields ($0, $1) and 2 measures (sum($2), count($2))
1032        assert_eq!(aggregate.grouping_expressions.len(), 2);
1033        assert_eq!(aggregate.measures.len(), 2);
1034
1035        let emit_kind = &aggregate
1036            .common
1037            .as_ref()
1038            .unwrap()
1039            .emit_kind
1040            .as_ref()
1041            .unwrap();
1042        let emit = match emit_kind {
1043            EmitKind::Emit(emit) => &emit.output_mapping,
1044            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1045        };
1046        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1047        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1048        assert_eq!(emit, &[2, 0, 3]);
1049    }
1050
1051    #[test]
1052    fn test_parse_aggregate_relation_simple() {
1053        let extensions = TestContext::new()
1054            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1055            .with_function(1, 10, "sum")
1056            .with_function(1, 11, "count")
1057            .extensions;
1058
1059        let aggregate = AggregateRel::parse_pair_with_context(
1060            &extensions,
1061            parse_exact(
1062                Rule::aggregate_relation,
1063                "Aggregate[$0 => sum($1), count($1)]",
1064            ),
1065            vec![Box::new(example_read_relation().into_rel())],
1066            3,
1067        )
1068        .unwrap();
1069
1070        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1071        assert_eq!(aggregate.grouping_expressions.len(), 1);
1072        assert_eq!(aggregate.measures.len(), 2);
1073
1074        let emit_kind = &aggregate
1075            .common
1076            .as_ref()
1077            .unwrap()
1078            .emit_kind
1079            .as_ref()
1080            .unwrap();
1081        let emit = match emit_kind {
1082            EmitKind::Emit(emit) => &emit.output_mapping,
1083            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1084        };
1085        // Output mapping should be [1, 2] (measures only)
1086        assert_eq!(emit, &[1, 2]);
1087    }
1088
1089    #[test]
1090    fn test_parse_aggregate_relation_no_group_by() {
1091        let extensions = TestContext::new()
1092            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1093            .with_function(1, 10, "sum")
1094            .with_function(1, 11, "count")
1095            .extensions;
1096
1097        let aggregate = AggregateRel::parse_pair_with_context(
1098            &extensions,
1099            parse_exact(
1100                Rule::aggregate_relation,
1101                "Aggregate[_ => sum($0), count($1)]",
1102            ),
1103            vec![Box::new(example_read_relation().into_rel())],
1104            3,
1105        )
1106        .unwrap();
1107
1108        // Should have 0 group-by fields and 2 measures
1109        assert_eq!(aggregate.grouping_expressions.len(), 0);
1110        assert_eq!(aggregate.measures.len(), 2);
1111
1112        let emit_kind = &aggregate
1113            .common
1114            .as_ref()
1115            .unwrap()
1116            .emit_kind
1117            .as_ref()
1118            .unwrap();
1119        let emit = match emit_kind {
1120            EmitKind::Emit(emit) => &emit.output_mapping,
1121            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1122        };
1123        // Output mapping should be [0, 1] (measures only, no group-by fields)
1124        assert_eq!(emit, &[0, 1]);
1125    }
1126
1127    #[test]
1128    fn test_parse_aggregate_relation_empty_group_by() {
1129        let extensions = TestContext::new()
1130            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1131            .with_function(1, 10, "sum")
1132            .with_function(1, 11, "count")
1133            .extensions;
1134
1135        let aggregate = AggregateRel::parse_pair_with_context(
1136            &extensions,
1137            parse_exact(
1138                Rule::aggregate_relation,
1139                "Aggregate[_ => sum($0), count($1)]",
1140            ),
1141            vec![Box::new(example_read_relation().into_rel())],
1142            3,
1143        )
1144        .unwrap();
1145
1146        // Should have 0 group-by fields and 2 measures
1147        assert_eq!(aggregate.grouping_expressions.len(), 0);
1148        assert_eq!(aggregate.measures.len(), 2);
1149
1150        let emit_kind = &aggregate
1151            .common
1152            .as_ref()
1153            .unwrap()
1154            .emit_kind
1155            .as_ref()
1156            .unwrap();
1157        let emit = match emit_kind {
1158            EmitKind::Emit(emit) => &emit.output_mapping,
1159            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1160        };
1161        // Output mapping should be [0, 1] (measures only, no group-by fields)
1162        assert_eq!(emit, &[0, 1]);
1163    }
1164
1165    #[test]
1166    fn test_fetch_relation_positive_values() {
1167        let extensions = SimpleExtensions::default();
1168
1169        // Test valid positive values should work
1170        let fetch_rel = FetchRel::parse_pair_with_context(
1171            &extensions,
1172            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1173            vec![Box::new(example_read_relation().into_rel())],
1174            3,
1175        )
1176        .unwrap();
1177
1178        // Verify the limit and offset values are correct
1179        assert_eq!(
1180            fetch_rel.count_mode,
1181            Some(CountMode::CountExpr(i64_literal_expr(10)))
1182        );
1183        assert_eq!(
1184            fetch_rel.offset_mode,
1185            Some(OffsetMode::OffsetExpr(i64_literal_expr(5)))
1186        );
1187    }
1188
1189    #[test]
1190    fn test_fetch_relation_negative_limit_rejected() {
1191        let extensions = SimpleExtensions::default();
1192
1193        // Test that fetch relations with negative limits are properly rejected
1194        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1195        if let Ok(mut pairs) = parsed_result {
1196            let pair = pairs.next().unwrap();
1197            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1198                // Full parse succeeded, now test that validation catches the negative value
1199                let result = FetchRel::parse_pair_with_context(
1200                    &extensions,
1201                    pair,
1202                    vec![Box::new(example_read_relation().into_rel())],
1203                    3,
1204                );
1205                assert!(result.is_err());
1206                let error_msg = result.unwrap_err().to_string();
1207                assert!(error_msg.contains("Fetch limit must be non-negative"));
1208            } else {
1209                // If grammar doesn't fully support negative values, that's also acceptable
1210                // since it would prevent negative values at parse time
1211                println!("Grammar prevents negative limit values at parse time");
1212            }
1213        } else {
1214            // Grammar doesn't support negative values in fetch context
1215            println!("Grammar prevents negative limit values at parse time");
1216        }
1217    }
1218
1219    #[test]
1220    fn test_fetch_relation_negative_offset_rejected() {
1221        let extensions = SimpleExtensions::default();
1222
1223        // Test that fetch relations with negative offsets are properly rejected
1224        let parsed_result =
1225            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1226        if let Ok(mut pairs) = parsed_result {
1227            let pair = pairs.next().unwrap();
1228            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1229                // Full parse succeeded, now test that validation catches the negative value
1230                let result = FetchRel::parse_pair_with_context(
1231                    &extensions,
1232                    pair,
1233                    vec![Box::new(example_read_relation().into_rel())],
1234                    3,
1235                );
1236                assert!(result.is_err());
1237                let error_msg = result.unwrap_err().to_string();
1238                assert!(error_msg.contains("Fetch offset must be non-negative"));
1239            } else {
1240                // If grammar doesn't fully support negative values, that's also acceptable
1241                println!("Grammar prevents negative offset values at parse time");
1242            }
1243        } else {
1244            // Grammar doesn't support negative values in fetch context
1245            println!("Grammar prevents negative offset values at parse time");
1246        }
1247    }
1248
1249    #[test]
1250    fn test_parse_join_relation() {
1251        let extensions = TestContext::new()
1252            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1253            .with_function(1, 10, "eq")
1254            .extensions;
1255
1256        let left_rel = example_read_relation().into_rel();
1257        let right_rel = example_read_relation().into_rel();
1258
1259        let join = JoinRel::parse_pair_with_context(
1260            &extensions,
1261            parse_exact(
1262                Rule::join_relation,
1263                "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1264            ),
1265            vec![Box::new(left_rel), Box::new(right_rel)],
1266            6, // left (3) + right (3) = 6 total input fields
1267        )
1268        .unwrap();
1269
1270        // Should be an Inner join
1271        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1272
1273        // Should have left and right relations
1274        assert!(join.left.is_some());
1275        assert!(join.right.is_some());
1276
1277        // Should have a join condition
1278        assert!(join.expression.is_some());
1279
1280        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1281        let emit = match emit_kind {
1282            EmitKind::Emit(emit) => &emit.output_mapping,
1283            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1284        };
1285        // Output mapping should be [0, 1, 3, 4] (selected columns)
1286        assert_eq!(emit, &[0, 1, 3, 4]);
1287    }
1288
1289    #[test]
1290    fn test_parse_join_relation_left_outer() {
1291        let extensions = TestContext::new()
1292            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1293            .with_function(1, 10, "eq")
1294            .extensions;
1295
1296        let left_rel = example_read_relation().into_rel();
1297        let right_rel = example_read_relation().into_rel();
1298
1299        let join = JoinRel::parse_pair_with_context(
1300            &extensions,
1301            parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1302            vec![Box::new(left_rel), Box::new(right_rel)],
1303            6,
1304        )
1305        .unwrap();
1306
1307        // Should be a Left join
1308        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1309
1310        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1311        let emit = match emit_kind {
1312            EmitKind::Emit(emit) => &emit.output_mapping,
1313            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1314        };
1315        // Output mapping should be [0, 1, 2]
1316        assert_eq!(emit, &[0, 1, 2]);
1317    }
1318
1319    #[test]
1320    fn test_parse_join_relation_left_semi() {
1321        let extensions = TestContext::new()
1322            .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1323            .with_function(1, 10, "eq")
1324            .extensions;
1325
1326        let left_rel = example_read_relation().into_rel();
1327        let right_rel = example_read_relation().into_rel();
1328
1329        let join = JoinRel::parse_pair_with_context(
1330            &extensions,
1331            parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1332            vec![Box::new(left_rel), Box::new(right_rel)],
1333            6,
1334        )
1335        .unwrap();
1336
1337        // Should be a LeftSemi join
1338        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1339
1340        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1341        let emit = match emit_kind {
1342            EmitKind::Emit(emit) => &emit.output_mapping,
1343            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1344        };
1345        // Output mapping should be [0, 1] (only left columns for semi join)
1346        assert_eq!(emit, &[0, 1]);
1347    }
1348
1349    #[test]
1350    fn test_parse_join_relation_requires_two_children() {
1351        let extensions = SimpleExtensions::default();
1352
1353        // Test with 0 children
1354        let result = JoinRel::parse_pair_with_context(
1355            &extensions,
1356            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1357            vec![],
1358            0,
1359        );
1360        assert!(result.is_err());
1361
1362        // Test with 1 child
1363        let result = JoinRel::parse_pair_with_context(
1364            &extensions,
1365            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1366            vec![Box::new(example_read_relation().into_rel())],
1367            3,
1368        );
1369        assert!(result.is_err());
1370
1371        // Test with 3 children
1372        let result = JoinRel::parse_pair_with_context(
1373            &extensions,
1374            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1375            vec![
1376                Box::new(example_read_relation().into_rel()),
1377                Box::new(example_read_relation().into_rel()),
1378                Box::new(example_read_relation().into_rel()),
1379            ],
1380            9,
1381        );
1382        assert!(result.is_err());
1383    }
1384
1385    fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
1386        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1387        assert_eq!(pairs.as_str(), input);
1388        let pair = pairs.next().unwrap();
1389        assert_eq!(pairs.next(), None);
1390        pair
1391    }
1392}