substrait_explain/textify/
plan.rs

1use std::fmt;
2
3use substrait::proto;
4
5use super::Textify;
6use crate::extensions::SimpleExtensions;
7use crate::parser::PLAN_HEADER;
8use crate::textify::foundation::ErrorAccumulator;
9use crate::textify::{OutputOptions, ScopedContext};
10
11#[derive(Debug, Clone)]
12pub struct PlanWriter<'a, E: ErrorAccumulator + Default> {
13    options: &'a OutputOptions,
14    extensions: SimpleExtensions,
15    relations: &'a [proto::PlanRel],
16    errors: E,
17}
18
19impl<'a, E: ErrorAccumulator + Default + Clone> PlanWriter<'a, E> {
20    pub fn new(options: &'a OutputOptions, plan: &'a proto::Plan) -> (Self, E) {
21        let (extensions, errs) =
22            SimpleExtensions::from_extensions(&plan.extension_uris, &plan.extensions);
23
24        let errors = E::default();
25        for err in errs {
26            errors.push(err.into());
27        }
28
29        let relations = plan.relations.as_slice();
30
31        (
32            Self {
33                options,
34                extensions,
35                relations,
36                errors: errors.clone(),
37            },
38            errors,
39        )
40    }
41
42    pub fn scope(&'a self) -> ScopedContext<'a, E> {
43        ScopedContext::new(self.options, &self.errors, &self.extensions)
44    }
45
46    pub fn write_extensions(&self, w: &mut impl fmt::Write) -> fmt::Result {
47        self.extensions.write(w, &self.options.indent)
48    }
49
50    pub fn write_relations(&self, w: &mut impl fmt::Write) -> fmt::Result {
51        // We always write the plan header, even if there are no relations.
52        writeln!(w, "{PLAN_HEADER}")?;
53        let scope = self.scope();
54        for (i, relation) in self.relations.iter().enumerate() {
55            if i > 0 {
56                writeln!(w)?;
57            }
58            relation.textify(&scope, w)?;
59        }
60        Ok(())
61    }
62}
63
64impl<'a, E: ErrorAccumulator + Default> fmt::Display for PlanWriter<'a, E> {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        self.write_extensions(f)?;
67        if !self.extensions.is_empty() {
68            writeln!(f)?;
69        }
70        self.write_relations(f)?;
71        Ok(())
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use std::fmt::Write;
78
79    use pext::simple_extension_declaration::{ExtensionFunction, MappingType};
80    use substrait::proto::expression::{RexType, ScalarFunction};
81    use substrait::proto::function_argument::ArgType;
82    use substrait::proto::read_rel::{NamedTable, ReadType};
83    use substrait::proto::r#type::{Kind, Nullability, Struct};
84    use substrait::proto::{
85        Expression, FunctionArgument, NamedStruct, ReadRel, Type, extensions as pext,
86    };
87
88    use super::*;
89    use crate::parser::expressions::FieldIndex;
90    use crate::textify::ErrorQueue;
91
92    /// Test a fairly basic plan with an extension, read, and project.
93    ///
94    /// This has a manually constructed plan, rather than using the parser; more
95    /// complete testing is in the integration tests.
96    #[test]
97    fn test_plan_writer() {
98        let mut plan = proto::Plan::default();
99
100        // Add extension URI
101        plan.extension_uris.push(pext::SimpleExtensionUri {
102            extension_uri_anchor: 1,
103            uri: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(),
104        });
105
106        // Add extension function declaration
107        plan.extensions.push(pext::SimpleExtensionDeclaration {
108            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
109                extension_uri_reference: 1,
110                function_anchor: 10,
111                name: "add".to_string(),
112            })),
113        });
114
115        // Create read relation
116        let read_rel = ReadRel {
117            read_type: Some(ReadType::NamedTable(NamedTable {
118                names: vec!["table1".to_string()],
119                ..Default::default()
120            })),
121            base_schema: Some(NamedStruct {
122                names: vec!["col1".to_string(), "col2".to_string()],
123                r#struct: Some(Struct {
124                    types: vec![
125                        Type {
126                            kind: Some(Kind::I32(proto::r#type::I32 {
127                                nullability: Nullability::Nullable as i32,
128                                type_variation_reference: 0,
129                            })),
130                        },
131                        Type {
132                            kind: Some(Kind::I32(proto::r#type::I32 {
133                                nullability: Nullability::Nullable as i32,
134                                type_variation_reference: 0,
135                            })),
136                        },
137                    ],
138                    ..Default::default()
139                }),
140            }),
141            ..Default::default()
142        };
143
144        // Create project relation with add function
145        let add_function = ScalarFunction {
146            function_reference: 10,
147            arguments: vec![
148                FunctionArgument {
149                    arg_type: Some(ArgType::Value(Expression {
150                        rex_type: Some(RexType::Selection(Box::new(
151                            FieldIndex(0).to_field_reference(),
152                        ))),
153                    })),
154                },
155                FunctionArgument {
156                    arg_type: Some(ArgType::Value(Expression {
157                        rex_type: Some(RexType::Selection(Box::new(
158                            FieldIndex(1).to_field_reference(),
159                        ))),
160                    })),
161                },
162            ],
163            options: vec![],
164            output_type: None,
165            #[allow(deprecated)]
166            args: vec![],
167        };
168
169        let project_rel = proto::ProjectRel {
170            expressions: vec![Expression {
171                rex_type: Some(RexType::ScalarFunction(add_function)),
172            }],
173            input: Some(Box::new(proto::Rel {
174                rel_type: Some(proto::rel::RelType::Read(Box::new(read_rel))),
175            })),
176            common: None,
177            advanced_extension: None,
178        };
179
180        // Add relations to plan
181        plan.relations.push(proto::PlanRel {
182            rel_type: Some(proto::plan_rel::RelType::Rel(proto::Rel {
183                rel_type: Some(proto::rel::RelType::Project(Box::new(project_rel))),
184            })),
185        });
186
187        let options = OutputOptions::default();
188        let (writer, errors) = PlanWriter::<ErrorQueue>::new(&options, &plan);
189        let mut output = String::new();
190        write!(output, "{writer}").unwrap();
191
192        // Assert that there are no errors
193        let errors: Vec<_> = errors.into();
194        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
195
196        let expected = r#"
197=== Extensions
198URIs:
199  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
200Functions:
201  # 10 @  1: add
202
203=== Plan
204Project[$0, $1, add($0, $1)]
205  Read[table1 => col1:i32?, col2:i32?]"#
206            .trim_start();
207
208        assert_eq!(
209            output, expected,
210            "Output:\n---\n{output}\n---\nExpected:\n---\n{expected}\n---"
211        );
212    }
213}