Skip to main content

substrait_explain/parser/
types.rs

1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::{ExtensionKind, MissingReference};
8use crate::parser::MessageParseError;
9
10/// Given a name (plain or compound) and an optional anchor, resolve and
11/// validate the extension anchor.
12///
13/// For `Function` kind, delegates to [`SimpleExtensions::resolve_function`]
14/// which encapsulates the full base-name fallback logic.
15///
16/// For other kinds, performs a direct anchor/name validation.
17pub(crate) fn get_and_validate_anchor(
18    extensions: &SimpleExtensions,
19    kind: ExtensionKind,
20    anchor: Option<u32>,
21    name: &str,
22    span: pest::Span,
23) -> Result<u32, MessageParseError> {
24    if kind == ExtensionKind::Function {
25        return extensions
26            .resolve_function(name, anchor)
27            .map(|r| r.anchor)
28            .map_err(|e| {
29                MessageParseError::lookup(kind.name(), e, span, "Error resolving function")
30            });
31    }
32    // For non-function kinds, validate the anchor/name pair directly.
33    match anchor {
34        Some(a) => match extensions.find_by_anchor(kind, a) {
35            Err(e) => Err(MessageParseError::lookup(
36                kind.name(),
37                e,
38                span,
39                "Error matching name to anchor",
40            )),
41            Ok((_, stored)) => {
42                if stored.full() == name || stored.base() == name {
43                    Ok(a)
44                } else {
45                    Err(MessageParseError::lookup(
46                        kind.name(),
47                        MissingReference::Mismatched(kind, name.to_string(), a),
48                        span,
49                        "Error matching name to anchor",
50                    ))
51                }
52            }
53        },
54        None => extensions.find_by_name(kind, name).map_err(|e| {
55            MessageParseError::lookup(kind.name(), e, span, "Error finding extension for name")
56        }),
57    }
58}
59
60impl ParsePair for Nullability {
61    fn rule() -> Rule {
62        Rule::nullability
63    }
64
65    fn message() -> &'static str {
66        "Nullability"
67    }
68
69    fn parse_pair(pair: Pair<Rule>) -> Self {
70        assert_eq!(pair.as_rule(), Rule::nullability);
71        match pair.as_str() {
72            "?" => Nullability::Nullable,
73            "" => Nullability::Required,
74            "⁉" => Nullability::Unspecified,
75            _ => panic!("Invalid nullability: {}", pair.as_str()),
76        }
77    }
78}
79
80impl ScopedParsePair for Parameter {
81    fn rule() -> Rule {
82        Rule::parameter
83    }
84
85    fn message() -> &'static str {
86        "Parameter"
87    }
88
89    fn parse_pair(
90        extensions: &SimpleExtensions,
91        pair: Pair<Rule>,
92    ) -> Result<Self, MessageParseError> {
93        assert_eq!(pair.as_rule(), Rule::parameter);
94        let inner = unwrap_single_pair(pair);
95        match inner.as_rule() {
96            Rule::r#type => Ok(Parameter {
97                parameter: Some(proto::r#type::parameter::Parameter::DataType(
98                    Type::parse_pair(extensions, inner)?,
99                )),
100            }),
101            _ => unimplemented!("{:?}", inner.as_rule()),
102        }
103    }
104}
105
106fn parse_simple_type(pair: Pair<Rule>) -> Type {
107    assert_eq!(pair.as_rule(), Rule::simple_type);
108    let mut iter = iter_pairs(pair.into_inner());
109    let name = iter.pop(Rule::simple_type_name).as_str();
110    let nullability = iter.parse_next::<Nullability>();
111    iter.done();
112
113    let kind = match name {
114        "boolean" => Kind::Bool(proto::r#type::Boolean {
115            nullability: nullability.into(),
116            type_variation_reference: 0,
117        }),
118        "i64" => Kind::I64(proto::r#type::I64 {
119            nullability: nullability.into(),
120            type_variation_reference: 0,
121        }),
122        "i32" => Kind::I32(proto::r#type::I32 {
123            nullability: nullability.into(),
124            type_variation_reference: 0,
125        }),
126        "i16" => Kind::I16(proto::r#type::I16 {
127            nullability: nullability.into(),
128            type_variation_reference: 0,
129        }),
130        "i8" => Kind::I8(proto::r#type::I8 {
131            nullability: nullability.into(),
132            type_variation_reference: 0,
133        }),
134        "fp32" => Kind::Fp32(proto::r#type::Fp32 {
135            nullability: nullability.into(),
136            type_variation_reference: 0,
137        }),
138        "fp64" => Kind::Fp64(proto::r#type::Fp64 {
139            nullability: nullability.into(),
140            type_variation_reference: 0,
141        }),
142        "string" => Kind::String(proto::r#type::String {
143            nullability: nullability.into(),
144            type_variation_reference: 0,
145        }),
146        "binary" => Kind::Binary(proto::r#type::Binary {
147            nullability: nullability.into(),
148            type_variation_reference: 0,
149        }),
150        #[allow(deprecated)]
151        "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
152            nullability: nullability.into(),
153            type_variation_reference: 0,
154        }),
155        #[allow(deprecated)]
156        "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
157            nullability: nullability.into(),
158            type_variation_reference: 0,
159        }),
160        "date" => Kind::Date(proto::r#type::Date {
161            nullability: nullability.into(),
162            type_variation_reference: 0,
163        }),
164        #[allow(deprecated)]
165        "time" => Kind::Time(proto::r#type::Time {
166            nullability: nullability.into(),
167            type_variation_reference: 0,
168        }),
169        "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
170            nullability: nullability.into(),
171            type_variation_reference: 0,
172        }),
173        "uuid" => Kind::Uuid(proto::r#type::Uuid {
174            nullability: nullability.into(),
175            type_variation_reference: 0,
176        }),
177        _ => unreachable!("Type {} exists in parser but not implemented in code", name),
178    };
179    Type { kind: Some(kind) }
180}
181
182fn parse_compound_type(
183    extensions: &SimpleExtensions,
184    pair: Pair<Rule>,
185) -> Result<Type, MessageParseError> {
186    assert_eq!(pair.as_rule(), Rule::compound_type);
187    let inner = unwrap_single_pair(pair);
188    match inner.as_rule() {
189        Rule::list_type => parse_list_type(extensions, inner),
190        // Rule::map_type => parse_map_type(inner),
191        // Rule::struct_type => parse_struct_type(inner),
192        Rule::precision_timestamp_tz_type
193        | Rule::precision_timestamp_type
194        | Rule::precision_time_type => parse_precision_type(inner),
195        _ => unimplemented!("{:?}", inner.as_rule()),
196    }
197}
198
199fn parse_precision_type(pair: Pair<Rule>) -> Result<Type, MessageParseError> {
200    let rule = pair.as_rule();
201    let mut iter = iter_pairs(pair.into_inner());
202    let nullability = iter.parse_next::<Nullability>();
203    let precision_pair = iter.pop(Rule::integer);
204    let precision_span = precision_pair.as_span();
205    let precision = precision_pair.as_str().parse::<i32>().unwrap();
206    if !(0..=12).contains(&precision) {
207        return Err(MessageParseError::invalid(
208            "precision time type",
209            precision_span,
210            format!("precision must be between 0 and 12, got {precision}"),
211        ));
212    }
213    iter.done();
214    let kind = match rule {
215        Rule::precision_timestamp_type => {
216            Kind::PrecisionTimestamp(proto::r#type::PrecisionTimestamp {
217                precision,
218                nullability: nullability.into(),
219                type_variation_reference: 0,
220            })
221        }
222        Rule::precision_timestamp_tz_type => {
223            Kind::PrecisionTimestampTz(proto::r#type::PrecisionTimestampTz {
224                precision,
225                nullability: nullability.into(),
226                type_variation_reference: 0,
227            })
228        }
229        Rule::precision_time_type => Kind::PrecisionTime(proto::r#type::PrecisionTime {
230            precision,
231            nullability: nullability.into(),
232            type_variation_reference: 0,
233        }),
234        _ => unreachable!("parse_precision_type called with rule {:?}", rule),
235    };
236    Ok(Type { kind: Some(kind) })
237}
238
239fn parse_list_type(
240    extensions: &SimpleExtensions,
241    pair: Pair<Rule>,
242) -> Result<Type, MessageParseError> {
243    assert_eq!(pair.as_rule(), Rule::list_type);
244    let mut iter = iter_pairs(pair.into_inner());
245    let nullability = iter.parse_next::<Nullability>();
246    let inner = iter.parse_next_scoped::<Type>(extensions)?;
247    iter.done();
248
249    Ok(Type {
250        kind: Some(Kind::List(Box::new(proto::r#type::List {
251            nullability: nullability.into(),
252            r#type: Some(Box::new(inner)),
253            type_variation_reference: 0,
254        }))),
255    })
256}
257
258fn parse_parameters(
259    extensions: &SimpleExtensions,
260    pair: Pair<Rule>,
261) -> Result<Vec<Parameter>, MessageParseError> {
262    assert_eq!(pair.as_rule(), Rule::parameters);
263    let mut iter = iter_pairs(pair.into_inner());
264    let mut params = Vec::new();
265    while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
266        params.push(param?);
267    }
268    iter.done();
269    Ok(params)
270}
271
272fn parse_user_defined_type(
273    extensions: &SimpleExtensions,
274    pair: Pair<Rule>,
275) -> Result<Type, MessageParseError> {
276    let span = pair.as_span();
277    assert_eq!(pair.as_rule(), Rule::user_defined_type);
278    let mut iter = iter_pairs(pair.into_inner());
279    let name = iter.pop(Rule::name).as_str();
280    let anchor = iter
281        .try_pop(Rule::anchor)
282        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
283
284    // TODO: Handle urn_anchor; validate that it matches the anchor
285    let _urn_anchor = iter
286        .try_pop(Rule::urn_anchor)
287        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
288
289    let nullability = iter.parse_next::<Nullability>();
290    let parameters = match iter.try_pop(Rule::parameters) {
291        Some(p) => parse_parameters(extensions, p)?,
292        None => Vec::new(),
293    };
294    iter.done();
295
296    let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
297
298    Ok(Type {
299        kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
300            type_reference: anchor,
301            nullability: nullability.into(),
302            type_parameters: parameters,
303            type_variation_reference: 0,
304        })),
305    })
306}
307
308impl ScopedParsePair for Type {
309    fn rule() -> Rule {
310        Rule::r#type
311    }
312
313    fn message() -> &'static str {
314        "Type"
315    }
316
317    fn parse_pair(
318        extensions: &SimpleExtensions,
319        pair: Pair<Rule>,
320    ) -> Result<Self, MessageParseError> {
321        assert_eq!(pair.as_rule(), Rule::r#type);
322        let inner = unwrap_single_pair(pair);
323        match inner.as_rule() {
324            Rule::simple_type => Ok(parse_simple_type(inner)),
325            Rule::compound_type => parse_compound_type(extensions, inner),
326            Rule::user_defined_type => parse_user_defined_type(extensions, inner),
327            _ => unreachable!(
328                "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}",
329                inner.as_rule()
330            ),
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use pest::Parser;
338    use substrait::proto::r#type::{I64, Kind, Nullability};
339
340    use super::*;
341    use crate::parser::ExpressionParser;
342
343    #[test]
344    fn test_parse_simple_type() {
345        let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
346        let pair = pairs.next().unwrap();
347        assert_eq!(pairs.next(), None);
348        let t = parse_simple_type(pair);
349        assert_eq!(
350            t,
351            Type {
352                kind: Some(Kind::I64(I64 {
353                    nullability: Nullability::Required as i32,
354                    type_variation_reference: 0,
355                })),
356            }
357        );
358
359        let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
360        let pair = pairs.next().unwrap();
361        assert_eq!(pairs.next(), None);
362        let t = parse_simple_type(pair);
363        assert_eq!(
364            t,
365            Type {
366                kind: Some(Kind::String(proto::r#type::String {
367                    nullability: Nullability::Nullable as i32,
368                    type_variation_reference: 0,
369                })),
370            }
371        );
372    }
373
374    #[test]
375    fn test_parse_type() {
376        let extensions = SimpleExtensions::default();
377        let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
378        let pair = pairs.next().unwrap();
379        assert_eq!(pairs.next(), None);
380        let t = Type::parse_pair(&extensions, pair).unwrap();
381        assert_eq!(
382            t,
383            Type {
384                kind: Some(Kind::I64(I64 {
385                    nullability: Nullability::Required as i32,
386                    type_variation_reference: 0,
387                }))
388            }
389        );
390    }
391
392    #[test]
393    fn test_parse_list_type() {
394        let extensions = SimpleExtensions::default();
395        let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
396        let pair = pairs.next().unwrap();
397        assert_eq!(pairs.next(), None);
398        let t = parse_list_type(&extensions, pair).unwrap();
399        assert_eq!(
400            t,
401            Type {
402                kind: Some(Kind::List(Box::new(proto::r#type::List {
403                    nullability: Nullability::Required as i32,
404                    r#type: Some(Box::new(Type {
405                        kind: Some(Kind::I64(I64 {
406                            nullability: Nullability::Required as i32,
407                            type_variation_reference: 0,
408                        }))
409                    })),
410                    type_variation_reference: 0,
411                })))
412            }
413        );
414    }
415
416    #[test]
417    fn test_parse_parameters() {
418        let extensions = SimpleExtensions::default();
419        let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
420        let pair = pairs.next().unwrap();
421        assert_eq!(pairs.next(), None);
422        let t = parse_parameters(&extensions, pair).unwrap();
423        assert_eq!(
424            t,
425            vec![
426                Parameter {
427                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
428                        kind: Some(Kind::I64(proto::r#type::I64 {
429                            nullability: Nullability::Nullable as i32,
430                            type_variation_reference: 0,
431                        })),
432                    })),
433                },
434                Parameter {
435                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
436                        kind: Some(Kind::String(proto::r#type::String {
437                            nullability: Nullability::Required as i32,
438                            type_variation_reference: 0,
439                        })),
440                    })),
441                },
442            ]
443        );
444    }
445
446    #[test]
447    fn test_udts() {
448        let mut extensions = SimpleExtensions::default();
449        extensions
450            .add_extension_urn("some_source".to_string(), 4)
451            .unwrap();
452        extensions
453            .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
454            .unwrap();
455        let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
456        let pair = pairs.next().unwrap();
457        assert_eq!(pairs.next(), None);
458
459        let t = parse_user_defined_type(&extensions, pair).unwrap();
460        assert_eq!(
461            t,
462            Type {
463                kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
464                    type_reference: 42,
465                    type_variation_reference: 0,
466                    nullability: Nullability::Required as i32,
467                    type_parameters: vec![Parameter {
468                        parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
469                            kind: Some(Kind::I64(proto::r#type::I64 {
470                                nullability: Nullability::Nullable as i32,
471                                type_variation_reference: 0,
472                            })),
473                        })),
474                    }],
475                }))
476            }
477        );
478    }
479}