substrait_explain/textify/
rels.rs

1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::fmt;
4use std::fmt::Debug;
5
6use prost::UnknownEnumValue;
7use substrait::proto::fetch_rel::CountMode;
8use substrait::proto::plan_rel::RelType as PlanRelType;
9use substrait::proto::read_rel::ReadType;
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::EmitKind;
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14    AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct,
15    PlanRel, ProjectRel, ReadRel, Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
16};
17
18use super::expressions::Reference;
19use super::types::Name;
20use super::{PlanError, Scope, Textify};
21
22pub trait NamedRelation {
23    fn name(&self) -> &'static str;
24}
25
26impl NamedRelation for Rel {
27    fn name(&self) -> &'static str {
28        match self.rel_type.as_ref() {
29            None => "UnknownRel",
30            Some(RelType::Read(_)) => "Read",
31            Some(RelType::Filter(_)) => "Filter",
32            Some(RelType::Project(_)) => "Project",
33            Some(RelType::Fetch(_)) => "Fetch",
34            Some(RelType::Aggregate(_)) => "Aggregate",
35            Some(RelType::Sort(_)) => "Sort",
36            Some(RelType::HashJoin(_)) => "HashJoin",
37            Some(RelType::Exchange(_)) => "Exchange",
38            Some(RelType::Join(_)) => "Join",
39            Some(RelType::Set(_)) => "Set",
40            Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
41            Some(RelType::Cross(_)) => "Cross",
42            Some(RelType::Reference(_)) => "Reference",
43            Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
44            Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
45            Some(RelType::Write(_)) => "Write",
46            Some(RelType::Ddl(_)) => "Ddl",
47            Some(RelType::Update(_)) => "Update",
48            Some(RelType::MergeJoin(_)) => "MergeJoin",
49            Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
50            Some(RelType::Window(_)) => "Window",
51            Some(RelType::Expand(_)) => "Expand",
52        }
53    }
54}
55
56/// Trait for enums that can be converted to a string representation for
57/// textification.
58///
59/// Returns Ok(str) for valid enum values, or Err([PlanError]) for invalid or
60/// unknown values.
61pub trait ValueEnum {
62    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
63}
64
65#[derive(Debug, Clone)]
66pub struct NamedArg<'a> {
67    pub name: &'a str,
68    pub value: Value<'a>,
69}
70
71#[derive(Debug, Clone)]
72pub enum Value<'a> {
73    Name(Name<'a>),
74    TableName(Vec<Name<'a>>),
75    Field(Option<Name<'a>>, Option<&'a Type>),
76    Tuple(Vec<Value<'a>>),
77    List(Vec<Value<'a>>),
78    Reference(i32),
79    Expression(&'a Expression),
80    AggregateFunction(&'a AggregateFunction),
81    /// Represents a missing, invalid, or unspecified value.
82    Missing(PlanError),
83    /// Represents a valid enum value as a string for textification.
84    Enum(Cow<'a, str>),
85    Integer(i32),
86}
87
88impl<'a> Value<'a> {
89    pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
90        match maybe_value {
91            Some(s) => s,
92            None => Value::Missing(f()),
93        }
94    }
95}
96
97impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
98    fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
99        match token {
100            Ok(value) => Value::TableName(value),
101            Err(err) => Value::Missing(err),
102        }
103    }
104}
105
106impl<'a> Textify for Value<'a> {
107    fn name() -> &'static str {
108        "Value"
109    }
110
111    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
112        match self {
113            Value::Name(name) => write!(w, "{}", ctx.display(name)),
114            Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
115            Value::Field(name, typ) => {
116                write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
117            }
118            Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
119            Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
120            Value::Reference(i) => write!(w, "{}", Reference(*i)),
121            Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
122            Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
123            Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
124            Value::Enum(res) => write!(w, "&{res}"),
125            Value::Integer(i) => write!(w, "{i}"),
126        }
127    }
128}
129
130fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
131    let mut fields = schema
132        .r#struct
133        .as_ref()
134        .map(|s| s.types.iter())
135        .into_iter()
136        .flatten();
137    let mut names = schema.names.iter();
138
139    // let field_count = schema.r#struct.as_ref().map(|s| s.types.len()).unwrap_or(0);
140    // let name_count = schema.names.len();
141
142    let mut values = Vec::new();
143    loop {
144        let field = fields.next();
145        let name = names.next().map(|n| Name(n));
146        if field.is_none() && name.is_none() {
147            break;
148        }
149
150        values.push(Value::Field(name, field));
151    }
152
153    values
154}
155
156struct Emitted<'a> {
157    pub values: &'a [Value<'a>],
158    pub emit: Option<&'a EmitKind>,
159}
160
161impl<'a> Emitted<'a> {
162    pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
163        Self { values, emit }
164    }
165
166    pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
167        write!(w, "{}", ctx.separated(self.values.iter(), ", "))
168    }
169}
170
171impl<'a> Textify for Emitted<'a> {
172    fn name() -> &'static str {
173        "Emitted"
174    }
175
176    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
177        if ctx.options().show_emit {
178            return self.write_direct(ctx, w);
179        }
180
181        let indices = match &self.emit {
182            Some(EmitKind::Emit(e)) => &e.output_mapping,
183            Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
184            None => return self.write_direct(ctx, w),
185        };
186
187        for (i, &index) in indices.iter().enumerate() {
188            if i > 0 {
189                write!(w, ", ")?;
190            }
191
192            match self.values.get(index as usize) {
193                Some(value) => write!(w, "{}", ctx.display(value))?,
194                None => write!(w, "{}", ctx.failure(PlanError::invalid(
195                    "Emitted",
196                    Some("output_mapping"),
197                    format!(
198                        "Output mapping index {} is out of bounds for values collection of size {}",
199                        index, self.values.len()
200                    )
201                )))?,
202            }
203        }
204
205        Ok(())
206    }
207}
208
209#[derive(Debug, Clone)]
210pub struct Arguments<'a> {
211    /// Positional arguments (e.g., a filter condition, group-bys, etc.))
212    pub positional: Vec<Value<'a>>,
213    /// Named arguments (e.g., limit=10, offset=5)
214    pub named: Vec<NamedArg<'a>>,
215}
216
217impl<'a> Textify for Arguments<'a> {
218    fn name() -> &'static str {
219        "Arguments"
220    }
221    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
222        if self.positional.is_empty() && self.named.is_empty() {
223            return write!(w, "_");
224        }
225
226        write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
227        if !self.positional.is_empty() && !self.named.is_empty() {
228            write!(w, ", ")?;
229        }
230        write!(w, "{}", ctx.separated(self.named.iter(), ", "))
231    }
232}
233
234pub struct Relation<'a> {
235    pub name: &'a str,
236    /// Arguments to the relation, if any.
237    ///
238    /// - `None` means this relation does not take arguments, and the argument
239    ///   section is omitted entirely.
240    /// - `Some(args)` with both vectors empty means the relation takes
241    ///   arguments, but none are provided; this will print as `_ => ...`.
242    /// - `Some(args)` with non-empty vectors will print as usual, with
243    ///   positional arguments first, then named arguments, separated by commas.
244    pub arguments: Option<Arguments<'a>>,
245    /// The columns emitted by this relation, pre-emit - the 'direct' column
246    /// output.
247    pub columns: Vec<Value<'a>>,
248    /// The emit kind, if any. If none, use the columns directly.
249    pub emit: Option<&'a EmitKind>,
250    /// The input relations.
251    pub children: Vec<Option<Relation<'a>>>,
252}
253
254impl Textify for Relation<'_> {
255    fn name() -> &'static str {
256        "Relation"
257    }
258
259    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
260        let cols = Emitted::new(&self.columns, self.emit);
261        let indent = ctx.indent();
262        let name = self.name;
263        let cols = ctx.display(&cols);
264        match &self.arguments {
265            None => {
266                write!(w, "{indent}{name}[{cols}]")?;
267            }
268            Some(args) => {
269                let args = ctx.display(args);
270                write!(w, "{indent}{name}[{args} => {cols}]")?;
271            }
272        }
273        let child_scope = ctx.push_indent();
274        for child in self.children.iter().flatten() {
275            writeln!(w)?;
276            child.textify(&child_scope, w)?;
277        }
278        Ok(())
279    }
280}
281
282impl<'a> Relation<'a> {
283    pub fn emitted(&self) -> usize {
284        match self.emit {
285            Some(EmitKind::Emit(e)) => e.output_mapping.len(),
286            Some(EmitKind::Direct(_)) => self.columns.len(),
287            None => self.columns.len(),
288        }
289    }
290}
291
292#[derive(Debug, Copy, Clone)]
293pub struct TableName<'a>(&'a [String]);
294
295impl<'a> Textify for TableName<'a> {
296    fn name() -> &'static str {
297        "TableName"
298    }
299
300    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
301        let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
302        write!(w, "{}", ctx.separated(names.iter(), "."))
303    }
304}
305
306pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
307    match rel {
308        Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
309        _ => Err(PlanError::unimplemented(
310            "ReadRel",
311            Some("table_name"),
312            format!("Unexpected read type {rel:?}") as String,
313        )),
314    }
315}
316
317impl<'a> From<&'a ReadRel> for Relation<'a> {
318    fn from(rel: &'a ReadRel) -> Self {
319        let name = get_table_name(rel.read_type.as_ref());
320        let table_name: Value = match name {
321            Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
322            Err(e) => Value::Missing(e),
323        };
324
325        let columns = match rel.base_schema {
326            Some(ref schema) => schema_to_values(schema),
327            None => {
328                let err = PlanError::unimplemented(
329                    "ReadRel",
330                    Some("base_schema"),
331                    "Base schema is required",
332                );
333                vec![Value::Missing(err)]
334            }
335        };
336        let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
337
338        Relation {
339            name: "Read",
340            arguments: Some(Arguments {
341                positional: vec![table_name],
342                named: vec![],
343            }),
344            columns,
345            emit,
346            children: vec![],
347        }
348    }
349}
350
351pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
352    rel.as_ref().and_then(|c| c.emit_kind.as_ref())
353}
354
355impl<'a> Relation<'a> {
356    /// Create a vector of values that are references to the emitted outputs of
357    /// this relation. "Emitted" here meaning the outputs of this relation after
358    /// the emit kind has been applied.
359    ///
360    /// This is useful for relations like Filter and Limit whose direct outputs
361    /// are primarily those of its children (direct here meaning before the emit
362    /// has been applied).
363    pub fn input_refs(&self) -> Vec<Value<'a>> {
364        let len = self.emitted();
365        (0..len).map(|i| Value::Reference(i as i32)).collect()
366    }
367
368    /// Convert a vector of relation references into their structured form.
369    ///
370    /// Returns a list of children (with None for ones missing), and a count of input columns.
371    pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
372        let mut children = vec![];
373        let mut inputs = 0;
374
375        for maybe_rel in refs {
376            match maybe_rel {
377                Some(rel) => {
378                    let child = Relation::from(rel);
379                    inputs += child.emitted();
380                    children.push(Some(child));
381                }
382                None => children.push(None),
383            }
384        }
385
386        (children, inputs)
387    }
388}
389
390impl<'a> From<&'a FilterRel> for Relation<'a> {
391    fn from(rel: &'a FilterRel) -> Self {
392        let condition = rel
393            .condition
394            .as_ref()
395            .map(|c| Value::Expression(c.as_ref()));
396        let condition = Value::expect(condition, || {
397            PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
398        });
399        let positional = vec![condition];
400        let arguments = Some(Arguments {
401            positional,
402            named: vec![],
403        });
404        let emit = get_emit(rel.common.as_ref());
405        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
406        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
407
408        Relation {
409            name: "Filter",
410            arguments,
411            columns,
412            emit,
413            children,
414        }
415    }
416}
417
418impl<'a> From<&'a ProjectRel> for Relation<'a> {
419    fn from(rel: &'a ProjectRel) -> Self {
420        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
421        let expressions = rel.expressions.iter().map(Value::Expression);
422        let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
423        columns.extend(expressions);
424
425        Relation {
426            name: "Project",
427            arguments: None,
428            columns,
429            emit: get_emit(rel.common.as_ref()),
430            children,
431        }
432    }
433}
434
435impl<'a> From<&'a Rel> for Relation<'a> {
436    fn from(rel: &'a Rel) -> Self {
437        match rel.rel_type.as_ref() {
438            Some(RelType::Read(r)) => Relation::from(r.as_ref()),
439            Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
440            Some(RelType::Project(r)) => Relation::from(r.as_ref()),
441            Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
442            Some(RelType::Sort(r)) => Relation::from(r.as_ref()),
443            Some(RelType::Fetch(r)) => Relation::from(r.as_ref()),
444            Some(RelType::Join(r)) => Relation::from(r.as_ref()),
445            _ => todo!(),
446        }
447    }
448}
449
450impl<'a> From<&'a AggregateRel> for Relation<'a> {
451    /// Convert an AggregateRel to a Relation for textification.
452    ///
453    /// The conversion follows this logic:
454    /// 1. Arguments: Group-by expressions (as Value::Expression)
455    /// 2. Columns: All possible outputs in order:
456    ///    - First: Group-by field references (Value::Reference)
457    ///    - Then: Aggregate function measures (Value::AggregateFunction)
458    /// 3. Emit: Uses the relation's emit mapping to select which outputs to display
459    /// 4. Children: The input relation
460    fn from(rel: &'a AggregateRel) -> Self {
461        // Arguments: group-by fields (as expressions)
462        let positional = rel
463            .grouping_expressions
464            .iter()
465            .map(Value::Expression)
466            .collect::<Vec<_>>();
467
468        let arguments = Some(Arguments {
469            positional,
470            named: vec![],
471        });
472        // The columns are the direct outputs of this relation (before emit)
473        let mut all_outputs: Vec<Value> = vec![];
474        let input_field_count = rel.grouping_expressions.len();
475        for i in 0..input_field_count {
476            all_outputs.push(Value::Reference(i as i32));
477        }
478
479        // Then, add all measures (aggregate functions)
480        // These are indexed after the group-by fields
481        for m in &rel.measures {
482            if let Some(agg_fn) = m.measure.as_ref() {
483                all_outputs.push(Value::AggregateFunction(agg_fn));
484            }
485        }
486        let emit = get_emit(rel.common.as_ref());
487        Relation {
488            name: "Aggregate",
489            arguments,
490            columns: all_outputs,
491            emit,
492            children: rel
493                .input
494                .as_ref()
495                .map(|c| Some(Relation::from(c.as_ref())))
496                .into_iter()
497                .collect(),
498        }
499    }
500}
501
502impl Textify for RelRoot {
503    fn name() -> &'static str {
504        "RelRoot"
505    }
506
507    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
508        let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
509
510        write!(
511            w,
512            "{}Root[{}]",
513            ctx.indent(),
514            ctx.separated(names.iter(), ", ")
515        )?;
516        let child_scope = ctx.push_indent();
517        for child in self.input.iter() {
518            let child = Relation::from(child);
519            writeln!(w)?;
520            child.textify(&child_scope, w)?;
521        }
522
523        Ok(())
524    }
525}
526
527impl Textify for PlanRelType {
528    fn name() -> &'static str {
529        "PlanRelType"
530    }
531
532    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
533        match self {
534            PlanRelType::Rel(rel) => Relation::from(rel).textify(ctx, w),
535            PlanRelType::Root(root) => root.textify(ctx, w),
536        }
537    }
538}
539
540impl Textify for PlanRel {
541    fn name() -> &'static str {
542        "PlanRel"
543    }
544
545    /// Write the relation as a string. Inputs are ignored - those are handled
546    /// separately.
547    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
548        write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
549    }
550}
551
552impl<'a> From<&'a SortRel> for Relation<'a> {
553    fn from(rel: &'a SortRel) -> Self {
554        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
555        let positional = rel.sorts.iter().map(Value::from).collect::<Vec<_>>();
556        let arguments = Some(Arguments {
557            positional,
558            named: vec![],
559        });
560        // The columns are the direct outputs of this relation (before emit)
561        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
562        let emit = get_emit(rel.common.as_ref());
563        Relation {
564            name: "Sort",
565            arguments,
566            columns,
567            emit,
568            children,
569        }
570    }
571}
572
573impl<'a> From<&'a FetchRel> for Relation<'a> {
574    fn from(rel: &'a FetchRel) -> Self {
575        let (children, _columns) = Relation::convert_children(vec![rel.input.as_deref()]);
576        let mut named_args = Vec::new();
577        match &rel.count_mode {
578            Some(CountMode::CountExpr(expr)) => {
579                named_args.push(NamedArg {
580                    name: "limit",
581                    value: Value::Expression(expr),
582                });
583            }
584            Some(CountMode::Count(val)) => {
585                named_args.push(NamedArg {
586                    name: "limit",
587                    value: Value::Integer(*val as i32),
588                });
589            }
590            None => {}
591        }
592        if let Some(offset) = &rel.offset_mode {
593            match offset {
594                substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
595                    named_args.push(NamedArg {
596                        name: "offset",
597                        value: Value::Expression(expr),
598                    });
599                }
600                substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
601                    named_args.push(NamedArg {
602                        name: "offset",
603                        value: Value::Integer(*val as i32),
604                    });
605                }
606            }
607        }
608
609        let emit = get_emit(rel.common.as_ref());
610        let columns = match emit {
611            Some(EmitKind::Emit(e)) => e
612                .output_mapping
613                .iter()
614                .map(|&i| Value::Reference(i))
615                .collect(),
616            _ => vec![],
617        };
618        Relation {
619            name: "Fetch",
620            arguments: Some(Arguments {
621                positional: vec![],
622                named: named_args,
623            }),
624            columns,
625            emit,
626            children,
627        }
628    }
629}
630
631fn join_output_columns(
632    join_type: join_rel::JoinType,
633    left_columns: usize,
634    right_columns: usize,
635) -> Vec<Value<'static>> {
636    let total_columns = match join_type {
637        // Inner, Left, Right, Outer joins output columns from both sides
638        join_rel::JoinType::Inner
639        | join_rel::JoinType::Left
640        | join_rel::JoinType::Right
641        | join_rel::JoinType::Outer => left_columns + right_columns,
642
643        // Left semi/anti joins only output columns from the left side
644        join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
645
646        // Right semi/anti joins output columns from the right side
647        join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
648
649        // Single joins behave like semi joins
650        join_rel::JoinType::LeftSingle => left_columns,
651        join_rel::JoinType::RightSingle => right_columns,
652
653        // Mark joins output base columns plus one mark column
654        join_rel::JoinType::LeftMark => left_columns + 1,
655        join_rel::JoinType::RightMark => right_columns + 1,
656
657        // Unspecified - fallback to all columns
658        join_rel::JoinType::Unspecified => left_columns + right_columns,
659    };
660
661    // Output is always a contiguous range starting from $0
662    (0..total_columns)
663        .map(|i| Value::Reference(i as i32))
664        .collect()
665}
666
667impl<'a> From<&'a JoinRel> for Relation<'a> {
668    fn from(rel: &'a JoinRel) -> Self {
669        let (children, _total_columns) =
670            Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()]);
671
672        // convert_children should preserve input vector length
673        assert_eq!(
674            children.len(),
675            2,
676            "convert_children should return same number of elements as input"
677        );
678
679        // Calculate left and right column counts separately
680        let left_columns = match &children[0] {
681            Some(child) => child.emitted(),
682            None => 0,
683        };
684        let right_columns = match &children[1] {
685            Some(child) => child.emitted(),
686            None => 0,
687        };
688
689        // Convert join type from protobuf i32 to enum value
690        // JoinType is stored as i32 in protobuf, convert to typed enum for processing
691        let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
692            Ok(join_type) => {
693                let join_type_value = match join_type.as_enum_str() {
694                    Ok(s) => Value::Enum(s),
695                    Err(e) => Value::Missing(e),
696                };
697                (join_type, join_type_value)
698            }
699            Err(_) => {
700                // Use Unspecified for the join_type but create an error for the join_type_value
701                let join_type_error = Value::Missing(PlanError::invalid(
702                    "JoinRel",
703                    Some("type"),
704                    format!("Unknown join type: {}", rel.r#type),
705                ));
706                (join_rel::JoinType::Unspecified, join_type_error)
707            }
708        };
709
710        // Join condition
711        let condition = rel
712            .expression
713            .as_ref()
714            .map(|c| Value::Expression(c.as_ref()));
715        let condition = Value::expect(condition, || {
716            PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
717        });
718
719        // TODO: Add support for post_join_filter when grammar is extended
720        // Currently post_join_filter is not supported in the text format
721        // grammar
722        let positional = vec![join_type_value, condition];
723        let arguments = Some(Arguments {
724            positional,
725            named: vec![],
726        });
727
728        let emit = get_emit(rel.common.as_ref());
729        let columns = join_output_columns(join_type, left_columns, right_columns);
730
731        Relation {
732            name: "Join",
733            arguments,
734            columns,
735            emit,
736            children,
737        }
738    }
739}
740
741impl<'a> From<&'a SortField> for Value<'a> {
742    fn from(sf: &'a SortField) -> Self {
743        let field = match &sf.expr {
744            Some(expr) => match &expr.rex_type {
745                Some(substrait::proto::expression::RexType::Selection(fref)) => {
746                    if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
747                        if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
748                            Value::Reference(sf.field)
749                        } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
750                    } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
751                }
752                _ => Value::Missing(PlanError::unimplemented(
753                    "SortField",
754                    Some("expr"),
755                    "Not a selection",
756                )),
757            },
758            None => Value::Missing(PlanError::unimplemented(
759                "SortField",
760                Some("expr"),
761                "Missing expr",
762            )),
763        };
764        let direction = match &sf.sort_kind {
765            Some(kind) => Value::from(kind),
766            None => Value::Missing(PlanError::invalid(
767                "SortKind",
768                Some(Cow::Borrowed("sort_kind")),
769                "Missing sort_kind",
770            )),
771        };
772        Value::Tuple(vec![field, direction])
773    }
774}
775
776impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
777    fn from(enum_val: &'a T) -> Self {
778        match enum_val.as_enum_str() {
779            Ok(s) => Value::Enum(s),
780            Err(e) => Value::Missing(e),
781        }
782    }
783}
784
785impl ValueEnum for SortKind {
786    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
787        let d = match self {
788            &SortKind::Direction(d) => SortDirection::try_from(d),
789            SortKind::ComparisonFunctionReference(f) => {
790                return Err(PlanError::invalid(
791                    "SortKind",
792                    Some(Cow::Owned(format!("function reference{f}"))),
793                    "SortKind::ComparisonFunctionReference unimplemented",
794                ));
795            }
796        };
797        let s = match d {
798            Err(UnknownEnumValue(d)) => {
799                return Err(PlanError::invalid(
800                    "SortKind",
801                    Some(Cow::Owned(format!("unknown variant: {d:?}"))),
802                    "Unknown SortDirection",
803                ));
804            }
805            Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
806            Ok(SortDirection::AscNullsLast) => "AscNullsLast",
807            Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
808            Ok(SortDirection::DescNullsLast) => "DescNullsLast",
809            Ok(SortDirection::Clustered) => "Clustered",
810            Ok(SortDirection::Unspecified) => {
811                return Err(PlanError::invalid(
812                    "SortKind",
813                    Option::<Cow<str>>::None,
814                    "Unspecified SortDirection",
815                ));
816            }
817        };
818        Ok(Cow::Borrowed(s))
819    }
820}
821
822impl ValueEnum for join_rel::JoinType {
823    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
824        let s = match self {
825            join_rel::JoinType::Unspecified => {
826                return Err(PlanError::invalid(
827                    "JoinType",
828                    Option::<Cow<str>>::None,
829                    "Unspecified JoinType",
830                ));
831            }
832            join_rel::JoinType::Inner => "Inner",
833            join_rel::JoinType::Outer => "Outer",
834            join_rel::JoinType::Left => "Left",
835            join_rel::JoinType::Right => "Right",
836            join_rel::JoinType::LeftSemi => "LeftSemi",
837            join_rel::JoinType::RightSemi => "RightSemi",
838            join_rel::JoinType::LeftAnti => "LeftAnti",
839            join_rel::JoinType::RightAnti => "RightAnti",
840            join_rel::JoinType::LeftSingle => "LeftSingle",
841            join_rel::JoinType::RightSingle => "RightSingle",
842            join_rel::JoinType::LeftMark => "LeftMark",
843            join_rel::JoinType::RightMark => "RightMark",
844        };
845        Ok(Cow::Borrowed(s))
846    }
847}
848
849impl<'a> Textify for NamedArg<'a> {
850    fn name() -> &'static str {
851        "NamedArg"
852    }
853    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
854        write!(w, "{}=", self.name)?;
855        self.value.textify(ctx, w)
856    }
857}
858
859#[cfg(test)]
860mod tests {
861    use substrait::proto::expression::literal::LiteralType;
862    use substrait::proto::expression::{Literal, RexType, ScalarFunction};
863    use substrait::proto::function_argument::ArgType;
864    use substrait::proto::read_rel::{NamedTable, ReadType};
865    use substrait::proto::rel_common::Emit;
866    use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
867    use substrait::proto::{
868        Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
869    };
870
871    use super::*;
872    use crate::fixtures::TestContext;
873    use crate::parser::expressions::FieldIndex;
874
875    #[test]
876    fn test_read_rel() {
877        let ctx = TestContext::new();
878
879        // Create a simple ReadRel with a NamedStruct schema
880        let read_rel = ReadRel {
881            common: None,
882            base_schema: Some(NamedStruct {
883                names: vec!["col1".into(), "column 2".into()],
884                r#struct: Some(Struct {
885                    type_variation_reference: 0,
886                    types: vec![
887                        Type {
888                            kind: Some(Kind::I32(ptype::I32 {
889                                type_variation_reference: 0,
890                                nullability: Nullability::Nullable as i32,
891                            })),
892                        },
893                        Type {
894                            kind: Some(Kind::String(ptype::String {
895                                type_variation_reference: 0,
896                                nullability: Nullability::Nullable as i32,
897                            })),
898                        },
899                    ],
900                    nullability: Nullability::Nullable as i32,
901                }),
902            }),
903            filter: None,
904            best_effort_filter: None,
905            projection: None,
906            advanced_extension: None,
907            read_type: Some(ReadType::NamedTable(NamedTable {
908                names: vec!["some_db".into(), "test_table".into()],
909                advanced_extension: None,
910            })),
911        };
912
913        let rel = Relation::from(&read_rel);
914
915        let (result, errors) = ctx.textify(&rel);
916        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
917        assert_eq!(
918            result,
919            "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
920        );
921    }
922
923    #[test]
924    fn test_filter_rel() {
925        let ctx = TestContext::new()
926            .with_uri(1, "test_uri")
927            .with_function(1, 10, "gt");
928
929        // Create a simple FilterRel with a ReadRel input and a filter expression
930        let read_rel = ReadRel {
931            common: None,
932            base_schema: Some(NamedStruct {
933                names: vec!["col1".into(), "col2".into()],
934                r#struct: Some(Struct {
935                    type_variation_reference: 0,
936                    types: vec![
937                        Type {
938                            kind: Some(Kind::I32(ptype::I32 {
939                                type_variation_reference: 0,
940                                nullability: Nullability::Nullable as i32,
941                            })),
942                        },
943                        Type {
944                            kind: Some(Kind::I32(ptype::I32 {
945                                type_variation_reference: 0,
946                                nullability: Nullability::Nullable as i32,
947                            })),
948                        },
949                    ],
950                    nullability: Nullability::Nullable as i32,
951                }),
952            }),
953            filter: None,
954            best_effort_filter: None,
955            projection: None,
956            advanced_extension: None,
957            read_type: Some(ReadType::NamedTable(NamedTable {
958                names: vec!["test_table".into()],
959                advanced_extension: None,
960            })),
961        };
962
963        // Create a filter expression: col1 > 10
964        let filter_expr = Expression {
965            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
966                function_reference: 10, // gt function
967                arguments: vec![
968                    FunctionArgument {
969                        arg_type: Some(ArgType::Value(Reference(0).into())),
970                    },
971                    FunctionArgument {
972                        arg_type: Some(ArgType::Value(Expression {
973                            rex_type: Some(RexType::Literal(Literal {
974                                literal_type: Some(LiteralType::I32(10)),
975                                nullable: false,
976                                type_variation_reference: 0,
977                            })),
978                        })),
979                    },
980                ],
981                options: vec![],
982                output_type: None,
983                #[allow(deprecated)]
984                args: vec![],
985            })),
986        };
987
988        let filter_rel = FilterRel {
989            common: None,
990            input: Some(Box::new(Rel {
991                rel_type: Some(RelType::Read(Box::new(read_rel))),
992            })),
993            condition: Some(Box::new(filter_expr)),
994            advanced_extension: None,
995        };
996
997        let rel = Rel {
998            rel_type: Some(RelType::Filter(Box::new(filter_rel))),
999        };
1000
1001        let rel = Relation::from(&rel);
1002
1003        let (result, errors) = ctx.textify(&rel);
1004        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1005        let expected = r#"
1006Filter[gt($0, 10:i32) => $0, $1]
1007  Read[test_table => col1:i32?, col2:i32?]"#
1008            .trim_start();
1009        assert_eq!(result, expected);
1010    }
1011
1012    #[test]
1013    fn test_aggregate_function_textify() {
1014        let ctx = TestContext::new()
1015        .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1016        .with_function(1, 10, "sum")
1017        .with_function(1, 11, "count");
1018
1019        // Create a simple AggregateFunction
1020        let agg_fn = AggregateFunction {
1021            function_reference: 10, // sum
1022            arguments: vec![FunctionArgument {
1023                arg_type: Some(ArgType::Value(Expression {
1024                    rex_type: Some(RexType::Selection(Box::new(
1025                        FieldIndex(1).to_field_reference(),
1026                    ))),
1027                })),
1028            }],
1029            options: vec![],
1030            output_type: None,
1031            invocation: 0,
1032            phase: 0,
1033            sorts: vec![],
1034            #[allow(deprecated)]
1035            args: vec![],
1036        };
1037
1038        let value = Value::AggregateFunction(&agg_fn);
1039        let (result, errors) = ctx.textify(&value);
1040
1041        println!("Textification result: {result}");
1042        if !errors.is_empty() {
1043            println!("Errors: {errors:?}");
1044        }
1045
1046        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1047        assert_eq!(result, "sum($1)");
1048    }
1049
1050    #[test]
1051    fn test_aggregate_relation_textify() {
1052        let ctx = TestContext::new()
1053        .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1054        .with_function(1, 10, "sum")
1055        .with_function(1, 11, "count");
1056
1057        // Create a simple AggregateRel
1058        let agg_fn1 = AggregateFunction {
1059            function_reference: 10, // sum
1060            arguments: vec![FunctionArgument {
1061                arg_type: Some(ArgType::Value(Expression {
1062                    rex_type: Some(RexType::Selection(Box::new(
1063                        FieldIndex(1).to_field_reference(),
1064                    ))),
1065                })),
1066            }],
1067            options: vec![],
1068            output_type: None,
1069            invocation: 0,
1070            phase: 0,
1071            sorts: vec![],
1072            #[allow(deprecated)]
1073            args: vec![],
1074        };
1075
1076        let agg_fn2 = AggregateFunction {
1077            function_reference: 11, // count
1078            arguments: vec![FunctionArgument {
1079                arg_type: Some(ArgType::Value(Expression {
1080                    rex_type: Some(RexType::Selection(Box::new(
1081                        FieldIndex(1).to_field_reference(),
1082                    ))),
1083                })),
1084            }],
1085            options: vec![],
1086            output_type: None,
1087            invocation: 0,
1088            phase: 0,
1089            sorts: vec![],
1090            #[allow(deprecated)]
1091            args: vec![],
1092        };
1093
1094        let aggregate_rel = AggregateRel {
1095            input: Some(Box::new(Rel {
1096                rel_type: Some(RelType::Read(Box::new(ReadRel {
1097                    common: None,
1098                    base_schema: Some(NamedStruct {
1099                        names: vec!["category".into(), "amount".into()],
1100                        r#struct: Some(Struct {
1101                            type_variation_reference: 0,
1102                            types: vec![
1103                                Type {
1104                                    kind: Some(Kind::String(ptype::String {
1105                                        type_variation_reference: 0,
1106                                        nullability: Nullability::Nullable as i32,
1107                                    })),
1108                                },
1109                                Type {
1110                                    kind: Some(Kind::Fp64(ptype::Fp64 {
1111                                        type_variation_reference: 0,
1112                                        nullability: Nullability::Nullable as i32,
1113                                    })),
1114                                },
1115                            ],
1116                            nullability: Nullability::Nullable as i32,
1117                        }),
1118                    }),
1119                    filter: None,
1120                    best_effort_filter: None,
1121                    projection: None,
1122                    advanced_extension: None,
1123                    read_type: Some(ReadType::NamedTable(NamedTable {
1124                        names: vec!["orders".into()],
1125                        advanced_extension: None,
1126                    })),
1127                }))),
1128            })),
1129            grouping_expressions: vec![Expression {
1130                rex_type: Some(RexType::Selection(Box::new(
1131                    FieldIndex(0).to_field_reference(),
1132                ))),
1133            }],
1134            groupings: vec![],
1135            measures: vec![
1136                aggregate_rel::Measure {
1137                    measure: Some(agg_fn1),
1138                    filter: None,
1139                },
1140                aggregate_rel::Measure {
1141                    measure: Some(agg_fn2),
1142                    filter: None,
1143                },
1144            ],
1145            common: Some(RelCommon {
1146                emit_kind: Some(EmitKind::Emit(Emit {
1147                    output_mapping: vec![1, 2], // measures only
1148                })),
1149                ..Default::default()
1150            }),
1151            advanced_extension: None,
1152        };
1153
1154        let relation = Relation::from(&aggregate_rel);
1155        let (result, errors) = ctx.textify(&relation);
1156
1157        println!("Aggregate relation textification result:");
1158        println!("{result}");
1159        if !errors.is_empty() {
1160            println!("Errors: {errors:?}");
1161        }
1162
1163        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1164        // Expected: Aggregate[$0 => sum($1), count($1)]
1165        assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
1166    }
1167
1168    #[test]
1169    fn test_arguments_textify_positional_only() {
1170        let ctx = TestContext::new();
1171        let args = Arguments {
1172            positional: vec![Value::Integer(42), Value::Integer(7)],
1173            named: vec![],
1174        };
1175        let (result, errors) = ctx.textify(&args);
1176        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1177        assert_eq!(result, "42, 7");
1178    }
1179
1180    #[test]
1181    fn test_arguments_textify_named_only() {
1182        let ctx = TestContext::new();
1183        let args = Arguments {
1184            positional: vec![],
1185            named: vec![
1186                NamedArg {
1187                    name: "limit",
1188                    value: Value::Integer(10),
1189                },
1190                NamedArg {
1191                    name: "offset",
1192                    value: Value::Integer(5),
1193                },
1194            ],
1195        };
1196        let (result, errors) = ctx.textify(&args);
1197        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1198        assert_eq!(result, "limit=10, offset=5");
1199    }
1200
1201    #[test]
1202    fn test_join_relation_unknown_type() {
1203        let ctx = TestContext::new();
1204
1205        // Create a join with an unknown/invalid type
1206        let join_rel = JoinRel {
1207            left: Some(Box::new(Rel {
1208                rel_type: Some(RelType::Read(Box::default())),
1209            })),
1210            right: Some(Box::new(Rel {
1211                rel_type: Some(RelType::Read(Box::default())),
1212            })),
1213            expression: Some(Box::new(Expression::default())),
1214            r#type: 999, // Invalid join type
1215            common: None,
1216            post_join_filter: None,
1217            advanced_extension: None,
1218        };
1219
1220        let relation = Relation::from(&join_rel);
1221        let (result, errors) = ctx.textify(&relation);
1222
1223        // Should contain error for unknown join type but still show condition and columns
1224        assert!(!errors.is_empty(), "Expected errors for unknown join type");
1225        assert!(
1226            result.contains("!{JoinRel}"),
1227            "Expected error token for unknown join type"
1228        );
1229        assert!(
1230            result.contains("Join["),
1231            "Expected Join relation to be formatted"
1232        );
1233        println!("Unknown join type result: {result}");
1234    }
1235
1236    #[test]
1237    fn test_arguments_textify_both() {
1238        let ctx = TestContext::new();
1239        let args = Arguments {
1240            positional: vec![Value::Integer(1)],
1241            named: vec![NamedArg {
1242                name: "foo",
1243                value: Value::Integer(2),
1244            }],
1245        };
1246        let (result, errors) = ctx.textify(&args);
1247        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1248        assert_eq!(result, "1, foo=2");
1249    }
1250
1251    #[test]
1252    fn test_arguments_textify_empty() {
1253        let ctx = TestContext::new();
1254        let args = Arguments {
1255            positional: vec![],
1256            named: vec![],
1257        };
1258        let (result, errors) = ctx.textify(&args);
1259        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1260        assert_eq!(result, "_");
1261    }
1262
1263    #[test]
1264    fn test_named_arg_textify_error_token() {
1265        let ctx = TestContext::new();
1266        let named_arg = NamedArg {
1267            name: "foo",
1268            value: Value::Missing(PlanError::invalid(
1269                "my_enum",
1270                Some(Cow::Borrowed("my_enum")),
1271                Cow::Borrowed("my_enum"),
1272            )),
1273        };
1274        let (result, errors) = ctx.textify(&named_arg);
1275        // Should show !{my_enum} in the output
1276        assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1277        // Should also accumulate an error
1278        assert!(!errors.is_empty(), "Expected error for error token");
1279    }
1280
1281    #[test]
1282    fn test_join_type_enum_textify() {
1283        // Test that JoinType enum values convert correctly to their string representation
1284        assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1285        assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1286        assert_eq!(
1287            join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1288            "LeftSemi"
1289        );
1290        assert_eq!(
1291            join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1292            "LeftAnti"
1293        );
1294    }
1295
1296    #[test]
1297    fn test_join_output_columns() {
1298        // Test Inner join - outputs all columns from both sides
1299        let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1300        assert_eq!(inner_cols.len(), 5); // 2 + 3 = 5 columns
1301        assert!(matches!(inner_cols[0], Value::Reference(0)));
1302        assert!(matches!(inner_cols[4], Value::Reference(4)));
1303
1304        // Test LeftSemi join - outputs only left columns
1305        let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1306        assert_eq!(left_semi_cols.len(), 2); // Only left columns
1307        assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1308        assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1309
1310        // Test RightSemi join - outputs right columns as contiguous range starting from $0
1311        let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1312        assert_eq!(right_semi_cols.len(), 3); // Only right columns
1313        assert!(matches!(right_semi_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1314        assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1315        assert!(matches!(right_semi_cols[2], Value::Reference(2))); // Last right column
1316
1317        // Test LeftMark join - outputs left columns plus a mark column as contiguous range
1318        let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1319        assert_eq!(left_mark_cols.len(), 3); // 2 left + 1 mark
1320        assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1321        assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1322        assert!(matches!(left_mark_cols[2], Value::Reference(2))); // Mark column at contiguous position
1323
1324        // Test RightMark join - outputs right columns plus a mark column as contiguous range
1325        let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1326        assert_eq!(right_mark_cols.len(), 4); // 3 right + 1 mark
1327        assert!(matches!(right_mark_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1328        assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1329        assert!(matches!(right_mark_cols[2], Value::Reference(2))); // Last right column
1330        assert!(matches!(right_mark_cols[3], Value::Reference(3))); // Mark column at contiguous position
1331    }
1332}