1use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
2use substrait::proto::aggregate_rel::Measure;
3use substrait::proto::expression::field_reference::ReferenceType;
4use substrait::proto::expression::if_then::IfClause;
5use substrait::proto::expression::literal::LiteralType;
6use substrait::proto::expression::{
7 FieldReference, IfThen, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
8};
9use substrait::proto::function_argument::ArgType;
10use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
11use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
12
13use super::types::get_and_validate_anchor;
14use super::{
15 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
16 unwrap_single_pair,
17};
18use crate::extensions::SimpleExtensions;
19use crate::extensions::simple::ExtensionKind;
20use crate::parser::ErrorKind;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct FieldIndex(pub i32);
25
26impl FieldIndex {
27 pub fn to_field_reference(self) -> FieldReference {
29 FieldReference {
32 reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
33 reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
34 reference_segment::StructField {
35 field: self.0,
36 child: None,
37 },
38 ))),
39 })),
40 root_type: None,
41 }
42 }
43}
44
45impl ParsePair for FieldIndex {
46 fn rule() -> Rule {
47 Rule::reference
48 }
49
50 fn message() -> &'static str {
51 "FieldIndex"
52 }
53
54 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
55 assert_eq!(pair.as_rule(), Self::rule());
56 let inner = unwrap_single_pair(pair);
57 let index: i32 = inner.as_str().parse().unwrap();
58 FieldIndex(index)
59 }
60}
61
62impl ParsePair for FieldReference {
63 fn rule() -> Rule {
64 Rule::reference
65 }
66
67 fn message() -> &'static str {
68 "FieldReference"
69 }
70
71 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
72 assert_eq!(pair.as_rule(), Self::rule());
73
74 FieldIndex::parse_pair(pair).to_field_reference()
76 }
77}
78
79fn to_int_literal(
80 value: pest::iterators::Pair<Rule>,
81 typ: Option<Type>,
82) -> Result<Literal, MessageParseError> {
83 assert_eq!(value.as_rule(), Rule::integer);
84 let parsed_value: i64 = value.as_str().parse().unwrap();
85
86 const DEFAULT_KIND: Kind = Kind::I64(I64 {
87 type_variation_reference: 0,
88 nullability: Nullability::Required as i32,
89 });
90
91 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
93
94 let (lit, nullability, tvar) = match &kind {
95 Kind::I8(i) => (
97 LiteralType::I8(parsed_value as i32),
98 i.nullability,
99 i.type_variation_reference,
100 ),
101 Kind::I16(i) => (
102 LiteralType::I16(parsed_value as i32),
103 i.nullability,
104 i.type_variation_reference,
105 ),
106 Kind::I32(i) => (
107 LiteralType::I32(parsed_value as i32),
108 i.nullability,
109 i.type_variation_reference,
110 ),
111 Kind::I64(i) => (
112 LiteralType::I64(parsed_value),
113 i.nullability,
114 i.type_variation_reference,
115 ),
116 k => {
117 let pest_error = pest::error::Error::new_from_span(
118 pest::error::ErrorVariant::CustomError {
119 message: format!("Invalid type for integer literal: {k:?}"),
120 },
121 value.as_span(),
122 );
123 let error = MessageParseError {
124 message: "int_literal_type",
125 kind: ErrorKind::InvalidValue,
126 error: Box::new(pest_error),
127 };
128 return Err(error);
129 }
130 };
131
132 Ok(Literal {
133 literal_type: Some(lit),
134 nullable: nullability != Nullability::Required as i32,
135 type_variation_reference: tvar,
136 })
137}
138
139fn to_float_literal(
140 value: pest::iterators::Pair<Rule>,
141 typ: Option<Type>,
142) -> Result<Literal, MessageParseError> {
143 assert_eq!(value.as_rule(), Rule::float);
144 let parsed_value: f64 = value.as_str().parse().unwrap();
145
146 const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
147 type_variation_reference: 0,
148 nullability: Nullability::Required as i32,
149 });
150
151 let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
153
154 let (lit, nullability, tvar) = match &kind {
155 Kind::Fp32(f) => (
156 LiteralType::Fp32(parsed_value as f32),
157 f.nullability,
158 f.type_variation_reference,
159 ),
160 Kind::Fp64(f) => (
161 LiteralType::Fp64(parsed_value),
162 f.nullability,
163 f.type_variation_reference,
164 ),
165 k => {
166 let pest_error = pest::error::Error::new_from_span(
167 pest::error::ErrorVariant::CustomError {
168 message: format!("Invalid type for float literal: {k:?}"),
169 },
170 value.as_span(),
171 );
172 let error = MessageParseError {
173 message: "float_literal_type",
174 kind: ErrorKind::InvalidValue,
175 error: Box::new(pest_error),
176 };
177 return Err(error);
178 }
179 };
180
181 Ok(Literal {
182 literal_type: Some(lit),
183 nullable: nullability != Nullability::Required as i32,
184 type_variation_reference: tvar,
185 })
186}
187
188fn to_boolean_literal(value: pest::iterators::Pair<Rule>) -> Result<Literal, MessageParseError> {
189 assert_eq!(value.as_rule(), Rule::boolean);
190 let parsed_value: bool = value.as_str().parse().unwrap();
191
192 Ok(Literal {
193 literal_type: Some(LiteralType::Boolean(parsed_value)),
194 nullable: false,
195 type_variation_reference: 0,
196 })
197}
198
199fn to_string_literal(
200 value: pest::iterators::Pair<Rule>,
201 typ: Option<Type>,
202) -> Result<Literal, MessageParseError> {
203 assert_eq!(value.as_rule(), Rule::string_literal);
204 let string_value = unescape_string(value.clone());
205
206 let Some(typ) = typ else {
208 return Ok(Literal {
209 literal_type: Some(LiteralType::String(string_value)),
210 nullable: false,
211 type_variation_reference: 0,
212 });
213 };
214
215 let Some(kind) = typ.kind else {
216 return Ok(Literal {
217 literal_type: Some(LiteralType::String(string_value)),
218 nullable: false,
219 type_variation_reference: 0,
220 });
221 };
222
223 match &kind {
224 Kind::Date(d) => {
225 let date_days = parse_date_to_days(&string_value, value.as_span())?;
227 Ok(Literal {
228 literal_type: Some(LiteralType::Date(date_days)),
229 nullable: d.nullability != Nullability::Required as i32,
230 type_variation_reference: d.type_variation_reference,
231 })
232 }
233 Kind::Time(t) => {
234 let time_microseconds = parse_time_to_microseconds(&string_value, value.as_span())?;
236 Ok(Literal {
237 literal_type: Some(LiteralType::Time(time_microseconds)),
238 nullable: t.nullability != Nullability::Required as i32,
239 type_variation_reference: t.type_variation_reference,
240 })
241 }
242 #[allow(deprecated)]
243 Kind::Timestamp(ts) => {
244 let timestamp_microseconds =
246 parse_timestamp_to_microseconds(&string_value, value.as_span())?;
247 Ok(Literal {
248 literal_type: Some(LiteralType::Timestamp(timestamp_microseconds)),
249 nullable: ts.nullability != Nullability::Required as i32,
250 type_variation_reference: ts.type_variation_reference,
251 })
252 }
253 _ => {
254 Ok(Literal {
256 literal_type: Some(LiteralType::String(string_value)),
257 nullable: false,
258 type_variation_reference: 0,
259 })
260 }
261 }
262}
263
264fn parse_date_to_days(date_str: &str, span: pest::Span) -> Result<i32, MessageParseError> {
266 let formats = ["%Y-%m-%d", "%Y/%m/%d"];
268
269 for format in &formats {
270 if let Ok(date) = NaiveDate::parse_from_str(date_str, format) {
271 let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
273 let days = date.signed_duration_since(epoch).num_days();
274 return Ok(days as i32);
275 }
276 }
277
278 Err(MessageParseError {
279 message: "date_parse_format",
280 kind: ErrorKind::InvalidValue,
281 error: Box::new(pest::error::Error::new_from_span(
282 pest::error::ErrorVariant::CustomError {
283 message: format!(
284 "Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD"
285 ),
286 },
287 span,
288 )),
289 })
290}
291
292fn parse_time_to_microseconds(time_str: &str, span: pest::Span) -> Result<i64, MessageParseError> {
294 let formats = ["%H:%M:%S%.f", "%H:%M:%S"];
296
297 for format in &formats {
298 if let Ok(time) = NaiveTime::parse_from_str(time_str, format) {
299 let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap();
301 let duration = time.signed_duration_since(midnight);
302 return Ok(duration.num_microseconds().unwrap_or(0));
303 }
304 }
305
306 Err(MessageParseError {
307 message: "time_parse_format",
308 kind: ErrorKind::InvalidValue,
309 error: Box::new(pest::error::Error::new_from_span(
310 pest::error::ErrorVariant::CustomError {
311 message: format!(
312 "Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff"
313 ),
314 },
315 span,
316 )),
317 })
318}
319
320fn parse_timestamp_to_microseconds(
322 timestamp_str: &str,
323 span: pest::Span,
324) -> Result<i64, MessageParseError> {
325 let formats = [
327 "%Y-%m-%dT%H:%M:%S%.f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S%.f", "%Y-%m-%d %H:%M:%S", "%Y/%m/%dT%H:%M:%S%.f", "%Y/%m/%dT%H:%M:%S", "%Y/%m/%d %H:%M:%S%.f", "%Y/%m/%d %H:%M:%S", ];
336
337 for format in &formats {
338 if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) {
339 let epoch = DateTime::from_timestamp(0, 0).unwrap().naive_utc();
341 let duration = datetime.signed_duration_since(epoch);
342 return Ok(duration.num_microseconds().unwrap_or(0));
343 }
344 }
345
346 Err(MessageParseError {
347 message: "timestamp_parse_format",
348 kind: ErrorKind::InvalidValue,
349 error: Box::new(pest::error::Error::new_from_span(
350 pest::error::ErrorVariant::CustomError {
351 message: format!(
352 "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS"
353 ),
354 },
355 span,
356 )),
357 })
358}
359
360impl ScopedParsePair for Literal {
361 fn rule() -> Rule {
362 Rule::literal
363 }
364
365 fn message() -> &'static str {
366 "Literal"
367 }
368
369 fn parse_pair(
370 extensions: &SimpleExtensions,
371 pair: pest::iterators::Pair<Rule>,
372 ) -> Result<Self, MessageParseError> {
373 assert_eq!(pair.as_rule(), Self::rule());
374 let mut pairs = pair.into_inner();
375 let value = pairs.next().unwrap(); let typ = pairs.next(); assert!(pairs.next().is_none());
378 let typ = match typ {
379 Some(t) => Some(Type::parse_pair(extensions, t)?),
380 None => None,
381 };
382 match value.as_rule() {
383 Rule::integer => to_int_literal(value, typ),
384 Rule::float => to_float_literal(value, typ),
385 Rule::boolean => to_boolean_literal(value),
386 Rule::string_literal => to_string_literal(value, typ),
387 _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
388 }
389 }
390}
391
392impl ScopedParsePair for ScalarFunction {
393 fn rule() -> Rule {
394 Rule::function_call
395 }
396
397 fn message() -> &'static str {
398 "ScalarFunction"
399 }
400
401 fn parse_pair(
402 extensions: &SimpleExtensions,
403 pair: pest::iterators::Pair<Rule>,
404 ) -> Result<Self, MessageParseError> {
405 assert_eq!(pair.as_rule(), Self::rule());
406 let span = pair.as_span();
407 let mut iter = RuleIter::from(pair.into_inner());
408
409 let name = iter.parse_next::<Name>();
411
412 let anchor = iter
414 .try_pop(Rule::anchor)
415 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
416
417 let _urn_anchor = iter
419 .try_pop(Rule::urn_anchor)
420 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
421
422 let argument_list = iter.pop(Rule::argument_list);
424 let mut arguments = Vec::new();
425 for e in argument_list.into_inner() {
426 arguments.push(FunctionArgument {
427 arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
428 });
429 }
430
431 let output_type = match iter.try_pop(Rule::r#type) {
433 Some(t) => Some(Type::parse_pair(extensions, t)?),
434 None => None,
435 };
436
437 iter.done();
438 let anchor =
439 get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
440 Ok(ScalarFunction {
441 function_reference: anchor,
442 arguments,
443 options: vec![], output_type,
445 #[allow(deprecated)]
446 args: vec![],
447 })
448 }
449}
450
451impl ScopedParsePair for Expression {
452 fn rule() -> Rule {
453 Rule::expression
454 }
455
456 fn message() -> &'static str {
457 "Expression"
458 }
459
460 fn parse_pair(
461 extensions: &SimpleExtensions,
462 pair: pest::iterators::Pair<Rule>,
463 ) -> Result<Self, MessageParseError> {
464 assert_eq!(pair.as_rule(), Self::rule());
465 let inner = unwrap_single_pair(pair);
466 match inner.as_rule() {
467 Rule::literal => Ok(Expression {
468 rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
469 }),
470 Rule::function_call => Ok(Expression {
471 rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
472 extensions, inner,
473 )?)),
474 }),
475 Rule::reference => Ok(Expression {
476 rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
477 inner,
478 )))),
479 }),
480 Rule::if_then => Ok(Expression {
481 rex_type: Some(RexType::IfThen(Box::new(IfThen::parse_pair(
482 extensions, inner,
483 )?))),
484 }),
485 _ => unreachable!(
486 "Grammar guarantees expression can only be literal, function_call, reference, or if_then, got: {:?}",
487 inner.as_rule()
488 ),
489 }
490 }
491}
492
493impl ScopedParsePair for IfClause {
494 fn rule() -> Rule {
495 Rule::if_clause
496 }
497
498 fn message() -> &'static str {
499 "IfClause"
500 }
501
502 fn parse_pair(
503 extensions: &SimpleExtensions,
504 pair: pest::iterators::Pair<Rule>,
505 ) -> Result<Self, MessageParseError> {
506 assert_eq!(pair.as_rule(), Self::rule());
507 let mut pairs = pair.into_inner(); let condition = pairs.next().unwrap();
510 let result = pairs.next().unwrap();
511 assert!(pairs.next().is_none());
512
513 let ex1 = Some(Expression::parse_pair(extensions, condition)?);
514 let ex2 = Some(Expression::parse_pair(extensions, result)?);
515
516 Ok(IfClause {
517 r#if: ex1,
518 then: ex2,
519 })
520 }
521}
522
523impl ScopedParsePair for IfThen {
524 fn rule() -> Rule {
525 Rule::if_then
526 }
527 fn message() -> &'static str {
528 "IfThen"
529 }
530
531 fn parse_pair(
532 extensions: &SimpleExtensions,
533 pair: pest::iterators::Pair<Rule>,
534 ) -> Result<Self, MessageParseError> {
535 assert_eq!(pair.as_rule(), Self::rule());
536
537 let mut iter = RuleIter::from(pair.into_inner()); let mut ifs: Vec<IfClause> = Vec::new();
540
541 while let Some(p) = iter.try_pop(Rule::if_clause) {
543 let if_clause = IfClause::parse_pair(extensions, p)?;
544 ifs.push(if_clause);
545 }
546
547 let pair = iter.try_pop(Rule::expression).unwrap(); iter.done();
549 let else_clause = Some(Box::new(Expression::parse_pair(extensions, pair)?));
550
551 Ok(IfThen {
552 ifs,
553 r#else: else_clause,
554 })
555 }
556}
557pub struct Name(pub String);
558
559impl ParsePair for Name {
560 fn rule() -> Rule {
561 Rule::name
562 }
563
564 fn message() -> &'static str {
565 "Name"
566 }
567
568 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
569 assert_eq!(pair.as_rule(), Self::rule());
570 let inner = unwrap_single_pair(pair);
571 match inner.as_rule() {
572 Rule::identifier => Name(inner.as_str().to_string()),
573 Rule::quoted_name => Name(unescape_string(inner)),
574 _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
575 }
576 }
577}
578
579impl ScopedParsePair for Measure {
580 fn rule() -> Rule {
581 Rule::aggregate_measure
582 }
583
584 fn message() -> &'static str {
585 "Measure"
586 }
587
588 fn parse_pair(
589 extensions: &SimpleExtensions,
590 pair: pest::iterators::Pair<Rule>,
591 ) -> Result<Self, MessageParseError> {
592 assert_eq!(pair.as_rule(), Self::rule());
593
594 let function_call_pair = unwrap_single_pair(pair);
596 assert_eq!(function_call_pair.as_rule(), Rule::function_call);
597
598 let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
600 Ok(Measure {
601 measure: Some(AggregateFunction {
602 function_reference: scalar.function_reference,
603 arguments: scalar.arguments,
604 options: scalar.options,
605 output_type: scalar.output_type,
606 invocation: 0, phase: 0, sorts: vec![], #[allow(deprecated)]
610 args: scalar.args,
611 }),
612 filter: None, })
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use pest::Parser as PestParser;
620
621 use super::*;
622 use crate::parser::ExpressionParser;
623
624 fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
625 let mut pairs = ExpressionParser::parse(rule, input).unwrap();
626 assert_eq!(pairs.as_str(), input);
627 let pair = pairs.next().unwrap();
628 assert_eq!(pairs.next(), None);
629 pair
630 }
631
632 fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
633 let pair = parse_exact(T::rule(), input);
634 let actual = T::parse_pair(pair);
635 assert_eq!(actual, expected);
636 }
637
638 fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
639 ext: &SimpleExtensions,
640 input: &str,
641 expected: T,
642 ) {
643 let pair = parse_exact(T::rule(), input);
644 let actual = T::parse_pair(ext, pair).unwrap();
645 assert_eq!(actual, expected);
646 }
647
648 #[test]
649 fn test_parse_field_reference() {
650 assert_parses_to("$1", FieldIndex(1).to_field_reference());
651 }
652
653 #[test]
654 fn test_parse_integer_literal() {
655 let extensions = SimpleExtensions::default();
656 let expected = Literal {
657 literal_type: Some(LiteralType::I64(1)),
658 nullable: false,
659 type_variation_reference: 0,
660 };
661 assert_parses_with(&extensions, "1", expected);
662 }
663
664 #[test]
665 fn test_parse_float_literal() {
666 let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
668 let parsed_text = pairs.as_str();
669 assert_eq!(parsed_text, "3.82");
670
671 let extensions = SimpleExtensions::default();
672 let expected = Literal {
673 literal_type: Some(LiteralType::Fp64(3.82)),
674 nullable: false,
675 type_variation_reference: 0,
676 };
677 assert_parses_with(&extensions, "3.82", expected);
678 }
679
680 #[test]
681 fn test_parse_negative_float_literal() {
682 let extensions = SimpleExtensions::default();
683 let expected = Literal {
684 literal_type: Some(LiteralType::Fp64(-2.5)),
685 nullable: false,
686 type_variation_reference: 0,
687 };
688 assert_parses_with(&extensions, "-2.5", expected);
689 }
690
691 #[test]
692 fn test_parse_boolean_true_literal() {
693 let extensions = SimpleExtensions::default();
694 let expected = Literal {
695 literal_type: Some(LiteralType::Boolean(true)),
696 nullable: false,
697 type_variation_reference: 0,
698 };
699 assert_parses_with(&extensions, "true", expected);
700 }
701
702 #[test]
703 fn test_parse_boolean_false_literal() {
704 let extensions = SimpleExtensions::default();
705 let expected = Literal {
706 literal_type: Some(LiteralType::Boolean(false)),
707 nullable: false,
708 type_variation_reference: 0,
709 };
710 assert_parses_with(&extensions, "false", expected);
711 }
712
713 #[test]
714 fn test_parse_float_literal_with_fp32_type() {
715 let extensions = SimpleExtensions::default();
716 let pair = parse_exact(Rule::literal, "3.82:fp32");
717 let result = Literal::parse_pair(&extensions, pair).unwrap();
718
719 match result.literal_type {
720 Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
721 _ => panic!("Expected Fp32 literal type"),
722 }
723 }
724
725 #[test]
726 fn test_parse_date_literal() {
727 let extensions = SimpleExtensions::default();
728 let pair = parse_exact(Rule::literal, "'2023-12-25':date");
729 let result = Literal::parse_pair(&extensions, pair).unwrap();
730
731 match result.literal_type {
732 Some(LiteralType::Date(days)) => {
733 assert!(
735 days > 0,
736 "Expected positive days since epoch, got: {}",
737 days
738 );
739 }
740 _ => panic!("Expected Date literal type, got: {:?}", result.literal_type),
741 }
742 }
743
744 #[test]
745 fn test_parse_time_literal() {
746 let extensions = SimpleExtensions::default();
747 let pair = parse_exact(Rule::literal, "'14:30:45':time");
748 let result = Literal::parse_pair(&extensions, pair).unwrap();
749
750 match result.literal_type {
751 Some(LiteralType::Time(microseconds)) => {
752 let expected = (14 * 3600 + 30 * 60 + 45) * 1_000_000;
754 assert_eq!(microseconds, expected);
755 }
756 _ => panic!("Expected Time literal type, got: {:?}", result.literal_type),
757 }
758 }
759
760 #[test]
761 fn test_parse_timestamp_literal_with_t() {
762 let extensions = SimpleExtensions::default();
763 let pair = parse_exact(Rule::literal, "'2023-01-01T12:00:00':timestamp");
764 let result = Literal::parse_pair(&extensions, pair).unwrap();
765
766 match result.literal_type {
767 #[allow(deprecated)]
768 Some(LiteralType::Timestamp(microseconds)) => {
769 assert!(
770 microseconds > 0,
771 "Expected positive microseconds since epoch"
772 );
773 }
774 _ => panic!(
775 "Expected Timestamp literal type, got: {:?}",
776 result.literal_type
777 ),
778 }
779 }
780
781 #[test]
782 fn test_parse_timestamp_literal_with_space() {
783 let extensions = SimpleExtensions::default();
784 let pair = parse_exact(Rule::literal, "'2023-01-01 12:00:00':timestamp");
785 let result = Literal::parse_pair(&extensions, pair).unwrap();
786
787 match result.literal_type {
788 #[allow(deprecated)]
789 Some(LiteralType::Timestamp(microseconds)) => {
790 assert!(
791 microseconds > 0,
792 "Expected positive microseconds since epoch"
793 );
794 }
795 _ => panic!(
796 "Expected Timestamp literal type, got: {:?}",
797 result.literal_type
798 ),
799 }
800 }
801
802 fn make_literal_bool(value: bool) -> Expression {
804 Expression {
805 rex_type: Some(RexType::Literal(Literal {
806 literal_type: Some(LiteralType::Boolean(value)),
807 nullable: false,
808 type_variation_reference: 0,
809 })),
810 }
811 }
812
813 #[test]
814 fn test_parse_if_then_single_clause() {
815 let extensions = SimpleExtensions::default();
816 let input = "if_then(true -> 42, _ -> 0)";
817 let pair = parse_exact(Rule::if_then, input);
818 let result = IfThen::parse_pair(&extensions, pair).unwrap();
819
820 assert_eq!(result.ifs.len(), 1);
821 assert!(result.r#else.is_some());
822 }
823
824 #[test]
825 fn test_parse_if_then_with_typed_literals() {
826 let extensions = SimpleExtensions::default();
827 let input = "if_then(true -> 100:i32, _ -> -100:i32)";
828 let pair = parse_exact(Rule::if_then, input);
829 let result = IfThen::parse_pair(&extensions, pair).unwrap();
830
831 assert_eq!(result.ifs.len(), 1);
832 assert!(result.r#else.is_some());
833 }
834
835 #[test]
836 fn test_parse_if_then_with_date_literals() {
837 let extensions = SimpleExtensions::default();
838 let input = "if_then(true -> '2023-12-25':date, _ -> '1970-01-01':date)";
839 let pair = parse_exact(Rule::if_then, input);
840 let result = IfThen::parse_pair(&extensions, pair).unwrap();
841
842 assert_eq!(result.ifs.len(), 1);
843 assert!(result.r#else.is_some());
844 }
845
846 #[test]
847 fn test_parse_if_then_with_time_literals() {
848 let extensions = SimpleExtensions::default();
849 let input = "if_then(true -> '14:30:45':time, _ -> '00:00:00':time)";
850 let pair = parse_exact(Rule::if_then, input);
851 let result = IfThen::parse_pair(&extensions, pair).unwrap();
852
853 assert_eq!(result.ifs.len(), 1);
854 assert!(result.r#else.is_some());
855 }
856
857 #[test]
858 fn test_parse_if_then_with_timestamp_literals() {
859 let extensions = SimpleExtensions::default();
860 let input = "if_then(true -> '2023-01-01T12:00:00':timestamp, _ -> '1970-01-01T00:00:00':timestamp)";
861 let pair = parse_exact(Rule::if_then, input);
862 let result = IfThen::parse_pair(&extensions, pair).unwrap();
863
864 assert_eq!(result.ifs.len(), 1);
865 assert!(result.r#else.is_some());
866 }
867
868 #[test]
869 fn test_parse_if_clause_with_whitespace_variations() {
870 let extensions = SimpleExtensions::default();
871
872 let inputs = vec!["true->false", "true -> false", "true -> false"];
874
875 for input in inputs {
876 let pair = parse_exact(Rule::if_clause, input);
877 let result = IfClause::parse_pair(&extensions, pair).unwrap();
878 assert!(result.r#if.is_some());
879 assert!(result.then.is_some());
880 }
881 }
882
883 #[test]
884 fn test_if_clause_structure() {
885 let extensions = SimpleExtensions::default();
886 let pair = parse_exact(Rule::if_clause, "42 -> 100");
887 let result = IfClause::parse_pair(&extensions, pair).unwrap();
888
889 let if_expr = result.r#if.as_ref().unwrap();
891 let then_expr = result.then.as_ref().unwrap();
892
893 match (&if_expr.rex_type, &then_expr.rex_type) {
895 (Some(RexType::Literal(_)), Some(RexType::Literal(_))) => {
896 }
898 _ => panic!("Expected both if and then to be literals"),
899 }
900 }
901
902 #[test]
903 fn test_if_then_structure() {
904 let extensions = SimpleExtensions::default();
905 let input = "if_then(true -> 1, false -> 2, _ -> 0)";
906 let pair = parse_exact(Rule::if_then, input);
907 let result = IfThen::parse_pair(&extensions, pair).unwrap();
908
909 assert_eq!(result.ifs.len(), 2);
911
912 for clause in &result.ifs {
914 assert!(clause.r#if.is_some(), "If clause condition should exist");
915 assert!(clause.then.is_some(), "If clause result should exist");
916 }
917
918 assert!(result.r#else.is_some(), "Else clause should exist");
920 }
921
922 #[test]
923 fn test_parse_if_then_mixed_types_in_conditions() {
924 let extensions = SimpleExtensions::default();
925 let input = "if_then(true -> 1, true -> 'yes', 'yes' -> true, 42 -> 2, $0 -> 3, _ -> 0)";
927 let pair = parse_exact(Rule::if_then, input);
928 let result = IfThen::parse_pair(&extensions, pair).unwrap();
929
930 assert_eq!(result.ifs.len(), 5);
931 assert!(result.r#else.is_some());
932 }
933
934 #[test]
935 fn test_if_then_preserves_clause_order() {
936 let extensions = SimpleExtensions::default();
937 let input = "if_then(1 -> 10, 2 -> 20, 3 -> 30, _ -> 0)";
938 let pair = parse_exact(Rule::if_then, input);
939 let result = IfThen::parse_pair(&extensions, pair).unwrap();
940
941 assert_eq!(result.ifs.len(), 3);
942
943 for (i, clause) in result.ifs.iter().enumerate() {
945 if let Some(Expression {
946 rex_type: Some(RexType::Literal(lit)),
947 }) = &clause.r#if
948 {
949 if let Some(LiteralType::I64(val)) = &lit.literal_type {
950 assert_eq!(*val, (i as i64) + 1);
951 }
952 }
953 }
954 }
955
956 #[test]
957 fn test_parse_if_then() {
958 let extensions = SimpleExtensions::default();
959
960 let c1 = IfClause {
961 r#if: Some(make_literal_bool(true)),
962 then: Some(make_literal_bool(true)),
963 };
964
965 let c2 = IfClause {
966 r#if: Some(make_literal_bool(false)),
967 then: Some(make_literal_bool(false)),
968 };
969
970 let if_clause = IfThen {
971 ifs: vec![c1, c2],
972 r#else: Some(Box::new(make_literal_bool(false))),
973 };
974 assert_parses_with(
975 &extensions,
976 "if_then(true -> true , false -> false, _ -> false)",
977 if_clause,
978 );
979 }
980}