1use std::fmt;
2use std::str::FromStr;
3
4use substrait::proto::{Expression, Type};
5use thiserror::Error;
6
7use super::{
8 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
9 unwrap_single_pair,
10};
11use crate::extensions::simple::{self, ExtensionKind};
12use crate::extensions::{
13 AddendumKind, ExtensionArgs, ExtensionColumn, ExtensionValue, InsertError, SimpleExtensions,
14 TupleValue,
15};
16use crate::parser::structural::IndentedLine;
17
18#[derive(Debug, Clone, Error)]
19pub enum ExtensionParseError {
20 #[error("Unexpected line, expected {0}")]
21 UnexpectedLine(ExpectedExtensionLine),
22 #[error("Error adding extension: {0}")]
23 ExtensionError(#[from] InsertError),
24 #[error("Error parsing message: {0}")]
25 Message(#[from] super::MessageParseError),
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum ExpectedExtensionLine {
34 Extensions,
37 ExtensionUrns,
40 ExtensionDeclarations(ExtensionKind),
43}
44
45impl fmt::Display for ExpectedExtensionLine {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 ExpectedExtensionLine::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
49 ExpectedExtensionLine::ExtensionUrns => write!(f, "Extension URNs"),
50 ExpectedExtensionLine::ExtensionDeclarations(kind) => {
51 write!(f, "Extension Declaration for {kind}")
52 }
53 }
54 }
55}
56
57#[derive(Debug)]
64pub struct ExtensionParser {
65 state: ExpectedExtensionLine,
66 extensions: SimpleExtensions,
67}
68
69impl Default for ExtensionParser {
70 fn default() -> Self {
71 Self {
72 state: ExpectedExtensionLine::Extensions,
73 extensions: SimpleExtensions::new(),
74 }
75 }
76}
77
78impl ExtensionParser {
79 pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
80 if line.1.is_empty() {
81 self.state = ExpectedExtensionLine::Extensions;
84 return Ok(());
85 }
86
87 match self.state {
88 ExpectedExtensionLine::Extensions => self.parse_subsection(line),
89 ExpectedExtensionLine::ExtensionUrns => self.parse_extension_urns(line),
90 ExpectedExtensionLine::ExtensionDeclarations(extension_kind) => {
91 self.parse_declarations(line, extension_kind)
92 }
93 }
94 }
95
96 fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
97 match line {
98 IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
99 self.state = ExpectedExtensionLine::ExtensionUrns;
100 Ok(())
101 }
102 IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
103 self.state = ExpectedExtensionLine::ExtensionDeclarations(ExtensionKind::Function);
104 Ok(())
105 }
106 IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
107 self.state = ExpectedExtensionLine::ExtensionDeclarations(ExtensionKind::Type);
108 Ok(())
109 }
110 IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
111 self.state =
112 ExpectedExtensionLine::ExtensionDeclarations(ExtensionKind::TypeVariation);
113 Ok(())
114 }
115 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
116 }
117 }
118
119 fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
120 match line {
121 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
123 let urn =
124 URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
125 self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
126 Ok(())
127 }
128 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
129 }
130 }
131
132 fn parse_declarations(
133 &mut self,
134 line: IndentedLine,
135 extension_kind: ExtensionKind,
136 ) -> Result<(), ExtensionParseError> {
137 match line {
138 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
140 let decl = SimpleExtensionDeclaration::from_str(s)?;
141 self.extensions.add_extension(
142 extension_kind,
143 decl.urn_anchor,
144 decl.anchor,
145 decl.name,
146 )?;
147 Ok(())
148 }
149 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
150 }
151 }
152
153 pub fn extensions(&self) -> &SimpleExtensions {
154 &self.extensions
155 }
156
157 #[cfg(test)]
158 pub(crate) fn state(&self) -> ExpectedExtensionLine {
159 self.state
160 }
161}
162
163#[derive(Debug, Clone, PartialEq)]
164pub struct URNExtensionDeclaration {
165 pub anchor: u32,
166 pub urn: String,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub struct SimpleExtensionDeclaration {
171 pub anchor: u32,
172 pub urn_anchor: u32,
173 pub name: String,
174}
175
176impl ParsePair for URNExtensionDeclaration {
177 fn rule() -> Rule {
178 Rule::extension_urn_declaration
179 }
180
181 fn message() -> &'static str {
182 "URNExtensionDeclaration"
183 }
184
185 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
186 assert_eq!(pair.as_rule(), Self::rule());
187
188 let mut iter = RuleIter::from(pair.into_inner());
189 let anchor_pair = iter.pop(Rule::urn_anchor);
190 let anchor = unwrap_single_pair(anchor_pair)
191 .as_str()
192 .parse::<u32>()
193 .unwrap();
194 let urn = iter.pop(Rule::urn).as_str().to_string();
195 iter.done();
196
197 URNExtensionDeclaration { anchor, urn }
198 }
199}
200
201impl FromStr for URNExtensionDeclaration {
202 type Err = super::MessageParseError;
203
204 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 Self::parse_str(s)
206 }
207}
208
209impl ParsePair for SimpleExtensionDeclaration {
210 fn rule() -> Rule {
211 Rule::simple_extension
212 }
213
214 fn message() -> &'static str {
215 "SimpleExtensionDeclaration"
216 }
217
218 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
219 assert_eq!(pair.as_rule(), Self::rule());
220 let mut iter = RuleIter::from(pair.into_inner());
221 let anchor_pair = iter.pop(Rule::anchor);
222 let anchor = unwrap_single_pair(anchor_pair)
223 .as_str()
224 .parse::<u32>()
225 .unwrap();
226 let urn_anchor_pair = iter.pop(Rule::urn_anchor);
227 let urn_anchor = unwrap_single_pair(urn_anchor_pair)
228 .as_str()
229 .parse::<u32>()
230 .unwrap();
231 let name_pair = iter.pop(Rule::compound_name);
233 let name = name_pair.as_str().to_string();
234 iter.done();
235
236 SimpleExtensionDeclaration {
237 anchor,
238 urn_anchor,
239 name,
240 }
241 }
242}
243
244impl FromStr for SimpleExtensionDeclaration {
245 type Err = super::MessageParseError;
246
247 fn from_str(s: &str) -> Result<Self, Self::Err> {
248 Self::parse_str(s)
249 }
250}
251
252use crate::extensions::any::Any;
256use crate::parser::expressions::{FieldIndex, Name};
257use crate::textify::expressions::Reference;
258
259impl ScopedParsePair for ExtensionValue {
260 fn rule() -> Rule {
261 Rule::extension_argument
262 }
263
264 fn message() -> &'static str {
265 "ExtensionValue"
266 }
267
268 fn parse_pair(
269 extensions: &SimpleExtensions,
270 pair: pest::iterators::Pair<Rule>,
271 ) -> Result<Self, MessageParseError> {
272 assert_eq!(pair.as_rule(), Self::rule());
273
274 let inner = unwrap_single_pair(pair); Ok(match inner.as_rule() {
277 Rule::enum_value => {
278 let s = inner.as_str().trim_start_matches('&').to_string();
280 ExtensionValue::Enum(s)
281 }
282 Rule::reference => {
283 let field_index = FieldIndex::parse_pair(inner);
285 ExtensionValue::from(Reference(field_index.0))
286 }
287 Rule::untyped_literal => {
288 let value_pair = unwrap_single_pair(inner);
290 match value_pair.as_rule() {
291 Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
292 Rule::integer => {
293 ExtensionValue::Integer(value_pair.as_str().parse::<i64>().unwrap())
294 }
295 Rule::float => {
296 ExtensionValue::Float(value_pair.as_str().parse::<f64>().unwrap())
297 }
298 Rule::boolean => ExtensionValue::Boolean(value_pair.as_str() == "true"),
299 _ => panic!(
300 "Unexpected extension scalar literal type: {:?}",
301 value_pair.as_rule()
302 ),
303 }
304 }
305 Rule::tuple => {
306 let tv = inner
307 .into_inner()
308 .map(|pair| ExtensionValue::parse_pair(extensions, pair))
309 .collect::<Result<TupleValue, MessageParseError>>()?;
310 ExtensionValue::Tuple(tv)
311 }
312 Rule::expression => {
313 let expr = Expression::parse_pair(extensions, inner)?;
314 ExtensionValue::from(expr)
315 }
316 _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
317 })
318 }
319}
320
321impl ScopedParsePair for ExtensionColumn {
322 fn rule() -> Rule {
323 Rule::extension_column
324 }
325
326 fn message() -> &'static str {
327 "ExtensionColumn"
328 }
329
330 fn parse_pair(
331 extensions: &SimpleExtensions,
332 pair: pest::iterators::Pair<Rule>,
333 ) -> Result<Self, MessageParseError> {
334 assert_eq!(pair.as_rule(), Self::rule());
335
336 let inner = unwrap_single_pair(pair); Ok(match inner.as_rule() {
339 Rule::named_column => {
340 let mut iter = inner.into_inner();
341 let name_pair = iter.next().unwrap(); let type_pair = iter.next().unwrap(); let name = Name::parse_pair(name_pair).0.to_string(); let ty = Type::parse_pair(extensions, type_pair)?;
346
347 ExtensionColumn::Named { name, r#type: ty }
348 }
349 Rule::reference => {
350 let field_index = FieldIndex::parse_pair(inner);
352 ExtensionColumn::Expr(Reference(field_index.0).into())
353 }
354 Rule::expression => {
355 let expr = Expression::parse_pair(extensions, inner)?;
356 ExtensionColumn::Expr(expr.into())
357 }
358 _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
359 })
360 }
361}
362
363#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub(crate) enum ExtensionRelationKind {
367 Leaf,
368 Single,
369 Multi,
370}
371
372impl FromStr for ExtensionRelationKind {
373 type Err = String;
374
375 fn from_str(s: &str) -> Result<Self, Self::Err> {
376 match s {
377 "ExtensionLeaf" => Ok(ExtensionRelationKind::Leaf),
378 "ExtensionSingle" => Ok(ExtensionRelationKind::Single),
379 "ExtensionMulti" => Ok(ExtensionRelationKind::Multi),
380 _ => Err(format!("Unknown extension relation type: {s}")),
381 }
382 }
383}
384
385impl ExtensionRelationKind {
386 pub(crate) fn validate_child_count(self, child_count: usize) -> Result<(), String> {
387 match self {
388 ExtensionRelationKind::Leaf => {
389 if child_count == 0 {
390 Ok(())
391 } else {
392 Err(format!(
393 "ExtensionLeaf should have no input children, got {child_count}"
394 ))
395 }
396 }
397 ExtensionRelationKind::Single => {
398 if child_count == 1 {
399 Ok(())
400 } else {
401 Err(format!(
402 "ExtensionSingle should have exactly 1 input child, got {child_count}"
403 ))
404 }
405 }
406 ExtensionRelationKind::Multi => Ok(()),
407 }
408 }
409
410 pub(crate) fn create_rel(
412 self,
413 detail: Option<Any>,
414 children: Vec<substrait::proto::Rel>,
415 ) -> substrait::proto::Rel {
416 use substrait::proto::rel::RelType;
417 use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
418
419 let rel_type = match self {
420 ExtensionRelationKind::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
421 common: None,
422 detail: detail.map(Into::into),
423 }),
424 ExtensionRelationKind::Single => {
425 let input = children.into_iter().next();
426 RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
427 common: None,
428 detail: detail.map(Into::into),
429 input: input.map(Box::new),
430 }))
431 }
432 ExtensionRelationKind::Multi => RelType::ExtensionMulti(ExtensionMultiRel {
433 common: None,
434 detail: detail.map(Into::into),
435 inputs: children,
436 }),
437 };
438
439 substrait::proto::Rel {
440 rel_type: Some(rel_type),
441 }
442 }
443}
444
445#[derive(Debug, Clone)]
448pub(crate) struct ExtensionInvocation {
449 pub(crate) relation_kind: ExtensionRelationKind,
450 pub(crate) name: String,
451 pub(crate) args: ExtensionArgs,
452}
453
454impl ScopedParsePair for ExtensionInvocation {
455 fn rule() -> Rule {
456 Rule::extension_relation
457 }
458
459 fn message() -> &'static str {
460 "ExtensionInvocation"
461 }
462
463 fn parse_pair(
464 extensions: &SimpleExtensions,
465 pair: pest::iterators::Pair<Rule>,
466 ) -> Result<Self, MessageParseError> {
467 assert_eq!(pair.as_rule(), Self::rule());
468
469 let mut iter = pair.into_inner();
470
471 let extension_name_pair = iter.next().unwrap(); let full_extension_name = extension_name_pair.as_str();
474
475 let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
478 let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
479 (parts[0], parts[1].to_string())
480 } else {
481 (full_extension_name, "UnknownExtension".to_string())
482 };
483
484 let relation_kind = ExtensionRelationKind::from_str(relation_type_str).unwrap();
485 let mut args = ExtensionArgs::default();
486
487 let ext_arguments = iter.next().unwrap();
489 match ext_arguments.as_rule() {
490 Rule::arguments => {
491 arguments_rule_parsing(extensions, ext_arguments, &mut args)?;
492 }
493 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
494 }
495
496 let extension_columns = iter.next();
498 if let Some(value) = extension_columns {
499 match value.as_rule() {
500 Rule::extension_columns => {
501 for col_pair in value.into_inner() {
502 if col_pair.as_rule() == Rule::extension_column {
503 let column = ExtensionColumn::parse_pair(extensions, col_pair)?;
504 args.output_columns.push(column);
505 }
506 }
507 }
508 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
509 }
510 }
511
512 Ok(ExtensionInvocation {
513 relation_kind,
514 name: custom_name,
515 args,
516 })
517 }
518}
519
520#[derive(Debug, Clone)]
522pub(crate) struct AddendumInvocation {
523 pub(crate) kind: AddendumKind,
524 pub(crate) name: String,
525 pub(crate) args: ExtensionArgs,
526}
527
528impl ScopedParsePair for AddendumInvocation {
529 fn rule() -> Rule {
530 Rule::addendum
531 }
532
533 fn message() -> &'static str {
534 "AddendumInvocation"
535 }
536
537 fn parse_pair(
538 extensions: &SimpleExtensions,
539 pair: pest::iterators::Pair<Rule>,
540 ) -> Result<Self, MessageParseError> {
541 assert_eq!(pair.as_rule(), Self::rule());
542
543 let mut iter = pair.into_inner();
544
545 let type_pair = iter.next().unwrap(); let kind = match type_pair.as_str() {
548 "Enh" => AddendumKind::Enhancement,
549 "Opt" => AddendumKind::Optimization,
550 "Ext" => AddendumKind::ExtensionTable,
551 other => unreachable!("Unexpected addendum_type: {other}"),
552 };
553
554 let name_pair = iter.next().unwrap();
556 let name = Name::parse_pair(name_pair).0.to_string();
557
558 let mut args = ExtensionArgs::default();
560
561 let arguments_pair = iter.next().unwrap();
562 match arguments_pair.as_rule() {
563 Rule::arguments => {
564 arguments_rule_parsing(extensions, arguments_pair, &mut args)?;
565 }
566 r => unreachable!("Unexpected rule in AddendumInvocation args: {r:?}"),
567 }
568
569 Ok(AddendumInvocation { kind, name, args })
570 }
571}
572
573fn arguments_rule_parsing(
574 extensions: &SimpleExtensions,
575 inner_pair: pest::iterators::Pair<'_, Rule>,
576 args: &mut ExtensionArgs,
577) -> Result<(), MessageParseError> {
578 for arg in inner_pair.into_inner() {
579 match arg.as_rule() {
580 Rule::extension_arguments => {
581 for arg_pair in arg.into_inner() {
582 assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
583 args.push(ExtensionValue::parse_pair(extensions, arg_pair)?);
584 }
585 }
586 Rule::extension_named_arguments => {
587 for arg_pair in arg.into_inner() {
588 assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
589 let mut arg_iter = arg_pair.into_inner();
590 let name_p = arg_iter.next().unwrap();
591 let value_p = arg_iter.next().unwrap();
592 let key = Name::parse_pair(name_p).0.to_string();
593 let val = ExtensionValue::parse_pair(extensions, value_p)?;
594 args.insert(key, val);
595 }
596 }
597 Rule::empty => {}
598 r => unreachable!("Unexpected rule in extension args: {r:?}"),
599 }
600 }
601 Ok(())
602}
603
604#[cfg(test)]
605mod tests {
606 use substrait::proto;
607 use substrait::proto::expression::RexType;
608 use substrait::proto::expression::literal::LiteralType;
609
610 use super::*;
611 use crate::OutputOptions;
612 use crate::extensions::{Expr, ExtensionValue};
613 use crate::fixtures::TestContext;
614 use crate::parser::Parser;
615 use crate::parser::common::test_support::ScopedParse;
616
617 fn parse_extension_value(text: &str) -> ExtensionValue {
618 ExtensionValue::parse(&SimpleExtensions::default(), text).unwrap()
619 }
620
621 #[test]
622 fn test_parse_urn_extension_declaration() {
623 let line = "@1: /my/urn1";
624 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
625 assert_eq!(urn.anchor, 1);
626 assert_eq!(urn.urn, "/my/urn1");
627 }
628
629 #[test]
630 fn test_parse_simple_extension_declaration() {
631 let line = "#5@2: my_function_name";
633 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
634 assert_eq!(decl.anchor, 5);
635 assert_eq!(decl.urn_anchor, 2);
636 assert_eq!(decl.name, "my_function_name");
637
638 let line2 = "#10 @200: another_ext_123";
640 let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
641 assert_eq!(decl.anchor, 10);
642 assert_eq!(decl.urn_anchor, 200);
643 assert_eq!(decl.name, "another_ext_123");
644 }
645
646 #[test]
647 fn test_parse_urn_extension_declaration_str() {
648 let line = "@1: /my/urn1";
649 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
650 assert_eq!(urn.anchor, 1);
651 assert_eq!(urn.urn, "/my/urn1");
652 }
653
654 #[test]
655 fn test_extensions_round_trip_plan() {
656 let input = r#"
657=== Extensions
658URNs:
659 @ 1: /urn/common
660 @ 2: /urn/specific_funcs
661Functions:
662 # 10 @ 1: func_a
663 # 11 @ 2: func_b_special
664Types:
665 # 20 @ 1: SomeType
666Type Variations:
667 # 30 @ 2: VarX
668"#
669 .trim_start();
670
671 let plan = Parser::parse(input).unwrap();
673
674 assert_eq!(plan.extension_urns.len(), 2);
676 assert_eq!(plan.extensions.len(), 4);
677
678 let (extensions, errors) =
680 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
681
682 assert!(errors.is_empty());
683 let output = extensions.to_string(" ");
685
686 assert_eq!(output, input);
688 }
689
690 #[test]
691 fn test_parse_simple_extension_declaration_compound_name() {
692 let line = "#1 @2: equal:any_any";
694 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
695 assert_eq!(decl.anchor, 1);
696 assert_eq!(decl.urn_anchor, 2);
697 assert_eq!(decl.name, "equal:any_any");
698 }
699
700 #[test]
701 fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
702 let line = "#3 @1: regexp_match_substring:str_str_i64";
703 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
704 assert_eq!(decl.anchor, 3);
705 assert_eq!(decl.urn_anchor, 1);
706 assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
707 }
708
709 #[test]
710 fn test_extensions_round_trip_plan_with_compound_names() {
711 let input = r#"=== Extensions
712URNs:
713 @ 1: extension:io.substrait:functions_string
714 @ 2: extension:io.substrait:functions_comparison
715Functions:
716 # 1 @ 2: equal:any_any
717 # 2 @ 1: regexp_match_substring:str_str
718 # 3 @ 1: regexp_match_substring:str_str_i64
719"#;
720 let plan = Parser::parse(input).unwrap();
721 let (extensions, errors) =
722 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
723 assert!(errors.is_empty());
724 assert_eq!(
726 extensions
727 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
728 .unwrap()
729 .1
730 .full(),
731 "equal:any_any"
732 );
733 assert_eq!(
734 extensions
735 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
736 .unwrap()
737 .1
738 .full(),
739 "regexp_match_substring:str_str_i64"
740 );
741 assert_eq!(extensions.to_string(" "), input);
743 }
744
745 #[test]
746 fn test_tuple_mixed_types_parses() {
747 let val = parse_extension_value("(&HASH, 8, 'hello')");
749 let ExtensionValue::Tuple(items) = val else {
750 panic!("expected Tuple, got {val:?}");
751 };
752 assert_eq!(items.len(), 3);
753 let items: Vec<&ExtensionValue> = items.iter().collect();
754 assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
755 assert_eq!(i64::try_from(items[1]).unwrap(), 8);
756 assert_eq!(<&str>::try_from(items[2]).unwrap(), "hello");
757 }
758
759 #[test]
760 fn test_empty_tuple_parses() {
761 let val = parse_extension_value("()");
762 let ExtensionValue::Tuple(items) = val else {
763 panic!("expected Tuple, got {val:?}");
764 };
765 assert!(items.is_empty());
766 }
767
768 #[test]
769 fn test_nested_tuple_parses() {
770 let val = parse_extension_value("((&HASH, &RANGE), 8)");
771 let ExtensionValue::Tuple(outer) = val else {
772 panic!("expected Tuple, got {val:?}");
773 };
774 assert_eq!(outer.len(), 2);
775 let ExtensionValue::Tuple(inner) = outer.iter().next().unwrap() else {
776 panic!("expected inner Tuple");
777 };
778 assert_eq!(inner.len(), 2);
779 assert!(matches!(inner.iter().next().unwrap(), ExtensionValue::Enum(s) if s == "HASH"));
780 assert_eq!(i64::try_from(outer.iter().nth(1).unwrap()).unwrap(), 8);
781 }
782
783 #[test]
784 fn test_tuple_in_addendum_parses() {
785 let inv = AddendumInvocation::parse(
786 &SimpleExtensions::default(),
787 "+ Enh:Foo[(&HASH, &RANGE), count=8]",
788 )
789 .unwrap();
790 assert_eq!(inv.kind, AddendumKind::Enhancement);
791 assert_eq!(inv.name, "Foo");
792 assert_eq!(inv.args.positional.len(), 1);
793 let ExtensionValue::Tuple(items) = &inv.args.positional[0] else {
794 panic!("expected Tuple positional arg");
795 };
796 assert_eq!(items.len(), 2);
797 let items: Vec<&ExtensionValue> = items.iter().collect();
798 assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
799 assert!(matches!(items[1], ExtensionValue::Enum(s) if s == "RANGE"));
800 assert_eq!(inv.args.named.len(), 1);
801 }
802
803 #[test]
804 fn extension_relation_kind_parses_text_prefixes() {
805 assert_eq!(
806 ExtensionRelationKind::from_str("ExtensionLeaf").unwrap(),
807 ExtensionRelationKind::Leaf
808 );
809 assert_eq!(
810 ExtensionRelationKind::from_str("ExtensionSingle").unwrap(),
811 ExtensionRelationKind::Single
812 );
813 assert_eq!(
814 ExtensionRelationKind::from_str("ExtensionMulti").unwrap(),
815 ExtensionRelationKind::Multi
816 );
817 }
818
819 #[test]
820 fn extension_multi_allows_any_child_count() {
821 assert!(ExtensionRelationKind::Multi.validate_child_count(0).is_ok());
822 assert!(ExtensionRelationKind::Multi.validate_child_count(1).is_ok());
823 assert!(ExtensionRelationKind::Multi.validate_child_count(3).is_ok());
824 }
825
826 #[test]
827 fn extension_single_rejects_wrong_child_counts() {
828 assert!(
829 ExtensionRelationKind::Single
830 .validate_child_count(0)
831 .is_err()
832 );
833 assert!(
834 ExtensionRelationKind::Single
835 .validate_child_count(2)
836 .is_err()
837 );
838 }
839
840 #[test]
841 fn test_tuple_textify_roundtrip() {
842 let ctx = TestContext::new();
843 for text in &[
844 "(&HASH, &RANGE)",
845 "(&HASH, 8, 'hello')",
846 "()",
847 "(&HASH,)",
848 "((&HASH, &RANGE), 8)",
849 ] {
850 let val = parse_extension_value(text);
851 let rendered = ctx.textify_no_errors(&val);
852 assert_eq!(&rendered, text, "roundtrip failed for {text}");
853 }
854 }
855
856 #[test]
857 fn test_literal_expression_value_textifies_to_canonical_literal() {
858 let expr = proto::Expression {
859 rex_type: Some(RexType::Literal(proto::expression::Literal {
860 literal_type: Some(LiteralType::I64(42)),
861 nullable: false,
862 type_variation_reference: 0,
863 })),
864 };
865 let value = ExtensionValue::from(expr.clone());
866 let ctx = TestContext::new();
867
868 let rendered = ctx.textify_no_errors(&value);
869 assert_eq!(rendered, "42");
870
871 let parsed = parse_extension_value(&rendered);
872 let parsed_expr = Expr::try_from(&parsed).unwrap();
873 assert_eq!(parsed_expr.as_proto(), &expr);
874 }
875
876 #[test]
877 fn test_extension_scalar_literals_stay_scalar_in_verbose_output() {
878 let ctx = TestContext::new().with_options(OutputOptions::verbose());
879
880 let scalar = ExtensionValue::from(42_i64);
881 assert_eq!(ctx.textify_no_errors(&scalar), "42");
882
883 let expression = ExtensionValue::from(Expr::from(42_i64));
884 assert_eq!(ctx.textify_no_errors(&expression), "42:i64");
885 }
886
887 #[test]
888 fn test_typed_extension_literal_parses_as_expression() {
889 let value = parse_extension_value("42:i16");
890 assert!(i64::try_from(&value).is_err());
891
892 let expr = Expr::try_from(&value).unwrap();
893 assert_eq!(ctx_text(&expr), "42:i16");
894 }
895
896 fn ctx_text(value: &Expr) -> String {
897 TestContext::new().textify_no_errors(value)
898 }
899}