substrait_explain/textify/
rels.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::convert::TryFrom;
4use std::fmt;
5use std::fmt::Debug;
6
7use prost::{Message, UnknownEnumValue};
8use substrait::proto::extensions::AdvancedExtension;
9use substrait::proto::fetch_rel::CountMode;
10use substrait::proto::plan_rel::RelType as PlanRelType;
11use substrait::proto::read_rel::ReadType;
12use substrait::proto::rel::RelType;
13use substrait::proto::rel_common::EmitKind;
14use substrait::proto::sort_field::{SortDirection, SortKind};
15use substrait::proto::{
16    AggregateFunction, AggregateRel, Expression, ExtensionLeafRel, ExtensionMultiRel,
17    ExtensionSingleRel, FetchRel, FilterRel, JoinRel, NamedStruct, PlanRel, ProjectRel, ReadRel,
18    Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
19};
20
21use super::expressions::Reference;
22use super::types::Name;
23use super::{PlanError, Scope, Textify};
24use crate::FormatError;
25use crate::extensions::any::AnyRef;
26use crate::extensions::{ExtensionColumn, ExtensionValue};
27
28pub trait NamedRelation {
29    fn name(&self) -> &'static str;
30}
31
32impl NamedRelation for Rel {
33    fn name(&self) -> &'static str {
34        match self.rel_type.as_ref() {
35            None => "UnknownRel",
36            Some(RelType::Read(_)) => "Read",
37            Some(RelType::Filter(_)) => "Filter",
38            Some(RelType::Project(_)) => "Project",
39            Some(RelType::Fetch(_)) => "Fetch",
40            Some(RelType::Aggregate(_)) => "Aggregate",
41            Some(RelType::Sort(_)) => "Sort",
42            Some(RelType::HashJoin(_)) => "HashJoin",
43            Some(RelType::Exchange(_)) => "Exchange",
44            Some(RelType::Join(_)) => "Join",
45            Some(RelType::Set(_)) => "Set",
46            Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
47            Some(RelType::Cross(_)) => "Cross",
48            Some(RelType::Reference(_)) => "Reference",
49            Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
50            Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
51            Some(RelType::Write(_)) => "Write",
52            Some(RelType::Ddl(_)) => "Ddl",
53            Some(RelType::Update(_)) => "Update",
54            Some(RelType::MergeJoin(_)) => "MergeJoin",
55            Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
56            Some(RelType::Window(_)) => "Window",
57            Some(RelType::Expand(_)) => "Expand",
58        }
59    }
60}
61
62impl Textify for Rel {
63    fn name() -> &'static str {
64        "Rel"
65    }
66
67    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
68        // delegates to `Relation` which carries `advanced_extension`, so the full
69        // header → enhancement → children sequence is handled uniformly there.
70        Relation::from_rel(self, ctx).textify(ctx, w)
71    }
72}
73
74/// Trait for enums that can be converted to a string representation for
75/// textification.
76///
77/// Returns Ok(str) for valid enum values, or Err([PlanError]) for invalid or
78/// unknown values.
79pub trait ValueEnum {
80    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
81}
82
83#[derive(Debug, Clone)]
84pub struct NamedArg<'a> {
85    pub name: Cow<'a, str>,
86    pub value: Value<'a>,
87}
88
89#[derive(Debug, Clone)]
90pub enum Value<'a> {
91    Name(Name<'a>),
92    TableName(Vec<Name<'a>>),
93    Field(Option<Name<'a>>, Option<&'a Type>),
94    Tuple(Vec<Value<'a>>),
95    List(Vec<Value<'a>>),
96    Reference(i32),
97    Expression(&'a Expression),
98    AggregateFunction(&'a AggregateFunction),
99    /// Represents a missing, invalid, or unspecified value.
100    Missing(PlanError),
101    /// Represents a valid enum value as a string for textification.
102    Enum(Cow<'a, str>),
103    EmptyGroup,
104    Integer(i64),
105    Float(f64),
106    Boolean(bool),
107    /// A decoded extension argument value.
108    ExtValue(ExtensionValue),
109    /// A decoded extension output column.
110    ExtColumn(ExtensionColumn),
111}
112
113impl<'a> Value<'a> {
114    pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
115        match maybe_value {
116            Some(s) => s,
117            None => Value::Missing(f()),
118        }
119    }
120}
121
122impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
123    fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
124        match token {
125            Ok(value) => Value::TableName(value),
126            Err(err) => Value::Missing(err),
127        }
128    }
129}
130
131impl<'a> Textify for Value<'a> {
132    fn name() -> &'static str {
133        "Value"
134    }
135
136    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
137        match self {
138            Value::Name(name) => write!(w, "{}", ctx.display(name)),
139            Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
140            Value::Field(name, typ) => {
141                write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
142            }
143            Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
144            Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
145            Value::Reference(i) => write!(w, "{}", Reference(*i)),
146            Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
147            Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
148            Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
149            Value::Enum(res) => write!(w, "&{res}"),
150            Value::Integer(i) => write!(w, "{i}"),
151            Value::EmptyGroup => write!(w, "_"),
152            Value::Float(f) => write!(w, "{f}"),
153            Value::Boolean(b) => write!(w, "{b}"),
154            Value::ExtValue(ev) => ev.textify(ctx, w),
155            Value::ExtColumn(ec) => ec.textify(ctx, w),
156        }
157    }
158}
159
160fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
161    let mut fields = schema
162        .r#struct
163        .as_ref()
164        .map(|s| s.types.iter())
165        .into_iter()
166        .flatten();
167    let mut names = schema.names.iter();
168
169    // let field_count = schema.r#struct.as_ref().map(|s| s.types.len()).unwrap_or(0);
170    // let name_count = schema.names.len();
171
172    let mut values = Vec::new();
173    loop {
174        let field = fields.next();
175        let name = names.next().map(|n| Name(n));
176        if field.is_none() && name.is_none() {
177            break;
178        }
179
180        values.push(Value::Field(name, field));
181    }
182
183    values
184}
185
186struct Emitted<'a> {
187    pub values: &'a [Value<'a>],
188    pub emit: Option<&'a EmitKind>,
189}
190
191impl<'a> Emitted<'a> {
192    pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
193        Self { values, emit }
194    }
195
196    pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
197        write!(w, "{}", ctx.separated(self.values.iter(), ", "))
198    }
199}
200
201impl<'a> Textify for Emitted<'a> {
202    fn name() -> &'static str {
203        "Emitted"
204    }
205
206    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
207        if ctx.options().show_emit {
208            return self.write_direct(ctx, w);
209        }
210
211        let indices = match &self.emit {
212            Some(EmitKind::Emit(e)) => &e.output_mapping,
213            Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
214            None => return self.write_direct(ctx, w),
215        };
216
217        for (i, &index) in indices.iter().enumerate() {
218            if i > 0 {
219                write!(w, ", ")?;
220            }
221
222            match self.values.get(index as usize) {
223                Some(value) => write!(w, "{}", ctx.display(value))?,
224                None => write!(w, "{}", ctx.failure(PlanError::invalid(
225                    "Emitted",
226                    Some("output_mapping"),
227                    format!(
228                        "Output mapping index {} is out of bounds for values collection of size {}",
229                        index, self.values.len()
230                    )
231                )))?,
232            }
233        }
234
235        Ok(())
236    }
237}
238
239#[derive(Debug, Clone)]
240pub struct Arguments<'a> {
241    /// Positional arguments (e.g., a filter condition, group-bys, etc.)
242    pub positional: Vec<Value<'a>>,
243    /// Named arguments (e.g., limit=10, offset=5)
244    pub named: Vec<NamedArg<'a>>,
245}
246
247impl<'a> Textify for Arguments<'a> {
248    fn name() -> &'static str {
249        "Arguments"
250    }
251    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
252        if self.positional.is_empty() && self.named.is_empty() {
253            return write!(w, "_");
254        }
255
256        write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
257        if !self.positional.is_empty() && !self.named.is_empty() {
258            write!(w, ", ")?;
259        }
260        write!(w, "{}", ctx.separated(self.named.iter(), ", "))
261    }
262}
263
264pub struct Relation<'a> {
265    pub name: Cow<'a, str>,
266    /// Arguments to the relation, if any.
267    ///
268    /// - `None` means this relation does not take arguments, and the argument
269    ///   section is omitted entirely.
270    /// - `Some(args)` with both vectors empty means the relation takes
271    ///   arguments, but none are provided; this will print as `_ => ...`.
272    /// - `Some(args)` with non-empty vectors will print as usual, with
273    ///   positional arguments first, then named arguments, separated by commas.
274    pub arguments: Option<Arguments<'a>>,
275    /// The columns emitted by this relation, pre-emit - the 'direct' column
276    /// output.
277    pub columns: Vec<Value<'a>>,
278    /// The emit kind, if any. If none, use the columns directly.
279    pub emit: Option<&'a EmitKind>,
280    /// The advanced extension (enhancement and/or optimizations) attached to
281    /// this relation, if any.  Mirrors the `advanced_extension` field carried
282    /// by standard relation types in the protobuf (Read, Filter, Project, etc...)
283    ///  Extension relations (`ExtensionLeaf`, `ExtensionSingle`, `ExtensionMulti`)
284    /// do not carry this field and always set it to `None`.
285    pub advanced_extension: Option<&'a AdvancedExtension>,
286    /// The input relations.
287    pub children: Vec<Option<Relation<'a>>>,
288}
289
290impl Textify for Relation<'_> {
291    fn name() -> &'static str {
292        "Relation"
293    }
294
295    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
296        self.write_header(ctx, w)?;
297        // Emit any enhancement / optimizations between the header line and the
298        // child relations, indented one level deeper than this relation — the
299        // same position they occupy in the text format when parsed.
300        if let Some(adv_ext) = self.advanced_extension {
301            let child_scope = ctx.push_indent();
302            adv_ext.textify(&child_scope, w)?;
303        }
304        self.write_children(ctx, w)?;
305        Ok(())
306    }
307}
308
309impl Relation<'_> {
310    /// Write the single header line for this relation, e.g. `Filter[$0 => $0]`.
311    /// Does not write a trailing newline; callers are responsible for any
312    /// newline that follows (either from adv_ext or from the next child).
313    pub fn write_header<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
314        let cols = Emitted::new(&self.columns, self.emit);
315        let indent = ctx.indent();
316        let name = &self.name;
317        let cols = ctx.display(&cols);
318        match &self.arguments {
319            None => {
320                write!(w, "{indent}{name}[{cols}]")
321            }
322            Some(args) => {
323                let args = ctx.display(args);
324                if self.columns.is_empty() && self.emit.is_none() {
325                    write!(w, "{indent}{name}[{args}]")
326                } else {
327                    write!(w, "{indent}{name}[{args} => {cols}]")
328                }
329            }
330        }
331    }
332
333    /// Write each child relation at one indent level deeper than `ctx`.
334    /// Each child is preceded by a newline.
335    pub fn write_children<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
336        let child_scope = ctx.push_indent();
337        for child in self.children.iter().flatten() {
338            writeln!(w)?;
339            child.textify(&child_scope, w)?;
340        }
341        Ok(())
342    }
343}
344
345impl<'a> Relation<'a> {
346    pub fn emitted(&self) -> usize {
347        match self.emit {
348            Some(EmitKind::Emit(e)) => e.output_mapping.len(),
349            Some(EmitKind::Direct(_)) => self.columns.len(),
350            None => self.columns.len(),
351        }
352    }
353}
354
355#[derive(Debug, Copy, Clone)]
356pub struct TableName<'a>(&'a [String]);
357
358impl<'a> Textify for TableName<'a> {
359    fn name() -> &'static str {
360        "TableName"
361    }
362
363    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
364        let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
365        write!(w, "{}", ctx.separated(names.iter(), "."))
366    }
367}
368
369pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
370    match rel {
371        Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
372        _ => Err(PlanError::unimplemented(
373            "ReadRel",
374            Some("table_name"),
375            format!("Unexpected read type {rel:?}") as String,
376        )),
377    }
378}
379
380impl<'a> Relation<'a> {
381    fn from_read<S: Scope>(rel: &'a ReadRel, _ctx: &S) -> Self {
382        let name = get_table_name(rel.read_type.as_ref());
383        let table_name: Value = match name {
384            Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
385            Err(e) => Value::Missing(e),
386        };
387
388        let columns = match rel.base_schema {
389            Some(ref schema) => schema_to_values(schema),
390            None => {
391                let err = PlanError::unimplemented(
392                    "ReadRel",
393                    Some("base_schema"),
394                    "Base schema is required",
395                );
396                vec![Value::Missing(err)]
397            }
398        };
399        let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
400
401        Relation {
402            name: Cow::Borrowed("Read"),
403            arguments: Some(Arguments {
404                positional: vec![table_name],
405                named: vec![],
406            }),
407            columns,
408            emit,
409            advanced_extension: rel.advanced_extension.as_ref(),
410            children: vec![],
411        }
412    }
413}
414
415pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
416    rel.as_ref().and_then(|c| c.emit_kind.as_ref())
417}
418
419impl<'a> Relation<'a> {
420    /// Create a vector of values that are references to the emitted outputs of
421    /// this relation. "Emitted" here meaning the outputs of this relation after
422    /// the emit kind has been applied.
423    ///
424    /// This is useful for relations like Filter and Limit whose direct outputs
425    /// are primarily those of its children (direct here meaning before the emit
426    /// has been applied).
427    pub fn input_refs(&self) -> Vec<Value<'a>> {
428        let len = self.emitted();
429        (0..len).map(|i| Value::Reference(i as i32)).collect()
430    }
431
432    /// Convert a vector of relation references into their structured form.
433    ///
434    /// Returns a list of children (with None for ones missing), and a count of input columns.
435    pub fn convert_children<S: Scope>(
436        refs: Vec<Option<&'a Rel>>,
437        ctx: &S,
438    ) -> (Vec<Option<Relation<'a>>>, usize) {
439        let mut children = vec![];
440        let mut inputs = 0;
441
442        for maybe_rel in refs {
443            match maybe_rel {
444                Some(rel) => {
445                    let child = Relation::from_rel(rel, ctx);
446                    inputs += child.emitted();
447                    children.push(Some(child));
448                }
449                None => children.push(None),
450            }
451        }
452
453        (children, inputs)
454    }
455}
456
457impl<'a> Relation<'a> {
458    fn from_filter<S: Scope>(rel: &'a FilterRel, ctx: &S) -> Self {
459        let condition = rel
460            .condition
461            .as_ref()
462            .map(|c| Value::Expression(c.as_ref()));
463        let condition = Value::expect(condition, || {
464            PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
465        });
466        let positional = vec![condition];
467        let arguments = Some(Arguments {
468            positional,
469            named: vec![],
470        });
471        let emit = get_emit(rel.common.as_ref());
472        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
473        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
474
475        Relation {
476            name: Cow::Borrowed("Filter"),
477            arguments,
478            columns,
479            emit,
480            advanced_extension: rel.advanced_extension.as_ref(),
481            children,
482        }
483    }
484
485    fn from_project<S: Scope>(rel: &'a ProjectRel, ctx: &S) -> Self {
486        let (children, input_columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
487        let mut columns: Vec<Value> = vec![];
488        for i in 0..input_columns {
489            columns.push(Value::Reference(i as i32));
490        }
491        for expr in &rel.expressions {
492            columns.push(Value::Expression(expr));
493        }
494
495        Relation {
496            name: Cow::Borrowed("Project"),
497            arguments: None,
498            columns,
499            emit: get_emit(rel.common.as_ref()),
500            advanced_extension: rel.advanced_extension.as_ref(),
501            children,
502        }
503    }
504
505    pub fn from_rel<S: Scope>(rel: &'a Rel, ctx: &S) -> Self {
506        match rel.rel_type.as_ref() {
507            Some(RelType::Read(r)) => Relation::from_read(r, ctx),
508            Some(RelType::Filter(r)) => Relation::from_filter(r, ctx),
509            Some(RelType::Project(r)) => Relation::from_project(r, ctx),
510            Some(RelType::Aggregate(r)) => Relation::from_aggregate(r, ctx),
511            Some(RelType::Sort(r)) => Relation::from_sort(r, ctx),
512            Some(RelType::Fetch(r)) => Relation::from_fetch(r, ctx),
513            Some(RelType::Join(r)) => Relation::from_join(r, ctx),
514            Some(RelType::ExtensionLeaf(r)) => Relation::from_extension_leaf(r, ctx),
515            Some(RelType::ExtensionSingle(r)) => Relation::from_extension_single(r, ctx),
516            Some(RelType::ExtensionMulti(r)) => Relation::from_extension_multi(r, ctx),
517            _ => {
518                let name = rel.name();
519                let token = ctx.failure(FormatError::Format(PlanError::unimplemented(
520                    "Rel",
521                    Some(name),
522                    format!("{name} is not yet supported in the text format"),
523                )));
524                Relation {
525                    name: Cow::Owned(format!("{token}")),
526                    arguments: None,
527                    columns: vec![],
528                    emit: None,
529                    advanced_extension: None,
530                    children: vec![],
531                }
532            }
533        }
534    }
535
536    fn from_extension_leaf<S: Scope>(rel: &'a ExtensionLeafRel, ctx: &S) -> Self {
537        let detail_ref = rel.detail.as_ref().map(AnyRef::from);
538        let decoded = match detail_ref {
539            Some(d) => ctx.extension_registry().decode(d),
540            None => Err(crate::extensions::registry::ExtensionError::MissingDetail),
541        };
542        Relation::from_extension("ExtensionLeaf", decoded, vec![], ctx)
543    }
544
545    fn from_extension_single<S: Scope>(rel: &'a ExtensionSingleRel, ctx: &S) -> Self {
546        let detail_ref = rel.detail.as_ref().map(AnyRef::from);
547        let decoded = match detail_ref {
548            Some(d) => ctx.extension_registry().decode(d),
549            None => Err(crate::extensions::registry::ExtensionError::MissingDetail),
550        };
551        Relation::from_extension("ExtensionSingle", decoded, vec![rel.input.as_deref()], ctx)
552    }
553
554    fn from_extension_multi<S: Scope>(rel: &'a ExtensionMultiRel, ctx: &S) -> Self {
555        let detail_ref = rel.detail.as_ref().map(AnyRef::from);
556        let decoded = match detail_ref {
557            Some(d) => ctx.extension_registry().decode(d),
558            None => Err(crate::extensions::registry::ExtensionError::MissingDetail),
559        };
560        let mut child_refs: Vec<Option<&'a Rel>> = vec![];
561        for input in &rel.inputs {
562            child_refs.push(Some(input));
563        }
564        Relation::from_extension("ExtensionMulti", decoded, child_refs, ctx)
565    }
566
567    fn from_extension<S: Scope>(
568        ext_type: &'static str,
569        decoded: Result<
570            (String, crate::extensions::ExtensionArgs),
571            crate::extensions::registry::ExtensionError,
572        >,
573        child_refs: Vec<Option<&'a Rel>>,
574        ctx: &S,
575    ) -> Self {
576        match decoded {
577            Ok((name, args)) => {
578                let (children, _) = Relation::convert_children(child_refs, ctx);
579                let mut positional = vec![];
580                for value in args.positional {
581                    positional.push(Value::ExtValue(value));
582                }
583                let mut named = vec![];
584                for (key, value) in args.named {
585                    named.push(NamedArg {
586                        name: Cow::Owned(key),
587                        value: Value::ExtValue(value),
588                    });
589                }
590                let mut columns = vec![];
591                for col in args.output_columns {
592                    columns.push(Value::ExtColumn(col));
593                }
594                Relation {
595                    name: Cow::Owned(format!("{}:{}", ext_type, name)),
596                    arguments: Some(Arguments { positional, named }),
597                    columns,
598                    emit: None,
599                    // Extension relations use `detail` rather than
600                    // `advanced_extension`; the field does not exist on these
601                    // proto types.
602                    advanced_extension: None,
603                    children,
604                }
605            }
606            Err(error) => {
607                let (children, _) = Relation::convert_children(child_refs, ctx);
608                Relation {
609                    name: Cow::Borrowed(ext_type),
610                    arguments: None,
611                    columns: vec![Value::Missing(PlanError::invalid(
612                        "extension",
613                        None::<&str>,
614                        error.to_string(),
615                    ))],
616                    emit: None,
617                    advanced_extension: None,
618                    children,
619                }
620            }
621        }
622    }
623
624    /// Convert an AggregateRel to a Relation for textification.
625    ///
626    /// The conversion follows this logic:
627    /// 1. Arguments: Group-by expressions (as Value::Expression)
628    /// 2. Columns: All possible outputs in order:
629    ///    - First: Group-by field references (Value::Reference)
630    ///    - Then: Aggregate function measures (Value::AggregateFunction)
631    /// 3. Emit: Uses the relation's emit mapping to select which outputs to display
632    /// 4. Children: The input relation
633    fn from_aggregate<S: Scope>(rel: &'a AggregateRel, ctx: &S) -> Self {
634        let mut grouping_sets: Vec<Vec<Value>> = vec![]; // the Groupings in the Aggregate
635        let expression_list: Vec<Value>; // grouping_expressions defined on Aggregate
636
637        // if rel.grouping_expressions is empty, the deprecated rel.groupings.grouping_expressions might be set
638        // If *both* the deprecated `rel.groupings.grouping_expressions` and `rel.grouping_expressions` are
639        // set, then we silently ignore the deprecated one.
640        #[allow(deprecated)]
641        if rel.grouping_expressions.is_empty()
642            && !rel.groupings.is_empty()
643            && !rel.groupings[0].grouping_expressions.is_empty()
644        {
645            (expression_list, grouping_sets) = Relation::get_grouping_sets(rel);
646        } else {
647            expression_list = rel
648                .grouping_expressions
649                .iter()
650                .map(Value::Expression)
651                .collect::<Vec<_>>(); // already a list of the unique expressions
652            for group in &rel.groupings {
653                let mut grouping_set: Vec<Value> = vec![];
654                for i in &group.expression_references {
655                    grouping_set.push(Value::Reference(*i as i32));
656                }
657                grouping_sets.push(grouping_set);
658            }
659            // no defined groupings means there is global group by
660            if rel.groupings.is_empty() {
661                grouping_sets.push(vec![]);
662            }
663        }
664
665        let is_single = grouping_sets.len() == 1;
666        let mut positional: Vec<Value> = vec![];
667        for g in grouping_sets {
668            if g.is_empty() {
669                positional.push(Value::EmptyGroup);
670            } else if is_single {
671                // Single non-empty grouping set: spread expressions directly without parens
672                positional.extend(g);
673            } else {
674                positional.push(Value::Tuple(g));
675            }
676        }
677
678        // adding the grouping_sets as a list of Arguments to Aggregate Rel
679        let arguments = Some(Arguments {
680            positional,
681            named: vec![],
682        });
683
684        // The columns are the direct outputs of this relation (before emit)
685        let mut all_outputs: Vec<Value> = expression_list;
686
687        // Then, add all measures (aggregate functions)
688        // These are indexed after the group-by fields
689        for m in &rel.measures {
690            if let Some(agg_fn) = m.measure.as_ref() {
691                all_outputs.push(Value::AggregateFunction(agg_fn));
692            }
693        }
694        let emit = get_emit(rel.common.as_ref());
695        let (children, _) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
696
697        Relation {
698            name: Cow::Borrowed("Aggregate"),
699            arguments,
700            columns: all_outputs,
701            emit,
702            advanced_extension: rel.advanced_extension.as_ref(),
703            children,
704        }
705    }
706
707    fn get_grouping_sets(rel: &'a AggregateRel) -> (Vec<Value<'a>>, Vec<Vec<Value<'a>>>) {
708        let mut grouping_sets: Vec<Vec<Value>> = vec![];
709        let mut expression_list: Vec<Value> = Vec::new();
710
711        // groupings might have the same expressions in their set so we use a map to get unique expressions
712        let mut expression_index_map = HashMap::new();
713        let mut i: i32 = 0; // index for the unique expression in the grouping_expressions list
714
715        for group in &rel.groupings {
716            let mut grouping_set: Vec<Value> = vec![];
717            #[allow(deprecated)]
718            for exp in &group.grouping_expressions {
719                // TODO: use a better key here than encoding to bytes.
720                // Ideally, substrait-rs would support `PartialEq` and `Hash`,
721                // but as there isn't an easy way to do that now, we'll skip.
722                let key = exp.encode_to_vec();
723                expression_index_map.entry(key.clone()).or_insert_with(|| {
724                    let value = Value::Expression(exp);
725                    expression_list.push(value); // new unique expression found
726                    // mapping the byte encoded expression to its index in the group_expression list
727                    let index = i;
728                    i += 1;
729                    index // is expression returned by this closure and inserted into map
730                });
731                grouping_set.push(Value::Reference(expression_index_map[&key]));
732            }
733            grouping_sets.push(grouping_set);
734        }
735        (expression_list, grouping_sets)
736    }
737}
738
739impl Textify for RelRoot {
740    fn name() -> &'static str {
741        "RelRoot"
742    }
743
744    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
745        let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
746
747        write!(
748            w,
749            "{}Root[{}]",
750            ctx.indent(),
751            ctx.separated(names.iter(), ", ")
752        )?;
753        let child_scope = ctx.push_indent();
754        for child in self.input.iter() {
755            writeln!(w)?;
756            child.textify(&child_scope, w)?;
757        }
758
759        Ok(())
760    }
761}
762
763impl Textify for PlanRelType {
764    fn name() -> &'static str {
765        "PlanRelType"
766    }
767
768    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
769        match self {
770            PlanRelType::Rel(rel) => rel.textify(ctx, w),
771            PlanRelType::Root(root) => root.textify(ctx, w),
772        }
773    }
774}
775
776impl Textify for PlanRel {
777    fn name() -> &'static str {
778        "PlanRel"
779    }
780
781    /// Write the relation as a string. Inputs are ignored - those are handled
782    /// separately.
783    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
784        write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
785    }
786}
787
788impl<'a> Relation<'a> {
789    fn from_sort<S: Scope>(rel: &'a SortRel, ctx: &S) -> Self {
790        let (children, input_columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
791        let mut positional = vec![];
792        for sort_field in &rel.sorts {
793            positional.push(Value::from(sort_field));
794        }
795        let arguments = Some(Arguments {
796            positional,
797            named: vec![],
798        });
799        // The columns are the direct outputs of this relation (before emit)
800        let mut col_values = vec![];
801        for i in 0..input_columns {
802            col_values.push(Value::Reference(i as i32));
803        }
804        let emit = get_emit(rel.common.as_ref());
805        Relation {
806            name: Cow::Borrowed("Sort"),
807            arguments,
808            columns: col_values,
809            emit,
810            advanced_extension: rel.advanced_extension.as_ref(),
811            children,
812        }
813    }
814
815    fn from_fetch<S: Scope>(rel: &'a FetchRel, ctx: &S) -> Self {
816        let (children, _) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
817        let mut named_args: Vec<NamedArg> = vec![];
818        match &rel.count_mode {
819            Some(CountMode::CountExpr(expr)) => {
820                named_args.push(NamedArg {
821                    name: Cow::Borrowed("limit"),
822                    value: Value::Expression(expr),
823                });
824            }
825            #[allow(deprecated)]
826            Some(CountMode::Count(val)) => {
827                named_args.push(NamedArg {
828                    name: Cow::Borrowed("limit"),
829                    value: Value::Integer(*val),
830                });
831            }
832            None => {}
833        }
834        if let Some(offset) = &rel.offset_mode {
835            match offset {
836                substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
837                    named_args.push(NamedArg {
838                        name: Cow::Borrowed("offset"),
839                        value: Value::Expression(expr),
840                    });
841                }
842                #[allow(deprecated)]
843                substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
844                    named_args.push(NamedArg {
845                        name: Cow::Borrowed("offset"),
846                        value: Value::Integer(*val),
847                    });
848                }
849            }
850        }
851
852        let emit = get_emit(rel.common.as_ref());
853        let mut columns = vec![];
854        if let Some(EmitKind::Emit(e)) = emit {
855            for &i in &e.output_mapping {
856                columns.push(Value::Reference(i));
857            }
858        }
859        Relation {
860            name: Cow::Borrowed("Fetch"),
861            arguments: Some(Arguments {
862                positional: vec![],
863                named: named_args,
864            }),
865            columns,
866            emit,
867            advanced_extension: rel.advanced_extension.as_ref(),
868            children,
869        }
870    }
871}
872
873fn join_output_columns(
874    join_type: join_rel::JoinType,
875    left_columns: usize,
876    right_columns: usize,
877) -> Vec<Value<'static>> {
878    let total_columns = match join_type {
879        // Inner, Left, Right, Outer joins output columns from both sides
880        join_rel::JoinType::Inner
881        | join_rel::JoinType::Left
882        | join_rel::JoinType::Right
883        | join_rel::JoinType::Outer => left_columns + right_columns,
884
885        // Left semi/anti joins only output columns from the left side
886        join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
887
888        // Right semi/anti joins output columns from the right side
889        join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
890
891        // Single joins behave like semi joins
892        join_rel::JoinType::LeftSingle => left_columns,
893        join_rel::JoinType::RightSingle => right_columns,
894
895        // Mark joins output base columns plus one mark column
896        join_rel::JoinType::LeftMark => left_columns + 1,
897        join_rel::JoinType::RightMark => right_columns + 1,
898
899        // Unspecified - fallback to all columns
900        join_rel::JoinType::Unspecified => left_columns + right_columns,
901    };
902
903    // Output is always a contiguous range starting from $0
904    (0..total_columns)
905        .map(|i| Value::Reference(i as i32))
906        .collect()
907}
908
909impl<'a> Relation<'a> {
910    fn from_join<S: Scope>(rel: &'a JoinRel, ctx: &S) -> Self {
911        let (children, _total_columns) =
912            Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()], ctx);
913
914        // convert_children should preserve input vector length
915        assert_eq!(
916            children.len(),
917            2,
918            "convert_children should return same number of elements as input"
919        );
920
921        // Calculate left and right column counts separately
922        let left_columns = match &children[0] {
923            Some(child) => child.emitted(),
924            None => 0,
925        };
926        let right_columns = match &children[1] {
927            Some(child) => child.emitted(),
928            None => 0,
929        };
930
931        // Convert join type from protobuf i32 to enum value
932        // JoinType is stored as i32 in protobuf, convert to typed enum for processing
933        let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
934            Ok(join_type) => {
935                let join_type_value = match join_type.as_enum_str() {
936                    Ok(s) => Value::Enum(s),
937                    Err(e) => Value::Missing(e),
938                };
939                (join_type, join_type_value)
940            }
941            Err(_) => {
942                // Use Unspecified for the join_type but create an error for the join_type_value
943                let join_type_error = Value::Missing(PlanError::invalid(
944                    "JoinRel",
945                    Some("type"),
946                    format!("Unknown join type: {}", rel.r#type),
947                ));
948                (join_rel::JoinType::Unspecified, join_type_error)
949            }
950        };
951
952        // Join condition
953        let condition = rel
954            .expression
955            .as_ref()
956            .map(|c| Value::Expression(c.as_ref()));
957        let condition = Value::expect(condition, || {
958            PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
959        });
960
961        // TODO: Add support for post_join_filter when grammar is extended
962        // Currently post_join_filter is not supported in the text format
963        // grammar
964        let positional = vec![join_type_value, condition];
965        let arguments = Some(Arguments {
966            positional,
967            named: vec![],
968        });
969
970        let emit = get_emit(rel.common.as_ref());
971        let columns = join_output_columns(join_type, left_columns, right_columns);
972
973        Relation {
974            name: Cow::Borrowed("Join"),
975            arguments,
976            columns,
977            emit,
978            advanced_extension: rel.advanced_extension.as_ref(),
979            children,
980        }
981    }
982}
983
984impl<'a> From<&'a SortField> for Value<'a> {
985    fn from(sf: &'a SortField) -> Self {
986        let field = match &sf.expr {
987            Some(expr) => match &expr.rex_type {
988                Some(substrait::proto::expression::RexType::Selection(fref)) => {
989                    if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
990                        if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
991                            Value::Reference(sf.field)
992                        } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
993                    } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
994                }
995                _ => Value::Missing(PlanError::unimplemented(
996                    "SortField",
997                    Some("expr"),
998                    "Not a selection",
999                )),
1000            },
1001            None => Value::Missing(PlanError::unimplemented(
1002                "SortField",
1003                Some("expr"),
1004                "Missing expr",
1005            )),
1006        };
1007        let direction = match &sf.sort_kind {
1008            Some(kind) => Value::from(kind),
1009            None => Value::Missing(PlanError::invalid(
1010                "SortKind",
1011                Some(Cow::Borrowed("sort_kind")),
1012                "Missing sort_kind",
1013            )),
1014        };
1015        Value::Tuple(vec![field, direction])
1016    }
1017}
1018
1019impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
1020    fn from(enum_val: &'a T) -> Self {
1021        match enum_val.as_enum_str() {
1022            Ok(s) => Value::Enum(s),
1023            Err(e) => Value::Missing(e),
1024        }
1025    }
1026}
1027
1028impl ValueEnum for SortKind {
1029    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
1030        let d = match self {
1031            &SortKind::Direction(d) => SortDirection::try_from(d),
1032            SortKind::ComparisonFunctionReference(f) => {
1033                return Err(PlanError::invalid(
1034                    "SortKind",
1035                    Some(Cow::Owned(format!("function reference{f}"))),
1036                    "SortKind::ComparisonFunctionReference unimplemented",
1037                ));
1038            }
1039        };
1040        let s = match d {
1041            Err(UnknownEnumValue(d)) => {
1042                return Err(PlanError::invalid(
1043                    "SortKind",
1044                    Some(Cow::Owned(format!("unknown variant: {d:?}"))),
1045                    "Unknown SortDirection",
1046                ));
1047            }
1048            Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
1049            Ok(SortDirection::AscNullsLast) => "AscNullsLast",
1050            Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
1051            Ok(SortDirection::DescNullsLast) => "DescNullsLast",
1052            Ok(SortDirection::Clustered) => "Clustered",
1053            Ok(SortDirection::Unspecified) => {
1054                return Err(PlanError::invalid(
1055                    "SortKind",
1056                    Option::<Cow<str>>::None,
1057                    "Unspecified SortDirection",
1058                ));
1059            }
1060        };
1061        Ok(Cow::Borrowed(s))
1062    }
1063}
1064
1065impl ValueEnum for join_rel::JoinType {
1066    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
1067        let s = match self {
1068            join_rel::JoinType::Unspecified => {
1069                return Err(PlanError::invalid(
1070                    "JoinType",
1071                    Option::<Cow<str>>::None,
1072                    "Unspecified JoinType",
1073                ));
1074            }
1075            join_rel::JoinType::Inner => "Inner",
1076            join_rel::JoinType::Outer => "Outer",
1077            join_rel::JoinType::Left => "Left",
1078            join_rel::JoinType::Right => "Right",
1079            join_rel::JoinType::LeftSemi => "LeftSemi",
1080            join_rel::JoinType::RightSemi => "RightSemi",
1081            join_rel::JoinType::LeftAnti => "LeftAnti",
1082            join_rel::JoinType::RightAnti => "RightAnti",
1083            join_rel::JoinType::LeftSingle => "LeftSingle",
1084            join_rel::JoinType::RightSingle => "RightSingle",
1085            join_rel::JoinType::LeftMark => "LeftMark",
1086            join_rel::JoinType::RightMark => "RightMark",
1087        };
1088        Ok(Cow::Borrowed(s))
1089    }
1090}
1091
1092impl<'a> Textify for NamedArg<'a> {
1093    fn name() -> &'static str {
1094        "NamedArg"
1095    }
1096    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
1097        write!(w, "{}=", self.name)?;
1098        self.value.textify(ctx, w)
1099    }
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104    use substrait::proto::aggregate_rel::Grouping;
1105    use substrait::proto::expression::literal::LiteralType;
1106    use substrait::proto::expression::{Literal, RexType, ScalarFunction};
1107    use substrait::proto::function_argument::ArgType;
1108    use substrait::proto::read_rel::{NamedTable, ReadType};
1109    use substrait::proto::rel_common::{Direct, Emit};
1110    use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
1111    use substrait::proto::{
1112        Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
1113    };
1114
1115    use super::*;
1116    use crate::fixtures::TestContext;
1117    use crate::parser::expressions::FieldIndex;
1118
1119    #[test]
1120    fn test_read_rel() {
1121        let ctx = TestContext::new();
1122
1123        // Create a simple ReadRel with a NamedStruct schema
1124        let read_rel = ReadRel {
1125            common: None,
1126            base_schema: Some(NamedStruct {
1127                names: vec!["col1".into(), "column 2".into()],
1128                r#struct: Some(Struct {
1129                    type_variation_reference: 0,
1130                    types: vec![
1131                        Type {
1132                            kind: Some(Kind::I32(ptype::I32 {
1133                                type_variation_reference: 0,
1134                                nullability: Nullability::Nullable as i32,
1135                            })),
1136                        },
1137                        Type {
1138                            kind: Some(Kind::String(ptype::String {
1139                                type_variation_reference: 0,
1140                                nullability: Nullability::Nullable as i32,
1141                            })),
1142                        },
1143                    ],
1144                    nullability: Nullability::Nullable as i32,
1145                }),
1146            }),
1147            filter: None,
1148            best_effort_filter: None,
1149            projection: None,
1150            advanced_extension: None,
1151            read_type: Some(ReadType::NamedTable(NamedTable {
1152                names: vec!["some_db".into(), "test_table".into()],
1153                advanced_extension: None,
1154            })),
1155        };
1156
1157        let rel = Rel {
1158            rel_type: Some(RelType::Read(Box::new(read_rel))),
1159        };
1160        let (result, errors) = ctx.textify(&rel);
1161        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1162        assert_eq!(
1163            result,
1164            "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
1165        );
1166    }
1167
1168    #[test]
1169    fn test_filter_rel() {
1170        let ctx = TestContext::new()
1171            .with_urn(1, "test_urn")
1172            .with_function(1, 10, "gt");
1173
1174        // Create a simple FilterRel with a ReadRel input and a filter expression
1175        let read_rel = ReadRel {
1176            common: None,
1177            base_schema: Some(NamedStruct {
1178                names: vec!["col1".into(), "col2".into()],
1179                r#struct: Some(Struct {
1180                    type_variation_reference: 0,
1181                    types: vec![
1182                        Type {
1183                            kind: Some(Kind::I32(ptype::I32 {
1184                                type_variation_reference: 0,
1185                                nullability: Nullability::Nullable as i32,
1186                            })),
1187                        },
1188                        Type {
1189                            kind: Some(Kind::I32(ptype::I32 {
1190                                type_variation_reference: 0,
1191                                nullability: Nullability::Nullable as i32,
1192                            })),
1193                        },
1194                    ],
1195                    nullability: Nullability::Nullable as i32,
1196                }),
1197            }),
1198            filter: None,
1199            best_effort_filter: None,
1200            projection: None,
1201            advanced_extension: None,
1202            read_type: Some(ReadType::NamedTable(NamedTable {
1203                names: vec!["test_table".into()],
1204                advanced_extension: None,
1205            })),
1206        };
1207
1208        // Create a filter expression: col1 > 10
1209        let filter_expr = Expression {
1210            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1211                function_reference: 10, // gt function
1212                arguments: vec![
1213                    FunctionArgument {
1214                        arg_type: Some(ArgType::Value(Reference(0).into())),
1215                    },
1216                    FunctionArgument {
1217                        arg_type: Some(ArgType::Value(Expression {
1218                            rex_type: Some(RexType::Literal(Literal {
1219                                literal_type: Some(LiteralType::I32(10)),
1220                                nullable: false,
1221                                type_variation_reference: 0,
1222                            })),
1223                        })),
1224                    },
1225                ],
1226                options: vec![],
1227                output_type: None,
1228                #[allow(deprecated)]
1229                args: vec![],
1230            })),
1231        };
1232
1233        let filter_rel = FilterRel {
1234            common: None,
1235            input: Some(Box::new(Rel {
1236                rel_type: Some(RelType::Read(Box::new(read_rel))),
1237            })),
1238            condition: Some(Box::new(filter_expr)),
1239            advanced_extension: None,
1240        };
1241
1242        let rel = Rel {
1243            rel_type: Some(RelType::Filter(Box::new(filter_rel))),
1244        };
1245
1246        let (result, errors) = ctx.textify(&rel);
1247        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1248        let expected = r#"
1249Filter[gt($0, 10:i32) => $0, $1]
1250  Read[test_table => col1:i32?, col2:i32?]"#
1251            .trim_start();
1252        assert_eq!(result, expected);
1253    }
1254
1255    #[test]
1256    fn test_aggregate_function_textify() {
1257        let ctx = TestContext::new()
1258        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1259        .with_function(1, 10, "sum")
1260        .with_function(1, 11, "count");
1261
1262        // Create a simple AggregateFunction
1263        let agg_fn = get_aggregate_func(10, 1);
1264
1265        let value = Value::AggregateFunction(&agg_fn);
1266        let (result, errors) = ctx.textify(&value);
1267
1268        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1269        assert_eq!(result, "sum($1)");
1270    }
1271
1272    #[test]
1273    fn test_aggregate_relation_textify() {
1274        let ctx = TestContext::new()
1275        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1276        .with_function(1, 10, "sum")
1277        .with_function(1, 11, "count");
1278
1279        // Create a simple AggregateRel
1280        let agg_fn1 = get_aggregate_func(10, 1);
1281        let agg_fn2 = get_aggregate_func(11, 1);
1282
1283        let grouping_expressions = vec![Expression {
1284            rex_type: Some(RexType::Selection(Box::new(
1285                FieldIndex(0).to_field_reference(),
1286            ))),
1287        }];
1288
1289        let measures = vec![
1290            aggregate_rel::Measure {
1291                measure: Some(agg_fn1),
1292                filter: None,
1293            },
1294            aggregate_rel::Measure {
1295                measure: Some(agg_fn2),
1296                filter: None,
1297            },
1298        ];
1299
1300        let common = Some(RelCommon {
1301            emit_kind: Some(EmitKind::Emit(Emit {
1302                output_mapping: vec![1, 2], // measures only
1303            })),
1304            ..Default::default()
1305        });
1306
1307        let aggregate_rel = create_aggregate_rel(grouping_expressions, vec![], measures, common);
1308
1309        let rel = Rel {
1310            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1311        };
1312        let (result, errors) = ctx.textify(&rel);
1313
1314        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1315        // Expected: Aggregate[_ => sum($1), count($1)] we chose to emit only measures
1316        assert!(result.contains("Aggregate[_ => sum($1), count($1)]"));
1317    }
1318
1319    #[test]
1320    fn test_multiple_groupings_on_aggregate_deprecated() {
1321        // Protobuf plan that uses AggregateRel.groupings with deprecated
1322        // grouping_expressions, leaving AggregateRel.grouping_expressions empty.
1323        let ctx = TestContext::new()
1324        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1325        .with_function(1, 11, "count");
1326
1327        let grouping_expr_0 = create_exp(0);
1328        let grouping_expr_1 = create_exp(1);
1329
1330        let grouping_sets = vec![
1331            aggregate_rel::Grouping {
1332                #[allow(deprecated)]
1333                grouping_expressions: vec![grouping_expr_0.clone()],
1334                expression_references: vec![],
1335            },
1336            aggregate_rel::Grouping {
1337                #[allow(deprecated)]
1338                grouping_expressions: vec![grouping_expr_0.clone(), grouping_expr_1.clone()],
1339                expression_references: vec![],
1340            },
1341        ];
1342
1343        let aggregate_rel = create_aggregate_rel(vec![], grouping_sets, vec![], None);
1344
1345        let rel = Rel {
1346            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1347        };
1348        let (result, errors) = ctx.textify(&rel);
1349
1350        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1351        assert!(result.contains("Aggregate[($0), ($0, $1) => $0, $1]"));
1352    }
1353
1354    #[test]
1355    fn test_multiple_groupings_with_measure_deprecated() {
1356        // Protobuf plan that uses AggregateRel.groupings with deprecated
1357        // grouping_expressions, leaving AggregateRel.grouping_expressions empty.
1358        let ctx = TestContext::new()
1359        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1360        .with_function(1, 11, "count");
1361
1362        let agg_fn1 = get_aggregate_func(11, 2);
1363
1364        let grouping_expr_0 = create_exp(0);
1365        let grouping_expr_1 = create_exp(1);
1366
1367        let grouping_sets = vec![
1368            aggregate_rel::Grouping {
1369                #[allow(deprecated)]
1370                grouping_expressions: vec![grouping_expr_0.clone()],
1371                expression_references: vec![],
1372            },
1373            aggregate_rel::Grouping {
1374                #[allow(deprecated)]
1375                grouping_expressions: vec![grouping_expr_0.clone(), grouping_expr_1.clone()],
1376                expression_references: vec![],
1377            },
1378        ];
1379
1380        let measures = vec![aggregate_rel::Measure {
1381            measure: Some(agg_fn1),
1382            filter: None,
1383        }];
1384
1385        let aggregate_rel = create_aggregate_rel(vec![], grouping_sets, measures, None);
1386
1387        let rel = Rel {
1388            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1389        };
1390        let (result, errors) = ctx.textify(&rel);
1391
1392        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1393        assert!(result.contains("($0), ($0, $1) => $0, $1, count($2)"));
1394    }
1395
1396    #[test]
1397    fn test_multiple_groupings_on_aggregate() {
1398        let ctx = TestContext::new()
1399        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1400        .with_function(1, 11, "count");
1401
1402        let agg_fn2 = get_aggregate_func(11, 2);
1403
1404        let grouping_expressions = vec![
1405            Expression {
1406                rex_type: Some(RexType::Selection(Box::new(
1407                    FieldIndex(0).to_field_reference(),
1408                ))),
1409            },
1410            Expression {
1411                rex_type: Some(RexType::Selection(Box::new(
1412                    FieldIndex(1).to_field_reference(),
1413                ))),
1414            },
1415        ];
1416
1417        let grouping_sets = vec![
1418            Grouping {
1419                #[allow(deprecated)]
1420                grouping_expressions: vec![],
1421                expression_references: vec![0, 1],
1422            },
1423            Grouping {
1424                #[allow(deprecated)]
1425                grouping_expressions: vec![],
1426                expression_references: vec![0, 1],
1427            },
1428            Grouping {
1429                #[allow(deprecated)]
1430                grouping_expressions: vec![],
1431                expression_references: vec![1],
1432            },
1433            Grouping {
1434                #[allow(deprecated)]
1435                grouping_expressions: vec![],
1436                expression_references: vec![1, 1],
1437            },
1438            Grouping {
1439                #[allow(deprecated)]
1440                grouping_expressions: vec![],
1441                expression_references: vec![],
1442            },
1443        ];
1444
1445        let measures = vec![aggregate_rel::Measure {
1446            measure: Some(agg_fn2),
1447            filter: None,
1448        }];
1449
1450        let aggregate_rel =
1451            create_aggregate_rel(grouping_expressions, grouping_sets, measures, None);
1452
1453        let rel = Rel {
1454            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1455        };
1456        let (result, errors) = ctx.textify(&rel);
1457
1458        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1459        assert!(
1460            result
1461                .contains("Aggregate[($0, $1), ($0, $1), ($1), ($1, $1), _ => $0, $1, count($2)]")
1462        );
1463    }
1464
1465    #[test]
1466    fn test_arguments_textify_positional_only() {
1467        let ctx = TestContext::new();
1468        let args = Arguments {
1469            positional: vec![Value::Integer(42), Value::Integer(7)],
1470            named: vec![],
1471        };
1472        let (result, errors) = ctx.textify(&args);
1473        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1474        assert_eq!(result, "42, 7");
1475    }
1476
1477    #[test]
1478    fn test_arguments_textify_named_only() {
1479        let ctx = TestContext::new();
1480        let args = Arguments {
1481            positional: vec![],
1482            named: vec![
1483                NamedArg {
1484                    name: Cow::Borrowed("limit"),
1485                    value: Value::Integer(10),
1486                },
1487                NamedArg {
1488                    name: Cow::Borrowed("offset"),
1489                    value: Value::Integer(5),
1490                },
1491            ],
1492        };
1493        let (result, errors) = ctx.textify(&args);
1494        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1495        assert_eq!(result, "limit=10, offset=5");
1496    }
1497
1498    #[test]
1499    fn test_join_relation_unknown_type() {
1500        let ctx = TestContext::new();
1501
1502        // Create a join with an unknown/invalid type
1503        let join_rel = JoinRel {
1504            left: Some(Box::new(Rel {
1505                rel_type: Some(RelType::Read(Box::default())),
1506            })),
1507            right: Some(Box::new(Rel {
1508                rel_type: Some(RelType::Read(Box::default())),
1509            })),
1510            expression: Some(Box::new(Expression::default())),
1511            r#type: 999, // Invalid join type
1512            common: None,
1513            post_join_filter: None,
1514            advanced_extension: None,
1515        };
1516
1517        let rel = Rel {
1518            rel_type: Some(RelType::Join(Box::new(join_rel))),
1519        };
1520        let (result, errors) = ctx.textify(&rel);
1521
1522        // Should contain error for unknown join type but still show condition and columns
1523        assert!(!errors.is_empty(), "Expected errors for unknown join type");
1524        assert!(
1525            result.contains("!{JoinRel}"),
1526            "Expected error token for unknown join type"
1527        );
1528        assert!(
1529            result.contains("Join["),
1530            "Expected Join relation to be formatted"
1531        );
1532    }
1533
1534    #[test]
1535    fn test_arguments_textify_both() {
1536        let ctx = TestContext::new();
1537        let args = Arguments {
1538            positional: vec![Value::Integer(1)],
1539            named: vec![NamedArg {
1540                name: "foo".into(),
1541                value: Value::Integer(2),
1542            }],
1543        };
1544        let (result, errors) = ctx.textify(&args);
1545        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1546        assert_eq!(result, "1, foo=2");
1547    }
1548
1549    #[test]
1550    fn test_arguments_textify_empty() {
1551        let ctx = TestContext::new();
1552        let args = Arguments {
1553            positional: vec![],
1554            named: vec![],
1555        };
1556        let (result, errors) = ctx.textify(&args);
1557        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1558        assert_eq!(result, "_");
1559    }
1560
1561    #[test]
1562    fn test_named_arg_textify_error_token() {
1563        let ctx = TestContext::new();
1564        let named_arg = NamedArg {
1565            name: "foo".into(),
1566            value: Value::Missing(PlanError::invalid(
1567                "my_enum",
1568                Some(Cow::Borrowed("my_enum")),
1569                Cow::Borrowed("my_enum"),
1570            )),
1571        };
1572        let (result, errors) = ctx.textify(&named_arg);
1573        // Should show !{my_enum} in the output
1574        assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1575        // Should also accumulate an error
1576        assert!(!errors.is_empty(), "Expected error for error token");
1577    }
1578
1579    #[test]
1580    fn test_join_type_enum_textify() {
1581        // Test that JoinType enum values convert correctly to their string representation
1582        assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1583        assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1584        assert_eq!(
1585            join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1586            "LeftSemi"
1587        );
1588        assert_eq!(
1589            join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1590            "LeftAnti"
1591        );
1592    }
1593
1594    #[test]
1595    fn test_join_output_columns() {
1596        // Test Inner join - outputs all columns from both sides
1597        let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1598        assert_eq!(inner_cols.len(), 5); // 2 + 3 = 5 columns
1599        assert!(matches!(inner_cols[0], Value::Reference(0)));
1600        assert!(matches!(inner_cols[4], Value::Reference(4)));
1601
1602        // Test LeftSemi join - outputs only left columns
1603        let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1604        assert_eq!(left_semi_cols.len(), 2); // Only left columns
1605        assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1606        assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1607
1608        // Test RightSemi join - outputs right columns as contiguous range starting from $0
1609        let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1610        assert_eq!(right_semi_cols.len(), 3); // Only right columns
1611        assert!(matches!(right_semi_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1612        assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1613        assert!(matches!(right_semi_cols[2], Value::Reference(2))); // Last right column
1614
1615        // Test LeftMark join - outputs left columns plus a mark column as contiguous range
1616        let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1617        assert_eq!(left_mark_cols.len(), 3); // 2 left + 1 mark
1618        assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1619        assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1620        assert!(matches!(left_mark_cols[2], Value::Reference(2))); // Mark column at contiguous position
1621
1622        // Test RightMark join - outputs right columns plus a mark column as contiguous range
1623        let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1624        assert_eq!(right_mark_cols.len(), 4); // 3 right + 1 mark
1625        assert!(matches!(right_mark_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1626        assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1627        assert!(matches!(right_mark_cols[2], Value::Reference(2))); // Last right column
1628        assert!(matches!(right_mark_cols[3], Value::Reference(3))); // Mark column at contiguous position
1629    }
1630
1631    fn get_aggregate_func(func_ref: u32, column_ind: i32) -> AggregateFunction {
1632        AggregateFunction {
1633            function_reference: func_ref,
1634            arguments: vec![FunctionArgument {
1635                arg_type: Some(ArgType::Value(Expression {
1636                    rex_type: Some(RexType::Selection(Box::new(
1637                        FieldIndex(column_ind).to_field_reference(),
1638                    ))),
1639                })),
1640            }],
1641            options: vec![],
1642            output_type: None,
1643            invocation: 0,
1644            phase: 0,
1645            sorts: vec![],
1646            #[allow(deprecated)]
1647            args: vec![],
1648        }
1649    }
1650
1651    fn create_aggregate_rel(
1652        grouping_expressions: Vec<Expression>,
1653        grouping_sets: Vec<Grouping>,
1654        measures: Vec<aggregate_rel::Measure>,
1655        common: Option<RelCommon>,
1656    ) -> AggregateRel {
1657        let common = common.or_else(|| {
1658            Some(RelCommon {
1659                emit_kind: Some(EmitKind::Direct(Direct {})),
1660                ..Default::default()
1661            })
1662        });
1663        AggregateRel {
1664            input: Some(Box::new(Rel {
1665                rel_type: Some(RelType::Read(Box::new(ReadRel {
1666                    common: None,
1667                    base_schema: Some(get_basic_schema()),
1668                    filter: None,
1669                    best_effort_filter: None,
1670                    projection: None,
1671                    advanced_extension: None,
1672                    read_type: Some(ReadType::NamedTable(NamedTable {
1673                        names: vec!["orders".into()],
1674                        advanced_extension: None,
1675                    })),
1676                }))),
1677            })),
1678            grouping_expressions,
1679            groupings: grouping_sets,
1680            measures,
1681            common,
1682            advanced_extension: None,
1683        }
1684    }
1685
1686    fn get_basic_schema() -> NamedStruct {
1687        NamedStruct {
1688            names: vec!["category".into(), "amount".into(), "value".into()],
1689            r#struct: Some(Struct {
1690                type_variation_reference: 0,
1691                types: vec![
1692                    Type {
1693                        kind: Some(Kind::String(ptype::String {
1694                            type_variation_reference: 0,
1695                            nullability: Nullability::Nullable as i32,
1696                        })),
1697                    },
1698                    Type {
1699                        kind: Some(Kind::Fp64(ptype::Fp64 {
1700                            type_variation_reference: 0,
1701                            nullability: Nullability::Nullable as i32,
1702                        })),
1703                    },
1704                    Type {
1705                        kind: Some(Kind::I32(ptype::I32 {
1706                            type_variation_reference: 0,
1707                            nullability: Nullability::Nullable as i32,
1708                        })),
1709                    },
1710                ],
1711                nullability: Nullability::Nullable as i32,
1712            }),
1713        }
1714    }
1715
1716    fn create_exp(column_ind: i32) -> Expression {
1717        Expression {
1718            rex_type: Some(RexType::Selection(Box::new(
1719                FieldIndex(column_ind).to_field_reference(),
1720            ))),
1721        }
1722    }
1723
1724    #[test]
1725    fn test_unsupported_rel_type_produces_failure_token() {
1726        use substrait::proto::CrossRel;
1727
1728        let ctx = TestContext::new();
1729
1730        // CrossRel is a valid Substrait relation type that the textifier
1731        // does not yet support.  Wrapping it in a Rel and textifying should
1732        // produce a `!{Rel}` failure token rather than panicking.
1733        let rel = Rel {
1734            rel_type: Some(RelType::Cross(Box::new(CrossRel {
1735                common: None,
1736                left: None,
1737                right: None,
1738                advanced_extension: None,
1739            }))),
1740        };
1741
1742        let (result, errors) = ctx.textify(&rel);
1743
1744        // The output should contain the failure token, not an empty string.
1745        assert!(
1746            result.contains("!{Rel}"),
1747            "Expected '!{{Rel}}' in output, got: {result}"
1748        );
1749
1750        // Exactly one error should have been collected.
1751        assert_eq!(errors.0.len(), 1, "Expected exactly one error: {errors:?}");
1752
1753        // The error should be a Format / Unimplemented error mentioning CrossRel.
1754        match &errors.0[0] {
1755            FormatError::Format(plan_err) => {
1756                assert_eq!(plan_err.message, "Rel");
1757                assert_eq!(
1758                    plan_err.error_type,
1759                    crate::textify::foundation::FormatErrorType::Unimplemented
1760                );
1761                assert!(
1762                    plan_err.lookup.as_deref().unwrap_or("").contains("Cross"),
1763                    "Expected lookup to mention 'Cross', got: {:?}",
1764                    plan_err.lookup
1765                );
1766            }
1767            other => panic!("Expected FormatError::Format, got: {other:?}"),
1768        }
1769    }
1770}