1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::{ExtensionKind, MissingReference};
8use crate::parser::MessageParseError;
9
10pub(crate) fn get_and_validate_anchor(
18 extensions: &SimpleExtensions,
19 kind: ExtensionKind,
20 anchor: Option<u32>,
21 name: &str,
22 span: pest::Span,
23) -> Result<u32, MessageParseError> {
24 if kind == ExtensionKind::Function {
25 return extensions
26 .resolve_function(name, anchor)
27 .map(|r| r.anchor)
28 .map_err(|e| {
29 MessageParseError::lookup(kind.name(), e, span, "Error resolving function")
30 });
31 }
32 match anchor {
34 Some(a) => match extensions.find_by_anchor(kind, a) {
35 Err(e) => Err(MessageParseError::lookup(
36 kind.name(),
37 e,
38 span,
39 "Error matching name to anchor",
40 )),
41 Ok((_, stored)) => {
42 if stored.full() == name || stored.base() == name {
43 Ok(a)
44 } else {
45 Err(MessageParseError::lookup(
46 kind.name(),
47 MissingReference::Mismatched(kind, name.to_string(), a),
48 span,
49 "Error matching name to anchor",
50 ))
51 }
52 }
53 },
54 None => extensions.find_by_name(kind, name).map_err(|e| {
55 MessageParseError::lookup(kind.name(), e, span, "Error finding extension for name")
56 }),
57 }
58}
59
60impl ParsePair for Nullability {
61 fn rule() -> Rule {
62 Rule::nullability
63 }
64
65 fn message() -> &'static str {
66 "Nullability"
67 }
68
69 fn parse_pair(pair: Pair<Rule>) -> Self {
70 assert_eq!(pair.as_rule(), Rule::nullability);
71 match pair.as_str() {
72 "?" => Nullability::Nullable,
73 "" => Nullability::Required,
74 "⁉" => Nullability::Unspecified,
75 _ => panic!("Invalid nullability: {}", pair.as_str()),
76 }
77 }
78}
79
80impl ScopedParsePair for Parameter {
81 fn rule() -> Rule {
82 Rule::parameter
83 }
84
85 fn message() -> &'static str {
86 "Parameter"
87 }
88
89 fn parse_pair(
90 extensions: &SimpleExtensions,
91 pair: Pair<Rule>,
92 ) -> Result<Self, MessageParseError> {
93 assert_eq!(pair.as_rule(), Rule::parameter);
94 let inner = unwrap_single_pair(pair);
95 match inner.as_rule() {
96 Rule::r#type => Ok(Parameter {
97 parameter: Some(proto::r#type::parameter::Parameter::DataType(
98 Type::parse_pair(extensions, inner)?,
99 )),
100 }),
101 _ => unimplemented!("{:?}", inner.as_rule()),
102 }
103 }
104}
105
106fn parse_simple_type(pair: Pair<Rule>) -> Type {
107 assert_eq!(pair.as_rule(), Rule::simple_type);
108 let mut iter = iter_pairs(pair.into_inner());
109 let name = iter.pop(Rule::simple_type_name).as_str();
110 let nullability = iter.parse_next::<Nullability>();
111 iter.done();
112
113 let kind = match name {
114 "boolean" => Kind::Bool(proto::r#type::Boolean {
115 nullability: nullability.into(),
116 type_variation_reference: 0,
117 }),
118 "i64" => Kind::I64(proto::r#type::I64 {
119 nullability: nullability.into(),
120 type_variation_reference: 0,
121 }),
122 "i32" => Kind::I32(proto::r#type::I32 {
123 nullability: nullability.into(),
124 type_variation_reference: 0,
125 }),
126 "i16" => Kind::I16(proto::r#type::I16 {
127 nullability: nullability.into(),
128 type_variation_reference: 0,
129 }),
130 "i8" => Kind::I8(proto::r#type::I8 {
131 nullability: nullability.into(),
132 type_variation_reference: 0,
133 }),
134 "fp32" => Kind::Fp32(proto::r#type::Fp32 {
135 nullability: nullability.into(),
136 type_variation_reference: 0,
137 }),
138 "fp64" => Kind::Fp64(proto::r#type::Fp64 {
139 nullability: nullability.into(),
140 type_variation_reference: 0,
141 }),
142 "string" => Kind::String(proto::r#type::String {
143 nullability: nullability.into(),
144 type_variation_reference: 0,
145 }),
146 "binary" => Kind::Binary(proto::r#type::Binary {
147 nullability: nullability.into(),
148 type_variation_reference: 0,
149 }),
150 #[allow(deprecated)]
151 "timestamp" => Kind::Timestamp(proto::r#type::Timestamp {
152 nullability: nullability.into(),
153 type_variation_reference: 0,
154 }),
155 #[allow(deprecated)]
156 "timestamp_tz" => Kind::TimestampTz(proto::r#type::TimestampTz {
157 nullability: nullability.into(),
158 type_variation_reference: 0,
159 }),
160 "date" => Kind::Date(proto::r#type::Date {
161 nullability: nullability.into(),
162 type_variation_reference: 0,
163 }),
164 "time" => Kind::Time(proto::r#type::Time {
165 nullability: nullability.into(),
166 type_variation_reference: 0,
167 }),
168 "interval_year" => Kind::IntervalYear(proto::r#type::IntervalYear {
169 nullability: nullability.into(),
170 type_variation_reference: 0,
171 }),
172 "uuid" => Kind::Uuid(proto::r#type::Uuid {
173 nullability: nullability.into(),
174 type_variation_reference: 0,
175 }),
176 _ => unreachable!("Type {} exists in parser but not implemented in code", name),
177 };
178 Type { kind: Some(kind) }
179}
180
181fn parse_compound_type(
182 extensions: &SimpleExtensions,
183 pair: Pair<Rule>,
184) -> Result<Type, MessageParseError> {
185 assert_eq!(pair.as_rule(), Rule::compound_type);
186 let inner = unwrap_single_pair(pair);
187 match inner.as_rule() {
188 Rule::list_type => parse_list_type(extensions, inner),
189 _ => unimplemented!("{:?}", inner.as_rule()),
192 }
193}
194
195fn parse_list_type(
196 extensions: &SimpleExtensions,
197 pair: Pair<Rule>,
198) -> Result<Type, MessageParseError> {
199 assert_eq!(pair.as_rule(), Rule::list_type);
200 let mut iter = iter_pairs(pair.into_inner());
201 let nullability = iter.parse_next::<Nullability>();
202 let inner = iter.parse_next_scoped::<Type>(extensions)?;
203 iter.done();
204
205 Ok(Type {
206 kind: Some(Kind::List(Box::new(proto::r#type::List {
207 nullability: nullability.into(),
208 r#type: Some(Box::new(inner)),
209 type_variation_reference: 0,
210 }))),
211 })
212}
213
214fn parse_parameters(
215 extensions: &SimpleExtensions,
216 pair: Pair<Rule>,
217) -> Result<Vec<Parameter>, MessageParseError> {
218 assert_eq!(pair.as_rule(), Rule::parameters);
219 let mut iter = iter_pairs(pair.into_inner());
220 let mut params = Vec::new();
221 while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
222 params.push(param?);
223 }
224 iter.done();
225 Ok(params)
226}
227
228fn parse_user_defined_type(
229 extensions: &SimpleExtensions,
230 pair: Pair<Rule>,
231) -> Result<Type, MessageParseError> {
232 let span = pair.as_span();
233 assert_eq!(pair.as_rule(), Rule::user_defined_type);
234 let mut iter = iter_pairs(pair.into_inner());
235 let name = iter.pop(Rule::name).as_str();
236 let anchor = iter
237 .try_pop(Rule::anchor)
238 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
239
240 let _urn_anchor = iter
242 .try_pop(Rule::urn_anchor)
243 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
244
245 let nullability = iter.parse_next::<Nullability>();
246 let parameters = match iter.try_pop(Rule::parameters) {
247 Some(p) => parse_parameters(extensions, p)?,
248 None => Vec::new(),
249 };
250 iter.done();
251
252 let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
253
254 Ok(Type {
255 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
256 type_reference: anchor,
257 nullability: nullability.into(),
258 type_parameters: parameters,
259 type_variation_reference: 0,
260 })),
261 })
262}
263
264impl ScopedParsePair for Type {
265 fn rule() -> Rule {
266 Rule::r#type
267 }
268
269 fn message() -> &'static str {
270 "Type"
271 }
272
273 fn parse_pair(
274 extensions: &SimpleExtensions,
275 pair: Pair<Rule>,
276 ) -> Result<Self, MessageParseError> {
277 assert_eq!(pair.as_rule(), Rule::r#type);
278 let inner = unwrap_single_pair(pair);
279 match inner.as_rule() {
280 Rule::simple_type => Ok(parse_simple_type(inner)),
281 Rule::compound_type => parse_compound_type(extensions, inner),
282 Rule::user_defined_type => parse_user_defined_type(extensions, inner),
283 _ => unreachable!(
284 "Grammar guarantees type can only be simple_type, compound_type, or user_defined_type, got: {:?}",
285 inner.as_rule()
286 ),
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use pest::Parser;
294 use substrait::proto::r#type::{I64, Kind, Nullability};
295
296 use super::*;
297 use crate::parser::ExpressionParser;
298
299 #[test]
300 fn test_parse_simple_type() {
301 let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
302 let pair = pairs.next().unwrap();
303 assert_eq!(pairs.next(), None);
304 let t = parse_simple_type(pair);
305 assert_eq!(
306 t,
307 Type {
308 kind: Some(Kind::I64(I64 {
309 nullability: Nullability::Required as i32,
310 type_variation_reference: 0,
311 })),
312 }
313 );
314
315 let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
316 let pair = pairs.next().unwrap();
317 assert_eq!(pairs.next(), None);
318 let t = parse_simple_type(pair);
319 assert_eq!(
320 t,
321 Type {
322 kind: Some(Kind::String(proto::r#type::String {
323 nullability: Nullability::Nullable as i32,
324 type_variation_reference: 0,
325 })),
326 }
327 );
328 }
329
330 #[test]
331 fn test_parse_type() {
332 let extensions = SimpleExtensions::default();
333 let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
334 let pair = pairs.next().unwrap();
335 assert_eq!(pairs.next(), None);
336 let t = Type::parse_pair(&extensions, pair).unwrap();
337 assert_eq!(
338 t,
339 Type {
340 kind: Some(Kind::I64(I64 {
341 nullability: Nullability::Required as i32,
342 type_variation_reference: 0,
343 }))
344 }
345 );
346 }
347
348 #[test]
349 fn test_parse_list_type() {
350 let extensions = SimpleExtensions::default();
351 let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
352 let pair = pairs.next().unwrap();
353 assert_eq!(pairs.next(), None);
354 let t = parse_list_type(&extensions, pair).unwrap();
355 assert_eq!(
356 t,
357 Type {
358 kind: Some(Kind::List(Box::new(proto::r#type::List {
359 nullability: Nullability::Required as i32,
360 r#type: Some(Box::new(Type {
361 kind: Some(Kind::I64(I64 {
362 nullability: Nullability::Required as i32,
363 type_variation_reference: 0,
364 }))
365 })),
366 type_variation_reference: 0,
367 })))
368 }
369 );
370 }
371
372 #[test]
373 fn test_parse_parameters() {
374 let extensions = SimpleExtensions::default();
375 let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
376 let pair = pairs.next().unwrap();
377 assert_eq!(pairs.next(), None);
378 let t = parse_parameters(&extensions, pair).unwrap();
379 assert_eq!(
380 t,
381 vec![
382 Parameter {
383 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
384 kind: Some(Kind::I64(proto::r#type::I64 {
385 nullability: Nullability::Nullable as i32,
386 type_variation_reference: 0,
387 })),
388 })),
389 },
390 Parameter {
391 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
392 kind: Some(Kind::String(proto::r#type::String {
393 nullability: Nullability::Required as i32,
394 type_variation_reference: 0,
395 })),
396 })),
397 },
398 ]
399 );
400 }
401
402 #[test]
403 fn test_udts() {
404 let mut extensions = SimpleExtensions::default();
405 extensions
406 .add_extension_urn("some_source".to_string(), 4)
407 .unwrap();
408 extensions
409 .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
410 .unwrap();
411 let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
412 let pair = pairs.next().unwrap();
413 assert_eq!(pairs.next(), None);
414
415 let t = parse_user_defined_type(&extensions, pair).unwrap();
416 assert_eq!(
417 t,
418 Type {
419 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
420 type_reference: 42,
421 type_variation_reference: 0,
422 nullability: Nullability::Required as i32,
423 type_parameters: vec![Parameter {
424 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
425 kind: Some(Kind::I64(proto::r#type::I64 {
426 nullability: Nullability::Nullable as i32,
427 type_variation_reference: 0,
428 })),
429 })),
430 }],
431 }))
432 }
433 );
434 }
435}