substrait_explain/parser/
expressions.rs

1use substrait::proto::aggregate_rel::Measure;
2use substrait::proto::expression::field_reference::ReferenceType;
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{
5    FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
6};
7use substrait::proto::function_argument::ArgType;
8use substrait::proto::r#type::{I64, Kind, Nullability};
9use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
10
11use super::types::get_and_validate_anchor;
12use super::{
13    MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
14    unwrap_single_pair,
15};
16use crate::extensions::SimpleExtensions;
17use crate::extensions::simple::ExtensionKind;
18use crate::parser::ErrorKind;
19
20/// A field index (e.g., parsed from "$0" -> 0).
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct FieldIndex(pub i32);
23
24impl FieldIndex {
25    /// Convert this field index to a FieldReference for use in expressions.
26    pub fn to_field_reference(self) -> FieldReference {
27        // XXX: Why is it so many layers to make a struct field reference? This is
28        // surprisingly complex
29        FieldReference {
30            reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
31                reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
32                    reference_segment::StructField {
33                        field: self.0,
34                        child: None,
35                    },
36                ))),
37            })),
38            root_type: None,
39        }
40    }
41}
42
43impl ParsePair for FieldIndex {
44    fn rule() -> Rule {
45        Rule::reference
46    }
47
48    fn message() -> &'static str {
49        "FieldIndex"
50    }
51
52    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
53        assert_eq!(pair.as_rule(), Self::rule());
54        let inner = unwrap_single_pair(pair);
55        let index: i32 = inner.as_str().parse().unwrap();
56        FieldIndex(index)
57    }
58}
59
60impl ParsePair for FieldReference {
61    fn rule() -> Rule {
62        Rule::reference
63    }
64
65    fn message() -> &'static str {
66        "FieldReference"
67    }
68
69    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
70        assert_eq!(pair.as_rule(), Self::rule());
71
72        // TODO: Other types of references.
73        FieldIndex::parse_pair(pair).to_field_reference()
74    }
75}
76
77fn to_int_literal(
78    value: pest::iterators::Pair<Rule>,
79    typ: Option<Type>,
80) -> Result<Literal, MessageParseError> {
81    assert_eq!(value.as_rule(), Rule::integer);
82    let parsed_value: i64 = value.as_str().parse().unwrap();
83
84    const DEFAULT_KIND: Kind = Kind::I64(I64 {
85        type_variation_reference: 0,
86        nullability: Nullability::Required as i32,
87    });
88
89    // If no type is provided, we assume i64, Nullability::Required.
90    let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
91
92    let (lit, nullability, tvar) = match &kind {
93        // If no type is provided, we assume i64, Nullability::Required.
94        Kind::I8(i) => (
95            LiteralType::I8(parsed_value as i32),
96            i.nullability,
97            i.type_variation_reference,
98        ),
99        Kind::I16(i) => (
100            LiteralType::I16(parsed_value as i32),
101            i.nullability,
102            i.type_variation_reference,
103        ),
104        Kind::I32(i) => (
105            LiteralType::I32(parsed_value as i32),
106            i.nullability,
107            i.type_variation_reference,
108        ),
109        Kind::I64(i) => (
110            LiteralType::I64(parsed_value),
111            i.nullability,
112            i.type_variation_reference,
113        ),
114        k => {
115            let pest_error = pest::error::Error::new_from_span(
116                pest::error::ErrorVariant::CustomError {
117                    message: format!("Invalid type for integer literal: {k:?}"),
118                },
119                value.as_span(),
120            );
121            let error = MessageParseError {
122                message: "int_literal_type",
123                kind: ErrorKind::InvalidValue,
124                error: Box::new(pest_error),
125            };
126            return Err(error);
127        }
128    };
129
130    Ok(Literal {
131        literal_type: Some(lit),
132        nullable: nullability != Nullability::Required as i32,
133        type_variation_reference: tvar,
134    })
135}
136
137impl ScopedParsePair for Literal {
138    fn rule() -> Rule {
139        Rule::literal
140    }
141
142    fn message() -> &'static str {
143        "Literal"
144    }
145
146    fn parse_pair(
147        extensions: &SimpleExtensions,
148        pair: pest::iterators::Pair<Rule>,
149    ) -> Result<Self, MessageParseError> {
150        assert_eq!(pair.as_rule(), Self::rule());
151        let mut pairs = pair.into_inner();
152        let value = pairs.next().unwrap(); // First item is always the value
153        let typ = pairs.next(); // Second item is optional type
154        assert!(pairs.next().is_none());
155        let typ = match typ {
156            Some(t) => Some(Type::parse_pair(extensions, t)?),
157            None => None,
158        };
159        match value.as_rule() {
160            Rule::integer => to_int_literal(value, typ),
161            Rule::string_literal => Ok(Literal {
162                literal_type: Some(LiteralType::String(unescape_string(value))),
163                nullable: false,
164                type_variation_reference: 0,
165            }),
166            _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
167        }
168    }
169}
170
171impl ScopedParsePair for ScalarFunction {
172    fn rule() -> Rule {
173        Rule::function_call
174    }
175
176    fn message() -> &'static str {
177        "ScalarFunction"
178    }
179
180    fn parse_pair(
181        extensions: &SimpleExtensions,
182        pair: pest::iterators::Pair<Rule>,
183    ) -> Result<Self, MessageParseError> {
184        assert_eq!(pair.as_rule(), Self::rule());
185        let span = pair.as_span();
186        let mut iter = RuleIter::from(pair.into_inner());
187
188        // Parse function name (required)
189        let name = iter.parse_next::<Name>();
190
191        // Parse optional anchor (e.g., #1)
192        let anchor = iter
193            .try_pop(Rule::anchor)
194            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
195
196        // Parse optional URI anchor (e.g., @1)
197        let _uri_anchor = iter
198            .try_pop(Rule::uri_anchor)
199            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
200
201        // Parse argument list (required)
202        let argument_list = iter.pop(Rule::argument_list);
203        let mut arguments = Vec::new();
204        for e in argument_list.into_inner() {
205            arguments.push(FunctionArgument {
206                arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
207            });
208        }
209
210        // Parse optional output type (e.g., :i64)
211        let output_type = match iter.try_pop(Rule::r#type) {
212            Some(t) => Some(Type::parse_pair(extensions, t)?),
213            None => None,
214        };
215
216        iter.done();
217        let anchor =
218            get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
219        Ok(ScalarFunction {
220            function_reference: anchor,
221            arguments,
222            options: vec![], // TODO: Function Options
223            output_type,
224            #[allow(deprecated)]
225            args: vec![],
226        })
227    }
228}
229
230impl ScopedParsePair for Expression {
231    fn rule() -> Rule {
232        Rule::expression
233    }
234
235    fn message() -> &'static str {
236        "Expression"
237    }
238
239    fn parse_pair(
240        extensions: &SimpleExtensions,
241        pair: pest::iterators::Pair<Rule>,
242    ) -> Result<Self, MessageParseError> {
243        assert_eq!(pair.as_rule(), Self::rule());
244        let inner = unwrap_single_pair(pair);
245
246        match inner.as_rule() {
247            Rule::literal => Ok(Expression {
248                rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
249            }),
250            Rule::function_call => Ok(Expression {
251                rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
252                    extensions, inner,
253                )?)),
254            }),
255            Rule::reference => Ok(Expression {
256                rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
257                    inner,
258                )))),
259            }),
260            _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
261        }
262    }
263}
264
265pub struct Name(pub String);
266
267impl ParsePair for Name {
268    fn rule() -> Rule {
269        Rule::name
270    }
271
272    fn message() -> &'static str {
273        "Name"
274    }
275
276    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
277        assert_eq!(pair.as_rule(), Self::rule());
278        let inner = unwrap_single_pair(pair);
279        match inner.as_rule() {
280            Rule::identifier => Name(inner.as_str().to_string()),
281            Rule::quoted_name => Name(unescape_string(inner)),
282            _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
283        }
284    }
285}
286
287impl ScopedParsePair for Measure {
288    fn rule() -> Rule {
289        Rule::aggregate_measure
290    }
291
292    fn message() -> &'static str {
293        "Measure"
294    }
295
296    fn parse_pair(
297        extensions: &SimpleExtensions,
298        pair: pest::iterators::Pair<Rule>,
299    ) -> Result<Self, MessageParseError> {
300        assert_eq!(pair.as_rule(), Self::rule());
301
302        // Extract the inner function_call from aggregate_measure
303        let function_call_pair = unwrap_single_pair(pair);
304        assert_eq!(function_call_pair.as_rule(), Rule::function_call);
305
306        // Parse as ScalarFunction, then convert to AggregateFunction
307        let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
308        Ok(Measure {
309            measure: Some(AggregateFunction {
310                function_reference: scalar.function_reference,
311                arguments: scalar.arguments,
312                options: scalar.options,
313                output_type: scalar.output_type,
314                invocation: 0, // TODO: support invocation (ALL, DISTINCT, etc.)
315                phase: 0, // TODO: support phase (INITIAL_TO_RESULT, PARTIAL_TO_INTERMEDIATE, etc.)
316                sorts: vec![], // TODO: support sorts for ordered aggregates
317                #[allow(deprecated)]
318                args: scalar.args,
319            }),
320            filter: None, // TODO: support filter conditions on aggregate measures
321        })
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use pest::Parser as PestParser;
328
329    use super::*;
330    use crate::parser::ExpressionParser;
331
332    fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
333        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
334        assert_eq!(pairs.as_str(), input);
335        let pair = pairs.next().unwrap();
336        assert_eq!(pairs.next(), None);
337        pair
338    }
339
340    fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
341        let pair = parse_exact(T::rule(), input);
342        let actual = T::parse_pair(pair);
343        assert_eq!(actual, expected);
344    }
345
346    fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
347        ext: &SimpleExtensions,
348        input: &str,
349        expected: T,
350    ) {
351        let pair = parse_exact(T::rule(), input);
352        let actual = T::parse_pair(ext, pair).unwrap();
353        assert_eq!(actual, expected);
354    }
355
356    #[test]
357    fn test_parse_field_reference() {
358        assert_parses_to("$1", FieldIndex(1).to_field_reference());
359    }
360
361    #[test]
362    fn test_parse_integer_literal() {
363        let extensions = SimpleExtensions::default();
364        let expected = Literal {
365            literal_type: Some(LiteralType::I64(1)),
366            nullable: false,
367            type_variation_reference: 0,
368        };
369        assert_parses_with(&extensions, "1", expected);
370    }
371
372    // #[test]
373    // fn test_parse_string_literal() {
374    //     assert_parses_to("'hello'", Literal::String("hello".to_string()));
375    // }
376
377    // #[test]
378    // fn test_parse_function_call_simple() {
379    //     assert_parses_to(
380    //         "add()",
381    //         FunctionCall {
382    //             name: "add".to_string(),
383    //             parameters: None,
384    //             anchor: None,
385    //             uri_anchor: None,
386    //             arguments: vec![],
387    //         },
388    //     );
389    // }
390
391    // #[test]
392    // fn test_parse_function_call_with_parameters() {
393    //     assert_parses_to(
394    //         "add<param1, param2>()",
395    //         FunctionCall {
396    //             name: "add".to_string(),
397    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
398    //             anchor: None,
399    //             uri_anchor: None,
400    //             arguments: vec![],
401    //         },
402    //     );
403    // }
404
405    // #[test]
406    // fn test_parse_function_call_with_anchor() {
407    //     assert_parses_to(
408    //         "add#1()",
409    //         FunctionCall {
410    //             name: "add".to_string(),
411    //             parameters: None,
412    //             anchor: Some(1),
413    //             uri_anchor: None,
414    //             arguments: vec![],
415    //         },
416    //     );
417    // }
418
419    // #[test]
420    // fn test_parse_function_call_with_uri_anchor() {
421    //     assert_parses_to(
422    //         "add@1()",
423    //         FunctionCall {
424    //             name: "add".to_string(),
425    //             parameters: None,
426    //             anchor: None,
427    //             uri_anchor: Some(1),
428    //             arguments: vec![],
429    //         },
430    //     );
431    // }
432
433    // #[test]
434    // fn test_parse_function_call_all_optionals() {
435    //     assert_parses_to(
436    //         "add<param1, param2>#1@2()",
437    //         FunctionCall {
438    //             name: "add".to_string(),
439    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
440    //             anchor: Some(1),
441    //             uri_anchor: Some(2),
442    //             arguments: vec![],
443    //         },
444    //     );
445    // }
446
447    // #[test]
448    // fn test_parse_function_call_with_simple_arguments() {
449    //     assert_parses_to(
450    //         "add(1, 2)",
451    //         FunctionCall {
452    //             name: "add".to_string(),
453    //             parameters: None,
454    //             anchor: None,
455    //             uri_anchor: None,
456    //             arguments: vec![
457    //                 Expression::Literal(Literal::Integer(1)),
458    //                 Expression::Literal(Literal::Integer(2)),
459    //             ],
460    //         },
461    //     );
462    // }
463
464    // #[test]
465    // fn test_parse_function_call_with_nested_function() {
466    //     assert_parses_to(
467    //         "outer_func(inner_func(), $1)",
468    //         Expression::FunctionCall(Box::new(FunctionCall {
469    //             name: "outer_func".to_string(),
470    //             parameters: None,
471    //             anchor: None,
472    //             uri_anchor: None,
473    //             arguments: vec![
474    //                 Expression::FunctionCall(Box::new(FunctionCall {
475    //                     name: "inner_func".to_string(),
476    //                     parameters: None,
477    //                     anchor: None,
478    //                     uri_anchor: None,
479    //                     arguments: vec![],
480    //                 })),
481    //                 Expression::Reference(Reference(1)),
482    //             ],
483    //         })),
484    //     );
485    // }
486
487    // #[test]
488    // fn test_parse_function_call_funny_names() {
489    //     assert_parses_to(
490    //         "'funny name'<param1, param2>#1@2()",
491    //         FunctionCall {
492    //             name: "funny name".to_string(),
493    //             parameters: Some(vec!["param1".to_string(), "param2".to_string()]),
494    //             anchor: Some(1),
495    //             uri_anchor: Some(2),
496    //             arguments: vec![],
497    //         },
498    //     );
499    // }
500
501    // #[test]
502    // fn test_parse_empty_string_literal() {
503    //     assert_parses_to("''", Literal::String("".to_string()));
504    // }
505}