substrait_explain/textify/
rels.rs

1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::fmt;
4use std::fmt::Debug;
5
6use prost::UnknownEnumValue;
7use substrait::proto::fetch_rel::CountMode;
8use substrait::proto::plan_rel::RelType as PlanRelType;
9use substrait::proto::read_rel::ReadType;
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::EmitKind;
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14    AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct,
15    PlanRel, ProjectRel, ReadRel, Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
16};
17
18use super::expressions::Reference;
19use super::types::Name;
20use super::{PlanError, Scope, Textify};
21
22pub trait NamedRelation {
23    fn name(&self) -> &'static str;
24}
25
26impl NamedRelation for Rel {
27    fn name(&self) -> &'static str {
28        match self.rel_type.as_ref() {
29            None => "UnknownRel",
30            Some(RelType::Read(_)) => "Read",
31            Some(RelType::Filter(_)) => "Filter",
32            Some(RelType::Project(_)) => "Project",
33            Some(RelType::Fetch(_)) => "Fetch",
34            Some(RelType::Aggregate(_)) => "Aggregate",
35            Some(RelType::Sort(_)) => "Sort",
36            Some(RelType::HashJoin(_)) => "HashJoin",
37            Some(RelType::Exchange(_)) => "Exchange",
38            Some(RelType::Join(_)) => "Join",
39            Some(RelType::Set(_)) => "Set",
40            Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
41            Some(RelType::Cross(_)) => "Cross",
42            Some(RelType::Reference(_)) => "Reference",
43            Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
44            Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
45            Some(RelType::Write(_)) => "Write",
46            Some(RelType::Ddl(_)) => "Ddl",
47            Some(RelType::Update(_)) => "Update",
48            Some(RelType::MergeJoin(_)) => "MergeJoin",
49            Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
50            Some(RelType::Window(_)) => "Window",
51            Some(RelType::Expand(_)) => "Expand",
52        }
53    }
54}
55
56impl Textify for Rel {
57    fn name() -> &'static str {
58        "Rel"
59    }
60
61    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
62        // Handle extension relations directly to use registry-aware implementations
63        match &self.rel_type {
64            Some(substrait::proto::rel::RelType::ExtensionLeaf(ext)) => ext.textify(ctx, w),
65            Some(substrait::proto::rel::RelType::ExtensionSingle(ext)) => ext.textify(ctx, w),
66            Some(substrait::proto::rel::RelType::ExtensionMulti(ext)) => ext.textify(ctx, w),
67            _ => {
68                // For non-extension relations, use the old Relation conversion
69                let relation = Relation::from(self);
70                relation.textify(ctx, w)
71            }
72        }
73    }
74}
75
76/// Trait for enums that can be converted to a string representation for
77/// textification.
78///
79/// Returns Ok(str) for valid enum values, or Err([PlanError]) for invalid or
80/// unknown values.
81pub trait ValueEnum {
82    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
83}
84
85#[derive(Debug, Clone)]
86pub struct NamedArg<'a> {
87    pub name: &'a str,
88    pub value: Value<'a>,
89}
90
91#[derive(Debug, Clone)]
92pub enum Value<'a> {
93    Name(Name<'a>),
94    TableName(Vec<Name<'a>>),
95    Field(Option<Name<'a>>, Option<&'a Type>),
96    Tuple(Vec<Value<'a>>),
97    List(Vec<Value<'a>>),
98    Reference(i32),
99    Expression(&'a Expression),
100    AggregateFunction(&'a AggregateFunction),
101    /// Represents a missing, invalid, or unspecified value.
102    Missing(PlanError),
103    /// Represents a valid enum value as a string for textification.
104    Enum(Cow<'a, str>),
105    Integer(i64),
106    Float(f64),
107    Boolean(bool),
108}
109
110impl<'a> Value<'a> {
111    pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
112        match maybe_value {
113            Some(s) => s,
114            None => Value::Missing(f()),
115        }
116    }
117}
118
119impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
120    fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
121        match token {
122            Ok(value) => Value::TableName(value),
123            Err(err) => Value::Missing(err),
124        }
125    }
126}
127
128impl<'a> Textify for Value<'a> {
129    fn name() -> &'static str {
130        "Value"
131    }
132
133    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
134        match self {
135            Value::Name(name) => write!(w, "{}", ctx.display(name)),
136            Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
137            Value::Field(name, typ) => {
138                write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
139            }
140            Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
141            Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
142            Value::Reference(i) => write!(w, "{}", Reference(*i)),
143            Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
144            Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
145            Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
146            Value::Enum(res) => write!(w, "&{res}"),
147            Value::Integer(i) => write!(w, "{i}"),
148            Value::Float(f) => write!(w, "{f}"),
149            Value::Boolean(b) => write!(w, "{b}"),
150        }
151    }
152}
153
154fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
155    let mut fields = schema
156        .r#struct
157        .as_ref()
158        .map(|s| s.types.iter())
159        .into_iter()
160        .flatten();
161    let mut names = schema.names.iter();
162
163    // let field_count = schema.r#struct.as_ref().map(|s| s.types.len()).unwrap_or(0);
164    // let name_count = schema.names.len();
165
166    let mut values = Vec::new();
167    loop {
168        let field = fields.next();
169        let name = names.next().map(|n| Name(n));
170        if field.is_none() && name.is_none() {
171            break;
172        }
173
174        values.push(Value::Field(name, field));
175    }
176
177    values
178}
179
180struct Emitted<'a> {
181    pub values: &'a [Value<'a>],
182    pub emit: Option<&'a EmitKind>,
183}
184
185impl<'a> Emitted<'a> {
186    pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
187        Self { values, emit }
188    }
189
190    pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
191        write!(w, "{}", ctx.separated(self.values.iter(), ", "))
192    }
193}
194
195impl<'a> Textify for Emitted<'a> {
196    fn name() -> &'static str {
197        "Emitted"
198    }
199
200    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
201        if ctx.options().show_emit {
202            return self.write_direct(ctx, w);
203        }
204
205        let indices = match &self.emit {
206            Some(EmitKind::Emit(e)) => &e.output_mapping,
207            Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
208            None => return self.write_direct(ctx, w),
209        };
210
211        for (i, &index) in indices.iter().enumerate() {
212            if i > 0 {
213                write!(w, ", ")?;
214            }
215
216            match self.values.get(index as usize) {
217                Some(value) => write!(w, "{}", ctx.display(value))?,
218                None => write!(w, "{}", ctx.failure(PlanError::invalid(
219                    "Emitted",
220                    Some("output_mapping"),
221                    format!(
222                        "Output mapping index {} is out of bounds for values collection of size {}",
223                        index, self.values.len()
224                    )
225                )))?,
226            }
227        }
228
229        Ok(())
230    }
231}
232
233#[derive(Debug, Clone)]
234pub struct Arguments<'a> {
235    /// Positional arguments (e.g., a filter condition, group-bys, etc.))
236    pub positional: Vec<Value<'a>>,
237    /// Named arguments (e.g., limit=10, offset=5)
238    pub named: Vec<NamedArg<'a>>,
239}
240
241impl<'a> Textify for Arguments<'a> {
242    fn name() -> &'static str {
243        "Arguments"
244    }
245    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
246        if self.positional.is_empty() && self.named.is_empty() {
247            return write!(w, "_");
248        }
249
250        write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
251        if !self.positional.is_empty() && !self.named.is_empty() {
252            write!(w, ", ")?;
253        }
254        write!(w, "{}", ctx.separated(self.named.iter(), ", "))
255    }
256}
257
258pub struct Relation<'a> {
259    pub name: &'a str,
260    /// Arguments to the relation, if any.
261    ///
262    /// - `None` means this relation does not take arguments, and the argument
263    ///   section is omitted entirely.
264    /// - `Some(args)` with both vectors empty means the relation takes
265    ///   arguments, but none are provided; this will print as `_ => ...`.
266    /// - `Some(args)` with non-empty vectors will print as usual, with
267    ///   positional arguments first, then named arguments, separated by commas.
268    pub arguments: Option<Arguments<'a>>,
269    /// The columns emitted by this relation, pre-emit - the 'direct' column
270    /// output.
271    pub columns: Vec<Value<'a>>,
272    /// The emit kind, if any. If none, use the columns directly.
273    pub emit: Option<&'a EmitKind>,
274    /// The input relations.
275    pub children: Vec<Option<Relation<'a>>>,
276}
277
278impl Textify for Relation<'_> {
279    fn name() -> &'static str {
280        "Relation"
281    }
282
283    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
284        let cols = Emitted::new(&self.columns, self.emit);
285        let indent = ctx.indent();
286        let name = self.name;
287        let cols = ctx.display(&cols);
288        match &self.arguments {
289            None => {
290                write!(w, "{indent}{name}[{cols}]")?;
291            }
292            Some(args) => {
293                let args = ctx.display(args);
294                write!(w, "{indent}{name}[{args} => {cols}]")?;
295            }
296        }
297        let child_scope = ctx.push_indent();
298        for child in self.children.iter().flatten() {
299            writeln!(w)?;
300            child.textify(&child_scope, w)?;
301        }
302        Ok(())
303    }
304}
305
306impl<'a> Relation<'a> {
307    pub fn emitted(&self) -> usize {
308        match self.emit {
309            Some(EmitKind::Emit(e)) => e.output_mapping.len(),
310            Some(EmitKind::Direct(_)) => self.columns.len(),
311            None => self.columns.len(),
312        }
313    }
314}
315
316#[derive(Debug, Copy, Clone)]
317pub struct TableName<'a>(&'a [String]);
318
319impl<'a> Textify for TableName<'a> {
320    fn name() -> &'static str {
321        "TableName"
322    }
323
324    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
325        let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
326        write!(w, "{}", ctx.separated(names.iter(), "."))
327    }
328}
329
330pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
331    match rel {
332        Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
333        _ => Err(PlanError::unimplemented(
334            "ReadRel",
335            Some("table_name"),
336            format!("Unexpected read type {rel:?}") as String,
337        )),
338    }
339}
340
341impl<'a> From<&'a ReadRel> for Relation<'a> {
342    fn from(rel: &'a ReadRel) -> Self {
343        let name = get_table_name(rel.read_type.as_ref());
344        let table_name: Value = match name {
345            Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
346            Err(e) => Value::Missing(e),
347        };
348
349        let columns = match rel.base_schema {
350            Some(ref schema) => schema_to_values(schema),
351            None => {
352                let err = PlanError::unimplemented(
353                    "ReadRel",
354                    Some("base_schema"),
355                    "Base schema is required",
356                );
357                vec![Value::Missing(err)]
358            }
359        };
360        let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
361
362        Relation {
363            name: "Read",
364            arguments: Some(Arguments {
365                positional: vec![table_name],
366                named: vec![],
367            }),
368            columns,
369            emit,
370            children: vec![],
371        }
372    }
373}
374
375pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
376    rel.as_ref().and_then(|c| c.emit_kind.as_ref())
377}
378
379impl<'a> Relation<'a> {
380    /// Create a vector of values that are references to the emitted outputs of
381    /// this relation. "Emitted" here meaning the outputs of this relation after
382    /// the emit kind has been applied.
383    ///
384    /// This is useful for relations like Filter and Limit whose direct outputs
385    /// are primarily those of its children (direct here meaning before the emit
386    /// has been applied).
387    pub fn input_refs(&self) -> Vec<Value<'a>> {
388        let len = self.emitted();
389        (0..len).map(|i| Value::Reference(i as i32)).collect()
390    }
391
392    /// Convert a vector of relation references into their structured form.
393    ///
394    /// Returns a list of children (with None for ones missing), and a count of input columns.
395    pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
396        let mut children = vec![];
397        let mut inputs = 0;
398
399        for maybe_rel in refs {
400            match maybe_rel {
401                Some(rel) => {
402                    let child = Relation::from(rel);
403                    inputs += child.emitted();
404                    children.push(Some(child));
405                }
406                None => children.push(None),
407            }
408        }
409
410        (children, inputs)
411    }
412}
413
414impl<'a> From<&'a FilterRel> for Relation<'a> {
415    fn from(rel: &'a FilterRel) -> Self {
416        let condition = rel
417            .condition
418            .as_ref()
419            .map(|c| Value::Expression(c.as_ref()));
420        let condition = Value::expect(condition, || {
421            PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
422        });
423        let positional = vec![condition];
424        let arguments = Some(Arguments {
425            positional,
426            named: vec![],
427        });
428        let emit = get_emit(rel.common.as_ref());
429        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
430        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
431
432        Relation {
433            name: "Filter",
434            arguments,
435            columns,
436            emit,
437            children,
438        }
439    }
440}
441
442impl<'a> From<&'a ProjectRel> for Relation<'a> {
443    fn from(rel: &'a ProjectRel) -> Self {
444        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
445        let expressions = rel.expressions.iter().map(Value::Expression);
446        let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
447        columns.extend(expressions);
448
449        Relation {
450            name: "Project",
451            arguments: None,
452            columns,
453            emit: get_emit(rel.common.as_ref()),
454            children,
455        }
456    }
457}
458
459impl<'a> From<&'a Rel> for Relation<'a> {
460    fn from(rel: &'a Rel) -> Self {
461        match rel.rel_type.as_ref() {
462            Some(RelType::Read(r)) => Relation::from(r.as_ref()),
463            Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
464            Some(RelType::Project(r)) => Relation::from(r.as_ref()),
465            Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
466            Some(RelType::Sort(r)) => Relation::from(r.as_ref()),
467            Some(RelType::Fetch(r)) => Relation::from(r.as_ref()),
468            Some(RelType::Join(r)) => Relation::from(r.as_ref()),
469            Some(RelType::ExtensionLeaf(_)) => {
470                unreachable!("Extension relations should use Textify trait directly")
471            }
472            Some(RelType::ExtensionSingle(_)) => {
473                unreachable!("Extension relations should use Textify trait directly")
474            }
475            Some(RelType::ExtensionMulti(_)) => {
476                unreachable!("Extension relations should use Textify trait directly")
477            }
478            _ => todo!(),
479        }
480    }
481}
482
483impl<'a> From<&'a AggregateRel> for Relation<'a> {
484    /// Convert an AggregateRel to a Relation for textification.
485    ///
486    /// The conversion follows this logic:
487    /// 1. Arguments: Group-by expressions (as Value::Expression)
488    /// 2. Columns: All possible outputs in order:
489    ///    - First: Group-by field references (Value::Reference)
490    ///    - Then: Aggregate function measures (Value::AggregateFunction)
491    /// 3. Emit: Uses the relation's emit mapping to select which outputs to display
492    /// 4. Children: The input relation
493    fn from(rel: &'a AggregateRel) -> Self {
494        // Arguments: group-by fields (as expressions)
495        let positional = rel
496            .grouping_expressions
497            .iter()
498            .map(Value::Expression)
499            .collect::<Vec<_>>();
500
501        let arguments = Some(Arguments {
502            positional,
503            named: vec![],
504        });
505        // The columns are the direct outputs of this relation (before emit)
506        let mut all_outputs: Vec<Value> = vec![];
507        let input_field_count = rel.grouping_expressions.len();
508        for i in 0..input_field_count {
509            all_outputs.push(Value::Reference(i as i32));
510        }
511
512        // Then, add all measures (aggregate functions)
513        // These are indexed after the group-by fields
514        for m in &rel.measures {
515            if let Some(agg_fn) = m.measure.as_ref() {
516                all_outputs.push(Value::AggregateFunction(agg_fn));
517            }
518        }
519        let emit = get_emit(rel.common.as_ref());
520        Relation {
521            name: "Aggregate",
522            arguments,
523            columns: all_outputs,
524            emit,
525            children: rel
526                .input
527                .as_ref()
528                .map(|c| Some(Relation::from(c.as_ref())))
529                .into_iter()
530                .collect(),
531        }
532    }
533}
534
535impl Textify for RelRoot {
536    fn name() -> &'static str {
537        "RelRoot"
538    }
539
540    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
541        let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
542
543        write!(
544            w,
545            "{}Root[{}]",
546            ctx.indent(),
547            ctx.separated(names.iter(), ", ")
548        )?;
549        let child_scope = ctx.push_indent();
550        for child in self.input.iter() {
551            writeln!(w)?;
552            child.textify(&child_scope, w)?;
553        }
554
555        Ok(())
556    }
557}
558
559impl Textify for PlanRelType {
560    fn name() -> &'static str {
561        "PlanRelType"
562    }
563
564    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
565        match self {
566            PlanRelType::Rel(rel) => rel.textify(ctx, w),
567            PlanRelType::Root(root) => root.textify(ctx, w),
568        }
569    }
570}
571
572impl Textify for PlanRel {
573    fn name() -> &'static str {
574        "PlanRel"
575    }
576
577    /// Write the relation as a string. Inputs are ignored - those are handled
578    /// separately.
579    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
580        write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
581    }
582}
583
584impl<'a> From<&'a SortRel> for Relation<'a> {
585    fn from(rel: &'a SortRel) -> Self {
586        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
587        let positional = rel.sorts.iter().map(Value::from).collect::<Vec<_>>();
588        let arguments = Some(Arguments {
589            positional,
590            named: vec![],
591        });
592        // The columns are the direct outputs of this relation (before emit)
593        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
594        let emit = get_emit(rel.common.as_ref());
595        Relation {
596            name: "Sort",
597            arguments,
598            columns,
599            emit,
600            children,
601        }
602    }
603}
604
605impl<'a> From<&'a FetchRel> for Relation<'a> {
606    fn from(rel: &'a FetchRel) -> Self {
607        let (children, _columns) = Relation::convert_children(vec![rel.input.as_deref()]);
608        let mut named_args = Vec::new();
609        match &rel.count_mode {
610            Some(CountMode::CountExpr(expr)) => {
611                named_args.push(NamedArg {
612                    name: "limit",
613                    value: Value::Expression(expr),
614                });
615            }
616            #[allow(deprecated)]
617            Some(CountMode::Count(val)) => {
618                named_args.push(NamedArg {
619                    name: "limit",
620                    value: Value::Integer(*val),
621                });
622            }
623            None => {}
624        }
625        if let Some(offset) = &rel.offset_mode {
626            match offset {
627                substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
628                    named_args.push(NamedArg {
629                        name: "offset",
630                        value: Value::Expression(expr),
631                    });
632                }
633                #[allow(deprecated)]
634                substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
635                    named_args.push(NamedArg {
636                        name: "offset",
637                        value: Value::Integer(*val),
638                    });
639                }
640            }
641        }
642
643        let emit = get_emit(rel.common.as_ref());
644        let columns = match emit {
645            Some(EmitKind::Emit(e)) => e
646                .output_mapping
647                .iter()
648                .map(|&i| Value::Reference(i))
649                .collect(),
650            _ => vec![],
651        };
652        Relation {
653            name: "Fetch",
654            arguments: Some(Arguments {
655                positional: vec![],
656                named: named_args,
657            }),
658            columns,
659            emit,
660            children,
661        }
662    }
663}
664
665fn join_output_columns(
666    join_type: join_rel::JoinType,
667    left_columns: usize,
668    right_columns: usize,
669) -> Vec<Value<'static>> {
670    let total_columns = match join_type {
671        // Inner, Left, Right, Outer joins output columns from both sides
672        join_rel::JoinType::Inner
673        | join_rel::JoinType::Left
674        | join_rel::JoinType::Right
675        | join_rel::JoinType::Outer => left_columns + right_columns,
676
677        // Left semi/anti joins only output columns from the left side
678        join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
679
680        // Right semi/anti joins output columns from the right side
681        join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
682
683        // Single joins behave like semi joins
684        join_rel::JoinType::LeftSingle => left_columns,
685        join_rel::JoinType::RightSingle => right_columns,
686
687        // Mark joins output base columns plus one mark column
688        join_rel::JoinType::LeftMark => left_columns + 1,
689        join_rel::JoinType::RightMark => right_columns + 1,
690
691        // Unspecified - fallback to all columns
692        join_rel::JoinType::Unspecified => left_columns + right_columns,
693    };
694
695    // Output is always a contiguous range starting from $0
696    (0..total_columns)
697        .map(|i| Value::Reference(i as i32))
698        .collect()
699}
700
701impl<'a> From<&'a JoinRel> for Relation<'a> {
702    fn from(rel: &'a JoinRel) -> Self {
703        let (children, _total_columns) =
704            Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()]);
705
706        // convert_children should preserve input vector length
707        assert_eq!(
708            children.len(),
709            2,
710            "convert_children should return same number of elements as input"
711        );
712
713        // Calculate left and right column counts separately
714        let left_columns = match &children[0] {
715            Some(child) => child.emitted(),
716            None => 0,
717        };
718        let right_columns = match &children[1] {
719            Some(child) => child.emitted(),
720            None => 0,
721        };
722
723        // Convert join type from protobuf i32 to enum value
724        // JoinType is stored as i32 in protobuf, convert to typed enum for processing
725        let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
726            Ok(join_type) => {
727                let join_type_value = match join_type.as_enum_str() {
728                    Ok(s) => Value::Enum(s),
729                    Err(e) => Value::Missing(e),
730                };
731                (join_type, join_type_value)
732            }
733            Err(_) => {
734                // Use Unspecified for the join_type but create an error for the join_type_value
735                let join_type_error = Value::Missing(PlanError::invalid(
736                    "JoinRel",
737                    Some("type"),
738                    format!("Unknown join type: {}", rel.r#type),
739                ));
740                (join_rel::JoinType::Unspecified, join_type_error)
741            }
742        };
743
744        // Join condition
745        let condition = rel
746            .expression
747            .as_ref()
748            .map(|c| Value::Expression(c.as_ref()));
749        let condition = Value::expect(condition, || {
750            PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
751        });
752
753        // TODO: Add support for post_join_filter when grammar is extended
754        // Currently post_join_filter is not supported in the text format
755        // grammar
756        let positional = vec![join_type_value, condition];
757        let arguments = Some(Arguments {
758            positional,
759            named: vec![],
760        });
761
762        let emit = get_emit(rel.common.as_ref());
763        let columns = join_output_columns(join_type, left_columns, right_columns);
764
765        Relation {
766            name: "Join",
767            arguments,
768            columns,
769            emit,
770            children,
771        }
772    }
773}
774
775impl<'a> From<&'a SortField> for Value<'a> {
776    fn from(sf: &'a SortField) -> Self {
777        let field = match &sf.expr {
778            Some(expr) => match &expr.rex_type {
779                Some(substrait::proto::expression::RexType::Selection(fref)) => {
780                    if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
781                        if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
782                            Value::Reference(sf.field)
783                        } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
784                    } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
785                }
786                _ => Value::Missing(PlanError::unimplemented(
787                    "SortField",
788                    Some("expr"),
789                    "Not a selection",
790                )),
791            },
792            None => Value::Missing(PlanError::unimplemented(
793                "SortField",
794                Some("expr"),
795                "Missing expr",
796            )),
797        };
798        let direction = match &sf.sort_kind {
799            Some(kind) => Value::from(kind),
800            None => Value::Missing(PlanError::invalid(
801                "SortKind",
802                Some(Cow::Borrowed("sort_kind")),
803                "Missing sort_kind",
804            )),
805        };
806        Value::Tuple(vec![field, direction])
807    }
808}
809
810impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
811    fn from(enum_val: &'a T) -> Self {
812        match enum_val.as_enum_str() {
813            Ok(s) => Value::Enum(s),
814            Err(e) => Value::Missing(e),
815        }
816    }
817}
818
819impl ValueEnum for SortKind {
820    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
821        let d = match self {
822            &SortKind::Direction(d) => SortDirection::try_from(d),
823            SortKind::ComparisonFunctionReference(f) => {
824                return Err(PlanError::invalid(
825                    "SortKind",
826                    Some(Cow::Owned(format!("function reference{f}"))),
827                    "SortKind::ComparisonFunctionReference unimplemented",
828                ));
829            }
830        };
831        let s = match d {
832            Err(UnknownEnumValue(d)) => {
833                return Err(PlanError::invalid(
834                    "SortKind",
835                    Some(Cow::Owned(format!("unknown variant: {d:?}"))),
836                    "Unknown SortDirection",
837                ));
838            }
839            Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
840            Ok(SortDirection::AscNullsLast) => "AscNullsLast",
841            Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
842            Ok(SortDirection::DescNullsLast) => "DescNullsLast",
843            Ok(SortDirection::Clustered) => "Clustered",
844            Ok(SortDirection::Unspecified) => {
845                return Err(PlanError::invalid(
846                    "SortKind",
847                    Option::<Cow<str>>::None,
848                    "Unspecified SortDirection",
849                ));
850            }
851        };
852        Ok(Cow::Borrowed(s))
853    }
854}
855
856impl ValueEnum for join_rel::JoinType {
857    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
858        let s = match self {
859            join_rel::JoinType::Unspecified => {
860                return Err(PlanError::invalid(
861                    "JoinType",
862                    Option::<Cow<str>>::None,
863                    "Unspecified JoinType",
864                ));
865            }
866            join_rel::JoinType::Inner => "Inner",
867            join_rel::JoinType::Outer => "Outer",
868            join_rel::JoinType::Left => "Left",
869            join_rel::JoinType::Right => "Right",
870            join_rel::JoinType::LeftSemi => "LeftSemi",
871            join_rel::JoinType::RightSemi => "RightSemi",
872            join_rel::JoinType::LeftAnti => "LeftAnti",
873            join_rel::JoinType::RightAnti => "RightAnti",
874            join_rel::JoinType::LeftSingle => "LeftSingle",
875            join_rel::JoinType::RightSingle => "RightSingle",
876            join_rel::JoinType::LeftMark => "LeftMark",
877            join_rel::JoinType::RightMark => "RightMark",
878        };
879        Ok(Cow::Borrowed(s))
880    }
881}
882
883impl<'a> Textify for NamedArg<'a> {
884    fn name() -> &'static str {
885        "NamedArg"
886    }
887    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
888        write!(w, "{}=", self.name)?;
889        self.value.textify(ctx, w)
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use substrait::proto::expression::literal::LiteralType;
896    use substrait::proto::expression::{Literal, RexType, ScalarFunction};
897    use substrait::proto::function_argument::ArgType;
898    use substrait::proto::read_rel::{NamedTable, ReadType};
899    use substrait::proto::rel_common::Emit;
900    use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
901    use substrait::proto::{
902        Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
903    };
904
905    use super::*;
906    use crate::fixtures::TestContext;
907    use crate::parser::expressions::FieldIndex;
908
909    #[test]
910    fn test_read_rel() {
911        let ctx = TestContext::new();
912
913        // Create a simple ReadRel with a NamedStruct schema
914        let read_rel = ReadRel {
915            common: None,
916            base_schema: Some(NamedStruct {
917                names: vec!["col1".into(), "column 2".into()],
918                r#struct: Some(Struct {
919                    type_variation_reference: 0,
920                    types: vec![
921                        Type {
922                            kind: Some(Kind::I32(ptype::I32 {
923                                type_variation_reference: 0,
924                                nullability: Nullability::Nullable as i32,
925                            })),
926                        },
927                        Type {
928                            kind: Some(Kind::String(ptype::String {
929                                type_variation_reference: 0,
930                                nullability: Nullability::Nullable as i32,
931                            })),
932                        },
933                    ],
934                    nullability: Nullability::Nullable as i32,
935                }),
936            }),
937            filter: None,
938            best_effort_filter: None,
939            projection: None,
940            advanced_extension: None,
941            read_type: Some(ReadType::NamedTable(NamedTable {
942                names: vec!["some_db".into(), "test_table".into()],
943                advanced_extension: None,
944            })),
945        };
946
947        let rel = Relation::from(&read_rel);
948
949        let (result, errors) = ctx.textify(&rel);
950        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
951        assert_eq!(
952            result,
953            "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
954        );
955    }
956
957    #[test]
958    fn test_filter_rel() {
959        let ctx = TestContext::new()
960            .with_urn(1, "test_urn")
961            .with_function(1, 10, "gt");
962
963        // Create a simple FilterRel with a ReadRel input and a filter expression
964        let read_rel = ReadRel {
965            common: None,
966            base_schema: Some(NamedStruct {
967                names: vec!["col1".into(), "col2".into()],
968                r#struct: Some(Struct {
969                    type_variation_reference: 0,
970                    types: vec![
971                        Type {
972                            kind: Some(Kind::I32(ptype::I32 {
973                                type_variation_reference: 0,
974                                nullability: Nullability::Nullable as i32,
975                            })),
976                        },
977                        Type {
978                            kind: Some(Kind::I32(ptype::I32 {
979                                type_variation_reference: 0,
980                                nullability: Nullability::Nullable as i32,
981                            })),
982                        },
983                    ],
984                    nullability: Nullability::Nullable as i32,
985                }),
986            }),
987            filter: None,
988            best_effort_filter: None,
989            projection: None,
990            advanced_extension: None,
991            read_type: Some(ReadType::NamedTable(NamedTable {
992                names: vec!["test_table".into()],
993                advanced_extension: None,
994            })),
995        };
996
997        // Create a filter expression: col1 > 10
998        let filter_expr = Expression {
999            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1000                function_reference: 10, // gt function
1001                arguments: vec![
1002                    FunctionArgument {
1003                        arg_type: Some(ArgType::Value(Reference(0).into())),
1004                    },
1005                    FunctionArgument {
1006                        arg_type: Some(ArgType::Value(Expression {
1007                            rex_type: Some(RexType::Literal(Literal {
1008                                literal_type: Some(LiteralType::I32(10)),
1009                                nullable: false,
1010                                type_variation_reference: 0,
1011                            })),
1012                        })),
1013                    },
1014                ],
1015                options: vec![],
1016                output_type: None,
1017                #[allow(deprecated)]
1018                args: vec![],
1019            })),
1020        };
1021
1022        let filter_rel = FilterRel {
1023            common: None,
1024            input: Some(Box::new(Rel {
1025                rel_type: Some(RelType::Read(Box::new(read_rel))),
1026            })),
1027            condition: Some(Box::new(filter_expr)),
1028            advanced_extension: None,
1029        };
1030
1031        let rel = Rel {
1032            rel_type: Some(RelType::Filter(Box::new(filter_rel))),
1033        };
1034
1035        let rel = Relation::from(&rel);
1036
1037        let (result, errors) = ctx.textify(&rel);
1038        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1039        let expected = r#"
1040Filter[gt($0, 10:i32) => $0, $1]
1041  Read[test_table => col1:i32?, col2:i32?]"#
1042            .trim_start();
1043        assert_eq!(result, expected);
1044    }
1045
1046    #[test]
1047    fn test_aggregate_function_textify() {
1048        let ctx = TestContext::new()
1049        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1050        .with_function(1, 10, "sum")
1051        .with_function(1, 11, "count");
1052
1053        // Create a simple AggregateFunction
1054        let agg_fn = AggregateFunction {
1055            function_reference: 10, // sum
1056            arguments: vec![FunctionArgument {
1057                arg_type: Some(ArgType::Value(Expression {
1058                    rex_type: Some(RexType::Selection(Box::new(
1059                        FieldIndex(1).to_field_reference(),
1060                    ))),
1061                })),
1062            }],
1063            options: vec![],
1064            output_type: None,
1065            invocation: 0,
1066            phase: 0,
1067            sorts: vec![],
1068            #[allow(deprecated)]
1069            args: vec![],
1070        };
1071
1072        let value = Value::AggregateFunction(&agg_fn);
1073        let (result, errors) = ctx.textify(&value);
1074
1075        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1076        assert_eq!(result, "sum($1)");
1077    }
1078
1079    #[test]
1080    fn test_aggregate_relation_textify() {
1081        let ctx = TestContext::new()
1082        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1083        .with_function(1, 10, "sum")
1084        .with_function(1, 11, "count");
1085
1086        // Create a simple AggregateRel
1087        let agg_fn1 = AggregateFunction {
1088            function_reference: 10, // sum
1089            arguments: vec![FunctionArgument {
1090                arg_type: Some(ArgType::Value(Expression {
1091                    rex_type: Some(RexType::Selection(Box::new(
1092                        FieldIndex(1).to_field_reference(),
1093                    ))),
1094                })),
1095            }],
1096            options: vec![],
1097            output_type: None,
1098            invocation: 0,
1099            phase: 0,
1100            sorts: vec![],
1101            #[allow(deprecated)]
1102            args: vec![],
1103        };
1104
1105        let agg_fn2 = AggregateFunction {
1106            function_reference: 11, // count
1107            arguments: vec![FunctionArgument {
1108                arg_type: Some(ArgType::Value(Expression {
1109                    rex_type: Some(RexType::Selection(Box::new(
1110                        FieldIndex(1).to_field_reference(),
1111                    ))),
1112                })),
1113            }],
1114            options: vec![],
1115            output_type: None,
1116            invocation: 0,
1117            phase: 0,
1118            sorts: vec![],
1119            #[allow(deprecated)]
1120            args: vec![],
1121        };
1122
1123        let aggregate_rel = AggregateRel {
1124            input: Some(Box::new(Rel {
1125                rel_type: Some(RelType::Read(Box::new(ReadRel {
1126                    common: None,
1127                    base_schema: Some(NamedStruct {
1128                        names: vec!["category".into(), "amount".into()],
1129                        r#struct: Some(Struct {
1130                            type_variation_reference: 0,
1131                            types: vec![
1132                                Type {
1133                                    kind: Some(Kind::String(ptype::String {
1134                                        type_variation_reference: 0,
1135                                        nullability: Nullability::Nullable as i32,
1136                                    })),
1137                                },
1138                                Type {
1139                                    kind: Some(Kind::Fp64(ptype::Fp64 {
1140                                        type_variation_reference: 0,
1141                                        nullability: Nullability::Nullable as i32,
1142                                    })),
1143                                },
1144                            ],
1145                            nullability: Nullability::Nullable as i32,
1146                        }),
1147                    }),
1148                    filter: None,
1149                    best_effort_filter: None,
1150                    projection: None,
1151                    advanced_extension: None,
1152                    read_type: Some(ReadType::NamedTable(NamedTable {
1153                        names: vec!["orders".into()],
1154                        advanced_extension: None,
1155                    })),
1156                }))),
1157            })),
1158            grouping_expressions: vec![Expression {
1159                rex_type: Some(RexType::Selection(Box::new(
1160                    FieldIndex(0).to_field_reference(),
1161                ))),
1162            }],
1163            groupings: vec![],
1164            measures: vec![
1165                aggregate_rel::Measure {
1166                    measure: Some(agg_fn1),
1167                    filter: None,
1168                },
1169                aggregate_rel::Measure {
1170                    measure: Some(agg_fn2),
1171                    filter: None,
1172                },
1173            ],
1174            common: Some(RelCommon {
1175                emit_kind: Some(EmitKind::Emit(Emit {
1176                    output_mapping: vec![1, 2], // measures only
1177                })),
1178                ..Default::default()
1179            }),
1180            advanced_extension: None,
1181        };
1182
1183        let relation = Relation::from(&aggregate_rel);
1184        let (result, errors) = ctx.textify(&relation);
1185
1186        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1187        // Expected: Aggregate[$0 => sum($1), count($1)]
1188        assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
1189    }
1190
1191    #[test]
1192    fn test_arguments_textify_positional_only() {
1193        let ctx = TestContext::new();
1194        let args = Arguments {
1195            positional: vec![Value::Integer(42), Value::Integer(7)],
1196            named: vec![],
1197        };
1198        let (result, errors) = ctx.textify(&args);
1199        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1200        assert_eq!(result, "42, 7");
1201    }
1202
1203    #[test]
1204    fn test_arguments_textify_named_only() {
1205        let ctx = TestContext::new();
1206        let args = Arguments {
1207            positional: vec![],
1208            named: vec![
1209                NamedArg {
1210                    name: "limit",
1211                    value: Value::Integer(10),
1212                },
1213                NamedArg {
1214                    name: "offset",
1215                    value: Value::Integer(5),
1216                },
1217            ],
1218        };
1219        let (result, errors) = ctx.textify(&args);
1220        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1221        assert_eq!(result, "limit=10, offset=5");
1222    }
1223
1224    #[test]
1225    fn test_join_relation_unknown_type() {
1226        let ctx = TestContext::new();
1227
1228        // Create a join with an unknown/invalid type
1229        let join_rel = JoinRel {
1230            left: Some(Box::new(Rel {
1231                rel_type: Some(RelType::Read(Box::default())),
1232            })),
1233            right: Some(Box::new(Rel {
1234                rel_type: Some(RelType::Read(Box::default())),
1235            })),
1236            expression: Some(Box::new(Expression::default())),
1237            r#type: 999, // Invalid join type
1238            common: None,
1239            post_join_filter: None,
1240            advanced_extension: None,
1241        };
1242
1243        let relation = Relation::from(&join_rel);
1244        let (result, errors) = ctx.textify(&relation);
1245
1246        // Should contain error for unknown join type but still show condition and columns
1247        assert!(!errors.is_empty(), "Expected errors for unknown join type");
1248        assert!(
1249            result.contains("!{JoinRel}"),
1250            "Expected error token for unknown join type"
1251        );
1252        assert!(
1253            result.contains("Join["),
1254            "Expected Join relation to be formatted"
1255        );
1256    }
1257
1258    #[test]
1259    fn test_arguments_textify_both() {
1260        let ctx = TestContext::new();
1261        let args = Arguments {
1262            positional: vec![Value::Integer(1)],
1263            named: vec![NamedArg {
1264                name: "foo",
1265                value: Value::Integer(2),
1266            }],
1267        };
1268        let (result, errors) = ctx.textify(&args);
1269        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1270        assert_eq!(result, "1, foo=2");
1271    }
1272
1273    #[test]
1274    fn test_arguments_textify_empty() {
1275        let ctx = TestContext::new();
1276        let args = Arguments {
1277            positional: vec![],
1278            named: vec![],
1279        };
1280        let (result, errors) = ctx.textify(&args);
1281        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1282        assert_eq!(result, "_");
1283    }
1284
1285    #[test]
1286    fn test_named_arg_textify_error_token() {
1287        let ctx = TestContext::new();
1288        let named_arg = NamedArg {
1289            name: "foo",
1290            value: Value::Missing(PlanError::invalid(
1291                "my_enum",
1292                Some(Cow::Borrowed("my_enum")),
1293                Cow::Borrowed("my_enum"),
1294            )),
1295        };
1296        let (result, errors) = ctx.textify(&named_arg);
1297        // Should show !{my_enum} in the output
1298        assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1299        // Should also accumulate an error
1300        assert!(!errors.is_empty(), "Expected error for error token");
1301    }
1302
1303    #[test]
1304    fn test_join_type_enum_textify() {
1305        // Test that JoinType enum values convert correctly to their string representation
1306        assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1307        assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1308        assert_eq!(
1309            join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1310            "LeftSemi"
1311        );
1312        assert_eq!(
1313            join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1314            "LeftAnti"
1315        );
1316    }
1317
1318    #[test]
1319    fn test_join_output_columns() {
1320        // Test Inner join - outputs all columns from both sides
1321        let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1322        assert_eq!(inner_cols.len(), 5); // 2 + 3 = 5 columns
1323        assert!(matches!(inner_cols[0], Value::Reference(0)));
1324        assert!(matches!(inner_cols[4], Value::Reference(4)));
1325
1326        // Test LeftSemi join - outputs only left columns
1327        let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1328        assert_eq!(left_semi_cols.len(), 2); // Only left columns
1329        assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1330        assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1331
1332        // Test RightSemi join - outputs right columns as contiguous range starting from $0
1333        let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1334        assert_eq!(right_semi_cols.len(), 3); // Only right columns
1335        assert!(matches!(right_semi_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1336        assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1337        assert!(matches!(right_semi_cols[2], Value::Reference(2))); // Last right column
1338
1339        // Test LeftMark join - outputs left columns plus a mark column as contiguous range
1340        let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1341        assert_eq!(left_mark_cols.len(), 3); // 2 left + 1 mark
1342        assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1343        assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1344        assert!(matches!(left_mark_cols[2], Value::Reference(2))); // Mark column at contiguous position
1345
1346        // Test RightMark join - outputs right columns plus a mark column as contiguous range
1347        let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1348        assert_eq!(right_mark_cols.len(), 4); // 3 right + 1 mark
1349        assert!(matches!(right_mark_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1350        assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1351        assert!(matches!(right_mark_cols[2], Value::Reference(2))); // Last right column
1352        assert!(matches!(right_mark_cols[3], Value::Reference(3))); // Mark column at contiguous position
1353    }
1354}