substrait_explain/
fixtures.rs

1//! Test fixtures for working with Substrait plans and substrait_explain
2
3use crate::extensions::simple::ExtensionKind;
4use crate::extensions::{ExtensionRegistry, SimpleExtensions};
5use crate::format;
6use crate::parser::{MessageParseError, Parser, ScopedParse};
7use crate::textify::foundation::{ErrorAccumulator, ErrorList};
8use crate::textify::{ErrorQueue, OutputOptions, Scope, ScopedContext, Textify};
9
10pub struct TestContext {
11    pub options: OutputOptions,
12    pub extensions: SimpleExtensions,
13    pub extension_registry: ExtensionRegistry,
14}
15
16impl Default for TestContext {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl TestContext {
23    pub fn new() -> Self {
24        Self {
25            options: OutputOptions::default(),
26            extensions: SimpleExtensions::new(),
27            extension_registry: ExtensionRegistry::new(),
28        }
29    }
30
31    pub fn with_options(mut self, options: OutputOptions) -> Self {
32        self.options = options;
33        self
34    }
35
36    pub fn with_urn(mut self, anchor: u32, urn: &str) -> Self {
37        self.extensions
38            .add_extension_urn(urn.to_string(), anchor)
39            .unwrap();
40        self
41    }
42
43    pub fn with_function(mut self, urn: u32, anchor: u32, name: impl Into<String>) -> Self {
44        assert!(self.extensions.find_urn(urn).is_ok());
45        self.extensions
46            .add_extension(ExtensionKind::Function, urn, anchor, name.into())
47            .unwrap();
48        self
49    }
50
51    pub fn with_type(mut self, urn: u32, anchor: u32, name: impl Into<String>) -> Self {
52        assert!(self.extensions.find_urn(urn).is_ok());
53        self.extensions
54            .add_extension(ExtensionKind::Type, urn, anchor, name.into())
55            .unwrap();
56        self
57    }
58
59    pub fn with_type_variation(mut self, urn: u32, anchor: u32, name: impl Into<String>) -> Self {
60        assert!(self.extensions.find_urn(urn).is_ok());
61        self.extensions
62            .add_extension(ExtensionKind::TypeVariation, urn, anchor, name.into())
63            .unwrap();
64        self
65    }
66
67    pub fn scope<'e, E: ErrorAccumulator>(&'e self, errors: &'e E) -> impl Scope + 'e {
68        ScopedContext::new(
69            &self.options,
70            errors,
71            &self.extensions,
72            &self.extension_registry,
73        )
74    }
75
76    pub fn textify<T: Textify>(&self, t: &T) -> (String, ErrorList) {
77        let errors = ErrorQueue::default();
78        let mut output = String::new();
79
80        let scope = self.scope(&errors);
81        t.textify(&scope, &mut output).unwrap();
82
83        let evec = errors.into_iter().collect();
84        (output, ErrorList(evec))
85    }
86
87    pub fn textify_no_errors<T: Textify>(&self, t: &T) -> String {
88        let (s, errs) = self.textify(t);
89        assert!(errs.is_empty(), "{} Errors: {}", errs.0.len(), errs.0[0]);
90        s
91    }
92
93    pub fn parse<T: ScopedParse>(&self, input: &str) -> Result<T, MessageParseError> {
94        T::parse(&self.extensions, input)
95    }
96}
97
98/// Roundtrip a plan and verify that the output is the same as the input, after
99/// being parsed to a Substrait plan and then back to text.
100pub fn roundtrip_plan(input: &str) {
101    // Parse the plan using the simplified interface
102    let plan = Parser::parse(input).unwrap_or_else(|e| {
103        println!("Error parsing plan:\n{e}");
104        panic!("{e}");
105    });
106
107    // Format the plan back to text using the simplified interface
108    let (actual, errors) = format(&plan);
109
110    // Check for formatting errors
111    if !errors.is_empty() {
112        println!("Formatting errors:");
113        for error in errors {
114            println!("  {error}");
115        }
116        panic!("Formatting errors occurred");
117    }
118
119    // Compare the output with the input, printing the difference.
120    assert_eq!(
121        actual.trim(),
122        input.trim(),
123        "Expected:\n---\n{}\n---\nActual:\n---\n{}\n---",
124        input.trim(),
125        actual.trim()
126    );
127}