Skip to main content

substrait_explain/
fixtures.rs

1//! Test fixtures for working with Substrait plans and substrait_explain
2
3use substrait::proto;
4
5use crate::extensions::simple::ExtensionKind;
6use crate::extensions::{ExtensionRegistry, SimpleExtensions};
7use crate::format;
8use crate::parser::{MessageParseError, Parser, ScopedParse};
9use crate::textify::foundation::{ErrorAccumulator, ErrorList};
10use crate::textify::{ErrorQueue, OutputOptions, Scope, ScopedContext, Textify};
11
12pub struct TestContext {
13    pub options: OutputOptions,
14    pub extensions: SimpleExtensions,
15    pub extension_registry: ExtensionRegistry,
16}
17
18impl Default for TestContext {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl TestContext {
25    pub fn new() -> Self {
26        Self {
27            options: OutputOptions::default(),
28            extensions: SimpleExtensions::new(),
29            extension_registry: ExtensionRegistry::new(),
30        }
31    }
32
33    pub fn with_options(mut self, options: OutputOptions) -> Self {
34        self.options = options;
35        self
36    }
37
38    pub fn with_urn(mut self, anchor: u32, urn: &str) -> Self {
39        self.extensions
40            .add_extension_urn(urn.to_string(), anchor)
41            .unwrap();
42        self
43    }
44
45    pub fn with_function(mut self, urn: u32, anchor: u32, name: impl Into<String>) -> Self {
46        assert!(self.extensions.find_urn(urn).is_ok());
47        self.extensions
48            .add_extension(ExtensionKind::Function, urn, anchor, name.into())
49            .unwrap();
50        self
51    }
52
53    pub fn with_type(mut self, urn: u32, anchor: u32, name: impl Into<String>) -> Self {
54        assert!(self.extensions.find_urn(urn).is_ok());
55        self.extensions
56            .add_extension(ExtensionKind::Type, urn, anchor, name.into())
57            .unwrap();
58        self
59    }
60
61    pub fn with_type_variation(mut self, urn: u32, anchor: u32, name: impl Into<String>) -> Self {
62        assert!(self.extensions.find_urn(urn).is_ok());
63        self.extensions
64            .add_extension(ExtensionKind::TypeVariation, urn, anchor, name.into())
65            .unwrap();
66        self
67    }
68
69    pub fn scope<'e, E: ErrorAccumulator>(&'e self, errors: &'e E) -> impl Scope + 'e {
70        ScopedContext::new(
71            &self.options,
72            errors,
73            &self.extensions,
74            &self.extension_registry,
75        )
76    }
77
78    pub fn textify<T: Textify>(&self, t: &T) -> (String, ErrorList) {
79        let errors = ErrorQueue::default();
80        let mut output = String::new();
81
82        let scope = self.scope(&errors);
83        t.textify(&scope, &mut output).unwrap();
84
85        let evec = errors.into_iter().collect();
86        (output, ErrorList(evec))
87    }
88
89    pub fn textify_no_errors<T: Textify>(&self, t: &T) -> String {
90        let (s, errs) = self.textify(t);
91        assert!(errs.is_empty(), "{} Errors: {}", errs.0.len(), errs.0[0]);
92        s
93    }
94
95    pub fn parse<T: ScopedParse>(&self, input: &str) -> Result<T, MessageParseError> {
96        T::parse(&self.extensions, input)
97    }
98}
99
100/// Roundtrip a plan and verify that the output is the same as the input, after
101/// being parsed to a Substrait plan and then back to text.
102pub fn roundtrip_plan(input: &str) {
103    // Parse the plan using the simplified interface
104    let plan = Parser::parse(input).unwrap_or_else(|e| {
105        println!("Error parsing plan:\n{e}");
106        panic!("{e}");
107    });
108
109    // Format the plan back to text using the simplified interface
110    let (actual, errors) = format(&plan);
111
112    // Check for formatting errors
113    if !errors.is_empty() {
114        println!("Formatting errors:");
115        for error in errors {
116            println!("  {error}");
117        }
118        panic!("Formatting errors occurred");
119    }
120
121    // Compare the output with the input, printing the difference.
122    assert_eq!(
123        actual.trim(),
124        input.trim(),
125        "Expected:\n---\n{}\n---\nActual:\n---\n{}\n---",
126        input.trim(),
127        actual.trim()
128    );
129}
130/// Parse a built-in type string (e.g. `"i64"`, `"string?"`) into a
131/// `proto::Type`. Panics on invalid type names.
132pub fn parse_type(s: &str) -> proto::Type {
133    proto::Type::parse(&SimpleExtensions::default(), s)
134        .unwrap_or_else(|e| panic!("failed to parse type '{s}': {e}"))
135}