1use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
2use substrait::proto::aggregate_rel::Measure;
3use substrait::proto::expression::field_reference::ReferenceType;
4use substrait::proto::expression::literal::LiteralType;
5use substrait::proto::expression::{
6 FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
7};
8use substrait::proto::function_argument::ArgType;
9use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
10use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
11
12use super::types::get_and_validate_anchor;
13use super::{
14 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
15 unwrap_single_pair,
16};
17use crate::extensions::SimpleExtensions;
18use crate::extensions::simple::ExtensionKind;
19use crate::parser::ErrorKind;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct FieldIndex(pub i32);
24
25impl FieldIndex {
26 pub fn to_field_reference(self) -> FieldReference {
28 FieldReference {
31 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
32 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
33 reference_segment::StructField {
34 field: self.0,
35 child: None,
36 },
37 ))),
38 })),
39 root_type: None,
40 }
41 }
42}
43
44impl ParsePair for FieldIndex {
45 fn rule() -> Rule {
46 Rule::reference
47 }
48
49 fn message() -> &'static str {
50 "FieldIndex"
51 }
52
53 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
54 assert_eq!(pair.as_rule(), Self::rule());
55 let inner = unwrap_single_pair(pair);
56 let index: i32 = inner.as_str().parse().unwrap();
57 FieldIndex(index)
58 }
59}
60
61impl ParsePair for FieldReference {
62 fn rule() -> Rule {
63 Rule::reference
64 }
65
66 fn message() -> &'static str {
67 "FieldReference"
68 }
69
70 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
71 assert_eq!(pair.as_rule(), Self::rule());
72
73 FieldIndex::parse_pair(pair).to_field_reference()
75 }
76}
77
78fn to_int_literal(
79 value: pest::iterators::Pair<Rule>,
80 typ: Option<Type>,
81) -> Result<Literal, MessageParseError> {
82 assert_eq!(value.as_rule(), Rule::integer);
83 let parsed_value: i64 = value.as_str().parse().unwrap();
84
85 const DEFAULT_KIND: Kind = Kind::I64(I64 {
86 type_variation_reference: 0,
87 nullability: Nullability::Required as i32,
88 });
89
90 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
92
93 let (lit, nullability, tvar) = match &kind {
94 Kind::I8(i) => (
96 LiteralType::I8(parsed_value as i32),
97 i.nullability,
98 i.type_variation_reference,
99 ),
100 Kind::I16(i) => (
101 LiteralType::I16(parsed_value as i32),
102 i.nullability,
103 i.type_variation_reference,
104 ),
105 Kind::I32(i) => (
106 LiteralType::I32(parsed_value as i32),
107 i.nullability,
108 i.type_variation_reference,
109 ),
110 Kind::I64(i) => (
111 LiteralType::I64(parsed_value),
112 i.nullability,
113 i.type_variation_reference,
114 ),
115 k => {
116 let pest_error = pest::error::Error::new_from_span(
117 pest::error::ErrorVariant::CustomError {
118 message: format!("Invalid type for integer literal: {k:?}"),
119 },
120 value.as_span(),
121 );
122 let error = MessageParseError {
123 message: "int_literal_type",
124 kind: ErrorKind::InvalidValue,
125 error: Box::new(pest_error),
126 };
127 return Err(error);
128 }
129 };
130
131 Ok(Literal {
132 literal_type: Some(lit),
133 nullable: nullability != Nullability::Required as i32,
134 type_variation_reference: tvar,
135 })
136}
137
138fn to_float_literal(
139 value: pest::iterators::Pair<Rule>,
140 typ: Option<Type>,
141) -> Result<Literal, MessageParseError> {
142 assert_eq!(value.as_rule(), Rule::float);
143 let parsed_value: f64 = value.as_str().parse().unwrap();
144
145 const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
146 type_variation_reference: 0,
147 nullability: Nullability::Required as i32,
148 });
149
150 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
152
153 let (lit, nullability, tvar) = match &kind {
154 Kind::Fp32(f) => (
155 LiteralType::Fp32(parsed_value as f32),
156 f.nullability,
157 f.type_variation_reference,
158 ),
159 Kind::Fp64(f) => (
160 LiteralType::Fp64(parsed_value),
161 f.nullability,
162 f.type_variation_reference,
163 ),
164 k => {
165 let pest_error = pest::error::Error::new_from_span(
166 pest::error::ErrorVariant::CustomError {
167 message: format!("Invalid type for float literal: {k:?}"),
168 },
169 value.as_span(),
170 );
171 let error = MessageParseError {
172 message: "float_literal_type",
173 kind: ErrorKind::InvalidValue,
174 error: Box::new(pest_error),
175 };
176 return Err(error);
177 }
178 };
179
180 Ok(Literal {
181 literal_type: Some(lit),
182 nullable: nullability != Nullability::Required as i32,
183 type_variation_reference: tvar,
184 })
185}
186
187fn to_boolean_literal(value: pest::iterators::Pair<Rule>) -> Result<Literal, MessageParseError> {
188 assert_eq!(value.as_rule(), Rule::boolean);
189 let parsed_value: bool = value.as_str().parse().unwrap();
190
191 Ok(Literal {
192 literal_type: Some(LiteralType::Boolean(parsed_value)),
193 nullable: false,
194 type_variation_reference: 0,
195 })
196}
197
198fn to_string_literal(
199 value: pest::iterators::Pair<Rule>,
200 typ: Option<Type>,
201) -> Result<Literal, MessageParseError> {
202 assert_eq!(value.as_rule(), Rule::string_literal);
203 let string_value = unescape_string(value.clone());
204
205 let Some(typ) = typ else {
207 return Ok(Literal {
208 literal_type: Some(LiteralType::String(string_value)),
209 nullable: false,
210 type_variation_reference: 0,
211 });
212 };
213
214 let Some(kind) = typ.kind else {
215 return Ok(Literal {
216 literal_type: Some(LiteralType::String(string_value)),
217 nullable: false,
218 type_variation_reference: 0,
219 });
220 };
221
222 match &kind {
223 Kind::Date(d) => {
224 let date_days = parse_date_to_days(&string_value, value.as_span())?;
226 Ok(Literal {
227 literal_type: Some(LiteralType::Date(date_days)),
228 nullable: d.nullability != Nullability::Required as i32,
229 type_variation_reference: d.type_variation_reference,
230 })
231 }
232 Kind::Time(t) => {
233 let time_microseconds = parse_time_to_microseconds(&string_value, value.as_span())?;
235 Ok(Literal {
236 literal_type: Some(LiteralType::Time(time_microseconds)),
237 nullable: t.nullability != Nullability::Required as i32,
238 type_variation_reference: t.type_variation_reference,
239 })
240 }
241 Kind::Timestamp(ts) => {
242 let timestamp_microseconds =
244 parse_timestamp_to_microseconds(&string_value, value.as_span())?;
245 Ok(Literal {
246 literal_type: Some(LiteralType::Timestamp(timestamp_microseconds)),
247 nullable: ts.nullability != Nullability::Required as i32,
248 type_variation_reference: ts.type_variation_reference,
249 })
250 }
251 _ => {
252 Ok(Literal {
254 literal_type: Some(LiteralType::String(string_value)),
255 nullable: false,
256 type_variation_reference: 0,
257 })
258 }
259 }
260}
261
262fn parse_date_to_days(date_str: &str, span: pest::Span) -> Result<i32, MessageParseError> {
264 let formats = ["%Y-%m-%d", "%Y/%m/%d"];
266
267 for format in &formats {
268 if let Ok(date) = NaiveDate::parse_from_str(date_str, format) {
269 let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
271 let days = date.signed_duration_since(epoch).num_days();
272 return Ok(days as i32);
273 }
274 }
275
276 Err(MessageParseError {
277 message: "date_parse_format",
278 kind: ErrorKind::InvalidValue,
279 error: Box::new(pest::error::Error::new_from_span(
280 pest::error::ErrorVariant::CustomError {
281 message: format!(
282 "Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD"
283 ),
284 },
285 span,
286 )),
287 })
288}
289
290fn parse_time_to_microseconds(time_str: &str, span: pest::Span) -> Result<i64, MessageParseError> {
292 let formats = ["%H:%M:%S%.f", "%H:%M:%S"];
294
295 for format in &formats {
296 if let Ok(time) = NaiveTime::parse_from_str(time_str, format) {
297 let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap();
299 let duration = time.signed_duration_since(midnight);
300 return Ok(duration.num_microseconds().unwrap_or(0));
301 }
302 }
303
304 Err(MessageParseError {
305 message: "time_parse_format",
306 kind: ErrorKind::InvalidValue,
307 error: Box::new(pest::error::Error::new_from_span(
308 pest::error::ErrorVariant::CustomError {
309 message: format!(
310 "Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff"
311 ),
312 },
313 span,
314 )),
315 })
316}
317
318fn parse_timestamp_to_microseconds(
320 timestamp_str: &str,
321 span: pest::Span,
322) -> Result<i64, MessageParseError> {
323 let formats = [
325 "%Y-%m-%dT%H:%M:%S%.f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S%.f", "%Y-%m-%d %H:%M:%S", "%Y/%m/%dT%H:%M:%S%.f", "%Y/%m/%dT%H:%M:%S", "%Y/%m/%d %H:%M:%S%.f", "%Y/%m/%d %H:%M:%S", ];
334
335 for format in &formats {
336 if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) {
337 let epoch = DateTime::from_timestamp(0, 0).unwrap().naive_utc();
339 let duration = datetime.signed_duration_since(epoch);
340 return Ok(duration.num_microseconds().unwrap_or(0));
341 }
342 }
343
344 Err(MessageParseError {
345 message: "timestamp_parse_format",
346 kind: ErrorKind::InvalidValue,
347 error: Box::new(pest::error::Error::new_from_span(
348 pest::error::ErrorVariant::CustomError {
349 message: format!(
350 "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS"
351 ),
352 },
353 span,
354 )),
355 })
356}
357
358impl ScopedParsePair for Literal {
359 fn rule() -> Rule {
360 Rule::literal
361 }
362
363 fn message() -> &'static str {
364 "Literal"
365 }
366
367 fn parse_pair(
368 extensions: &SimpleExtensions,
369 pair: pest::iterators::Pair<Rule>,
370 ) -> Result<Self, MessageParseError> {
371 assert_eq!(pair.as_rule(), Self::rule());
372 let mut pairs = pair.into_inner();
373 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
376 let typ = match typ {
377 Some(t) => Some(Type::parse_pair(extensions, t)?),
378 None => None,
379 };
380 match value.as_rule() {
381 Rule::integer => to_int_literal(value, typ),
382 Rule::float => to_float_literal(value, typ),
383 Rule::boolean => to_boolean_literal(value),
384 Rule::string_literal => to_string_literal(value, typ),
385 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
386 }
387 }
388}
389
390impl ScopedParsePair for ScalarFunction {
391 fn rule() -> Rule {
392 Rule::function_call
393 }
394
395 fn message() -> &'static str {
396 "ScalarFunction"
397 }
398
399 fn parse_pair(
400 extensions: &SimpleExtensions,
401 pair: pest::iterators::Pair<Rule>,
402 ) -> Result<Self, MessageParseError> {
403 assert_eq!(pair.as_rule(), Self::rule());
404 let span = pair.as_span();
405 let mut iter = RuleIter::from(pair.into_inner());
406
407 let name = iter.parse_next::<Name>();
409
410 let anchor = iter
412 .try_pop(Rule::anchor)
413 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
414
415 let _uri_anchor = iter
417 .try_pop(Rule::uri_anchor)
418 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
419
420 let argument_list = iter.pop(Rule::argument_list);
422 let mut arguments = Vec::new();
423 for e in argument_list.into_inner() {
424 arguments.push(FunctionArgument {
425 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
426 });
427 }
428
429 let output_type = match iter.try_pop(Rule::r#type) {
431 Some(t) => Some(Type::parse_pair(extensions, t)?),
432 None => None,
433 };
434
435 iter.done();
436 let anchor =
437 get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
438 Ok(ScalarFunction {
439 function_reference: anchor,
440 arguments,
441 options: vec![], output_type,
443 #[allow(deprecated)]
444 args: vec![],
445 })
446 }
447}
448
449impl ScopedParsePair for Expression {
450 fn rule() -> Rule {
451 Rule::expression
452 }
453
454 fn message() -> &'static str {
455 "Expression"
456 }
457
458 fn parse_pair(
459 extensions: &SimpleExtensions,
460 pair: pest::iterators::Pair<Rule>,
461 ) -> Result<Self, MessageParseError> {
462 assert_eq!(pair.as_rule(), Self::rule());
463 let inner = unwrap_single_pair(pair);
464
465 match inner.as_rule() {
466 Rule::literal => Ok(Expression {
467 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
468 }),
469 Rule::function_call => Ok(Expression {
470 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
471 extensions, inner,
472 )?)),
473 }),
474 Rule::reference => Ok(Expression {
475 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
476 inner,
477 )))),
478 }),
479 _ => unimplemented!("Expression unexpected rule: {:?}", inner.as_rule()),
480 }
481 }
482}
483
484pub struct Name(pub String);
485
486impl ParsePair for Name {
487 fn rule() -> Rule {
488 Rule::name
489 }
490
491 fn message() -> &'static str {
492 "Name"
493 }
494
495 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
496 assert_eq!(pair.as_rule(), Self::rule());
497 let inner = unwrap_single_pair(pair);
498 match inner.as_rule() {
499 Rule::identifier => Name(inner.as_str().to_string()),
500 Rule::quoted_name => Name(unescape_string(inner)),
501 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
502 }
503 }
504}
505
506impl ScopedParsePair for Measure {
507 fn rule() -> Rule {
508 Rule::aggregate_measure
509 }
510
511 fn message() -> &'static str {
512 "Measure"
513 }
514
515 fn parse_pair(
516 extensions: &SimpleExtensions,
517 pair: pest::iterators::Pair<Rule>,
518 ) -> Result<Self, MessageParseError> {
519 assert_eq!(pair.as_rule(), Self::rule());
520
521 let function_call_pair = unwrap_single_pair(pair);
523 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
524
525 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
527 Ok(Measure {
528 measure: Some(AggregateFunction {
529 function_reference: scalar.function_reference,
530 arguments: scalar.arguments,
531 options: scalar.options,
532 output_type: scalar.output_type,
533 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
537 args: scalar.args,
538 }),
539 filter: None, })
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use pest::Parser as PestParser;
547
548 use super::*;
549 use crate::parser::ExpressionParser;
550
551 fn parse_exact(rule: Rule, input: &str) -> pest::iterators::Pair<Rule> {
552 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
553 assert_eq!(pairs.as_str(), input);
554 let pair = pairs.next().unwrap();
555 assert_eq!(pairs.next(), None);
556 pair
557 }
558
559 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
560 let pair = parse_exact(T::rule(), input);
561 let actual = T::parse_pair(pair);
562 assert_eq!(actual, expected);
563 }
564
565 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
566 ext: &SimpleExtensions,
567 input: &str,
568 expected: T,
569 ) {
570 let pair = parse_exact(T::rule(), input);
571 let actual = T::parse_pair(ext, pair).unwrap();
572 assert_eq!(actual, expected);
573 }
574
575 #[test]
576 fn test_parse_field_reference() {
577 assert_parses_to("$1", FieldIndex(1).to_field_reference());
578 }
579
580 #[test]
581 fn test_parse_integer_literal() {
582 let extensions = SimpleExtensions::default();
583 let expected = Literal {
584 literal_type: Some(LiteralType::I64(1)),
585 nullable: false,
586 type_variation_reference: 0,
587 };
588 assert_parses_with(&extensions, "1", expected);
589 }
590
591 #[test]
592 fn test_parse_float_literal() {
593 let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
595 let parsed_text = pairs.as_str();
596 assert_eq!(parsed_text, "3.82");
597
598 let extensions = SimpleExtensions::default();
599 let expected = Literal {
600 literal_type: Some(LiteralType::Fp64(3.82)),
601 nullable: false,
602 type_variation_reference: 0,
603 };
604 assert_parses_with(&extensions, "3.82", expected);
605 }
606
607 #[test]
608 fn test_parse_negative_float_literal() {
609 let extensions = SimpleExtensions::default();
610 let expected = Literal {
611 literal_type: Some(LiteralType::Fp64(-2.5)),
612 nullable: false,
613 type_variation_reference: 0,
614 };
615 assert_parses_with(&extensions, "-2.5", expected);
616 }
617
618 #[test]
619 fn test_parse_boolean_true_literal() {
620 let extensions = SimpleExtensions::default();
621 let expected = Literal {
622 literal_type: Some(LiteralType::Boolean(true)),
623 nullable: false,
624 type_variation_reference: 0,
625 };
626 assert_parses_with(&extensions, "true", expected);
627 }
628
629 #[test]
630 fn test_parse_boolean_false_literal() {
631 let extensions = SimpleExtensions::default();
632 let expected = Literal {
633 literal_type: Some(LiteralType::Boolean(false)),
634 nullable: false,
635 type_variation_reference: 0,
636 };
637 assert_parses_with(&extensions, "false", expected);
638 }
639
640 #[test]
641 fn test_parse_float_literal_with_fp32_type() {
642 let extensions = SimpleExtensions::default();
643 let pair = parse_exact(Rule::literal, "3.82:fp32");
644 let result = Literal::parse_pair(&extensions, pair).unwrap();
645
646 match result.literal_type {
647 Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
648 _ => panic!("Expected Fp32 literal type"),
649 }
650 }
651
652 #[test]
653 fn test_parse_date_literal() {
654 let extensions = SimpleExtensions::default();
655 let pair = parse_exact(Rule::literal, "'2023-12-25':date");
656 let result = Literal::parse_pair(&extensions, pair).unwrap();
657
658 match result.literal_type {
659 Some(LiteralType::Date(days)) => {
660 assert!(
662 days > 0,
663 "Expected positive days since epoch, got: {}",
664 days
665 );
666 }
667 _ => panic!("Expected Date literal type, got: {:?}", result.literal_type),
668 }
669 }
670
671 #[test]
672 fn test_parse_time_literal() {
673 let extensions = SimpleExtensions::default();
674 let pair = parse_exact(Rule::literal, "'14:30:45':time");
675 let result = Literal::parse_pair(&extensions, pair).unwrap();
676
677 match result.literal_type {
678 Some(LiteralType::Time(microseconds)) => {
679 let expected = (14 * 3600 + 30 * 60 + 45) * 1_000_000;
681 assert_eq!(microseconds, expected);
682 }
683 _ => panic!("Expected Time literal type, got: {:?}", result.literal_type),
684 }
685 }
686
687 #[test]
688 fn test_parse_timestamp_literal_with_t() {
689 let extensions = SimpleExtensions::default();
690 let pair = parse_exact(Rule::literal, "'2023-01-01T12:00:00':timestamp");
691 let result = Literal::parse_pair(&extensions, pair).unwrap();
692
693 match result.literal_type {
694 Some(LiteralType::Timestamp(microseconds)) => {
695 assert!(
696 microseconds > 0,
697 "Expected positive microseconds since epoch"
698 );
699 }
700 _ => panic!(
701 "Expected Timestamp literal type, got: {:?}",
702 result.literal_type
703 ),
704 }
705 }
706
707 #[test]
708 fn test_parse_timestamp_literal_with_space() {
709 let extensions = SimpleExtensions::default();
710 let pair = parse_exact(Rule::literal, "'2023-01-01 12:00:00':timestamp");
711 let result = Literal::parse_pair(&extensions, pair).unwrap();
712
713 match result.literal_type {
714 Some(LiteralType::Timestamp(microseconds)) => {
715 assert!(
716 microseconds > 0,
717 "Expected positive microseconds since epoch"
718 );
719 }
720 _ => panic!(
721 "Expected Timestamp literal type, got: {:?}",
722 result.literal_type
723 ),
724 }
725 }
726
727 }