1use substrait::proto::aggregate_rel::Measure;
2use substrait::proto::expression::field_reference::ReferenceType;
3use substrait::proto::expression::literal::LiteralType;
4use substrait::proto::expression::{
5 FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
6};
7use substrait::proto::function_argument::ArgType;
8use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
9use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
10
11use super::types::get_and_validate_anchor;
12use super::{
13 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
14 unwrap_single_pair,
15};
16use crate::extensions::SimpleExtensions;
17use crate::extensions::simple::ExtensionKind;
18use crate::parser::ErrorKind;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct FieldIndex(pub i32);
23
24impl FieldIndex {
25 pub fn to_field_reference(self) -> FieldReference {
27 FieldReference {
30 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
31 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
32 reference_segment::StructField {
33 field: self.0,
34 child: None,
35 },
36 ))),
37 })),
38 root_type: None,
39 }
40 }
41}
42
43impl ParsePair for FieldIndex {
44 fn rule() -> Rule {
45 Rule::reference
46 }
47
48 fn message() -> &'static str {
49 "FieldIndex"
50 }
51
52 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
53 assert_eq!(pair.as_rule(), Self::rule());
54 let inner = unwrap_single_pair(pair);
55 let index: i32 = inner.as_str().parse().unwrap();
56 FieldIndex(index)
57 }
58}
59
60impl ParsePair for FieldReference {
61 fn rule() -> Rule {
62 Rule::reference
63 }
64
65 fn message() -> &'static str {
66 "FieldReference"
67 }
68
69 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
70 assert_eq!(pair.as_rule(), Self::rule());
71
72 FieldIndex::parse_pair(pair).to_field_reference()
74 }
75}
76
77fn to_int_literal(
78 value: pest::iterators::Pair<Rule>,
79 typ: Option<Type>,
80) -> Result<Literal, MessageParseError> {
81 assert_eq!(value.as_rule(), Rule::integer);
82 let parsed_value: i64 = value.as_str().parse().unwrap();
83
84 const DEFAULT_KIND: Kind = Kind::I64(I64 {
85 type_variation_reference: 0,
86 nullability: Nullability::Required as i32,
87 });
88
89 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
91
92 let (lit, nullability, tvar) = match &kind {
93 Kind::I8(i) => (
95 LiteralType::I8(parsed_value as i32),
96 i.nullability,
97 i.type_variation_reference,
98 ),
99 Kind::I16(i) => (
100 LiteralType::I16(parsed_value as i32),
101 i.nullability,
102 i.type_variation_reference,
103 ),
104 Kind::I32(i) => (
105 LiteralType::I32(parsed_value as i32),
106 i.nullability,
107 i.type_variation_reference,
108 ),
109 Kind::I64(i) => (
110 LiteralType::I64(parsed_value),
111 i.nullability,
112 i.type_variation_reference,
113 ),
114 k => {
115 let pest_error = pest::error::Error::new_from_span(
116 pest::error::ErrorVariant::CustomError {
117 message: format!("Invalid type for integer literal: {k:?}"),
118 },
119 value.as_span(),
120 );
121 let error = MessageParseError {
122 message: "int_literal_type",
123 kind: ErrorKind::InvalidValue,
124 error: Box::new(pest_error),
125 };
126 return Err(error);
127 }
128 };
129
130 Ok(Literal {
131 literal_type: Some(lit),
132 nullable: nullability != Nullability::Required as i32,
133 type_variation_reference: tvar,
134 })
135}
136
137fn to_float_literal(
138 value: pest::iterators::Pair<Rule>,
139 typ: Option<Type>,
140) -> Result<Literal, MessageParseError> {
141 assert_eq!(value.as_rule(), Rule::float);
142 let parsed_value: f64 = value.as_str().parse().unwrap();
143
144 const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
145 type_variation_reference: 0,
146 nullability: Nullability::Required as i32,
147 });
148
149 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
151
152 let (lit, nullability, tvar) = match &kind {
153 Kind::Fp32(f) => (
154 LiteralType::Fp32(parsed_value as f32),
155 f.nullability,
156 f.type_variation_reference,
157 ),
158 Kind::Fp64(f) => (
159 LiteralType::Fp64(parsed_value),
160 f.nullability,
161 f.type_variation_reference,
162 ),
163 k => {
164 let pest_error = pest::error::Error::new_from_span(
165 pest::error::ErrorVariant::CustomError {
166 message: format!("Invalid type for float literal: {k:?}"),
167 },
168 value.as_span(),
169 );
170 let error = MessageParseError {
171 message: "float_literal_type",
172 kind: ErrorKind::InvalidValue,
173 error: Box::new(pest_error),
174 };
175 return Err(error);
176 }
177 };
178
179 Ok(Literal {
180 literal_type: Some(lit),
181 nullable: nullability != Nullability::Required as i32,
182 type_variation_reference: tvar,
183 })
184}
185
186fn to_boolean_literal(value: pest::iterators::Pair<Rule>) -> Result<Literal, MessageParseError> {
187 assert_eq!(value.as_rule(), Rule::boolean);
188 let parsed_value: bool = value.as_str().parse().unwrap();
189
190 Ok(Literal {
191 literal_type: Some(LiteralType::Boolean(parsed_value)),
192 nullable: false,
193 type_variation_reference: 0,
194 })
195}
196
197impl ScopedParsePair for Literal {
198 fn rule() -> Rule {
199 Rule::literal
200 }
201
202 fn message() -> &'static str {
203 "Literal"
204 }
205
206 fn parse_pair(
207 extensions: &SimpleExtensions,
208 pair: pest::iterators::Pair<Rule>,
209 ) -> Result<Self, MessageParseError> {
210 assert_eq!(pair.as_rule(), Self::rule());
211 let mut pairs = pair.into_inner();
212 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
215 let typ = match typ {
216 Some(t) => Some(Type::parse_pair(extensions, t)?),
217 None => None,
218 };
219 match value.as_rule() {
220 Rule::integer => to_int_literal(value, typ),
221 Rule::float => to_float_literal(value, typ),
222 Rule::boolean => to_boolean_literal(value),
223 Rule::string_literal => Ok(Literal {
224 literal_type: Some(LiteralType::String(unescape_string(value))),
225 nullable: false,
226 type_variation_reference: 0,
227 }),
228 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
229 }
230 }
231}
232
233impl ScopedParsePair for ScalarFunction {
234 fn rule() -> Rule {
235 Rule::function_call
236 }
237
238 fn message() -> &'static str {
239 "ScalarFunction"
240 }
241
242 fn parse_pair(
243 extensions: &SimpleExtensions,
244 pair: pest::iterators::Pair<Rule>,
245 ) -> Result<Self, MessageParseError> {
246 assert_eq!(pair.as_rule(), Self::rule());
247 let span = pair.as_span();
248 let mut iter = RuleIter::from(pair.into_inner());
249
250 let name = iter.parse_next::<Name>();
252
253 let anchor = iter
255 .try_pop(Rule::anchor)
256 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
257
258 let _uri_anchor = iter
260 .try_pop(Rule::uri_anchor)
261 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
262
263 let argument_list = iter.pop(Rule::argument_list);
265 let mut arguments = Vec::new();
266 for e in argument_list.into_inner() {
267 arguments.push(FunctionArgument {
268 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
269 });
270 }
271
272 let output_type = match iter.try_pop(Rule::r#type) {
274 Some(t) => Some(Type::parse_pair(extensions, t)?),
275 None => None,
276 };
277
278 iter.done();
279 let anchor =
280 get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
281 Ok(ScalarFunction {
282 function_reference: anchor,
283 arguments,
284 options: vec![], output_type,
286 #[allow(deprecated)]
287 args: vec![],
288 })
289 }
290}
291
292impl ScopedParsePair for Expression {
293 fn rule() -> Rule {
294 Rule::expression
295 }
296
297 fn message() -> &'static str {
298 "Expression"
299 }
300
301 fn parse_pair(
302 extensions: &SimpleExtensions,
303 pair: pest::iterators::Pair<Rule>,
304 ) -> Result<Self, MessageParseError> {
305 assert_eq!(pair.as_rule(), Self::rule());
306 let inner = unwrap_single_pair(pair);
307
308 match inner.as_rule() {
309 Rule::literal => Ok(Expression {
310 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
311 }),
312 Rule::function_call => Ok(Expression {
313 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
314 extensions, inner,
315 )?)),
316 }),
317 Rule::reference => Ok(Expression {
318 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
319 inner,
320 )))),
321 }),
322 _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
323 }
324 }
325}
326
327pub struct Name(pub String);
328
329impl ParsePair for Name {
330 fn rule() -> Rule {
331 Rule::name
332 }
333
334 fn message() -> &'static str {
335 "Name"
336 }
337
338 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
339 assert_eq!(pair.as_rule(), Self::rule());
340 let inner = unwrap_single_pair(pair);
341 match inner.as_rule() {
342 Rule::identifier => Name(inner.as_str().to_string()),
343 Rule::quoted_name => Name(unescape_string(inner)),
344 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
345 }
346 }
347}
348
349impl ScopedParsePair for Measure {
350 fn rule() -> Rule {
351 Rule::aggregate_measure
352 }
353
354 fn message() -> &'static str {
355 "Measure"
356 }
357
358 fn parse_pair(
359 extensions: &SimpleExtensions,
360 pair: pest::iterators::Pair<Rule>,
361 ) -> Result<Self, MessageParseError> {
362 assert_eq!(pair.as_rule(), Self::rule());
363
364 let function_call_pair = unwrap_single_pair(pair);
366 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
367
368 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
370 Ok(Measure {
371 measure: Some(AggregateFunction {
372 function_reference: scalar.function_reference,
373 arguments: scalar.arguments,
374 options: scalar.options,
375 output_type: scalar.output_type,
376 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
380 args: scalar.args,
381 }),
382 filter: None, })
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use pest::Parser as PestParser;
390
391 use super::*;
392 use crate::parser::ExpressionParser;
393
394 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
395 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
396 assert_eq!(pairs.as_str(), input);
397 let pair = pairs.next().unwrap();
398 assert_eq!(pairs.next(), None);
399 pair
400 }
401
402 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
403 let pair = parse_exact(T::rule(), input);
404 let actual = T::parse_pair(pair);
405 assert_eq!(actual, expected);
406 }
407
408 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
409 ext: &SimpleExtensions,
410 input: &str,
411 expected: T,
412 ) {
413 let pair = parse_exact(T::rule(), input);
414 let actual = T::parse_pair(ext, pair).unwrap();
415 assert_eq!(actual, expected);
416 }
417
418 #[test]
419 fn test_parse_field_reference() {
420 assert_parses_to("$1", FieldIndex(1).to_field_reference());
421 }
422
423 #[test]
424 fn test_parse_integer_literal() {
425 let extensions = SimpleExtensions::default();
426 let expected = Literal {
427 literal_type: Some(LiteralType::I64(1)),
428 nullable: false,
429 type_variation_reference: 0,
430 };
431 assert_parses_with(&extensions, "1", expected);
432 }
433
434 #[test]
435 fn test_parse_float_literal() {
436 let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
438 let parsed_text = pairs.as_str();
439 assert_eq!(parsed_text, "3.82");
440
441 let extensions = SimpleExtensions::default();
442 let expected = Literal {
443 literal_type: Some(LiteralType::Fp64(3.82)),
444 nullable: false,
445 type_variation_reference: 0,
446 };
447 assert_parses_with(&extensions, "3.82", expected);
448 }
449
450 #[test]
451 fn test_parse_negative_float_literal() {
452 let extensions = SimpleExtensions::default();
453 let expected = Literal {
454 literal_type: Some(LiteralType::Fp64(-2.5)),
455 nullable: false,
456 type_variation_reference: 0,
457 };
458 assert_parses_with(&extensions, "-2.5", expected);
459 }
460
461 #[test]
462 fn test_parse_boolean_true_literal() {
463 let extensions = SimpleExtensions::default();
464 let expected = Literal {
465 literal_type: Some(LiteralType::Boolean(true)),
466 nullable: false,
467 type_variation_reference: 0,
468 };
469 assert_parses_with(&extensions, "true", expected);
470 }
471
472 #[test]
473 fn test_parse_boolean_false_literal() {
474 let extensions = SimpleExtensions::default();
475 let expected = Literal {
476 literal_type: Some(LiteralType::Boolean(false)),
477 nullable: false,
478 type_variation_reference: 0,
479 };
480 assert_parses_with(&extensions, "false", expected);
481 }
482
483 #[test]
484 fn test_parse_float_literal_with_fp32_type() {
485 let extensions = SimpleExtensions::default();
486 let pair = parse_exact(Rule::literal, "3.82:fp32");
487 let result = Literal::parse_pair(&extensions, pair).unwrap();
488
489 match result.literal_type {
490 Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
491 _ => panic!("Expected Fp32 literal type"),
492 }
493 }
494
495 }