1use pest::iterators::Pair;
2use substrait::proto::r#type::{Kind, Nullability, Parameter};
3use substrait::proto::{self, Type};
4
5use super::{ParsePair, Rule, ScopedParsePair, iter_pairs, unwrap_single_pair};
6use crate::extensions::SimpleExtensions;
7use crate::extensions::simple::ExtensionKind;
8use crate::parser::{ErrorKind, MessageParseError};
9
10pub(crate) fn get_and_validate_anchor(
13 extensions: &SimpleExtensions,
14 kind: ExtensionKind,
15 anchor: Option<u32>,
16 name: &str,
17 span: pest::Span,
18) -> Result<u32, MessageParseError> {
19 match anchor {
20 Some(a) => match extensions.is_name_unique(kind, a, name) {
21 Ok(_) => Ok(a),
22 Err(e) => {
23 let message = "Error matching name to anchor".to_string();
24 let error = MessageParseError {
25 message: kind.name(),
26 kind: ErrorKind::Lookup(e),
27 error: Box::new(pest::error::Error::new_from_span(
28 pest::error::ErrorVariant::CustomError { message },
29 span,
30 )),
31 };
32 Err(error)
33 }
34 },
35 None => match extensions.find_by_name(kind, name) {
36 Ok(a) => Ok(a),
37 Err(e) => {
38 let message = "Error finding extension for name".to_string();
39 let error = MessageParseError {
40 message: kind.name(),
41 kind: ErrorKind::Lookup(e),
42 error: Box::new(pest::error::Error::new_from_span(
43 pest::error::ErrorVariant::CustomError { message },
44 span,
45 )),
46 };
47 Err(error)
48 }
49 },
50 }
51}
52
53impl ParsePair for Nullability {
54 fn rule() -> Rule {
55 Rule::nullability
56 }
57
58 fn message() -> &'static str {
59 "Nullability"
60 }
61
62 fn parse_pair(pair: Pair<Rule>) -> Self {
63 assert_eq!(pair.as_rule(), Rule::nullability);
64 match pair.as_str() {
65 "?" => Nullability::Nullable,
66 "" => Nullability::Required,
67 "⁉" => Nullability::Unspecified,
68 _ => panic!("Invalid nullability: {}", pair.as_str()),
69 }
70 }
71}
72
73impl ScopedParsePair for Parameter {
74 fn rule() -> Rule {
75 Rule::parameter
76 }
77
78 fn message() -> &'static str {
79 "Parameter"
80 }
81
82 fn parse_pair(
83 extensions: &SimpleExtensions,
84 pair: Pair<Rule>,
85 ) -> Result<Self, MessageParseError> {
86 assert_eq!(pair.as_rule(), Rule::parameter);
87 let inner = unwrap_single_pair(pair);
88 match inner.as_rule() {
89 Rule::r#type => Ok(Parameter {
90 parameter: Some(proto::r#type::parameter::Parameter::DataType(
91 Type::parse_pair(extensions, inner)?,
92 )),
93 }),
94 _ => unimplemented!("{:?}", inner.as_rule()),
95 }
96 }
97}
98
99fn parse_simple_type(pair: Pair<Rule>) -> Type {
100 assert_eq!(pair.as_rule(), Rule::simple_type);
101 let mut iter = iter_pairs(pair.into_inner());
102 let name = iter.pop(Rule::simple_type_name).as_str();
103 let nullability = iter.parse_next::<Nullability>();
104 iter.done();
105 let kind = match name {
106 "boolean" => Kind::Bool(proto::r#type::Boolean {
107 nullability: nullability.into(),
108 type_variation_reference: 0,
109 }),
110 "i64" => Kind::I64(proto::r#type::I64 {
111 nullability: nullability.into(),
112 type_variation_reference: 0,
113 }),
114 "i32" => Kind::I32(proto::r#type::I32 {
115 nullability: nullability.into(),
116 type_variation_reference: 0,
117 }),
118 "i16" => Kind::I16(proto::r#type::I16 {
119 nullability: nullability.into(),
120 type_variation_reference: 0,
121 }),
122 "i8" => Kind::I8(proto::r#type::I8 {
123 nullability: nullability.into(),
124 type_variation_reference: 0,
125 }),
126 "fp32" => Kind::Fp32(proto::r#type::Fp32 {
127 nullability: nullability.into(),
128 type_variation_reference: 0,
129 }),
130 "fp64" => Kind::Fp64(proto::r#type::Fp64 {
131 nullability: nullability.into(),
132 type_variation_reference: 0,
133 }),
134 "string" => Kind::String(proto::r#type::String {
135 nullability: nullability.into(),
136 type_variation_reference: 0,
137 }),
138 _ => unimplemented!("{}", name),
139 };
140 Type { kind: Some(kind) }
141}
142
143fn parse_compound_type(
144 extensions: &SimpleExtensions,
145 pair: Pair<Rule>,
146) -> Result<Type, MessageParseError> {
147 assert_eq!(pair.as_rule(), Rule::compound_type);
148 let inner = unwrap_single_pair(pair);
149 match inner.as_rule() {
150 Rule::list_type => parse_list_type(extensions, inner),
151 _ => unimplemented!("{:?}", inner.as_rule()),
154 }
155}
156
157fn parse_list_type(
158 extensions: &SimpleExtensions,
159 pair: Pair<Rule>,
160) -> Result<Type, MessageParseError> {
161 assert_eq!(pair.as_rule(), Rule::list_type);
162 let mut iter = iter_pairs(pair.into_inner());
163 let nullability = iter.parse_next::<Nullability>();
164 let inner = iter.parse_next_scoped::<Type>(extensions)?;
165 iter.done();
166
167 Ok(Type {
168 kind: Some(Kind::List(Box::new(proto::r#type::List {
169 nullability: nullability.into(),
170 r#type: Some(Box::new(inner)),
171 type_variation_reference: 0,
172 }))),
173 })
174}
175
176fn parse_parameters(
177 extensions: &SimpleExtensions,
178 pair: Pair<Rule>,
179) -> Result<Vec<Parameter>, MessageParseError> {
180 assert_eq!(pair.as_rule(), Rule::parameters);
181 let mut iter = iter_pairs(pair.into_inner());
182 let mut params = Vec::new();
183 while let Some(param) = iter.parse_if_next_scoped::<Parameter>(extensions) {
184 params.push(param?);
185 }
186 iter.done();
187 Ok(params)
188}
189
190fn parse_user_defined_type(
191 extensions: &SimpleExtensions,
192 pair: Pair<Rule>,
193) -> Result<Type, MessageParseError> {
194 let span = pair.as_span();
195 assert_eq!(pair.as_rule(), Rule::user_defined_type);
196 let mut iter = iter_pairs(pair.into_inner());
197 let name = iter.pop(Rule::name).as_str();
198 let anchor = iter
199 .try_pop(Rule::anchor)
200 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
201
202 let _uri_anchor = iter
204 .try_pop(Rule::uri_anchor)
205 .map(|n| unwrap_single_pair(n).as_str().parse::<u32>().unwrap());
206
207 let nullability = iter.parse_next::<Nullability>();
208 let parameters = match iter.try_pop(Rule::parameters) {
209 Some(p) => parse_parameters(extensions, p)?,
210 None => Vec::new(),
211 };
212 iter.done();
213
214 let anchor = get_and_validate_anchor(extensions, ExtensionKind::Type, anchor, name, span)?;
215
216 Ok(Type {
217 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
218 type_reference: anchor,
219 nullability: nullability.into(),
220 type_parameters: parameters,
221 type_variation_reference: 0,
222 })),
223 })
224}
225
226impl ScopedParsePair for Type {
227 fn rule() -> Rule {
228 Rule::r#type
229 }
230
231 fn message() -> &'static str {
232 "Type"
233 }
234
235 fn parse_pair(
236 extensions: &SimpleExtensions,
237 pair: Pair<Rule>,
238 ) -> Result<Self, MessageParseError> {
239 assert_eq!(pair.as_rule(), Rule::r#type);
240 let inner = unwrap_single_pair(pair);
241 match inner.as_rule() {
242 Rule::simple_type => Ok(parse_simple_type(inner)),
243 Rule::compound_type => parse_compound_type(extensions, inner),
244 Rule::user_defined_type => parse_user_defined_type(extensions, inner),
245 _ => unimplemented!("{:?}", inner.as_rule()),
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use pest::Parser;
253 use substrait::proto::r#type::{I64, Kind, Nullability};
254
255 use super::*;
256 use crate::parser::ExpressionParser;
257
258 #[test]
259 fn test_parse_simple_type() {
260 let mut pairs = ExpressionParser::parse(Rule::simple_type, "i64").unwrap();
261 let pair = pairs.next().unwrap();
262 assert_eq!(pairs.next(), None);
263 let t = parse_simple_type(pair);
264 assert_eq!(
265 t,
266 Type {
267 kind: Some(Kind::I64(I64 {
268 nullability: Nullability::Required as i32,
269 type_variation_reference: 0,
270 })),
271 }
272 );
273
274 let mut pairs = ExpressionParser::parse(Rule::simple_type, "string?").unwrap();
275 let pair = pairs.next().unwrap();
276 assert_eq!(pairs.next(), None);
277 let t = parse_simple_type(pair);
278 assert_eq!(
279 t,
280 Type {
281 kind: Some(Kind::String(proto::r#type::String {
282 nullability: Nullability::Nullable as i32,
283 type_variation_reference: 0,
284 })),
285 }
286 );
287 }
288
289 #[test]
290 fn test_parse_type() {
291 let extensions = SimpleExtensions::default();
292 let mut pairs = ExpressionParser::parse(Rule::r#type, "i64").unwrap();
293 let pair = pairs.next().unwrap();
294 assert_eq!(pairs.next(), None);
295 let t = Type::parse_pair(&extensions, pair).unwrap();
296 assert_eq!(
297 t,
298 Type {
299 kind: Some(Kind::I64(I64 {
300 nullability: Nullability::Required as i32,
301 type_variation_reference: 0,
302 }))
303 }
304 );
305 }
306
307 #[test]
308 fn test_parse_list_type() {
309 let extensions = SimpleExtensions::default();
310 let mut pairs = ExpressionParser::parse(Rule::list_type, "list<i64>").unwrap();
311 let pair = pairs.next().unwrap();
312 assert_eq!(pairs.next(), None);
313 let t = parse_list_type(&extensions, pair).unwrap();
314 assert_eq!(
315 t,
316 Type {
317 kind: Some(Kind::List(Box::new(proto::r#type::List {
318 nullability: Nullability::Required as i32,
319 r#type: Some(Box::new(Type {
320 kind: Some(Kind::I64(I64 {
321 nullability: Nullability::Required as i32,
322 type_variation_reference: 0,
323 }))
324 })),
325 type_variation_reference: 0,
326 })))
327 }
328 );
329 }
330
331 #[test]
332 fn test_parse_parameters() {
333 let extensions = SimpleExtensions::default();
334 let mut pairs = ExpressionParser::parse(Rule::parameters, "<i64?,string>").unwrap();
335 let pair = pairs.next().unwrap();
336 assert_eq!(pairs.next(), None);
337 let t = parse_parameters(&extensions, pair).unwrap();
338 assert_eq!(
339 t,
340 vec![
341 Parameter {
342 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
343 kind: Some(Kind::I64(proto::r#type::I64 {
344 nullability: Nullability::Nullable as i32,
345 type_variation_reference: 0,
346 })),
347 })),
348 },
349 Parameter {
350 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
351 kind: Some(Kind::String(proto::r#type::String {
352 nullability: Nullability::Required as i32,
353 type_variation_reference: 0,
354 })),
355 })),
356 },
357 ]
358 );
359 }
360
361 #[test]
362 fn test_udts() {
363 let mut extensions = SimpleExtensions::default();
364 extensions
365 .add_extension_uri("some_source".to_string(), 4)
366 .unwrap();
367 extensions
368 .add_extension(ExtensionKind::Type, 4, 42, "udt".to_string())
369 .unwrap();
370 let mut pairs = ExpressionParser::parse(Rule::user_defined_type, "udt#42<i64?>").unwrap();
371 let pair = pairs.next().unwrap();
372 assert_eq!(pairs.next(), None);
373
374 let t = parse_user_defined_type(&extensions, pair).unwrap();
375 assert_eq!(
376 t,
377 Type {
378 kind: Some(Kind::UserDefined(proto::r#type::UserDefined {
379 type_reference: 42,
380 type_variation_reference: 0,
381 nullability: Nullability::Required as i32,
382 type_parameters: vec![Parameter {
383 parameter: Some(proto::r#type::parameter::Parameter::DataType(Type {
384 kind: Some(Kind::I64(proto::r#type::I64 {
385 nullability: Nullability::Nullable as i32,
386 type_variation_reference: 0,
387 })),
388 })),
389 }],
390 }))
391 }
392 );
393 }
394}