1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::ExtensionKind;
8use crate::parser::{ErrorKind, MessageParseError};
9
10pub(crate) fn get_and_validate_anchor(
13 extensions: &SimpleExtensions,
14 kind: ExtensionKind,
15 anchor: Option<u32>,
16 name: &str,
17 span: pest::Span,
18) -> Result<u32, MessageParseError> {
19 match anchor {
20 Some(a) => match extensions.is_name_unique(kind, a, name) {
21 Ok(_) => Ok(a),
22 Err(e) => {
23 let message = "Error matching name to anchor".to_string();
24 let error = MessageParseError {
25 message: kind.name(),
26 kind: ErrorKind::Lookup(e),
27 error: Box::new(pest::error::Error::new_from_span(
28 pest::error::ErrorVariant::CustomError { message },
29 span,
30 )),
31 };
32 Err(error)
33 }
34 },
35 None => match extensions.find_by_name(kind, name) {
36 Ok(a) => Ok(a),
37 Err(e) => {
38 let message = "Error finding extension for name".to_string();
39 let error = MessageParseError {
40 message: kind.name(),
41 kind: ErrorKind::Lookup(e),
42 error: Box::new(pest::error::Error::new_from_span(
43 pest::error::ErrorVariant::CustomError { message },
44 span,
45 )),
46 };
47 Err(error)
48 }
49 },
50 }
51}
52
53impl ParsePair for Nullability {
54 fn rule() -> Rule {
55 Rule::nullability
56 }
57
58 fn message() -> &'static str {
59 "Nullability"
60 }
61
62 fn parse_pair(pair: Pair<Rule>) -> Self {
63 assert_eq!(pair.as_rule(), Rule::nullability);
64 match pair.as_str() {
65 "?" => Nullability::Nullable,
66 "" => Nullability::Required,
67 "⁉" => Nullability::Unspecified,
68 _ => panic!("Invalid nullability: {}", pair.as_str()),
69 }
70 }
71}
72
73impl ScopedParsePair for Parameter {
74 fn rule() -> Rule {
75 Rule::parameter
76 }
77
78 fn message() -> &'static str {
79 "Parameter"
80 }
81
82 fn parse_pair(
83 extensions: &SimpleExtensions,
84 pair: Pair<Rule>,
85 ) -> Result<Self, MessageParseError> {
86 assert_eq!(pair.as_rule(), Rule::parameter);
87 let inner = unwrap_single_pair(pair);
88 match inner.as_rule() {
89 Rule::r#type => Ok(Parameter {
90 parameter: Some(proto::r#type::parameter::Parameter::DataType(
91 Type::parse_pair(extensions, inner)?,
92 )),
93 }),
94 _ => unimplemented!("{:?}", inner.as_rule()),
95 }
96 }
97}
98
99fn parse_simple_type(pair: Pair<Rule>) -> Type {
100 assert_eq!(pair.as_rule(), Rule::simple_type);
101 let mut iter = iter_pairs(pair.into_inner());
102 let name = iter.pop(Rule::simple_type_name).as_str();
103 let nullability = iter.parse_next::<Nullability>();
104 iter.done();
105
106 let kind = match name {
107 "boolean" => Kind::Bool(proto::r#type::Boolean {
108 nullability: nullability.into(),
109 type_variation_reference: 0,
110 }),
111 "i64" => Kind::I64(proto::r#type::I64 {
112 nullability: nullability.into(),
113 type_variation_reference: 0,
114 }),
115 "i32" => Kind::I32(proto::r#type::I32 {
116 nullability: nullability.into(),
117 type_variation_reference: 0,
118 }),
119 "i16" => Kind::I16(proto::r#type::I16 {
120 nullability: nullability.into(),
121 type_variation_reference: 0,
122 }),
123 "i8" => Kind::I8(proto::r#type::I8 {
124 nullability: nullability.into(),
125 type_variation_reference: 0,
126 }),
127 "fp32" => Kind::Fp32(proto::r#type::Fp32 {
128 nullability: nullability.into(),
129 type_variation_reference: 0,
130 }),
131 "fp64" => Kind::Fp64(proto::r#type::Fp64 {
132 nullability: nullability.into(),
133 type_variation_reference: 0,
134 }),
135 "string" => Kind::String(proto::r#type::String {
136 nullability: nullability.into(),
137 type_variation_reference: 0,
138 }),
139 "binary" => Kind::Binary(proto::r#type::Binary {
140 nullability: nullability.into(),
141 type_variation_reference: 0,
142 }),
143 #[allow(deprecated)]
144 "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
145 nullability: nullability.into(),
146 type_variation_reference: 0,
147 }),
148 #[allow(deprecated)]
149 "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
150 nullability: nullability.into(),
151 type_variation_reference: 0,
152 }),
153 "date" => Kind::Date(proto::r#type::Date {
154 nullability: nullability.into(),
155 type_variation_reference: 0,
156 }),
157 "time" => Kind::Time(proto::r#type::Time {
158 nullability: nullability.into(),
159 type_variation_reference: 0,
160 }),
161 "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
162 nullability: nullability.into(),
163 type_variation_reference: 0,
164 }),
165 "uuid" => Kind::Uuid(proto::r#type::Uuid {
166 nullability: nullability.into(),
167 type_variation_reference: 0,
168 }),
169 _ => unreachable!("Type {} exists in parser but not implemented in code", name),
170 };
171 Type { kind: Some(kind) }
172}
173
174fn parse_compound_type(
175 extensions: &SimpleExtensions,
176 pair: Pair<Rule>,
177) -> Result<Type, MessageParseError> {
178 assert_eq!(pair.as_rule(), Rule::compound_type);
179 let inner = unwrap_single_pair(pair);
180 match inner.as_rule() {
181 Rule::list_type => parse_list_type(extensions, inner),
182 _ => unimplemented!("{:?}", inner.as_rule()),
185 }
186}
187
188fn parse_list_type(
189 extensions: &SimpleExtensions,
190 pair: Pair<Rule>,
191) -> Result<Type, MessageParseError> {
192 assert_eq!(pair.as_rule(), Rule::list_type);
193 let mut iter = iter_pairs(pair.into_inner());
194 let nullability = iter.parse_next::<Nullability>();
195 let inner = iter.parse_next_scoped::<Type>(extensions)?;
196 iter.done();
197
198 Ok(Type {
199 kind: Some(Kind::List(Box::new(proto::r#type::List {
200 nullability: nullability.into(),
201 r#type: Some(Box::new(inner)),
202 type_variation_reference: 0,
203 }))),
204 })
205}
206
207fn parse_parameters(
208 extensions: &SimpleExtensions,
209 pair: Pair<Rule>,
210) -> Result<Vec<Parameter>, MessageParseError> {
211 assert_eq!(pair.as_rule(), Rule::parameters);
212 let mut iter = iter_pairs(pair.into_inner());
213 let mut params = Vec::new();
214 while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
215 params.push(param?);
216 }
217 iter.done();
218 Ok(params)
219}
220
221fn parse_user_defined_type(
222 extensions: &SimpleExtensions,
223 pair: Pair<Rule>,
224) -> Result<Type, MessageParseError> {
225 let span = pair.as_span();
226 assert_eq!(pair.as_rule(), Rule::user_defined_type);
227 let mut iter = iter_pairs(pair.into_inner());
228 let name = iter.pop(Rule::name).as_str();
229 let anchor = iter
230 .try_pop(Rule::anchor)
231 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
232
233 let _urn_anchor = iter
235 .try_pop(Rule::urn_anchor)
236 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
237
238 let nullability = iter.parse_next::<Nullability>();
239 let parameters = match iter.try_pop(Rule::parameters) {
240 Some(p) => parse_parameters(extensions, p)?,
241 None => Vec::new(),
242 };
243 iter.done();
244
245 let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
246
247 Ok(Type {
248 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
249 type_reference: anchor,
250 nullability: nullability.into(),
251 type_parameters: parameters,
252 type_variation_reference: 0,
253 })),
254 })
255}
256
257impl ScopedParsePair for Type {
258 fn rule() -> Rule {
259 Rule::r#type
260 }
261
262 fn message() -> &'static str {
263 "Type"
264 }
265
266 fn parse_pair(
267 extensions: &SimpleExtensions,
268 pair: Pair<Rule>,
269 ) -> Result<Self, MessageParseError> {
270 assert_eq!(pair.as_rule(), Rule::r#type);
271 let inner = unwrap_single_pair(pair);
272 match inner.as_rule() {
273 Rule::simple_type => Ok(parse_simple_type(inner)),
274 Rule::compound_type => parse_compound_type(extensions, inner),
275 Rule::user_defined_type => parse_user_defined_type(extensions, inner),
276 _ => unreachable!(
277 "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}",
278 inner.as_rule()
279 ),
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use pest::Parser;
287 use substrait::proto::r#type::{I64, Kind, Nullability};
288
289 use super::*;
290 use crate::parser::ExpressionParser;
291
292 #[test]
293 fn test_parse_simple_type() {
294 let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
295 let pair = pairs.next().unwrap();
296 assert_eq!(pairs.next(), None);
297 let t = parse_simple_type(pair);
298 assert_eq!(
299 t,
300 Type {
301 kind: Some(Kind::I64(I64 {
302 nullability: Nullability::Required as i32,
303 type_variation_reference: 0,
304 })),
305 }
306 );
307
308 let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
309 let pair = pairs.next().unwrap();
310 assert_eq!(pairs.next(), None);
311 let t = parse_simple_type(pair);
312 assert_eq!(
313 t,
314 Type {
315 kind: Some(Kind::String(proto::r#type::String {
316 nullability: Nullability::Nullable as i32,
317 type_variation_reference: 0,
318 })),
319 }
320 );
321 }
322
323 #[test]
324 fn test_parse_type() {
325 let extensions = SimpleExtensions::default();
326 let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
327 let pair = pairs.next().unwrap();
328 assert_eq!(pairs.next(), None);
329 let t = Type::parse_pair(&extensions, pair).unwrap();
330 assert_eq!(
331 t,
332 Type {
333 kind: Some(Kind::I64(I64 {
334 nullability: Nullability::Required as i32,
335 type_variation_reference: 0,
336 }))
337 }
338 );
339 }
340
341 #[test]
342 fn test_parse_list_type() {
343 let extensions = SimpleExtensions::default();
344 let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
345 let pair = pairs.next().unwrap();
346 assert_eq!(pairs.next(), None);
347 let t = parse_list_type(&extensions, pair).unwrap();
348 assert_eq!(
349 t,
350 Type {
351 kind: Some(Kind::List(Box::new(proto::r#type::List {
352 nullability: Nullability::Required as i32,
353 r#type: Some(Box::new(Type {
354 kind: Some(Kind::I64(I64 {
355 nullability: Nullability::Required as i32,
356 type_variation_reference: 0,
357 }))
358 })),
359 type_variation_reference: 0,
360 })))
361 }
362 );
363 }
364
365 #[test]
366 fn test_parse_parameters() {
367 let extensions = SimpleExtensions::default();
368 let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
369 let pair = pairs.next().unwrap();
370 assert_eq!(pairs.next(), None);
371 let t = parse_parameters(&extensions, pair).unwrap();
372 assert_eq!(
373 t,
374 vec![
375 Parameter {
376 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
377 kind: Some(Kind::I64(proto::r#type::I64 {
378 nullability: Nullability::Nullable as i32,
379 type_variation_reference: 0,
380 })),
381 })),
382 },
383 Parameter {
384 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
385 kind: Some(Kind::String(proto::r#type::String {
386 nullability: Nullability::Required as i32,
387 type_variation_reference: 0,
388 })),
389 })),
390 },
391 ]
392 );
393 }
394
395 #[test]
396 fn test_udts() {
397 let mut extensions = SimpleExtensions::default();
398 extensions
399 .add_extension_urn("some_source".to_string(), 4)
400 .unwrap();
401 extensions
402 .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
403 .unwrap();
404 let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
405 let pair = pairs.next().unwrap();
406 assert_eq!(pairs.next(), None);
407
408 let t = parse_user_defined_type(&extensions, pair).unwrap();
409 assert_eq!(
410 t,
411 Type {
412 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
413 type_reference: 42,
414 type_variation_reference: 0,
415 nullability: Nullability::Required as i32,
416 type_parameters: vec![Parameter {
417 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
418 kind: Some(Kind::I64(proto::r#type::I64 {
419 nullability: Nullability::Nullable as i32,
420 type_variation_reference: 0,
421 })),
422 })),
423 }],
424 }))
425 }
426 );
427 }
428}