substrait_explain/parser/
types.rs

1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::{ExtensionKind, MissingReference};
8use crate::parser::MessageParseError;
9
10/// Given a name (plain or compound) and an optional anchor, resolve and
11/// validate the extension anchor.
12///
13/// For `Function` kind, delegates to [`SimpleExtensions::resolve_function`]
14/// which encapsulates the full base-name fallback logic.
15///
16/// For other kinds, performs a direct anchor/name validation.
17pub(crate) fn get_and_validate_anchor(
18    extensions: &SimpleExtensions,
19    kind: ExtensionKind,
20    anchor: Option<u32>,
21    name: &str,
22    span: pest::Span,
23) -> Result<u32, MessageParseError> {
24    if kind == ExtensionKind::Function {
25        return extensions
26            .resolve_function(name, anchor)
27            .map(|r| r.anchor)
28            .map_err(|e| {
29                MessageParseError::lookup(kind.name(), e, span, "Error resolving function")
30            });
31    }
32    // For non-function kinds, validate the anchor/name pair directly.
33    match anchor {
34        Some(a) => match extensions.find_by_anchor(kind, a) {
35            Err(e) => Err(MessageParseError::lookup(
36                kind.name(),
37                e,
38                span,
39                "Error matching name to anchor",
40            )),
41            Ok((_, stored)) => {
42                if stored.full() == name || stored.base() == name {
43                    Ok(a)
44                } else {
45                    Err(MessageParseError::lookup(
46                        kind.name(),
47                        MissingReference::Mismatched(kind, name.to_string(), a),
48                        span,
49                        "Error matching name to anchor",
50                    ))
51                }
52            }
53        },
54        None => extensions.find_by_name(kind, name).map_err(|e| {
55            MessageParseError::lookup(kind.name(), e, span, "Error finding extension for name")
56        }),
57    }
58}
59
60impl ParsePair for Nullability {
61    fn rule() -> Rule {
62        Rule::nullability
63    }
64
65    fn message() -> &'static str {
66        "Nullability"
67    }
68
69    fn parse_pair(pair: Pair<Rule>) -> Self {
70        assert_eq!(pair.as_rule(), Rule::nullability);
71        match pair.as_str() {
72            "?" => Nullability::Nullable,
73            "" => Nullability::Required,
74            "⁉" => Nullability::Unspecified,
75            _ => panic!("Invalid nullability: {}", pair.as_str()),
76        }
77    }
78}
79
80impl ScopedParsePair for Parameter {
81    fn rule() -> Rule {
82        Rule::parameter
83    }
84
85    fn message() -> &'static str {
86        "Parameter"
87    }
88
89    fn parse_pair(
90        extensions: &SimpleExtensions,
91        pair: Pair<Rule>,
92    ) -> Result<Self, MessageParseError> {
93        assert_eq!(pair.as_rule(), Rule::parameter);
94        let inner = unwrap_single_pair(pair);
95        match inner.as_rule() {
96            Rule::r#type => Ok(Parameter {
97                parameter: Some(proto::r#type::parameter::Parameter::DataType(
98                    Type::parse_pair(extensions, inner)?,
99                )),
100            }),
101            _ => unimplemented!("{:?}", inner.as_rule()),
102        }
103    }
104}
105
106fn parse_simple_type(pair: Pair<Rule>) -> Type {
107    assert_eq!(pair.as_rule(), Rule::simple_type);
108    let mut iter = iter_pairs(pair.into_inner());
109    let name = iter.pop(Rule::simple_type_name).as_str();
110    let nullability = iter.parse_next::<Nullability>();
111    iter.done();
112
113    let kind = match name {
114        "boolean" => Kind::Bool(proto::r#type::Boolean {
115            nullability: nullability.into(),
116            type_variation_reference: 0,
117        }),
118        "i64" => Kind::I64(proto::r#type::I64 {
119            nullability: nullability.into(),
120            type_variation_reference: 0,
121        }),
122        "i32" => Kind::I32(proto::r#type::I32 {
123            nullability: nullability.into(),
124            type_variation_reference: 0,
125        }),
126        "i16" => Kind::I16(proto::r#type::I16 {
127            nullability: nullability.into(),
128            type_variation_reference: 0,
129        }),
130        "i8" => Kind::I8(proto::r#type::I8 {
131            nullability: nullability.into(),
132            type_variation_reference: 0,
133        }),
134        "fp32" => Kind::Fp32(proto::r#type::Fp32 {
135            nullability: nullability.into(),
136            type_variation_reference: 0,
137        }),
138        "fp64" => Kind::Fp64(proto::r#type::Fp64 {
139            nullability: nullability.into(),
140            type_variation_reference: 0,
141        }),
142        "string" => Kind::String(proto::r#type::String {
143            nullability: nullability.into(),
144            type_variation_reference: 0,
145        }),
146        "binary" => Kind::Binary(proto::r#type::Binary {
147            nullability: nullability.into(),
148            type_variation_reference: 0,
149        }),
150        #[allow(deprecated)]
151        "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
152            nullability: nullability.into(),
153            type_variation_reference: 0,
154        }),
155        #[allow(deprecated)]
156        "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
157            nullability: nullability.into(),
158            type_variation_reference: 0,
159        }),
160        "date" => Kind::Date(proto::r#type::Date {
161            nullability: nullability.into(),
162            type_variation_reference: 0,
163        }),
164        "time" => Kind::Time(proto::r#type::Time {
165            nullability: nullability.into(),
166            type_variation_reference: 0,
167        }),
168        "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
169            nullability: nullability.into(),
170            type_variation_reference: 0,
171        }),
172        "uuid" => Kind::Uuid(proto::r#type::Uuid {
173            nullability: nullability.into(),
174            type_variation_reference: 0,
175        }),
176        _ => unreachable!("Type {} exists in parser but not implemented in code", name),
177    };
178    Type { kind: Some(kind) }
179}
180
181fn parse_compound_type(
182    extensions: &SimpleExtensions,
183    pair: Pair<Rule>,
184) -> Result<Type, MessageParseError> {
185    assert_eq!(pair.as_rule(), Rule::compound_type);
186    let inner = unwrap_single_pair(pair);
187    match inner.as_rule() {
188        Rule::list_type => parse_list_type(extensions, inner),
189        // Rule::map_type => parse_map_type(inner),
190        // Rule::struct_type => parse_struct_type(inner),
191        _ => unimplemented!("{:?}", inner.as_rule()),
192    }
193}
194
195fn parse_list_type(
196    extensions: &SimpleExtensions,
197    pair: Pair<Rule>,
198) -> Result<Type, MessageParseError> {
199    assert_eq!(pair.as_rule(), Rule::list_type);
200    let mut iter = iter_pairs(pair.into_inner());
201    let nullability = iter.parse_next::<Nullability>();
202    let inner = iter.parse_next_scoped::<Type>(extensions)?;
203    iter.done();
204
205    Ok(Type {
206        kind: Some(Kind::List(Box::new(proto::r#type::List {
207            nullability: nullability.into(),
208            r#type: Some(Box::new(inner)),
209            type_variation_reference: 0,
210        }))),
211    })
212}
213
214fn parse_parameters(
215    extensions: &SimpleExtensions,
216    pair: Pair<Rule>,
217) -> Result<Vec<Parameter>, MessageParseError> {
218    assert_eq!(pair.as_rule(), Rule::parameters);
219    let mut iter = iter_pairs(pair.into_inner());
220    let mut params = Vec::new();
221    while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
222        params.push(param?);
223    }
224    iter.done();
225    Ok(params)
226}
227
228fn parse_user_defined_type(
229    extensions: &SimpleExtensions,
230    pair: Pair<Rule>,
231) -> Result<Type, MessageParseError> {
232    let span = pair.as_span();
233    assert_eq!(pair.as_rule(), Rule::user_defined_type);
234    let mut iter = iter_pairs(pair.into_inner());
235    let name = iter.pop(Rule::name).as_str();
236    let anchor = iter
237        .try_pop(Rule::anchor)
238        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
239
240    // TODO: Handle urn_anchor; validate that it matches the anchor
241    let _urn_anchor = iter
242        .try_pop(Rule::urn_anchor)
243        .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
244
245    let nullability = iter.parse_next::<Nullability>();
246    let parameters = match iter.try_pop(Rule::parameters) {
247        Some(p) => parse_parameters(extensions, p)?,
248        None => Vec::new(),
249    };
250    iter.done();
251
252    let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
253
254    Ok(Type {
255        kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
256            type_reference: anchor,
257            nullability: nullability.into(),
258            type_parameters: parameters,
259            type_variation_reference: 0,
260        })),
261    })
262}
263
264impl ScopedParsePair for Type {
265    fn rule() -> Rule {
266        Rule::r#type
267    }
268
269    fn message() -> &'static str {
270        "Type"
271    }
272
273    fn parse_pair(
274        extensions: &SimpleExtensions,
275        pair: Pair<Rule>,
276    ) -> Result<Self, MessageParseError> {
277        assert_eq!(pair.as_rule(), Rule::r#type);
278        let inner = unwrap_single_pair(pair);
279        match inner.as_rule() {
280            Rule::simple_type => Ok(parse_simple_type(inner)),
281            Rule::compound_type => parse_compound_type(extensions, inner),
282            Rule::user_defined_type => parse_user_defined_type(extensions, inner),
283            _ => unreachable!(
284                "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}",
285                inner.as_rule()
286            ),
287        }
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use pest::Parser;
294    use substrait::proto::r#type::{I64, Kind, Nullability};
295
296    use super::*;
297    use crate::parser::ExpressionParser;
298
299    #[test]
300    fn test_parse_simple_type() {
301        let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
302        let pair = pairs.next().unwrap();
303        assert_eq!(pairs.next(), None);
304        let t = parse_simple_type(pair);
305        assert_eq!(
306            t,
307            Type {
308                kind: Some(Kind::I64(I64 {
309                    nullability: Nullability::Required as i32,
310                    type_variation_reference: 0,
311                })),
312            }
313        );
314
315        let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
316        let pair = pairs.next().unwrap();
317        assert_eq!(pairs.next(), None);
318        let t = parse_simple_type(pair);
319        assert_eq!(
320            t,
321            Type {
322                kind: Some(Kind::String(proto::r#type::String {
323                    nullability: Nullability::Nullable as i32,
324                    type_variation_reference: 0,
325                })),
326            }
327        );
328    }
329
330    #[test]
331    fn test_parse_type() {
332        let extensions = SimpleExtensions::default();
333        let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
334        let pair = pairs.next().unwrap();
335        assert_eq!(pairs.next(), None);
336        let t = Type::parse_pair(&extensions, pair).unwrap();
337        assert_eq!(
338            t,
339            Type {
340                kind: Some(Kind::I64(I64 {
341                    nullability: Nullability::Required as i32,
342                    type_variation_reference: 0,
343                }))
344            }
345        );
346    }
347
348    #[test]
349    fn test_parse_list_type() {
350        let extensions = SimpleExtensions::default();
351        let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
352        let pair = pairs.next().unwrap();
353        assert_eq!(pairs.next(), None);
354        let t = parse_list_type(&extensions, pair).unwrap();
355        assert_eq!(
356            t,
357            Type {
358                kind: Some(Kind::List(Box::new(proto::r#type::List {
359                    nullability: Nullability::Required as i32,
360                    r#type: Some(Box::new(Type {
361                        kind: Some(Kind::I64(I64 {
362                            nullability: Nullability::Required as i32,
363                            type_variation_reference: 0,
364                        }))
365                    })),
366                    type_variation_reference: 0,
367                })))
368            }
369        );
370    }
371
372    #[test]
373    fn test_parse_parameters() {
374        let extensions = SimpleExtensions::default();
375        let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
376        let pair = pairs.next().unwrap();
377        assert_eq!(pairs.next(), None);
378        let t = parse_parameters(&extensions, pair).unwrap();
379        assert_eq!(
380            t,
381            vec![
382                Parameter {
383                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
384                        kind: Some(Kind::I64(proto::r#type::I64 {
385                            nullability: Nullability::Nullable as i32,
386                            type_variation_reference: 0,
387                        })),
388                    })),
389                },
390                Parameter {
391                    parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
392                        kind: Some(Kind::String(proto::r#type::String {
393                            nullability: Nullability::Required as i32,
394                            type_variation_reference: 0,
395                        })),
396                    })),
397                },
398            ]
399        );
400    }
401
402    #[test]
403    fn test_udts() {
404        let mut extensions = SimpleExtensions::default();
405        extensions
406            .add_extension_urn("some_source".to_string(), 4)
407            .unwrap();
408        extensions
409            .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
410            .unwrap();
411        let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
412        let pair = pairs.next().unwrap();
413        assert_eq!(pairs.next(), None);
414
415        let t = parse_user_defined_type(&extensions, pair).unwrap();
416        assert_eq!(
417            t,
418            Type {
419                kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
420                    type_reference: 42,
421                    type_variation_reference: 0,
422                    nullability: Nullability::Required as i32,
423                    type_parameters: vec![Parameter {
424                        parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
425                            kind: Some(Kind::I64(proto::r#type::I64 {
426                                nullability: Nullability::Nullable as i32,
427                                type_variation_reference: 0,
428                            })),
429                        })),
430                    }],
431                }))
432            }
433        );
434    }
435}