substrait_explain/parser/
relations.rs

1use std::collections::HashMap;
2
3use substrait::proto::fetch_rel::{CountMode, OffsetMode};
4use substrait::proto::rel::RelType;
5use substrait::proto::rel_common::{Emit, EmitKind};
6use substrait::proto::sort_field::{SortDirection, SortKind};
7use substrait::proto::{
8    AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct, ProjectRel, ReadRel, Rel,
9    RelCommon, SortField, SortRel, Type, aggregate_rel, join_rel, read_rel, r#type,
10};
11
12use super::{
13    ErrorKind, MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unwrap_single_pair,
14};
15use crate::extensions::SimpleExtensions;
16use crate::parser::expressions::{FieldIndex, Name};
17
18/// A trait for parsing relations with full context needed for tree building.
19/// This includes extensions, the parsed pair, input children, and output field count.
20pub trait RelationParsePair: Sized {
21    fn rule() -> Rule;
22
23    fn message() -> &'static str;
24
25    /// Parse a relation with full context for tree building.
26    ///
27    /// Args:
28    /// - extensions: The extensions context
29    /// - pair: The parsed pest pair
30    /// - input_children: The input relations (for wiring)
31    /// - input_field_count: Number of output fields from input children (for output mapping)
32    fn parse_pair_with_context(
33        extensions: &SimpleExtensions,
34        pair: pest::iterators::Pair<Rule>,
35        input_children: Vec<Box<Rel>>,
36        input_field_count: usize,
37    ) -> Result<Self, MessageParseError>;
38
39    fn into_rel(self) -> Rel;
40}
41
42pub struct TableName(Vec<String>);
43
44impl ParsePair for TableName {
45    fn rule() -> Rule {
46        Rule::table_name
47    }
48
49    fn message() -> &'static str {
50        "TableName"
51    }
52
53    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
54        assert_eq!(pair.as_rule(), Self::rule());
55        let pairs = pair.into_inner();
56        let mut names = Vec::with_capacity(pairs.len());
57        let mut iter = RuleIter::from(pairs);
58        while let Some(name) = iter.parse_if_next::<Name>() {
59            names.push(name.0);
60        }
61        iter.done();
62        Self(names)
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct Column {
68    pub name: String,
69    pub typ: Type,
70}
71
72impl ScopedParsePair for Column {
73    fn rule() -> Rule {
74        Rule::named_column
75    }
76
77    fn message() -> &'static str {
78        "Column"
79    }
80
81    fn parse_pair(
82        extensions: &SimpleExtensions,
83        pair: pest::iterators::Pair<Rule>,
84    ) -> Result<Self, MessageParseError> {
85        assert_eq!(pair.as_rule(), Self::rule());
86        let mut iter = RuleIter::from(pair.into_inner());
87        let name = iter.parse_next::<Name>().0;
88        let typ = iter.parse_next_scoped(extensions)?;
89        iter.done();
90        Ok(Self { name, typ })
91    }
92}
93
94pub struct NamedColumnList(Vec<Column>);
95
96impl ScopedParsePair for NamedColumnList {
97    fn rule() -> Rule {
98        Rule::named_column_list
99    }
100
101    fn message() -> &'static str {
102        "NamedColumnList"
103    }
104
105    fn parse_pair(
106        extensions: &SimpleExtensions,
107        pair: pest::iterators::Pair<Rule>,
108    ) -> Result<Self, MessageParseError> {
109        assert_eq!(pair.as_rule(), Self::rule());
110        let mut columns = Vec::new();
111        for col in pair.into_inner() {
112            columns.push(Column::parse_pair(extensions, col)?);
113        }
114        Ok(Self(columns))
115    }
116}
117
118/// This is a utility function for extracting a single child from the list of
119/// children, to be used in the RelationParsePair trait. The RelationParsePair
120/// trait passes a Vec of children, because some relations have multiple
121/// children - but most accept exactly one child.
122#[allow(clippy::vec_box)]
123pub(crate) fn expect_one_child(
124    message: &'static str,
125    pair: &pest::iterators::Pair<Rule>,
126    mut input_children: Vec<Box<Rel>>,
127) -> Result<Box<Rel>, MessageParseError> {
128    match input_children.len() {
129        0 => Err(MessageParseError::invalid(
130            message,
131            pair.as_span(),
132            format!("{message} missing child"),
133        )),
134        1 => Ok(input_children.pop().unwrap()),
135        n => Err(MessageParseError::invalid(
136            message,
137            pair.as_span(),
138            format!("{message} should have 1 input child, got {n}"),
139        )),
140    }
141}
142
143/// Parse a reference list Pair and return an EmitKind::Emit.
144fn parse_reference_emit(pair: pest::iterators::Pair<Rule>) -> EmitKind {
145    assert_eq!(pair.as_rule(), Rule::reference_list);
146    let output_mapping = pair
147        .into_inner()
148        .map(|p| FieldIndex::parse_pair(p).0)
149        .collect::<Vec<i32>>();
150    EmitKind::Emit(Emit { output_mapping })
151}
152
153/// Extracts named arguments from pest pairs with duplicate detection and completeness checking.
154///
155/// Usage: `extractor.pop("limit", Rule::fetch_value).0.pop("offset", Rule::fetch_value).0.done()`
156///
157/// The fluent API ensures all arguments are processed exactly once and none are forgotten.
158pub struct ParsedNamedArgs<'a> {
159    map: HashMap<&'a str, pest::iterators::Pair<'a, Rule>>,
160}
161
162impl<'a> ParsedNamedArgs<'a> {
163    pub fn new(
164        pairs: pest::iterators::Pairs<'a, Rule>,
165        rule: Rule,
166    ) -> Result<Self, MessageParseError> {
167        let mut map = HashMap::new();
168        for pair in pairs {
169            assert_eq!(pair.as_rule(), rule);
170            let mut inner = pair.clone().into_inner();
171            let name_pair = inner.next().unwrap();
172            let value_pair = inner.next().unwrap();
173            assert_eq!(inner.next(), None);
174            let name = name_pair.as_str();
175            if map.contains_key(name) {
176                return Err(MessageParseError::invalid(
177                    "NamedArg",
178                    name_pair.as_span(),
179                    format!("Duplicate argument: {name}"),
180                ));
181            }
182            map.insert(name, value_pair);
183        }
184        Ok(Self { map })
185    }
186
187    // Returns the pair if it exists and matches the rule, otherwise None.
188    // Asserts that the rule must match the rule of the pair (and therefore
189    // panics in non-release-mode if not)
190    pub fn pop(
191        mut self,
192        name: &str,
193        rule: Rule,
194    ) -> (Self, Option<pest::iterators::Pair<'a, Rule>>) {
195        let pair = self.map.remove(name).inspect(|pair| {
196            assert_eq!(pair.as_rule(), rule, "Rule mismatch for argument {name}");
197        });
198        (self, pair)
199    }
200
201    // Returns an error if there are any unused arguments.
202    pub fn done(self) -> Result<(), MessageParseError> {
203        if let Some((name, pair)) = self.map.iter().next() {
204            return Err(MessageParseError::invalid(
205                "NamedArgExtractor",
206                // No span available for all unused args; use default.
207                pair.as_span(),
208                format!("Unknown argument: {name}"),
209            ));
210        }
211        Ok(())
212    }
213}
214
215impl RelationParsePair for ReadRel {
216    fn rule() -> Rule {
217        Rule::read_relation
218    }
219
220    fn message() -> &'static str {
221        "ReadRel"
222    }
223
224    fn into_rel(self) -> Rel {
225        Rel {
226            rel_type: Some(RelType::Read(Box::new(self))),
227        }
228    }
229
230    fn parse_pair_with_context(
231        extensions: &SimpleExtensions,
232        pair: pest::iterators::Pair<Rule>,
233        input_children: Vec<Box<Rel>>,
234        input_field_count: usize,
235    ) -> Result<Self, MessageParseError> {
236        assert_eq!(pair.as_rule(), Self::rule());
237        // ReadRel is a leaf node - it should have no input children and 0 input fields
238        if !input_children.is_empty() {
239            return Err(MessageParseError::invalid(
240                Self::message(),
241                pair.as_span(),
242                "ReadRel should have no input children",
243            ));
244        }
245        if input_field_count != 0 {
246            let error = pest::error::Error::new_from_span(
247                pest::error::ErrorVariant::CustomError {
248                    message: "ReadRel should have 0 input fields".to_string(),
249                },
250                pair.as_span(),
251            );
252            return Err(MessageParseError::new(
253                "ReadRel",
254                ErrorKind::InvalidValue,
255                Box::new(error),
256            ));
257        }
258
259        let mut iter = RuleIter::from(pair.into_inner());
260        let table = iter.parse_next::<TableName>().0;
261        let columns = iter.parse_next_scoped::<NamedColumnList>(extensions)?.0;
262        iter.done();
263
264        let (names, types): (Vec<_>, Vec<_>) = columns.into_iter().map(|c| (c.name, c.typ)).unzip();
265        let struct_ = r#type::Struct {
266            types,
267            type_variation_reference: 0,
268            nullability: r#type::Nullability::Required as i32,
269        };
270        let named_struct = NamedStruct {
271            names,
272            r#struct: Some(struct_),
273        };
274
275        let read_rel = ReadRel {
276            base_schema: Some(named_struct),
277            read_type: Some(read_rel::ReadType::NamedTable(read_rel::NamedTable {
278                names: table,
279                advanced_extension: None,
280            })),
281            ..Default::default()
282        };
283
284        Ok(read_rel)
285    }
286}
287
288impl RelationParsePair for FilterRel {
289    fn rule() -> Rule {
290        Rule::filter_relation
291    }
292
293    fn message() -> &'static str {
294        "FilterRel"
295    }
296
297    fn into_rel(self) -> Rel {
298        Rel {
299            rel_type: Some(RelType::Filter(Box::new(self))),
300        }
301    }
302
303    fn parse_pair_with_context(
304        extensions: &SimpleExtensions,
305        pair: pest::iterators::Pair<Rule>,
306        input_children: Vec<Box<Rel>>,
307        _input_field_count: usize,
308    ) -> Result<Self, MessageParseError> {
309        // Form: Filter[condition => references]
310
311        assert_eq!(pair.as_rule(), Self::rule());
312        let input = expect_one_child(Self::message(), &pair, input_children)?;
313        let mut iter = RuleIter::from(pair.into_inner());
314        // condition
315        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
316        // references (which become the emit)
317        let references_pair = iter.pop(Rule::reference_list);
318        iter.done();
319
320        let emit = parse_reference_emit(references_pair);
321        let common = RelCommon {
322            emit_kind: Some(emit),
323            ..Default::default()
324        };
325
326        Ok(FilterRel {
327            input: Some(input),
328            condition: Some(Box::new(condition)),
329            common: Some(common),
330            advanced_extension: None,
331        })
332    }
333}
334
335impl RelationParsePair for ProjectRel {
336    fn rule() -> Rule {
337        Rule::project_relation
338    }
339
340    fn message() -> &'static str {
341        "ProjectRel"
342    }
343
344    fn into_rel(self) -> Rel {
345        Rel {
346            rel_type: Some(RelType::Project(Box::new(self))),
347        }
348    }
349
350    fn parse_pair_with_context(
351        extensions: &SimpleExtensions,
352        pair: pest::iterators::Pair<Rule>,
353        input_children: Vec<Box<Rel>>,
354        input_field_count: usize,
355    ) -> Result<Self, MessageParseError> {
356        assert_eq!(pair.as_rule(), Self::rule());
357        let input = expect_one_child(Self::message(), &pair, input_children)?;
358
359        // Get the argument list (contains references and expressions)
360        let arguments_pair = unwrap_single_pair(pair);
361
362        let mut expressions = Vec::new();
363        let mut output_mapping = Vec::new();
364
365        // Process each argument (can be either a reference or expression)
366        for arg in arguments_pair.into_inner() {
367            let inner_arg = crate::parser::unwrap_single_pair(arg);
368            match inner_arg.as_rule() {
369                Rule::reference => {
370                    // Parse reference like "$0" -> 0
371                    let field_index = FieldIndex::parse_pair(inner_arg);
372                    output_mapping.push(field_index.0);
373                }
374                Rule::expression => {
375                    // Parse as expression (e.g., 42, add($0, $1))
376                    let _expr = Expression::parse_pair(extensions, inner_arg)?;
377                    expressions.push(_expr);
378                    // Expression: index after all input fields
379                    output_mapping.push(input_field_count as i32 + (expressions.len() as i32 - 1));
380                }
381                _ => panic!("Unexpected inner argument rule: {:?}", inner_arg.as_rule()),
382            }
383        }
384
385        let emit = EmitKind::Emit(Emit { output_mapping });
386        let common = RelCommon {
387            emit_kind: Some(emit),
388            ..Default::default()
389        };
390
391        Ok(ProjectRel {
392            input: Some(input),
393            expressions,
394            common: Some(common),
395            advanced_extension: None,
396        })
397    }
398}
399
400impl RelationParsePair for AggregateRel {
401    fn rule() -> Rule {
402        Rule::aggregate_relation
403    }
404
405    fn message() -> &'static str {
406        "AggregateRel"
407    }
408
409    fn into_rel(self) -> Rel {
410        Rel {
411            rel_type: Some(RelType::Aggregate(Box::new(self))),
412        }
413    }
414
415    fn parse_pair_with_context(
416        extensions: &SimpleExtensions,
417        pair: pest::iterators::Pair<Rule>,
418        input_children: Vec<Box<Rel>>,
419        _input_field_count: usize,
420    ) -> Result<Self, MessageParseError> {
421        assert_eq!(pair.as_rule(), Self::rule());
422        let input = expect_one_child(Self::message(), &pair, input_children)?;
423        let mut iter = RuleIter::from(pair.into_inner());
424        let group_by_pair = iter.pop(Rule::aggregate_group_by);
425        let output_pair = iter.pop(Rule::aggregate_output);
426        iter.done();
427        let mut grouping_expressions = Vec::new();
428        for group_by_item in group_by_pair.into_inner() {
429            match group_by_item.as_rule() {
430                Rule::reference => {
431                    let field_index = FieldIndex::parse_pair(group_by_item);
432                    grouping_expressions.push(Expression {
433                        rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
434                            field_index.to_field_reference(),
435                        ))),
436                    });
437                }
438                Rule::empty => {
439                    // No grouping expressions to add
440                }
441                _ => panic!(
442                    "Unexpected group-by item rule: {:?}",
443                    group_by_item.as_rule()
444                ),
445            }
446        }
447
448        // Parse output items (can be references or aggregate measures)
449        let mut measures = Vec::new();
450        let mut output_mapping = Vec::new();
451        let group_by_count = grouping_expressions.len();
452        let mut measure_count = 0;
453
454        for output_item in output_pair.into_inner() {
455            let inner_item = unwrap_single_pair(output_item);
456            match inner_item.as_rule() {
457                Rule::reference => {
458                    let field_index = FieldIndex::parse_pair(inner_item);
459                    output_mapping.push(field_index.0);
460                }
461                Rule::aggregate_measure => {
462                    let measure = aggregate_rel::Measure::parse_pair(extensions, inner_item)?;
463                    measures.push(measure);
464                    output_mapping.push(group_by_count as i32 + measure_count);
465                    measure_count += 1;
466                }
467                _ => panic!(
468                    "Unexpected inner output item rule: {:?}",
469                    inner_item.as_rule()
470                ),
471            }
472        }
473
474        let emit = EmitKind::Emit(Emit { output_mapping });
475        let common = RelCommon {
476            emit_kind: Some(emit),
477            ..Default::default()
478        };
479
480        Ok(AggregateRel {
481            input: Some(input),
482            grouping_expressions,
483            groupings: vec![], // TODO: Create groupings from grouping_expressions for complex grouping scenarios
484            measures,
485            common: Some(common),
486            advanced_extension: None,
487        })
488    }
489}
490
491impl ScopedParsePair for SortField {
492    fn rule() -> Rule {
493        Rule::sort_field
494    }
495
496    fn message() -> &'static str {
497        "SortField"
498    }
499
500    fn parse_pair(
501        _extensions: &SimpleExtensions,
502        pair: pest::iterators::Pair<Rule>,
503    ) -> Result<Self, MessageParseError> {
504        assert_eq!(pair.as_rule(), Self::rule());
505        let mut iter = RuleIter::from(pair.into_inner());
506        let reference_pair = iter.pop(Rule::reference);
507        let field_index = FieldIndex::parse_pair(reference_pair);
508        let direction_pair = iter.pop(Rule::sort_direction);
509        // Strip the '&' prefix from enum syntax (e.g., "&AscNullsFirst" ->
510        // "AscNullsFirst") The grammar includes '&' to distinguish enums from
511        // identifiers, but the enum variant names don't include it
512        let direction = match direction_pair.as_str().trim_start_matches('&') {
513            "AscNullsFirst" => SortDirection::AscNullsFirst,
514            "AscNullsLast" => SortDirection::AscNullsLast,
515            "DescNullsFirst" => SortDirection::DescNullsFirst,
516            "DescNullsLast" => SortDirection::DescNullsLast,
517            other => {
518                return Err(MessageParseError::invalid(
519                    "SortDirection",
520                    direction_pair.as_span(),
521                    format!("Unknown sort direction: {other}"),
522                ));
523            }
524        };
525        iter.done();
526        Ok(SortField {
527            expr: Some(Expression {
528                rex_type: Some(substrait::proto::expression::RexType::Selection(Box::new(
529                    field_index.to_field_reference(),
530                ))),
531            }),
532            // TODO: Add support for SortKind::ComparisonFunctionReference
533            sort_kind: Some(SortKind::Direction(direction as i32)),
534        })
535    }
536}
537
538impl RelationParsePair for SortRel {
539    fn rule() -> Rule {
540        Rule::sort_relation
541    }
542
543    fn message() -> &'static str {
544        "SortRel"
545    }
546
547    fn into_rel(self) -> Rel {
548        Rel {
549            rel_type: Some(RelType::Sort(Box::new(self))),
550        }
551    }
552
553    fn parse_pair_with_context(
554        extensions: &SimpleExtensions,
555        pair: pest::iterators::Pair<Rule>,
556        input_children: Vec<Box<Rel>>,
557        _input_field_count: usize,
558    ) -> Result<Self, MessageParseError> {
559        assert_eq!(pair.as_rule(), Self::rule());
560        let input = expect_one_child(Self::message(), &pair, input_children)?;
561        let mut iter = RuleIter::from(pair.into_inner());
562        let sort_field_list_pair = iter.pop(Rule::sort_field_list);
563        let reference_list_pair = iter.pop(Rule::reference_list);
564        let mut sorts = Vec::new();
565        for sort_field_pair in sort_field_list_pair.into_inner() {
566            let sort_field = SortField::parse_pair(extensions, sort_field_pair)?;
567            sorts.push(sort_field);
568        }
569        let emit = parse_reference_emit(reference_list_pair);
570        let common = RelCommon {
571            emit_kind: Some(emit),
572            ..Default::default()
573        };
574        iter.done();
575        Ok(SortRel {
576            input: Some(input),
577            sorts,
578            common: Some(common),
579            advanced_extension: None,
580        })
581    }
582}
583
584impl ScopedParsePair for CountMode {
585    fn rule() -> Rule {
586        Rule::fetch_value
587    }
588    fn message() -> &'static str {
589        "CountMode"
590    }
591    fn parse_pair(
592        extensions: &SimpleExtensions,
593        pair: pest::iterators::Pair<Rule>,
594    ) -> Result<Self, MessageParseError> {
595        assert_eq!(pair.as_rule(), Self::rule());
596        let mut arg_inner = RuleIter::from(pair.into_inner());
597        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
598            int_pair
599        } else {
600            arg_inner.pop(Rule::expression)
601        };
602        match value_pair.as_rule() {
603            Rule::integer => {
604                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
605                    MessageParseError::invalid(
606                        Self::message(),
607                        value_pair.as_span(),
608                        format!("Invalid integer: {e}"),
609                    )
610                })?;
611                if value < 0 {
612                    return Err(MessageParseError::invalid(
613                        Self::message(),
614                        value_pair.as_span(),
615                        format!("Fetch limit must be non-negative, got: {value}"),
616                    ));
617                }
618                Ok(CountMode::Count(value))
619            }
620            Rule::expression => {
621                let expr = Expression::parse_pair(extensions, value_pair)?;
622                Ok(CountMode::CountExpr(Box::new(expr)))
623            }
624            _ => Err(MessageParseError::invalid(
625                Self::message(),
626                value_pair.as_span(),
627                format!("Unexpected rule for CountMode: {:?}", value_pair.as_rule()),
628            )),
629        }
630    }
631}
632
633impl ScopedParsePair for OffsetMode {
634    fn rule() -> Rule {
635        Rule::fetch_value
636    }
637    fn message() -> &'static str {
638        "OffsetMode"
639    }
640    fn parse_pair(
641        extensions: &SimpleExtensions,
642        pair: pest::iterators::Pair<Rule>,
643    ) -> Result<Self, MessageParseError> {
644        assert_eq!(pair.as_rule(), Self::rule());
645        let mut arg_inner = RuleIter::from(pair.into_inner());
646        let value_pair = if let Some(int_pair) = arg_inner.try_pop(Rule::integer) {
647            int_pair
648        } else {
649            arg_inner.pop(Rule::expression)
650        };
651        match value_pair.as_rule() {
652            Rule::integer => {
653                let value = value_pair.as_str().parse::<i64>().map_err(|e| {
654                    MessageParseError::invalid(
655                        Self::message(),
656                        value_pair.as_span(),
657                        format!("Invalid integer: {e}"),
658                    )
659                })?;
660                if value < 0 {
661                    return Err(MessageParseError::invalid(
662                        Self::message(),
663                        value_pair.as_span(),
664                        format!("Fetch offset must be non-negative, got: {value}"),
665                    ));
666                }
667                Ok(OffsetMode::Offset(value))
668            }
669            Rule::expression => {
670                let expr = Expression::parse_pair(extensions, value_pair)?;
671                Ok(OffsetMode::OffsetExpr(Box::new(expr)))
672            }
673            _ => Err(MessageParseError::invalid(
674                Self::message(),
675                value_pair.as_span(),
676                format!("Unexpected rule for OffsetMode: {:?}", value_pair.as_rule()),
677            )),
678        }
679    }
680}
681
682impl RelationParsePair for FetchRel {
683    fn rule() -> Rule {
684        Rule::fetch_relation
685    }
686
687    fn message() -> &'static str {
688        "FetchRel"
689    }
690
691    fn into_rel(self) -> Rel {
692        Rel {
693            rel_type: Some(RelType::Fetch(Box::new(self))),
694        }
695    }
696
697    fn parse_pair_with_context(
698        extensions: &SimpleExtensions,
699        pair: pest::iterators::Pair<Rule>,
700        input_children: Vec<Box<Rel>>,
701        _input_field_count: usize,
702    ) -> Result<Self, MessageParseError> {
703        assert_eq!(pair.as_rule(), Self::rule());
704        let input = expect_one_child(Self::message(), &pair, input_children)?;
705        let mut iter = RuleIter::from(pair.into_inner());
706
707        // Extract all pairs first, then do validation
708        let (limit_pair, offset_pair) = match iter.try_pop(Rule::fetch_named_arg_list) {
709            None => {
710                // If there are no arguments, it should be empty
711                iter.pop(Rule::empty);
712                (None, None)
713            }
714            Some(fetch_args_pair) => {
715                let extractor =
716                    ParsedNamedArgs::new(fetch_args_pair.into_inner(), Rule::fetch_named_arg)?;
717                let (extractor, limit_pair) = extractor.pop("limit", Rule::fetch_value);
718                let (extractor, offset_pair) = extractor.pop("offset", Rule::fetch_value);
719                extractor.done()?;
720                (limit_pair, offset_pair)
721            }
722        };
723
724        let reference_list_pair = iter.pop(Rule::reference_list);
725        let emit = parse_reference_emit(reference_list_pair);
726        let common = RelCommon {
727            emit_kind: Some(emit),
728            ..Default::default()
729        };
730        iter.done();
731
732        // Now do validation after iterator is fully consumed
733        let count_mode = limit_pair
734            .map(|pair| CountMode::parse_pair(extensions, pair))
735            .transpose()?;
736        let offset_mode = offset_pair
737            .map(|pair| OffsetMode::parse_pair(extensions, pair))
738            .transpose()?;
739        Ok(FetchRel {
740            input: Some(input),
741            common: Some(common),
742            advanced_extension: None,
743            offset_mode,
744            count_mode,
745        })
746    }
747}
748
749impl ParsePair for join_rel::JoinType {
750    fn rule() -> Rule {
751        Rule::join_type
752    }
753
754    fn message() -> &'static str {
755        "JoinType"
756    }
757
758    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
759        assert_eq!(pair.as_rule(), Self::rule());
760        let join_type_str = pair.as_str().trim_start_matches('&');
761        match join_type_str {
762            "Inner" => join_rel::JoinType::Inner,
763            "Left" => join_rel::JoinType::Left,
764            "Right" => join_rel::JoinType::Right,
765            "Outer" => join_rel::JoinType::Outer,
766            "LeftSemi" => join_rel::JoinType::LeftSemi,
767            "RightSemi" => join_rel::JoinType::RightSemi,
768            "LeftAnti" => join_rel::JoinType::LeftAnti,
769            "RightAnti" => join_rel::JoinType::RightAnti,
770            "LeftSingle" => join_rel::JoinType::LeftSingle,
771            "RightSingle" => join_rel::JoinType::RightSingle,
772            "LeftMark" => join_rel::JoinType::LeftMark,
773            "RightMark" => join_rel::JoinType::RightMark,
774            _ => panic!("Unknown join type: {join_type_str} (this should be caught by grammar)"),
775        }
776    }
777}
778
779impl RelationParsePair for JoinRel {
780    fn rule() -> Rule {
781        Rule::join_relation
782    }
783
784    fn message() -> &'static str {
785        "JoinRel"
786    }
787
788    fn into_rel(self) -> Rel {
789        Rel {
790            rel_type: Some(RelType::Join(Box::new(self))),
791        }
792    }
793
794    fn parse_pair_with_context(
795        extensions: &SimpleExtensions,
796        pair: pest::iterators::Pair<Rule>,
797        input_children: Vec<Box<Rel>>,
798        _input_field_count: usize,
799    ) -> Result<Self, MessageParseError> {
800        assert_eq!(pair.as_rule(), Self::rule());
801
802        // Join requires exactly 2 input children
803        if input_children.len() != 2 {
804            return Err(MessageParseError::invalid(
805                Self::message(),
806                pair.as_span(),
807                format!(
808                    "JoinRel should have exactly 2 input children, got {}",
809                    input_children.len()
810                ),
811            ));
812        }
813
814        let mut children_iter = input_children.into_iter();
815        let left = children_iter.next().unwrap();
816        let right = children_iter.next().unwrap();
817
818        let mut iter = RuleIter::from(pair.into_inner());
819
820        // Parse join type
821        let join_type = iter.parse_next::<join_rel::JoinType>();
822
823        // Parse join condition expression
824        let condition = iter.parse_next_scoped::<Expression>(extensions)?;
825
826        // Parse output references (which become the emit)
827        let reference_list_pair = iter.pop(Rule::reference_list);
828        iter.done();
829
830        let emit = parse_reference_emit(reference_list_pair);
831        let common = RelCommon {
832            emit_kind: Some(emit),
833            ..Default::default()
834        };
835
836        Ok(JoinRel {
837            common: Some(common),
838            left: Some(left),
839            right: Some(right),
840            expression: Some(Box::new(condition)),
841            post_join_filter: None, // Not supported in grammar yet
842            r#type: join_type as i32,
843            advanced_extension: None,
844        })
845    }
846}
847
848#[cfg(test)]
849mod tests {
850    use pest::Parser;
851
852    use super::*;
853    use crate::fixtures::TestContext;
854    use crate::parser::{ExpressionParser, Rule};
855
856    #[test]
857    fn test_parse_relation() {
858        // Removed: test_parse_relation for old Relation struct
859    }
860
861    #[test]
862    fn test_parse_read_relation() {
863        let extensions = SimpleExtensions::default();
864        let read = ReadRel::parse_pair_with_context(
865            &extensions,
866            parse_exact(Rule::read_relation, "Read[ab.cd.ef => a:i32, b:string?]"),
867            vec![],
868            0,
869        )
870        .unwrap();
871        let names = match &read.read_type {
872            Some(read_rel::ReadType::NamedTable(table)) => &table.names,
873            _ => panic!("Expected NamedTable"),
874        };
875        assert_eq!(names, &["ab", "cd", "ef"]);
876        let columns = &read
877            .base_schema
878            .as_ref()
879            .unwrap()
880            .r#struct
881            .as_ref()
882            .unwrap()
883            .types;
884        assert_eq!(columns.len(), 2);
885    }
886
887    /// Produces a ReadRel with 3 columns: a:i32, b:string?, c:i64
888    fn example_read_relation() -> ReadRel {
889        let extensions = SimpleExtensions::default();
890        ReadRel::parse_pair_with_context(
891            &extensions,
892            parse_exact(
893                Rule::read_relation,
894                "Read[ab.cd.ef => a:i32, b:string?, c:i64]",
895            ),
896            vec![],
897            0,
898        )
899        .unwrap()
900    }
901
902    #[test]
903    fn test_parse_filter_relation() {
904        let extensions = SimpleExtensions::default();
905        let filter = FilterRel::parse_pair_with_context(
906            &extensions,
907            parse_exact(Rule::filter_relation, "Filter[$1 => $0, $1, $2]"),
908            vec![Box::new(example_read_relation().into_rel())],
909            3,
910        )
911        .unwrap();
912        let emit_kind = &filter.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
913        let emit = match emit_kind {
914            EmitKind::Emit(emit) => &emit.output_mapping,
915            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
916        };
917        assert_eq!(emit, &[0, 1, 2]);
918    }
919
920    #[test]
921    fn test_parse_project_relation() {
922        let extensions = SimpleExtensions::default();
923        let project = ProjectRel::parse_pair_with_context(
924            &extensions,
925            parse_exact(Rule::project_relation, "Project[$0, $1, 42]"),
926            vec![Box::new(example_read_relation().into_rel())],
927            3,
928        )
929        .unwrap();
930
931        // Should have 1 expression (42) and 2 references ($0, $1)
932        assert_eq!(project.expressions.len(), 1);
933
934        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
935        let emit = match emit_kind {
936            EmitKind::Emit(emit) => &emit.output_mapping,
937            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
938        };
939        // Output mapping should be [0, 1, 3]. References are 0-2; expression is 3.
940        assert_eq!(emit, &[0, 1, 3]);
941    }
942
943    #[test]
944    fn test_parse_project_relation_complex() {
945        let extensions = SimpleExtensions::default();
946        let project = ProjectRel::parse_pair_with_context(
947            &extensions,
948            parse_exact(Rule::project_relation, "Project[42, $0, 100, $2, $1]"),
949            vec![Box::new(example_read_relation().into_rel())],
950            5, // Assume 5 input fields
951        )
952        .unwrap();
953
954        // Should have 2 expressions (42, 100) and 3 references ($0, $2, $1)
955        assert_eq!(project.expressions.len(), 2);
956
957        let emit_kind = &project.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
958        let emit = match emit_kind {
959            EmitKind::Emit(emit) => &emit.output_mapping,
960            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
961        };
962        // Direct mapping: [input_fields..., 42, 100] (input fields first, then expressions)
963        // Output mapping: [5, 0, 6, 2, 1] (to get: 42, $0, 100, $2, $1)
964        assert_eq!(emit, &[5, 0, 6, 2, 1]);
965    }
966
967    #[test]
968    fn test_parse_aggregate_relation() {
969        let extensions = TestContext::new()
970            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
971            .with_function(1, 10, "sum")
972            .with_function(1, 11, "count")
973            .extensions;
974
975        let aggregate = AggregateRel::parse_pair_with_context(
976            &extensions,
977            parse_exact(
978                Rule::aggregate_relation,
979                "Aggregate[$0, $1 => sum($2), $0, count($2)]",
980            ),
981            vec![Box::new(example_read_relation().into_rel())],
982            3,
983        )
984        .unwrap();
985
986        // Should have 2 group-by fields ($0, $1) and 2 measures (sum($2), count($2))
987        assert_eq!(aggregate.grouping_expressions.len(), 2);
988        assert_eq!(aggregate.measures.len(), 2);
989
990        let emit_kind = &aggregate
991            .common
992            .as_ref()
993            .unwrap()
994            .emit_kind
995            .as_ref()
996            .unwrap();
997        let emit = match emit_kind {
998            EmitKind::Emit(emit) => &emit.output_mapping,
999            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1000        };
1001        // Output mapping should be [2, 0, 3] (measures and group-by fields in order)
1002        // sum($2) -> 2, $0 -> 0, count($2) -> 3
1003        assert_eq!(emit, &[2, 0, 3]);
1004    }
1005
1006    #[test]
1007    fn test_parse_aggregate_relation_simple() {
1008        let extensions = TestContext::new()
1009            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1010            .with_function(1, 10, "sum")
1011            .with_function(1, 11, "count")
1012            .extensions;
1013
1014        let aggregate = AggregateRel::parse_pair_with_context(
1015            &extensions,
1016            parse_exact(
1017                Rule::aggregate_relation,
1018                "Aggregate[$0 => sum($1), count($1)]",
1019            ),
1020            vec![Box::new(example_read_relation().into_rel())],
1021            3,
1022        )
1023        .unwrap();
1024
1025        // Should have 1 group-by field ($0) and 2 measures (sum($1), count($1))
1026        assert_eq!(aggregate.grouping_expressions.len(), 1);
1027        assert_eq!(aggregate.measures.len(), 2);
1028
1029        let emit_kind = &aggregate
1030            .common
1031            .as_ref()
1032            .unwrap()
1033            .emit_kind
1034            .as_ref()
1035            .unwrap();
1036        let emit = match emit_kind {
1037            EmitKind::Emit(emit) => &emit.output_mapping,
1038            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1039        };
1040        // Output mapping should be [1, 2] (measures only)
1041        assert_eq!(emit, &[1, 2]);
1042    }
1043
1044    #[test]
1045    fn test_parse_aggregate_relation_no_group_by() {
1046        let extensions = TestContext::new()
1047            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1048            .with_function(1, 10, "sum")
1049            .with_function(1, 11, "count")
1050            .extensions;
1051
1052        let aggregate = AggregateRel::parse_pair_with_context(
1053            &extensions,
1054            parse_exact(
1055                Rule::aggregate_relation,
1056                "Aggregate[_ => sum($0), count($1)]",
1057            ),
1058            vec![Box::new(example_read_relation().into_rel())],
1059            3,
1060        )
1061        .unwrap();
1062
1063        // Should have 0 group-by fields and 2 measures
1064        assert_eq!(aggregate.grouping_expressions.len(), 0);
1065        assert_eq!(aggregate.measures.len(), 2);
1066
1067        let emit_kind = &aggregate
1068            .common
1069            .as_ref()
1070            .unwrap()
1071            .emit_kind
1072            .as_ref()
1073            .unwrap();
1074        let emit = match emit_kind {
1075            EmitKind::Emit(emit) => &emit.output_mapping,
1076            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1077        };
1078        // Output mapping should be [0, 1] (measures only, no group-by fields)
1079        assert_eq!(emit, &[0, 1]);
1080    }
1081
1082    #[test]
1083    fn test_parse_aggregate_relation_empty_group_by() {
1084        let extensions = TestContext::new()
1085            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1086            .with_function(1, 10, "sum")
1087            .with_function(1, 11, "count")
1088            .extensions;
1089
1090        let aggregate = AggregateRel::parse_pair_with_context(
1091            &extensions,
1092            parse_exact(
1093                Rule::aggregate_relation,
1094                "Aggregate[_ => sum($0), count($1)]",
1095            ),
1096            vec![Box::new(example_read_relation().into_rel())],
1097            3,
1098        )
1099        .unwrap();
1100
1101        // Should have 0 group-by fields and 2 measures
1102        assert_eq!(aggregate.grouping_expressions.len(), 0);
1103        assert_eq!(aggregate.measures.len(), 2);
1104
1105        let emit_kind = &aggregate
1106            .common
1107            .as_ref()
1108            .unwrap()
1109            .emit_kind
1110            .as_ref()
1111            .unwrap();
1112        let emit = match emit_kind {
1113            EmitKind::Emit(emit) => &emit.output_mapping,
1114            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1115        };
1116        // Output mapping should be [0, 1] (measures only, no group-by fields)
1117        assert_eq!(emit, &[0, 1]);
1118    }
1119
1120    #[test]
1121    fn test_fetch_relation_positive_values() {
1122        let extensions = SimpleExtensions::default();
1123
1124        // Test valid positive values should work
1125        let fetch_rel = FetchRel::parse_pair_with_context(
1126            &extensions,
1127            parse_exact(Rule::fetch_relation, "Fetch[limit=10, offset=5 => $0]"),
1128            vec![Box::new(example_read_relation().into_rel())],
1129            3,
1130        )
1131        .unwrap();
1132
1133        // Verify the limit and offset values are correct
1134        assert_eq!(fetch_rel.count_mode, Some(CountMode::Count(10)));
1135        assert_eq!(fetch_rel.offset_mode, Some(OffsetMode::Offset(5)));
1136    }
1137
1138    #[test]
1139    fn test_fetch_relation_negative_limit_rejected() {
1140        let extensions = SimpleExtensions::default();
1141
1142        // Test that fetch relations with negative limits are properly rejected
1143        let parsed_result = ExpressionParser::parse(Rule::fetch_relation, "Fetch[limit=-5 => $0]");
1144        if let Ok(mut pairs) = parsed_result {
1145            let pair = pairs.next().unwrap();
1146            if pair.as_str() == "Fetch[limit=-5 => $0]" {
1147                // Full parse succeeded, now test that validation catches the negative value
1148                let result = FetchRel::parse_pair_with_context(
1149                    &extensions,
1150                    pair,
1151                    vec![Box::new(example_read_relation().into_rel())],
1152                    3,
1153                );
1154                assert!(result.is_err());
1155                let error_msg = result.unwrap_err().to_string();
1156                assert!(error_msg.contains("Fetch limit must be non-negative"));
1157            } else {
1158                // If grammar doesn't fully support negative values, that's also acceptable
1159                // since it would prevent negative values at parse time
1160                println!("Grammar prevents negative limit values at parse time");
1161            }
1162        } else {
1163            // Grammar doesn't support negative values in fetch context
1164            println!("Grammar prevents negative limit values at parse time");
1165        }
1166    }
1167
1168    #[test]
1169    fn test_fetch_relation_negative_offset_rejected() {
1170        let extensions = SimpleExtensions::default();
1171
1172        // Test that fetch relations with negative offsets are properly rejected
1173        let parsed_result =
1174            ExpressionParser::parse(Rule::fetch_relation, "Fetch[offset=-10 => $0]");
1175        if let Ok(mut pairs) = parsed_result {
1176            let pair = pairs.next().unwrap();
1177            if pair.as_str() == "Fetch[offset=-10 => $0]" {
1178                // Full parse succeeded, now test that validation catches the negative value
1179                let result = FetchRel::parse_pair_with_context(
1180                    &extensions,
1181                    pair,
1182                    vec![Box::new(example_read_relation().into_rel())],
1183                    3,
1184                );
1185                assert!(result.is_err());
1186                let error_msg = result.unwrap_err().to_string();
1187                assert!(error_msg.contains("Fetch offset must be non-negative"));
1188            } else {
1189                // If grammar doesn't fully support negative values, that's also acceptable
1190                println!("Grammar prevents negative offset values at parse time");
1191            }
1192        } else {
1193            // Grammar doesn't support negative values in fetch context
1194            println!("Grammar prevents negative offset values at parse time");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_parse_join_relation() {
1200        let extensions = TestContext::new()
1201            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1202            .with_function(1, 10, "eq")
1203            .extensions;
1204
1205        let left_rel = example_read_relation().into_rel();
1206        let right_rel = example_read_relation().into_rel();
1207
1208        let join = JoinRel::parse_pair_with_context(
1209            &extensions,
1210            parse_exact(
1211                Rule::join_relation,
1212                "Join[&Inner, eq($0, $3) => $0, $1, $3, $4]",
1213            ),
1214            vec![Box::new(left_rel), Box::new(right_rel)],
1215            6, // left (3) + right (3) = 6 total input fields
1216        )
1217        .unwrap();
1218
1219        // Should be an Inner join
1220        assert_eq!(join.r#type, join_rel::JoinType::Inner as i32);
1221
1222        // Should have left and right relations
1223        assert!(join.left.is_some());
1224        assert!(join.right.is_some());
1225
1226        // Should have a join condition
1227        assert!(join.expression.is_some());
1228
1229        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1230        let emit = match emit_kind {
1231            EmitKind::Emit(emit) => &emit.output_mapping,
1232            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1233        };
1234        // Output mapping should be [0, 1, 3, 4] (selected columns)
1235        assert_eq!(emit, &[0, 1, 3, 4]);
1236    }
1237
1238    #[test]
1239    fn test_parse_join_relation_left_outer() {
1240        let extensions = TestContext::new()
1241            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1242            .with_function(1, 10, "eq")
1243            .extensions;
1244
1245        let left_rel = example_read_relation().into_rel();
1246        let right_rel = example_read_relation().into_rel();
1247
1248        let join = JoinRel::parse_pair_with_context(
1249            &extensions,
1250            parse_exact(Rule::join_relation, "Join[&Left, eq($0, $3) => $0, $1, $2]"),
1251            vec![Box::new(left_rel), Box::new(right_rel)],
1252            6,
1253        )
1254        .unwrap();
1255
1256        // Should be a Left join
1257        assert_eq!(join.r#type, join_rel::JoinType::Left as i32);
1258
1259        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1260        let emit = match emit_kind {
1261            EmitKind::Emit(emit) => &emit.output_mapping,
1262            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1263        };
1264        // Output mapping should be [0, 1, 2]
1265        assert_eq!(emit, &[0, 1, 2]);
1266    }
1267
1268    #[test]
1269    fn test_parse_join_relation_left_semi() {
1270        let extensions = TestContext::new()
1271            .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml")
1272            .with_function(1, 10, "eq")
1273            .extensions;
1274
1275        let left_rel = example_read_relation().into_rel();
1276        let right_rel = example_read_relation().into_rel();
1277
1278        let join = JoinRel::parse_pair_with_context(
1279            &extensions,
1280            parse_exact(Rule::join_relation, "Join[&LeftSemi, eq($0, $3) => $0, $1]"),
1281            vec![Box::new(left_rel), Box::new(right_rel)],
1282            6,
1283        )
1284        .unwrap();
1285
1286        // Should be a LeftSemi join
1287        assert_eq!(join.r#type, join_rel::JoinType::LeftSemi as i32);
1288
1289        let emit_kind = &join.common.as_ref().unwrap().emit_kind.as_ref().unwrap();
1290        let emit = match emit_kind {
1291            EmitKind::Emit(emit) => &emit.output_mapping,
1292            _ => panic!("Expected EmitKind::Emit, got {emit_kind:?}"),
1293        };
1294        // Output mapping should be [0, 1] (only left columns for semi join)
1295        assert_eq!(emit, &[0, 1]);
1296    }
1297
1298    #[test]
1299    fn test_parse_join_relation_requires_two_children() {
1300        let extensions = SimpleExtensions::default();
1301
1302        // Test with 0 children
1303        let result = JoinRel::parse_pair_with_context(
1304            &extensions,
1305            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1306            vec![],
1307            0,
1308        );
1309        assert!(result.is_err());
1310
1311        // Test with 1 child
1312        let result = JoinRel::parse_pair_with_context(
1313            &extensions,
1314            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1315            vec![Box::new(example_read_relation().into_rel())],
1316            3,
1317        );
1318        assert!(result.is_err());
1319
1320        // Test with 3 children
1321        let result = JoinRel::parse_pair_with_context(
1322            &extensions,
1323            parse_exact(Rule::join_relation, "Join[&Inner, eq($0, $1) => $0, $1]"),
1324            vec![
1325                Box::new(example_read_relation().into_rel()),
1326                Box::new(example_read_relation().into_rel()),
1327                Box::new(example_read_relation().into_rel()),
1328            ],
1329            9,
1330        );
1331        assert!(result.is_err());
1332    }
1333
1334    fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
1335        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
1336        assert_eq!(pairs.as_str(), input);
1337        let pair = pairs.next().unwrap();
1338        assert_eq!(pairs.next(), None);
1339        pair
1340    }
1341}