substrait_explain/parser/
expressions.rs

1use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime};
2use substrait::proto::aggregate_rel::Measure;
3use substrait::proto::expression::field_reference::ReferenceType;
4use substrait::proto::expression::if_then::IfClause;
5use substrait::proto::expression::literal::LiteralType;
6use substrait::proto::expression::{
7    FieldReference, IfThen, Literal, ReferenceSegment, RexType, ScalarFunction, reference_segment,
8};
9use substrait::proto::function_argument::ArgType;
10use substrait::proto::r#type::{Fp64, I64, Kind, Nullability};
11use substrait::proto::{AggregateFunction, Expression, FunctionArgument, Type};
12
13use super::types::get_and_validate_anchor;
14use super::{
15    MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
16    unwrap_single_pair,
17};
18use crate::extensions::SimpleExtensions;
19use crate::extensions::simple::ExtensionKind;
20use crate::parser::ErrorKind;
21
22/// A field index (e.g., parsed from "$0" -> 0).
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct FieldIndex(pub i32);
25
26impl FieldIndex {
27    /// Convert this field index to a FieldReference for use in expressions.
28    pub fn to_field_reference(self) -> FieldReference {
29        // XXX: Why is it so many layers to make a struct field reference? This is
30        // surprisingly complex
31        FieldReference {
32            reference_type: Some(ReferenceType::DirectReference(ReferenceSegment {
33                reference_type: Some(reference_segment::ReferenceType::StructField(Box::new(
34                    reference_segment::StructField {
35                        field: self.0,
36                        child: None,
37                    },
38                ))),
39            })),
40            root_type: None,
41        }
42    }
43}
44
45impl ParsePair for FieldIndex {
46    fn rule() -> Rule {
47        Rule::reference
48    }
49
50    fn message() -> &'static str {
51        "FieldIndex"
52    }
53
54    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
55        assert_eq!(pair.as_rule(), Self::rule());
56        let inner = unwrap_single_pair(pair);
57        let index: i32 = inner.as_str().parse().unwrap();
58        FieldIndex(index)
59    }
60}
61
62impl ParsePair for FieldReference {
63    fn rule() -> Rule {
64        Rule::reference
65    }
66
67    fn message() -> &'static str {
68        "FieldReference"
69    }
70
71    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
72        assert_eq!(pair.as_rule(), Self::rule());
73
74        // TODO: Other types of references.
75        FieldIndex::parse_pair(pair).to_field_reference()
76    }
77}
78
79fn to_int_literal(
80    value: pest::iterators::Pair<Rule>,
81    typ: Option<Type>,
82) -> Result<Literal, MessageParseError> {
83    assert_eq!(value.as_rule(), Rule::integer);
84    let parsed_value: i64 = value.as_str().parse().unwrap();
85
86    const DEFAULT_KIND: Kind = Kind::I64(I64 {
87        type_variation_reference: 0,
88        nullability: Nullability::Required as i32,
89    });
90
91    // If no type is provided, we assume i64, Nullability::Required.
92    let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
93
94    let (lit, nullability, tvar) = match &kind {
95        // If no type is provided, we assume i64, Nullability::Required.
96        Kind::I8(i) => (
97            LiteralType::I8(parsed_value as i32),
98            i.nullability,
99            i.type_variation_reference,
100        ),
101        Kind::I16(i) => (
102            LiteralType::I16(parsed_value as i32),
103            i.nullability,
104            i.type_variation_reference,
105        ),
106        Kind::I32(i) => (
107            LiteralType::I32(parsed_value as i32),
108            i.nullability,
109            i.type_variation_reference,
110        ),
111        Kind::I64(i) => (
112            LiteralType::I64(parsed_value),
113            i.nullability,
114            i.type_variation_reference,
115        ),
116        k => {
117            let pest_error = pest::error::Error::new_from_span(
118                pest::error::ErrorVariant::CustomError {
119                    message: format!("Invalid type for integer literal: {k:?}"),
120                },
121                value.as_span(),
122            );
123            let error = MessageParseError {
124                message: "int_literal_type",
125                kind: ErrorKind::InvalidValue,
126                error: Box::new(pest_error),
127            };
128            return Err(error);
129        }
130    };
131
132    Ok(Literal {
133        literal_type: Some(lit),
134        nullable: nullability != Nullability::Required as i32,
135        type_variation_reference: tvar,
136    })
137}
138
139fn to_float_literal(
140    value: pest::iterators::Pair<Rule>,
141    typ: Option<Type>,
142) -> Result<Literal, MessageParseError> {
143    assert_eq!(value.as_rule(), Rule::float);
144    let parsed_value: f64 = value.as_str().parse().unwrap();
145
146    const DEFAULT_KIND: Kind = Kind::Fp64(Fp64 {
147        type_variation_reference: 0,
148        nullability: Nullability::Required as i32,
149    });
150
151    // If no type is provided, we assume fp64, Nullability::Required.
152    let kind = typ.and_then(|t| t.kind).unwrap_or(DEFAULT_KIND);
153
154    let (lit, nullability, tvar) = match &kind {
155        Kind::Fp32(f) => (
156            LiteralType::Fp32(parsed_value as f32),
157            f.nullability,
158            f.type_variation_reference,
159        ),
160        Kind::Fp64(f) => (
161            LiteralType::Fp64(parsed_value),
162            f.nullability,
163            f.type_variation_reference,
164        ),
165        k => {
166            let pest_error = pest::error::Error::new_from_span(
167                pest::error::ErrorVariant::CustomError {
168                    message: format!("Invalid type for float literal: {k:?}"),
169                },
170                value.as_span(),
171            );
172            let error = MessageParseError {
173                message: "float_literal_type",
174                kind: ErrorKind::InvalidValue,
175                error: Box::new(pest_error),
176            };
177            return Err(error);
178        }
179    };
180
181    Ok(Literal {
182        literal_type: Some(lit),
183        nullable: nullability != Nullability::Required as i32,
184        type_variation_reference: tvar,
185    })
186}
187
188fn to_boolean_literal(value: pest::iterators::Pair<Rule>) -> Result<Literal, MessageParseError> {
189    assert_eq!(value.as_rule(), Rule::boolean);
190    let parsed_value: bool = value.as_str().parse().unwrap();
191
192    Ok(Literal {
193        literal_type: Some(LiteralType::Boolean(parsed_value)),
194        nullable: false,
195        type_variation_reference: 0,
196    })
197}
198
199fn to_string_literal(
200    value: pest::iterators::Pair<Rule>,
201    typ: Option<Type>,
202) -> Result<Literal, MessageParseError> {
203    assert_eq!(value.as_rule(), Rule::string_literal);
204    let string_value = unescape_string(value.clone());
205
206    // If no type is provided, default to string
207    let Some(typ) = typ else {
208        return Ok(Literal {
209            literal_type: Some(LiteralType::String(string_value)),
210            nullable: false,
211            type_variation_reference: 0,
212        });
213    };
214
215    let Some(kind) = typ.kind else {
216        return Ok(Literal {
217            literal_type: Some(LiteralType::String(string_value)),
218            nullable: false,
219            type_variation_reference: 0,
220        });
221    };
222
223    match &kind {
224        Kind::Date(d) => {
225            // Parse date in ISO 8601 format: YYYY-MM-DD
226            let date_days = parse_date_to_days(&string_value, value.as_span())?;
227            Ok(Literal {
228                literal_type: Some(LiteralType::Date(date_days)),
229                nullable: d.nullability != Nullability::Required as i32,
230                type_variation_reference: d.type_variation_reference,
231            })
232        }
233        Kind::Time(t) => {
234            // Parse time in ISO 8601 format: HH:MM:SS[.fff]
235            let time_microseconds = parse_time_to_microseconds(&string_value, value.as_span())?;
236            Ok(Literal {
237                literal_type: Some(LiteralType::Time(time_microseconds)),
238                nullable: t.nullability != Nullability::Required as i32,
239                type_variation_reference: t.type_variation_reference,
240            })
241        }
242        #[allow(deprecated)]
243        Kind::Timestamp(ts) => {
244            // Parse timestamp in ISO 8601 format: YYYY-MM-DDTHH:MM:SS[.fff] or YYYY-MM-DD HH:MM:SS[.fff]
245            let timestamp_microseconds =
246                parse_timestamp_to_microseconds(&string_value, value.as_span())?;
247            Ok(Literal {
248                literal_type: Some(LiteralType::Timestamp(timestamp_microseconds)),
249                nullable: ts.nullability != Nullability::Required as i32,
250                type_variation_reference: ts.type_variation_reference,
251            })
252        }
253        _ => {
254            // For other types, treat as string
255            Ok(Literal {
256                literal_type: Some(LiteralType::String(string_value)),
257                nullable: false,
258                type_variation_reference: 0,
259            })
260        }
261    }
262}
263
264/// Parse a date string using chrono to days since Unix epoch
265fn parse_date_to_days(date_str: &str, span: pest::Span) -> Result<i32, MessageParseError> {
266    // Try multiple date formats for flexibility
267    let formats = ["%Y-%m-%d", "%Y/%m/%d"];
268
269    for format in &formats {
270        if let Ok(date) = NaiveDate::parse_from_str(date_str, format) {
271            // Calculate days since Unix epoch (1970-01-01)
272            let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
273            let days = date.signed_duration_since(epoch).num_days();
274            return Ok(days as i32);
275        }
276    }
277
278    Err(MessageParseError {
279        message: "date_parse_format",
280        kind: ErrorKind::InvalidValue,
281        error: Box::new(pest::error::Error::new_from_span(
282            pest::error::ErrorVariant::CustomError {
283                message: format!(
284                    "Invalid date format: '{date_str}'. Expected YYYY-MM-DD or YYYY/MM/DD"
285                ),
286            },
287            span,
288        )),
289    })
290}
291
292/// Parse a time string using chrono to microseconds since midnight
293fn parse_time_to_microseconds(time_str: &str, span: pest::Span) -> Result<i64, MessageParseError> {
294    // Try multiple time formats for flexibility
295    let formats = ["%H:%M:%S%.f", "%H:%M:%S"];
296
297    for format in &formats {
298        if let Ok(time) = NaiveTime::parse_from_str(time_str, format) {
299            // Convert to microseconds since midnight
300            let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap();
301            let duration = time.signed_duration_since(midnight);
302            return Ok(duration.num_microseconds().unwrap_or(0));
303        }
304    }
305
306    Err(MessageParseError {
307        message: "time_parse_format",
308        kind: ErrorKind::InvalidValue,
309        error: Box::new(pest::error::Error::new_from_span(
310            pest::error::ErrorVariant::CustomError {
311                message: format!(
312                    "Invalid time format: '{time_str}'. Expected HH:MM:SS or HH:MM:SS.fff"
313                ),
314            },
315            span,
316        )),
317    })
318}
319
320/// Parse a timestamp string using chrono to microseconds since Unix epoch
321fn parse_timestamp_to_microseconds(
322    timestamp_str: &str,
323    span: pest::Span,
324) -> Result<i64, MessageParseError> {
325    // Try multiple timestamp formats for flexibility
326    let formats = [
327        "%Y-%m-%dT%H:%M:%S%.f", // ISO 8601 with T and fractional seconds
328        "%Y-%m-%dT%H:%M:%S",    // ISO 8601 with T
329        "%Y-%m-%d %H:%M:%S%.f", // Space separator with fractional seconds
330        "%Y-%m-%d %H:%M:%S",    // Space separator
331        "%Y/%m/%dT%H:%M:%S%.f", // Alternative date format with T
332        "%Y/%m/%dT%H:%M:%S",    // Alternative date format with T
333        "%Y/%m/%d %H:%M:%S%.f", // Alternative date format with space
334        "%Y/%m/%d %H:%M:%S",    // Alternative date format with space
335    ];
336
337    for format in &formats {
338        if let Ok(datetime) = NaiveDateTime::parse_from_str(timestamp_str, format) {
339            // Calculate microseconds since Unix epoch (1970-01-01 00:00:00)
340            let epoch = DateTime::from_timestamp(0, 0).unwrap().naive_utc();
341            let duration = datetime.signed_duration_since(epoch);
342            return Ok(duration.num_microseconds().unwrap_or(0));
343        }
344    }
345
346    Err(MessageParseError {
347        message: "timestamp_parse_format",
348        kind: ErrorKind::InvalidValue,
349        error: Box::new(pest::error::Error::new_from_span(
350            pest::error::ErrorVariant::CustomError {
351                message: format!(
352                    "Invalid timestamp format: '{timestamp_str}'. Expected YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS"
353                ),
354            },
355            span,
356        )),
357    })
358}
359
360impl ScopedParsePair for Literal {
361    fn rule() -> Rule {
362        Rule::literal
363    }
364
365    fn message() -> &'static str {
366        "Literal"
367    }
368
369    fn parse_pair(
370        extensions: &SimpleExtensions,
371        pair: pest::iterators::Pair<Rule>,
372    ) -> Result<Self, MessageParseError> {
373        assert_eq!(pair.as_rule(), Self::rule());
374        let mut pairs = pair.into_inner();
375        let value = pairs.next().unwrap(); // First item is always the value
376        let typ = pairs.next(); // Second item is optional type
377        assert!(pairs.next().is_none());
378        let typ = match typ {
379            Some(t) => Some(Type::parse_pair(extensions, t)?),
380            None => None,
381        };
382        match value.as_rule() {
383            Rule::integer => to_int_literal(value, typ),
384            Rule::float => to_float_literal(value, typ),
385            Rule::boolean => to_boolean_literal(value),
386            Rule::string_literal => to_string_literal(value, typ),
387            _ => unreachable!("Literal unexpected rule: {:?}", value.as_rule()),
388        }
389    }
390}
391
392impl ScopedParsePair for ScalarFunction {
393    fn rule() -> Rule {
394        Rule::function_call
395    }
396
397    fn message() -> &'static str {
398        "ScalarFunction"
399    }
400
401    fn parse_pair(
402        extensions: &SimpleExtensions,
403        pair: pest::iterators::Pair<Rule>,
404    ) -> Result<Self, MessageParseError> {
405        assert_eq!(pair.as_rule(), Self::rule());
406        let span = pair.as_span();
407        let mut iter = RuleIter::from(pair.into_inner());
408
409        // Parse function name (required)
410        let name = iter.parse_next::<Name>();
411
412        // Parse optional URN anchor (e.g., #1)
413        let anchor = iter
414            .try_pop(Rule::anchor)
415            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
416
417        // Parse optional URN anchor (e.g., @1)
418        let _urn_anchor = iter
419            .try_pop(Rule::urn_anchor)
420            .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
421
422        // Parse argument list (required)
423        let argument_list = iter.pop(Rule::argument_list);
424        let mut arguments = Vec::new();
425        for e in argument_list.into_inner() {
426            arguments.push(FunctionArgument {
427                arg_type: Some(ArgType::Value(Expression::parse_pair(extensions, e)?)),
428            });
429        }
430
431        // Parse optional output type (e.g., :i64)
432        let output_type = match iter.try_pop(Rule::r#type) {
433            Some(t) => Some(Type::parse_pair(extensions, t)?),
434            None => None,
435        };
436
437        iter.done();
438        let anchor =
439            get_and_validate_anchor(extensions, ExtensionKind::Function, anchor, &name.0, span)?;
440        Ok(ScalarFunction {
441            function_reference: anchor,
442            arguments,
443            options: vec![], // TODO: Function Options
444            output_type,
445            #[allow(deprecated)]
446            args: vec![],
447        })
448    }
449}
450
451impl ScopedParsePair for Expression {
452    fn rule() -> Rule {
453        Rule::expression
454    }
455
456    fn message() -> &'static str {
457        "Expression"
458    }
459
460    fn parse_pair(
461        extensions: &SimpleExtensions,
462        pair: pest::iterators::Pair<Rule>,
463    ) -> Result<Self, MessageParseError> {
464        assert_eq!(pair.as_rule(), Self::rule());
465        let inner = unwrap_single_pair(pair);
466        match inner.as_rule() {
467            Rule::literal => Ok(Expression {
468                rex_type: Some(RexType::Literal(Literal::parse_pair(extensions, inner)?)),
469            }),
470            Rule::function_call => Ok(Expression {
471                rex_type: Some(RexType::ScalarFunction(ScalarFunction::parse_pair(
472                    extensions, inner,
473                )?)),
474            }),
475            Rule::reference => Ok(Expression {
476                rex_type: Some(RexType::Selection(Box::new(FieldReference::parse_pair(
477                    inner,
478                )))),
479            }),
480            Rule::if_then => Ok(Expression {
481                rex_type: Some(RexType::IfThen(Box::new(IfThen::parse_pair(
482                    extensions, inner,
483                )?))),
484            }),
485            _ => unreachable!(
486                "Grammar guarantees expression can only be literal, function_call, reference, or if_then, got: {:?}",
487                inner.as_rule()
488            ),
489        }
490    }
491}
492
493impl ScopedParsePair for IfClause {
494    fn rule() -> Rule {
495        Rule::if_clause
496    }
497
498    fn message() -> &'static str {
499        "IfClause"
500    }
501
502    fn parse_pair(
503        extensions: &SimpleExtensions,
504        pair: pest::iterators::Pair<Rule>,
505    ) -> Result<Self, MessageParseError> {
506        assert_eq!(pair.as_rule(), Self::rule());
507        let mut pairs = pair.into_inner(); // should have 2 children, 2 expressions
508
509        let condition = pairs.next().unwrap();
510        let result = pairs.next().unwrap();
511        assert!(pairs.next().is_none());
512
513        let ex1 = Some(Expression::parse_pair(extensions, condition)?);
514        let ex2 = Some(Expression::parse_pair(extensions, result)?);
515
516        Ok(IfClause {
517            r#if: ex1,
518            then: ex2,
519        })
520    }
521}
522
523impl ScopedParsePair for IfThen {
524    fn rule() -> Rule {
525        Rule::if_then
526    }
527    fn message() -> &'static str {
528        "IfThen"
529    }
530
531    fn parse_pair(
532        extensions: &SimpleExtensions,
533        pair: pest::iterators::Pair<Rule>,
534    ) -> Result<Self, MessageParseError> {
535        assert_eq!(pair.as_rule(), Self::rule());
536
537        let mut iter = RuleIter::from(pair.into_inner()); // should have 2 or more children
538
539        let mut ifs: Vec<IfClause> = Vec::new();
540
541        // gets all of the if clauses
542        while let Some(p) = iter.try_pop(Rule::if_clause) {
543            let if_clause = IfClause::parse_pair(extensions, p)?;
544            ifs.push(if_clause);
545        }
546
547        let pair = iter.try_pop(Rule::expression).unwrap(); // should be else expression
548        iter.done();
549        let else_clause = Some(Box::new(Expression::parse_pair(extensions, pair)?));
550
551        Ok(IfThen {
552            ifs,
553            r#else: else_clause,
554        })
555    }
556}
557pub struct Name(pub String);
558
559impl ParsePair for Name {
560    fn rule() -> Rule {
561        Rule::name
562    }
563
564    fn message() -> &'static str {
565        "Name"
566    }
567
568    fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
569        assert_eq!(pair.as_rule(), Self::rule());
570        let inner = unwrap_single_pair(pair);
571        match inner.as_rule() {
572            Rule::identifier => Name(inner.as_str().to_string()),
573            Rule::quoted_name => Name(unescape_string(inner)),
574            _ => unreachable!("Name unexpected rule: {:?}", inner.as_rule()),
575        }
576    }
577}
578
579impl ScopedParsePair for Measure {
580    fn rule() -> Rule {
581        Rule::aggregate_measure
582    }
583
584    fn message() -> &'static str {
585        "Measure"
586    }
587
588    fn parse_pair(
589        extensions: &SimpleExtensions,
590        pair: pest::iterators::Pair<Rule>,
591    ) -> Result<Self, MessageParseError> {
592        assert_eq!(pair.as_rule(), Self::rule());
593
594        // Extract the inner function_call from aggregate_measure
595        let function_call_pair = unwrap_single_pair(pair);
596        assert_eq!(function_call_pair.as_rule(), Rule::function_call);
597
598        // Parse as ScalarFunction, then convert to AggregateFunction
599        let scalar = ScalarFunction::parse_pair(extensions, function_call_pair)?;
600        Ok(Measure {
601            measure: Some(AggregateFunction {
602                function_reference: scalar.function_reference,
603                arguments: scalar.arguments,
604                options: scalar.options,
605                output_type: scalar.output_type,
606                invocation: 0, // TODO: support invocation (ALL, DISTINCT, etc.)
607                phase: 0, // TODO: support phase (INITIAL_TO_RESULT, PARTIAL_TO_INTERMEDIATE, etc.)
608                sorts: vec![], // TODO: support sorts for ordered aggregates
609                #[allow(deprecated)]
610                args: scalar.args,
611            }),
612            filter: None, // TODO: support filter conditions on aggregate measures
613        })
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use pest::Parser as PestParser;
620
621    use super::*;
622    use crate::parser::ExpressionParser;
623
624    fn parse_exact(rule: Rule, input: &'_ str) -> pest::iterators::Pair<'_, Rule> {
625        let mut pairs = ExpressionParser::parse(rule, input).unwrap();
626        assert_eq!(pairs.as_str(), input);
627        let pair = pairs.next().unwrap();
628        assert_eq!(pairs.next(), None);
629        pair
630    }
631
632    fn assert_parses_to<T: ParsePair + PartialEq + std::fmt::Debug>(input: &str, expected: T) {
633        let pair = parse_exact(T::rule(), input);
634        let actual = T::parse_pair(pair);
635        assert_eq!(actual, expected);
636    }
637
638    fn assert_parses_with<T: ScopedParsePair + PartialEq + std::fmt::Debug>(
639        ext: &SimpleExtensions,
640        input: &str,
641        expected: T,
642    ) {
643        let pair = parse_exact(T::rule(), input);
644        let actual = T::parse_pair(ext, pair).unwrap();
645        assert_eq!(actual, expected);
646    }
647
648    #[test]
649    fn test_parse_field_reference() {
650        assert_parses_to("$1", FieldIndex(1).to_field_reference());
651    }
652
653    #[test]
654    fn test_parse_integer_literal() {
655        let extensions = SimpleExtensions::default();
656        let expected = Literal {
657            literal_type: Some(LiteralType::I64(1)),
658            nullable: false,
659            type_variation_reference: 0,
660        };
661        assert_parses_with(&extensions, "1", expected);
662    }
663
664    #[test]
665    fn test_parse_float_literal() {
666        // First test that the grammar can parse floats
667        let pairs = ExpressionParser::parse(Rule::float, "3.82").unwrap();
668        let parsed_text = pairs.as_str();
669        assert_eq!(parsed_text, "3.82");
670
671        let extensions = SimpleExtensions::default();
672        let expected = Literal {
673            literal_type: Some(LiteralType::Fp64(3.82)),
674            nullable: false,
675            type_variation_reference: 0,
676        };
677        assert_parses_with(&extensions, "3.82", expected);
678    }
679
680    #[test]
681    fn test_parse_negative_float_literal() {
682        let extensions = SimpleExtensions::default();
683        let expected = Literal {
684            literal_type: Some(LiteralType::Fp64(-2.5)),
685            nullable: false,
686            type_variation_reference: 0,
687        };
688        assert_parses_with(&extensions, "-2.5", expected);
689    }
690
691    #[test]
692    fn test_parse_boolean_true_literal() {
693        let extensions = SimpleExtensions::default();
694        let expected = Literal {
695            literal_type: Some(LiteralType::Boolean(true)),
696            nullable: false,
697            type_variation_reference: 0,
698        };
699        assert_parses_with(&extensions, "true", expected);
700    }
701
702    #[test]
703    fn test_parse_boolean_false_literal() {
704        let extensions = SimpleExtensions::default();
705        let expected = Literal {
706            literal_type: Some(LiteralType::Boolean(false)),
707            nullable: false,
708            type_variation_reference: 0,
709        };
710        assert_parses_with(&extensions, "false", expected);
711    }
712
713    #[test]
714    fn test_parse_float_literal_with_fp32_type() {
715        let extensions = SimpleExtensions::default();
716        let pair = parse_exact(Rule::literal, "3.82:fp32");
717        let result = Literal::parse_pair(&extensions, pair).unwrap();
718
719        match result.literal_type {
720            Some(LiteralType::Fp32(val)) => assert!((val - 3.82).abs() < f32::EPSILON),
721            _ => panic!("Expected Fp32 literal type"),
722        }
723    }
724
725    #[test]
726    fn test_parse_date_literal() {
727        let extensions = SimpleExtensions::default();
728        let pair = parse_exact(Rule::literal, "'2023-12-25':date");
729        let result = Literal::parse_pair(&extensions, pair).unwrap();
730
731        match result.literal_type {
732            Some(LiteralType::Date(days)) => {
733                // 2023-12-25 should be a positive number of days since 1970-01-01
734                assert!(
735                    days > 0,
736                    "Expected positive days since epoch, got: {}",
737                    days
738                );
739            }
740            _ => panic!("Expected Date literal type, got: {:?}", result.literal_type),
741        }
742    }
743
744    #[test]
745    fn test_parse_time_literal() {
746        let extensions = SimpleExtensions::default();
747        let pair = parse_exact(Rule::literal, "'14:30:45':time");
748        let result = Literal::parse_pair(&extensions, pair).unwrap();
749
750        match result.literal_type {
751            Some(LiteralType::Time(microseconds)) => {
752                // 14:30:45 = (14*3600 + 30*60 + 45) * 1_000_000 microseconds
753                let expected = (14 * 3600 + 30 * 60 + 45) * 1_000_000;
754                assert_eq!(microseconds, expected);
755            }
756            _ => panic!("Expected Time literal type, got: {:?}", result.literal_type),
757        }
758    }
759
760    #[test]
761    fn test_parse_timestamp_literal_with_t() {
762        let extensions = SimpleExtensions::default();
763        let pair = parse_exact(Rule::literal, "'2023-01-01T12:00:00':timestamp");
764        let result = Literal::parse_pair(&extensions, pair).unwrap();
765
766        match result.literal_type {
767            #[allow(deprecated)]
768            Some(LiteralType::Timestamp(microseconds)) => {
769                assert!(
770                    microseconds > 0,
771                    "Expected positive microseconds since epoch"
772                );
773            }
774            _ => panic!(
775                "Expected Timestamp literal type, got: {:?}",
776                result.literal_type
777            ),
778        }
779    }
780
781    #[test]
782    fn test_parse_timestamp_literal_with_space() {
783        let extensions = SimpleExtensions::default();
784        let pair = parse_exact(Rule::literal, "'2023-01-01 12:00:00':timestamp");
785        let result = Literal::parse_pair(&extensions, pair).unwrap();
786
787        match result.literal_type {
788            #[allow(deprecated)]
789            Some(LiteralType::Timestamp(microseconds)) => {
790                assert!(
791                    microseconds > 0,
792                    "Expected positive microseconds since epoch"
793                );
794            }
795            _ => panic!(
796                "Expected Timestamp literal type, got: {:?}",
797                result.literal_type
798            ),
799        }
800    }
801
802    /// Helper function to create a literal boolean expression
803    fn make_literal_bool(value: bool) -> Expression {
804        Expression {
805            rex_type: Some(RexType::Literal(Literal {
806                literal_type: Some(LiteralType::Boolean(value)),
807                nullable: false,
808                type_variation_reference: 0,
809            })),
810        }
811    }
812
813    #[test]
814    fn test_parse_if_then_single_clause() {
815        let extensions = SimpleExtensions::default();
816        let input = "if_then(true -> 42, _ -> 0)";
817        let pair = parse_exact(Rule::if_then, input);
818        let result = IfThen::parse_pair(&extensions, pair).unwrap();
819
820        assert_eq!(result.ifs.len(), 1);
821        assert!(result.r#else.is_some());
822    }
823
824    #[test]
825    fn test_parse_if_then_with_typed_literals() {
826        let extensions = SimpleExtensions::default();
827        let input = "if_then(true -> 100:i32, _ -> -100:i32)";
828        let pair = parse_exact(Rule::if_then, input);
829        let result = IfThen::parse_pair(&extensions, pair).unwrap();
830
831        assert_eq!(result.ifs.len(), 1);
832        assert!(result.r#else.is_some());
833    }
834
835    #[test]
836    fn test_parse_if_then_with_date_literals() {
837        let extensions = SimpleExtensions::default();
838        let input = "if_then(true -> '2023-12-25':date, _ -> '1970-01-01':date)";
839        let pair = parse_exact(Rule::if_then, input);
840        let result = IfThen::parse_pair(&extensions, pair).unwrap();
841
842        assert_eq!(result.ifs.len(), 1);
843        assert!(result.r#else.is_some());
844    }
845
846    #[test]
847    fn test_parse_if_then_with_time_literals() {
848        let extensions = SimpleExtensions::default();
849        let input = "if_then(true -> '14:30:45':time, _ -> '00:00:00':time)";
850        let pair = parse_exact(Rule::if_then, input);
851        let result = IfThen::parse_pair(&extensions, pair).unwrap();
852
853        assert_eq!(result.ifs.len(), 1);
854        assert!(result.r#else.is_some());
855    }
856
857    #[test]
858    fn test_parse_if_then_with_timestamp_literals() {
859        let extensions = SimpleExtensions::default();
860        let input = "if_then(true -> '2023-01-01T12:00:00':timestamp, _ -> '1970-01-01T00:00:00':timestamp)";
861        let pair = parse_exact(Rule::if_then, input);
862        let result = IfThen::parse_pair(&extensions, pair).unwrap();
863
864        assert_eq!(result.ifs.len(), 1);
865        assert!(result.r#else.is_some());
866    }
867
868    #[test]
869    fn test_parse_if_clause_with_whitespace_variations() {
870        let extensions = SimpleExtensions::default();
871
872        // Test with various whitespace patterns
873        let inputs = vec!["true->false", "true -> false", "true  ->  false"];
874
875        for input in inputs {
876            let pair = parse_exact(Rule::if_clause, input);
877            let result = IfClause::parse_pair(&extensions, pair).unwrap();
878            assert!(result.r#if.is_some());
879            assert!(result.then.is_some());
880        }
881    }
882
883    #[test]
884    fn test_if_clause_structure() {
885        let extensions = SimpleExtensions::default();
886        let pair = parse_exact(Rule::if_clause, "42 -> 100");
887        let result = IfClause::parse_pair(&extensions, pair).unwrap();
888
889        // Verify the if clause has both condition and result
890        let if_expr = result.r#if.as_ref().unwrap();
891        let then_expr = result.then.as_ref().unwrap();
892
893        // Check that they are literal expressions
894        match (&if_expr.rex_type, &then_expr.rex_type) {
895            (Some(RexType::Literal(_)), Some(RexType::Literal(_))) => {
896                // Success - both are literals as expected
897            }
898            _ => panic!("Expected both if and then to be literals"),
899        }
900    }
901
902    #[test]
903    fn test_if_then_structure() {
904        let extensions = SimpleExtensions::default();
905        let input = "if_then(true -> 1, false -> 2, _ -> 0)";
906        let pair = parse_exact(Rule::if_then, input);
907        let result = IfThen::parse_pair(&extensions, pair).unwrap();
908
909        // Verify structure
910        assert_eq!(result.ifs.len(), 2);
911
912        // Check each if clause
913        for clause in &result.ifs {
914            assert!(clause.r#if.is_some(), "If clause condition should exist");
915            assert!(clause.then.is_some(), "If clause result should exist");
916        }
917
918        // Check else clause
919        assert!(result.r#else.is_some(), "Else clause should exist");
920    }
921
922    #[test]
923    fn test_parse_if_then_mixed_types_in_conditions() {
924        let extensions = SimpleExtensions::default();
925        // Different types in conditions (not results)
926        let input = "if_then(true -> 1, true -> 'yes', 'yes' -> true, 42 -> 2, $0 -> 3, _ -> 0)";
927        let pair = parse_exact(Rule::if_then, input);
928        let result = IfThen::parse_pair(&extensions, pair).unwrap();
929
930        assert_eq!(result.ifs.len(), 5);
931        assert!(result.r#else.is_some());
932    }
933
934    #[test]
935    fn test_if_then_preserves_clause_order() {
936        let extensions = SimpleExtensions::default();
937        let input = "if_then(1 -> 10, 2 -> 20, 3 -> 30, _ -> 0)";
938        let pair = parse_exact(Rule::if_then, input);
939        let result = IfThen::parse_pair(&extensions, pair).unwrap();
940
941        assert_eq!(result.ifs.len(), 3);
942
943        // Verify the clauses are in order by checking the literal values
944        for (i, clause) in result.ifs.iter().enumerate() {
945            if let Some(Expression {
946                rex_type: Some(RexType::Literal(lit)),
947            }) = &clause.r#if
948            {
949                if let Some(LiteralType::I64(val)) = &lit.literal_type {
950                    assert_eq!(*val, (i as i64) + 1);
951                }
952            }
953        }
954    }
955
956    #[test]
957    fn test_parse_if_then() {
958        let extensions = SimpleExtensions::default();
959
960        let c1 = IfClause {
961            r#if: Some(make_literal_bool(true)),
962            then: Some(make_literal_bool(true)),
963        };
964
965        let c2 = IfClause {
966            r#if: Some(make_literal_bool(false)),
967            then: Some(make_literal_bool(false)),
968        };
969
970        let if_clause = IfThen {
971            ifs: vec![c1, c2],
972            r#else: Some(Box::new(make_literal_bool(false))),
973        };
974        assert_parses_with(
975            &extensions,
976            "if_then(true -> true , false -> false, _ -> false)",
977            if_clause,
978        );
979    }
980}