substrait_explain/parser/
expressions.rs1use substrait::proto::aggregate_rel::Measure;
2use substrait::proto::expression::field_reference::ReferenceType;
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{
5 FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
6};
7use substrait::proto::function_argument::ArgType;
8use substrait::proto::r#type::{I64, Kind, Nullability};
9use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
10
11use super::types::get_and_validate_anchor;
12use super::{
13 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
14 unwrap_single_pair,
15};
16use crate::extensions::SimpleExtensions;
17use crate::extensions::simple::ExtensionKind;
18use crate::parser::ErrorKind;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct FieldIndex(pub i32);
23
24impl FieldIndex {
25 pub fn to_field_reference(self) -> FieldReference {
27 FieldReference {
30 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
31 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
32 reference_segment::StructField {
33 field: self.0,
34 child: None,
35 },
36 ))),
37 })),
38 root_type: None,
39 }
40 }
41}
42
43impl ParsePair for FieldIndex {
44 fn rule() -> Rule {
45 Rule::reference
46 }
47
48 fn message() -> &'static str {
49 "FieldIndex"
50 }
51
52 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
53 assert_eq!(pair.as_rule(), Self::rule());
54 let inner = unwrap_single_pair(pair);
55 let index: i32 = inner.as_str().parse().unwrap();
56 FieldIndex(index)
57 }
58}
59
60impl ParsePair for FieldReference {
61 fn rule() -> Rule {
62 Rule::reference
63 }
64
65 fn message() -> &'static str {
66 "FieldReference"
67 }
68
69 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
70 assert_eq!(pair.as_rule(), Self::rule());
71
72 FieldIndex::parse_pair(pair).to_field_reference()
74 }
75}
76
77fn to_int_literal(
78 value: pest::iterators::Pair<Rule>,
79 typ: Option<Type>,
80) -> Result<Literal, MessageParseError> {
81 assert_eq!(value.as_rule(), Rule::integer);
82 let parsed_value: i64 = value.as_str().parse().unwrap();
83
84 const DEFAULT_KIND: Kind = Kind::I64(I64 {
85 type_variation_reference: 0,
86 nullability: Nullability::Required as i32,
87 });
88
89 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
91
92 let (lit, nullability, tvar) = match &kind {
93 Kind::I8(i) => (
95 LiteralType::I8(parsed_value as i32),
96 i.nullability,
97 i.type_variation_reference,
98 ),
99 Kind::I16(i) => (
100 LiteralType::I16(parsed_value as i32),
101 i.nullability,
102 i.type_variation_reference,
103 ),
104 Kind::I32(i) => (
105 LiteralType::I32(parsed_value as i32),
106 i.nullability,
107 i.type_variation_reference,
108 ),
109 Kind::I64(i) => (
110 LiteralType::I64(parsed_value),
111 i.nullability,
112 i.type_variation_reference,
113 ),
114 k => {
115 let pest_error = pest::error::Error::new_from_span(
116 pest::error::ErrorVariant::CustomError {
117 message: format!("Invalid type for integer literal: {k:?}"),
118 },
119 value.as_span(),
120 );
121 let error = MessageParseError {
122 message: "int_literal_type",
123 kind: ErrorKind::InvalidValue,
124 error: Box::new(pest_error),
125 };
126 return Err(error);
127 }
128 };
129
130 Ok(Literal {
131 literal_type: Some(lit),
132 nullable: nullability != Nullability::Required as i32,
133 type_variation_reference: tvar,
134 })
135}
136
137impl ScopedParsePair for Literal {
138 fn rule() -> Rule {
139 Rule::literal
140 }
141
142 fn message() -> &'static str {
143 "Literal"
144 }
145
146 fn parse_pair(
147 extensions: &SimpleExtensions,
148 pair: pest::iterators::Pair<Rule>,
149 ) -> Result<Self, MessageParseError> {
150 assert_eq!(pair.as_rule(), Self::rule());
151 let mut pairs = pair.into_inner();
152 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
155 let typ = match typ {
156 Some(t) => Some(Type::parse_pair(extensions, t)?),
157 None => None,
158 };
159 match value.as_rule() {
160 Rule::integer => to_int_literal(value, typ),
161 Rule::string_literal => Ok(Literal {
162 literal_type: Some(LiteralType::String(unescape_string(value))),
163 nullable: false,
164 type_variation_reference: 0,
165 }),
166 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
167 }
168 }
169}
170
171impl ScopedParsePair for ScalarFunction {
172 fn rule() -> Rule {
173 Rule::function_call
174 }
175
176 fn message() -> &'static str {
177 "ScalarFunction"
178 }
179
180 fn parse_pair(
181 extensions: &SimpleExtensions,
182 pair: pest::iterators::Pair<Rule>,
183 ) -> Result<Self, MessageParseError> {
184 assert_eq!(pair.as_rule(), Self::rule());
185 let span = pair.as_span();
186 let mut iter = RuleIter::from(pair.into_inner());
187
188 let name = iter.parse_next::<Name>();
190
191 let anchor = iter
193 .try_pop(Rule::anchor)
194 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
195
196 let _uri_anchor = iter
198 .try_pop(Rule::uri_anchor)
199 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
200
201 let argument_list = iter.pop(Rule::argument_list);
203 let mut arguments = Vec::new();
204 for e in argument_list.into_inner() {
205 arguments.push(FunctionArgument {
206 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
207 });
208 }
209
210 let output_type = match iter.try_pop(Rule::r#type) {
212 Some(t) => Some(Type::parse_pair(extensions, t)?),
213 None => None,
214 };
215
216 iter.done();
217 let anchor =
218 get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
219 Ok(ScalarFunction {
220 function_reference: anchor,
221 arguments,
222 options: vec![], output_type,
224 #[allow(deprecated)]
225 args: vec![],
226 })
227 }
228}
229
230impl ScopedParsePair for Expression {
231 fn rule() -> Rule {
232 Rule::expression
233 }
234
235 fn message() -> &'static str {
236 "Expression"
237 }
238
239 fn parse_pair(
240 extensions: &SimpleExtensions,
241 pair: pest::iterators::Pair<Rule>,
242 ) -> Result<Self, MessageParseError> {
243 assert_eq!(pair.as_rule(), Self::rule());
244 let inner = unwrap_single_pair(pair);
245
246 match inner.as_rule() {
247 Rule::literal => Ok(Expression {
248 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
249 }),
250 Rule::function_call => Ok(Expression {
251 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
252 extensions, inner,
253 )?)),
254 }),
255 Rule::reference => Ok(Expression {
256 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
257 inner,
258 )))),
259 }),
260 _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
261 }
262 }
263}
264
265pub struct Name(pub String);
266
267impl ParsePair for Name {
268 fn rule() -> Rule {
269 Rule::name
270 }
271
272 fn message() -> &'static str {
273 "Name"
274 }
275
276 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
277 assert_eq!(pair.as_rule(), Self::rule());
278 let inner = unwrap_single_pair(pair);
279 match inner.as_rule() {
280 Rule::identifier => Name(inner.as_str().to_string()),
281 Rule::quoted_name => Name(unescape_string(inner)),
282 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
283 }
284 }
285}
286
287impl ScopedParsePair for Measure {
288 fn rule() -> Rule {
289 Rule::aggregate_measure
290 }
291
292 fn message() -> &'static str {
293 "Measure"
294 }
295
296 fn parse_pair(
297 extensions: &SimpleExtensions,
298 pair: pest::iterators::Pair<Rule>,
299 ) -> Result<Self, MessageParseError> {
300 assert_eq!(pair.as_rule(), Self::rule());
301
302 let function_call_pair = unwrap_single_pair(pair);
304 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
305
306 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
308 Ok(Measure {
309 measure: Some(AggregateFunction {
310 function_reference: scalar.function_reference,
311 arguments: scalar.arguments,
312 options: scalar.options,
313 output_type: scalar.output_type,
314 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
318 args: scalar.args,
319 }),
320 filter: None, })
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use pest::Parser as PestParser;
328
329 use super::*;
330 use crate::parser::ExpressionParser;
331
332 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
333 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
334 assert_eq!(pairs.as_str(), input);
335 let pair = pairs.next().unwrap();
336 assert_eq!(pairs.next(), None);
337 pair
338 }
339
340 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
341 let pair = parse_exact(T::rule(), input);
342 let actual = T::parse_pair(pair);
343 assert_eq!(actual, expected);
344 }
345
346 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
347 ext: &SimpleExtensions,
348 input: &str,
349 expected: T,
350 ) {
351 let pair = parse_exact(T::rule(), input);
352 let actual = T::parse_pair(ext, pair).unwrap();
353 assert_eq!(actual, expected);
354 }
355
356 #[test]
357 fn test_parse_field_reference() {
358 assert_parses_to("$1", FieldIndex(1).to_field_reference());
359 }
360
361 #[test]
362 fn test_parse_integer_literal() {
363 let extensions = SimpleExtensions::default();
364 let expected = Literal {
365 literal_type: Some(LiteralType::I64(1)),
366 nullable: false,
367 type_variation_reference: 0,
368 };
369 assert_parses_with(&extensions, "1", expected);
370 }
371
372 }