substrait_explain/textify/
plan.rs

1use std::fmt;
2
3use substrait::proto;
4
5use super::Textify;
6use crate::extensions::SimpleExtensions;
7use crate::parser::PLAN_HEADER;
8use crate::textify::foundation::ErrorAccumulator;
9use crate::textify::{OutputOptions, ScopedContext};
10
11#[derive(Debug, Clone)]
12pub struct PlanWriter<'a, E: ErrorAccumulator + Default> {
13    options: &'a OutputOptions,
14    extensions: SimpleExtensions,
15    relations: &'a [proto::PlanRel],
16    errors: E,
17}
18
19impl<'a, E: ErrorAccumulator + Default + Clone> PlanWriter<'a, E> {
20    pub fn new(options: &'a OutputOptions, plan: &'a proto::Plan) -> (Self, E) {
21        let (extensions, errs) =
22            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
23
24        let errors = E::default();
25        for err in errs {
26            errors.push(err.into());
27        }
28
29        let relations = plan.relations.as_slice();
30
31        (
32            Self {
33                options,
34                extensions,
35                relations,
36                errors: errors.clone(),
37            },
38            errors,
39        )
40    }
41
42    pub fn scope(&'a self) -> ScopedContext<'a, E> {
43        ScopedContext::new(self.options, &self.errors, &self.extensions)
44    }
45
46    pub fn write_extensions(&self, w: &mut impl fmt::Write) -> fmt::Result {
47        self.extensions.write(w, &self.options.indent)
48    }
49
50    pub fn write_relations(&self, w: &mut impl fmt::Write) -> fmt::Result {
51        // We always write the plan header, even if there are no relations.
52        writeln!(w, "{PLAN_HEADER}")?;
53        let scope = self.scope();
54        for (i, relation) in self.relations.iter().enumerate() {
55            if i > 0 {
56                writeln!(w)?;
57                writeln!(w)?;
58            }
59            relation.textify(&scope, w)?;
60        }
61        Ok(())
62    }
63}
64
65impl<'a, E: ErrorAccumulator + Default> fmt::Display for PlanWriter<'a, E> {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        self.write_extensions(f)?;
68        if !self.extensions.is_empty() {
69            writeln!(f)?;
70        }
71        self.write_relations(f)?;
72        writeln!(f)?;
73        Ok(())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::fmt::Write;
80
81    use pext::simple_extension_declaration::{ExtensionFunction, MappingType};
82    use substrait::proto::expression::{RexType, ScalarFunction};
83    use substrait::proto::function_argument::ArgType;
84    use substrait::proto::read_rel::{NamedTable, ReadType};
85    use substrait::proto::r#type::{Kind, Nullability, Struct};
86    use substrait::proto::{
87        Expression, FunctionArgument, NamedStruct, ReadRel, Type, extensions as pext,
88    };
89
90    use super::*;
91    use crate::parser::expressions::FieldIndex;
92    use crate::textify::ErrorQueue;
93
94    /// Test a fairly basic plan with an extension, read, and project.
95    ///
96    /// This has a manually constructed plan, rather than using the parser; more
97    /// complete testing is in the integration tests.
98    #[test]
99    fn test_plan_writer() {
100        let mut plan = proto::Plan::default();
101
102        // Add extension URN
103        plan.extension_urns.push(pext::SimpleExtensionUrn {
104            extension_urn_anchor: 1,
105            urn: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(),
106        });
107
108        // Add extension function declaration
109        plan.extensions.push(pext::SimpleExtensionDeclaration {
110            #[allow(deprecated)]
111            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
112                extension_urn_reference: 1,
113                extension_uri_reference: Default::default(), // deprecated
114                function_anchor: 10,
115                name: "add".to_string(),
116            })),
117        });
118
119        // Create read relation
120        let read_rel = ReadRel {
121            read_type: Some(ReadType::NamedTable(NamedTable {
122                names: vec!["table1".to_string()],
123                ..Default::default()
124            })),
125            base_schema: Some(NamedStruct {
126                names: vec!["col1".to_string(), "col2".to_string()],
127                r#struct: Some(Struct {
128                    types: vec![
129                        Type {
130                            kind: Some(Kind::I32(proto::r#type::I32 {
131                                nullability: Nullability::Nullable as i32,
132                                type_variation_reference: 0,
133                            })),
134                        },
135                        Type {
136                            kind: Some(Kind::I32(proto::r#type::I32 {
137                                nullability: Nullability::Nullable as i32,
138                                type_variation_reference: 0,
139                            })),
140                        },
141                    ],
142                    ..Default::default()
143                }),
144            }),
145            ..Default::default()
146        };
147
148        // Create project relation with add function
149        let add_function = ScalarFunction {
150            function_reference: 10,
151            arguments: vec![
152                FunctionArgument {
153                    arg_type: Some(ArgType::Value(Expression {
154                        rex_type: Some(RexType::Selection(Box::new(
155                            FieldIndex(0).to_field_reference(),
156                        ))),
157                    })),
158                },
159                FunctionArgument {
160                    arg_type: Some(ArgType::Value(Expression {
161                        rex_type: Some(RexType::Selection(Box::new(
162                            FieldIndex(1).to_field_reference(),
163                        ))),
164                    })),
165                },
166            ],
167            options: vec![],
168            output_type: None,
169            #[allow(deprecated)]
170            args: vec![],
171        };
172
173        let project_rel = proto::ProjectRel {
174            expressions: vec![Expression {
175                rex_type: Some(RexType::ScalarFunction(add_function)),
176            }],
177            input: Some(Box::new(proto::Rel {
178                rel_type: Some(proto::rel::RelType::Read(Box::new(read_rel))),
179            })),
180            common: None,
181            advanced_extension: None,
182        };
183
184        // Add relations to plan
185        plan.relations.push(proto::PlanRel {
186            rel_type: Some(proto::plan_rel::RelType::Rel(proto::Rel {
187                rel_type: Some(proto::rel::RelType::Project(Box::new(project_rel))),
188            })),
189        });
190
191        let options = OutputOptions::default();
192        let (writer, errors) = PlanWriter::<ErrorQueue>::new(&options, &plan);
193        let mut output = String::new();
194        write!(output, "{writer}").unwrap();
195
196        // Assert that there are no errors
197        let errors: Vec<_> = errors.into();
198        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
199
200        let expected = r#"
201=== Extensions
202URNs:
203  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
204Functions:
205  # 10 @  1: add
206
207=== Plan
208Project[$0, $1, add($0, $1)]
209  Read[table1 => col1:i32?, col2:i32?]
210"#
211        .trim_start();
212
213        assert_eq!(
214            output, expected,
215            "Output:\n---\n{output}\n---\nExpected:\n---\n{expected}\n---"
216        );
217    }
218}