substrait_explain/parser/
types.rs

1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::ExtensionKind;
8use crate::parser::{ErrorKind, MessageParseError};
9
10// Given a name and an optional anchor, get the anchor and validate it. Errors will be pushed to the Scope error accumulator,
11// and the anchor will be returned if it is valid.
12pub(crate) fn get_and_validate_anchor(
13    extensions: &SimpleExtensions,
14    kind: ExtensionKind,
15    anchor: Option<u32>,
16    name: &str,
17    span: pest::Span,
18) -> Result<u32, MessageParseError> {
19    match anchor {
20        Some(a) => match extensions.is_name_unique(kind, a, name) {
21            Ok(_) => Ok(a),
22            Err(e) => {
23                let message = "Error matching name to anchor".to_string();
24                let error = MessageParseError {
25                    message: kind.name(),
26                    kind: ErrorKind::Lookup(e),
27                    error: Box::new(pest::error::Error::new_from_span(
28                        pest::error::ErrorVariant::CustomError { message },
29                        span,
30                    )),
31                };
32                Err(error)
33            }
34        },
35        None => match extensions.find_by_name(kind, name) {
36            Ok(a) => Ok(a),
37            Err(e) => {
38                let message = "Error finding extension for name".to_string();
39                let error = MessageParseError {
40                    message: kind.name(),
41                    kind: ErrorKind::Lookup(e),
42                    error: Box::new(pest::error::Error::new_from_span(
43                        pest::error::ErrorVariant::CustomError { message },
44                        span,
45                    )),
46                };
47                Err(error)
48            }
49        },
50    }
51}
52
53impl ParsePair for Nullability {
54    fn rule() -> Rule {
55        Rule::nullability
56    }
57
58    fn message() -> &'static str {
59        "Nullability"
60    }
61
62    fn parse_pair(pair: Pair<Rule>) -> Self {
63        assert_eq!(pair.as_rule(), Rule::nullability);
64        match pair.as_str() {
65            "?" => Nullability::Nullable,
66            "" => Nullability::Required,
67            "⁉" => Nullability::Unspecified,
68            _ => panic!("Invalid nullability: {}", pair.as_str()),
69        }
70    }
71}
72
73impl ScopedParsePair for Parameter {
74    fn rule() -> Rule {
75        Rule::parameter
76    }
77
78    fn message() -> &'static str {
79        "Parameter"
80    }
81
82    fn parse_pair(
83        extensions: &SimpleExtensions,
84        pair: Pair<Rule>,
85    ) -> Result<Self, MessageParseError> {
86        assert_eq!(pair.as_rule(), Rule::parameter);
87        let inner = unwrap_single_pair(pair);
88        match inner.as_rule() {
89            Rule::r#type => Ok(Parameter {
90                parameter: Some(proto::r#type::parameter::Parameter::DataType(
91                    Type::parse_pair(extensions, inner)?,
92                )),
93            }),
94            _ => unimplemented!("{:?}", inner.as_rule()),
95        }
96    }
97}
98
99fn parse_simple_type(pair: Pair<Rule>) -> Type {
100    assert_eq!(pair.as_rule(), Rule::simple_type);
101    let mut iter = iter_pairs(pair.into_inner());
102    let name = iter.pop(Rule::simple_type_name).as_str();
103    let nullability = iter.parse_next::<Nullability>();
104    iter.done();
105    let kind = match name {
106        "boolean" => Kind::Bool(proto::r#type::Boolean {
107            nullability: nullability.into(),
108            type_variation_reference: 0,
109        }),
110        "i64" => Kind::I64(proto::r#type::I64 {
111            nullability: nullability.into(),
112            type_variation_reference: 0,
113        }),
114        "i32" => Kind::I32(proto::r#type::I32 {
115            nullability: nullability.into(),
116            type_variation_reference: 0,
117        }),
118        "i16" => Kind::I16(proto::r#type::I16 {
119            nullability: nullability.into(),
120            type_variation_reference: 0,
121        }),
122        "i8" => Kind::I8(proto::r#type::I8 {
123            nullability: nullability.into(),
124            type_variation_reference: 0,
125        }),
126        "fp32" => Kind::Fp32(proto::r#type::Fp32 {
127            nullability: nullability.into(),
128            type_variation_reference: 0,
129        }),
130        "fp64" => Kind::Fp64(proto::r#type::Fp64 {
131            nullability: nullability.into(),
132            type_variation_reference: 0,
133        }),
134        "string" => Kind::String(proto::r#type::String {
135            nullability: nullability.into(),
136            type_variation_reference: 0,
137        }),
138        "binary" => Kind::Binary(proto::r#type::Binary {
139            nullability: nullability.into(),
140            type_variation_reference: 0,
141        }),
142        "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
143            nullability: nullability.into(),
144            type_variation_reference: 0,
145        }),
146        "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
147            nullability: nullability.into(),
148            type_variation_reference: 0,
149        }),
150        "date" => Kind::Date(proto::r#type::Date {
151            nullability: nullability.into(),
152            type_variation_reference: 0,
153        }),
154        "time" => Kind::Time(proto::r#type::Time {
155            nullability: nullability.into(),
156            type_variation_reference: 0,
157        }),
158        "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
159            nullability: nullability.into(),
160            type_variation_reference: 0,
161        }),
162        "uuid" => Kind::Uuid(proto::r#type::Uuid {
163            nullability: nullability.into(),
164            type_variation_reference: 0,
165        }),
166        _ => unreachable!("Type {} exists in parser but not implemented in code", name),
167    };
168    Type { kind: Some(kind) }
169}
170
171fn parse_compound_type(
172    extensions: &SimpleExtensions,
173    pair: Pair<Rule>,
174) -> Result<Type, MessageParseError> {
175    assert_eq!(pair.as_rule(), Rule::compound_type);
176    let inner = unwrap_single_pair(pair);
177    match inner.as_rule() {
178        Rule::list_type => parse_list_type(extensions, inner),
179        // Rule::map_type => parse_map_type(inner),
180        // Rule::struct_type => parse_struct_type(inner),
181        _ => unimplemented!("{:?}", inner.as_rule()),
182    }
183}
184
185fn parse_list_type(
186    extensions: &SimpleExtensions,
187    pair: Pair<Rule>,
188) -> Result<Type, MessageParseError> {
189    assert_eq!(pair.as_rule(), Rule::list_type);
190    let mut iter = iter_pairs(pair.into_inner());
191    let nullability = iter.parse_next::<Nullability>();
192    let inner = iter.parse_next_scoped::<Type>(extensions)?;
193    iter.done();
194
195    Ok(Type {
196        kind: Some(Kind::List(Box::new(proto::r#type::List {
197            nullability: nullability.into(),
198            r#type: Some(Box::new(inner)),
199            type_variation_reference: 0,
200        }))),
201    })
202}
203
204fn parse_parameters(
205    extensions: &SimpleExtensions,
206    pair: Pair<Rule>,
207) -> Result<Vec<Parameter>, MessageParseError> {
208    assert_eq!(pair.as_rule(), Rule::parameters);
209    let mut iter = iter_pairs(pair.into_inner());
210    let mut params = Vec::new();
211    while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
212        params.push(param?);
213    }
214    iter.done();
215    Ok(params)
216}
217
218fn parse_user_defined_type(
219    extensions: &SimpleExtensions,
220    pair: Pair<Rule>,
221) -> Result<Type, MessageParseError> {
222    let span = pair.as_span();
223    assert_eq!(pair.as_rule(), Rule::user_defined_type);
224    let mut iter = iter_pairs(pair.into_inner());
225    let name = iter.pop(Rule::name).as_str();
226    let anchor = iter
227        .try_pop(Rule::anchor)
228        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
229
230    // TODO: Handle uri_anchor; validate that it matches the anchor
231    let _uri_anchor = iter
232        .try_pop(Rule::uri_anchor)
233        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
234
235    let nullability = iter.parse_next::<Nullability>();
236    let parameters = match iter.try_pop(Rule::parameters) {
237        Some(p) => parse_parameters(extensions, p)?,
238        None => Vec::new(),
239    };
240    iter.done();
241
242    let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
243
244    Ok(Type {
245        kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
246            type_reference: anchor,
247            nullability: nullability.into(),
248            type_parameters: parameters,
249            type_variation_reference: 0,
250        })),
251    })
252}
253
254impl ScopedParsePair for Type {
255    fn rule() -> Rule {
256        Rule::r#type
257    }
258
259    fn message() -> &'static str {
260        "Type"
261    }
262
263    fn parse_pair(
264        extensions: &SimpleExtensions,
265        pair: Pair<Rule>,
266    ) -> Result<Self, MessageParseError> {
267        assert_eq!(pair.as_rule(), Rule::r#type);
268        let inner = unwrap_single_pair(pair);
269        match inner.as_rule() {
270            Rule::simple_type => Ok(parse_simple_type(inner)),
271            Rule::compound_type => parse_compound_type(extensions, inner),
272            Rule::user_defined_type => parse_user_defined_type(extensions, inner),
273            _ => unimplemented!("{:?}", inner.as_rule()),
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use pest::Parser;
281    use substrait::proto::r#type::{I64, Kind, Nullability};
282
283    use super::*;
284    use crate::parser::ExpressionParser;
285
286    #[test]
287    fn test_parse_simple_type() {
288        let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
289        let pair = pairs.next().unwrap();
290        assert_eq!(pairs.next(), None);
291        let t = parse_simple_type(pair);
292        assert_eq!(
293            t,
294            Type {
295                kind: Some(Kind::I64(I64 {
296                    nullability: Nullability::Required as i32,
297                    type_variation_reference: 0,
298                })),
299            }
300        );
301
302        let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
303        let pair = pairs.next().unwrap();
304        assert_eq!(pairs.next(), None);
305        let t = parse_simple_type(pair);
306        assert_eq!(
307            t,
308            Type {
309                kind: Some(Kind::String(proto::r#type::String {
310                    nullability: Nullability::Nullable as i32,
311                    type_variation_reference: 0,
312                })),
313            }
314        );
315    }
316
317    #[test]
318    fn test_parse_type() {
319        let extensions = SimpleExtensions::default();
320        let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
321        let pair = pairs.next().unwrap();
322        assert_eq!(pairs.next(), None);
323        let t = Type::parse_pair(&extensions, pair).unwrap();
324        assert_eq!(
325            t,
326            Type {
327                kind: Some(Kind::I64(I64 {
328                    nullability: Nullability::Required as i32,
329                    type_variation_reference: 0,
330                }))
331            }
332        );
333    }
334
335    #[test]
336    fn test_parse_list_type() {
337        let extensions = SimpleExtensions::default();
338        let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
339        let pair = pairs.next().unwrap();
340        assert_eq!(pairs.next(), None);
341        let t = parse_list_type(&extensions, pair).unwrap();
342        assert_eq!(
343            t,
344            Type {
345                kind: Some(Kind::List(Box::new(proto::r#type::List {
346                    nullability: Nullability::Required as i32,
347                    r#type: Some(Box::new(Type {
348                        kind: Some(Kind::I64(I64 {
349                            nullability: Nullability::Required as i32,
350                            type_variation_reference: 0,
351                        }))
352                    })),
353                    type_variation_reference: 0,
354                })))
355            }
356        );
357    }
358
359    #[test]
360    fn test_parse_parameters() {
361        let extensions = SimpleExtensions::default();
362        let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
363        let pair = pairs.next().unwrap();
364        assert_eq!(pairs.next(), None);
365        let t = parse_parameters(&extensions, pair).unwrap();
366        assert_eq!(
367            t,
368            vec![
369                Parameter {
370                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
371                        kind: Some(Kind::I64(proto::r#type::I64 {
372                            nullability: Nullability::Nullable as i32,
373                            type_variation_reference: 0,
374                        })),
375                    })),
376                },
377                Parameter {
378                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
379                        kind: Some(Kind::String(proto::r#type::String {
380                            nullability: Nullability::Required as i32,
381                            type_variation_reference: 0,
382                        })),
383                    })),
384                },
385            ]
386        );
387    }
388
389    #[test]
390    fn test_udts() {
391        let mut extensions = SimpleExtensions::default();
392        extensions
393            .add_extension_uri("some_source".to_string(), 4)
394            .unwrap();
395        extensions
396            .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
397            .unwrap();
398        let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
399        let pair = pairs.next().unwrap();
400        assert_eq!(pairs.next(), None);
401
402        let t = parse_user_defined_type(&extensions, pair).unwrap();
403        assert_eq!(
404            t,
405            Type {
406                kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
407                    type_reference: 42,
408                    type_variation_reference: 0,
409                    nullability: Nullability::Required as i32,
410                    type_parameters: vec![Parameter {
411                        parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
412                            kind: Some(Kind::I64(proto::r#type::I64 {
413                                nullability: Nullability::Nullable as i32,
414                                type_variation_reference: 0,
415                            })),
416                        })),
417                    }],
418                }))
419            }
420        );
421    }
422}