1use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
2use substrait::proto::aggregate_rel::Measure;
3use substrait::proto::expression::field_reference::{ReferenceType, RootReference, RootType};
4use substrait::proto::expression::if_then::IfClause;
5use substrait::proto::expression::literal::LiteralType;
6use substrait::proto::expression::{
7 Cast, FieldReference, IfThen, Literal, ReferenceSegment, RexType, ScalarFunction, cast,
8 reference_segment,
9};
10use substrait::proto::function_argument::ArgType;
11use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
12use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
13
14use super::types::get_and_validate_anchor;
15use super::{
16 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
17 unwrap_single_pair,
18};
19use crate::extensions::SimpleExtensions;
20use crate::extensions::simple::{CompoundName, ExtensionKind};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct FieldIndex(pub i32);
25
26impl FieldIndex {
27 pub fn to_field_reference(self) -> FieldReference {
29 FieldReference {
32 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
33 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
34 reference_segment::StructField {
35 field: self.0,
36 child: None,
37 },
38 ))),
39 })),
40 root_type: Some(RootType::RootReference(RootReference {})),
41 }
42 }
43}
44
45impl ParsePair for FieldIndex {
46 fn rule() -> Rule {
47 Rule::reference
48 }
49
50 fn message() -> &'static str {
51 "FieldIndex"
52 }
53
54 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
55 assert_eq!(pair.as_rule(), Self::rule());
56 let inner = unwrap_single_pair(pair);
57 let index: i32 = inner.as_str().parse().unwrap();
58 FieldIndex(index)
59 }
60}
61
62impl ParsePair for FieldReference {
63 fn rule() -> Rule {
64 Rule::reference
65 }
66
67 fn message() -> &'static str {
68 "FieldReference"
69 }
70
71 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
72 assert_eq!(pair.as_rule(), Self::rule());
73
74 FieldIndex::parse_pair(pair).to_field_reference()
76 }
77}
78
79fn to_int_literal(
80 value: pest::iterators::Pair<Rule>,
81 typ: Option<Type>,
82) -> Result<Literal, MessageParseError> {
83 assert_eq!(value.as_rule(), Rule::integer);
84 let parsed_value: i64 = value.as_str().parse().unwrap();
85
86 const DEFAULT_KIND: Kind = Kind::I64(I64 {
87 type_variation_reference: 0,
88 nullability: Nullability::Required as i32,
89 });
90
91 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
93
94 let (lit, nullability, tvar) = match &kind {
95 Kind::I8(i) => (
97 LiteralType::I8(parsed_value as i32),
98 i.nullability,
99 i.type_variation_reference,
100 ),
101 Kind::I16(i) => (
102 LiteralType::I16(parsed_value as i32),
103 i.nullability,
104 i.type_variation_reference,
105 ),
106 Kind::I32(i) => (
107 LiteralType::I32(parsed_value as i32),
108 i.nullability,
109 i.type_variation_reference,
110 ),
111 Kind::I64(i) => (
112 LiteralType::I64(parsed_value),
113 i.nullability,
114 i.type_variation_reference,
115 ),
116 k => {
117 return Err(MessageParseError::invalid(
118 "int_literal_type",
119 value.as_span(),
120 format!("Invalid type for integer literal: {k:?}"),
121 ));
122 }
123 };
124
125 Ok(Literal {
126 literal_type: Some(lit),
127 nullable: nullability != Nullability::Required as i32,
128 type_variation_reference: tvar,
129 })
130}
131
132fn to_float_literal(
133 value: pest::iterators::Pair<Rule>,
134 typ: Option<Type>,
135) -> Result<Literal, MessageParseError> {
136 assert_eq!(value.as_rule(), Rule::float);
137 let parsed_value: f64 = value.as_str().parse().unwrap();
138
139 const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
140 type_variation_reference: 0,
141 nullability: Nullability::Required as i32,
142 });
143
144 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
146
147 let (lit, nullability, tvar) = match &kind {
148 Kind::Fp32(f) => (
149 LiteralType::Fp32(parsed_value as f32),
150 f.nullability,
151 f.type_variation_reference,
152 ),
153 Kind::Fp64(f) => (
154 LiteralType::Fp64(parsed_value),
155 f.nullability,
156 f.type_variation_reference,
157 ),
158 k => {
159 return Err(MessageParseError::invalid(
160 "float_literal_type",
161 value.as_span(),
162 format!("Invalid type for float literal: {k:?}"),
163 ));
164 }
165 };
166
167 Ok(Literal {
168 literal_type: Some(lit),
169 nullable: nullability != Nullability::Required as i32,
170 type_variation_reference: tvar,
171 })
172}
173
174fn to_boolean_literal(
175 value: pest::iterators::Pair<Rule>,
176 typ: Option<Type>,
177) -> Result<Literal, MessageParseError> {
178 assert_eq!(value.as_rule(), Rule::boolean);
179 let parsed_value: bool = value.as_str().parse().unwrap();
180
181 let (nullable, tvar) = match typ.and_then(|t| t.kind) {
182 Some(Kind::Bool(b)) => (
183 b.nullability != Nullability::Required as i32,
184 b.type_variation_reference,
185 ),
186 None => (false, 0),
187 Some(k) => {
188 return Err(MessageParseError::invalid(
189 "bool_literal_type",
190 value.as_span(),
191 format!("Invalid type for boolean literal: {k:?}"),
192 ));
193 }
194 };
195
196 Ok(Literal {
197 literal_type: Some(LiteralType::Boolean(parsed_value)),
198 nullable,
199 type_variation_reference: tvar,
200 })
201}
202
203fn to_string_literal(
204 value: pest::iterators::Pair<Rule>,
205 typ: Option<Type>,
206) -> Result<Literal, MessageParseError> {
207 assert_eq!(value.as_rule(), Rule::string_literal);
208 let string_value = unescape_string(value.clone());
209
210 let Some(typ) = typ else {
212 return Ok(Literal {
213 literal_type: Some(LiteralType::String(string_value)),
214 nullable: false,
215 type_variation_reference: 0,
216 });
217 };
218
219 let Some(kind) = typ.kind else {
220 return Ok(Literal {
221 literal_type: Some(LiteralType::String(string_value)),
222 nullable: false,
223 type_variation_reference: 0,
224 });
225 };
226
227 match &kind {
228 Kind::Date(d) => {
229 let date_days = parse_date_to_days(&string_value, value.as_span())?;
231 Ok(Literal {
232 literal_type: Some(LiteralType::Date(date_days)),
233 nullable: d.nullability != Nullability::Required as i32,
234 type_variation_reference: d.type_variation_reference,
235 })
236 }
237 #[allow(deprecated)]
238 Kind::Time(t) => {
239 let time_microseconds = parse_time_to_microseconds(&string_value, value.as_span())?;
241 Ok(Literal {
242 literal_type: Some(LiteralType::Time(time_microseconds)),
243 nullable: t.nullability != Nullability::Required as i32,
244 type_variation_reference: t.type_variation_reference,
245 })
246 }
247 #[allow(deprecated)]
248 Kind::Timestamp(ts) => {
249 let timestamp_microseconds =
251 parse_timestamp_to_microseconds(&string_value, value.as_span())?;
252 Ok(Literal {
253 literal_type: Some(LiteralType::Timestamp(timestamp_microseconds)),
254 nullable: ts.nullability != Nullability::Required as i32,
255 type_variation_reference: ts.type_variation_reference,
256 })
257 }
258 _ => {
259 Ok(Literal {
261 literal_type: Some(LiteralType::String(string_value)),
262 nullable: false,
263 type_variation_reference: 0,
264 })
265 }
266 }
267}
268
269fn to_null_literal(
270 value: pest::iterators::Pair<Rule>,
271 typ: Option<Type>,
272) -> Result<Literal, MessageParseError> {
273 assert_eq!(value.as_rule(), Rule::null);
274 let typ = typ.ok_or_else(|| {
275 MessageParseError::invalid(
276 "null_literal_type",
277 value.as_span(),
278 "Null literals require an explicit type annotation, e.g. null:i64?",
279 )
280 })?;
281
282 Ok(Literal {
283 literal_type: Some(LiteralType::Null(typ)),
284 nullable: false,
285 type_variation_reference: 0,
286 })
287}
288
289fn parse_date_to_days(date_str: &str, span: pest::Span) -> Result<i32, MessageParseError> {
291 let formats = ["%Y-%m-%d", "%Y/%m/%d"];
293
294 for format in &formats {
295 if let Ok(date) = NaiveDate::parse_from_str(date_str, format) {
296 let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
298 let days = date.signed_duration_since(epoch).num_days();
299 return Ok(days as i32);
300 }
301 }
302
303 Err(MessageParseError::invalid(
304 "date_parse_format",
305 span,
306 format!("Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD"),
307 ))
308}
309
310fn parse_time_to_microseconds(time_str: &str, span: pest::Span) -> Result<i64, MessageParseError> {
312 let formats = ["%H:%M:%S%.f", "%H:%M:%S"];
314
315 for format in &formats {
316 if let Ok(time) = NaiveTime::parse_from_str(time_str, format) {
317 let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap();
319 let duration = time.signed_duration_since(midnight);
320 return Ok(duration.num_microseconds().unwrap_or(0));
321 }
322 }
323
324 Err(MessageParseError::invalid(
325 "time_parse_format",
326 span,
327 format!("Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff"),
328 ))
329}
330
331fn parse_timestamp_to_microseconds(
333 timestamp_str: &str,
334 span: pest::Span,
335) -> Result<i64, MessageParseError> {
336 let formats = [
338 "%Y-%m-%dT%H:%M:%S%.f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S%.f", "%Y-%m-%d %H:%M:%S", "%Y/%m/%dT%H:%M:%S%.f", "%Y/%m/%dT%H:%M:%S", "%Y/%m/%d %H:%M:%S%.f", "%Y/%m/%d %H:%M:%S", ];
347
348 for format in &formats {
349 if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) {
350 let epoch = DateTime::from_timestamp(0, 0).unwrap().naive_utc();
352 let duration = datetime.signed_duration_since(epoch);
353 return Ok(duration.num_microseconds().unwrap_or(0));
354 }
355 }
356
357 Err(MessageParseError::invalid(
358 "timestamp_parse_format",
359 span,
360 format!(
361 "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS"
362 ),
363 ))
364}
365
366impl ScopedParsePair for Literal {
367 fn rule() -> Rule {
368 Rule::literal
369 }
370
371 fn message() -> &'static str {
372 "Literal"
373 }
374
375 fn parse_pair(
376 extensions: &SimpleExtensions,
377 pair: pest::iterators::Pair<Rule>,
378 ) -> Result<Self, MessageParseError> {
379 assert_eq!(pair.as_rule(), Self::rule());
380 let mut pairs = pair.into_inner();
381 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
384 let typ = match typ {
385 Some(t) => Some(Type::parse_pair(extensions, t)?),
386 None => None,
387 };
388 match value.as_rule() {
389 Rule::integer => to_int_literal(value, typ),
390 Rule::float => to_float_literal(value, typ),
391 Rule::boolean => to_boolean_literal(value, typ),
392 Rule::string_literal => to_string_literal(value, typ),
393 Rule::null => to_null_literal(value, typ),
394 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
395 }
396 }
397}
398
399impl ScopedParsePair for ScalarFunction {
400 fn rule() -> Rule {
401 Rule::function_call
402 }
403
404 fn message() -> &'static str {
405 "ScalarFunction"
406 }
407
408 fn parse_pair(
409 extensions: &SimpleExtensions,
410 pair: pest::iterators::Pair<Rule>,
411 ) -> Result<Self, MessageParseError> {
412 assert_eq!(pair.as_rule(), Self::rule());
413 let span = pair.as_span();
414 let mut iter = RuleIter::from(pair.into_inner());
415
416 let name = iter.parse_next::<CompoundName>();
418
419 let anchor = iter
421 .try_pop(Rule::anchor)
422 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
423
424 let _urn_anchor = iter
426 .try_pop(Rule::urn_anchor)
427 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
428
429 let argument_list = iter.pop(Rule::argument_list);
431 let mut arguments = Vec::new();
432 for e in argument_list.into_inner() {
433 arguments.push(FunctionArgument {
434 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
435 });
436 }
437
438 let output_type = Some(Type::parse_pair(extensions, iter.pop(Rule::r#type))?);
441
442 iter.done();
443 let anchor = get_and_validate_anchor(
444 extensions,
445 ExtensionKind::Function,
446 anchor,
447 name.full(),
448 span,
449 )?;
450 Ok(ScalarFunction {
451 function_reference: anchor,
452 arguments,
453 options: vec![], output_type,
455 #[allow(deprecated)]
456 args: vec![],
457 })
458 }
459}
460
461impl ScopedParsePair for Cast {
462 fn rule() -> Rule {
463 Rule::cast_expression
464 }
465
466 fn message() -> &'static str {
467 "Cast"
468 }
469
470 fn parse_pair(
471 extensions: &SimpleExtensions,
472 pair: pest::iterators::Pair<Rule>,
473 ) -> Result<Self, MessageParseError> {
474 assert_eq!(pair.as_rule(), Self::rule());
475 let mut pairs = pair.into_inner();
476
477 let expr_pair = pairs.next().unwrap();
478
479 let next = pairs.next().unwrap();
481 let (failure_behavior, type_pair) = if next.as_rule() == Rule::cast_failure_behavior {
482 let fb = match next.as_str() {
483 "?" => cast::FailureBehavior::ReturnNull as i32,
484 "!" => cast::FailureBehavior::ThrowException as i32,
485 _ => unreachable!("Grammar guarantees cast_failure_behavior is ? or !"),
486 };
487 (fb, pairs.next().unwrap())
488 } else {
489 (cast::FailureBehavior::Unspecified as i32, next)
490 };
491
492 assert!(pairs.next().is_none());
493
494 let input = Expression::parse_pair(extensions, expr_pair)?;
495 let target_type = Type::parse_pair(extensions, type_pair)?;
496
497 Ok(Cast {
498 r#type: Some(target_type),
499 input: Some(Box::new(input)),
500 failure_behavior,
501 })
502 }
503}
504
505impl ScopedParsePair for Expression {
506 fn rule() -> Rule {
507 Rule::expression
508 }
509
510 fn message() -> &'static str {
511 "Expression"
512 }
513
514 fn parse_pair(
515 extensions: &SimpleExtensions,
516 pair: pest::iterators::Pair<Rule>,
517 ) -> Result<Self, MessageParseError> {
518 assert_eq!(pair.as_rule(), Self::rule());
519 let inner = unwrap_single_pair(pair);
520 match inner.as_rule() {
521 Rule::literal => Ok(Expression {
522 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
523 }),
524 Rule::function_call => Ok(Expression {
525 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
526 extensions, inner,
527 )?)),
528 }),
529 Rule::reference => Ok(Expression {
530 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
531 inner,
532 )))),
533 }),
534 Rule::if_then => Ok(Expression {
535 rex_type: Some(RexType::IfThen(Box::new(IfThen::parse_pair(
536 extensions, inner,
537 )?))),
538 }),
539 Rule::cast_expression => Ok(Expression {
540 rex_type: Some(RexType::Cast(Box::new(Cast::parse_pair(
541 extensions, inner,
542 )?))),
543 }),
544 _ => unreachable!(
545 "Grammar guarantees expression can only be literal, function_call, reference, if_then, or cast_expression, got: {:?}",
546 inner.as_rule()
547 ),
548 }
549 }
550}
551
552impl ScopedParsePair for IfClause {
553 fn rule() -> Rule {
554 Rule::if_clause
555 }
556
557 fn message() -> &'static str {
558 "IfClause"
559 }
560
561 fn parse_pair(
562 extensions: &SimpleExtensions,
563 pair: pest::iterators::Pair<Rule>,
564 ) -> Result<Self, MessageParseError> {
565 assert_eq!(pair.as_rule(), Self::rule());
566 let mut pairs = pair.into_inner(); let condition = pairs.next().unwrap();
569 let result = pairs.next().unwrap();
570 assert!(pairs.next().is_none());
571
572 let ex1 = Some(Expression::parse_pair(extensions, condition)?);
573 let ex2 = Some(Expression::parse_pair(extensions, result)?);
574
575 Ok(IfClause {
576 r#if: ex1,
577 then: ex2,
578 })
579 }
580}
581
582impl ScopedParsePair for IfThen {
583 fn rule() -> Rule {
584 Rule::if_then
585 }
586 fn message() -> &'static str {
587 "IfThen"
588 }
589
590 fn parse_pair(
591 extensions: &SimpleExtensions,
592 pair: pest::iterators::Pair<Rule>,
593 ) -> Result<Self, MessageParseError> {
594 assert_eq!(pair.as_rule(), Self::rule());
595
596 let mut iter = RuleIter::from(pair.into_inner()); let mut ifs: Vec<IfClause> = Vec::new();
599
600 while let Some(p) = iter.try_pop(Rule::if_clause) {
602 let if_clause = IfClause::parse_pair(extensions, p)?;
603 ifs.push(if_clause);
604 }
605
606 let pair = iter.try_pop(Rule::expression).unwrap(); iter.done();
608 let else_clause = Some(Box::new(Expression::parse_pair(extensions, pair)?));
609
610 Ok(IfThen {
611 ifs,
612 r#else: else_clause,
613 })
614 }
615}
616pub struct Name(pub String);
617
618impl ParsePair for Name {
619 fn rule() -> Rule {
620 Rule::name
621 }
622
623 fn message() -> &'static str {
624 "Name"
625 }
626
627 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
628 assert_eq!(pair.as_rule(), Self::rule());
629 let inner = unwrap_single_pair(pair);
630 match inner.as_rule() {
631 Rule::identifier => Name(inner.as_str().to_string()),
632 Rule::quoted_name => Name(unescape_string(inner)),
633 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
634 }
635 }
636}
637
638impl ParsePair for CompoundName {
639 fn rule() -> Rule {
640 Rule::compound_name
641 }
642
643 fn message() -> &'static str {
644 "CompoundName"
645 }
646
647 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
648 assert_eq!(pair.as_rule(), Self::rule());
649 CompoundName::new(pair.as_str())
650 }
651}
652
653impl ScopedParsePair for Measure {
654 fn rule() -> Rule {
655 Rule::aggregate_measure
656 }
657
658 fn message() -> &'static str {
659 "Measure"
660 }
661
662 fn parse_pair(
663 extensions: &SimpleExtensions,
664 pair: pest::iterators::Pair<Rule>,
665 ) -> Result<Self, MessageParseError> {
666 assert_eq!(pair.as_rule(), Self::rule());
667
668 let function_call_pair = unwrap_single_pair(pair);
670 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
671
672 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
674 Ok(Measure {
675 measure: Some(AggregateFunction {
676 function_reference: scalar.function_reference,
677 arguments: scalar.arguments,
678 options: scalar.options,
679 output_type: scalar.output_type,
680 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
684 args: scalar.args,
685 }),
686 filter: None, })
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use pest::Parser as PestParser;
694
695 use super::*;
696 use crate::parser::ExpressionParser;
697
698 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
699 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
700 assert_eq!(pairs.as_str(), input);
701 let pair = pairs.next().unwrap();
702 assert_eq!(pairs.next(), None);
703 pair
704 }
705
706 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
707 let pair = parse_exact(T::rule(), input);
708 let actual = T::parse_pair(pair);
709 assert_eq!(actual, expected);
710 }
711
712 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
713 ext: &SimpleExtensions,
714 input: &str,
715 expected: T,
716 ) {
717 let pair = parse_exact(T::rule(), input);
718 let actual = T::parse_pair(ext, pair).unwrap();
719 assert_eq!(actual, expected);
720 }
721
722 #[test]
723 fn test_parse_field_reference() {
724 assert_parses_to("$1", FieldIndex(1).to_field_reference());
725 }
726
727 #[test]
728 fn test_parse_integer_literal() {
729 let extensions = SimpleExtensions::default();
730 let expected = Literal {
731 literal_type: Some(LiteralType::I64(1)),
732 nullable: false,
733 type_variation_reference: 0,
734 };
735 assert_parses_with(&extensions, "1", expected);
736 }
737
738 #[test]
739 fn test_parse_float_literal() {
740 let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
742 let parsed_text = pairs.as_str();
743 assert_eq!(parsed_text, "3.82");
744
745 let extensions = SimpleExtensions::default();
746 let expected = Literal {
747 literal_type: Some(LiteralType::Fp64(3.82)),
748 nullable: false,
749 type_variation_reference: 0,
750 };
751 assert_parses_with(&extensions, "3.82", expected);
752 }
753
754 #[test]
755 fn test_parse_negative_float_literal() {
756 let extensions = SimpleExtensions::default();
757 let expected = Literal {
758 literal_type: Some(LiteralType::Fp64(-2.5)),
759 nullable: false,
760 type_variation_reference: 0,
761 };
762 assert_parses_with(&extensions, "-2.5", expected);
763 }
764
765 #[test]
766 fn test_parse_boolean_true_literal() {
767 let extensions = SimpleExtensions::default();
768 let expected = Literal {
769 literal_type: Some(LiteralType::Boolean(true)),
770 nullable: false,
771 type_variation_reference: 0,
772 };
773 assert_parses_with(&extensions, "true", expected);
774 }
775
776 #[test]
777 fn test_parse_boolean_false_literal() {
778 let extensions = SimpleExtensions::default();
779 let expected = Literal {
780 literal_type: Some(LiteralType::Boolean(false)),
781 nullable: false,
782 type_variation_reference: 0,
783 };
784 assert_parses_with(&extensions, "false", expected);
785 }
786
787 #[test]
788 fn test_parse_nullable_boolean_literal() {
789 let extensions = SimpleExtensions::default();
790 let expected_true = Literal {
791 literal_type: Some(LiteralType::Boolean(true)),
792 nullable: true,
793 type_variation_reference: 0,
794 };
795 let expected_false = Literal {
796 literal_type: Some(LiteralType::Boolean(false)),
797 nullable: true,
798 type_variation_reference: 0,
799 };
800 assert_parses_with(&extensions, "true:boolean?", expected_true);
801 assert_parses_with(&extensions, "false:boolean?", expected_false);
802 }
803
804 #[test]
805 fn test_parse_nullable_integer_literal() {
806 let extensions = SimpleExtensions::default();
807 let expected_i32 = Literal {
808 literal_type: Some(LiteralType::I32(78)),
809 nullable: true,
810 type_variation_reference: 0,
811 };
812 let expected_i64 = Literal {
813 literal_type: Some(LiteralType::I64(42)),
814 nullable: true,
815 type_variation_reference: 0,
816 };
817 assert_parses_with(&extensions, "78:i32?", expected_i32);
818 assert_parses_with(&extensions, "42:i64?", expected_i64);
819 }
820
821 #[test]
822 fn test_parse_nullable_float_literal() {
823 let extensions = SimpleExtensions::default();
824 let expected_fp64 = Literal {
825 literal_type: Some(LiteralType::Fp64(3.19)),
826 nullable: true,
827 type_variation_reference: 0,
828 };
829 assert_parses_with(&extensions, "3.19:fp64?", expected_fp64);
830 }
831
832 #[test]
833 fn test_parse_float_literal_with_fp32_type() {
834 let extensions = SimpleExtensions::default();
835 let pair = parse_exact(Rule::literal, "3.82:fp32");
836 let result = Literal::parse_pair(&extensions, pair).unwrap();
837
838 match result.literal_type {
839 Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
840 _ => panic!("Expected Fp32 literal type"),
841 }
842 }
843
844 #[test]
845 fn test_parse_date_literal() {
846 let extensions = SimpleExtensions::default();
847 let pair = parse_exact(Rule::literal, "'2023-12-25':date");
848 let result = Literal::parse_pair(&extensions, pair).unwrap();
849
850 match result.literal_type {
851 Some(LiteralType::Date(days)) => {
852 assert!(
854 days > 0,
855 "Expected positive days since epoch, got: {}",
856 days
857 );
858 }
859 _ => panic!("Expected Date literal type, got: {:?}", result.literal_type),
860 }
861 }
862
863 #[test]
864 fn test_parse_time_literal() {
865 let extensions = SimpleExtensions::default();
866 let pair = parse_exact(Rule::literal, "'14:30:45':time");
867 let result = Literal::parse_pair(&extensions, pair).unwrap();
868
869 match result.literal_type {
870 #[allow(deprecated)]
871 Some(LiteralType::Time(microseconds)) => {
872 let expected = (14 * 3600 + 30 * 60 + 45) * 1_000_000;
874 assert_eq!(microseconds, expected);
875 }
876 _ => panic!("Expected Time literal type, got: {:?}", result.literal_type),
877 }
878 }
879
880 #[test]
881 fn test_parse_timestamp_literal_with_t() {
882 let extensions = SimpleExtensions::default();
883 let pair = parse_exact(Rule::literal, "'2023-01-01T12:00:00':timestamp");
884 let result = Literal::parse_pair(&extensions, pair).unwrap();
885
886 match result.literal_type {
887 #[allow(deprecated)]
888 Some(LiteralType::Timestamp(microseconds)) => {
889 assert!(
890 microseconds > 0,
891 "Expected positive microseconds since epoch"
892 );
893 }
894 _ => panic!(
895 "Expected Timestamp literal type, got: {:?}",
896 result.literal_type
897 ),
898 }
899 }
900
901 #[test]
902 fn test_parse_timestamp_literal_with_space() {
903 let extensions = SimpleExtensions::default();
904 let pair = parse_exact(Rule::literal, "'2023-01-01 12:00:00':timestamp");
905 let result = Literal::parse_pair(&extensions, pair).unwrap();
906
907 match result.literal_type {
908 #[allow(deprecated)]
909 Some(LiteralType::Timestamp(microseconds)) => {
910 assert!(
911 microseconds > 0,
912 "Expected positive microseconds since epoch"
913 );
914 }
915 _ => panic!(
916 "Expected Timestamp literal type, got: {:?}",
917 result.literal_type
918 ),
919 }
920 }
921
922 fn make_literal_bool(value: bool) -> Expression {
924 Expression {
925 rex_type: Some(RexType::Literal(Literal {
926 literal_type: Some(LiteralType::Boolean(value)),
927 nullable: false,
928 type_variation_reference: 0,
929 })),
930 }
931 }
932
933 #[test]
934 fn test_parse_if_then_single_clause() {
935 let extensions = SimpleExtensions::default();
936 let input = "if_then(true -> 42, _ -> 0)";
937 let pair = parse_exact(Rule::if_then, input);
938 let result = IfThen::parse_pair(&extensions, pair).unwrap();
939
940 assert_eq!(result.ifs.len(), 1);
941 assert!(result.r#else.is_some());
942 }
943
944 #[test]
945 fn test_parse_if_then_with_typed_literals() {
946 let extensions = SimpleExtensions::default();
947 let input = "if_then(true -> 100:i32, _ -> -100:i32)";
948 let pair = parse_exact(Rule::if_then, input);
949 let result = IfThen::parse_pair(&extensions, pair).unwrap();
950
951 assert_eq!(result.ifs.len(), 1);
952 assert!(result.r#else.is_some());
953 }
954
955 #[test]
956 fn test_parse_if_then_with_date_literals() {
957 let extensions = SimpleExtensions::default();
958 let input = "if_then(true -> '2023-12-25':date, _ -> '1970-01-01':date)";
959 let pair = parse_exact(Rule::if_then, input);
960 let result = IfThen::parse_pair(&extensions, pair).unwrap();
961
962 assert_eq!(result.ifs.len(), 1);
963 assert!(result.r#else.is_some());
964 }
965
966 #[test]
967 fn test_parse_if_then_with_time_literals() {
968 let extensions = SimpleExtensions::default();
969 let input = "if_then(true -> '14:30:45':time, _ -> '00:00:00':time)";
970 let pair = parse_exact(Rule::if_then, input);
971 let result = IfThen::parse_pair(&extensions, pair).unwrap();
972
973 assert_eq!(result.ifs.len(), 1);
974 assert!(result.r#else.is_some());
975 }
976
977 #[test]
978 fn test_parse_if_then_with_timestamp_literals() {
979 let extensions = SimpleExtensions::default();
980 let input = "if_then(true -> '2023-01-01T12:00:00':timestamp, _ -> '1970-01-01T00:00:00':timestamp)";
981 let pair = parse_exact(Rule::if_then, input);
982 let result = IfThen::parse_pair(&extensions, pair).unwrap();
983
984 assert_eq!(result.ifs.len(), 1);
985 assert!(result.r#else.is_some());
986 }
987
988 #[test]
989 fn test_parse_if_clause_with_whitespace_variations() {
990 let extensions = SimpleExtensions::default();
991
992 let inputs = vec!["true->false", "true -> false", "true -> false"];
994
995 for input in inputs {
996 let pair = parse_exact(Rule::if_clause, input);
997 let result = IfClause::parse_pair(&extensions, pair).unwrap();
998 assert!(result.r#if.is_some());
999 assert!(result.then.is_some());
1000 }
1001 }
1002
1003 #[test]
1004 fn test_if_clause_structure() {
1005 let extensions = SimpleExtensions::default();
1006 let pair = parse_exact(Rule::if_clause, "42 -> 100");
1007 let result = IfClause::parse_pair(&extensions, pair).unwrap();
1008
1009 let if_expr = result.r#if.as_ref().unwrap();
1011 let then_expr = result.then.as_ref().unwrap();
1012
1013 match (&if_expr.rex_type, &then_expr.rex_type) {
1015 (Some(RexType::Literal(_)), Some(RexType::Literal(_))) => {
1016 }
1018 _ => panic!("Expected both if and then to be literals"),
1019 }
1020 }
1021
1022 #[test]
1023 fn test_if_then_structure() {
1024 let extensions = SimpleExtensions::default();
1025 let input = "if_then(true -> 1, false -> 2, _ -> 0)";
1026 let pair = parse_exact(Rule::if_then, input);
1027 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1028
1029 assert_eq!(result.ifs.len(), 2);
1031
1032 for clause in &result.ifs {
1034 assert!(clause.r#if.is_some(), "If clause condition should exist");
1035 assert!(clause.then.is_some(), "If clause result should exist");
1036 }
1037
1038 assert!(result.r#else.is_some(), "Else clause should exist");
1040 }
1041
1042 #[test]
1043 fn test_parse_if_then_mixed_types_in_conditions() {
1044 let extensions = SimpleExtensions::default();
1045 let input = "if_then(true -> 1, true -> 'yes', 'yes' -> true, 42 -> 2, $0 -> 3, _ -> 0)";
1047 let pair = parse_exact(Rule::if_then, input);
1048 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1049
1050 assert_eq!(result.ifs.len(), 5);
1051 assert!(result.r#else.is_some());
1052 }
1053
1054 #[test]
1055 fn test_if_then_preserves_clause_order() {
1056 let extensions = SimpleExtensions::default();
1057 let input = "if_then(1 -> 10, 2 -> 20, 3 -> 30, _ -> 0)";
1058 let pair = parse_exact(Rule::if_then, input);
1059 let result = IfThen::parse_pair(&extensions, pair).unwrap();
1060
1061 assert_eq!(result.ifs.len(), 3);
1062
1063 for (i, clause) in result.ifs.iter().enumerate() {
1065 if let Some(Expression {
1066 rex_type: Some(RexType::Literal(lit)),
1067 }) = &clause.r#if
1068 && let Some(LiteralType::I64(val)) = &lit.literal_type
1069 {
1070 assert_eq!(*val, (i as i64) + 1);
1071 }
1072 }
1073 }
1074
1075 #[test]
1076 fn test_parse_if_then() {
1077 let extensions = SimpleExtensions::default();
1078
1079 let c1 = IfClause {
1080 r#if: Some(make_literal_bool(true)),
1081 then: Some(make_literal_bool(true)),
1082 };
1083
1084 let c2 = IfClause {
1085 r#if: Some(make_literal_bool(false)),
1086 then: Some(make_literal_bool(false)),
1087 };
1088
1089 let if_clause = IfThen {
1090 ifs: vec![c1, c2],
1091 r#else: Some(Box::new(make_literal_bool(false))),
1092 };
1093 assert_parses_with(
1094 &extensions,
1095 "if_then(true -> true , false -> false, _ -> false)",
1096 if_clause,
1097 );
1098 }
1099
1100 fn parse_compound_name(input: &str) -> CompoundName {
1103 let pair = parse_exact(Rule::compound_name, input);
1104 CompoundName::parse_pair(pair)
1105 }
1106
1107 #[test]
1108 fn test_compound_name_plain() {
1109 assert_eq!(parse_compound_name("add").full(), "add");
1110 }
1111
1112 #[test]
1113 fn test_compound_name_full_zero_arg_type_signature() {
1114 let n = parse_compound_name("add:");
1116 assert_eq!(n.full(), "add:");
1117 assert_eq!(n.base(), "add");
1118 assert!(n.matches("add:"));
1119 assert!(!n.matches("add:i64_i64"));
1120 assert!(n.matches("add"));
1121 }
1122
1123 #[test]
1124 fn test_compound_name_with_signature() {
1125 assert_eq!(parse_compound_name("equal:any_any").full(), "equal:any_any");
1126 assert_eq!(
1127 parse_compound_name("regexp_match_substring:str_str_i64").full(),
1128 "regexp_match_substring:str_str_i64"
1129 );
1130 assert_eq!(parse_compound_name("add:i64_i64").full(), "add:i64_i64");
1131 }
1132
1133 #[test]
1134 fn test_compound_name_stops_at_opening_paren() {
1135 let pairs = ExpressionParser::parse(Rule::compound_name, "equal:any_any").unwrap();
1139 assert_eq!(pairs.as_str(), "equal:any_any");
1140 }
1141
1142 fn make_extensions_for_fn_tests() -> SimpleExtensions {
1145 let mut exts = SimpleExtensions::default();
1146 exts.add_extension_urn("urn".to_string(), 1).unwrap();
1147 exts.add_extension(
1148 crate::extensions::simple::ExtensionKind::Function,
1149 1,
1150 1,
1151 "equal:any_any".to_string(),
1152 )
1153 .unwrap();
1154 exts.add_extension(
1155 crate::extensions::simple::ExtensionKind::Function,
1156 1,
1157 2,
1158 "equal:str_str".to_string(),
1159 )
1160 .unwrap();
1161 exts.add_extension(
1162 crate::extensions::simple::ExtensionKind::Function,
1163 1,
1164 3,
1165 "add:i64_i64".to_string(),
1166 )
1167 .unwrap();
1168 exts
1169 }
1170
1171 #[test]
1172 fn test_scalar_function_full_compound_name() {
1173 let exts = make_extensions_for_fn_tests();
1175 let pair = parse_exact(Rule::function_call, "equal:any_any($0, $1):boolean");
1176 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1177 assert_eq!(f.function_reference, 1);
1178 assert_eq!(f.arguments.len(), 2);
1179 assert!(
1180 f.output_type.is_some(),
1181 "output_type must be set after parsing"
1182 );
1183 }
1184
1185 #[test]
1186 fn test_scalar_function_second_overload() {
1187 let exts = make_extensions_for_fn_tests();
1188 let pair = parse_exact(Rule::function_call, "equal:str_str($0, $1):boolean");
1189 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1190
1191 assert_eq!(f.arguments.len(), 2);
1192 assert_eq!(f.function_reference, 2);
1193 }
1194
1195 #[test]
1196 fn test_scalar_function_base_name_unique_overload() {
1197 let exts = make_extensions_for_fn_tests();
1199 let pair = parse_exact(Rule::function_call, "add($0, $1):i64");
1200 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1201
1202 assert_eq!(f.arguments.len(), 2);
1203 assert_eq!(f.function_reference, 3);
1204 assert!(
1205 f.output_type.is_some(),
1206 "output_type must be set after parsing"
1207 );
1208 }
1209
1210 #[test]
1211 fn test_scalar_function_base_name_ambiguous_fails() {
1212 let exts = make_extensions_for_fn_tests();
1214 let pair = parse_exact(Rule::function_call, "equal($0, $1):boolean");
1215 let result = ScalarFunction::parse_pair(&exts, pair);
1216 assert!(result.is_err(), "ambiguous base name should fail");
1217 }
1218
1219 #[test]
1220 fn test_scalar_function_compound_name_with_anchor() {
1221 let exts = make_extensions_for_fn_tests();
1222 let pair = parse_exact(Rule::function_call, "equal:any_any#1($0, $1):boolean");
1223 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1224 assert_eq!(f.function_reference, 1);
1225 assert_eq!(f.arguments.len(), 2);
1226 }
1227
1228 #[test]
1229 fn test_scalar_function_base_name_with_anchor() {
1230 let exts = make_extensions_for_fn_tests();
1232 let pair = parse_exact(Rule::function_call, "equal#1($0, $1):boolean");
1233 let f = ScalarFunction::parse_pair(&exts, pair).unwrap();
1234 assert_eq!(f.function_reference, 1);
1235 assert_eq!(f.arguments.len(), 2);
1236 }
1237
1238 #[test]
1239 fn test_scalar_function_wrong_name_for_anchor_fails() {
1240 let exts = make_extensions_for_fn_tests();
1241 let pair = parse_exact(Rule::function_call, "like#1($0):boolean");
1242 let result = ScalarFunction::parse_pair(&exts, pair);
1243 assert!(result.is_err(), "mismatched name/anchor should fail");
1244 }
1245
1246 #[test]
1247 fn test_scalar_function_missing_type_fails_to_parse() {
1248 let result = ExpressionParser::parse(Rule::function_call, "add($0, $1)");
1250 assert!(
1251 result.is_err(),
1252 "function call without type annotation should fail to parse"
1253 );
1254 }
1255
1256 #[test]
1257 fn test_parse_cast_expression_basic() {
1258 let extensions = SimpleExtensions::default();
1259 let pair = parse_exact(Rule::cast_expression, "(78:i32)::i16");
1260 let result = Cast::parse_pair(&extensions, pair).unwrap();
1261
1262 let input = result.input.as_ref().unwrap();
1264 match &input.rex_type {
1265 Some(RexType::Literal(lit)) => match &lit.literal_type {
1266 Some(LiteralType::I32(v)) => assert_eq!(*v, 78),
1267 other => panic!("Expected I32 literal, got: {:?}", other),
1268 },
1269 other => panic!("Expected literal, got: {:?}", other),
1270 }
1271
1272 let target = result.r#type.as_ref().unwrap();
1274 match &target.kind {
1275 Some(substrait::proto::r#type::Kind::I16(_)) => {}
1276 other => panic!("Expected i16 type, got: {:?}", other),
1277 }
1278
1279 assert_eq!(result.failure_behavior, 0);
1280 }
1281
1282 #[test]
1283 fn test_parse_cast_expression_via_expression_rule() {
1284 let extensions = SimpleExtensions::default();
1285 let pair = parse_exact(Rule::expression, "(78:i32)::i16");
1286 let result = Expression::parse_pair(&extensions, pair).unwrap();
1287
1288 match result.rex_type {
1289 Some(RexType::Cast(_)) => {}
1290 other => panic!("Expected Cast rex type, got: {:?}", other),
1291 }
1292 }
1293
1294 #[test]
1295 fn test_parse_cast_expression_nested() {
1296 let extensions = SimpleExtensions::default();
1297 let pair = parse_exact(Rule::cast_expression, "((78:i32)::i16)::i32");
1298 let result = Cast::parse_pair(&extensions, pair).unwrap();
1299
1300 let input = result.input.as_ref().unwrap();
1302 match &input.rex_type {
1303 Some(RexType::Cast(inner)) => {
1304 let inner_input = inner.input.as_ref().unwrap();
1305 match &inner_input.rex_type {
1306 Some(RexType::Literal(lit)) => match &lit.literal_type {
1307 Some(LiteralType::I32(v)) => assert_eq!(*v, 78),
1308 other => panic!("Expected I32 literal, got: {:?}", other),
1309 },
1310 other => panic!("Expected literal, got: {:?}", other),
1311 }
1312 }
1313 other => panic!("Expected inner Cast, got: {:?}", other),
1314 }
1315
1316 match &result.r#type.as_ref().unwrap().kind {
1317 Some(substrait::proto::r#type::Kind::I32(_)) => {}
1318 other => panic!("Expected i32 outer type, got: {:?}", other),
1319 }
1320 }
1321
1322 #[test]
1323 fn test_parse_cast_expression_with_boolean() {
1324 let extensions = SimpleExtensions::default();
1325 let pair = parse_exact(Rule::cast_expression, "(true)::i32");
1326 let result = Cast::parse_pair(&extensions, pair).unwrap();
1327
1328 let input = result.input.as_ref().unwrap();
1329 match &input.rex_type {
1330 Some(RexType::Literal(lit)) => match &lit.literal_type {
1331 Some(LiteralType::Boolean(v)) => assert!(*v),
1332 other => panic!("Expected Boolean literal, got: {:?}", other),
1333 },
1334 other => panic!("Expected literal, got: {:?}", other),
1335 }
1336 }
1337
1338 #[test]
1339 fn test_parse_cast_expression_with_whitespace() {
1340 let extensions = SimpleExtensions::default();
1341 let pair = parse_exact(Rule::cast_expression, "( 78:i32 ) :: i16");
1343 let result = Cast::parse_pair(&extensions, pair).unwrap();
1344 assert!(result.input.is_some());
1345 assert!(result.r#type.is_some());
1346 }
1347
1348 #[test]
1349 fn test_parse_cast_unspecified_failure_behavior() {
1350 let extensions = SimpleExtensions::default();
1351 let pair = parse_exact(Rule::cast_expression, "(78:i32)::i16");
1352 let result = Cast::parse_pair(&extensions, pair).unwrap();
1353 assert_eq!(
1354 result.failure_behavior,
1355 cast::FailureBehavior::Unspecified as i32
1356 );
1357 }
1358
1359 #[test]
1360 fn test_parse_cast_return_null_failure_behavior() {
1361 let extensions = SimpleExtensions::default();
1362 let pair = parse_exact(Rule::cast_expression, "(78:i32)::?i16");
1363 let result = Cast::parse_pair(&extensions, pair).unwrap();
1364 assert_eq!(
1365 result.failure_behavior,
1366 cast::FailureBehavior::ReturnNull as i32
1367 );
1368 }
1369
1370 #[test]
1371 fn test_parse_cast_throw_exception_failure_behavior() {
1372 let extensions = SimpleExtensions::default();
1373 let pair = parse_exact(Rule::cast_expression, "(78:i32)::!i16");
1374 let result = Cast::parse_pair(&extensions, pair).unwrap();
1375 assert_eq!(
1376 result.failure_behavior,
1377 cast::FailureBehavior::ThrowException as i32
1378 );
1379 }
1380}