1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::ExtensionKind;
8use crate::parser::{ErrorKind, MessageParseError};
9
10pub(crate) fn get_and_validate_anchor(
13 extensions: &SimpleExtensions,
14 kind: ExtensionKind,
15 anchor: Option<u32>,
16 name: &str,
17 span: pest::Span,
18) -> Result<u32, MessageParseError> {
19 match anchor {
20 Some(a) => match extensions.is_name_unique(kind, a, name) {
21 Ok(_) => Ok(a),
22 Err(e) => {
23 let message = "Error matching name to anchor".to_string();
24 let error = MessageParseError {
25 message: kind.name(),
26 kind: ErrorKind::Lookup(e),
27 error: Box::new(pest::error::Error::new_from_span(
28 pest::error::ErrorVariant::CustomError { message },
29 span,
30 )),
31 };
32 Err(error)
33 }
34 },
35 None => match extensions.find_by_name(kind, name) {
36 Ok(a) => Ok(a),
37 Err(e) => {
38 let message = "Error finding extension for name".to_string();
39 let error = MessageParseError {
40 message: kind.name(),
41 kind: ErrorKind::Lookup(e),
42 error: Box::new(pest::error::Error::new_from_span(
43 pest::error::ErrorVariant::CustomError { message },
44 span,
45 )),
46 };
47 Err(error)
48 }
49 },
50 }
51}
52
53impl ParsePair for Nullability {
54 fn rule() -> Rule {
55 Rule::nullability
56 }
57
58 fn message() -> &'static str {
59 "Nullability"
60 }
61
62 fn parse_pair(pair: Pair<Rule>) -> Self {
63 assert_eq!(pair.as_rule(), Rule::nullability);
64 match pair.as_str() {
65 "?" => Nullability::Nullable,
66 "" => Nullability::Required,
67 "⁉" => Nullability::Unspecified,
68 _ => panic!("Invalid nullability: {}", pair.as_str()),
69 }
70 }
71}
72
73impl ScopedParsePair for Parameter {
74 fn rule() -> Rule {
75 Rule::parameter
76 }
77
78 fn message() -> &'static str {
79 "Parameter"
80 }
81
82 fn parse_pair(
83 extensions: &SimpleExtensions,
84 pair: Pair<Rule>,
85 ) -> Result<Self, MessageParseError> {
86 assert_eq!(pair.as_rule(), Rule::parameter);
87 let inner = unwrap_single_pair(pair);
88 match inner.as_rule() {
89 Rule::r#type => Ok(Parameter {
90 parameter: Some(proto::r#type::parameter::Parameter::DataType(
91 Type::parse_pair(extensions, inner)?,
92 )),
93 }),
94 _ => unimplemented!("{:?}", inner.as_rule()),
95 }
96 }
97}
98
99fn parse_simple_type(pair: Pair<Rule>) -> Type {
100 assert_eq!(pair.as_rule(), Rule::simple_type);
101 let mut iter = iter_pairs(pair.into_inner());
102 let name = iter.pop(Rule::simple_type_name).as_str();
103 let nullability = iter.parse_next::<Nullability>();
104 iter.done();
105 let kind = match name {
106 "boolean" => Kind::Bool(proto::r#type::Boolean {
107 nullability: nullability.into(),
108 type_variation_reference: 0,
109 }),
110 "i64" => Kind::I64(proto::r#type::I64 {
111 nullability: nullability.into(),
112 type_variation_reference: 0,
113 }),
114 "i32" => Kind::I32(proto::r#type::I32 {
115 nullability: nullability.into(),
116 type_variation_reference: 0,
117 }),
118 "i16" => Kind::I16(proto::r#type::I16 {
119 nullability: nullability.into(),
120 type_variation_reference: 0,
121 }),
122 "i8" => Kind::I8(proto::r#type::I8 {
123 nullability: nullability.into(),
124 type_variation_reference: 0,
125 }),
126 "fp32" => Kind::Fp32(proto::r#type::Fp32 {
127 nullability: nullability.into(),
128 type_variation_reference: 0,
129 }),
130 "fp64" => Kind::Fp64(proto::r#type::Fp64 {
131 nullability: nullability.into(),
132 type_variation_reference: 0,
133 }),
134 "string" => Kind::String(proto::r#type::String {
135 nullability: nullability.into(),
136 type_variation_reference: 0,
137 }),
138 "binary" => Kind::Binary(proto::r#type::Binary {
139 nullability: nullability.into(),
140 type_variation_reference: 0,
141 }),
142 "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
143 nullability: nullability.into(),
144 type_variation_reference: 0,
145 }),
146 "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
147 nullability: nullability.into(),
148 type_variation_reference: 0,
149 }),
150 "date" => Kind::Date(proto::r#type::Date {
151 nullability: nullability.into(),
152 type_variation_reference: 0,
153 }),
154 "time" => Kind::Time(proto::r#type::Time {
155 nullability: nullability.into(),
156 type_variation_reference: 0,
157 }),
158 "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
159 nullability: nullability.into(),
160 type_variation_reference: 0,
161 }),
162 "uuid" => Kind::Uuid(proto::r#type::Uuid {
163 nullability: nullability.into(),
164 type_variation_reference: 0,
165 }),
166 _ => unreachable!("Type {} exists in parser but not implemented in code", name),
167 };
168 Type { kind: Some(kind) }
169}
170
171fn parse_compound_type(
172 extensions: &SimpleExtensions,
173 pair: Pair<Rule>,
174) -> Result<Type, MessageParseError> {
175 assert_eq!(pair.as_rule(), Rule::compound_type);
176 let inner = unwrap_single_pair(pair);
177 match inner.as_rule() {
178 Rule::list_type => parse_list_type(extensions, inner),
179 _ => unimplemented!("{:?}", inner.as_rule()),
182 }
183}
184
185fn parse_list_type(
186 extensions: &SimpleExtensions,
187 pair: Pair<Rule>,
188) -> Result<Type, MessageParseError> {
189 assert_eq!(pair.as_rule(), Rule::list_type);
190 let mut iter = iter_pairs(pair.into_inner());
191 let nullability = iter.parse_next::<Nullability>();
192 let inner = iter.parse_next_scoped::<Type>(extensions)?;
193 iter.done();
194
195 Ok(Type {
196 kind: Some(Kind::List(Box::new(proto::r#type::List {
197 nullability: nullability.into(),
198 r#type: Some(Box::new(inner)),
199 type_variation_reference: 0,
200 }))),
201 })
202}
203
204fn parse_parameters(
205 extensions: &SimpleExtensions,
206 pair: Pair<Rule>,
207) -> Result<Vec<Parameter>, MessageParseError> {
208 assert_eq!(pair.as_rule(), Rule::parameters);
209 let mut iter = iter_pairs(pair.into_inner());
210 let mut params = Vec::new();
211 while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
212 params.push(param?);
213 }
214 iter.done();
215 Ok(params)
216}
217
218fn parse_user_defined_type(
219 extensions: &SimpleExtensions,
220 pair: Pair<Rule>,
221) -> Result<Type, MessageParseError> {
222 let span = pair.as_span();
223 assert_eq!(pair.as_rule(), Rule::user_defined_type);
224 let mut iter = iter_pairs(pair.into_inner());
225 let name = iter.pop(Rule::name).as_str();
226 let anchor = iter
227 .try_pop(Rule::anchor)
228 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
229
230 let _uri_anchor = iter
232 .try_pop(Rule::uri_anchor)
233 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
234
235 let nullability = iter.parse_next::<Nullability>();
236 let parameters = match iter.try_pop(Rule::parameters) {
237 Some(p) => parse_parameters(extensions, p)?,
238 None => Vec::new(),
239 };
240 iter.done();
241
242 let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
243
244 Ok(Type {
245 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
246 type_reference: anchor,
247 nullability: nullability.into(),
248 type_parameters: parameters,
249 type_variation_reference: 0,
250 })),
251 })
252}
253
254impl ScopedParsePair for Type {
255 fn rule() -> Rule {
256 Rule::r#type
257 }
258
259 fn message() -> &'static str {
260 "Type"
261 }
262
263 fn parse_pair(
264 extensions: &SimpleExtensions,
265 pair: Pair<Rule>,
266 ) -> Result<Self, MessageParseError> {
267 assert_eq!(pair.as_rule(), Rule::r#type);
268 let inner = unwrap_single_pair(pair);
269 match inner.as_rule() {
270 Rule::simple_type => Ok(parse_simple_type(inner)),
271 Rule::compound_type => parse_compound_type(extensions, inner),
272 Rule::user_defined_type => parse_user_defined_type(extensions, inner),
273 _ => unimplemented!("{:?}", inner.as_rule()),
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use pest::Parser;
281 use substrait::proto::r#type::{I64, Kind, Nullability};
282
283 use super::*;
284 use crate::parser::ExpressionParser;
285
286 #[test]
287 fn test_parse_simple_type() {
288 let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
289 let pair = pairs.next().unwrap();
290 assert_eq!(pairs.next(), None);
291 let t = parse_simple_type(pair);
292 assert_eq!(
293 t,
294 Type {
295 kind: Some(Kind::I64(I64 {
296 nullability: Nullability::Required as i32,
297 type_variation_reference: 0,
298 })),
299 }
300 );
301
302 let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
303 let pair = pairs.next().unwrap();
304 assert_eq!(pairs.next(), None);
305 let t = parse_simple_type(pair);
306 assert_eq!(
307 t,
308 Type {
309 kind: Some(Kind::String(proto::r#type::String {
310 nullability: Nullability::Nullable as i32,
311 type_variation_reference: 0,
312 })),
313 }
314 );
315 }
316
317 #[test]
318 fn test_parse_type() {
319 let extensions = SimpleExtensions::default();
320 let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
321 let pair = pairs.next().unwrap();
322 assert_eq!(pairs.next(), None);
323 let t = Type::parse_pair(&extensions, pair).unwrap();
324 assert_eq!(
325 t,
326 Type {
327 kind: Some(Kind::I64(I64 {
328 nullability: Nullability::Required as i32,
329 type_variation_reference: 0,
330 }))
331 }
332 );
333 }
334
335 #[test]
336 fn test_parse_list_type() {
337 let extensions = SimpleExtensions::default();
338 let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
339 let pair = pairs.next().unwrap();
340 assert_eq!(pairs.next(), None);
341 let t = parse_list_type(&extensions, pair).unwrap();
342 assert_eq!(
343 t,
344 Type {
345 kind: Some(Kind::List(Box::new(proto::r#type::List {
346 nullability: Nullability::Required as i32,
347 r#type: Some(Box::new(Type {
348 kind: Some(Kind::I64(I64 {
349 nullability: Nullability::Required as i32,
350 type_variation_reference: 0,
351 }))
352 })),
353 type_variation_reference: 0,
354 })))
355 }
356 );
357 }
358
359 #[test]
360 fn test_parse_parameters() {
361 let extensions = SimpleExtensions::default();
362 let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
363 let pair = pairs.next().unwrap();
364 assert_eq!(pairs.next(), None);
365 let t = parse_parameters(&extensions, pair).unwrap();
366 assert_eq!(
367 t,
368 vec![
369 Parameter {
370 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
371 kind: Some(Kind::I64(proto::r#type::I64 {
372 nullability: Nullability::Nullable as i32,
373 type_variation_reference: 0,
374 })),
375 })),
376 },
377 Parameter {
378 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
379 kind: Some(Kind::String(proto::r#type::String {
380 nullability: Nullability::Required as i32,
381 type_variation_reference: 0,
382 })),
383 })),
384 },
385 ]
386 );
387 }
388
389 #[test]
390 fn test_udts() {
391 let mut extensions = SimpleExtensions::default();
392 extensions
393 .add_extension_uri("some_source".to_string(), 4)
394 .unwrap();
395 extensions
396 .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
397 .unwrap();
398 let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
399 let pair = pairs.next().unwrap();
400 assert_eq!(pairs.next(), None);
401
402 let t = parse_user_defined_type(&extensions, pair).unwrap();
403 assert_eq!(
404 t,
405 Type {
406 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
407 type_reference: 42,
408 type_variation_reference: 0,
409 nullability: Nullability::Required as i32,
410 type_parameters: vec![Parameter {
411 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
412 kind: Some(Kind::I64(proto::r#type::I64 {
413 nullability: Nullability::Nullable as i32,
414 type_variation_reference: 0,
415 })),
416 })),
417 }],
418 }))
419 }
420 );
421 }
422}