1use std::fs;
2use std::io::{self, Read, Write};
3use std::process::ExitCode;
4
5use anyhow::{Context, Result};
6use clap::{Parser, Subcommand};
7use prost::Message;
8
9use crate::extensions::ExtensionRegistry;
10use crate::{FormatError, OutputOptions, Visibility, format_with_registry, parse_with_registry};
11
12#[derive(Debug)]
17pub enum Outcome {
18 Success,
20 HadFormattingIssues(Vec<FormatError>),
22}
23
24#[derive(Parser)]
25#[command(name = "substrait-explain")]
26#[command(about = "A CLI for parsing and formatting Substrait query plans")]
27#[command(version)]
28pub struct Cli {
29 #[command(subcommand)]
30 pub command: Commands,
31}
32
33impl Cli {
34 pub fn run(self) -> ExitCode {
38 self.run_with_extensions(ExtensionRegistry::default())
39 }
40
41 pub fn run_with_extensions(self, registry: ExtensionRegistry) -> ExitCode {
52 match self.run_inner(®istry) {
53 Ok(Outcome::Success) => ExitCode::SUCCESS,
54 Ok(Outcome::HadFormattingIssues(errors)) => {
55 eprintln!("Formatting issues:");
56 for error in errors {
57 eprintln!(" {error}");
58 }
59 ExitCode::FAILURE
60 }
61 Err(e) => {
62 eprintln!("Error: {e:?}");
63 ExitCode::FAILURE
64 }
65 }
66 }
67
68 fn run_inner(self, registry: &ExtensionRegistry) -> Result<Outcome> {
69 match &self.command {
70 Commands::Convert {
71 input,
72 output,
73 from,
74 to,
75 show_literal_types,
76 show_expression_types,
77 verbose,
78 } => {
79 let reader = get_reader(input)
80 .with_context(|| format!("Failed to open input file: {input}"))?;
81 let writer = get_writer(output)
82 .with_context(|| format!("Failed to create output file: {output}"))?;
83 let options =
84 self.create_output_options(*show_literal_types, *show_expression_types);
85 let from_format = self.resolve_input_format(from, input)?;
86 let to_format = self.resolve_output_format(to, output)?;
87 self.run_convert_with_io(
88 reader,
89 writer,
90 &from_format,
91 &to_format,
92 &options,
93 *verbose,
94 registry,
95 )
96 }
97
98 Commands::Validate {
99 input,
100 output,
101 verbose,
102 } => {
103 let reader = get_reader(input)
104 .with_context(|| format!("Failed to open input file: {input}"))?;
105 let writer = get_writer(output)
106 .with_context(|| format!("Failed to create output file: {output}"))?;
107 self.run_validate_with_io(reader, writer, *verbose, registry)
108 }
109 }
110 }
111
112 pub fn run_with_io<R: Read, W: Write>(
114 &self,
115 reader: R,
116 writer: W,
117 registry: &ExtensionRegistry,
118 ) -> Result<Outcome> {
119 match &self.command {
120 Commands::Convert {
121 input,
122 output,
123 from,
124 to,
125 show_literal_types,
126 show_expression_types,
127 verbose,
128 ..
129 } => {
130 let options =
131 self.create_output_options(*show_literal_types, *show_expression_types);
132 let from_format = self.resolve_input_format(from, input)?;
133 let to_format = self.resolve_output_format(to, output)?;
134 self.run_convert_with_io(
135 reader,
136 writer,
137 &from_format,
138 &to_format,
139 &options,
140 *verbose,
141 registry,
142 )
143 }
144
145 Commands::Validate { verbose, .. } => {
146 self.run_validate_with_io(reader, writer, *verbose, registry)
147 }
148 }
149 }
150
151 fn create_output_options(
152 &self,
153 show_literal_types: bool,
154 show_expression_types: bool,
155 ) -> OutputOptions {
156 let mut options = OutputOptions::default();
157
158 if show_literal_types {
159 options.literal_types = Visibility::Always;
160 }
161
162 if show_expression_types {
163 options.fn_types = true;
164 }
165
166 options
167 }
168
169 fn resolve_input_format(&self, format: &Option<Format>, input_path: &str) -> Result<Format> {
170 match format {
171 Some(fmt) => Ok(fmt.clone()),
172 None => Format::from_extension(input_path).ok_or_else(|| {
173 anyhow::anyhow!(
174 "Could not auto-detect input format from file extension. \
175 Please specify format explicitly with -f/--from. \
176 Supported formats: text, json, yaml, protobuf/proto/pb"
177 )
178 }),
179 }
180 }
181
182 fn resolve_output_format(&self, format: &Option<Format>, output_path: &str) -> Result<Format> {
183 match format {
184 Some(fmt) => Ok(fmt.clone()),
185 None => Format::from_extension(output_path).ok_or_else(|| {
186 anyhow::anyhow!(
187 "Could not auto-detect output format from file extension. \
188 Please specify format explicitly with -t/--to. \
189 Supported formats: text, json, yaml, protobuf/proto/pb"
190 )
191 }),
192 }
193 }
194
195 #[allow(clippy::too_many_arguments)]
199 fn run_convert_with_io<R: Read, W: Write>(
200 &self,
201 reader: R,
202 writer: W,
203 from: &Format,
204 to: &Format,
205 options: &OutputOptions,
206 verbose: bool,
207 registry: &ExtensionRegistry,
208 ) -> Result<Outcome> {
209 let plan = from.read_plan(reader, registry).with_context(|| {
211 format!(
212 "Failed to parse input as {} format",
213 format!("{from:?}").to_lowercase()
214 )
215 })?;
216
217 let outcome = to
219 .write_plan(writer, &plan, options, registry)
220 .with_context(|| {
221 format!(
222 "Failed to write output as {} format",
223 format!("{to:?}").to_lowercase()
224 )
225 })?;
226
227 if verbose && matches!(outcome, Outcome::Success) {
228 eprintln!("Successfully converted from {from:?} to {to:?}");
229 }
230
231 Ok(outcome)
232 }
233
234 fn run_validate_with_io<R: Read, W: Write>(
235 &self,
236 reader: R,
237 writer: W,
238 verbose: bool,
239 registry: &ExtensionRegistry,
240 ) -> Result<Outcome> {
241 let plan = Format::Text
242 .read_plan(reader, registry)
243 .with_context(|| "Failed to parse input as Substrait text format")?;
244
245 let outcome = Format::Text
246 .write_plan(writer, &plan, &OutputOptions::default(), registry)
247 .with_context(|| "Failed to format plan as Substrait text format")?;
248
249 if verbose && matches!(outcome, Outcome::Success) {
250 eprintln!("Successfully validated plan");
251 }
252
253 Ok(outcome)
254 }
255}
256
257#[derive(Subcommand)]
258pub enum Commands {
259 Convert {
275 #[arg(short, long, default_value = "-")]
277 input: String,
278 #[arg(short, long, default_value = "-")]
280 output: String,
281 #[arg(short = 'f', long)]
283 from: Option<Format>,
284 #[arg(short = 't', long)]
286 to: Option<Format>,
287 #[arg(long)]
289 show_literal_types: bool,
290 #[arg(long)]
292 show_expression_types: bool,
293 #[arg(short, long)]
295 verbose: bool,
296 },
297 Validate {
299 #[arg(short, long, default_value = "-")]
301 input: String,
302 #[arg(short, long, default_value = "-")]
304 output: String,
305 #[arg(short, long)]
307 verbose: bool,
308 },
309}
310
311#[derive(Clone, Debug, PartialEq)]
312pub enum Format {
313 Text,
314 Json,
315 Yaml,
316 Protobuf,
317}
318
319impl std::str::FromStr for Format {
320 type Err = String;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 match s.to_lowercase().as_str() {
324 "text" => Ok(Format::Text),
325 "json" => Ok(Format::Json),
326 "yaml" => Ok(Format::Yaml),
327 "protobuf" | "proto" | "pb" => Ok(Format::Protobuf),
328 _ => Err(format!(
329 "Invalid format: '{s}'. Supported formats: text, json, yaml, protobuf/proto/pb"
330 )),
331 }
332 }
333}
334
335impl Format {
336 pub fn from_extension(path: &str) -> Option<Format> {
338 if path == "-" {
339 return None; }
341
342 let extension = std::path::Path::new(path)
343 .extension()
344 .and_then(|ext| ext.to_str())
345 .map(|ext| ext.to_lowercase());
346
347 match extension.as_deref() {
348 Some("substrait") | Some("txt") => Some(Format::Text),
349 Some("json") => Some(Format::Json),
350 Some("yaml") | Some("yml") => Some(Format::Yaml),
351 Some("pb") | Some("proto") | Some("protobuf") => Some(Format::Protobuf),
352 _ => None,
353 }
354 }
355
356 pub fn read_plan<R: Read>(
357 &self,
358 reader: R,
359 registry: &ExtensionRegistry,
360 ) -> Result<substrait::proto::Plan> {
361 match self {
362 Format::Text => {
363 let input_text = read_text_input(reader)?;
364 Ok(parse_with_registry(&input_text, registry)?)
365 }
366 Format::Json => {
367 let input_text = read_text_input(reader)?;
368 let pool = crate::json::build_descriptor_pool(®istry.descriptors())?;
369 crate::json::parse_json(&input_text, &pool)
370 }
371 Format::Yaml => {
372 #[cfg(feature = "serde")]
373 {
374 let input_text = read_text_input(reader)?;
375 Ok(serde_yaml::from_str(&input_text)?)
376 }
377 #[cfg(not(feature = "serde"))]
378 {
379 Err("YAML support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
380 }
381 }
382 Format::Protobuf => {
383 let input_bytes = read_binary_input(reader)?;
384 Ok(substrait::proto::Plan::decode(&input_bytes[..])?)
385 }
386 }
387 }
388
389 pub fn write_plan<W: Write>(
390 &self,
391 writer: W,
392 plan: &substrait::proto::Plan,
393 options: &OutputOptions,
394 registry: &ExtensionRegistry,
395 ) -> Result<Outcome> {
396 match self {
397 Format::Text => {
398 let (text, errors) = format_with_registry(plan, options, registry);
399
400 write_text_output(writer, &text)?;
402
403 if errors.is_empty() {
405 Ok(Outcome::Success)
406 } else {
407 Ok(Outcome::HadFormattingIssues(errors))
408 }
409 }
410 Format::Json => {
411 #[cfg(feature = "serde")]
412 {
413 let json = serde_json::to_string_pretty(plan)?;
414 write_text_output(writer, &json)?;
415 Ok(Outcome::Success)
416 }
417 #[cfg(not(feature = "serde"))]
418 {
419 Err("JSON support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
420 }
421 }
422 Format::Yaml => {
423 #[cfg(feature = "serde")]
424 {
425 let yaml = serde_yaml::to_string(plan)?;
426 write_text_output(writer, &yaml)?;
427 Ok(Outcome::Success)
428 }
429 #[cfg(not(feature = "serde"))]
430 {
431 Err("YAML support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
432 }
433 }
434 Format::Protobuf => {
435 let bytes = plan.encode_to_vec();
436 write_binary_output(writer, &bytes)?;
437 Ok(Outcome::Success)
438 }
439 }
440 }
441}
442
443fn read_text_input<R: Read>(mut reader: R) -> Result<String> {
445 let mut buffer = String::new();
446 reader.read_to_string(&mut buffer)?;
447 Ok(buffer)
448}
449
450fn read_binary_input<R: Read>(mut reader: R) -> Result<Vec<u8>> {
452 let mut buffer = Vec::new();
453 reader.read_to_end(&mut buffer)?;
454 Ok(buffer)
455}
456
457fn write_text_output<W: Write>(mut writer: W, content: &str) -> Result<()> {
459 writer.write_all(content.as_bytes())?;
460 Ok(())
461}
462
463fn write_binary_output<W: Write>(mut writer: W, content: &[u8]) -> Result<()> {
465 writer.write_all(content)?;
466 Ok(())
467}
468
469fn get_reader(path: &str) -> Result<Box<dyn Read>> {
471 if path == "-" {
472 Ok(Box::new(io::stdin()))
473 } else {
474 Ok(Box::new(fs::File::open(path)?))
475 }
476}
477
478fn get_writer(path: &str) -> Result<Box<dyn Write>> {
480 if path == "-" {
481 Ok(Box::new(io::stdout()))
482 } else {
483 Ok(Box::new(fs::File::create(path)?))
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use std::io::Cursor;
490
491 use substrait::proto::expression::RexType;
492 use substrait::proto::plan_rel;
493 use substrait::proto::rel::RelType;
494
495 use super::*;
496 use crate::extensions::{
497 Explainable, ExtensionArgs, ExtensionColumn, ExtensionError, ExtensionRelationType,
498 ExtensionValue,
499 };
500 use crate::parse;
501
502 const BASIC_PLAN: &str = r#"=== Plan
503Root[result]
504 Project[$0, $1]
505 Read[data => a:i64, b:string]
506"#;
507
508 const PLAN_WITH_EXTENSIONS: &str = r#"=== Extensions
509URNs:
510 @ 1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
511Functions:
512 # 10 @ 1: gt
513
514=== Plan
515Root[result]
516 Filter[gt($2, 100) => $0, $1, $2]
517 Project[$0, $1, $2]
518 Read[data => a:i64, b:string, c:i32]
519"#;
520
521 #[test]
522 fn test_convert_text_to_text() {
523 let input = Cursor::new(BASIC_PLAN);
524 let mut output = Vec::new();
525
526 let cli = Cli {
527 command: Commands::Convert {
528 input: "input.substrait".to_string(),
529 output: "output.substrait".to_string(),
530 from: Some(Format::Text),
531 to: Some(Format::Text),
532 show_literal_types: false,
533 show_expression_types: false,
534 verbose: false,
535 },
536 };
537
538 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
539 .unwrap();
540
541 let output_content = String::from_utf8(output).unwrap();
542 assert!(output_content.contains("=== Plan"));
543 assert!(output_content.contains("Root[result]"));
544 assert!(output_content.contains("Project[$0, $1]"));
545 assert!(output_content.contains("Read[data => a:i64, b:string]"));
546 }
547
548 #[test]
549 fn test_convert_text_to_json() {
550 let input = Cursor::new(BASIC_PLAN);
551 let mut output = Vec::new();
552
553 let cli = Cli {
554 command: Commands::Convert {
555 input: "input.substrait".to_string(),
556 output: "output.json".to_string(),
557 from: Some(Format::Text),
558 to: Some(Format::Json),
559 show_literal_types: false,
560 show_expression_types: false,
561 verbose: false,
562 },
563 };
564
565 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
566 .unwrap();
567
568 let output_content = String::from_utf8(output).unwrap();
569 assert!(output_content.contains("\"relations\""));
570 assert!(output_content.contains("\"root\""));
571 assert!(output_content.contains("\"project\""));
572 assert!(output_content.contains("\"read\""));
573 }
574
575 #[test]
576 fn test_convert_json_to_text() {
577 let input = Cursor::new(BASIC_PLAN);
579 let mut json_output = Vec::new();
580
581 let cli_to_json = Cli {
582 command: Commands::Convert {
583 input: "input.substrait".to_string(),
584 output: "output.json".to_string(),
585 from: Some(Format::Text),
586 to: Some(Format::Json),
587 show_literal_types: false,
588 show_expression_types: false,
589 verbose: false,
590 },
591 };
592
593 cli_to_json
594 .run_with_io(input, &mut json_output, &ExtensionRegistry::default())
595 .unwrap();
596
597 let json_input = Cursor::new(json_output);
599 let mut text_output = Vec::new();
600
601 let cli_to_text = Cli {
602 command: Commands::Convert {
603 input: "input.json".to_string(),
604 output: "output.substrait".to_string(),
605 from: Some(Format::Json),
606 to: Some(Format::Text),
607 show_literal_types: false,
608 show_expression_types: false,
609 verbose: false,
610 },
611 };
612
613 cli_to_text
614 .run_with_io(json_input, &mut text_output, &ExtensionRegistry::default())
615 .unwrap();
616
617 let output_content = String::from_utf8(text_output).unwrap();
618 assert!(output_content.contains("=== Plan"));
619 assert!(output_content.contains("Root[result]"));
620 }
621
622 #[test]
623 fn test_convert_with_protobuf_output() {
624 let input = Cursor::new(BASIC_PLAN);
625 let mut output = Vec::new();
626
627 let cli = Cli {
628 command: Commands::Convert {
629 input: "input.substrait".to_string(),
630 output: "output.pb".to_string(),
631 from: Some(Format::Text),
632 to: Some(Format::Protobuf),
633 show_literal_types: false,
634 show_expression_types: false,
635 verbose: false,
636 },
637 };
638
639 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
640 .unwrap();
641
642 assert!(!output.is_empty());
644
645 let output_string = String::from_utf8_lossy(&output);
647 assert!(!output_string.contains("=== Plan"));
648 }
649
650 #[test]
651 fn test_validate_command() {
652 let input = Cursor::new(BASIC_PLAN);
653 let mut output = Vec::new();
654
655 let cli = Cli {
656 command: Commands::Validate {
657 input: String::new(),
658 output: String::new(),
659 verbose: false,
660 },
661 };
662
663 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
664 .unwrap();
665
666 let output_content = String::from_utf8(output).unwrap();
667 assert!(output_content.contains("=== Plan"));
668 assert!(output_content.contains("Root[result]"));
669 assert!(output_content.contains("Project[$0, $1]"));
670 assert!(output_content.contains("Read[data => a:i64, b:string]"));
671 }
672
673 #[test]
674 fn test_validate_with_extensions() {
675 let input = Cursor::new(PLAN_WITH_EXTENSIONS);
676 let mut output = Vec::new();
677
678 let cli = Cli {
679 command: Commands::Validate {
680 input: String::new(),
681 output: String::new(),
682 verbose: false,
683 },
684 };
685
686 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
687 .unwrap();
688
689 let output_content = String::from_utf8(output).unwrap();
690 assert!(output_content.contains("=== Extensions"));
691 assert!(output_content.contains("=== Plan"));
692 assert!(output_content.contains("Root[result]"));
693 assert!(output_content.contains("Filter[gt($2, 100)"));
694 }
695
696 #[test]
697 fn test_convert_with_formatting_options() {
698 let input = Cursor::new(BASIC_PLAN);
699 let mut output = Vec::new();
700
701 let cli = Cli {
702 command: Commands::Convert {
703 input: "input.substrait".to_string(),
704 output: "output.substrait".to_string(),
705 from: Some(Format::Text),
706 to: Some(Format::Text),
707 show_literal_types: true,
708 show_expression_types: true,
709 verbose: false,
710 },
711 };
712
713 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
714 .unwrap();
715
716 let output_content = String::from_utf8(output).unwrap();
717 assert!(output_content.contains("=== Plan"));
718 assert!(output_content.contains("Root[result]"));
719 }
720
721 #[test]
722 fn test_auto_detect_from_extension() {
723 assert_eq!(Format::from_extension("plan.substrait"), Some(Format::Text));
725 assert_eq!(Format::from_extension("plan.txt"), Some(Format::Text));
726
727 assert_eq!(Format::from_extension("plan.json"), Some(Format::Json));
729
730 assert_eq!(Format::from_extension("plan.yaml"), Some(Format::Yaml));
732 assert_eq!(Format::from_extension("plan.yml"), Some(Format::Yaml));
733
734 assert_eq!(Format::from_extension("plan.pb"), Some(Format::Protobuf));
736 assert_eq!(Format::from_extension("plan.proto"), Some(Format::Protobuf));
737 assert_eq!(
738 Format::from_extension("plan.protobuf"),
739 Some(Format::Protobuf)
740 );
741
742 assert_eq!(Format::from_extension("plan.unknown"), None);
744 assert_eq!(Format::from_extension("plan"), None);
745
746 assert_eq!(Format::from_extension("-"), None);
748 }
749
750 #[test]
751 fn test_convert_with_auto_detection() {
752 let input = Cursor::new(BASIC_PLAN);
753 let mut output = Vec::new();
754
755 let cli = Cli {
756 command: Commands::Convert {
757 input: "input.substrait".to_string(),
758 output: "output.json".to_string(),
759 from: None, to: None, show_literal_types: false,
762 show_expression_types: false,
763 verbose: false,
764 },
765 };
766
767 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
768 .unwrap();
769
770 let output_content = String::from_utf8(output).unwrap();
771 assert!(output_content.contains("\"relations\""));
772 assert!(output_content.contains("\"root\""));
773 assert!(output_content.contains("\"project\""));
774 assert!(output_content.contains("\"read\""));
775 }
776
777 #[test]
778 fn test_auto_detection_error_unknown_input_extension() {
779 let input = Cursor::new(BASIC_PLAN);
780 let mut output = Vec::new();
781
782 let cli = Cli {
783 command: Commands::Convert {
784 input: "input.unknown".to_string(),
785 output: "output.json".to_string(),
786 from: None, to: None,
788 show_literal_types: false,
789 show_expression_types: false,
790 verbose: false,
791 },
792 };
793
794 let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
795 assert!(result.is_err());
796 assert!(
797 result
798 .unwrap_err()
799 .to_string()
800 .contains("Could not auto-detect input format")
801 );
802 }
803
804 #[test]
805 fn test_auto_detection_error_unknown_output_extension() {
806 let input = Cursor::new(BASIC_PLAN);
807 let mut output = Vec::new();
808
809 let cli = Cli {
810 command: Commands::Convert {
811 input: "input.substrait".to_string(),
812 output: "output.unknown".to_string(),
813 from: None,
814 to: None, show_literal_types: false,
816 show_expression_types: false,
817 verbose: false,
818 },
819 };
820
821 let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
822 assert!(result.is_err());
823 assert!(
824 result
825 .unwrap_err()
826 .to_string()
827 .contains("Could not auto-detect output format")
828 );
829 }
830
831 #[test]
832 fn test_explicit_format_overrides_auto_detection() {
833 let input = Cursor::new(BASIC_PLAN);
834 let mut output = Vec::new();
835
836 let cli = Cli {
837 command: Commands::Convert {
838 input: "input.json".to_string(), output: "output.pb".to_string(), from: Some(Format::Text), to: Some(Format::Text), show_literal_types: false,
843 show_expression_types: false,
844 verbose: false,
845 },
846 };
847
848 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
849 .unwrap();
850
851 let output_content = String::from_utf8(output).unwrap();
852 assert!(output_content.contains("=== Plan"));
853 assert!(output_content.contains("Root[result]"));
854 }
855
856 #[test]
857 fn test_protobuf_roundtrip() {
858 let input = Cursor::new(BASIC_PLAN);
860 let mut protobuf_output = Vec::new();
861
862 let cli_to_protobuf = Cli {
863 command: Commands::Convert {
864 input: "input.substrait".to_string(),
865 output: "output.pb".to_string(),
866 from: Some(Format::Text),
867 to: Some(Format::Protobuf),
868 show_literal_types: false,
869 show_expression_types: false,
870 verbose: false,
871 },
872 };
873
874 cli_to_protobuf
875 .run_with_io(input, &mut protobuf_output, &ExtensionRegistry::default())
876 .unwrap();
877
878 let protobuf_input = Cursor::new(protobuf_output);
880 let mut text_output = Vec::new();
881
882 let cli_to_text = Cli {
883 command: Commands::Convert {
884 input: "input.pb".to_string(),
885 output: "output.substrait".to_string(),
886 from: Some(Format::Protobuf),
887 to: Some(Format::Text),
888 show_literal_types: false,
889 show_expression_types: false,
890 verbose: false,
891 },
892 };
893
894 cli_to_text
895 .run_with_io(
896 protobuf_input,
897 &mut text_output,
898 &ExtensionRegistry::default(),
899 )
900 .unwrap();
901
902 let output_content = String::from_utf8(text_output).unwrap();
903 assert!(output_content.contains("=== Plan"));
904 assert!(output_content.contains("Root[result]"));
905 assert!(output_content.contains("Read[data => a:i64, b:string]"));
906 }
907
908 #[derive(Clone, PartialEq, prost::Message)]
915 struct TestSource {
916 #[prost(string, tag = "1")]
917 tag: String,
918 }
919
920 impl prost::Name for TestSource {
921 const NAME: &'static str = "TestSource";
922 const PACKAGE: &'static str = "test";
923 fn full_name() -> String {
924 "test.TestSource".to_string()
925 }
926 fn type_url() -> String {
927 "type.googleapis.com/test.TestSource".to_string()
928 }
929 }
930
931 impl Explainable for TestSource {
932 fn name() -> &'static str {
933 "TestSource"
934 }
935
936 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
937 let mut extractor = args.extractor();
938 let tag: &str = extractor.expect_named_arg("tag")?;
939 extractor.check_exhausted()?;
940 Ok(TestSource {
941 tag: tag.to_string(),
942 })
943 }
944
945 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
946 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
947 args.named
948 .insert("tag".to_string(), ExtensionValue::String(self.tag.clone()));
949 args.output_columns.push(ExtensionColumn::Named {
950 name: "val".to_string(),
951 type_spec: "i64".to_string(),
952 });
953 Ok(args)
954 }
955 }
956
957 fn make_extension_registry() -> ExtensionRegistry {
958 let mut registry = ExtensionRegistry::new();
959 registry.register_relation::<TestSource>().unwrap();
960 registry
961 }
962
963 const PLAN_WITH_CUSTOM_EXTENSION: &str = r#"=== Plan
964Root[val]
965 ExtensionLeaf:TestSource[tag='hello' => val:i64]
966"#;
967
968 #[test]
969 fn test_convert_text_to_text_with_extension_registry() {
970 let registry = make_extension_registry();
971 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
972 let mut output = Vec::new();
973
974 let cli = Cli {
975 command: Commands::Convert {
976 input: "input.substrait".to_string(),
977 output: "output.substrait".to_string(),
978 from: Some(Format::Text),
979 to: Some(Format::Text),
980 show_literal_types: false,
981 show_expression_types: false,
982 verbose: false,
983 },
984 };
985
986 cli.run_with_io(input, &mut output, ®istry).unwrap();
987
988 let output_content = String::from_utf8(output).unwrap();
989 assert_eq!(output_content, PLAN_WITH_CUSTOM_EXTENSION);
990 }
991
992 #[test]
993 fn test_convert_text_to_json_with_extension_registry() {
994 let registry = make_extension_registry();
995 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
996 let mut output = Vec::new();
997
998 let cli = Cli {
999 command: Commands::Convert {
1000 input: "input.substrait".to_string(),
1001 output: "output.json".to_string(),
1002 from: Some(Format::Text),
1003 to: Some(Format::Json),
1004 show_literal_types: false,
1005 show_expression_types: false,
1006 verbose: false,
1007 },
1008 };
1009
1010 cli.run_with_io(input, &mut output, ®istry).unwrap();
1011
1012 let output_content = String::from_utf8(output).unwrap();
1013 assert!(output_content.contains("\"extensionLeaf\""));
1014 }
1015
1016 #[test]
1017 fn test_validate_with_extension_registry() {
1018 let registry = make_extension_registry();
1019 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
1020 let mut output = Vec::new();
1021
1022 let cli = Cli {
1023 command: Commands::Validate {
1024 input: String::new(),
1025 output: String::new(),
1026 verbose: false,
1027 },
1028 };
1029
1030 cli.run_with_io(input, &mut output, ®istry).unwrap();
1031
1032 let output_content = String::from_utf8(output).unwrap();
1033 assert_eq!(output_content, PLAN_WITH_CUSTOM_EXTENSION);
1034 }
1035
1036 #[test]
1037 fn test_convert_text_fails_without_extension_registry() {
1038 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
1040 let mut output = Vec::new();
1041
1042 let cli = Cli {
1043 command: Commands::Convert {
1044 input: "input.substrait".to_string(),
1045 output: "output.substrait".to_string(),
1046 from: Some(Format::Text),
1047 to: Some(Format::Text),
1048 show_literal_types: false,
1049 show_expression_types: false,
1050 verbose: false,
1051 },
1052 };
1053
1054 let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
1055 assert!(result.is_err());
1056 }
1057
1058 fn make_plan_with_invalid_function_ref() -> substrait::proto::Plan {
1060 const VALID_PLAN: &str = r#"=== Extensions
1061URNs:
1062 @ 1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml
1063Functions:
1064 # 10 @ 1: equal
1065
1066=== Plan
1067Root[result]
1068 Filter[equal($0, 42:i32) => $0]
1069 Read[data => a:i32]
1070"#;
1071
1072 let mut plan = parse(VALID_PLAN).expect("Failed to parse valid plan");
1073
1074 let rel_root = plan.relations.first_mut().unwrap();
1076 let plan_rel::RelType::Root(root) = rel_root.rel_type.as_mut().unwrap() else {
1077 panic!("Expected Root relation");
1078 };
1079 let rel = root.input.as_mut().unwrap();
1080 let RelType::Filter(filter) = rel.rel_type.as_mut().unwrap() else {
1081 panic!("Expected Filter relation");
1082 };
1083 let condition = filter.condition.as_mut().unwrap();
1084 let RexType::ScalarFunction(func) = condition.rex_type.as_mut().unwrap() else {
1085 panic!("Expected ScalarFunction");
1086 };
1087 func.function_reference = 999; plan
1090 }
1091
1092 #[test]
1093 fn test_write_plan_reports_formatting_issues() {
1094 let plan = make_plan_with_invalid_function_ref();
1095 let mut output = Vec::new();
1096
1097 let result = Format::Text.write_plan(
1098 &mut output,
1099 &plan,
1100 &OutputOptions::default(),
1101 &ExtensionRegistry::default(),
1102 );
1103
1104 let outcome = result.expect("write_plan should not return hard error");
1106 assert!(
1107 matches!(outcome, Outcome::HadFormattingIssues(ref errors) if !errors.is_empty()),
1108 "Expected HadFormattingIssues with errors, got {outcome:?}"
1109 );
1110 assert!(
1112 !output.is_empty(),
1113 "Output should be written even with issues"
1114 );
1115 }
1116}