Skip to main content

substrait_explain/textify/
plan.rs

1use std::fmt;
2
3use substrait::proto;
4
5use super::Textify;
6use crate::extensions::{ExtensionRegistry, SimpleExtensions};
7use crate::parser::PLAN_HEADER;
8use crate::textify::foundation::ErrorAccumulator;
9use crate::textify::{OutputOptions, ScopedContext};
10
11#[derive(Debug, Clone)]
12pub(crate) struct PlanWriter<'a, E: ErrorAccumulator + Default> {
13    options: &'a OutputOptions,
14    extensions: SimpleExtensions,
15    relations: &'a [proto::PlanRel],
16    errors: E,
17    extension_registry: &'a ExtensionRegistry,
18}
19
20impl<'a, E: ErrorAccumulator + Default + Clone> PlanWriter<'a, E> {
21    pub(crate) fn new(
22        options: &'a OutputOptions,
23        plan: &'a proto::Plan,
24        extension_registry: &'a ExtensionRegistry,
25    ) -> (Self, E) {
26        let (extensions, errs) =
27            SimpleExtensions::from_extensions(&plan.extension_urns, &plan.extensions);
28
29        let errors = E::default();
30        for err in errs {
31            errors.push(err.into());
32        }
33
34        let relations = plan.relations.as_slice();
35
36        (
37            Self {
38                options,
39                extensions,
40                relations,
41                errors: errors.clone(),
42                extension_registry,
43            },
44            errors,
45        )
46    }
47
48    pub(crate) fn scope(&'a self) -> ScopedContext<'a, E> {
49        ScopedContext::new(
50            self.options,
51            &self.errors,
52            &self.extensions,
53            self.extension_registry,
54        )
55    }
56
57    pub(crate) fn write_extensions(&self, w: &mut impl fmt::Write) -> fmt::Result {
58        self.extensions.write(w, &self.options.indent)
59    }
60
61    pub(crate) fn write_relations(&self, w: &mut impl fmt::Write) -> fmt::Result {
62        // We always write the plan header, even if there are no relations.
63        writeln!(w, "{PLAN_HEADER}")?;
64        let scope = self.scope();
65        for (i, relation) in self.relations.iter().enumerate() {
66            if i > 0 {
67                writeln!(w)?;
68                writeln!(w)?;
69            }
70            relation.textify(&scope, w)?;
71        }
72        Ok(())
73    }
74}
75
76impl<'a, E: ErrorAccumulator + Default> fmt::Display for PlanWriter<'a, E> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        self.write_extensions(f)?;
79        if !self.extensions.is_empty() {
80            writeln!(f)?;
81        }
82        self.write_relations(f)?;
83        writeln!(f)?;
84        Ok(())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::fmt::Write;
91
92    use pext::simple_extension_declaration::{ExtensionFunction, MappingType};
93    use substrait::proto::expression::{RexType, ScalarFunction};
94    use substrait::proto::function_argument::ArgType;
95    use substrait::proto::read_rel::{NamedTable, ReadType};
96    use substrait::proto::r#type::{I64, Kind, Nullability, Struct};
97    use substrait::proto::{
98        Expression, FunctionArgument, NamedStruct, ReadRel, Type, extensions as pext,
99    };
100
101    use super::*;
102    use crate::parser::expressions::FieldIndex;
103    use crate::textify::ErrorQueue;
104
105    /// Test a fairly basic plan with an extension, read, and project.
106    ///
107    /// This has a manually constructed plan, rather than using the parser; more
108    /// complete testing is in the integration tests.
109    #[test]
110    fn test_plan_writer() {
111        let mut plan = proto::Plan::default();
112
113        // Add extension URN
114        plan.extension_urns.push(pext::SimpleExtensionUrn {
115            extension_urn_anchor: 1,
116            urn: "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml".to_string(),
117        });
118
119        // Add extension function declaration
120        plan.extensions.push(pext::SimpleExtensionDeclaration {
121            mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
122                extension_urn_reference: 1,
123                function_anchor: 10,
124                name: "add".to_string(),
125            })),
126        });
127
128        // Create read relation
129        let read_rel = ReadRel {
130            read_type: Some(ReadType::NamedTable(NamedTable {
131                names: vec!["table1".to_string()],
132                ..Default::default()
133            })),
134            base_schema: Some(NamedStruct {
135                names: vec!["col1".to_string(), "col2".to_string()],
136                r#struct: Some(Struct {
137                    types: vec![
138                        Type {
139                            kind: Some(Kind::I32(proto::r#type::I32 {
140                                nullability: Nullability::Nullable as i32,
141                                type_variation_reference: 0,
142                            })),
143                        },
144                        Type {
145                            kind: Some(Kind::I32(proto::r#type::I32 {
146                                nullability: Nullability::Nullable as i32,
147                                type_variation_reference: 0,
148                            })),
149                        },
150                    ],
151                    ..Default::default()
152                }),
153            }),
154            ..Default::default()
155        };
156
157        // Create project relation with add function
158        let add_function = ScalarFunction {
159            function_reference: 10,
160            arguments: vec![
161                FunctionArgument {
162                    arg_type: Some(ArgType::Value(Expression {
163                        rex_type: Some(RexType::Selection(Box::new(
164                            FieldIndex(0).to_field_reference(),
165                        ))),
166                    })),
167                },
168                FunctionArgument {
169                    arg_type: Some(ArgType::Value(Expression {
170                        rex_type: Some(RexType::Selection(Box::new(
171                            FieldIndex(1).to_field_reference(),
172                        ))),
173                    })),
174                },
175            ],
176            options: vec![],
177            output_type: Some(Type {
178                kind: Some(Kind::I64(I64 {
179                    nullability: Nullability::Required as i32,
180                    type_variation_reference: 0,
181                })),
182            }),
183            #[allow(deprecated)]
184            args: vec![],
185        };
186
187        let project_rel = proto::ProjectRel {
188            expressions: vec![Expression {
189                rex_type: Some(RexType::ScalarFunction(add_function)),
190            }],
191            input: Some(Box::new(proto::Rel {
192                rel_type: Some(proto::rel::RelType::Read(Box::new(read_rel))),
193            })),
194            common: None,
195            advanced_extension: None,
196        };
197
198        // Add relations to plan
199        plan.relations.push(proto::PlanRel {
200            rel_type: Some(proto::plan_rel::RelType::Rel(proto::Rel {
201                rel_type: Some(proto::rel::RelType::Project(Box::new(project_rel))),
202            })),
203        });
204
205        let options = OutputOptions::default();
206        let extension_registry = ExtensionRegistry::new();
207        let (writer, errors) = PlanWriter::<ErrorQueue>::new(&options, &plan, &extension_registry);
208        let mut output = String::new();
209        write!(output, "{writer}").unwrap();
210
211        // Assert that there are no errors
212        let errors: Vec<_> = errors.into();
213        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
214
215        let expected = r#"
216=== Extensions
217URNs:
218  @  1: https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
219Functions:
220  # 10 @  1: add
221
222=== Plan
223Project[$0, $1, add($0, $1):i64]
224  Read[table1 => col1:i32?, col2:i32?]
225"#
226        .trim_start();
227
228        assert_eq!(
229            output, expected,
230            "Output:\n---\n{output}\n---\nExpected:\n---\n{expected}\n---"
231        );
232    }
233}