substrait_explain/parser/
types.rs

1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::ExtensionKind;
8use crate::parser::{ErrorKind, MessageParseError};
9
10// Given a name and an optional anchor, get the anchor and validate it. Errors will be pushed to the Scope error accumulator,
11// and the anchor will be returned if it is valid.
12pub(crate) fn get_and_validate_anchor(
13    extensions: &SimpleExtensions,
14    kind: ExtensionKind,
15    anchor: Option<u32>,
16    name: &str,
17    span: pest::Span,
18) -> Result<u32, MessageParseError> {
19    match anchor {
20        Some(a) => match extensions.is_name_unique(kind, a, name) {
21            Ok(_) => Ok(a),
22            Err(e) => {
23                let message = "Error matching name to anchor".to_string();
24                let error = MessageParseError {
25                    message: kind.name(),
26                    kind: ErrorKind::Lookup(e),
27                    error: Box::new(pest::error::Error::new_from_span(
28                        pest::error::ErrorVariant::CustomError { message },
29                        span,
30                    )),
31                };
32                Err(error)
33            }
34        },
35        None => match extensions.find_by_name(kind, name) {
36            Ok(a) => Ok(a),
37            Err(e) => {
38                let message = "Error finding extension for name".to_string();
39                let error = MessageParseError {
40                    message: kind.name(),
41                    kind: ErrorKind::Lookup(e),
42                    error: Box::new(pest::error::Error::new_from_span(
43                        pest::error::ErrorVariant::CustomError { message },
44                        span,
45                    )),
46                };
47                Err(error)
48            }
49        },
50    }
51}
52
53impl ParsePair for Nullability {
54    fn rule() -> Rule {
55        Rule::nullability
56    }
57
58    fn message() -> &'static str {
59        "Nullability"
60    }
61
62    fn parse_pair(pair: Pair<Rule>) -> Self {
63        assert_eq!(pair.as_rule(), Rule::nullability);
64        match pair.as_str() {
65            "?" => Nullability::Nullable,
66            "" => Nullability::Required,
67            "⁉" => Nullability::Unspecified,
68            _ => panic!("Invalid nullability: {}", pair.as_str()),
69        }
70    }
71}
72
73impl ScopedParsePair for Parameter {
74    fn rule() -> Rule {
75        Rule::parameter
76    }
77
78    fn message() -> &'static str {
79        "Parameter"
80    }
81
82    fn parse_pair(
83        extensions: &SimpleExtensions,
84        pair: Pair<Rule>,
85    ) -> Result<Self, MessageParseError> {
86        assert_eq!(pair.as_rule(), Rule::parameter);
87        let inner = unwrap_single_pair(pair);
88        match inner.as_rule() {
89            Rule::r#type => Ok(Parameter {
90                parameter: Some(proto::r#type::parameter::Parameter::DataType(
91                    Type::parse_pair(extensions, inner)?,
92                )),
93            }),
94            _ => unimplemented!("{:?}", inner.as_rule()),
95        }
96    }
97}
98
99fn parse_simple_type(pair: Pair<Rule>) -> Type {
100    assert_eq!(pair.as_rule(), Rule::simple_type);
101    let mut iter = iter_pairs(pair.into_inner());
102    let name = iter.pop(Rule::simple_type_name).as_str();
103    let nullability = iter.parse_next::<Nullability>();
104    iter.done();
105
106    let kind = match name {
107        "boolean" => Kind::Bool(proto::r#type::Boolean {
108            nullability: nullability.into(),
109            type_variation_reference: 0,
110        }),
111        "i64" => Kind::I64(proto::r#type::I64 {
112            nullability: nullability.into(),
113            type_variation_reference: 0,
114        }),
115        "i32" => Kind::I32(proto::r#type::I32 {
116            nullability: nullability.into(),
117            type_variation_reference: 0,
118        }),
119        "i16" => Kind::I16(proto::r#type::I16 {
120            nullability: nullability.into(),
121            type_variation_reference: 0,
122        }),
123        "i8" => Kind::I8(proto::r#type::I8 {
124            nullability: nullability.into(),
125            type_variation_reference: 0,
126        }),
127        "fp32" => Kind::Fp32(proto::r#type::Fp32 {
128            nullability: nullability.into(),
129            type_variation_reference: 0,
130        }),
131        "fp64" => Kind::Fp64(proto::r#type::Fp64 {
132            nullability: nullability.into(),
133            type_variation_reference: 0,
134        }),
135        "string" => Kind::String(proto::r#type::String {
136            nullability: nullability.into(),
137            type_variation_reference: 0,
138        }),
139        "binary" => Kind::Binary(proto::r#type::Binary {
140            nullability: nullability.into(),
141            type_variation_reference: 0,
142        }),
143        #[allow(deprecated)]
144        "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
145            nullability: nullability.into(),
146            type_variation_reference: 0,
147        }),
148        #[allow(deprecated)]
149        "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
150            nullability: nullability.into(),
151            type_variation_reference: 0,
152        }),
153        "date" => Kind::Date(proto::r#type::Date {
154            nullability: nullability.into(),
155            type_variation_reference: 0,
156        }),
157        "time" => Kind::Time(proto::r#type::Time {
158            nullability: nullability.into(),
159            type_variation_reference: 0,
160        }),
161        "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
162            nullability: nullability.into(),
163            type_variation_reference: 0,
164        }),
165        "uuid" => Kind::Uuid(proto::r#type::Uuid {
166            nullability: nullability.into(),
167            type_variation_reference: 0,
168        }),
169        _ => unreachable!("Type {} exists in parser but not implemented in code", name),
170    };
171    Type { kind: Some(kind) }
172}
173
174fn parse_compound_type(
175    extensions: &SimpleExtensions,
176    pair: Pair<Rule>,
177) -> Result<Type, MessageParseError> {
178    assert_eq!(pair.as_rule(), Rule::compound_type);
179    let inner = unwrap_single_pair(pair);
180    match inner.as_rule() {
181        Rule::list_type => parse_list_type(extensions, inner),
182        // Rule::map_type => parse_map_type(inner),
183        // Rule::struct_type => parse_struct_type(inner),
184        _ => unimplemented!("{:?}", inner.as_rule()),
185    }
186}
187
188fn parse_list_type(
189    extensions: &SimpleExtensions,
190    pair: Pair<Rule>,
191) -> Result<Type, MessageParseError> {
192    assert_eq!(pair.as_rule(), Rule::list_type);
193    let mut iter = iter_pairs(pair.into_inner());
194    let nullability = iter.parse_next::<Nullability>();
195    let inner = iter.parse_next_scoped::<Type>(extensions)?;
196    iter.done();
197
198    Ok(Type {
199        kind: Some(Kind::List(Box::new(proto::r#type::List {
200            nullability: nullability.into(),
201            r#type: Some(Box::new(inner)),
202            type_variation_reference: 0,
203        }))),
204    })
205}
206
207fn parse_parameters(
208    extensions: &SimpleExtensions,
209    pair: Pair<Rule>,
210) -> Result<Vec<Parameter>, MessageParseError> {
211    assert_eq!(pair.as_rule(), Rule::parameters);
212    let mut iter = iter_pairs(pair.into_inner());
213    let mut params = Vec::new();
214    while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
215        params.push(param?);
216    }
217    iter.done();
218    Ok(params)
219}
220
221fn parse_user_defined_type(
222    extensions: &SimpleExtensions,
223    pair: Pair<Rule>,
224) -> Result<Type, MessageParseError> {
225    let span = pair.as_span();
226    assert_eq!(pair.as_rule(), Rule::user_defined_type);
227    let mut iter = iter_pairs(pair.into_inner());
228    let name = iter.pop(Rule::name).as_str();
229    let anchor = iter
230        .try_pop(Rule::anchor)
231        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
232
233    // TODO: Handle urn_anchor; validate that it matches the anchor
234    let _urn_anchor = iter
235        .try_pop(Rule::urn_anchor)
236        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
237
238    let nullability = iter.parse_next::<Nullability>();
239    let parameters = match iter.try_pop(Rule::parameters) {
240        Some(p) => parse_parameters(extensions, p)?,
241        None => Vec::new(),
242    };
243    iter.done();
244
245    let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
246
247    Ok(Type {
248        kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
249            type_reference: anchor,
250            nullability: nullability.into(),
251            type_parameters: parameters,
252            type_variation_reference: 0,
253        })),
254    })
255}
256
257impl ScopedParsePair for Type {
258    fn rule() -> Rule {
259        Rule::r#type
260    }
261
262    fn message() -> &'static str {
263        "Type"
264    }
265
266    fn parse_pair(
267        extensions: &SimpleExtensions,
268        pair: Pair<Rule>,
269    ) -> Result<Self, MessageParseError> {
270        assert_eq!(pair.as_rule(), Rule::r#type);
271        let inner = unwrap_single_pair(pair);
272        match inner.as_rule() {
273            Rule::simple_type => Ok(parse_simple_type(inner)),
274            Rule::compound_type => parse_compound_type(extensions, inner),
275            Rule::user_defined_type => parse_user_defined_type(extensions, inner),
276            _ => unreachable!(
277                "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}",
278                inner.as_rule()
279            ),
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use pest::Parser;
287    use substrait::proto::r#type::{I64, Kind, Nullability};
288
289    use super::*;
290    use crate::parser::ExpressionParser;
291
292    #[test]
293    fn test_parse_simple_type() {
294        let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
295        let pair = pairs.next().unwrap();
296        assert_eq!(pairs.next(), None);
297        let t = parse_simple_type(pair);
298        assert_eq!(
299            t,
300            Type {
301                kind: Some(Kind::I64(I64 {
302                    nullability: Nullability::Required as i32,
303                    type_variation_reference: 0,
304                })),
305            }
306        );
307
308        let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
309        let pair = pairs.next().unwrap();
310        assert_eq!(pairs.next(), None);
311        let t = parse_simple_type(pair);
312        assert_eq!(
313            t,
314            Type {
315                kind: Some(Kind::String(proto::r#type::String {
316                    nullability: Nullability::Nullable as i32,
317                    type_variation_reference: 0,
318                })),
319            }
320        );
321    }
322
323    #[test]
324    fn test_parse_type() {
325        let extensions = SimpleExtensions::default();
326        let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
327        let pair = pairs.next().unwrap();
328        assert_eq!(pairs.next(), None);
329        let t = Type::parse_pair(&extensions, pair).unwrap();
330        assert_eq!(
331            t,
332            Type {
333                kind: Some(Kind::I64(I64 {
334                    nullability: Nullability::Required as i32,
335                    type_variation_reference: 0,
336                }))
337            }
338        );
339    }
340
341    #[test]
342    fn test_parse_list_type() {
343        let extensions = SimpleExtensions::default();
344        let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
345        let pair = pairs.next().unwrap();
346        assert_eq!(pairs.next(), None);
347        let t = parse_list_type(&extensions, pair).unwrap();
348        assert_eq!(
349            t,
350            Type {
351                kind: Some(Kind::List(Box::new(proto::r#type::List {
352                    nullability: Nullability::Required as i32,
353                    r#type: Some(Box::new(Type {
354                        kind: Some(Kind::I64(I64 {
355                            nullability: Nullability::Required as i32,
356                            type_variation_reference: 0,
357                        }))
358                    })),
359                    type_variation_reference: 0,
360                })))
361            }
362        );
363    }
364
365    #[test]
366    fn test_parse_parameters() {
367        let extensions = SimpleExtensions::default();
368        let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
369        let pair = pairs.next().unwrap();
370        assert_eq!(pairs.next(), None);
371        let t = parse_parameters(&extensions, pair).unwrap();
372        assert_eq!(
373            t,
374            vec![
375                Parameter {
376                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
377                        kind: Some(Kind::I64(proto::r#type::I64 {
378                            nullability: Nullability::Nullable as i32,
379                            type_variation_reference: 0,
380                        })),
381                    })),
382                },
383                Parameter {
384                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
385                        kind: Some(Kind::String(proto::r#type::String {
386                            nullability: Nullability::Required as i32,
387                            type_variation_reference: 0,
388                        })),
389                    })),
390                },
391            ]
392        );
393    }
394
395    #[test]
396    fn test_udts() {
397        let mut extensions = SimpleExtensions::default();
398        extensions
399            .add_extension_urn("some_source".to_string(), 4)
400            .unwrap();
401        extensions
402            .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
403            .unwrap();
404        let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
405        let pair = pairs.next().unwrap();
406        assert_eq!(pairs.next(), None);
407
408        let t = parse_user_defined_type(&extensions, pair).unwrap();
409        assert_eq!(
410            t,
411            Type {
412                kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
413                    type_reference: 42,
414                    type_variation_reference: 0,
415                    nullability: Nullability::Required as i32,
416                    type_parameters: vec![Parameter {
417                        parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
418                            kind: Some(Kind::I64(proto::r#type::I64 {
419                                nullability: Nullability::Nullable as i32,
420                                type_variation_reference: 0,
421                            })),
422                        })),
423                    }],
424                }))
425            }
426        );
427    }
428}