1use std::fs;
2use std::io::{self, Read, Write};
3use std::process::ExitCode;
4
5use anyhow::{Context, Result};
6use clap::{Parser, Subcommand};
7use prost::Message;
8
9use crate::extensions::ExtensionRegistry;
10use crate::{FormatError, OutputOptions, Visibility, format_with_registry, parse_with_registry};
11
12#[derive(Debug)]
17pub enum Outcome {
18 Success,
20 HadFormattingIssues(Vec<FormatError>),
22}
23
24#[derive(Parser)]
25#[command(name = "substrait-explain")]
26#[command(about = "A CLI for parsing and formatting Substrait query plans")]
27#[command(version)]
28pub struct Cli {
29 #[command(subcommand)]
30 pub command: Commands,
31}
32
33impl Cli {
34 pub fn run(self) -> ExitCode {
38 self.run_with_extensions(ExtensionRegistry::default())
39 }
40
41 pub fn run_with_extensions(self, registry: ExtensionRegistry) -> ExitCode {
52 match self.run_inner(®istry) {
53 Ok(Outcome::Success) => ExitCode::SUCCESS,
54 Ok(Outcome::HadFormattingIssues(errors)) => {
55 eprintln!("Formatting issues:");
56 for error in errors {
57 eprintln!(" {error}");
58 }
59 ExitCode::FAILURE
60 }
61 Err(e) => {
62 eprintln!("Error: {e:?}");
63 ExitCode::FAILURE
64 }
65 }
66 }
67
68 fn run_inner(self, registry: &ExtensionRegistry) -> Result<Outcome> {
69 match &self.command {
70 Commands::Convert {
71 input,
72 output,
73 from,
74 to,
75 show_literal_types,
76 show_expression_types,
77 verbose,
78 } => {
79 let reader = get_reader(input)
80 .with_context(|| format!("Failed to open input file: {input}"))?;
81 let writer = get_writer(output)
82 .with_context(|| format!("Failed to create output file: {output}"))?;
83 let options =
84 self.create_output_options(*show_literal_types, *show_expression_types);
85 let from_format = self.resolve_input_format(from, input)?;
86 let to_format = self.resolve_output_format(to, output)?;
87 self.run_convert_with_io(
88 reader,
89 writer,
90 &from_format,
91 &to_format,
92 &options,
93 *verbose,
94 registry,
95 )
96 }
97
98 Commands::Validate {
99 input,
100 output,
101 verbose,
102 } => {
103 let reader = get_reader(input)
104 .with_context(|| format!("Failed to open input file: {input}"))?;
105 let writer = get_writer(output)
106 .with_context(|| format!("Failed to create output file: {output}"))?;
107 self.run_validate_with_io(reader, writer, *verbose, registry)
108 }
109 }
110 }
111
112 pub fn run_with_io<R: Read, W: Write>(
114 &self,
115 reader: R,
116 writer: W,
117 registry: &ExtensionRegistry,
118 ) -> Result<Outcome> {
119 match &self.command {
120 Commands::Convert {
121 input,
122 output,
123 from,
124 to,
125 show_literal_types,
126 show_expression_types,
127 verbose,
128 ..
129 } => {
130 let options =
131 self.create_output_options(*show_literal_types, *show_expression_types);
132 let from_format = self.resolve_input_format(from, input)?;
133 let to_format = self.resolve_output_format(to, output)?;
134 self.run_convert_with_io(
135 reader,
136 writer,
137 &from_format,
138 &to_format,
139 &options,
140 *verbose,
141 registry,
142 )
143 }
144
145 Commands::Validate { verbose, .. } => {
146 self.run_validate_with_io(reader, writer, *verbose, registry)
147 }
148 }
149 }
150
151 fn create_output_options(
152 &self,
153 show_literal_types: bool,
154 show_expression_types: bool,
155 ) -> OutputOptions {
156 let mut options = OutputOptions::default();
157
158 if show_literal_types {
159 options.literal_types = Visibility::Always;
160 }
161
162 if show_expression_types {
163 options.fn_types = true;
164 }
165
166 options
167 }
168
169 fn resolve_input_format(&self, format: &Option<Format>, input_path: &str) -> Result<Format> {
170 match format {
171 Some(fmt) => Ok(fmt.clone()),
172 None => Format::from_extension(input_path).ok_or_else(|| {
173 anyhow::anyhow!(
174 "Could not auto-detect input format from file extension. \
175 Please specify format explicitly with -f/--from. \
176 Supported formats: text, json, yaml, protobuf/proto/pb"
177 )
178 }),
179 }
180 }
181
182 fn resolve_output_format(&self, format: &Option<Format>, output_path: &str) -> Result<Format> {
183 match format {
184 Some(fmt) => Ok(fmt.clone()),
185 None => Format::from_extension(output_path).ok_or_else(|| {
186 anyhow::anyhow!(
187 "Could not auto-detect output format from file extension. \
188 Please specify format explicitly with -t/--to. \
189 Supported formats: text, json, yaml, protobuf/proto/pb"
190 )
191 }),
192 }
193 }
194
195 #[allow(clippy::too_many_arguments)]
199 fn run_convert_with_io<R: Read, W: Write>(
200 &self,
201 reader: R,
202 writer: W,
203 from: &Format,
204 to: &Format,
205 options: &OutputOptions,
206 verbose: bool,
207 registry: &ExtensionRegistry,
208 ) -> Result<Outcome> {
209 let plan = from.read_plan(reader, registry).with_context(|| {
211 format!(
212 "Failed to parse input as {} format",
213 format!("{from:?}").to_lowercase()
214 )
215 })?;
216
217 let outcome = to
219 .write_plan(writer, &plan, options, registry)
220 .with_context(|| {
221 format!(
222 "Failed to write output as {} format",
223 format!("{to:?}").to_lowercase()
224 )
225 })?;
226
227 if verbose && matches!(outcome, Outcome::Success) {
228 eprintln!("Successfully converted from {from:?} to {to:?}");
229 }
230
231 Ok(outcome)
232 }
233
234 fn run_validate_with_io<R: Read, W: Write>(
235 &self,
236 reader: R,
237 writer: W,
238 verbose: bool,
239 registry: &ExtensionRegistry,
240 ) -> Result<Outcome> {
241 let plan = Format::Text
242 .read_plan(reader, registry)
243 .with_context(|| "Failed to parse input as Substrait text format")?;
244
245 let outcome = Format::Text
246 .write_plan(writer, &plan, &OutputOptions::default(), registry)
247 .with_context(|| "Failed to format plan as Substrait text format")?;
248
249 if verbose && matches!(outcome, Outcome::Success) {
250 eprintln!("Successfully validated plan");
251 }
252
253 Ok(outcome)
254 }
255}
256
257#[derive(Subcommand)]
258pub enum Commands {
259 Convert {
275 #[arg(short, long, default_value = "-")]
277 input: String,
278 #[arg(short, long, default_value = "-")]
280 output: String,
281 #[arg(short = 'f', long)]
283 from: Option<Format>,
284 #[arg(short = 't', long)]
286 to: Option<Format>,
287 #[arg(long)]
289 show_literal_types: bool,
290 #[arg(long)]
292 show_expression_types: bool,
293 #[arg(short, long)]
295 verbose: bool,
296 },
297 Validate {
299 #[arg(short, long, default_value = "-")]
301 input: String,
302 #[arg(short, long, default_value = "-")]
304 output: String,
305 #[arg(short, long)]
307 verbose: bool,
308 },
309}
310
311#[derive(Clone, Debug, PartialEq)]
312pub enum Format {
313 Text,
314 Json,
315 Yaml,
316 Protobuf,
317}
318
319impl std::str::FromStr for Format {
320 type Err = String;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 match s.to_lowercase().as_str() {
324 "text" => Ok(Format::Text),
325 "json" => Ok(Format::Json),
326 "yaml" => Ok(Format::Yaml),
327 "protobuf" | "proto" | "pb" => Ok(Format::Protobuf),
328 _ => Err(format!(
329 "Invalid format: '{s}'. Supported formats: text, json, yaml, protobuf/proto/pb"
330 )),
331 }
332 }
333}
334
335impl Format {
336 pub fn from_extension(path: &str) -> Option<Format> {
338 if path == "-" {
339 return None; }
341
342 let extension = std::path::Path::new(path)
343 .extension()
344 .and_then(|ext| ext.to_str())
345 .map(|ext| ext.to_lowercase());
346
347 match extension.as_deref() {
348 Some("substrait") | Some("txt") => Some(Format::Text),
349 Some("json") => Some(Format::Json),
350 Some("yaml") | Some("yml") => Some(Format::Yaml),
351 Some("pb") | Some("proto") | Some("protobuf") => Some(Format::Protobuf),
352 _ => None,
353 }
354 }
355
356 pub fn read_plan<R: Read>(
357 &self,
358 reader: R,
359 registry: &ExtensionRegistry,
360 ) -> Result<substrait::proto::Plan> {
361 match self {
362 Format::Text => {
363 let input_text = read_text_input(reader)?;
364 Ok(parse_with_registry(&input_text, registry)?)
365 }
366 Format::Json => {
367 let input_text = read_text_input(reader)?;
368 let pool = crate::json::build_descriptor_pool(®istry.descriptors())?;
369 crate::json::parse_json(&input_text, &pool)
370 }
371 Format::Yaml => {
372 #[cfg(feature = "serde")]
373 {
374 let input_text = read_text_input(reader)?;
375 Ok(serde_yaml::from_str(&input_text)?)
376 }
377 #[cfg(not(feature = "serde"))]
378 {
379 Err("YAML support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
380 }
381 }
382 Format::Protobuf => {
383 let input_bytes = read_binary_input(reader)?;
384 Ok(substrait::proto::Plan::decode(&input_bytes[..])?)
385 }
386 }
387 }
388
389 pub fn write_plan<W: Write>(
390 &self,
391 writer: W,
392 plan: &substrait::proto::Plan,
393 options: &OutputOptions,
394 registry: &ExtensionRegistry,
395 ) -> Result<Outcome> {
396 match self {
397 Format::Text => {
398 let (text, errors) = format_with_registry(plan, options, registry);
399
400 write_text_output(writer, &text)?;
402
403 if errors.is_empty() {
405 Ok(Outcome::Success)
406 } else {
407 Ok(Outcome::HadFormattingIssues(errors))
408 }
409 }
410 Format::Json => {
411 #[cfg(feature = "serde")]
412 {
413 let json = serde_json::to_string_pretty(plan)?;
414 write_text_output(writer, &json)?;
415 Ok(Outcome::Success)
416 }
417 #[cfg(not(feature = "serde"))]
418 {
419 Err("JSON support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
420 }
421 }
422 Format::Yaml => {
423 #[cfg(feature = "serde")]
424 {
425 let yaml = serde_yaml::to_string(plan)?;
426 write_text_output(writer, &yaml)?;
427 Ok(Outcome::Success)
428 }
429 #[cfg(not(feature = "serde"))]
430 {
431 Err("YAML support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
432 }
433 }
434 Format::Protobuf => {
435 let bytes = plan.encode_to_vec();
436 write_binary_output(writer, &bytes)?;
437 Ok(Outcome::Success)
438 }
439 }
440 }
441}
442
443fn read_text_input<R: Read>(mut reader: R) -> Result<String> {
445 let mut buffer = String::new();
446 reader.read_to_string(&mut buffer)?;
447 Ok(buffer)
448}
449
450fn read_binary_input<R: Read>(mut reader: R) -> Result<Vec<u8>> {
452 let mut buffer = Vec::new();
453 reader.read_to_end(&mut buffer)?;
454 Ok(buffer)
455}
456
457fn write_text_output<W: Write>(mut writer: W, content: &str) -> Result<()> {
459 writer.write_all(content.as_bytes())?;
460 Ok(())
461}
462
463fn write_binary_output<W: Write>(mut writer: W, content: &[u8]) -> Result<()> {
465 writer.write_all(content)?;
466 Ok(())
467}
468
469fn get_reader(path: &str) -> Result<Box<dyn Read>> {
471 if path == "-" {
472 Ok(Box::new(io::stdin()))
473 } else {
474 Ok(Box::new(fs::File::open(path)?))
475 }
476}
477
478fn get_writer(path: &str) -> Result<Box<dyn Write>> {
480 if path == "-" {
481 Ok(Box::new(io::stdout()))
482 } else {
483 Ok(Box::new(fs::File::create(path)?))
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use std::io::Cursor;
490
491 use substrait::proto::expression::RexType;
492 use substrait::proto::plan_rel;
493 use substrait::proto::rel::RelType;
494
495 use super::*;
496 use crate::extensions::{Explainable, ExtensionArgs, ExtensionColumn, ExtensionError};
497 use crate::fixtures::parse_type;
498 use crate::parse;
499
500 const BASIC_PLAN: &str = r#"=== Plan
501Root[result]
502 Project[$0, $1]
503 Read[data => a:i64, b:string]
504"#;
505
506 const PLAN_WITH_EXTENSIONS: &str = r#"=== Extensions
507URNs:
508 @ 1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
509Functions:
510 # 10 @ 1: gt
511
512=== Plan
513Root[result]
514 Filter[gt($2, 100) => $0, $1, $2]
515 Project[$0, $1, $2]
516 Read[data => a:i64, b:string, c:i32]
517"#;
518
519 #[test]
520 fn test_convert_text_to_text() {
521 let input = Cursor::new(BASIC_PLAN);
522 let mut output = Vec::new();
523
524 let cli = Cli {
525 command: Commands::Convert {
526 input: "input.substrait".to_string(),
527 output: "output.substrait".to_string(),
528 from: Some(Format::Text),
529 to: Some(Format::Text),
530 show_literal_types: false,
531 show_expression_types: false,
532 verbose: false,
533 },
534 };
535
536 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
537 .unwrap();
538
539 let output_content = String::from_utf8(output).unwrap();
540 assert!(output_content.contains("=== Plan"));
541 assert!(output_content.contains("Root[result]"));
542 assert!(output_content.contains("Project[$0, $1]"));
543 assert!(output_content.contains("Read[data => a:i64, b:string]"));
544 }
545
546 #[test]
547 fn test_convert_text_to_json() {
548 let input = Cursor::new(BASIC_PLAN);
549 let mut output = Vec::new();
550
551 let cli = Cli {
552 command: Commands::Convert {
553 input: "input.substrait".to_string(),
554 output: "output.json".to_string(),
555 from: Some(Format::Text),
556 to: Some(Format::Json),
557 show_literal_types: false,
558 show_expression_types: false,
559 verbose: false,
560 },
561 };
562
563 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
564 .unwrap();
565
566 let output_content = String::from_utf8(output).unwrap();
567 assert!(output_content.contains("\"relations\""));
568 assert!(output_content.contains("\"root\""));
569 assert!(output_content.contains("\"project\""));
570 assert!(output_content.contains("\"read\""));
571 }
572
573 #[test]
574 fn test_convert_json_to_text() {
575 let input = Cursor::new(BASIC_PLAN);
577 let mut json_output = Vec::new();
578
579 let cli_to_json = Cli {
580 command: Commands::Convert {
581 input: "input.substrait".to_string(),
582 output: "output.json".to_string(),
583 from: Some(Format::Text),
584 to: Some(Format::Json),
585 show_literal_types: false,
586 show_expression_types: false,
587 verbose: false,
588 },
589 };
590
591 cli_to_json
592 .run_with_io(input, &mut json_output, &ExtensionRegistry::default())
593 .unwrap();
594
595 let json_input = Cursor::new(json_output);
597 let mut text_output = Vec::new();
598
599 let cli_to_text = Cli {
600 command: Commands::Convert {
601 input: "input.json".to_string(),
602 output: "output.substrait".to_string(),
603 from: Some(Format::Json),
604 to: Some(Format::Text),
605 show_literal_types: false,
606 show_expression_types: false,
607 verbose: false,
608 },
609 };
610
611 cli_to_text
612 .run_with_io(json_input, &mut text_output, &ExtensionRegistry::default())
613 .unwrap();
614
615 let output_content = String::from_utf8(text_output).unwrap();
616 assert!(output_content.contains("=== Plan"));
617 assert!(output_content.contains("Root[result]"));
618 }
619
620 #[test]
621 fn test_convert_with_protobuf_output() {
622 let input = Cursor::new(BASIC_PLAN);
623 let mut output = Vec::new();
624
625 let cli = Cli {
626 command: Commands::Convert {
627 input: "input.substrait".to_string(),
628 output: "output.pb".to_string(),
629 from: Some(Format::Text),
630 to: Some(Format::Protobuf),
631 show_literal_types: false,
632 show_expression_types: false,
633 verbose: false,
634 },
635 };
636
637 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
638 .unwrap();
639
640 assert!(!output.is_empty());
642
643 let output_string = String::from_utf8_lossy(&output);
645 assert!(!output_string.contains("=== Plan"));
646 }
647
648 #[test]
649 fn test_validate_command() {
650 let input = Cursor::new(BASIC_PLAN);
651 let mut output = Vec::new();
652
653 let cli = Cli {
654 command: Commands::Validate {
655 input: String::new(),
656 output: String::new(),
657 verbose: false,
658 },
659 };
660
661 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
662 .unwrap();
663
664 let output_content = String::from_utf8(output).unwrap();
665 assert!(output_content.contains("=== Plan"));
666 assert!(output_content.contains("Root[result]"));
667 assert!(output_content.contains("Project[$0, $1]"));
668 assert!(output_content.contains("Read[data => a:i64, b:string]"));
669 }
670
671 #[test]
672 fn test_validate_with_extensions() {
673 let input = Cursor::new(PLAN_WITH_EXTENSIONS);
674 let mut output = Vec::new();
675
676 let cli = Cli {
677 command: Commands::Validate {
678 input: String::new(),
679 output: String::new(),
680 verbose: false,
681 },
682 };
683
684 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
685 .unwrap();
686
687 let output_content = String::from_utf8(output).unwrap();
688 assert!(output_content.contains("=== Extensions"));
689 assert!(output_content.contains("=== Plan"));
690 assert!(output_content.contains("Root[result]"));
691 assert!(output_content.contains("Filter[gt($2, 100)"));
692 }
693
694 #[test]
695 fn test_convert_with_formatting_options() {
696 let input = Cursor::new(BASIC_PLAN);
697 let mut output = Vec::new();
698
699 let cli = Cli {
700 command: Commands::Convert {
701 input: "input.substrait".to_string(),
702 output: "output.substrait".to_string(),
703 from: Some(Format::Text),
704 to: Some(Format::Text),
705 show_literal_types: true,
706 show_expression_types: true,
707 verbose: false,
708 },
709 };
710
711 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
712 .unwrap();
713
714 let output_content = String::from_utf8(output).unwrap();
715 assert!(output_content.contains("=== Plan"));
716 assert!(output_content.contains("Root[result]"));
717 }
718
719 #[test]
720 fn test_auto_detect_from_extension() {
721 assert_eq!(Format::from_extension("plan.substrait"), Some(Format::Text));
723 assert_eq!(Format::from_extension("plan.txt"), Some(Format::Text));
724
725 assert_eq!(Format::from_extension("plan.json"), Some(Format::Json));
727
728 assert_eq!(Format::from_extension("plan.yaml"), Some(Format::Yaml));
730 assert_eq!(Format::from_extension("plan.yml"), Some(Format::Yaml));
731
732 assert_eq!(Format::from_extension("plan.pb"), Some(Format::Protobuf));
734 assert_eq!(Format::from_extension("plan.proto"), Some(Format::Protobuf));
735 assert_eq!(
736 Format::from_extension("plan.protobuf"),
737 Some(Format::Protobuf)
738 );
739
740 assert_eq!(Format::from_extension("plan.unknown"), None);
742 assert_eq!(Format::from_extension("plan"), None);
743
744 assert_eq!(Format::from_extension("-"), None);
746 }
747
748 #[test]
749 fn test_convert_with_auto_detection() {
750 let input = Cursor::new(BASIC_PLAN);
751 let mut output = Vec::new();
752
753 let cli = Cli {
754 command: Commands::Convert {
755 input: "input.substrait".to_string(),
756 output: "output.json".to_string(),
757 from: None, to: None, show_literal_types: false,
760 show_expression_types: false,
761 verbose: false,
762 },
763 };
764
765 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
766 .unwrap();
767
768 let output_content = String::from_utf8(output).unwrap();
769 assert!(output_content.contains("\"relations\""));
770 assert!(output_content.contains("\"root\""));
771 assert!(output_content.contains("\"project\""));
772 assert!(output_content.contains("\"read\""));
773 }
774
775 #[test]
776 fn test_auto_detection_error_unknown_input_extension() {
777 let input = Cursor::new(BASIC_PLAN);
778 let mut output = Vec::new();
779
780 let cli = Cli {
781 command: Commands::Convert {
782 input: "input.unknown".to_string(),
783 output: "output.json".to_string(),
784 from: None, to: None,
786 show_literal_types: false,
787 show_expression_types: false,
788 verbose: false,
789 },
790 };
791
792 let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
793 assert!(result.is_err());
794 assert!(
795 result
796 .unwrap_err()
797 .to_string()
798 .contains("Could not auto-detect input format")
799 );
800 }
801
802 #[test]
803 fn test_auto_detection_error_unknown_output_extension() {
804 let input = Cursor::new(BASIC_PLAN);
805 let mut output = Vec::new();
806
807 let cli = Cli {
808 command: Commands::Convert {
809 input: "input.substrait".to_string(),
810 output: "output.unknown".to_string(),
811 from: None,
812 to: None, show_literal_types: false,
814 show_expression_types: false,
815 verbose: false,
816 },
817 };
818
819 let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
820 assert!(result.is_err());
821 assert!(
822 result
823 .unwrap_err()
824 .to_string()
825 .contains("Could not auto-detect output format")
826 );
827 }
828
829 #[test]
830 fn test_explicit_format_overrides_auto_detection() {
831 let input = Cursor::new(BASIC_PLAN);
832 let mut output = Vec::new();
833
834 let cli = Cli {
835 command: Commands::Convert {
836 input: "input.json".to_string(), output: "output.pb".to_string(), from: Some(Format::Text), to: Some(Format::Text), show_literal_types: false,
841 show_expression_types: false,
842 verbose: false,
843 },
844 };
845
846 cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
847 .unwrap();
848
849 let output_content = String::from_utf8(output).unwrap();
850 assert!(output_content.contains("=== Plan"));
851 assert!(output_content.contains("Root[result]"));
852 }
853
854 #[test]
855 fn test_protobuf_roundtrip() {
856 let input = Cursor::new(BASIC_PLAN);
858 let mut protobuf_output = Vec::new();
859
860 let cli_to_protobuf = Cli {
861 command: Commands::Convert {
862 input: "input.substrait".to_string(),
863 output: "output.pb".to_string(),
864 from: Some(Format::Text),
865 to: Some(Format::Protobuf),
866 show_literal_types: false,
867 show_expression_types: false,
868 verbose: false,
869 },
870 };
871
872 cli_to_protobuf
873 .run_with_io(input, &mut protobuf_output, &ExtensionRegistry::default())
874 .unwrap();
875
876 let protobuf_input = Cursor::new(protobuf_output);
878 let mut text_output = Vec::new();
879
880 let cli_to_text = Cli {
881 command: Commands::Convert {
882 input: "input.pb".to_string(),
883 output: "output.substrait".to_string(),
884 from: Some(Format::Protobuf),
885 to: Some(Format::Text),
886 show_literal_types: false,
887 show_expression_types: false,
888 verbose: false,
889 },
890 };
891
892 cli_to_text
893 .run_with_io(
894 protobuf_input,
895 &mut text_output,
896 &ExtensionRegistry::default(),
897 )
898 .unwrap();
899
900 let output_content = String::from_utf8(text_output).unwrap();
901 assert!(output_content.contains("=== Plan"));
902 assert!(output_content.contains("Root[result]"));
903 assert!(output_content.contains("Read[data => a:i64, b:string]"));
904 }
905
906 #[derive(Clone, PartialEq, prost::Message)]
913 struct TestSource {
914 #[prost(string, tag = "1")]
915 tag: String,
916 }
917
918 impl prost::Name for TestSource {
919 const NAME: &'static str = "TestSource";
920 const PACKAGE: &'static str = "test";
921 fn full_name() -> String {
922 "test.TestSource".to_string()
923 }
924 fn type_url() -> String {
925 "type.googleapis.com/test.TestSource".to_string()
926 }
927 }
928
929 impl Explainable for TestSource {
930 fn name() -> &'static str {
931 "TestSource"
932 }
933
934 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
935 let mut extractor = args.extractor();
936 let tag: &str = extractor.expect_named_arg("tag")?;
937 extractor.check_exhausted()?;
938 Ok(TestSource {
939 tag: tag.to_string(),
940 })
941 }
942
943 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
944 let mut args = ExtensionArgs::default();
945 args.insert("tag", self.tag.clone());
946 args.output_columns.push(ExtensionColumn::Named {
947 name: "val".to_string(),
948 r#type: parse_type("i64"),
949 });
950 Ok(args)
951 }
952 }
953
954 fn make_extension_registry() -> ExtensionRegistry {
955 let mut registry = ExtensionRegistry::new();
956 registry.register_relation::<TestSource>().unwrap();
957 registry
958 }
959
960 const PLAN_WITH_CUSTOM_EXTENSION: &str = r#"=== Plan
961Root[val]
962 ExtensionLeaf:TestSource[tag='hello' => val:i64]
963"#;
964
965 #[test]
966 fn test_convert_text_to_text_with_extension_registry() {
967 let registry = make_extension_registry();
968 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
969 let mut output = Vec::new();
970
971 let cli = Cli {
972 command: Commands::Convert {
973 input: "input.substrait".to_string(),
974 output: "output.substrait".to_string(),
975 from: Some(Format::Text),
976 to: Some(Format::Text),
977 show_literal_types: false,
978 show_expression_types: false,
979 verbose: false,
980 },
981 };
982
983 cli.run_with_io(input, &mut output, ®istry).unwrap();
984
985 let output_content = String::from_utf8(output).unwrap();
986 assert_eq!(output_content, PLAN_WITH_CUSTOM_EXTENSION);
987 }
988
989 #[test]
990 fn test_convert_text_to_json_with_extension_registry() {
991 let registry = make_extension_registry();
992 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
993 let mut output = Vec::new();
994
995 let cli = Cli {
996 command: Commands::Convert {
997 input: "input.substrait".to_string(),
998 output: "output.json".to_string(),
999 from: Some(Format::Text),
1000 to: Some(Format::Json),
1001 show_literal_types: false,
1002 show_expression_types: false,
1003 verbose: false,
1004 },
1005 };
1006
1007 cli.run_with_io(input, &mut output, ®istry).unwrap();
1008
1009 let output_content = String::from_utf8(output).unwrap();
1010 assert!(output_content.contains("\"extensionLeaf\""));
1011 }
1012
1013 #[test]
1014 fn test_validate_with_extension_registry() {
1015 let registry = make_extension_registry();
1016 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
1017 let mut output = Vec::new();
1018
1019 let cli = Cli {
1020 command: Commands::Validate {
1021 input: String::new(),
1022 output: String::new(),
1023 verbose: false,
1024 },
1025 };
1026
1027 cli.run_with_io(input, &mut output, ®istry).unwrap();
1028
1029 let output_content = String::from_utf8(output).unwrap();
1030 assert_eq!(output_content, PLAN_WITH_CUSTOM_EXTENSION);
1031 }
1032
1033 #[test]
1034 fn test_convert_text_fails_without_extension_registry() {
1035 let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
1037 let mut output = Vec::new();
1038
1039 let cli = Cli {
1040 command: Commands::Convert {
1041 input: "input.substrait".to_string(),
1042 output: "output.substrait".to_string(),
1043 from: Some(Format::Text),
1044 to: Some(Format::Text),
1045 show_literal_types: false,
1046 show_expression_types: false,
1047 verbose: false,
1048 },
1049 };
1050
1051 let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
1052 assert!(result.is_err());
1053 }
1054
1055 fn make_plan_with_invalid_function_ref() -> substrait::proto::Plan {
1057 const VALID_PLAN: &str = r#"=== Extensions
1058URNs:
1059 @ 1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml
1060Functions:
1061 # 10 @ 1: equal
1062
1063=== Plan
1064Root[result]
1065 Filter[equal($0, 42:i32) => $0]
1066 Read[data => a:i32]
1067"#;
1068
1069 let mut plan = parse(VALID_PLAN).expect("Failed to parse valid plan");
1070
1071 let rel_root = plan.relations.first_mut().unwrap();
1073 let plan_rel::RelType::Root(root) = rel_root.rel_type.as_mut().unwrap() else {
1074 panic!("Expected Root relation");
1075 };
1076 let rel = root.input.as_mut().unwrap();
1077 let RelType::Filter(filter) = rel.rel_type.as_mut().unwrap() else {
1078 panic!("Expected Filter relation");
1079 };
1080 let condition = filter.condition.as_mut().unwrap();
1081 let RexType::ScalarFunction(func) = condition.rex_type.as_mut().unwrap() else {
1082 panic!("Expected ScalarFunction");
1083 };
1084 func.function_reference = 999; plan
1087 }
1088
1089 #[test]
1090 fn test_write_plan_reports_formatting_issues() {
1091 let plan = make_plan_with_invalid_function_ref();
1092 let mut output = Vec::new();
1093
1094 let result = Format::Text.write_plan(
1095 &mut output,
1096 &plan,
1097 &OutputOptions::default(),
1098 &ExtensionRegistry::default(),
1099 );
1100
1101 let outcome = result.expect("write_plan should not return hard error");
1103 assert!(
1104 matches!(outcome, Outcome::HadFormattingIssues(ref errors) if !errors.is_empty()),
1105 "Expected HadFormattingIssues with errors, got {outcome:?}"
1106 );
1107 assert!(
1109 !output.is_empty(),
1110 "Output should be written even with issues"
1111 );
1112 }
1113}