1use std::fmt;
2use std::str::FromStr;
3
4use substrait::proto::{Expression, Type};
5use thiserror::Error;
6
7use super::{
8 MessageParseError, ParsePair, Rule, RuleIter, ScopedParsePair, unescape_string,
9 unwrap_single_pair,
10};
11use crate::extensions::simple::{self, ExtensionKind};
12use crate::extensions::{
13 AddendumKind, ExtensionArgs, ExtensionColumn, ExtensionValue, InsertError, SimpleExtensions,
14 TupleValue,
15};
16use crate::parser::structural::IndentedLine;
17
18#[derive(Debug, Clone, Error)]
19pub enum ExtensionParseError {
20 #[error("Unexpected line, expected {0}")]
21 UnexpectedLine(ExtensionParserState),
22 #[error("Error adding extension: {0}")]
23 ExtensionError(#[from] InsertError),
24 #[error("Error parsing message: {0}")]
25 Message(#[from] super::MessageParseError),
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum ExtensionParserState {
32 Extensions,
35 ExtensionUrns,
38 ExtensionDeclarations(ExtensionKind),
41}
42
43impl fmt::Display for ExtensionParserState {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 ExtensionParserState::Extensions => write!(f, "Subsection Header, e.g. 'URNs:'"),
47 ExtensionParserState::ExtensionUrns => write!(f, "Extension URNs"),
48 ExtensionParserState::ExtensionDeclarations(kind) => {
49 write!(f, "Extension Declaration for {kind}")
50 }
51 }
52 }
53}
54
55#[derive(Debug)]
62pub struct ExtensionParser {
63 state: ExtensionParserState,
64 extensions: SimpleExtensions,
65}
66
67impl Default for ExtensionParser {
68 fn default() -> Self {
69 Self {
70 state: ExtensionParserState::Extensions,
71 extensions: SimpleExtensions::new(),
72 }
73 }
74}
75
76impl ExtensionParser {
77 pub fn parse_line(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
78 if line.1.is_empty() {
79 self.state = ExtensionParserState::Extensions;
82 return Ok(());
83 }
84
85 match self.state {
86 ExtensionParserState::Extensions => self.parse_subsection(line),
87 ExtensionParserState::ExtensionUrns => self.parse_extension_urns(line),
88 ExtensionParserState::ExtensionDeclarations(extension_kind) => {
89 self.parse_declarations(line, extension_kind)
90 }
91 }
92 }
93
94 fn parse_subsection(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
95 match line {
96 IndentedLine(0, simple::EXTENSION_URNS_HEADER) => {
97 self.state = ExtensionParserState::ExtensionUrns;
98 Ok(())
99 }
100 IndentedLine(0, simple::EXTENSION_FUNCTIONS_HEADER) => {
101 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Function);
102 Ok(())
103 }
104 IndentedLine(0, simple::EXTENSION_TYPES_HEADER) => {
105 self.state = ExtensionParserState::ExtensionDeclarations(ExtensionKind::Type);
106 Ok(())
107 }
108 IndentedLine(0, simple::EXTENSION_TYPE_VARIATIONS_HEADER) => {
109 self.state =
110 ExtensionParserState::ExtensionDeclarations(ExtensionKind::TypeVariation);
111 Ok(())
112 }
113 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
114 }
115 }
116
117 fn parse_extension_urns(&mut self, line: IndentedLine) -> Result<(), ExtensionParseError> {
118 match line {
119 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
121 let urn =
122 URNExtensionDeclaration::from_str(s).map_err(ExtensionParseError::Message)?;
123 self.extensions.add_extension_urn(urn.urn, urn.anchor)?;
124 Ok(())
125 }
126 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
127 }
128 }
129
130 fn parse_declarations(
131 &mut self,
132 line: IndentedLine,
133 extension_kind: ExtensionKind,
134 ) -> Result<(), ExtensionParseError> {
135 match line {
136 IndentedLine(0, _s) => self.parse_subsection(line), IndentedLine(1, s) => {
138 let decl = SimpleExtensionDeclaration::from_str(s)?;
139 self.extensions.add_extension(
140 extension_kind,
141 decl.urn_anchor,
142 decl.anchor,
143 decl.name,
144 )?;
145 Ok(())
146 }
147 _ => Err(ExtensionParseError::UnexpectedLine(self.state)),
148 }
149 }
150
151 pub fn extensions(&self) -> &SimpleExtensions {
152 &self.extensions
153 }
154
155 pub fn state(&self) -> ExtensionParserState {
156 self.state
157 }
158}
159
160#[derive(Debug, Clone, PartialEq)]
161pub struct URNExtensionDeclaration {
162 pub anchor: u32,
163 pub urn: String,
164}
165
166#[derive(Debug, Clone, PartialEq)]
167pub struct SimpleExtensionDeclaration {
168 pub anchor: u32,
169 pub urn_anchor: u32,
170 pub name: String,
171}
172
173impl ParsePair for URNExtensionDeclaration {
174 fn rule() -> Rule {
175 Rule::extension_urn_declaration
176 }
177
178 fn message() -> &'static str {
179 "URNExtensionDeclaration"
180 }
181
182 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
183 assert_eq!(pair.as_rule(), Self::rule());
184
185 let mut iter = RuleIter::from(pair.into_inner());
186 let anchor_pair = iter.pop(Rule::urn_anchor);
187 let anchor = unwrap_single_pair(anchor_pair)
188 .as_str()
189 .parse::<u32>()
190 .unwrap();
191 let urn = iter.pop(Rule::urn).as_str().to_string();
192 iter.done();
193
194 URNExtensionDeclaration { anchor, urn }
195 }
196}
197
198impl FromStr for URNExtensionDeclaration {
199 type Err = super::MessageParseError;
200
201 fn from_str(s: &str) -> Result<Self, Self::Err> {
202 Self::parse_str(s)
203 }
204}
205
206impl ParsePair for SimpleExtensionDeclaration {
207 fn rule() -> Rule {
208 Rule::simple_extension
209 }
210
211 fn message() -> &'static str {
212 "SimpleExtensionDeclaration"
213 }
214
215 fn parse_pair(pair: pest::iterators::Pair<Rule>) -> Self {
216 assert_eq!(pair.as_rule(), Self::rule());
217 let mut iter = RuleIter::from(pair.into_inner());
218 let anchor_pair = iter.pop(Rule::anchor);
219 let anchor = unwrap_single_pair(anchor_pair)
220 .as_str()
221 .parse::<u32>()
222 .unwrap();
223 let urn_anchor_pair = iter.pop(Rule::urn_anchor);
224 let urn_anchor = unwrap_single_pair(urn_anchor_pair)
225 .as_str()
226 .parse::<u32>()
227 .unwrap();
228 let name_pair = iter.pop(Rule::compound_name);
230 let name = name_pair.as_str().to_string();
231 iter.done();
232
233 SimpleExtensionDeclaration {
234 anchor,
235 urn_anchor,
236 name,
237 }
238 }
239}
240
241impl FromStr for SimpleExtensionDeclaration {
242 type Err = super::MessageParseError;
243
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 Self::parse_str(s)
246 }
247}
248
249use crate::extensions::any::Any;
253use crate::parser::expressions::{FieldIndex, Name};
254use crate::textify::expressions::Reference;
255
256impl ScopedParsePair for ExtensionValue {
257 fn rule() -> Rule {
258 Rule::extension_argument
259 }
260
261 fn message() -> &'static str {
262 "ExtensionValue"
263 }
264
265 fn parse_pair(
266 extensions: &SimpleExtensions,
267 pair: pest::iterators::Pair<Rule>,
268 ) -> Result<Self, MessageParseError> {
269 assert_eq!(pair.as_rule(), Self::rule());
270
271 let inner = unwrap_single_pair(pair); Ok(match inner.as_rule() {
274 Rule::enum_value => {
275 let s = inner.as_str().trim_start_matches('&').to_string();
277 ExtensionValue::Enum(s)
278 }
279 Rule::reference => {
280 let field_index = FieldIndex::parse_pair(inner);
282 ExtensionValue::from(Reference(field_index.0))
283 }
284 Rule::untyped_literal => {
285 let value_pair = unwrap_single_pair(inner);
287 match value_pair.as_rule() {
288 Rule::string_literal => ExtensionValue::String(unescape_string(value_pair)),
289 Rule::integer => {
290 ExtensionValue::Integer(value_pair.as_str().parse::<i64>().unwrap())
291 }
292 Rule::float => {
293 ExtensionValue::Float(value_pair.as_str().parse::<f64>().unwrap())
294 }
295 Rule::boolean => ExtensionValue::Boolean(value_pair.as_str() == "true"),
296 _ => panic!(
297 "Unexpected extension scalar literal type: {:?}",
298 value_pair.as_rule()
299 ),
300 }
301 }
302 Rule::tuple => {
303 let tv = inner
304 .into_inner()
305 .map(|pair| ExtensionValue::parse_pair(extensions, pair))
306 .collect::<Result<TupleValue, MessageParseError>>()?;
307 ExtensionValue::Tuple(tv)
308 }
309 Rule::expression => {
310 let expr = Expression::parse_pair(extensions, inner)?;
311 ExtensionValue::from(expr)
312 }
313 _ => panic!("Unexpected extension argument type: {:?}", inner.as_rule()),
314 })
315 }
316}
317
318impl ScopedParsePair for ExtensionColumn {
319 fn rule() -> Rule {
320 Rule::extension_column
321 }
322
323 fn message() -> &'static str {
324 "ExtensionColumn"
325 }
326
327 fn parse_pair(
328 extensions: &SimpleExtensions,
329 pair: pest::iterators::Pair<Rule>,
330 ) -> Result<Self, MessageParseError> {
331 assert_eq!(pair.as_rule(), Self::rule());
332
333 let inner = unwrap_single_pair(pair); Ok(match inner.as_rule() {
336 Rule::named_column => {
337 let mut iter = inner.into_inner();
338 let name_pair = iter.next().unwrap(); let type_pair = iter.next().unwrap(); let name = Name::parse_pair(name_pair).0.to_string(); let ty = Type::parse_pair(extensions, type_pair)?;
343
344 ExtensionColumn::Named { name, r#type: ty }
345 }
346 Rule::reference => {
347 let field_index = FieldIndex::parse_pair(inner);
349 ExtensionColumn::Expr(Reference(field_index.0).into())
350 }
351 Rule::expression => {
352 let expr = Expression::parse_pair(extensions, inner)?;
353 ExtensionColumn::Expr(expr.into())
354 }
355 _ => panic!("Unexpected extension column type: {:?}", inner.as_rule()),
356 })
357 }
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
363pub(crate) enum ExtensionRelationKind {
364 Leaf,
365 Single,
366 Multi,
367}
368
369impl FromStr for ExtensionRelationKind {
370 type Err = String;
371
372 fn from_str(s: &str) -> Result<Self, Self::Err> {
373 match s {
374 "ExtensionLeaf" => Ok(ExtensionRelationKind::Leaf),
375 "ExtensionSingle" => Ok(ExtensionRelationKind::Single),
376 "ExtensionMulti" => Ok(ExtensionRelationKind::Multi),
377 _ => Err(format!("Unknown extension relation type: {s}")),
378 }
379 }
380}
381
382impl ExtensionRelationKind {
383 pub(crate) fn validate_child_count(self, child_count: usize) -> Result<(), String> {
384 match self {
385 ExtensionRelationKind::Leaf => {
386 if child_count == 0 {
387 Ok(())
388 } else {
389 Err(format!(
390 "ExtensionLeaf should have no input children, got {child_count}"
391 ))
392 }
393 }
394 ExtensionRelationKind::Single => {
395 if child_count == 1 {
396 Ok(())
397 } else {
398 Err(format!(
399 "ExtensionSingle should have exactly 1 input child, got {child_count}"
400 ))
401 }
402 }
403 ExtensionRelationKind::Multi => Ok(()),
404 }
405 }
406
407 pub(crate) fn create_rel(
409 self,
410 detail: Option<Any>,
411 children: Vec<substrait::proto::Rel>,
412 ) -> substrait::proto::Rel {
413 use substrait::proto::rel::RelType;
414 use substrait::proto::{ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel};
415
416 let rel_type = match self {
417 ExtensionRelationKind::Leaf => RelType::ExtensionLeaf(ExtensionLeafRel {
418 common: None,
419 detail: detail.map(Into::into),
420 }),
421 ExtensionRelationKind::Single => {
422 let input = children.into_iter().next();
423 RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
424 common: None,
425 detail: detail.map(Into::into),
426 input: input.map(Box::new),
427 }))
428 }
429 ExtensionRelationKind::Multi => RelType::ExtensionMulti(ExtensionMultiRel {
430 common: None,
431 detail: detail.map(Into::into),
432 inputs: children,
433 }),
434 };
435
436 substrait::proto::Rel {
437 rel_type: Some(rel_type),
438 }
439 }
440}
441
442#[derive(Debug, Clone)]
445pub(crate) struct ExtensionInvocation {
446 pub(crate) relation_kind: ExtensionRelationKind,
447 pub(crate) name: String,
448 pub(crate) args: ExtensionArgs,
449}
450
451impl ScopedParsePair for ExtensionInvocation {
452 fn rule() -> Rule {
453 Rule::extension_relation
454 }
455
456 fn message() -> &'static str {
457 "ExtensionInvocation"
458 }
459
460 fn parse_pair(
461 extensions: &SimpleExtensions,
462 pair: pest::iterators::Pair<Rule>,
463 ) -> Result<Self, MessageParseError> {
464 assert_eq!(pair.as_rule(), Self::rule());
465
466 let mut iter = pair.into_inner();
467
468 let extension_name_pair = iter.next().unwrap(); let full_extension_name = extension_name_pair.as_str();
471
472 let (relation_type_str, custom_name) = if full_extension_name.contains(':') {
475 let parts: Vec<&str> = full_extension_name.splitn(2, ':').collect();
476 (parts[0], parts[1].to_string())
477 } else {
478 (full_extension_name, "UnknownExtension".to_string())
479 };
480
481 let relation_kind = ExtensionRelationKind::from_str(relation_type_str).unwrap();
482 let mut args = ExtensionArgs::default();
483
484 let ext_arguments = iter.next().unwrap();
486 match ext_arguments.as_rule() {
487 Rule::arguments => {
488 arguments_rule_parsing(extensions, ext_arguments, &mut args)?;
489 }
490 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
491 }
492
493 let extension_columns = iter.next();
495 if let Some(value) = extension_columns {
496 match value.as_rule() {
497 Rule::extension_columns => {
498 for col_pair in value.into_inner() {
499 if col_pair.as_rule() == Rule::extension_column {
500 let column = ExtensionColumn::parse_pair(extensions, col_pair)?;
501 args.output_columns.push(column);
502 }
503 }
504 }
505 r => unreachable!("Unexpected rule in ExtensionArgs: {:?}", r),
506 }
507 }
508
509 Ok(ExtensionInvocation {
510 relation_kind,
511 name: custom_name,
512 args,
513 })
514 }
515}
516
517#[derive(Debug, Clone)]
519pub(crate) struct AddendumInvocation {
520 pub(crate) kind: AddendumKind,
521 pub(crate) name: String,
522 pub(crate) args: ExtensionArgs,
523}
524
525impl ScopedParsePair for AddendumInvocation {
526 fn rule() -> Rule {
527 Rule::addendum
528 }
529
530 fn message() -> &'static str {
531 "AddendumInvocation"
532 }
533
534 fn parse_pair(
535 extensions: &SimpleExtensions,
536 pair: pest::iterators::Pair<Rule>,
537 ) -> Result<Self, MessageParseError> {
538 assert_eq!(pair.as_rule(), Self::rule());
539
540 let mut iter = pair.into_inner();
541
542 let type_pair = iter.next().unwrap(); let kind = match type_pair.as_str() {
545 "Enh" => AddendumKind::Enhancement,
546 "Opt" => AddendumKind::Optimization,
547 "Ext" => AddendumKind::ExtensionTable,
548 other => unreachable!("Unexpected addendum_type: {other}"),
549 };
550
551 let name_pair = iter.next().unwrap();
553 let name = Name::parse_pair(name_pair).0.to_string();
554
555 let mut args = ExtensionArgs::default();
557
558 let arguments_pair = iter.next().unwrap();
559 match arguments_pair.as_rule() {
560 Rule::arguments => {
561 arguments_rule_parsing(extensions, arguments_pair, &mut args)?;
562 }
563 r => unreachable!("Unexpected rule in AddendumInvocation args: {r:?}"),
564 }
565
566 Ok(AddendumInvocation { kind, name, args })
567 }
568}
569
570fn arguments_rule_parsing(
571 extensions: &SimpleExtensions,
572 inner_pair: pest::iterators::Pair<'_, Rule>,
573 args: &mut ExtensionArgs,
574) -> Result<(), MessageParseError> {
575 for arg in inner_pair.into_inner() {
576 match arg.as_rule() {
577 Rule::extension_arguments => {
578 for arg_pair in arg.into_inner() {
579 assert_eq!(arg_pair.as_rule(), Rule::extension_argument);
580 args.push(ExtensionValue::parse_pair(extensions, arg_pair)?);
581 }
582 }
583 Rule::extension_named_arguments => {
584 for arg_pair in arg.into_inner() {
585 assert_eq!(arg_pair.as_rule(), Rule::extension_named_argument);
586 let mut arg_iter = arg_pair.into_inner();
587 let name_p = arg_iter.next().unwrap();
588 let value_p = arg_iter.next().unwrap();
589 let key = Name::parse_pair(name_p).0.to_string();
590 let val = ExtensionValue::parse_pair(extensions, value_p)?;
591 args.insert(key, val);
592 }
593 }
594 Rule::empty => {}
595 r => unreachable!("Unexpected rule in extension args: {r:?}"),
596 }
597 }
598 Ok(())
599}
600
601#[cfg(test)]
602mod tests {
603 use substrait::proto;
604 use substrait::proto::expression::RexType;
605 use substrait::proto::expression::literal::LiteralType;
606
607 use super::*;
608 use crate::OutputOptions;
609 use crate::extensions::{Expr, ExtensionValue};
610 use crate::fixtures::TestContext;
611 use crate::parser::{Parser, ScopedParse};
612
613 fn parse_extension_value(text: &str) -> ExtensionValue {
614 ExtensionValue::parse(&SimpleExtensions::default(), text).unwrap()
615 }
616
617 #[test]
618 fn test_parse_urn_extension_declaration() {
619 let line = "@1: /my/urn1";
620 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
621 assert_eq!(urn.anchor, 1);
622 assert_eq!(urn.urn, "/my/urn1");
623 }
624
625 #[test]
626 fn test_parse_simple_extension_declaration() {
627 let line = "#5@2: my_function_name";
629 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
630 assert_eq!(decl.anchor, 5);
631 assert_eq!(decl.urn_anchor, 2);
632 assert_eq!(decl.name, "my_function_name");
633
634 let line2 = "#10 @200: another_ext_123";
636 let decl = SimpleExtensionDeclaration::from_str(line2).unwrap();
637 assert_eq!(decl.anchor, 10);
638 assert_eq!(decl.urn_anchor, 200);
639 assert_eq!(decl.name, "another_ext_123");
640 }
641
642 #[test]
643 fn test_parse_urn_extension_declaration_str() {
644 let line = "@1: /my/urn1";
645 let urn = URNExtensionDeclaration::parse_str(line).unwrap();
646 assert_eq!(urn.anchor, 1);
647 assert_eq!(urn.urn, "/my/urn1");
648 }
649
650 #[test]
651 fn test_extensions_round_trip_plan() {
652 let input = r#"
653=== Extensions
654URNs:
655 @ 1: /urn/common
656 @ 2: /urn/specific_funcs
657Functions:
658 # 10 @ 1: func_a
659 # 11 @ 2: func_b_special
660Types:
661 # 20 @ 1: SomeType
662Type Variations:
663 # 30 @ 2: VarX
664"#
665 .trim_start();
666
667 let plan = Parser::parse(input).unwrap();
669
670 assert_eq!(plan.extension_urns.len(), 2);
672 assert_eq!(plan.extensions.len(), 4);
673
674 let (extensions, errors) =
676 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
677
678 assert!(errors.is_empty());
679 let output = extensions.to_string(" ");
681
682 assert_eq!(output, input);
684 }
685
686 #[test]
687 fn test_parse_simple_extension_declaration_compound_name() {
688 let line = "#1 @2: equal:any_any";
690 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
691 assert_eq!(decl.anchor, 1);
692 assert_eq!(decl.urn_anchor, 2);
693 assert_eq!(decl.name, "equal:any_any");
694 }
695
696 #[test]
697 fn test_parse_simple_extension_declaration_compound_name_multi_segment() {
698 let line = "#3 @1: regexp_match_substring:str_str_i64";
699 let decl = SimpleExtensionDeclaration::from_str(line).unwrap();
700 assert_eq!(decl.anchor, 3);
701 assert_eq!(decl.urn_anchor, 1);
702 assert_eq!(decl.name, "regexp_match_substring:str_str_i64");
703 }
704
705 #[test]
706 fn test_extensions_round_trip_plan_with_compound_names() {
707 let input = r#"=== Extensions
708URNs:
709 @ 1: extension:io.substrait:functions_string
710 @ 2: extension:io.substrait:functions_comparison
711Functions:
712 # 1 @ 2: equal:any_any
713 # 2 @ 1: regexp_match_substring:str_str
714 # 3 @ 1: regexp_match_substring:str_str_i64
715"#;
716 let plan = Parser::parse(input).unwrap();
717 let (extensions, errors) =
718 SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
719 assert!(errors.is_empty());
720 assert_eq!(
722 extensions
723 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 1)
724 .unwrap()
725 .1
726 .full(),
727 "equal:any_any"
728 );
729 assert_eq!(
730 extensions
731 .find_by_anchor(crate::extensions::simple::ExtensionKind::Function, 3)
732 .unwrap()
733 .1
734 .full(),
735 "regexp_match_substring:str_str_i64"
736 );
737 assert_eq!(extensions.to_string(" "), input);
739 }
740
741 #[test]
742 fn test_tuple_mixed_types_parses() {
743 let val = parse_extension_value("(&HASH, 8, 'hello')");
745 let ExtensionValue::Tuple(items) = val else {
746 panic!("expected Tuple, got {val:?}");
747 };
748 assert_eq!(items.len(), 3);
749 let items: Vec<&ExtensionValue> = items.iter().collect();
750 assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
751 assert_eq!(i64::try_from(items[1]).unwrap(), 8);
752 assert_eq!(<&str>::try_from(items[2]).unwrap(), "hello");
753 }
754
755 #[test]
756 fn test_empty_tuple_parses() {
757 let val = parse_extension_value("()");
758 let ExtensionValue::Tuple(items) = val else {
759 panic!("expected Tuple, got {val:?}");
760 };
761 assert!(items.is_empty());
762 }
763
764 #[test]
765 fn test_nested_tuple_parses() {
766 let val = parse_extension_value("((&HASH, &RANGE), 8)");
767 let ExtensionValue::Tuple(outer) = val else {
768 panic!("expected Tuple, got {val:?}");
769 };
770 assert_eq!(outer.len(), 2);
771 let ExtensionValue::Tuple(inner) = outer.iter().next().unwrap() else {
772 panic!("expected inner Tuple");
773 };
774 assert_eq!(inner.len(), 2);
775 assert!(matches!(inner.iter().next().unwrap(), ExtensionValue::Enum(s) if s == "HASH"));
776 assert_eq!(i64::try_from(outer.iter().nth(1).unwrap()).unwrap(), 8);
777 }
778
779 #[test]
780 fn test_tuple_in_addendum_parses() {
781 let inv = AddendumInvocation::parse(
782 &SimpleExtensions::default(),
783 "+ Enh:Foo[(&HASH, &RANGE), count=8]",
784 )
785 .unwrap();
786 assert_eq!(inv.kind, AddendumKind::Enhancement);
787 assert_eq!(inv.name, "Foo");
788 assert_eq!(inv.args.positional.len(), 1);
789 let ExtensionValue::Tuple(items) = &inv.args.positional[0] else {
790 panic!("expected Tuple positional arg");
791 };
792 assert_eq!(items.len(), 2);
793 let items: Vec<&ExtensionValue> = items.iter().collect();
794 assert!(matches!(items[0], ExtensionValue::Enum(s) if s == "HASH"));
795 assert!(matches!(items[1], ExtensionValue::Enum(s) if s == "RANGE"));
796 assert_eq!(inv.args.named.len(), 1);
797 }
798
799 #[test]
800 fn extension_relation_kind_parses_text_prefixes() {
801 assert_eq!(
802 ExtensionRelationKind::from_str("ExtensionLeaf").unwrap(),
803 ExtensionRelationKind::Leaf
804 );
805 assert_eq!(
806 ExtensionRelationKind::from_str("ExtensionSingle").unwrap(),
807 ExtensionRelationKind::Single
808 );
809 assert_eq!(
810 ExtensionRelationKind::from_str("ExtensionMulti").unwrap(),
811 ExtensionRelationKind::Multi
812 );
813 }
814
815 #[test]
816 fn extension_multi_allows_any_child_count() {
817 assert!(ExtensionRelationKind::Multi.validate_child_count(0).is_ok());
818 assert!(ExtensionRelationKind::Multi.validate_child_count(1).is_ok());
819 assert!(ExtensionRelationKind::Multi.validate_child_count(3).is_ok());
820 }
821
822 #[test]
823 fn extension_single_rejects_wrong_child_counts() {
824 assert!(
825 ExtensionRelationKind::Single
826 .validate_child_count(0)
827 .is_err()
828 );
829 assert!(
830 ExtensionRelationKind::Single
831 .validate_child_count(2)
832 .is_err()
833 );
834 }
835
836 #[test]
837 fn test_tuple_textify_roundtrip() {
838 let ctx = TestContext::new();
839 for text in &[
840 "(&HASH, &RANGE)",
841 "(&HASH, 8, 'hello')",
842 "()",
843 "(&HASH,)",
844 "((&HASH, &RANGE), 8)",
845 ] {
846 let val = parse_extension_value(text);
847 let rendered = ctx.textify_no_errors(&val);
848 assert_eq!(&rendered, text, "roundtrip failed for {text}");
849 }
850 }
851
852 #[test]
853 fn test_literal_expression_value_textifies_to_canonical_literal() {
854 let expr = proto::Expression {
855 rex_type: Some(RexType::Literal(proto::expression::Literal {
856 literal_type: Some(LiteralType::I64(42)),
857 nullable: false,
858 type_variation_reference: 0,
859 })),
860 };
861 let value = ExtensionValue::from(expr.clone());
862 let ctx = TestContext::new();
863
864 let rendered = ctx.textify_no_errors(&value);
865 assert_eq!(rendered, "42");
866
867 let parsed = parse_extension_value(&rendered);
868 let parsed_expr = Expr::try_from(&parsed).unwrap();
869 assert_eq!(parsed_expr.as_proto(), &expr);
870 }
871
872 #[test]
873 fn test_extension_scalar_literals_stay_scalar_in_verbose_output() {
874 let ctx = TestContext::new().with_options(OutputOptions::verbose());
875
876 let scalar = ExtensionValue::from(42_i64);
877 assert_eq!(ctx.textify_no_errors(&scalar), "42");
878
879 let expression = ExtensionValue::from(Expr::from(42_i64));
880 assert_eq!(ctx.textify_no_errors(&expression), "42:i64");
881 }
882
883 #[test]
884 fn test_typed_extension_literal_parses_as_expression() {
885 let value = parse_extension_value("42:i16");
886 assert!(i64::try_from(&value).is_err());
887
888 let expr = Expr::try_from(&value).unwrap();
889 assert_eq!(ctx_text(&expr), "42:i16");
890 }
891
892 fn ctx_text(value: &Expr) -> String {
893 TestContext::new().textify_no_errors(value)
894 }
895}