substrait_explain/
cli.rs

1use std::fs;
2use std::io::{self, Read, Write};
3use std::process::ExitCode;
4
5use anyhow::{Context, Result};
6use clap::{Parser, Subcommand};
7use prost::Message;
8
9use crate::extensions::ExtensionRegistry;
10use crate::{FormatError, OutputOptions, Visibility, format_with_registry, parse_with_registry};
11
12/// The outcome of a CLI operation.
13///
14/// Distinguishes between complete success and "soft failures" like formatting
15/// issues where output was still written but there were problems.
16#[derive(Debug)]
17pub enum Outcome {
18    /// Operation completed successfully with no issues.
19    Success,
20    /// Output was written, but there were formatting issues.
21    HadFormattingIssues(Vec<FormatError>),
22}
23
24#[derive(Parser)]
25#[command(name = "substrait-explain")]
26#[command(about = "A CLI for parsing and formatting Substrait query plans")]
27#[command(version)]
28pub struct Cli {
29    #[command(subcommand)]
30    pub command: Commands,
31}
32
33impl Cli {
34    /// Run the CLI and return an exit code.
35    ///
36    /// Errors are printed to stderr.
37    pub fn run(self) -> ExitCode {
38        self.run_with_extensions(ExtensionRegistry::default())
39    }
40
41    /// Run the CLI with a custom extension registry and return an exit code.
42    ///
43    /// Use this when embedding the CLI in a binary that registers custom
44    /// extension relation types:
45    ///
46    /// ```rust,ignore
47    /// let mut registry = ExtensionRegistry::new();
48    /// registry.register_relation::<MyCustomScan>().unwrap();
49    /// Cli::parse().run_with_extensions(registry)
50    /// ```
51    pub fn run_with_extensions(self, registry: ExtensionRegistry) -> ExitCode {
52        match self.run_inner(&registry) {
53            Ok(Outcome::Success) => ExitCode::SUCCESS,
54            Ok(Outcome::HadFormattingIssues(errors)) => {
55                eprintln!("Formatting issues:");
56                for error in errors {
57                    eprintln!("  {error}");
58                }
59                ExitCode::FAILURE
60            }
61            Err(e) => {
62                eprintln!("Error: {e:?}");
63                ExitCode::FAILURE
64            }
65        }
66    }
67
68    fn run_inner(self, registry: &ExtensionRegistry) -> Result<Outcome> {
69        match &self.command {
70            Commands::Convert {
71                input,
72                output,
73                from,
74                to,
75                show_literal_types,
76                show_expression_types,
77                verbose,
78            } => {
79                let reader = get_reader(input)
80                    .with_context(|| format!("Failed to open input file: {input}"))?;
81                let writer = get_writer(output)
82                    .with_context(|| format!("Failed to create output file: {output}"))?;
83                let options =
84                    self.create_output_options(*show_literal_types, *show_expression_types);
85                let from_format = self.resolve_input_format(from, input)?;
86                let to_format = self.resolve_output_format(to, output)?;
87                self.run_convert_with_io(
88                    reader,
89                    writer,
90                    &from_format,
91                    &to_format,
92                    &options,
93                    *verbose,
94                    registry,
95                )
96            }
97
98            Commands::Validate {
99                input,
100                output,
101                verbose,
102            } => {
103                let reader = get_reader(input)
104                    .with_context(|| format!("Failed to open input file: {input}"))?;
105                let writer = get_writer(output)
106                    .with_context(|| format!("Failed to create output file: {output}"))?;
107                self.run_validate_with_io(reader, writer, *verbose, registry)
108            }
109        }
110    }
111
112    /// Run CLI with provided readers and writers for testing
113    pub fn run_with_io<R: Read, W: Write>(
114        &self,
115        reader: R,
116        writer: W,
117        registry: &ExtensionRegistry,
118    ) -> Result<Outcome> {
119        match &self.command {
120            Commands::Convert {
121                input,
122                output,
123                from,
124                to,
125                show_literal_types,
126                show_expression_types,
127                verbose,
128                ..
129            } => {
130                let options =
131                    self.create_output_options(*show_literal_types, *show_expression_types);
132                let from_format = self.resolve_input_format(from, input)?;
133                let to_format = self.resolve_output_format(to, output)?;
134                self.run_convert_with_io(
135                    reader,
136                    writer,
137                    &from_format,
138                    &to_format,
139                    &options,
140                    *verbose,
141                    registry,
142                )
143            }
144
145            Commands::Validate { verbose, .. } => {
146                self.run_validate_with_io(reader, writer, *verbose, registry)
147            }
148        }
149    }
150
151    fn create_output_options(
152        &self,
153        show_literal_types: bool,
154        show_expression_types: bool,
155    ) -> OutputOptions {
156        let mut options = OutputOptions::default();
157
158        if show_literal_types {
159            options.literal_types = Visibility::Always;
160        }
161
162        if show_expression_types {
163            options.fn_types = true;
164        }
165
166        options
167    }
168
169    fn resolve_input_format(&self, format: &Option<Format>, input_path: &str) -> Result<Format> {
170        match format {
171            Some(fmt) => Ok(fmt.clone()),
172            None => Format::from_extension(input_path).ok_or_else(|| {
173                anyhow::anyhow!(
174                    "Could not auto-detect input format from file extension. \
175                     Please specify format explicitly with -f/--from. \
176                     Supported formats: text, json, yaml, protobuf/proto/pb"
177                )
178            }),
179        }
180    }
181
182    fn resolve_output_format(&self, format: &Option<Format>, output_path: &str) -> Result<Format> {
183        match format {
184            Some(fmt) => Ok(fmt.clone()),
185            None => Format::from_extension(output_path).ok_or_else(|| {
186                anyhow::anyhow!(
187                    "Could not auto-detect output format from file extension. \
188                     Please specify format explicitly with -t/--to. \
189                     Supported formats: text, json, yaml, protobuf/proto/pb"
190                )
191            }),
192        }
193    }
194
195    // TODO: this could use a refactor; the too_many_arguments tells us
196    // something useful here. We could perhaps add a type containing (registry,
197    // formats, options) or something
198    #[allow(clippy::too_many_arguments)]
199    fn run_convert_with_io<R: Read, W: Write>(
200        &self,
201        reader: R,
202        writer: W,
203        from: &Format,
204        to: &Format,
205        options: &OutputOptions,
206        verbose: bool,
207        registry: &ExtensionRegistry,
208    ) -> Result<Outcome> {
209        // Read input based on format
210        let plan = from.read_plan(reader, registry).with_context(|| {
211            format!(
212                "Failed to parse input as {} format",
213                format!("{from:?}").to_lowercase()
214            )
215        })?;
216
217        // Write output based on format
218        let outcome = to
219            .write_plan(writer, &plan, options, registry)
220            .with_context(|| {
221                format!(
222                    "Failed to write output as {} format",
223                    format!("{to:?}").to_lowercase()
224                )
225            })?;
226
227        if verbose && matches!(outcome, Outcome::Success) {
228            eprintln!("Successfully converted from {from:?} to {to:?}");
229        }
230
231        Ok(outcome)
232    }
233
234    fn run_validate_with_io<R: Read, W: Write>(
235        &self,
236        reader: R,
237        writer: W,
238        verbose: bool,
239        registry: &ExtensionRegistry,
240    ) -> Result<Outcome> {
241        let plan = Format::Text
242            .read_plan(reader, registry)
243            .with_context(|| "Failed to parse input as Substrait text format")?;
244
245        let outcome = Format::Text
246            .write_plan(writer, &plan, &OutputOptions::default(), registry)
247            .with_context(|| "Failed to format plan as Substrait text format")?;
248
249        if verbose && matches!(outcome, Outcome::Success) {
250            eprintln!("Successfully validated plan");
251        }
252
253        Ok(outcome)
254    }
255}
256
257#[derive(Subcommand)]
258pub enum Commands {
259    /// Convert between different Substrait plan formats
260    ///
261    /// Format auto-detection:
262    ///   If -f/--from or -t/--to are not specified, formats will be auto-detected
263    ///   from file extensions:
264    ///     .substrait, .txt    -> text format
265    ///     .json               -> json format
266    ///     .yaml, .yml         -> yaml format
267    ///     .pb, .proto, .protobuf -> protobuf format
268    ///
269    /// Plan formats:
270    ///   text     - Human-readable Substrait text format
271    ///   json     - JSON serialized protobuf
272    ///   yaml     - YAML serialized protobuf
273    ///   protobuf - Binary protobuf format
274    Convert {
275        /// Input file (use - for stdin)
276        #[arg(short, long, default_value = "-")]
277        input: String,
278        /// Output file (use - for stdout)
279        #[arg(short, long, default_value = "-")]
280        output: String,
281        /// Input format: text, json, yaml, protobuf/proto/pb (auto-detected from file extension if not specified)
282        #[arg(short = 'f', long)]
283        from: Option<Format>,
284        /// Output format: text, json, yaml, protobuf/proto/pb (auto-detected from file extension if not specified)
285        #[arg(short = 't', long)]
286        to: Option<Format>,
287        /// Show literal types (text output only)
288        #[arg(long)]
289        show_literal_types: bool,
290        /// Show expression types (text output only)
291        #[arg(long)]
292        show_expression_types: bool,
293        /// Verbose output
294        #[arg(short, long)]
295        verbose: bool,
296    },
297    /// Validate text format by parsing and formatting (roundtrip test)
298    Validate {
299        /// Input file (use - for stdin)
300        #[arg(short, long, default_value = "-")]
301        input: String,
302        /// Output file (use - for stdout)
303        #[arg(short, long, default_value = "-")]
304        output: String,
305        /// Verbose output
306        #[arg(short, long)]
307        verbose: bool,
308    },
309}
310
311#[derive(Clone, Debug, PartialEq)]
312pub enum Format {
313    Text,
314    Json,
315    Yaml,
316    Protobuf,
317}
318
319impl std::str::FromStr for Format {
320    type Err = String;
321
322    fn from_str(s: &str) -> Result<Self, Self::Err> {
323        match s.to_lowercase().as_str() {
324            "text" => Ok(Format::Text),
325            "json" => Ok(Format::Json),
326            "yaml" => Ok(Format::Yaml),
327            "protobuf" | "proto" | "pb" => Ok(Format::Protobuf),
328            _ => Err(format!(
329                "Invalid format: '{s}'. Supported formats: text, json, yaml, protobuf/proto/pb"
330            )),
331        }
332    }
333}
334
335impl Format {
336    /// Detect format from file extension
337    pub fn from_extension(path: &str) -> Option<Format> {
338        if path == "-" {
339            return None; // stdin/stdout - no extension
340        }
341
342        let extension = std::path::Path::new(path)
343            .extension()
344            .and_then(|ext| ext.to_str())
345            .map(|ext| ext.to_lowercase());
346
347        match extension.as_deref() {
348            Some("substrait") | Some("txt") => Some(Format::Text),
349            Some("json") => Some(Format::Json),
350            Some("yaml") | Some("yml") => Some(Format::Yaml),
351            Some("pb") | Some("proto") | Some("protobuf") => Some(Format::Protobuf),
352            _ => None,
353        }
354    }
355
356    pub fn read_plan<R: Read>(
357        &self,
358        reader: R,
359        registry: &ExtensionRegistry,
360    ) -> Result<substrait::proto::Plan> {
361        match self {
362            Format::Text => {
363                let input_text = read_text_input(reader)?;
364                Ok(parse_with_registry(&input_text, registry)?)
365            }
366            Format::Json => {
367                let input_text = read_text_input(reader)?;
368                let pool = crate::json::build_descriptor_pool(&registry.descriptors())?;
369                crate::json::parse_json(&input_text, &pool)
370            }
371            Format::Yaml => {
372                #[cfg(feature = "serde")]
373                {
374                    let input_text = read_text_input(reader)?;
375                    Ok(serde_yaml::from_str(&input_text)?)
376                }
377                #[cfg(not(feature = "serde"))]
378                {
379                    Err("YAML support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
380                }
381            }
382            Format::Protobuf => {
383                let input_bytes = read_binary_input(reader)?;
384                Ok(substrait::proto::Plan::decode(&input_bytes[..])?)
385            }
386        }
387    }
388
389    pub fn write_plan<W: Write>(
390        &self,
391        writer: W,
392        plan: &substrait::proto::Plan,
393        options: &OutputOptions,
394        registry: &ExtensionRegistry,
395    ) -> Result<Outcome> {
396        match self {
397            Format::Text => {
398                let (text, errors) = format_with_registry(plan, options, registry);
399
400                // Write output first (best-effort)
401                write_text_output(writer, &text)?;
402
403                // Return outcome based on whether there were formatting issues
404                if errors.is_empty() {
405                    Ok(Outcome::Success)
406                } else {
407                    Ok(Outcome::HadFormattingIssues(errors))
408                }
409            }
410            Format::Json => {
411                #[cfg(feature = "serde")]
412                {
413                    let json = serde_json::to_string_pretty(plan)?;
414                    write_text_output(writer, &json)?;
415                    Ok(Outcome::Success)
416                }
417                #[cfg(not(feature = "serde"))]
418                {
419                    Err("JSON support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
420                }
421            }
422            Format::Yaml => {
423                #[cfg(feature = "serde")]
424                {
425                    let yaml = serde_yaml::to_string(plan)?;
426                    write_text_output(writer, &yaml)?;
427                    Ok(Outcome::Success)
428                }
429                #[cfg(not(feature = "serde"))]
430                {
431                    Err("YAML support requires the 'serde' feature. Install with: cargo install substrait-explain --features cli,serde".into())
432                }
433            }
434            Format::Protobuf => {
435                let bytes = plan.encode_to_vec();
436                write_binary_output(writer, &bytes)?;
437                Ok(Outcome::Success)
438            }
439        }
440    }
441}
442
443/// Read text input from reader
444fn read_text_input<R: Read>(mut reader: R) -> Result<String> {
445    let mut buffer = String::new();
446    reader.read_to_string(&mut buffer)?;
447    Ok(buffer)
448}
449
450/// Read binary input from reader
451fn read_binary_input<R: Read>(mut reader: R) -> Result<Vec<u8>> {
452    let mut buffer = Vec::new();
453    reader.read_to_end(&mut buffer)?;
454    Ok(buffer)
455}
456
457/// Write text output to writer
458fn write_text_output<W: Write>(mut writer: W, content: &str) -> Result<()> {
459    writer.write_all(content.as_bytes())?;
460    Ok(())
461}
462
463/// Write binary output to writer
464fn write_binary_output<W: Write>(mut writer: W, content: &[u8]) -> Result<()> {
465    writer.write_all(content)?;
466    Ok(())
467}
468
469/// Helper function to get reader from file path (or stdin if "-")
470fn get_reader(path: &str) -> Result<Box<dyn Read>> {
471    if path == "-" {
472        Ok(Box::new(io::stdin()))
473    } else {
474        Ok(Box::new(fs::File::open(path)?))
475    }
476}
477
478/// Helper function to get writer from file path (or stdout if "-")
479fn get_writer(path: &str) -> Result<Box<dyn Write>> {
480    if path == "-" {
481        Ok(Box::new(io::stdout()))
482    } else {
483        Ok(Box::new(fs::File::create(path)?))
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use std::io::Cursor;
490
491    use substrait::proto::expression::RexType;
492    use substrait::proto::plan_rel;
493    use substrait::proto::rel::RelType;
494
495    use super::*;
496    use crate::extensions::{
497        Explainable, ExtensionArgs, ExtensionColumn, ExtensionError, ExtensionRelationType,
498        ExtensionValue,
499    };
500    use crate::parse;
501
502    const BASIC_PLAN: &str = r#"=== Plan
503Root[result]
504  Project[$0, $1]
505    Read[data => a:i64, b:string]
506"#;
507
508    const PLAN_WITH_EXTENSIONS: &str = r#"=== Extensions
509URNs:
510  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
511Functions:
512  # 10 @  1: gt
513
514=== Plan
515Root[result]
516  Filter[gt($2, 100) => $0, $1, $2]
517    Project[$0, $1, $2]
518      Read[data => a:i64, b:string, c:i32]
519"#;
520
521    #[test]
522    fn test_convert_text_to_text() {
523        let input = Cursor::new(BASIC_PLAN);
524        let mut output = Vec::new();
525
526        let cli = Cli {
527            command: Commands::Convert {
528                input: "input.substrait".to_string(),
529                output: "output.substrait".to_string(),
530                from: Some(Format::Text),
531                to: Some(Format::Text),
532                show_literal_types: false,
533                show_expression_types: false,
534                verbose: false,
535            },
536        };
537
538        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
539            .unwrap();
540
541        let output_content = String::from_utf8(output).unwrap();
542        assert!(output_content.contains("=== Plan"));
543        assert!(output_content.contains("Root[result]"));
544        assert!(output_content.contains("Project[$0, $1]"));
545        assert!(output_content.contains("Read[data => a:i64, b:string]"));
546    }
547
548    #[test]
549    fn test_convert_text_to_json() {
550        let input = Cursor::new(BASIC_PLAN);
551        let mut output = Vec::new();
552
553        let cli = Cli {
554            command: Commands::Convert {
555                input: "input.substrait".to_string(),
556                output: "output.json".to_string(),
557                from: Some(Format::Text),
558                to: Some(Format::Json),
559                show_literal_types: false,
560                show_expression_types: false,
561                verbose: false,
562            },
563        };
564
565        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
566            .unwrap();
567
568        let output_content = String::from_utf8(output).unwrap();
569        assert!(output_content.contains("\"relations\""));
570        assert!(output_content.contains("\"root\""));
571        assert!(output_content.contains("\"project\""));
572        assert!(output_content.contains("\"read\""));
573    }
574
575    #[test]
576    fn test_convert_json_to_text() {
577        // First convert text to JSON
578        let input = Cursor::new(BASIC_PLAN);
579        let mut json_output = Vec::new();
580
581        let cli_to_json = Cli {
582            command: Commands::Convert {
583                input: "input.substrait".to_string(),
584                output: "output.json".to_string(),
585                from: Some(Format::Text),
586                to: Some(Format::Json),
587                show_literal_types: false,
588                show_expression_types: false,
589                verbose: false,
590            },
591        };
592
593        cli_to_json
594            .run_with_io(input, &mut json_output, &ExtensionRegistry::default())
595            .unwrap();
596
597        // Now convert JSON back to text
598        let json_input = Cursor::new(json_output);
599        let mut text_output = Vec::new();
600
601        let cli_to_text = Cli {
602            command: Commands::Convert {
603                input: "input.json".to_string(),
604                output: "output.substrait".to_string(),
605                from: Some(Format::Json),
606                to: Some(Format::Text),
607                show_literal_types: false,
608                show_expression_types: false,
609                verbose: false,
610            },
611        };
612
613        cli_to_text
614            .run_with_io(json_input, &mut text_output, &ExtensionRegistry::default())
615            .unwrap();
616
617        let output_content = String::from_utf8(text_output).unwrap();
618        assert!(output_content.contains("=== Plan"));
619        assert!(output_content.contains("Root[result]"));
620    }
621
622    #[test]
623    fn test_convert_with_protobuf_output() {
624        let input = Cursor::new(BASIC_PLAN);
625        let mut output = Vec::new();
626
627        let cli = Cli {
628            command: Commands::Convert {
629                input: "input.substrait".to_string(),
630                output: "output.pb".to_string(),
631                from: Some(Format::Text),
632                to: Some(Format::Protobuf),
633                show_literal_types: false,
634                show_expression_types: false,
635                verbose: false,
636            },
637        };
638
639        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
640            .unwrap();
641
642        // Protobuf output should be binary, so we just check that it's not empty
643        assert!(!output.is_empty());
644
645        // Should not contain readable text
646        let output_string = String::from_utf8_lossy(&output);
647        assert!(!output_string.contains("=== Plan"));
648    }
649
650    #[test]
651    fn test_validate_command() {
652        let input = Cursor::new(BASIC_PLAN);
653        let mut output = Vec::new();
654
655        let cli = Cli {
656            command: Commands::Validate {
657                input: String::new(),
658                output: String::new(),
659                verbose: false,
660            },
661        };
662
663        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
664            .unwrap();
665
666        let output_content = String::from_utf8(output).unwrap();
667        assert!(output_content.contains("=== Plan"));
668        assert!(output_content.contains("Root[result]"));
669        assert!(output_content.contains("Project[$0, $1]"));
670        assert!(output_content.contains("Read[data => a:i64, b:string]"));
671    }
672
673    #[test]
674    fn test_validate_with_extensions() {
675        let input = Cursor::new(PLAN_WITH_EXTENSIONS);
676        let mut output = Vec::new();
677
678        let cli = Cli {
679            command: Commands::Validate {
680                input: String::new(),
681                output: String::new(),
682                verbose: false,
683            },
684        };
685
686        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
687            .unwrap();
688
689        let output_content = String::from_utf8(output).unwrap();
690        assert!(output_content.contains("=== Extensions"));
691        assert!(output_content.contains("=== Plan"));
692        assert!(output_content.contains("Root[result]"));
693        assert!(output_content.contains("Filter[gt($2, 100)"));
694    }
695
696    #[test]
697    fn test_convert_with_formatting_options() {
698        let input = Cursor::new(BASIC_PLAN);
699        let mut output = Vec::new();
700
701        let cli = Cli {
702            command: Commands::Convert {
703                input: "input.substrait".to_string(),
704                output: "output.substrait".to_string(),
705                from: Some(Format::Text),
706                to: Some(Format::Text),
707                show_literal_types: true,
708                show_expression_types: true,
709                verbose: false,
710            },
711        };
712
713        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
714            .unwrap();
715
716        let output_content = String::from_utf8(output).unwrap();
717        assert!(output_content.contains("=== Plan"));
718        assert!(output_content.contains("Root[result]"));
719    }
720
721    #[test]
722    fn test_auto_detect_from_extension() {
723        // Test auto-detection of text format
724        assert_eq!(Format::from_extension("plan.substrait"), Some(Format::Text));
725        assert_eq!(Format::from_extension("plan.txt"), Some(Format::Text));
726
727        // Test auto-detection of JSON format
728        assert_eq!(Format::from_extension("plan.json"), Some(Format::Json));
729
730        // Test auto-detection of YAML format
731        assert_eq!(Format::from_extension("plan.yaml"), Some(Format::Yaml));
732        assert_eq!(Format::from_extension("plan.yml"), Some(Format::Yaml));
733
734        // Test auto-detection of protobuf format
735        assert_eq!(Format::from_extension("plan.pb"), Some(Format::Protobuf));
736        assert_eq!(Format::from_extension("plan.proto"), Some(Format::Protobuf));
737        assert_eq!(
738            Format::from_extension("plan.protobuf"),
739            Some(Format::Protobuf)
740        );
741
742        // Test unknown extensions
743        assert_eq!(Format::from_extension("plan.unknown"), None);
744        assert_eq!(Format::from_extension("plan"), None);
745
746        // Test stdin/stdout
747        assert_eq!(Format::from_extension("-"), None);
748    }
749
750    #[test]
751    fn test_convert_with_auto_detection() {
752        let input = Cursor::new(BASIC_PLAN);
753        let mut output = Vec::new();
754
755        let cli = Cli {
756            command: Commands::Convert {
757                input: "input.substrait".to_string(),
758                output: "output.json".to_string(),
759                from: None, // Auto-detect from extension
760                to: None,   // Auto-detect from extension
761                show_literal_types: false,
762                show_expression_types: false,
763                verbose: false,
764            },
765        };
766
767        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
768            .unwrap();
769
770        let output_content = String::from_utf8(output).unwrap();
771        assert!(output_content.contains("\"relations\""));
772        assert!(output_content.contains("\"root\""));
773        assert!(output_content.contains("\"project\""));
774        assert!(output_content.contains("\"read\""));
775    }
776
777    #[test]
778    fn test_auto_detection_error_unknown_input_extension() {
779        let input = Cursor::new(BASIC_PLAN);
780        let mut output = Vec::new();
781
782        let cli = Cli {
783            command: Commands::Convert {
784                input: "input.unknown".to_string(),
785                output: "output.json".to_string(),
786                from: None, // Should fail auto-detection
787                to: None,
788                show_literal_types: false,
789                show_expression_types: false,
790                verbose: false,
791            },
792        };
793
794        let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
795        assert!(result.is_err());
796        assert!(
797            result
798                .unwrap_err()
799                .to_string()
800                .contains("Could not auto-detect input format")
801        );
802    }
803
804    #[test]
805    fn test_auto_detection_error_unknown_output_extension() {
806        let input = Cursor::new(BASIC_PLAN);
807        let mut output = Vec::new();
808
809        let cli = Cli {
810            command: Commands::Convert {
811                input: "input.substrait".to_string(),
812                output: "output.unknown".to_string(),
813                from: None,
814                to: None, // Should fail auto-detection
815                show_literal_types: false,
816                show_expression_types: false,
817                verbose: false,
818            },
819        };
820
821        let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
822        assert!(result.is_err());
823        assert!(
824            result
825                .unwrap_err()
826                .to_string()
827                .contains("Could not auto-detect output format")
828        );
829    }
830
831    #[test]
832    fn test_explicit_format_overrides_auto_detection() {
833        let input = Cursor::new(BASIC_PLAN);
834        let mut output = Vec::new();
835
836        let cli = Cli {
837            command: Commands::Convert {
838                input: "input.json".to_string(), // Would auto-detect as JSON
839                output: "output.pb".to_string(), // Would auto-detect as Protobuf
840                from: Some(Format::Text),        // Explicit override
841                to: Some(Format::Text),          // Explicit override
842                show_literal_types: false,
843                show_expression_types: false,
844                verbose: false,
845            },
846        };
847
848        cli.run_with_io(input, &mut output, &ExtensionRegistry::default())
849            .unwrap();
850
851        let output_content = String::from_utf8(output).unwrap();
852        assert!(output_content.contains("=== Plan"));
853        assert!(output_content.contains("Root[result]"));
854    }
855
856    #[test]
857    fn test_protobuf_roundtrip() {
858        // Convert text to protobuf
859        let input = Cursor::new(BASIC_PLAN);
860        let mut protobuf_output = Vec::new();
861
862        let cli_to_protobuf = Cli {
863            command: Commands::Convert {
864                input: "input.substrait".to_string(),
865                output: "output.pb".to_string(),
866                from: Some(Format::Text),
867                to: Some(Format::Protobuf),
868                show_literal_types: false,
869                show_expression_types: false,
870                verbose: false,
871            },
872        };
873
874        cli_to_protobuf
875            .run_with_io(input, &mut protobuf_output, &ExtensionRegistry::default())
876            .unwrap();
877
878        // Convert protobuf back to text
879        let protobuf_input = Cursor::new(protobuf_output);
880        let mut text_output = Vec::new();
881
882        let cli_to_text = Cli {
883            command: Commands::Convert {
884                input: "input.pb".to_string(),
885                output: "output.substrait".to_string(),
886                from: Some(Format::Protobuf),
887                to: Some(Format::Text),
888                show_literal_types: false,
889                show_expression_types: false,
890                verbose: false,
891            },
892        };
893
894        cli_to_text
895            .run_with_io(
896                protobuf_input,
897                &mut text_output,
898                &ExtensionRegistry::default(),
899            )
900            .unwrap();
901
902        let output_content = String::from_utf8(text_output).unwrap();
903        assert!(output_content.contains("=== Plan"));
904        assert!(output_content.contains("Root[result]"));
905        assert!(output_content.contains("Read[data => a:i64, b:string]"));
906    }
907
908    // -----------------------------------------------------------------
909    // Minimal test extension for verifying registry-aware CLI parsing
910    // -----------------------------------------------------------------
911
912    /// A minimal ExtensionLeaf with one named argument, used to verify
913    /// that CLI commands pass the registry through to the text parser.
914    #[derive(Clone, PartialEq, prost::Message)]
915    struct TestSource {
916        #[prost(string, tag = "1")]
917        tag: String,
918    }
919
920    impl prost::Name for TestSource {
921        const NAME: &'static str = "TestSource";
922        const PACKAGE: &'static str = "test";
923        fn full_name() -> String {
924            "test.TestSource".to_string()
925        }
926        fn type_url() -> String {
927            "type.googleapis.com/test.TestSource".to_string()
928        }
929    }
930
931    impl Explainable for TestSource {
932        fn name() -> &'static str {
933            "TestSource"
934        }
935
936        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
937            let mut extractor = args.extractor();
938            let tag: &str = extractor.expect_named_arg("tag")?;
939            extractor.check_exhausted()?;
940            Ok(TestSource {
941                tag: tag.to_string(),
942            })
943        }
944
945        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
946            let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
947            args.named
948                .insert("tag".to_string(), ExtensionValue::String(self.tag.clone()));
949            args.output_columns.push(ExtensionColumn::Named {
950                name: "val".to_string(),
951                type_spec: "i64".to_string(),
952            });
953            Ok(args)
954        }
955    }
956
957    fn make_extension_registry() -> ExtensionRegistry {
958        let mut registry = ExtensionRegistry::new();
959        registry.register_relation::<TestSource>().unwrap();
960        registry
961    }
962
963    const PLAN_WITH_CUSTOM_EXTENSION: &str = r#"=== Plan
964Root[val]
965  ExtensionLeaf:TestSource[tag='hello' => val:i64]
966"#;
967
968    #[test]
969    fn test_convert_text_to_text_with_extension_registry() {
970        let registry = make_extension_registry();
971        let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
972        let mut output = Vec::new();
973
974        let cli = Cli {
975            command: Commands::Convert {
976                input: "input.substrait".to_string(),
977                output: "output.substrait".to_string(),
978                from: Some(Format::Text),
979                to: Some(Format::Text),
980                show_literal_types: false,
981                show_expression_types: false,
982                verbose: false,
983            },
984        };
985
986        cli.run_with_io(input, &mut output, &registry).unwrap();
987
988        let output_content = String::from_utf8(output).unwrap();
989        assert_eq!(output_content, PLAN_WITH_CUSTOM_EXTENSION);
990    }
991
992    #[test]
993    fn test_convert_text_to_json_with_extension_registry() {
994        let registry = make_extension_registry();
995        let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
996        let mut output = Vec::new();
997
998        let cli = Cli {
999            command: Commands::Convert {
1000                input: "input.substrait".to_string(),
1001                output: "output.json".to_string(),
1002                from: Some(Format::Text),
1003                to: Some(Format::Json),
1004                show_literal_types: false,
1005                show_expression_types: false,
1006                verbose: false,
1007            },
1008        };
1009
1010        cli.run_with_io(input, &mut output, &registry).unwrap();
1011
1012        let output_content = String::from_utf8(output).unwrap();
1013        assert!(output_content.contains("\"extensionLeaf\""));
1014    }
1015
1016    #[test]
1017    fn test_validate_with_extension_registry() {
1018        let registry = make_extension_registry();
1019        let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
1020        let mut output = Vec::new();
1021
1022        let cli = Cli {
1023            command: Commands::Validate {
1024                input: String::new(),
1025                output: String::new(),
1026                verbose: false,
1027            },
1028        };
1029
1030        cli.run_with_io(input, &mut output, &registry).unwrap();
1031
1032        let output_content = String::from_utf8(output).unwrap();
1033        assert_eq!(output_content, PLAN_WITH_CUSTOM_EXTENSION);
1034    }
1035
1036    #[test]
1037    fn test_convert_text_fails_without_extension_registry() {
1038        // Without the registry, parsing a plan with custom extensions should fail
1039        let input = Cursor::new(PLAN_WITH_CUSTOM_EXTENSION);
1040        let mut output = Vec::new();
1041
1042        let cli = Cli {
1043            command: Commands::Convert {
1044                input: "input.substrait".to_string(),
1045                output: "output.substrait".to_string(),
1046                from: Some(Format::Text),
1047                to: Some(Format::Text),
1048                show_literal_types: false,
1049                show_expression_types: false,
1050                verbose: false,
1051            },
1052        };
1053
1054        let result = cli.run_with_io(input, &mut output, &ExtensionRegistry::default());
1055        assert!(result.is_err());
1056    }
1057
1058    /// Creates a plan with an invalid function reference that will cause formatting errors.
1059    fn make_plan_with_invalid_function_ref() -> substrait::proto::Plan {
1060        const VALID_PLAN: &str = r#"=== Extensions
1061URNs:
1062  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml
1063Functions:
1064  # 10 @  1: equal
1065
1066=== Plan
1067Root[result]
1068  Filter[equal($0, 42:i32) => $0]
1069    Read[data => a:i32]
1070"#;
1071
1072        let mut plan = parse(VALID_PLAN).expect("Failed to parse valid plan");
1073
1074        // Navigate to the function and corrupt its reference
1075        let rel_root = plan.relations.first_mut().unwrap();
1076        let plan_rel::RelType::Root(root) = rel_root.rel_type.as_mut().unwrap() else {
1077            panic!("Expected Root relation");
1078        };
1079        let rel = root.input.as_mut().unwrap();
1080        let RelType::Filter(filter) = rel.rel_type.as_mut().unwrap() else {
1081            panic!("Expected Filter relation");
1082        };
1083        let condition = filter.condition.as_mut().unwrap();
1084        let RexType::ScalarFunction(func) = condition.rex_type.as_mut().unwrap() else {
1085            panic!("Expected ScalarFunction");
1086        };
1087        func.function_reference = 999; // Invalid - doesn't exist in extensions
1088
1089        plan
1090    }
1091
1092    #[test]
1093    fn test_write_plan_reports_formatting_issues() {
1094        let plan = make_plan_with_invalid_function_ref();
1095        let mut output = Vec::new();
1096
1097        let result = Format::Text.write_plan(
1098            &mut output,
1099            &plan,
1100            &OutputOptions::default(),
1101            &ExtensionRegistry::default(),
1102        );
1103
1104        // Should succeed but report formatting issues
1105        let outcome = result.expect("write_plan should not return hard error");
1106        assert!(
1107            matches!(outcome, Outcome::HadFormattingIssues(ref errors) if !errors.is_empty()),
1108            "Expected HadFormattingIssues with errors, got {outcome:?}"
1109        );
1110        // Output should still be written (best-effort formatting)
1111        assert!(
1112            !output.is_empty(),
1113            "Output should be written even with issues"
1114        );
1115    }
1116}