substrait_explain/parser/
expressions.rs

1use substrait::proto::aggregate_rel::Measure;
2use substrait::proto::expression::field_reference::ReferenceType;
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{
5    FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
6};
7use substrait::proto::function_argument::ArgType;
8use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
9use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
10
11use super::types::get_and_validate_anchor;
12use super::{
13    MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
14    unwrap_single_pair,
15};
16use crate::extensions::SimpleExtensions;
17use crate::extensions::simple::ExtensionKind;
18use crate::parser::ErrorKind;
19
20/// A field index (e.g., parsed from "$0" -> 0).
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct FieldIndex(pub i32);
23
24impl FieldIndex {
25    /// Convert this field index to a FieldReference for use in expressions.
26    pub fn to_field_reference(self) -> FieldReference {
27        // XXX: Why is it so many layers to make a struct field reference? This is
28        // surprisingly complex
29        FieldReference {
30            reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
31                reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
32                    reference_segment::StructField {
33                        field: self.0,
34                        child: None,
35                    },
36                ))),
37            })),
38            root_type: None,
39        }
40    }
41}
42
43impl ParsePair for FieldIndex {
44    fn rule() -> Rule {
45        Rule::reference
46    }
47
48    fn message() -> &'static str {
49        "FieldIndex"
50    }
51
52    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
53        assert_eq!(pair.as_rule(), Self::rule());
54        let inner = unwrap_single_pair(pair);
55        let index: i32 = inner.as_str().parse().unwrap();
56        FieldIndex(index)
57    }
58}
59
60impl ParsePair for FieldReference {
61    fn rule() -> Rule {
62        Rule::reference
63    }
64
65    fn message() -> &'static str {
66        "FieldReference"
67    }
68
69    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
70        assert_eq!(pair.as_rule(), Self::rule());
71
72        // TODO: Other types of references.
73        FieldIndex::parse_pair(pair).to_field_reference()
74    }
75}
76
77fn to_int_literal(
78    value: pest::iterators::Pair<Rule>,
79    typ: Option<Type>,
80) -> Result<Literal, MessageParseError> {
81    assert_eq!(value.as_rule(), Rule::integer);
82    let parsed_value: i64 = value.as_str().parse().unwrap();
83
84    const DEFAULT_KIND: Kind = Kind::I64(I64 {
85        type_variation_reference: 0,
86        nullability: Nullability::Required as i32,
87    });
88
89    // If no type is provided, we assume i64, Nullability::Required.
90    let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
91
92    let (lit, nullability, tvar) = match &kind {
93        // If no type is provided, we assume i64, Nullability::Required.
94        Kind::I8(i) => (
95            LiteralType::I8(parsed_value as i32),
96            i.nullability,
97            i.type_variation_reference,
98        ),
99        Kind::I16(i) => (
100            LiteralType::I16(parsed_value as i32),
101            i.nullability,
102            i.type_variation_reference,
103        ),
104        Kind::I32(i) => (
105            LiteralType::I32(parsed_value as i32),
106            i.nullability,
107            i.type_variation_reference,
108        ),
109        Kind::I64(i) => (
110            LiteralType::I64(parsed_value),
111            i.nullability,
112            i.type_variation_reference,
113        ),
114        k => {
115            let pest_error = pest::error::Error::new_from_span(
116                pest::error::ErrorVariant::CustomError {
117                    message: format!("Invalid type for integer literal: {k:?}"),
118                },
119                value.as_span(),
120            );
121            let error = MessageParseError {
122                message: "int_literal_type",
123                kind: ErrorKind::InvalidValue,
124                error: Box::new(pest_error),
125            };
126            return Err(error);
127        }
128    };
129
130    Ok(Literal {
131        literal_type: Some(lit),
132        nullable: nullability != Nullability::Required as i32,
133        type_variation_reference: tvar,
134    })
135}
136
137fn to_float_literal(
138    value: pest::iterators::Pair<Rule>,
139    typ: Option<Type>,
140) -> Result<Literal, MessageParseError> {
141    assert_eq!(value.as_rule(), Rule::float);
142    let parsed_value: f64 = value.as_str().parse().unwrap();
143
144    const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
145        type_variation_reference: 0,
146        nullability: Nullability::Required as i32,
147    });
148
149    // If no type is provided, we assume fp64, Nullability::Required.
150    let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
151
152    let (lit, nullability, tvar) = match &kind {
153        Kind::Fp32(f) => (
154            LiteralType::Fp32(parsed_value as f32),
155            f.nullability,
156            f.type_variation_reference,
157        ),
158        Kind::Fp64(f) => (
159            LiteralType::Fp64(parsed_value),
160            f.nullability,
161            f.type_variation_reference,
162        ),
163        k => {
164            let pest_error = pest::error::Error::new_from_span(
165                pest::error::ErrorVariant::CustomError {
166                    message: format!("Invalid type for float literal: {k:?}"),
167                },
168                value.as_span(),
169            );
170            let error = MessageParseError {
171                message: "float_literal_type",
172                kind: ErrorKind::InvalidValue,
173                error: Box::new(pest_error),
174            };
175            return Err(error);
176        }
177    };
178
179    Ok(Literal {
180        literal_type: Some(lit),
181        nullable: nullability != Nullability::Required as i32,
182        type_variation_reference: tvar,
183    })
184}
185
186fn to_boolean_literal(value: pest::iterators::Pair<Rule>) -> Result<Literal, MessageParseError> {
187    assert_eq!(value.as_rule(), Rule::boolean);
188    let parsed_value: bool = value.as_str().parse().unwrap();
189
190    Ok(Literal {
191        literal_type: Some(LiteralType::Boolean(parsed_value)),
192        nullable: false,
193        type_variation_reference: 0,
194    })
195}
196
197impl ScopedParsePair for Literal {
198    fn rule() -> Rule {
199        Rule::literal
200    }
201
202    fn message() -> &'static str {
203        "Literal"
204    }
205
206    fn parse_pair(
207        extensions: &SimpleExtensions,
208        pair: pest::iterators::Pair<Rule>,
209    ) -> Result<Self, MessageParseError> {
210        assert_eq!(pair.as_rule(), Self::rule());
211        let mut pairs = pair.into_inner();
212        let value = pairs.next().unwrap(); // First item is always the value
213        let typ = pairs.next(); // Second item is optional type
214        assert!(pairs.next().is_none());
215        let typ = match typ {
216            Some(t) => Some(Type::parse_pair(extensions, t)?),
217            None => None,
218        };
219        match value.as_rule() {
220            Rule::integer => to_int_literal(value, typ),
221            Rule::float => to_float_literal(value, typ),
222            Rule::boolean => to_boolean_literal(value),
223            Rule::string_literal => Ok(Literal {
224                literal_type: Some(LiteralType::String(unescape_string(value))),
225                nullable: false,
226                type_variation_reference: 0,
227            }),
228            _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
229        }
230    }
231}
232
233impl ScopedParsePair for ScalarFunction {
234    fn rule() -> Rule {
235        Rule::function_call
236    }
237
238    fn message() -> &'static str {
239        "ScalarFunction"
240    }
241
242    fn parse_pair(
243        extensions: &SimpleExtensions,
244        pair: pest::iterators::Pair<Rule>,
245    ) -> Result<Self, MessageParseError> {
246        assert_eq!(pair.as_rule(), Self::rule());
247        let span = pair.as_span();
248        let mut iter = RuleIter::from(pair.into_inner());
249
250        // Parse function name (required)
251        let name = iter.parse_next::<Name>();
252
253        // Parse optional anchor (e.g., #1)
254        let anchor = iter
255            .try_pop(Rule::anchor)
256            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
257
258        // Parse optional URI anchor (e.g., @1)
259        let _uri_anchor = iter
260            .try_pop(Rule::uri_anchor)
261            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
262
263        // Parse argument list (required)
264        let argument_list = iter.pop(Rule::argument_list);
265        let mut arguments = Vec::new();
266        for e in argument_list.into_inner() {
267            arguments.push(FunctionArgument {
268                arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
269            });
270        }
271
272        // Parse optional output type (e.g., :i64)
273        let output_type = match iter.try_pop(Rule::r#type) {
274            Some(t) => Some(Type::parse_pair(extensions, t)?),
275            None => None,
276        };
277
278        iter.done();
279        let anchor =
280            get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
281        Ok(ScalarFunction {
282            function_reference: anchor,
283            arguments,
284            options: vec![], // TODO: Function Options
285            output_type,
286            #[allow(deprecated)]
287            args: vec![],
288        })
289    }
290}
291
292impl ScopedParsePair for Expression {
293    fn rule() -> Rule {
294        Rule::expression
295    }
296
297    fn message() -> &'static str {
298        "Expression"
299    }
300
301    fn parse_pair(
302        extensions: &SimpleExtensions,
303        pair: pest::iterators::Pair<Rule>,
304    ) -> Result<Self, MessageParseError> {
305        assert_eq!(pair.as_rule(), Self::rule());
306        let inner = unwrap_single_pair(pair);
307
308        match inner.as_rule() {
309            Rule::literal => Ok(Expression {
310                rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
311            }),
312            Rule::function_call => Ok(Expression {
313                rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
314                    extensions, inner,
315                )?)),
316            }),
317            Rule::reference => Ok(Expression {
318                rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
319                    inner,
320                )))),
321            }),
322            _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
323        }
324    }
325}
326
327pub struct Name(pub String);
328
329impl ParsePair for Name {
330    fn rule() -> Rule {
331        Rule::name
332    }
333
334    fn message() -> &'static str {
335        "Name"
336    }
337
338    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
339        assert_eq!(pair.as_rule(), Self::rule());
340        let inner = unwrap_single_pair(pair);
341        match inner.as_rule() {
342            Rule::identifier => Name(inner.as_str().to_string()),
343            Rule::quoted_name => Name(unescape_string(inner)),
344            _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
345        }
346    }
347}
348
349impl ScopedParsePair for Measure {
350    fn rule() -> Rule {
351        Rule::aggregate_measure
352    }
353
354    fn message() -> &'static str {
355        "Measure"
356    }
357
358    fn parse_pair(
359        extensions: &SimpleExtensions,
360        pair: pest::iterators::Pair<Rule>,
361    ) -> Result<Self, MessageParseError> {
362        assert_eq!(pair.as_rule(), Self::rule());
363
364        // Extract the inner function_call from aggregate_measure
365        let function_call_pair = unwrap_single_pair(pair);
366        assert_eq!(function_call_pair.as_rule(), Rule::function_call);
367
368        // Parse as ScalarFunction, then convert to AggregateFunction
369        let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
370        Ok(Measure {
371            measure: Some(AggregateFunction {
372                function_reference: scalar.function_reference,
373                arguments: scalar.arguments,
374                options: scalar.options,
375                output_type: scalar.output_type,
376                invocation: 0, // TODO: support invocation (ALL, DISTINCT, etc.)
377                phase: 0, // TODO: support phase (INITIAL_TO_RESULT, PARTIAL_TO_INTERMEDIATE, etc.)
378                sorts: vec![], // TODO: support sorts for ordered aggregates
379                #[allow(deprecated)]
380                args: scalar.args,
381            }),
382            filter: None, // TODO: support filter conditions on aggregate measures
383        })
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use pest::Parser as PestParser;
390
391    use super::*;
392    use crate::parser::ExpressionParser;
393
394    fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
395        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
396        assert_eq!(pairs.as_str(), input);
397        let pair = pairs.next().unwrap();
398        assert_eq!(pairs.next(), None);
399        pair
400    }
401
402    fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
403        let pair = parse_exact(T::rule(), input);
404        let actual = T::parse_pair(pair);
405        assert_eq!(actual, expected);
406    }
407
408    fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
409        ext: &SimpleExtensions,
410        input: &str,
411        expected: T,
412    ) {
413        let pair = parse_exact(T::rule(), input);
414        let actual = T::parse_pair(ext, pair).unwrap();
415        assert_eq!(actual, expected);
416    }
417
418    #[test]
419    fn test_parse_field_reference() {
420        assert_parses_to("$1", FieldIndex(1).to_field_reference());
421    }
422
423    #[test]
424    fn test_parse_integer_literal() {
425        let extensions = SimpleExtensions::default();
426        let expected = Literal {
427            literal_type: Some(LiteralType::I64(1)),
428            nullable: false,
429            type_variation_reference: 0,
430        };
431        assert_parses_with(&extensions, "1", expected);
432    }
433
434    #[test]
435    fn test_parse_float_literal() {
436        // First test that the grammar can parse floats
437        let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
438        let parsed_text = pairs.as_str();
439        assert_eq!(parsed_text, "3.82");
440
441        let extensions = SimpleExtensions::default();
442        let expected = Literal {
443            literal_type: Some(LiteralType::Fp64(3.82)),
444            nullable: false,
445            type_variation_reference: 0,
446        };
447        assert_parses_with(&extensions, "3.82", expected);
448    }
449
450    #[test]
451    fn test_parse_negative_float_literal() {
452        let extensions = SimpleExtensions::default();
453        let expected = Literal {
454            literal_type: Some(LiteralType::Fp64(-2.5)),
455            nullable: false,
456            type_variation_reference: 0,
457        };
458        assert_parses_with(&extensions, "-2.5", expected);
459    }
460
461    #[test]
462    fn test_parse_boolean_true_literal() {
463        let extensions = SimpleExtensions::default();
464        let expected = Literal {
465            literal_type: Some(LiteralType::Boolean(true)),
466            nullable: false,
467            type_variation_reference: 0,
468        };
469        assert_parses_with(&extensions, "true", expected);
470    }
471
472    #[test]
473    fn test_parse_boolean_false_literal() {
474        let extensions = SimpleExtensions::default();
475        let expected = Literal {
476            literal_type: Some(LiteralType::Boolean(false)),
477            nullable: false,
478            type_variation_reference: 0,
479        };
480        assert_parses_with(&extensions, "false", expected);
481    }
482
483    #[test]
484    fn test_parse_float_literal_with_fp32_type() {
485        let extensions = SimpleExtensions::default();
486        let pair = parse_exact(Rule::literal, "3.82:fp32");
487        let result = Literal::parse_pair(&extensions, pair).unwrap();
488
489        match result.literal_type {
490            Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
491            _ => panic!("Expected Fp32 literal type"),
492        }
493    }
494
495    // #[test]
496    // fn test_parse_string_literal() {
497    //     assert_parses_to("'hello'", Literal::String("hello".to_string()));
498    // }
499
500    // #[test]
501    // fn test_parse_function_call_simple() {
502    //     assert_parses_to(
503    //         "add()",
504    //         FunctionCall {
505    //             name: "add".to_string(),
506    //             parameters: None,
507    //             anchor: None,
508    //             uri_anchor: None,
509    //             arguments: vec![],
510    //         },
511    //     );
512    // }
513
514    // #[test]
515    // fn test_parse_function_call_with_parameters() {
516    //     assert_parses_to(
517    //         "add<param1, param2>()",
518    //         FunctionCall {
519    //             name: "add".to_string(),
520    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
521    //             anchor: None,
522    //             uri_anchor: None,
523    //             arguments: vec![],
524    //         },
525    //     );
526    // }
527
528    // #[test]
529    // fn test_parse_function_call_with_anchor() {
530    //     assert_parses_to(
531    //         "add#1()",
532    //         FunctionCall {
533    //             name: "add".to_string(),
534    //             parameters: None,
535    //             anchor: Some(1),
536    //             uri_anchor: None,
537    //             arguments: vec![],
538    //         },
539    //     );
540    // }
541
542    // #[test]
543    // fn test_parse_function_call_with_uri_anchor() {
544    //     assert_parses_to(
545    //         "add@1()",
546    //         FunctionCall {
547    //             name: "add".to_string(),
548    //             parameters: None,
549    //             anchor: None,
550    //             uri_anchor: Some(1),
551    //             arguments: vec![],
552    //         },
553    //     );
554    // }
555
556    // #[test]
557    // fn test_parse_function_call_all_optionals() {
558    //     assert_parses_to(
559    //         "add<param1, param2>#1@2()",
560    //         FunctionCall {
561    //             name: "add".to_string(),
562    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
563    //             anchor: Some(1),
564    //             uri_anchor: Some(2),
565    //             arguments: vec![],
566    //         },
567    //     );
568    // }
569
570    // #[test]
571    // fn test_parse_function_call_with_simple_arguments() {
572    //     assert_parses_to(
573    //         "add(1, 2)",
574    //         FunctionCall {
575    //             name: "add".to_string(),
576    //             parameters: None,
577    //             anchor: None,
578    //             uri_anchor: None,
579    //             arguments: vec![
580    //                 Expression::Literal(Literal::Integer(1)),
581    //                 Expression::Literal(Literal::Integer(2)),
582    //             ],
583    //         },
584    //     );
585    // }
586
587    // #[test]
588    // fn test_parse_function_call_with_nested_function() {
589    //     assert_parses_to(
590    //         "outer_func(inner_func(), $1)",
591    //         Expression::FunctionCall(Box::new(FunctionCall {
592    //             name: "outer_func".to_string(),
593    //             parameters: None,
594    //             anchor: None,
595    //             uri_anchor: None,
596    //             arguments: vec![
597    //                 Expression::FunctionCall(Box::new(FunctionCall {
598    //                     name: "inner_func".to_string(),
599    //                     parameters: None,
600    //                     anchor: None,
601    //                     uri_anchor: None,
602    //                     arguments: vec![],
603    //                 })),
604    //                 Expression::Reference(Reference(1)),
605    //             ],
606    //         })),
607    //     );
608    // }
609
610    // #[test]
611    // fn test_parse_function_call_funny_names() {
612    //     assert_parses_to(
613    //         "'funny name'<param1, param2>#1@2()",
614    //         FunctionCall {
615    //             name: "funny name".to_string(),
616    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
617    //             anchor: Some(1),
618    //             uri_anchor: Some(2),
619    //             arguments: vec![],
620    //         },
621    //     );
622    // }
623
624    // #[test]
625    // fn test_parse_empty_string_literal() {
626    //     assert_parses_to("''", Literal::String("".to_string()));
627    // }
628}