substrait_explain/textify/
plan.rs

1use std::fmt;
2
3use substrait::proto;
4
5use super::Textify;
6use crate::extensions::{ExtensionRegistry, SimpleExtensions};
7use crate::parser::PLAN_HEADER;
8use crate::textify::foundation::ErrorAccumulator;
9use crate::textify::{OutputOptions, ScopedContext};
10
11#[derive(Debug, Clone)]
12pub struct PlanWriter<'a, E: ErrorAccumulator + Default> {
13    options: &'a OutputOptions,
14    extensions: SimpleExtensions,
15    relations: &'a [proto::PlanRel],
16    errors: E,
17    extension_registry: &'a ExtensionRegistry,
18}
19
20impl<'a, E: ErrorAccumulator + Default + Clone> PlanWriter<'a, E> {
21    pub fn new(
22        options: &'a OutputOptions,
23        plan: &'a proto::Plan,
24        extension_registry: &'a ExtensionRegistry,
25    ) -> (Self, E) {
26        let (extensions, errs) =
27            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
28
29        let errors = E::default();
30        for err in errs {
31            errors.push(err.into());
32        }
33
34        let relations = plan.relations.as_slice();
35
36        (
37            Self {
38                options,
39                extensions,
40                relations,
41                errors: errors.clone(),
42                extension_registry,
43            },
44            errors,
45        )
46    }
47
48    pub fn scope(&'a self) -> ScopedContext<'a, E> {
49        ScopedContext::new(
50            self.options,
51            &self.errors,
52            &self.extensions,
53            self.extension_registry,
54        )
55    }
56
57    pub fn write_extensions(&self, w: &mut impl fmt::Write) -> fmt::Result {
58        self.extensions.write(w, &self.options.indent)
59    }
60
61    pub fn write_relations(&self, w: &mut impl fmt::Write) -> fmt::Result {
62        // We always write the plan header, even if there are no relations.
63        writeln!(w, "{PLAN_HEADER}")?;
64        let scope = self.scope();
65        for (i, relation) in self.relations.iter().enumerate() {
66            if i > 0 {
67                writeln!(w)?;
68                writeln!(w)?;
69            }
70            relation.textify(&scope, w)?;
71        }
72        Ok(())
73    }
74}
75
76impl<'a, E: ErrorAccumulator + Default> fmt::Display for PlanWriter<'a, E> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        self.write_extensions(f)?;
79        if !self.extensions.is_empty() {
80            writeln!(f)?;
81        }
82        self.write_relations(f)?;
83        writeln!(f)?;
84        Ok(())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::fmt::Write;
91
92    use pext::simple_extension_declaration::{ExtensionFunction, MappingType};
93    use substrait::proto::expression::{RexType, ScalarFunction};
94    use substrait::proto::function_argument::ArgType;
95    use substrait::proto::read_rel::{NamedTable, ReadType};
96    use substrait::proto::r#type::{Kind, Nullability, Struct};
97    use substrait::proto::{
98        Expression, FunctionArgument, NamedStruct, ReadRel, Type, extensions as pext,
99    };
100
101    use super::*;
102    use crate::parser::expressions::FieldIndex;
103    use crate::textify::ErrorQueue;
104
105    /// Test a fairly basic plan with an extension, read, and project.
106    ///
107    /// This has a manually constructed plan, rather than using the parser; more
108    /// complete testing is in the integration tests.
109    #[test]
110    fn test_plan_writer() {
111        let mut plan = proto::Plan::default();
112
113        // Add extension URN
114        plan.extension_urns.push(pext::SimpleExtensionUrn {
115            extension_urn_anchor: 1,
116            urn: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(),
117        });
118
119        // Add extension function declaration
120        plan.extensions.push(pext::SimpleExtensionDeclaration {
121            #[allow(deprecated)]
122            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
123                extension_urn_reference: 1,
124                extension_uri_reference: Default::default(), // deprecated
125                function_anchor: 10,
126                name: "add".to_string(),
127            })),
128        });
129
130        // Create read relation
131        let read_rel = ReadRel {
132            read_type: Some(ReadType::NamedTable(NamedTable {
133                names: vec!["table1".to_string()],
134                ..Default::default()
135            })),
136            base_schema: Some(NamedStruct {
137                names: vec!["col1".to_string(), "col2".to_string()],
138                r#struct: Some(Struct {
139                    types: vec![
140                        Type {
141                            kind: Some(Kind::I32(proto::r#type::I32 {
142                                nullability: Nullability::Nullable as i32,
143                                type_variation_reference: 0,
144                            })),
145                        },
146                        Type {
147                            kind: Some(Kind::I32(proto::r#type::I32 {
148                                nullability: Nullability::Nullable as i32,
149                                type_variation_reference: 0,
150                            })),
151                        },
152                    ],
153                    ..Default::default()
154                }),
155            }),
156            ..Default::default()
157        };
158
159        // Create project relation with add function
160        let add_function = ScalarFunction {
161            function_reference: 10,
162            arguments: vec![
163                FunctionArgument {
164                    arg_type: Some(ArgType::Value(Expression {
165                        rex_type: Some(RexType::Selection(Box::new(
166                            FieldIndex(0).to_field_reference(),
167                        ))),
168                    })),
169                },
170                FunctionArgument {
171                    arg_type: Some(ArgType::Value(Expression {
172                        rex_type: Some(RexType::Selection(Box::new(
173                            FieldIndex(1).to_field_reference(),
174                        ))),
175                    })),
176                },
177            ],
178            options: vec![],
179            output_type: None,
180            #[allow(deprecated)]
181            args: vec![],
182        };
183
184        let project_rel = proto::ProjectRel {
185            expressions: vec![Expression {
186                rex_type: Some(RexType::ScalarFunction(add_function)),
187            }],
188            input: Some(Box::new(proto::Rel {
189                rel_type: Some(proto::rel::RelType::Read(Box::new(read_rel))),
190            })),
191            common: None,
192            advanced_extension: None,
193        };
194
195        // Add relations to plan
196        plan.relations.push(proto::PlanRel {
197            rel_type: Some(proto::plan_rel::RelType::Rel(proto::Rel {
198                rel_type: Some(proto::rel::RelType::Project(Box::new(project_rel))),
199            })),
200        });
201
202        let options = OutputOptions::default();
203        let extension_registry = ExtensionRegistry::new();
204        let (writer, errors) = PlanWriter::<ErrorQueue>::new(&options, &plan, &extension_registry);
205        let mut output = String::new();
206        write!(output, "{writer}").unwrap();
207
208        // Assert that there are no errors
209        let errors: Vec<_> = errors.into();
210        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
211
212        let expected = r#"
213=== Extensions
214URNs:
215  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
216Functions:
217  # 10 @  1: add
218
219=== Plan
220Project[$0, $1, add($0, $1)]
221  Read[table1 => col1:i32?, col2:i32?]
222"#
223        .trim_start();
224
225        assert_eq!(
226            output, expected,
227            "Output:\n---\n{output}\n---\nExpected:\n---\n{expected}\n---"
228        );
229    }
230}