1use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
2use substrait::proto::aggregate_rel::Measure;
3use substrait::proto::expression::field_reference::ReferenceType;
4use substrait::proto::expression::if_then::IfClause;
5use substrait::proto::expression::literal::LiteralType;
6use substrait::proto::expression::{
7 Cast, FieldReference, IfThen, Literal, ReferenceSegment, RexType, ScalarFunction, cast,
8 reference_segment,
9};
10use substrait::proto::function_argument::ArgType;
11use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
12use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
13
14use super::types::get_and_validate_anchor;
15use super::{
16 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
17 unwrap_single_pair,
18};
19use crate::extensions::SimpleExtensions;
20use crate::extensions::simple::{CompoundName, ExtensionKind};
21use crate::parser::ErrorKind;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct FieldIndex(pub i32);
26
27impl FieldIndex {
28 pub fn to_field_reference(self) -> FieldReference {
30 FieldReference {
33 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
34 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
35 reference_segment::StructField {
36 field: self.0,
37 child: None,
38 },
39 ))),
40 })),
41 root_type: None,
42 }
43 }
44}
45
46impl ParsePair for FieldIndex {
47 fn rule() -> Rule {
48 Rule::reference
49 }
50
51 fn message() -> &'static str {
52 "FieldIndex"
53 }
54
55 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
56 assert_eq!(pair.as_rule(), Self::rule());
57 let inner = unwrap_single_pair(pair);
58 let index: i32 = inner.as_str().parse().unwrap();
59 FieldIndex(index)
60 }
61}
62
63impl ParsePair for FieldReference {
64 fn rule() -> Rule {
65 Rule::reference
66 }
67
68 fn message() -> &'static str {
69 "FieldReference"
70 }
71
72 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
73 assert_eq!(pair.as_rule(), Self::rule());
74
75 FieldIndex::parse_pair(pair).to_field_reference()
77 }
78}
79
80fn to_int_literal(
81 value: pest::iterators::Pair<Rule>,
82 typ: Option<Type>,
83) -> Result<Literal, MessageParseError> {
84 assert_eq!(value.as_rule(), Rule::integer);
85 let parsed_value: i64 = value.as_str().parse().unwrap();
86
87 const DEFAULT_KIND: Kind = Kind::I64(I64 {
88 type_variation_reference: 0,
89 nullability: Nullability::Required as i32,
90 });
91
92 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
94
95 let (lit, nullability, tvar) = match &kind {
96 Kind::I8(i) => (
98 LiteralType::I8(parsed_value as i32),
99 i.nullability,
100 i.type_variation_reference,
101 ),
102 Kind::I16(i) => (
103 LiteralType::I16(parsed_value as i32),
104 i.nullability,
105 i.type_variation_reference,
106 ),
107 Kind::I32(i) => (
108 LiteralType::I32(parsed_value as i32),
109 i.nullability,
110 i.type_variation_reference,
111 ),
112 Kind::I64(i) => (
113 LiteralType::I64(parsed_value),
114 i.nullability,
115 i.type_variation_reference,
116 ),
117 k => {
118 let pest_error = pest::error::Error::new_from_span(
119 pest::error::ErrorVariant::CustomError {
120 message: format!("Invalid type for integer literal: {k:?}"),
121 },
122 value.as_span(),
123 );
124 let error = MessageParseError {
125 message: "int_literal_type",
126 kind: ErrorKind::InvalidValue,
127 error: Box::new(pest_error),
128 };
129 return Err(error);
130 }
131 };
132
133 Ok(Literal {
134 literal_type: Some(lit),
135 nullable: nullability != Nullability::Required as i32,
136 type_variation_reference: tvar,
137 })
138}
139
140fn to_float_literal(
141 value: pest::iterators::Pair<Rule>,
142 typ: Option<Type>,
143) -> Result<Literal, MessageParseError> {
144 assert_eq!(value.as_rule(), Rule::float);
145 let parsed_value: f64 = value.as_str().parse().unwrap();
146
147 const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
148 type_variation_reference: 0,
149 nullability: Nullability::Required as i32,
150 });
151
152 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
154
155 let (lit, nullability, tvar) = match &kind {
156 Kind::Fp32(f) => (
157 LiteralType::Fp32(parsed_value as f32),
158 f.nullability,
159 f.type_variation_reference,
160 ),
161 Kind::Fp64(f) => (
162 LiteralType::Fp64(parsed_value),
163 f.nullability,
164 f.type_variation_reference,
165 ),
166 k => {
167 let pest_error = pest::error::Error::new_from_span(
168 pest::error::ErrorVariant::CustomError {
169 message: format!("Invalid type for float literal: {k:?}"),
170 },
171 value.as_span(),
172 );
173 let error = MessageParseError {
174 message: "float_literal_type",
175 kind: ErrorKind::InvalidValue,
176 error: Box::new(pest_error),
177 };
178 return Err(error);
179 }
180 };
181
182 Ok(Literal {
183 literal_type: Some(lit),
184 nullable: nullability != Nullability::Required as i32,
185 type_variation_reference: tvar,
186 })
187}
188
189fn to_boolean_literal(
190 value: pest::iterators::Pair<Rule>,
191 typ: Option<Type>,
192) -> Result<Literal, MessageParseError> {
193 assert_eq!(value.as_rule(), Rule::boolean);
194 let parsed_value: bool = value.as_str().parse().unwrap();
195
196 let (nullable, tvar) = match typ.and_then(|t| t.kind) {
197 Some(Kind::Bool(b)) => (
198 b.nullability != Nullability::Required as i32,
199 b.type_variation_reference,
200 ),
201 None => (false, 0),
202 Some(k) => {
203 let pest_error = pest::error::Error::new_from_span(
204 pest::error::ErrorVariant::CustomError {
205 message: format!("Invalid type for boolean literal: {k:?}"),
206 },
207 value.as_span(),
208 );
209 return Err(MessageParseError {
210 message: "bool_literal_type",
211 kind: ErrorKind::InvalidValue,
212 error: Box::new(pest_error),
213 });
214 }
215 };
216
217 Ok(Literal {
218 literal_type: Some(LiteralType::Boolean(parsed_value)),
219 nullable,
220 type_variation_reference: tvar,
221 })
222}
223
224fn to_string_literal(
225 value: pest::iterators::Pair<Rule>,
226 typ: Option<Type>,
227) -> Result<Literal, MessageParseError> {
228 assert_eq!(value.as_rule(), Rule::string_literal);
229 let string_value = unescape_string(value.clone());
230
231 let Some(typ) = typ else {
233 return Ok(Literal {
234 literal_type: Some(LiteralType::String(string_value)),
235 nullable: false,
236 type_variation_reference: 0,
237 });
238 };
239
240 let Some(kind) = typ.kind else {
241 return Ok(Literal {
242 literal_type: Some(LiteralType::String(string_value)),
243 nullable: false,
244 type_variation_reference: 0,
245 });
246 };
247
248 match &kind {
249 Kind::Date(d) => {
250 let date_days = parse_date_to_days(&string_value, value.as_span())?;
252 Ok(Literal {
253 literal_type: Some(LiteralType::Date(date_days)),
254 nullable: d.nullability != Nullability::Required as i32,
255 type_variation_reference: d.type_variation_reference,
256 })
257 }
258 Kind::Time(t) => {
259 let time_microseconds = parse_time_to_microseconds(&string_value, value.as_span())?;
261 Ok(Literal {
262 literal_type: Some(LiteralType::Time(time_microseconds)),
263 nullable: t.nullability != Nullability::Required as i32,
264 type_variation_reference: t.type_variation_reference,
265 })
266 }
267 #[allow(deprecated)]
268 Kind::Timestamp(ts) => {
269 let timestamp_microseconds =
271 parse_timestamp_to_microseconds(&string_value, value.as_span())?;
272 Ok(Literal {
273 literal_type: Some(LiteralType::Timestamp(timestamp_microseconds)),
274 nullable: ts.nullability != Nullability::Required as i32,
275 type_variation_reference: ts.type_variation_reference,
276 })
277 }
278 _ => {
279 Ok(Literal {
281 literal_type: Some(LiteralType::String(string_value)),
282 nullable: false,
283 type_variation_reference: 0,
284 })
285 }
286 }
287}
288
289fn parse_date_to_days(date_str: &str, span: pest::Span) -> Result<i32, MessageParseError> {
291 let formats = ["%Y-%m-%d", "%Y/%m/%d"];
293
294 for format in &formats {
295 if let Ok(date) = NaiveDate::parse_from_str(date_str, format) {
296 let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
298 let days = date.signed_duration_since(epoch).num_days();
299 return Ok(days as i32);
300 }
301 }
302
303 Err(MessageParseError {
304 message: "date_parse_format",
305 kind: ErrorKind::InvalidValue,
306 error: Box::new(pest::error::Error::new_from_span(
307 pest::error::ErrorVariant::CustomError {
308 message: format!(
309 "Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD"
310 ),
311 },
312 span,
313 )),
314 })
315}
316
317fn parse_time_to_microseconds(time_str: &str, span: pest::Span) -> Result<i64, MessageParseError> {
319 let formats = ["%H:%M:%S%.f", "%H:%M:%S"];
321
322 for format in &formats {
323 if let Ok(time) = NaiveTime::parse_from_str(time_str, format) {
324 let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap();
326 let duration = time.signed_duration_since(midnight);
327 return Ok(duration.num_microseconds().unwrap_or(0));
328 }
329 }
330
331 Err(MessageParseError {
332 message: "time_parse_format",
333 kind: ErrorKind::InvalidValue,
334 error: Box::new(pest::error::Error::new_from_span(
335 pest::error::ErrorVariant::CustomError {
336 message: format!(
337 "Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff"
338 ),
339 },
340 span,
341 )),
342 })
343}
344
345fn parse_timestamp_to_microseconds(
347 timestamp_str: &str,
348 span: pest::Span,
349) -> Result<i64, MessageParseError> {
350 let formats = [
352 "%Y-%m-%dT%H:%M:%S%.f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S%.f", "%Y-%m-%d %H:%M:%S", "%Y/%m/%dT%H:%M:%S%.f", "%Y/%m/%dT%H:%M:%S", "%Y/%m/%d %H:%M:%S%.f", "%Y/%m/%d %H:%M:%S", ];
361
362 for format in &formats {
363 if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) {
364 let epoch = DateTime::from_timestamp(0, 0).unwrap().naive_utc();
366 let duration = datetime.signed_duration_since(epoch);
367 return Ok(duration.num_microseconds().unwrap_or(0));
368 }
369 }
370
371 Err(MessageParseError {
372 message: "timestamp_parse_format",
373 kind: ErrorKind::InvalidValue,
374 error: Box::new(pest::error::Error::new_from_span(
375 pest::error::ErrorVariant::CustomError {
376 message: format!(
377 "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS"
378 ),
379 },
380 span,
381 )),
382 })
383}
384
385impl ScopedParsePair for Literal {
386 fn rule() -> Rule {
387 Rule::literal
388 }
389
390 fn message() -> &'static str {
391 "Literal"
392 }
393
394 fn parse_pair(
395 extensions: &SimpleExtensions,
396 pair: pest::iterators::Pair<Rule>,
397 ) -> Result<Self, MessageParseError> {
398 assert_eq!(pair.as_rule(), Self::rule());
399 let mut pairs = pair.into_inner();
400 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
403 let typ = match typ {
404 Some(t) => Some(Type::parse_pair(extensions, t)?),
405 None => None,
406 };
407 match value.as_rule() {
408 Rule::integer => to_int_literal(value, typ),
409 Rule::float => to_float_literal(value, typ),
410 Rule::boolean => to_boolean_literal(value, typ),
411 Rule::string_literal => to_string_literal(value, typ),
412 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
413 }
414 }
415}
416
417impl ScopedParsePair for ScalarFunction {
418 fn rule() -> Rule {
419 Rule::function_call
420 }
421
422 fn message() -> &'static str {
423 "ScalarFunction"
424 }
425
426 fn parse_pair(
427 extensions: &SimpleExtensions,
428 pair: pest::iterators::Pair<Rule>,
429 ) -> Result<Self, MessageParseError> {
430 assert_eq!(pair.as_rule(), Self::rule());
431 let span = pair.as_span();
432 let mut iter = RuleIter::from(pair.into_inner());
433
434 let name = iter.parse_next::<CompoundName>();
436
437 let anchor = iter
439 .try_pop(Rule::anchor)
440 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
441
442 let _urn_anchor = iter
444 .try_pop(Rule::urn_anchor)
445 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
446
447 let argument_list = iter.pop(Rule::argument_list);
449 let mut arguments = Vec::new();
450 for e in argument_list.into_inner() {
451 arguments.push(FunctionArgument {
452 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
453 });
454 }
455
456 let output_type = match iter.try_pop(Rule::r#type) {
458 Some(t) => Some(Type::parse_pair(extensions, t)?),
459 None => None,
460 };
461
462 iter.done();
463 let anchor = get_and_validate_anchor(
464 extensions,
465 ExtensionKind::Function,
466 anchor,
467 name.full(),
468 span,
469 )?;
470 Ok(ScalarFunction {
471 function_reference: anchor,
472 arguments,
473 options: vec![], output_type,
475 #[allow(deprecated)]
476 args: vec![],
477 })
478 }
479}
480
481impl ScopedParsePair for Cast {
482 fn rule() -> Rule {
483 Rule::cast_expression
484 }
485
486 fn message() -> &'static str {
487 "Cast"
488 }
489
490 fn parse_pair(
491 extensions: &SimpleExtensions,
492 pair: pest::iterators::Pair<Rule>,
493 ) -> Result<Self, MessageParseError> {
494 assert_eq!(pair.as_rule(), Self::rule());
495 let mut pairs = pair.into_inner();
496
497 let expr_pair = pairs.next().unwrap();
498
499 let next = pairs.next().unwrap();
501 let (failure_behavior, type_pair) = if next.as_rule() == Rule::cast_failure_behavior {
502 let fb = match next.as_str() {
503 "?" => cast::FailureBehavior::ReturnNull as i32,
504 "!" => cast::FailureBehavior::ThrowException as i32,
505 _ => unreachable!("Grammar guarantees cast_failure_behavior is ? or !"),
506 };
507 (fb, pairs.next().unwrap())
508 } else {
509 (cast::FailureBehavior::Unspecified as i32, next)
510 };
511
512 assert!(pairs.next().is_none());
513
514 let input = Expression::parse_pair(extensions, expr_pair)?;
515 let target_type = Type::parse_pair(extensions, type_pair)?;
516
517 Ok(Cast {
518 r#type: Some(target_type),
519 input: Some(Box::new(input)),
520 failure_behavior,
521 })
522 }
523}
524
525impl ScopedParsePair for Expression {
526 fn rule() -> Rule {
527 Rule::expression
528 }
529
530 fn message() -> &'static str {
531 "Expression"
532 }
533
534 fn parse_pair(
535 extensions: &SimpleExtensions,
536 pair: pest::iterators::Pair<Rule>,
537 ) -> Result<Self, MessageParseError> {
538 assert_eq!(pair.as_rule(), Self::rule());
539 let inner = unwrap_single_pair(pair);
540 match inner.as_rule() {
541 Rule::literal => Ok(Expression {
542 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
543 }),
544 Rule::function_call => Ok(Expression {
545 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
546 extensions, inner,
547 )?)),
548 }),
549 Rule::reference => Ok(Expression {
550 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
551 inner,
552 )))),
553 }),
554 Rule::if_then => Ok(Expression {
555 rex_type: Some(RexType::IfThen(Box::new(IfThen::parse_pair(
556 extensions, inner,
557 )?))),
558 }),
559 Rule::cast_expression => Ok(Expression {
560 rex_type: Some(RexType::Cast(Box::new(Cast::parse_pair(
561 extensions, inner,
562 )?))),
563 }),
564 _ => unreachable!(
565 "Grammar guarantees expression can only be literal, function_call, reference, if_then, or cast_expression, got: {:?}",
566 inner.as_rule()
567 ),
568 }
569 }
570}
571
572impl ScopedParsePair for IfClause {
573 fn rule() -> Rule {
574 Rule::if_clause
575 }
576
577 fn message() -> &'static str {
578 "IfClause"
579 }
580
581 fn parse_pair(
582 extensions: &SimpleExtensions,
583 pair: pest::iterators::Pair<Rule>,
584 ) -> Result<Self, MessageParseError> {
585 assert_eq!(pair.as_rule(), Self::rule());
586 let mut pairs = pair.into_inner(); let condition = pairs.next().unwrap();
589 let result = pairs.next().unwrap();
590 assert!(pairs.next().is_none());
591
592 let ex1 = Some(Expression::parse_pair(extensions, condition)?);
593 let ex2 = Some(Expression::parse_pair(extensions, result)?);
594
595 Ok(IfClause {
596 r#if: ex1,
597 then: ex2,
598 })
599 }
600}
601
602impl ScopedParsePair for IfThen {
603 fn rule() -> Rule {
604 Rule::if_then
605 }
606 fn message() -> &'static str {
607 "IfThen"
608 }
609
610 fn parse_pair(
611 extensions: &SimpleExtensions,
612 pair: pest::iterators::Pair<Rule>,
613 ) -> Result<Self, MessageParseError> {
614 assert_eq!(pair.as_rule(), Self::rule());
615
616 let mut iter = RuleIter::from(pair.into_inner()); let mut ifs: Vec<IfClause> = Vec::new();
619
620 while let Some(p) = iter.try_pop(Rule::if_clause) {
622 let if_clause = IfClause::parse_pair(extensions, p)?;
623 ifs.push(if_clause);
624 }
625
626 let pair = iter.try_pop(Rule::expression).unwrap(); iter.done();
628 let else_clause = Some(Box::new(Expression::parse_pair(extensions, pair)?));
629
630 Ok(IfThen {
631 ifs,
632 r#else: else_clause,
633 })
634 }
635}
636pub struct Name(pub String);
637
638impl ParsePair for Name {
639 fn rule() -> Rule {
640 Rule::name
641 }
642
643 fn message() -> &'static str {
644 "Name"
645 }
646
647 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
648 assert_eq!(pair.as_rule(), Self::rule());
649 let inner = unwrap_single_pair(pair);
650 match inner.as_rule() {
651 Rule::identifier => Name(inner.as_str().to_string()),
652 Rule::quoted_name => Name(unescape_string(inner)),
653 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
654 }
655 }
656}
657
658impl ParsePair for CompoundName {
659 fn rule() -> Rule {
660 Rule::compound_name
661 }
662
663 fn message() -> &'static str {
664 "CompoundName"
665 }
666
667 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
668 assert_eq!(pair.as_rule(), Self::rule());
669 CompoundName::new(pair.as_str())
670 }
671}
672
673impl ScopedParsePair for Measure {
674 fn rule() -> Rule {
675 Rule::aggregate_measure
676 }
677
678 fn message() -> &'static str {
679 "Measure"
680 }
681
682 fn parse_pair(
683 extensions: &SimpleExtensions,
684 pair: pest::iterators::Pair<Rule>,
685 ) -> Result<Self, MessageParseError> {
686 assert_eq!(pair.as_rule(), Self::rule());
687
688 let function_call_pair = unwrap_single_pair(pair);
690 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
691
692 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
694 Ok(Measure {
695 measure: Some(AggregateFunction {
696 function_reference: scalar.function_reference,
697 arguments: scalar.arguments,
698 options: scalar.options,
699 output_type: scalar.output_type,
700 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
704 args: scalar.args,
705 }),
706 filter: None, })
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use pest::Parser as PestParser;
714
715 use super::*;
716 use crate::parser::ExpressionParser;
717
718 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
719 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
720 assert_eq!(pairs.as_str(), input);
721 let pair = pairs.next().unwrap();
722 assert_eq!(pairs.next(), None);
723 pair
724 }
725
726 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
727 let pair = parse_exact(T::rule(), input);
728 let actual = T::parse_pair(pair);
729 assert_eq!(actual, expected);
730 }
731
732 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
733 ext: &SimpleExtensions,
734 input: &str,
735 expected: T,
736 ) {
737 let pair = parse_exact(T::rule(), input);
738 let actual = T::parse_pair(ext, pair).unwrap();
739 assert_eq!(actual, expected);
740 }
741
742 #[test]
743 fn test_parse_field_reference() {
744 assert_parses_to("$1", FieldIndex(1).to_field_reference());
745 }
746
747 #[test]
748 fn test_parse_integer_literal() {
749 let extensions = SimpleExtensions::default();
750 let expected = Literal {
751 literal_type: Some(LiteralType::I64(1)),
752 nullable: false,
753 type_variation_reference: 0,
754 };
755 assert_parses_with(&extensions, "1", expected);
756 }
757
758 #[test]
759 fn test_parse_float_literal() {
760 let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
762 let parsed_text = pairs.as_str();
763 assert_eq!(parsed_text, "3.82");
764
765 let extensions = SimpleExtensions::default();
766 let expected = Literal {
767 literal_type: Some(LiteralType::Fp64(3.82)),
768 nullable: false,
769 type_variation_reference: 0,
770 };
771 assert_parses_with(&extensions, "3.82", expected);
772 }
773
774 #[test]
775 fn test_parse_negative_float_literal() {
776 let extensions = SimpleExtensions::default();
777 let expected = Literal {
778 literal_type: Some(LiteralType::Fp64(-2.5)),
779 nullable: false,
780 type_variation_reference: 0,
781 };
782 assert_parses_with(&extensions, "-2.5", expected);
783 }
784
785 #[test]
786 fn test_parse_boolean_true_literal() {
787 let extensions = SimpleExtensions::default();
788 let expected = Literal {
789 literal_type: Some(LiteralType::Boolean(true)),
790 nullable: false,
791 type_variation_reference: 0,
792 };
793 assert_parses_with(&extensions, "true", expected);
794 }
795
796 #[test]
797 fn test_parse_boolean_false_literal() {
798 let extensions = SimpleExtensions::default();
799 let expected = Literal {
800 literal_type: Some(LiteralType::Boolean(false)),
801 nullable: false,
802 type_variation_reference: 0,
803 };
804 assert_parses_with(&extensions, "false", expected);
805 }
806
807 #[test]
808 fn test_parse_nullable_boolean_literal() {
809 let extensions = SimpleExtensions::default();
810 let expected_true = Literal {
811 literal_type: Some(LiteralType::Boolean(true)),
812 nullable: true,
813 type_variation_reference: 0,
814 };
815 let expected_false = Literal {
816 literal_type: Some(LiteralType::Boolean(false)),
817 nullable: true,
818 type_variation_reference: 0,
819 };
820 assert_parses_with(&extensions, "true:boolean?", expected_true);
821 assert_parses_with(&extensions, "false:boolean?", expected_false);
822 }
823
824 #[test]
825 fn test_parse_nullable_integer_literal() {
826 let extensions = SimpleExtensions::default();
827 let expected_i32 = Literal {
828 literal_type: Some(LiteralType::I32(78)),
829 nullable: true,
830 type_variation_reference: 0,
831 };
832 let expected_i64 = Literal {
833 literal_type: Some(LiteralType::I64(42)),
834 nullable: true,
835 type_variation_reference: 0,
836 };
837 assert_parses_with(&extensions, "78:i32?", expected_i32);
838 assert_parses_with(&extensions, "42:i64?", expected_i64);
839 }
840
841 #[test]
842 fn test_parse_nullable_float_literal() {
843 let extensions = SimpleExtensions::default();
844 let expected_fp64 = Literal {
845 literal_type: Some(LiteralType::Fp64(3.19)),
846 nullable: true,
847 type_variation_reference: 0,
848 };
849 assert_parses_with(&extensions, "3.19:fp64?", expected_fp64);
850 }
851
852 #[test]
853 fn test_parse_float_literal_with_fp32_type() {
854 let extensions = SimpleExtensions::default();
855 let pair = parse_exact(Rule::literal, "3.82:fp32");
856 let result = Literal::parse_pair(&extensions, pair).unwrap();
857
858 match result.literal_type {
859 Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
860 _ => panic!("Expected Fp32 literal type"),
861 }
862 }
863
864 #[test]
865 fn test_parse_date_literal() {
866 let extensions = SimpleExtensions::default();
867 let pair = parse_exact(Rule::literal, "'2023-12-25':date");
868 let result = Literal::parse_pair(&extensions, pair).unwrap();
869
870 match result.literal_type {
871 Some(LiteralType::Date(days)) => {
872 assert!(
874 days > 0,
875 "Expected positive days since epoch, got: {}",
876 days
877 );
878 }
879 _ => panic!("Expected Date literal type, got: {:?}", result.literal_type),
880 }
881 }
882
883 #[test]
884 fn test_parse_time_literal() {
885 let extensions = SimpleExtensions::default();
886 let pair = parse_exact(Rule::literal, "'14:30:45':time");
887 let result = Literal::parse_pair(&extensions, pair).unwrap();
888
889 match result.literal_type {
890 Some(LiteralType::Time(microseconds)) => {
891 let expected = (14 * 3600 + 30 * 60 + 45) * 1_000_000;
893 assert_eq!(microseconds, expected);
894 }
895 _ => panic!("Expected Time literal type, got: {:?}", result.literal_type),
896 }
897 }
898
899 #[test]
900 fn test_parse_timestamp_literal_with_t() {
901 let extensions = SimpleExtensions::default();
902 let pair = parse_exact(Rule::literal, "'2023-01-01T12:00:00':timestamp");
903 let result = Literal::parse_pair(&extensions, pair).unwrap();
904
905 match result.literal_type {
906 #[allow(deprecated)]
907 Some(LiteralType::Timestamp(microseconds)) => {
908 assert!(
909 microseconds > 0,
910 "Expected positive microseconds since epoch"
911 );
912 }
913 _ => panic!(
914 "Expected Timestamp literal type, got: {:?}",
915 result.literal_type
916 ),
917 }
918 }
919
920 #[test]
921 fn test_parse_timestamp_literal_with_space() {
922 let extensions = SimpleExtensions::default();
923 let pair = parse_exact(Rule::literal, "'2023-01-01 12:00:00':timestamp");
924 let result = Literal::parse_pair(&extensions, pair).unwrap();
925
926 match result.literal_type {
927 #[allow(deprecated)]
928 Some(LiteralType::Timestamp(microseconds)) => {
929 assert!(
930 microseconds > 0,
931 "Expected positive microseconds since epoch"
932 );
933 }
934 _ => panic!(
935 "Expected Timestamp literal type, got: {:?}",
936 result.literal_type
937 ),
938 }
939 }
940
941 fn make_literal_bool(value: bool) -> Expression {
943 Expression {
944 rex_type: Some(RexType::Literal(Literal {
945 literal_type: Some(LiteralType::Boolean(value)),
946 nullable: false,
947 type_variation_reference: 0,
948 })),
949 }
950 }
951
952 #[test]
953 fn test_parse_if_then_single_clause() {
954 let extensions = SimpleExtensions::default();
955 let input = "if_then(true -> 42, _ -> 0)";
956 let pair = parse_exact(Rule::if_then, input);
957 let result = IfThen::parse_pair(&extensions, pair).unwrap();
958
959 assert_eq!(result.ifs.len(), 1);
960 assert!(result.r#else.is_some());
961 }
962
963 #[test]
964 fn test_parse_if_then_with_typed_literals() {
965 let extensions = SimpleExtensions::default();
966 let input = "if_then(true -> 100:i32, _ -> -100:i32)";
967 let pair = parse_exact(Rule::if_then, input);
968 let result = IfThen::parse_pair(&extensions, pair).unwrap();
969
970 assert_eq!(result.ifs.len(), 1);
971 assert!(result.r#else.is_some());
972 }
973
974 #[test]
975 fn test_parse_if_then_with_date_literals() {
976 let extensions = SimpleExtensions::default();
977 let input = "if_then(true -> '2023-12-25':date, _ -> '1970-01-01':date)";
978 let pair = parse_exact(Rule::if_then, input);
979 let result = IfThen::parse_pair(&extensions, pair).unwrap();
980
981 assert_eq!(result.ifs.len(), 1);
982 assert!(result.r#else.is_some());
983 }
984
985 #[test]
986 fn test_parse_if_then_with_time_literals() {
987 let extensions = SimpleExtensions::default();
988 let input = "if_then(true -> '14:30:45':time, _ -> '00:00:00':time)";
989 let pair = parse_exact(Rule::if_then, input);
990 let result = IfThen::parse_pair(&extensions, pair).unwrap();
991
992 assert_eq!(result.ifs.len(), 1);
993 assert!(result.r#else.is_some());
994 }
995
996 #[test]
997 fn test_parse_if_then_with_timestamp_literals() {
998 let extensions = SimpleExtensions::default();
999 let input = "if_then(true -> '2023-01-01T12:00:00':timestamp, _ -> '1970-01-01T00:00:00':timestamp)";
1000 let pair = parse_exact(Rule::if_then, input);
1001 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1002
1003 assert_eq!(result.ifs.len(), 1);
1004 assert!(result.r#else.is_some());
1005 }
1006
1007 #[test]
1008 fn test_parse_if_clause_with_whitespace_variations() {
1009 let extensions = SimpleExtensions::default();
1010
1011 let inputs = vec!["true->false", "true -> false", "true -> false"];
1013
1014 for input in inputs {
1015 let pair = parse_exact(Rule::if_clause, input);
1016 let result = IfClause::parse_pair(&extensions, pair).unwrap();
1017 assert!(result.r#if.is_some());
1018 assert!(result.then.is_some());
1019 }
1020 }
1021
1022 #[test]
1023 fn test_if_clause_structure() {
1024 let extensions = SimpleExtensions::default();
1025 let pair = parse_exact(Rule::if_clause, "42 -> 100");
1026 let result = IfClause::parse_pair(&extensions, pair).unwrap();
1027
1028 let if_expr = result.r#if.as_ref().unwrap();
1030 let then_expr = result.then.as_ref().unwrap();
1031
1032 match (&if_expr.rex_type, &then_expr.rex_type) {
1034 (Some(RexType::Literal(_)), Some(RexType::Literal(_))) => {
1035 }
1037 _ => panic!("Expected both if and then to be literals"),
1038 }
1039 }
1040
1041 #[test]
1042 fn test_if_then_structure() {
1043 let extensions = SimpleExtensions::default();
1044 let input = "if_then(true -> 1, false -> 2, _ -> 0)";
1045 let pair = parse_exact(Rule::if_then, input);
1046 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1047
1048 assert_eq!(result.ifs.len(), 2);
1050
1051 for clause in &result.ifs {
1053 assert!(clause.r#if.is_some(), "If clause condition should exist");
1054 assert!(clause.then.is_some(), "If clause result should exist");
1055 }
1056
1057 assert!(result.r#else.is_some(), "Else clause should exist");
1059 }
1060
1061 #[test]
1062 fn test_parse_if_then_mixed_types_in_conditions() {
1063 let extensions = SimpleExtensions::default();
1064 let input = "if_then(true -> 1, true -> 'yes', 'yes' -> true, 42 -> 2, $0 -> 3, _ -> 0)";
1066 let pair = parse_exact(Rule::if_then, input);
1067 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1068
1069 assert_eq!(result.ifs.len(), 5);
1070 assert!(result.r#else.is_some());
1071 }
1072
1073 #[test]
1074 fn test_if_then_preserves_clause_order() {
1075 let extensions = SimpleExtensions::default();
1076 let input = "if_then(1 -> 10, 2 -> 20, 3 -> 30, _ -> 0)";
1077 let pair = parse_exact(Rule::if_then, input);
1078 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1079
1080 assert_eq!(result.ifs.len(), 3);
1081
1082 for (i, clause) in result.ifs.iter().enumerate() {
1084 if let Some(Expression {
1085 rex_type: Some(RexType::Literal(lit)),
1086 }) = &clause.r#if
1087 && let Some(LiteralType::I64(val)) = &lit.literal_type
1088 {
1089 assert_eq!(*val, (i as i64) + 1);
1090 }
1091 }
1092 }
1093
1094 #[test]
1095 fn test_parse_if_then() {
1096 let extensions = SimpleExtensions::default();
1097
1098 let c1 = IfClause {
1099 r#if: Some(make_literal_bool(true)),
1100 then: Some(make_literal_bool(true)),
1101 };
1102
1103 let c2 = IfClause {
1104 r#if: Some(make_literal_bool(false)),
1105 then: Some(make_literal_bool(false)),
1106 };
1107
1108 let if_clause = IfThen {
1109 ifs: vec![c1, c2],
1110 r#else: Some(Box::new(make_literal_bool(false))),
1111 };
1112 assert_parses_with(
1113 &extensions,
1114 "if_then(true -> true , false -> false, _ -> false)",
1115 if_clause,
1116 );
1117 }
1118
1119 fn parse_compound_name(input: &str) -> CompoundName {
1122 let pair = parse_exact(Rule::compound_name, input);
1123 CompoundName::parse_pair(pair)
1124 }
1125
1126 #[test]
1127 fn test_compound_name_plain() {
1128 assert_eq!(parse_compound_name("add").full(), "add");
1129 }
1130
1131 #[test]
1132 fn test_compound_name_with_signature() {
1133 assert_eq!(parse_compound_name("equal:any_any").full(), "equal:any_any");
1134 assert_eq!(
1135 parse_compound_name("regexp_match_substring:str_str_i64").full(),
1136 "regexp_match_substring:str_str_i64"
1137 );
1138 assert_eq!(parse_compound_name("add:i64_i64").full(), "add:i64_i64");
1139 }
1140
1141 #[test]
1142 fn test_compound_name_stops_at_opening_paren() {
1143 let pairs = ExpressionParser::parse(Rule::compound_name, "equal:any_any").unwrap();
1147 assert_eq!(pairs.as_str(), "equal:any_any");
1148 }
1149
1150 fn make_extensions_for_fn_tests() -> SimpleExtensions {
1153 let mut exts = SimpleExtensions::default();
1154 exts.add_extension_urn("urn".to_string(), 1).unwrap();
1155 exts.add_extension(
1156 crate::extensions::simple::ExtensionKind::Function,
1157 1,
1158 1,
1159 "equal:any_any".to_string(),
1160 )
1161 .unwrap();
1162 exts.add_extension(
1163 crate::extensions::simple::ExtensionKind::Function,
1164 1,
1165 2,
1166 "equal:str_str".to_string(),
1167 )
1168 .unwrap();
1169 exts.add_extension(
1170 crate::extensions::simple::ExtensionKind::Function,
1171 1,
1172 3,
1173 "add:i64_i64".to_string(),
1174 )
1175 .unwrap();
1176 exts
1177 }
1178
1179 #[test]
1180 fn test_scalar_function_full_compound_name() {
1181 let exts = make_extensions_for_fn_tests();
1183 let pair = parse_exact(Rule::function_call, "equal:any_any($0, $1)");
1184 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1185 assert_eq!(f.function_reference, 1);
1186 assert_eq!(f.arguments.len(), 2);
1187 }
1188
1189 #[test]
1190 fn test_scalar_function_second_overload() {
1191 let exts = make_extensions_for_fn_tests();
1192 let pair = parse_exact(Rule::function_call, "equal:str_str($0, $1)");
1193 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1194
1195 assert_eq!(f.arguments.len(), 2);
1196 assert_eq!(f.function_reference, 2);
1197 }
1198
1199 #[test]
1200 fn test_scalar_function_base_name_unique_overload() {
1201 let exts = make_extensions_for_fn_tests();
1203 let pair = parse_exact(Rule::function_call, "add($0, $1)");
1204 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1205
1206 assert_eq!(f.arguments.len(), 2);
1207 assert_eq!(f.function_reference, 3);
1208 }
1209
1210 #[test]
1211 fn test_scalar_function_base_name_ambiguous_fails() {
1212 let exts = make_extensions_for_fn_tests();
1214 let pair = parse_exact(Rule::function_call, "equal($0, $1)");
1215 let result = ScalarFunction::parse_pair(&exts, pair);
1216 assert!(result.is_err(), "ambiguous base name should fail");
1217 }
1218
1219 #[test]
1220 fn test_scalar_function_compound_name_with_anchor() {
1221 let exts = make_extensions_for_fn_tests();
1222 let pair = parse_exact(Rule::function_call, "equal:any_any#1($0, $1)");
1223 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1224 assert_eq!(f.function_reference, 1);
1225 assert_eq!(f.arguments.len(), 2);
1226 }
1227
1228 #[test]
1229 fn test_scalar_function_base_name_with_anchor() {
1230 let exts = make_extensions_for_fn_tests();
1232 let pair = parse_exact(Rule::function_call, "equal#1($0, $1)");
1233 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1234 assert_eq!(f.function_reference, 1);
1235 assert_eq!(f.arguments.len(), 2);
1236 }
1237
1238 #[test]
1239 fn test_scalar_function_wrong_name_for_anchor_fails() {
1240 let exts = make_extensions_for_fn_tests();
1241 let pair = parse_exact(Rule::function_call, "like#1($0)");
1242 let result = ScalarFunction::parse_pair(&exts, pair);
1243 assert!(result.is_err(), "mismatched name/anchor should fail");
1244 }
1245
1246 #[test]
1247 fn test_parse_cast_expression_basic() {
1248 let extensions = SimpleExtensions::default();
1249 let pair = parse_exact(Rule::cast_expression, "(78:i32)::i16");
1250 let result = Cast::parse_pair(&extensions, pair).unwrap();
1251
1252 let input = result.input.as_ref().unwrap();
1254 match &input.rex_type {
1255 Some(RexType::Literal(lit)) => match &lit.literal_type {
1256 Some(LiteralType::I32(v)) => assert_eq!(*v, 78),
1257 other => panic!("Expected I32 literal, got: {:?}", other),
1258 },
1259 other => panic!("Expected literal, got: {:?}", other),
1260 }
1261
1262 let target = result.r#type.as_ref().unwrap();
1264 match &target.kind {
1265 Some(substrait::proto::r#type::Kind::I16(_)) => {}
1266 other => panic!("Expected i16 type, got: {:?}", other),
1267 }
1268
1269 assert_eq!(result.failure_behavior, 0);
1270 }
1271
1272 #[test]
1273 fn test_parse_cast_expression_via_expression_rule() {
1274 let extensions = SimpleExtensions::default();
1275 let pair = parse_exact(Rule::expression, "(78:i32)::i16");
1276 let result = Expression::parse_pair(&extensions, pair).unwrap();
1277
1278 match result.rex_type {
1279 Some(RexType::Cast(_)) => {}
1280 other => panic!("Expected Cast rex type, got: {:?}", other),
1281 }
1282 }
1283
1284 #[test]
1285 fn test_parse_cast_expression_nested() {
1286 let extensions = SimpleExtensions::default();
1287 let pair = parse_exact(Rule::cast_expression, "((78:i32)::i16)::i32");
1288 let result = Cast::parse_pair(&extensions, pair).unwrap();
1289
1290 let input = result.input.as_ref().unwrap();
1292 match &input.rex_type {
1293 Some(RexType::Cast(inner)) => {
1294 let inner_input = inner.input.as_ref().unwrap();
1295 match &inner_input.rex_type {
1296 Some(RexType::Literal(lit)) => match &lit.literal_type {
1297 Some(LiteralType::I32(v)) => assert_eq!(*v, 78),
1298 other => panic!("Expected I32 literal, got: {:?}", other),
1299 },
1300 other => panic!("Expected literal, got: {:?}", other),
1301 }
1302 }
1303 other => panic!("Expected inner Cast, got: {:?}", other),
1304 }
1305
1306 match &result.r#type.as_ref().unwrap().kind {
1307 Some(substrait::proto::r#type::Kind::I32(_)) => {}
1308 other => panic!("Expected i32 outer type, got: {:?}", other),
1309 }
1310 }
1311
1312 #[test]
1313 fn test_parse_cast_expression_with_boolean() {
1314 let extensions = SimpleExtensions::default();
1315 let pair = parse_exact(Rule::cast_expression, "(true)::i32");
1316 let result = Cast::parse_pair(&extensions, pair).unwrap();
1317
1318 let input = result.input.as_ref().unwrap();
1319 match &input.rex_type {
1320 Some(RexType::Literal(lit)) => match &lit.literal_type {
1321 Some(LiteralType::Boolean(v)) => assert!(*v),
1322 other => panic!("Expected Boolean literal, got: {:?}", other),
1323 },
1324 other => panic!("Expected literal, got: {:?}", other),
1325 }
1326 }
1327
1328 #[test]
1329 fn test_parse_cast_expression_with_whitespace() {
1330 let extensions = SimpleExtensions::default();
1331 let pair = parse_exact(Rule::cast_expression, "( 78:i32 ) :: i16");
1333 let result = Cast::parse_pair(&extensions, pair).unwrap();
1334 assert!(result.input.is_some());
1335 assert!(result.r#type.is_some());
1336 }
1337
1338 #[test]
1339 fn test_parse_cast_unspecified_failure_behavior() {
1340 let extensions = SimpleExtensions::default();
1341 let pair = parse_exact(Rule::cast_expression, "(78:i32)::i16");
1342 let result = Cast::parse_pair(&extensions, pair).unwrap();
1343 assert_eq!(
1344 result.failure_behavior,
1345 cast::FailureBehavior::Unspecified as i32
1346 );
1347 }
1348
1349 #[test]
1350 fn test_parse_cast_return_null_failure_behavior() {
1351 let extensions = SimpleExtensions::default();
1352 let pair = parse_exact(Rule::cast_expression, "(78:i32)::?i16");
1353 let result = Cast::parse_pair(&extensions, pair).unwrap();
1354 assert_eq!(
1355 result.failure_behavior,
1356 cast::FailureBehavior::ReturnNull as i32
1357 );
1358 }
1359
1360 #[test]
1361 fn test_parse_cast_throw_exception_failure_behavior() {
1362 let extensions = SimpleExtensions::default();
1363 let pair = parse_exact(Rule::cast_expression, "(78:i32)::!i16");
1364 let result = Cast::parse_pair(&extensions, pair).unwrap();
1365 assert_eq!(
1366 result.failure_behavior,
1367 cast::FailureBehavior::ThrowException as i32
1368 );
1369 }
1370}