substrait_explain/textify/
rels.rs

1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::fmt;
4use std::fmt::Debug;
5
6use prost::UnknownEnumValue;
7use substrait::proto::fetch_rel::CountMode;
8use substrait::proto::plan_rel::RelType as PlanRelType;
9use substrait::proto::read_rel::ReadType;
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::EmitKind;
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14    AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct,
15    PlanRel, ProjectRel, ReadRel, Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
16};
17
18use super::expressions::Reference;
19use super::types::Name;
20use super::{PlanError, Scope, Textify};
21
22pub trait NamedRelation {
23    fn name(&self) -> &'static str;
24}
25
26impl NamedRelation for Rel {
27    fn name(&self) -> &'static str {
28        match self.rel_type.as_ref() {
29            None => "UnknownRel",
30            Some(RelType::Read(_)) => "Read",
31            Some(RelType::Filter(_)) => "Filter",
32            Some(RelType::Project(_)) => "Project",
33            Some(RelType::Fetch(_)) => "Fetch",
34            Some(RelType::Aggregate(_)) => "Aggregate",
35            Some(RelType::Sort(_)) => "Sort",
36            Some(RelType::HashJoin(_)) => "HashJoin",
37            Some(RelType::Exchange(_)) => "Exchange",
38            Some(RelType::Join(_)) => "Join",
39            Some(RelType::Set(_)) => "Set",
40            Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
41            Some(RelType::Cross(_)) => "Cross",
42            Some(RelType::Reference(_)) => "Reference",
43            Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
44            Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
45            Some(RelType::Write(_)) => "Write",
46            Some(RelType::Ddl(_)) => "Ddl",
47            Some(RelType::Update(_)) => "Update",
48            Some(RelType::MergeJoin(_)) => "MergeJoin",
49            Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
50            Some(RelType::Window(_)) => "Window",
51            Some(RelType::Expand(_)) => "Expand",
52        }
53    }
54}
55
56/// Trait for enums that can be converted to a string representation for
57/// textification.
58///
59/// Returns Ok(str) for valid enum values, or Err([PlanError]) for invalid or
60/// unknown values.
61pub trait ValueEnum {
62    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
63}
64
65#[derive(Debug, Clone)]
66pub struct NamedArg<'a> {
67    pub name: &'a str,
68    pub value: Value<'a>,
69}
70
71#[derive(Debug, Clone)]
72pub enum Value<'a> {
73    Name(Name<'a>),
74    TableName(Vec<Name<'a>>),
75    Field(Option<Name<'a>>, Option<&'a Type>),
76    Tuple(Vec<Value<'a>>),
77    List(Vec<Value<'a>>),
78    Reference(i32),
79    Expression(&'a Expression),
80    AggregateFunction(&'a AggregateFunction),
81    /// Represents a missing, invalid, or unspecified value.
82    Missing(PlanError),
83    /// Represents a valid enum value as a string for textification.
84    Enum(Cow<'a, str>),
85    Integer(i32),
86}
87
88impl<'a> Value<'a> {
89    pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
90        match maybe_value {
91            Some(s) => s,
92            None => Value::Missing(f()),
93        }
94    }
95}
96
97impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
98    fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
99        match token {
100            Ok(value) => Value::TableName(value),
101            Err(err) => Value::Missing(err),
102        }
103    }
104}
105
106impl<'a> Textify for Value<'a> {
107    fn name() -> &'static str {
108        "Value"
109    }
110
111    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
112        match self {
113            Value::Name(name) => write!(w, "{}", ctx.display(name)),
114            Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
115            Value::Field(name, typ) => {
116                write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
117            }
118            Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
119            Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
120            Value::Reference(i) => write!(w, "{}", Reference(*i)),
121            Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
122            Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
123            Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
124            Value::Enum(res) => write!(w, "&{res}"),
125            Value::Integer(i) => write!(w, "{i}"),
126        }
127    }
128}
129
130fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
131    let mut fields = schema
132        .r#struct
133        .as_ref()
134        .map(|s| s.types.iter())
135        .into_iter()
136        .flatten();
137    let mut names = schema.names.iter();
138
139    // let field_count = schema.r#struct.as_ref().map(|s| s.types.len()).unwrap_or(0);
140    // let name_count = schema.names.len();
141
142    let mut values = Vec::new();
143    loop {
144        let field = fields.next();
145        let name = names.next().map(|n| Name(n));
146        if field.is_none() && name.is_none() {
147            break;
148        }
149
150        values.push(Value::Field(name, field));
151    }
152
153    values
154}
155
156struct Emitted<'a> {
157    pub values: &'a [Value<'a>],
158    pub emit: Option<&'a EmitKind>,
159}
160
161impl<'a> Emitted<'a> {
162    pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
163        Self { values, emit }
164    }
165
166    pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
167        write!(w, "{}", ctx.separated(self.values.iter(), ", "))
168    }
169}
170
171impl<'a> Textify for Emitted<'a> {
172    fn name() -> &'static str {
173        "Emitted"
174    }
175
176    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
177        if ctx.options().show_emit {
178            return self.write_direct(ctx, w);
179        }
180
181        let indices = match &self.emit {
182            Some(EmitKind::Emit(e)) => &e.output_mapping,
183            Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
184            None => return self.write_direct(ctx, w),
185        };
186
187        for (i, &index) in indices.iter().enumerate() {
188            if i > 0 {
189                write!(w, ", ")?;
190            }
191
192            match self.values.get(index as usize) {
193                Some(value) => write!(w, "{}", ctx.display(value))?,
194                None => write!(w, "{}", ctx.failure(PlanError::invalid(
195                    "Emitted",
196                    Some("output_mapping"),
197                    format!(
198                        "Output mapping index {} is out of bounds for values collection of size {}",
199                        index, self.values.len()
200                    )
201                )))?,
202            }
203        }
204
205        Ok(())
206    }
207}
208
209#[derive(Debug, Clone)]
210pub struct Arguments<'a> {
211    /// Positional arguments (e.g., a filter condition, group-bys, etc.))
212    pub positional: Vec<Value<'a>>,
213    /// Named arguments (e.g., limit=10, offset=5)
214    pub named: Vec<NamedArg<'a>>,
215}
216
217impl<'a> Textify for Arguments<'a> {
218    fn name() -> &'static str {
219        "Arguments"
220    }
221    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
222        if self.positional.is_empty() && self.named.is_empty() {
223            return write!(w, "_");
224        }
225
226        write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
227        if !self.positional.is_empty() && !self.named.is_empty() {
228            write!(w, ", ")?;
229        }
230        write!(w, "{}", ctx.separated(self.named.iter(), ", "))
231    }
232}
233
234pub struct Relation<'a> {
235    pub name: &'a str,
236    /// Arguments to the relation, if any.
237    ///
238    /// - `None` means this relation does not take arguments, and the argument
239    ///   section is omitted entirely.
240    /// - `Some(args)` with both vectors empty means the relation takes
241    ///   arguments, but none are provided; this will print as `_ => ...`.
242    /// - `Some(args)` with non-empty vectors will print as usual, with
243    ///   positional arguments first, then named arguments, separated by commas.
244    pub arguments: Option<Arguments<'a>>,
245    /// The columns emitted by this relation, pre-emit - the 'direct' column
246    /// output.
247    pub columns: Vec<Value<'a>>,
248    /// The emit kind, if any. If none, use the columns directly.
249    pub emit: Option<&'a EmitKind>,
250    /// The input relations.
251    pub children: Vec<Option<Relation<'a>>>,
252}
253
254impl Textify for Relation<'_> {
255    fn name() -> &'static str {
256        "Relation"
257    }
258
259    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
260        let cols = Emitted::new(&self.columns, self.emit);
261        let indent = ctx.indent();
262        let name = self.name;
263        let cols = ctx.display(&cols);
264        match &self.arguments {
265            None => {
266                write!(w, "{indent}{name}[{cols}]")?;
267            }
268            Some(args) => {
269                let args = ctx.display(args);
270                write!(w, "{indent}{name}[{args} => {cols}]")?;
271            }
272        }
273        let child_scope = ctx.push_indent();
274        for child in self.children.iter().flatten() {
275            writeln!(w)?;
276            child.textify(&child_scope, w)?;
277        }
278        Ok(())
279    }
280}
281
282impl<'a> Relation<'a> {
283    pub fn emitted(&self) -> usize {
284        match self.emit {
285            Some(EmitKind::Emit(e)) => e.output_mapping.len(),
286            Some(EmitKind::Direct(_)) => self.columns.len(),
287            None => self.columns.len(),
288        }
289    }
290}
291
292#[derive(Debug, Copy, Clone)]
293pub struct TableName<'a>(&'a [String]);
294
295impl<'a> Textify for TableName<'a> {
296    fn name() -> &'static str {
297        "TableName"
298    }
299
300    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
301        let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
302        write!(w, "{}", ctx.separated(names.iter(), "."))
303    }
304}
305
306pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
307    match rel {
308        Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
309        _ => Err(PlanError::unimplemented(
310            "ReadRel",
311            Some("table_name"),
312            format!("Unexpected read type {rel:?}") as String,
313        )),
314    }
315}
316
317impl<'a> From<&'a ReadRel> for Relation<'a> {
318    fn from(rel: &'a ReadRel) -> Self {
319        let name = get_table_name(rel.read_type.as_ref());
320        let table_name: Value = match name {
321            Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
322            Err(e) => Value::Missing(e),
323        };
324
325        let columns = match rel.base_schema {
326            Some(ref schema) => schema_to_values(schema),
327            None => {
328                let err = PlanError::unimplemented(
329                    "ReadRel",
330                    Some("base_schema"),
331                    "Base schema is required",
332                );
333                vec![Value::Missing(err)]
334            }
335        };
336        let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
337
338        Relation {
339            name: "Read",
340            arguments: Some(Arguments {
341                positional: vec![table_name],
342                named: vec![],
343            }),
344            columns,
345            emit,
346            children: vec![],
347        }
348    }
349}
350
351pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
352    rel.as_ref().and_then(|c| c.emit_kind.as_ref())
353}
354
355impl<'a> Relation<'a> {
356    /// Create a vector of values that are references to the emitted outputs of
357    /// this relation. "Emitted" here meaning the outputs of this relation after
358    /// the emit kind has been applied.
359    ///
360    /// This is useful for relations like Filter and Limit whose direct outputs
361    /// are primarily those of its children (direct here meaning before the emit
362    /// has been applied).
363    pub fn input_refs(&self) -> Vec<Value<'a>> {
364        let len = self.emitted();
365        (0..len).map(|i| Value::Reference(i as i32)).collect()
366    }
367
368    /// Convert a vector of relation references into their structured form.
369    ///
370    /// Returns a list of children (with None for ones missing), and a count of input columns.
371    pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
372        let mut children = vec![];
373        let mut inputs = 0;
374
375        for maybe_rel in refs {
376            match maybe_rel {
377                Some(rel) => {
378                    let child = Relation::from(rel);
379                    inputs += child.emitted();
380                    children.push(Some(child));
381                }
382                None => children.push(None),
383            }
384        }
385
386        (children, inputs)
387    }
388}
389
390impl<'a> From<&'a FilterRel> for Relation<'a> {
391    fn from(rel: &'a FilterRel) -> Self {
392        let condition = rel
393            .condition
394            .as_ref()
395            .map(|c| Value::Expression(c.as_ref()));
396        let condition = Value::expect(condition, || {
397            PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
398        });
399        let positional = vec![condition];
400        let arguments = Some(Arguments {
401            positional,
402            named: vec![],
403        });
404        let emit = get_emit(rel.common.as_ref());
405        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
406        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
407
408        Relation {
409            name: "Filter",
410            arguments,
411            columns,
412            emit,
413            children,
414        }
415    }
416}
417
418impl<'a> From<&'a ProjectRel> for Relation<'a> {
419    fn from(rel: &'a ProjectRel) -> Self {
420        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
421        let expressions = rel.expressions.iter().map(Value::Expression);
422        let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
423        columns.extend(expressions);
424
425        Relation {
426            name: "Project",
427            arguments: None,
428            columns,
429            emit: get_emit(rel.common.as_ref()),
430            children,
431        }
432    }
433}
434
435impl<'a> From<&'a Rel> for Relation<'a> {
436    fn from(rel: &'a Rel) -> Self {
437        match rel.rel_type.as_ref() {
438            Some(RelType::Read(r)) => Relation::from(r.as_ref()),
439            Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
440            Some(RelType::Project(r)) => Relation::from(r.as_ref()),
441            Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
442            Some(RelType::Sort(r)) => Relation::from(r.as_ref()),
443            Some(RelType::Fetch(r)) => Relation::from(r.as_ref()),
444            Some(RelType::Join(r)) => Relation::from(r.as_ref()),
445            _ => todo!(),
446        }
447    }
448}
449
450impl<'a> From<&'a AggregateRel> for Relation<'a> {
451    /// Convert an AggregateRel to a Relation for textification.
452    ///
453    /// The conversion follows this logic:
454    /// 1. Arguments: Group-by expressions (as Value::Expression)
455    /// 2. Columns: All possible outputs in order:
456    ///    - First: Group-by field references (Value::Reference)
457    ///    - Then: Aggregate function measures (Value::AggregateFunction)
458    /// 3. Emit: Uses the relation's emit mapping to select which outputs to display
459    /// 4. Children: The input relation
460    fn from(rel: &'a AggregateRel) -> Self {
461        // Arguments: group-by fields (as expressions)
462        let positional = rel
463            .grouping_expressions
464            .iter()
465            .map(Value::Expression)
466            .collect::<Vec<_>>();
467
468        let arguments = Some(Arguments {
469            positional,
470            named: vec![],
471        });
472        // The columns are the direct outputs of this relation (before emit)
473        let mut all_outputs: Vec<Value> = vec![];
474        let input_field_count = rel.grouping_expressions.len();
475        for i in 0..input_field_count {
476            all_outputs.push(Value::Reference(i as i32));
477        }
478
479        // Then, add all measures (aggregate functions)
480        // These are indexed after the group-by fields
481        for m in &rel.measures {
482            if let Some(agg_fn) = m.measure.as_ref() {
483                all_outputs.push(Value::AggregateFunction(agg_fn));
484            }
485        }
486        let emit = get_emit(rel.common.as_ref());
487        Relation {
488            name: "Aggregate",
489            arguments,
490            columns: all_outputs,
491            emit,
492            children: rel
493                .input
494                .as_ref()
495                .map(|c| Some(Relation::from(c.as_ref())))
496                .into_iter()
497                .collect(),
498        }
499    }
500}
501
502impl Textify for RelRoot {
503    fn name() -> &'static str {
504        "RelRoot"
505    }
506
507    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
508        let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
509
510        write!(
511            w,
512            "{}Root[{}]",
513            ctx.indent(),
514            ctx.separated(names.iter(), ", ")
515        )?;
516        let child_scope = ctx.push_indent();
517        for child in self.input.iter() {
518            let child = Relation::from(child);
519            writeln!(w)?;
520            child.textify(&child_scope, w)?;
521        }
522
523        Ok(())
524    }
525}
526
527impl Textify for PlanRelType {
528    fn name() -> &'static str {
529        "PlanRelType"
530    }
531
532    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
533        match self {
534            PlanRelType::Rel(rel) => Relation::from(rel).textify(ctx, w),
535            PlanRelType::Root(root) => root.textify(ctx, w),
536        }
537    }
538}
539
540impl Textify for PlanRel {
541    fn name() -> &'static str {
542        "PlanRel"
543    }
544
545    /// Write the relation as a string. Inputs are ignored - those are handled
546    /// separately.
547    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
548        write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
549    }
550}
551
552impl<'a> From<&'a SortRel> for Relation<'a> {
553    fn from(rel: &'a SortRel) -> Self {
554        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
555        let positional = rel.sorts.iter().map(Value::from).collect::<Vec<_>>();
556        let arguments = Some(Arguments {
557            positional,
558            named: vec![],
559        });
560        // The columns are the direct outputs of this relation (before emit)
561        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
562        let emit = get_emit(rel.common.as_ref());
563        Relation {
564            name: "Sort",
565            arguments,
566            columns,
567            emit,
568            children,
569        }
570    }
571}
572
573impl<'a> From<&'a FetchRel> for Relation<'a> {
574    fn from(rel: &'a FetchRel) -> Self {
575        let (children, _columns) = Relation::convert_children(vec![rel.input.as_deref()]);
576        let mut named_args = Vec::new();
577        match &rel.count_mode {
578            Some(CountMode::CountExpr(expr)) => {
579                named_args.push(NamedArg {
580                    name: "limit",
581                    value: Value::Expression(expr),
582                });
583            }
584            #[allow(deprecated)]
585            Some(CountMode::Count(val)) => {
586                named_args.push(NamedArg {
587                    name: "limit",
588                    value: Value::Integer(*val as i32),
589                });
590            }
591            None => {}
592        }
593        if let Some(offset) = &rel.offset_mode {
594            match offset {
595                substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
596                    named_args.push(NamedArg {
597                        name: "offset",
598                        value: Value::Expression(expr),
599                    });
600                }
601                #[allow(deprecated)]
602                substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
603                    named_args.push(NamedArg {
604                        name: "offset",
605                        value: Value::Integer(*val as i32),
606                    });
607                }
608            }
609        }
610
611        let emit = get_emit(rel.common.as_ref());
612        let columns = match emit {
613            Some(EmitKind::Emit(e)) => e
614                .output_mapping
615                .iter()
616                .map(|&i| Value::Reference(i))
617                .collect(),
618            _ => vec![],
619        };
620        Relation {
621            name: "Fetch",
622            arguments: Some(Arguments {
623                positional: vec![],
624                named: named_args,
625            }),
626            columns,
627            emit,
628            children,
629        }
630    }
631}
632
633fn join_output_columns(
634    join_type: join_rel::JoinType,
635    left_columns: usize,
636    right_columns: usize,
637) -> Vec<Value<'static>> {
638    let total_columns = match join_type {
639        // Inner, Left, Right, Outer joins output columns from both sides
640        join_rel::JoinType::Inner
641        | join_rel::JoinType::Left
642        | join_rel::JoinType::Right
643        | join_rel::JoinType::Outer => left_columns + right_columns,
644
645        // Left semi/anti joins only output columns from the left side
646        join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
647
648        // Right semi/anti joins output columns from the right side
649        join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
650
651        // Single joins behave like semi joins
652        join_rel::JoinType::LeftSingle => left_columns,
653        join_rel::JoinType::RightSingle => right_columns,
654
655        // Mark joins output base columns plus one mark column
656        join_rel::JoinType::LeftMark => left_columns + 1,
657        join_rel::JoinType::RightMark => right_columns + 1,
658
659        // Unspecified - fallback to all columns
660        join_rel::JoinType::Unspecified => left_columns + right_columns,
661    };
662
663    // Output is always a contiguous range starting from $0
664    (0..total_columns)
665        .map(|i| Value::Reference(i as i32))
666        .collect()
667}
668
669impl<'a> From<&'a JoinRel> for Relation<'a> {
670    fn from(rel: &'a JoinRel) -> Self {
671        let (children, _total_columns) =
672            Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()]);
673
674        // convert_children should preserve input vector length
675        assert_eq!(
676            children.len(),
677            2,
678            "convert_children should return same number of elements as input"
679        );
680
681        // Calculate left and right column counts separately
682        let left_columns = match &children[0] {
683            Some(child) => child.emitted(),
684            None => 0,
685        };
686        let right_columns = match &children[1] {
687            Some(child) => child.emitted(),
688            None => 0,
689        };
690
691        // Convert join type from protobuf i32 to enum value
692        // JoinType is stored as i32 in protobuf, convert to typed enum for processing
693        let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
694            Ok(join_type) => {
695                let join_type_value = match join_type.as_enum_str() {
696                    Ok(s) => Value::Enum(s),
697                    Err(e) => Value::Missing(e),
698                };
699                (join_type, join_type_value)
700            }
701            Err(_) => {
702                // Use Unspecified for the join_type but create an error for the join_type_value
703                let join_type_error = Value::Missing(PlanError::invalid(
704                    "JoinRel",
705                    Some("type"),
706                    format!("Unknown join type: {}", rel.r#type),
707                ));
708                (join_rel::JoinType::Unspecified, join_type_error)
709            }
710        };
711
712        // Join condition
713        let condition = rel
714            .expression
715            .as_ref()
716            .map(|c| Value::Expression(c.as_ref()));
717        let condition = Value::expect(condition, || {
718            PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
719        });
720
721        // TODO: Add support for post_join_filter when grammar is extended
722        // Currently post_join_filter is not supported in the text format
723        // grammar
724        let positional = vec![join_type_value, condition];
725        let arguments = Some(Arguments {
726            positional,
727            named: vec![],
728        });
729
730        let emit = get_emit(rel.common.as_ref());
731        let columns = join_output_columns(join_type, left_columns, right_columns);
732
733        Relation {
734            name: "Join",
735            arguments,
736            columns,
737            emit,
738            children,
739        }
740    }
741}
742
743impl<'a> From<&'a SortField> for Value<'a> {
744    fn from(sf: &'a SortField) -> Self {
745        let field = match &sf.expr {
746            Some(expr) => match &expr.rex_type {
747                Some(substrait::proto::expression::RexType::Selection(fref)) => {
748                    if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
749                        if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
750                            Value::Reference(sf.field)
751                        } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
752                    } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
753                }
754                _ => Value::Missing(PlanError::unimplemented(
755                    "SortField",
756                    Some("expr"),
757                    "Not a selection",
758                )),
759            },
760            None => Value::Missing(PlanError::unimplemented(
761                "SortField",
762                Some("expr"),
763                "Missing expr",
764            )),
765        };
766        let direction = match &sf.sort_kind {
767            Some(kind) => Value::from(kind),
768            None => Value::Missing(PlanError::invalid(
769                "SortKind",
770                Some(Cow::Borrowed("sort_kind")),
771                "Missing sort_kind",
772            )),
773        };
774        Value::Tuple(vec![field, direction])
775    }
776}
777
778impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
779    fn from(enum_val: &'a T) -> Self {
780        match enum_val.as_enum_str() {
781            Ok(s) => Value::Enum(s),
782            Err(e) => Value::Missing(e),
783        }
784    }
785}
786
787impl ValueEnum for SortKind {
788    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
789        let d = match self {
790            &SortKind::Direction(d) => SortDirection::try_from(d),
791            SortKind::ComparisonFunctionReference(f) => {
792                return Err(PlanError::invalid(
793                    "SortKind",
794                    Some(Cow::Owned(format!("function reference{f}"))),
795                    "SortKind::ComparisonFunctionReference unimplemented",
796                ));
797            }
798        };
799        let s = match d {
800            Err(UnknownEnumValue(d)) => {
801                return Err(PlanError::invalid(
802                    "SortKind",
803                    Some(Cow::Owned(format!("unknown variant: {d:?}"))),
804                    "Unknown SortDirection",
805                ));
806            }
807            Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
808            Ok(SortDirection::AscNullsLast) => "AscNullsLast",
809            Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
810            Ok(SortDirection::DescNullsLast) => "DescNullsLast",
811            Ok(SortDirection::Clustered) => "Clustered",
812            Ok(SortDirection::Unspecified) => {
813                return Err(PlanError::invalid(
814                    "SortKind",
815                    Option::<Cow<str>>::None,
816                    "Unspecified SortDirection",
817                ));
818            }
819        };
820        Ok(Cow::Borrowed(s))
821    }
822}
823
824impl ValueEnum for join_rel::JoinType {
825    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
826        let s = match self {
827            join_rel::JoinType::Unspecified => {
828                return Err(PlanError::invalid(
829                    "JoinType",
830                    Option::<Cow<str>>::None,
831                    "Unspecified JoinType",
832                ));
833            }
834            join_rel::JoinType::Inner => "Inner",
835            join_rel::JoinType::Outer => "Outer",
836            join_rel::JoinType::Left => "Left",
837            join_rel::JoinType::Right => "Right",
838            join_rel::JoinType::LeftSemi => "LeftSemi",
839            join_rel::JoinType::RightSemi => "RightSemi",
840            join_rel::JoinType::LeftAnti => "LeftAnti",
841            join_rel::JoinType::RightAnti => "RightAnti",
842            join_rel::JoinType::LeftSingle => "LeftSingle",
843            join_rel::JoinType::RightSingle => "RightSingle",
844            join_rel::JoinType::LeftMark => "LeftMark",
845            join_rel::JoinType::RightMark => "RightMark",
846        };
847        Ok(Cow::Borrowed(s))
848    }
849}
850
851impl<'a> Textify for NamedArg<'a> {
852    fn name() -> &'static str {
853        "NamedArg"
854    }
855    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
856        write!(w, "{}=", self.name)?;
857        self.value.textify(ctx, w)
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use substrait::proto::expression::literal::LiteralType;
864    use substrait::proto::expression::{Literal, RexType, ScalarFunction};
865    use substrait::proto::function_argument::ArgType;
866    use substrait::proto::read_rel::{NamedTable, ReadType};
867    use substrait::proto::rel_common::Emit;
868    use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
869    use substrait::proto::{
870        Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
871    };
872
873    use super::*;
874    use crate::fixtures::TestContext;
875    use crate::parser::expressions::FieldIndex;
876
877    #[test]
878    fn test_read_rel() {
879        let ctx = TestContext::new();
880
881        // Create a simple ReadRel with a NamedStruct schema
882        let read_rel = ReadRel {
883            common: None,
884            base_schema: Some(NamedStruct {
885                names: vec!["col1".into(), "column 2".into()],
886                r#struct: Some(Struct {
887                    type_variation_reference: 0,
888                    types: vec![
889                        Type {
890                            kind: Some(Kind::I32(ptype::I32 {
891                                type_variation_reference: 0,
892                                nullability: Nullability::Nullable as i32,
893                            })),
894                        },
895                        Type {
896                            kind: Some(Kind::String(ptype::String {
897                                type_variation_reference: 0,
898                                nullability: Nullability::Nullable as i32,
899                            })),
900                        },
901                    ],
902                    nullability: Nullability::Nullable as i32,
903                }),
904            }),
905            filter: None,
906            best_effort_filter: None,
907            projection: None,
908            advanced_extension: None,
909            read_type: Some(ReadType::NamedTable(NamedTable {
910                names: vec!["some_db".into(), "test_table".into()],
911                advanced_extension: None,
912            })),
913        };
914
915        let rel = Relation::from(&read_rel);
916
917        let (result, errors) = ctx.textify(&rel);
918        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
919        assert_eq!(
920            result,
921            "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
922        );
923    }
924
925    #[test]
926    fn test_filter_rel() {
927        let ctx = TestContext::new()
928            .with_urn(1, "test_urn")
929            .with_function(1, 10, "gt");
930
931        // Create a simple FilterRel with a ReadRel input and a filter expression
932        let read_rel = ReadRel {
933            common: None,
934            base_schema: Some(NamedStruct {
935                names: vec!["col1".into(), "col2".into()],
936                r#struct: Some(Struct {
937                    type_variation_reference: 0,
938                    types: vec![
939                        Type {
940                            kind: Some(Kind::I32(ptype::I32 {
941                                type_variation_reference: 0,
942                                nullability: Nullability::Nullable as i32,
943                            })),
944                        },
945                        Type {
946                            kind: Some(Kind::I32(ptype::I32 {
947                                type_variation_reference: 0,
948                                nullability: Nullability::Nullable as i32,
949                            })),
950                        },
951                    ],
952                    nullability: Nullability::Nullable as i32,
953                }),
954            }),
955            filter: None,
956            best_effort_filter: None,
957            projection: None,
958            advanced_extension: None,
959            read_type: Some(ReadType::NamedTable(NamedTable {
960                names: vec!["test_table".into()],
961                advanced_extension: None,
962            })),
963        };
964
965        // Create a filter expression: col1 > 10
966        let filter_expr = Expression {
967            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
968                function_reference: 10, // gt function
969                arguments: vec![
970                    FunctionArgument {
971                        arg_type: Some(ArgType::Value(Reference(0).into())),
972                    },
973                    FunctionArgument {
974                        arg_type: Some(ArgType::Value(Expression {
975                            rex_type: Some(RexType::Literal(Literal {
976                                literal_type: Some(LiteralType::I32(10)),
977                                nullable: false,
978                                type_variation_reference: 0,
979                            })),
980                        })),
981                    },
982                ],
983                options: vec![],
984                output_type: None,
985                #[allow(deprecated)]
986                args: vec![],
987            })),
988        };
989
990        let filter_rel = FilterRel {
991            common: None,
992            input: Some(Box::new(Rel {
993                rel_type: Some(RelType::Read(Box::new(read_rel))),
994            })),
995            condition: Some(Box::new(filter_expr)),
996            advanced_extension: None,
997        };
998
999        let rel = Rel {
1000            rel_type: Some(RelType::Filter(Box::new(filter_rel))),
1001        };
1002
1003        let rel = Relation::from(&rel);
1004
1005        let (result, errors) = ctx.textify(&rel);
1006        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1007        let expected = r#"
1008Filter[gt($0, 10:i32) => $0, $1]
1009  Read[test_table => col1:i32?, col2:i32?]"#
1010            .trim_start();
1011        assert_eq!(result, expected);
1012    }
1013
1014    #[test]
1015    fn test_aggregate_function_textify() {
1016        let ctx = TestContext::new()
1017        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1018        .with_function(1, 10, "sum")
1019        .with_function(1, 11, "count");
1020
1021        // Create a simple AggregateFunction
1022        let agg_fn = AggregateFunction {
1023            function_reference: 10, // sum
1024            arguments: vec![FunctionArgument {
1025                arg_type: Some(ArgType::Value(Expression {
1026                    rex_type: Some(RexType::Selection(Box::new(
1027                        FieldIndex(1).to_field_reference(),
1028                    ))),
1029                })),
1030            }],
1031            options: vec![],
1032            output_type: None,
1033            invocation: 0,
1034            phase: 0,
1035            sorts: vec![],
1036            #[allow(deprecated)]
1037            args: vec![],
1038        };
1039
1040        let value = Value::AggregateFunction(&agg_fn);
1041        let (result, errors) = ctx.textify(&value);
1042
1043        println!("Textification result: {result}");
1044        if !errors.is_empty() {
1045            println!("Errors: {errors:?}");
1046        }
1047
1048        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1049        assert_eq!(result, "sum($1)");
1050    }
1051
1052    #[test]
1053    fn test_aggregate_relation_textify() {
1054        let ctx = TestContext::new()
1055        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1056        .with_function(1, 10, "sum")
1057        .with_function(1, 11, "count");
1058
1059        // Create a simple AggregateRel
1060        let agg_fn1 = AggregateFunction {
1061            function_reference: 10, // sum
1062            arguments: vec![FunctionArgument {
1063                arg_type: Some(ArgType::Value(Expression {
1064                    rex_type: Some(RexType::Selection(Box::new(
1065                        FieldIndex(1).to_field_reference(),
1066                    ))),
1067                })),
1068            }],
1069            options: vec![],
1070            output_type: None,
1071            invocation: 0,
1072            phase: 0,
1073            sorts: vec![],
1074            #[allow(deprecated)]
1075            args: vec![],
1076        };
1077
1078        let agg_fn2 = AggregateFunction {
1079            function_reference: 11, // count
1080            arguments: vec![FunctionArgument {
1081                arg_type: Some(ArgType::Value(Expression {
1082                    rex_type: Some(RexType::Selection(Box::new(
1083                        FieldIndex(1).to_field_reference(),
1084                    ))),
1085                })),
1086            }],
1087            options: vec![],
1088            output_type: None,
1089            invocation: 0,
1090            phase: 0,
1091            sorts: vec![],
1092            #[allow(deprecated)]
1093            args: vec![],
1094        };
1095
1096        let aggregate_rel = AggregateRel {
1097            input: Some(Box::new(Rel {
1098                rel_type: Some(RelType::Read(Box::new(ReadRel {
1099                    common: None,
1100                    base_schema: Some(NamedStruct {
1101                        names: vec!["category".into(), "amount".into()],
1102                        r#struct: Some(Struct {
1103                            type_variation_reference: 0,
1104                            types: vec![
1105                                Type {
1106                                    kind: Some(Kind::String(ptype::String {
1107                                        type_variation_reference: 0,
1108                                        nullability: Nullability::Nullable as i32,
1109                                    })),
1110                                },
1111                                Type {
1112                                    kind: Some(Kind::Fp64(ptype::Fp64 {
1113                                        type_variation_reference: 0,
1114                                        nullability: Nullability::Nullable as i32,
1115                                    })),
1116                                },
1117                            ],
1118                            nullability: Nullability::Nullable as i32,
1119                        }),
1120                    }),
1121                    filter: None,
1122                    best_effort_filter: None,
1123                    projection: None,
1124                    advanced_extension: None,
1125                    read_type: Some(ReadType::NamedTable(NamedTable {
1126                        names: vec!["orders".into()],
1127                        advanced_extension: None,
1128                    })),
1129                }))),
1130            })),
1131            grouping_expressions: vec![Expression {
1132                rex_type: Some(RexType::Selection(Box::new(
1133                    FieldIndex(0).to_field_reference(),
1134                ))),
1135            }],
1136            groupings: vec![],
1137            measures: vec![
1138                aggregate_rel::Measure {
1139                    measure: Some(agg_fn1),
1140                    filter: None,
1141                },
1142                aggregate_rel::Measure {
1143                    measure: Some(agg_fn2),
1144                    filter: None,
1145                },
1146            ],
1147            common: Some(RelCommon {
1148                emit_kind: Some(EmitKind::Emit(Emit {
1149                    output_mapping: vec![1, 2], // measures only
1150                })),
1151                ..Default::default()
1152            }),
1153            advanced_extension: None,
1154        };
1155
1156        let relation = Relation::from(&aggregate_rel);
1157        let (result, errors) = ctx.textify(&relation);
1158
1159        println!("Aggregate relation textification result:");
1160        println!("{result}");
1161        if !errors.is_empty() {
1162            println!("Errors: {errors:?}");
1163        }
1164
1165        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1166        // Expected: Aggregate[$0 => sum($1), count($1)]
1167        assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
1168    }
1169
1170    #[test]
1171    fn test_arguments_textify_positional_only() {
1172        let ctx = TestContext::new();
1173        let args = Arguments {
1174            positional: vec![Value::Integer(42), Value::Integer(7)],
1175            named: vec![],
1176        };
1177        let (result, errors) = ctx.textify(&args);
1178        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1179        assert_eq!(result, "42, 7");
1180    }
1181
1182    #[test]
1183    fn test_arguments_textify_named_only() {
1184        let ctx = TestContext::new();
1185        let args = Arguments {
1186            positional: vec![],
1187            named: vec![
1188                NamedArg {
1189                    name: "limit",
1190                    value: Value::Integer(10),
1191                },
1192                NamedArg {
1193                    name: "offset",
1194                    value: Value::Integer(5),
1195                },
1196            ],
1197        };
1198        let (result, errors) = ctx.textify(&args);
1199        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1200        assert_eq!(result, "limit=10, offset=5");
1201    }
1202
1203    #[test]
1204    fn test_join_relation_unknown_type() {
1205        let ctx = TestContext::new();
1206
1207        // Create a join with an unknown/invalid type
1208        let join_rel = JoinRel {
1209            left: Some(Box::new(Rel {
1210                rel_type: Some(RelType::Read(Box::default())),
1211            })),
1212            right: Some(Box::new(Rel {
1213                rel_type: Some(RelType::Read(Box::default())),
1214            })),
1215            expression: Some(Box::new(Expression::default())),
1216            r#type: 999, // Invalid join type
1217            common: None,
1218            post_join_filter: None,
1219            advanced_extension: None,
1220        };
1221
1222        let relation = Relation::from(&join_rel);
1223        let (result, errors) = ctx.textify(&relation);
1224
1225        // Should contain error for unknown join type but still show condition and columns
1226        assert!(!errors.is_empty(), "Expected errors for unknown join type");
1227        assert!(
1228            result.contains("!{JoinRel}"),
1229            "Expected error token for unknown join type"
1230        );
1231        assert!(
1232            result.contains("Join["),
1233            "Expected Join relation to be formatted"
1234        );
1235        println!("Unknown join type result: {result}");
1236    }
1237
1238    #[test]
1239    fn test_arguments_textify_both() {
1240        let ctx = TestContext::new();
1241        let args = Arguments {
1242            positional: vec![Value::Integer(1)],
1243            named: vec![NamedArg {
1244                name: "foo",
1245                value: Value::Integer(2),
1246            }],
1247        };
1248        let (result, errors) = ctx.textify(&args);
1249        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1250        assert_eq!(result, "1, foo=2");
1251    }
1252
1253    #[test]
1254    fn test_arguments_textify_empty() {
1255        let ctx = TestContext::new();
1256        let args = Arguments {
1257            positional: vec![],
1258            named: vec![],
1259        };
1260        let (result, errors) = ctx.textify(&args);
1261        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1262        assert_eq!(result, "_");
1263    }
1264
1265    #[test]
1266    fn test_named_arg_textify_error_token() {
1267        let ctx = TestContext::new();
1268        let named_arg = NamedArg {
1269            name: "foo",
1270            value: Value::Missing(PlanError::invalid(
1271                "my_enum",
1272                Some(Cow::Borrowed("my_enum")),
1273                Cow::Borrowed("my_enum"),
1274            )),
1275        };
1276        let (result, errors) = ctx.textify(&named_arg);
1277        // Should show !{my_enum} in the output
1278        assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1279        // Should also accumulate an error
1280        assert!(!errors.is_empty(), "Expected error for error token");
1281    }
1282
1283    #[test]
1284    fn test_join_type_enum_textify() {
1285        // Test that JoinType enum values convert correctly to their string representation
1286        assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1287        assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1288        assert_eq!(
1289            join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1290            "LeftSemi"
1291        );
1292        assert_eq!(
1293            join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1294            "LeftAnti"
1295        );
1296    }
1297
1298    #[test]
1299    fn test_join_output_columns() {
1300        // Test Inner join - outputs all columns from both sides
1301        let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1302        assert_eq!(inner_cols.len(), 5); // 2 + 3 = 5 columns
1303        assert!(matches!(inner_cols[0], Value::Reference(0)));
1304        assert!(matches!(inner_cols[4], Value::Reference(4)));
1305
1306        // Test LeftSemi join - outputs only left columns
1307        let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1308        assert_eq!(left_semi_cols.len(), 2); // Only left columns
1309        assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1310        assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1311
1312        // Test RightSemi join - outputs right columns as contiguous range starting from $0
1313        let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1314        assert_eq!(right_semi_cols.len(), 3); // Only right columns
1315        assert!(matches!(right_semi_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1316        assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1317        assert!(matches!(right_semi_cols[2], Value::Reference(2))); // Last right column
1318
1319        // Test LeftMark join - outputs left columns plus a mark column as contiguous range
1320        let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1321        assert_eq!(left_mark_cols.len(), 3); // 2 left + 1 mark
1322        assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1323        assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1324        assert!(matches!(left_mark_cols[2], Value::Reference(2))); // Mark column at contiguous position
1325
1326        // Test RightMark join - outputs right columns plus a mark column as contiguous range
1327        let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1328        assert_eq!(right_mark_cols.len(), 4); // 3 right + 1 mark
1329        assert!(matches!(right_mark_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1330        assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1331        assert!(matches!(right_mark_cols[2], Value::Reference(2))); // Last right column
1332        assert!(matches!(right_mark_cols[3], Value::Reference(3))); // Mark column at contiguous position
1333    }
1334}