substrait_explain/textify/
plan.rs

1use std::fmt;
2
3use substrait::proto;
4
5use super::Textify;
6use crate::extensions::SimpleExtensions;
7use crate::parser::PLAN_HEADER;
8use crate::textify::foundation::ErrorAccumulator;
9use crate::textify::{OutputOptions, ScopedContext};
10
11#[derive(Debug, Clone)]
12pub struct PlanWriter<'a, E: ErrorAccumulator + Default> {
13    options: &'a OutputOptions,
14    extensions: SimpleExtensions,
15    relations: &'a [proto::PlanRel],
16    errors: E,
17}
18
19impl<'a, E: ErrorAccumulator + Default + Clone> PlanWriter<'a, E> {
20    pub fn new(options: &'a OutputOptions, plan: &'a proto::Plan) -> (Self, E) {
21        let (extensions, errs) =
22            SimpleExtensions::from_extensions(&plan.extension_uris, &plan.extensions);
23
24        let errors = E::default();
25        for err in errs {
26            errors.push(err.into());
27        }
28
29        let relations = plan.relations.as_slice();
30
31        (
32            Self {
33                options,
34                extensions,
35                relations,
36                errors: errors.clone(),
37            },
38            errors,
39        )
40    }
41
42    pub fn scope(&'a self) -> ScopedContext<'a, E> {
43        ScopedContext::new(self.options, &self.errors, &self.extensions)
44    }
45
46    pub fn write_extensions(&self, w: &mut impl fmt::Write) -> fmt::Result {
47        self.extensions.write(w, &self.options.indent)
48    }
49
50    pub fn write_relations(&self, w: &mut impl fmt::Write) -> fmt::Result {
51        // We always write the plan header, even if there are no relations.
52        writeln!(w, "{PLAN_HEADER}")?;
53        let scope = self.scope();
54        for (i, relation) in self.relations.iter().enumerate() {
55            if i > 0 {
56                writeln!(w)?;
57                writeln!(w)?;
58            }
59            relation.textify(&scope, w)?;
60        }
61        Ok(())
62    }
63}
64
65impl<'a, E: ErrorAccumulator + Default> fmt::Display for PlanWriter<'a, E> {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        self.write_extensions(f)?;
68        if !self.extensions.is_empty() {
69            writeln!(f)?;
70        }
71        self.write_relations(f)?;
72        writeln!(f)?;
73        Ok(())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::fmt::Write;
80
81    use pext::simple_extension_declaration::{ExtensionFunction, MappingType};
82    use substrait::proto::expression::{RexType, ScalarFunction};
83    use substrait::proto::function_argument::ArgType;
84    use substrait::proto::read_rel::{NamedTable, ReadType};
85    use substrait::proto::r#type::{Kind, Nullability, Struct};
86    use substrait::proto::{
87        Expression, FunctionArgument, NamedStruct, ReadRel, Type, extensions as pext,
88    };
89
90    use super::*;
91    use crate::parser::expressions::FieldIndex;
92    use crate::textify::ErrorQueue;
93
94    /// Test a fairly basic plan with an extension, read, and project.
95    ///
96    /// This has a manually constructed plan, rather than using the parser; more
97    /// complete testing is in the integration tests.
98    #[test]
99    fn test_plan_writer() {
100        let mut plan = proto::Plan::default();
101
102        // Add extension URI
103        plan.extension_uris.push(pext::SimpleExtensionUri {
104            extension_uri_anchor: 1,
105            uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(),
106        });
107
108        // Add extension function declaration
109        plan.extensions.push(pext::SimpleExtensionDeclaration {
110            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
111                extension_uri_reference: 1,
112                function_anchor: 10,
113                name: "add".to_string(),
114            })),
115        });
116
117        // Create read relation
118        let read_rel = ReadRel {
119            read_type: Some(ReadType::NamedTable(NamedTable {
120                names: vec!["table1".to_string()],
121                ..Default::default()
122            })),
123            base_schema: Some(NamedStruct {
124                names: vec!["col1".to_string(), "col2".to_string()],
125                r#struct: Some(Struct {
126                    types: vec![
127                        Type {
128                            kind: Some(Kind::I32(proto::r#type::I32 {
129                                nullability: Nullability::Nullable as i32,
130                                type_variation_reference: 0,
131                            })),
132                        },
133                        Type {
134                            kind: Some(Kind::I32(proto::r#type::I32 {
135                                nullability: Nullability::Nullable as i32,
136                                type_variation_reference: 0,
137                            })),
138                        },
139                    ],
140                    ..Default::default()
141                }),
142            }),
143            ..Default::default()
144        };
145
146        // Create project relation with add function
147        let add_function = ScalarFunction {
148            function_reference: 10,
149            arguments: vec![
150                FunctionArgument {
151                    arg_type: Some(ArgType::Value(Expression {
152                        rex_type: Some(RexType::Selection(Box::new(
153                            FieldIndex(0).to_field_reference(),
154                        ))),
155                    })),
156                },
157                FunctionArgument {
158                    arg_type: Some(ArgType::Value(Expression {
159                        rex_type: Some(RexType::Selection(Box::new(
160                            FieldIndex(1).to_field_reference(),
161                        ))),
162                    })),
163                },
164            ],
165            options: vec![],
166            output_type: None,
167            #[allow(deprecated)]
168            args: vec![],
169        };
170
171        let project_rel = proto::ProjectRel {
172            expressions: vec![Expression {
173                rex_type: Some(RexType::ScalarFunction(add_function)),
174            }],
175            input: Some(Box::new(proto::Rel {
176                rel_type: Some(proto::rel::RelType::Read(Box::new(read_rel))),
177            })),
178            common: None,
179            advanced_extension: None,
180        };
181
182        // Add relations to plan
183        plan.relations.push(proto::PlanRel {
184            rel_type: Some(proto::plan_rel::RelType::Rel(proto::Rel {
185                rel_type: Some(proto::rel::RelType::Project(Box::new(project_rel))),
186            })),
187        });
188
189        let options = OutputOptions::default();
190        let (writer, errors) = PlanWriter::<ErrorQueue>::new(&options, &plan);
191        let mut output = String::new();
192        write!(output, "{writer}").unwrap();
193
194        // Assert that there are no errors
195        let errors: Vec<_> = errors.into();
196        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
197
198        let expected = r#"
199=== Extensions
200URIs:
201  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
202Functions:
203  # 10 @  1: add
204
205=== Plan
206Project[$0, $1, add($0, $1)]
207  Read[table1 => col1:i32?, col2:i32?]
208"#
209        .trim_start();
210
211        assert_eq!(
212            output, expected,
213            "Output:\n---\n{output}\n---\nExpected:\n---\n{expected}\n---"
214        );
215    }
216}