1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::{ExtensionKind, MissingReference};
8use crate::parser::MessageParseError;
9
10pub(crate) fn get_and_validate_anchor(
18 extensions: &SimpleExtensions,
19 kind: ExtensionKind,
20 anchor: Option<u32>,
21 name: &str,
22 span: pest::Span,
23) -> Result<u32, MessageParseError> {
24 if kind == ExtensionKind::Function {
25 return extensions
26 .resolve_function(name, anchor)
27 .map(|r| r.anchor)
28 .map_err(|e| {
29 MessageParseError::lookup(kind.name(), e, span, "Error resolving function")
30 });
31 }
32 match anchor {
34 Some(a) => match extensions.find_by_anchor(kind, a) {
35 Err(e) => Err(MessageParseError::lookup(
36 kind.name(),
37 e,
38 span,
39 "Error matching name to anchor",
40 )),
41 Ok((_, stored)) => {
42 if stored.full() == name || stored.base() == name {
43 Ok(a)
44 } else {
45 Err(MessageParseError::lookup(
46 kind.name(),
47 MissingReference::Mismatched(kind, name.to_string(), a),
48 span,
49 "Error matching name to anchor",
50 ))
51 }
52 }
53 },
54 None => extensions.find_by_name(kind, name).map_err(|e| {
55 MessageParseError::lookup(kind.name(), e, span, "Error finding extension for name")
56 }),
57 }
58}
59
60impl ParsePair for Nullability {
61 fn rule() -> Rule {
62 Rule::nullability
63 }
64
65 fn message() -> &'static str {
66 "Nullability"
67 }
68
69 fn parse_pair(pair: Pair<Rule>) -> Self {
70 assert_eq!(pair.as_rule(), Rule::nullability);
71 match pair.as_str() {
72 "?" => Nullability::Nullable,
73 "" => Nullability::Required,
74 "⁉" => Nullability::Unspecified,
75 _ => panic!("Invalid nullability: {}", pair.as_str()),
76 }
77 }
78}
79
80impl ScopedParsePair for Parameter {
81 fn rule() -> Rule {
82 Rule::parameter
83 }
84
85 fn message() -> &'static str {
86 "Parameter"
87 }
88
89 fn parse_pair(
90 extensions: &SimpleExtensions,
91 pair: Pair<Rule>,
92 ) -> Result<Self, MessageParseError> {
93 assert_eq!(pair.as_rule(), Rule::parameter);
94 let inner = unwrap_single_pair(pair);
95 match inner.as_rule() {
96 Rule::r#type => Ok(Parameter {
97 parameter: Some(proto::r#type::parameter::Parameter::DataType(
98 Type::parse_pair(extensions, inner)?,
99 )),
100 }),
101 _ => unimplemented!("{:?}", inner.as_rule()),
102 }
103 }
104}
105
106fn parse_simple_type(pair: Pair<Rule>) -> Type {
107 assert_eq!(pair.as_rule(), Rule::simple_type);
108 let mut iter = iter_pairs(pair.into_inner());
109 let name = iter.pop(Rule::simple_type_name).as_str();
110 let nullability = iter.parse_next::<Nullability>();
111 iter.done();
112
113 let kind = match name {
114 "boolean" => Kind::Bool(proto::r#type::Boolean {
115 nullability: nullability.into(),
116 type_variation_reference: 0,
117 }),
118 "i64" => Kind::I64(proto::r#type::I64 {
119 nullability: nullability.into(),
120 type_variation_reference: 0,
121 }),
122 "i32" => Kind::I32(proto::r#type::I32 {
123 nullability: nullability.into(),
124 type_variation_reference: 0,
125 }),
126 "i16" => Kind::I16(proto::r#type::I16 {
127 nullability: nullability.into(),
128 type_variation_reference: 0,
129 }),
130 "i8" => Kind::I8(proto::r#type::I8 {
131 nullability: nullability.into(),
132 type_variation_reference: 0,
133 }),
134 "fp32" => Kind::Fp32(proto::r#type::Fp32 {
135 nullability: nullability.into(),
136 type_variation_reference: 0,
137 }),
138 "fp64" => Kind::Fp64(proto::r#type::Fp64 {
139 nullability: nullability.into(),
140 type_variation_reference: 0,
141 }),
142 "string" => Kind::String(proto::r#type::String {
143 nullability: nullability.into(),
144 type_variation_reference: 0,
145 }),
146 "binary" => Kind::Binary(proto::r#type::Binary {
147 nullability: nullability.into(),
148 type_variation_reference: 0,
149 }),
150 #[allow(deprecated)]
151 "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
152 nullability: nullability.into(),
153 type_variation_reference: 0,
154 }),
155 #[allow(deprecated)]
156 "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
157 nullability: nullability.into(),
158 type_variation_reference: 0,
159 }),
160 "date" => Kind::Date(proto::r#type::Date {
161 nullability: nullability.into(),
162 type_variation_reference: 0,
163 }),
164 #[allow(deprecated)]
165 "time" => Kind::Time(proto::r#type::Time {
166 nullability: nullability.into(),
167 type_variation_reference: 0,
168 }),
169 "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
170 nullability: nullability.into(),
171 type_variation_reference: 0,
172 }),
173 "uuid" => Kind::Uuid(proto::r#type::Uuid {
174 nullability: nullability.into(),
175 type_variation_reference: 0,
176 }),
177 _ => unreachable!("Type {} exists in parser but not implemented in code", name),
178 };
179 Type { kind: Some(kind) }
180}
181
182fn parse_compound_type(
183 extensions: &SimpleExtensions,
184 pair: Pair<Rule>,
185) -> Result<Type, MessageParseError> {
186 assert_eq!(pair.as_rule(), Rule::compound_type);
187 let inner = unwrap_single_pair(pair);
188 match inner.as_rule() {
189 Rule::list_type => parse_list_type(extensions, inner),
190 Rule::precision_timestamp_tz_type
193 | Rule::precision_timestamp_type
194 | Rule::precision_time_type => parse_precision_type(inner),
195 _ => unimplemented!("{:?}", inner.as_rule()),
196 }
197}
198
199fn parse_precision_type(pair: Pair<Rule>) -> Result<Type, MessageParseError> {
200 let rule = pair.as_rule();
201 let mut iter = iter_pairs(pair.into_inner());
202 let nullability = iter.parse_next::<Nullability>();
203 let precision_pair = iter.pop(Rule::integer);
204 let precision_span = precision_pair.as_span();
205 let precision = precision_pair.as_str().parse::<i32>().unwrap();
206 if !(0..=12).contains(&precision) {
207 return Err(MessageParseError::invalid(
208 "precision time type",
209 precision_span,
210 format!("precision must be between 0 and 12, got {precision}"),
211 ));
212 }
213 iter.done();
214 let kind = match rule {
215 Rule::precision_timestamp_type => {
216 Kind::PrecisionTimestamp(proto::r#type::PrecisionTimestamp {
217 precision,
218 nullability: nullability.into(),
219 type_variation_reference: 0,
220 })
221 }
222 Rule::precision_timestamp_tz_type => {
223 Kind::PrecisionTimestampTz(proto::r#type::PrecisionTimestampTz {
224 precision,
225 nullability: nullability.into(),
226 type_variation_reference: 0,
227 })
228 }
229 Rule::precision_time_type => Kind::PrecisionTime(proto::r#type::PrecisionTime {
230 precision,
231 nullability: nullability.into(),
232 type_variation_reference: 0,
233 }),
234 _ => unreachable!("parse_precision_type called with rule {:?}", rule),
235 };
236 Ok(Type { kind: Some(kind) })
237}
238
239fn parse_list_type(
240 extensions: &SimpleExtensions,
241 pair: Pair<Rule>,
242) -> Result<Type, MessageParseError> {
243 assert_eq!(pair.as_rule(), Rule::list_type);
244 let mut iter = iter_pairs(pair.into_inner());
245 let nullability = iter.parse_next::<Nullability>();
246 let inner = iter.parse_next_scoped::<Type>(extensions)?;
247 iter.done();
248
249 Ok(Type {
250 kind: Some(Kind::List(Box::new(proto::r#type::List {
251 nullability: nullability.into(),
252 r#type: Some(Box::new(inner)),
253 type_variation_reference: 0,
254 }))),
255 })
256}
257
258fn parse_parameters(
259 extensions: &SimpleExtensions,
260 pair: Pair<Rule>,
261) -> Result<Vec<Parameter>, MessageParseError> {
262 assert_eq!(pair.as_rule(), Rule::parameters);
263 let mut iter = iter_pairs(pair.into_inner());
264 let mut params = Vec::new();
265 while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
266 params.push(param?);
267 }
268 iter.done();
269 Ok(params)
270}
271
272fn parse_user_defined_type(
273 extensions: &SimpleExtensions,
274 pair: Pair<Rule>,
275) -> Result<Type, MessageParseError> {
276 let span = pair.as_span();
277 assert_eq!(pair.as_rule(), Rule::user_defined_type);
278 let mut iter = iter_pairs(pair.into_inner());
279 let name = iter.pop(Rule::name).as_str();
280 let anchor = iter
281 .try_pop(Rule::anchor)
282 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
283
284 let _urn_anchor = iter
286 .try_pop(Rule::urn_anchor)
287 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
288
289 let nullability = iter.parse_next::<Nullability>();
290 let parameters = match iter.try_pop(Rule::parameters) {
291 Some(p) => parse_parameters(extensions, p)?,
292 None => Vec::new(),
293 };
294 iter.done();
295
296 let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
297
298 Ok(Type {
299 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
300 type_reference: anchor,
301 nullability: nullability.into(),
302 type_parameters: parameters,
303 type_variation_reference: 0,
304 })),
305 })
306}
307
308impl ScopedParsePair for Type {
309 fn rule() -> Rule {
310 Rule::r#type
311 }
312
313 fn message() -> &'static str {
314 "Type"
315 }
316
317 fn parse_pair(
318 extensions: &SimpleExtensions,
319 pair: Pair<Rule>,
320 ) -> Result<Self, MessageParseError> {
321 assert_eq!(pair.as_rule(), Rule::r#type);
322 let inner = unwrap_single_pair(pair);
323 match inner.as_rule() {
324 Rule::simple_type => Ok(parse_simple_type(inner)),
325 Rule::compound_type => parse_compound_type(extensions, inner),
326 Rule::user_defined_type => parse_user_defined_type(extensions, inner),
327 _ => unreachable!(
328 "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}",
329 inner.as_rule()
330 ),
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use pest::Parser;
338 use substrait::proto::r#type::{I64, Kind, Nullability};
339
340 use super::*;
341 use crate::parser::ExpressionParser;
342
343 #[test]
344 fn test_parse_simple_type() {
345 let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
346 let pair = pairs.next().unwrap();
347 assert_eq!(pairs.next(), None);
348 let t = parse_simple_type(pair);
349 assert_eq!(
350 t,
351 Type {
352 kind: Some(Kind::I64(I64 {
353 nullability: Nullability::Required as i32,
354 type_variation_reference: 0,
355 })),
356 }
357 );
358
359 let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
360 let pair = pairs.next().unwrap();
361 assert_eq!(pairs.next(), None);
362 let t = parse_simple_type(pair);
363 assert_eq!(
364 t,
365 Type {
366 kind: Some(Kind::String(proto::r#type::String {
367 nullability: Nullability::Nullable as i32,
368 type_variation_reference: 0,
369 })),
370 }
371 );
372 }
373
374 #[test]
375 fn test_parse_type() {
376 let extensions = SimpleExtensions::default();
377 let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
378 let pair = pairs.next().unwrap();
379 assert_eq!(pairs.next(), None);
380 let t = Type::parse_pair(&extensions, pair).unwrap();
381 assert_eq!(
382 t,
383 Type {
384 kind: Some(Kind::I64(I64 {
385 nullability: Nullability::Required as i32,
386 type_variation_reference: 0,
387 }))
388 }
389 );
390 }
391
392 #[test]
393 fn test_parse_list_type() {
394 let extensions = SimpleExtensions::default();
395 let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
396 let pair = pairs.next().unwrap();
397 assert_eq!(pairs.next(), None);
398 let t = parse_list_type(&extensions, pair).unwrap();
399 assert_eq!(
400 t,
401 Type {
402 kind: Some(Kind::List(Box::new(proto::r#type::List {
403 nullability: Nullability::Required as i32,
404 r#type: Some(Box::new(Type {
405 kind: Some(Kind::I64(I64 {
406 nullability: Nullability::Required as i32,
407 type_variation_reference: 0,
408 }))
409 })),
410 type_variation_reference: 0,
411 })))
412 }
413 );
414 }
415
416 #[test]
417 fn test_parse_parameters() {
418 let extensions = SimpleExtensions::default();
419 let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
420 let pair = pairs.next().unwrap();
421 assert_eq!(pairs.next(), None);
422 let t = parse_parameters(&extensions, pair).unwrap();
423 assert_eq!(
424 t,
425 vec![
426 Parameter {
427 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
428 kind: Some(Kind::I64(proto::r#type::I64 {
429 nullability: Nullability::Nullable as i32,
430 type_variation_reference: 0,
431 })),
432 })),
433 },
434 Parameter {
435 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
436 kind: Some(Kind::String(proto::r#type::String {
437 nullability: Nullability::Required as i32,
438 type_variation_reference: 0,
439 })),
440 })),
441 },
442 ]
443 );
444 }
445
446 #[test]
447 fn test_udts() {
448 let mut extensions = SimpleExtensions::default();
449 extensions
450 .add_extension_urn("some_source".to_string(), 4)
451 .unwrap();
452 extensions
453 .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
454 .unwrap();
455 let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
456 let pair = pairs.next().unwrap();
457 assert_eq!(pairs.next(), None);
458
459 let t = parse_user_defined_type(&extensions, pair).unwrap();
460 assert_eq!(
461 t,
462 Type {
463 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
464 type_reference: 42,
465 type_variation_reference: 0,
466 nullability: Nullability::Required as i32,
467 type_parameters: vec![Parameter {
468 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
469 kind: Some(Kind::I64(proto::r#type::I64 {
470 nullability: Nullability::Nullable as i32,
471 type_variation_reference: 0,
472 })),
473 })),
474 }],
475 }))
476 }
477 );
478 }
479}