Skip to main content

substrait_explain/textify/
rels.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::convert::TryFrom;
4use std::fmt;
5use std::fmt::Debug;
6
7use prost::{Message, UnknownEnumValue};
8use substrait::proto::fetch_rel::CountMode;
9use substrait::proto::plan_rel::RelType as PlanRelType;
10use substrait::proto::read_rel::ReadType;
11use substrait::proto::rel::RelType;
12use substrait::proto::rel_common::EmitKind;
13use substrait::proto::sort_field::{SortDirection, SortKind};
14use substrait::proto::{
15    AggregateFunction, AggregateRel, Expression, ExtensionLeafRel, ExtensionMultiRel,
16    ExtensionSingleRel, FetchRel, FilterRel, JoinRel, NamedStruct, PlanRel, ProjectRel, ReadRel,
17    Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
18};
19
20use super::addenda::AddendumLines;
21use super::expressions::Reference;
22use super::types::Name;
23use super::{PlanError, Scope, Textify};
24use crate::FormatError;
25use crate::extensions::any::AnyRef;
26use crate::extensions::{ExtensionArgs, ExtensionColumn, ExtensionError, ExtensionValue};
27
28pub trait NamedRelation {
29    fn name(&self) -> &'static str;
30}
31
32impl NamedRelation for Rel {
33    fn name(&self) -> &'static str {
34        match self.rel_type.as_ref() {
35            None => "UnknownRel",
36            Some(RelType::Read(_)) => "Read",
37            Some(RelType::Filter(_)) => "Filter",
38            Some(RelType::Project(_)) => "Project",
39            Some(RelType::Fetch(_)) => "Fetch",
40            Some(RelType::Aggregate(_)) => "Aggregate",
41            Some(RelType::Sort(_)) => "Sort",
42            Some(RelType::HashJoin(_)) => "HashJoin",
43            Some(RelType::Exchange(_)) => "Exchange",
44            Some(RelType::Join(_)) => "Join",
45            Some(RelType::Set(_)) => "Set",
46            Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
47            Some(RelType::Cross(_)) => "Cross",
48            Some(RelType::Reference(_)) => "Reference",
49            Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
50            Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
51            Some(RelType::Write(_)) => "Write",
52            Some(RelType::Ddl(_)) => "Ddl",
53            Some(RelType::Update(_)) => "Update",
54            Some(RelType::MergeJoin(_)) => "MergeJoin",
55            Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
56            Some(RelType::Window(_)) => "Window",
57            Some(RelType::Expand(_)) => "Expand",
58        }
59    }
60}
61
62impl Textify for Rel {
63    fn name() -> &'static str {
64        "Rel"
65    }
66
67    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
68        // delegates to `Relation` which carries `advanced_extension`, so the full
69        // header → enhancement → children sequence is handled uniformly there.
70        Relation::from_rel(self, ctx).textify(ctx, w)
71    }
72}
73
74/// Trait for enums that can be converted to a string representation for
75/// textification.
76///
77/// Returns Ok(str) for valid enum values, or Err([PlanError]) for invalid or
78/// unknown values.
79pub trait ValueEnum {
80    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
81}
82
83#[derive(Debug, Clone)]
84pub struct NamedArg<'a> {
85    pub name: Cow<'a, str>,
86    pub value: Value<'a>,
87}
88
89#[derive(Debug, Clone)]
90pub enum Value<'a> {
91    Name(Name<'a>),
92    TableName(Vec<Name<'a>>),
93    Field(Option<Name<'a>>, Option<&'a Type>),
94    Tuple(Vec<Value<'a>>),
95    List(Vec<Value<'a>>),
96    Reference(i32),
97    Expression(&'a Expression),
98    AggregateFunction(&'a AggregateFunction),
99    /// Represents a missing, invalid, or unspecified value.
100    Missing(PlanError),
101    /// Represents a valid enum value as a string for textification.
102    Enum(Cow<'a, str>),
103    EmptyGroup,
104    Integer(i64),
105    Float(f64),
106    Boolean(bool),
107    /// A decoded extension argument value.
108    ExtValue(ExtensionValue),
109    /// A decoded extension output column.
110    ExtColumn(ExtensionColumn),
111}
112
113impl<'a> Value<'a> {
114    pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
115        match maybe_value {
116            Some(s) => s,
117            None => Value::Missing(f()),
118        }
119    }
120}
121
122impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
123    fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
124        match token {
125            Ok(value) => Value::TableName(value),
126            Err(err) => Value::Missing(err),
127        }
128    }
129}
130
131impl<'a> Textify for Value<'a> {
132    fn name() -> &'static str {
133        "Value"
134    }
135
136    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
137        match self {
138            Value::Name(name) => write!(w, "{}", ctx.display(name)),
139            Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
140            Value::Field(name, typ) => {
141                write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
142            }
143            Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
144            Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
145            Value::Reference(i) => write!(w, "{}", Reference(*i)),
146            Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
147            Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
148            Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
149            Value::Enum(res) => write!(w, "&{res}"),
150            Value::Integer(i) => write!(w, "{i}"),
151            Value::EmptyGroup => write!(w, "_"),
152            Value::Float(f) => write!(w, "{f}"),
153            Value::Boolean(b) => write!(w, "{b}"),
154            Value::ExtValue(ev) => ev.textify(ctx, w),
155            Value::ExtColumn(ec) => ec.textify(ctx, w),
156        }
157    }
158}
159
160fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
161    let mut fields = schema
162        .r#struct
163        .as_ref()
164        .map(|s| s.types.iter())
165        .into_iter()
166        .flatten();
167    let mut names = schema.names.iter();
168
169    // let field_count = schema.r#struct.as_ref().map(|s| s.types.len()).unwrap_or(0);
170    // let name_count = schema.names.len();
171
172    let mut values = Vec::new();
173    loop {
174        let field = fields.next();
175        let name = names.next().map(|n| Name(n));
176        if field.is_none() && name.is_none() {
177            break;
178        }
179
180        values.push(Value::Field(name, field));
181    }
182
183    values
184}
185
186struct Emitted<'a> {
187    pub values: &'a [Value<'a>],
188    pub emit: Option<&'a EmitKind>,
189}
190
191impl<'a> Emitted<'a> {
192    pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
193        Self { values, emit }
194    }
195
196    pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
197        write!(w, "{}", ctx.separated(self.values.iter(), ", "))
198    }
199}
200
201impl<'a> Textify for Emitted<'a> {
202    fn name() -> &'static str {
203        "Emitted"
204    }
205
206    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
207        if ctx.options().show_emit {
208            return self.write_direct(ctx, w);
209        }
210
211        let indices = match &self.emit {
212            Some(EmitKind::Emit(e)) => &e.output_mapping,
213            Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
214            None => return self.write_direct(ctx, w),
215        };
216
217        for (i, &index) in indices.iter().enumerate() {
218            if i > 0 {
219                write!(w, ", ")?;
220            }
221
222            match self.values.get(index as usize) {
223                Some(value) => write!(w, "{}", ctx.display(value))?,
224                None => write!(w, "{}", ctx.failure(PlanError::invalid(
225                    "Emitted",
226                    Some("output_mapping"),
227                    format!(
228                        "Output mapping index {} is out of bounds for values collection of size {}",
229                        index, self.values.len()
230                    )
231                )))?,
232            }
233        }
234
235        Ok(())
236    }
237}
238
239#[derive(Debug, Clone)]
240pub struct Arguments<'a> {
241    /// Positional arguments (e.g., a filter condition, group-bys, etc.)
242    pub positional: Vec<Value<'a>>,
243    /// Named arguments (e.g., limit=10, offset=5)
244    pub named: Vec<NamedArg<'a>>,
245}
246
247impl<'a> Textify for Arguments<'a> {
248    fn name() -> &'static str {
249        "Arguments"
250    }
251    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
252        if self.positional.is_empty() && self.named.is_empty() {
253            return write!(w, "_");
254        }
255
256        write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
257        if !self.positional.is_empty() && !self.named.is_empty() {
258            write!(w, ", ")?;
259        }
260        write!(w, "{}", ctx.separated(self.named.iter(), ", "))
261    }
262}
263
264pub struct Relation<'a> {
265    pub name: Cow<'a, str>,
266    /// Arguments to the relation, if any.
267    ///
268    /// - `None` means this relation does not take arguments, and the argument
269    ///   section is omitted entirely.
270    /// - `Some(args)` with both vectors empty means the relation takes
271    ///   arguments, but none are provided; this will print as `_ => ...`.
272    /// - `Some(args)` with non-empty vectors will print as usual, with
273    ///   positional arguments first, then named arguments, separated by commas.
274    pub arguments: Option<Arguments<'a>>,
275    /// The columns emitted by this relation, pre-emit - the 'direct' column
276    /// output.
277    pub columns: Vec<Value<'a>>,
278    /// The emit kind, if any. If none, use the columns directly.
279    pub emit: Option<&'a EmitKind>,
280    /// `+`-prefixed addendum lines to emit between this relation's header and
281    /// children.  This owns the canonical ordering for `+ Ext`, `+ Enh`, and
282    /// `+ Opt` lines rather than making the generic relation shape grow one
283    /// field per addendum kind.
284    addenda: AddendumLines,
285    /// The input relations.
286    pub children: Vec<Option<Relation<'a>>>,
287}
288
289impl Textify for Relation<'_> {
290    fn name() -> &'static str {
291        "Relation"
292    }
293
294    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
295        self.write_header(ctx, w)?;
296        let child_scope = ctx.push_indent();
297        self.addenda.textify(&child_scope, w)?;
298        self.write_children(ctx, w)?;
299        Ok(())
300    }
301}
302
303impl Relation<'_> {
304    /// Write the single header line for this relation, e.g. `Filter[$0 => $0]`.
305    /// Does not write a trailing newline; callers are responsible for any
306    /// newline that follows (either from an addendum or from the next child).
307    pub fn write_header<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
308        let cols = Emitted::new(&self.columns, self.emit);
309        let indent = ctx.indent();
310        let name = &self.name;
311        let cols = ctx.display(&cols);
312        match &self.arguments {
313            None => {
314                write!(w, "{indent}{name}[{cols}]")
315            }
316            Some(args) => {
317                let args = ctx.display(args);
318                write!(w, "{indent}{name}[{args} => {cols}]")
319            }
320        }
321    }
322
323    /// Write each child relation at one indent level deeper than `ctx`.
324    /// Each child is preceded by a newline.
325    pub fn write_children<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
326        let child_scope = ctx.push_indent();
327        for child in self.children.iter().flatten() {
328            writeln!(w)?;
329            child.textify(&child_scope, w)?;
330        }
331        Ok(())
332    }
333}
334
335impl<'a> Relation<'a> {
336    pub fn emitted(&self) -> usize {
337        match self.emit {
338            Some(EmitKind::Emit(e)) => e.output_mapping.len(),
339            Some(EmitKind::Direct(_)) => self.columns.len(),
340            None => self.columns.len(),
341        }
342    }
343}
344
345#[derive(Debug, Copy, Clone)]
346pub struct TableName<'a>(&'a [String]);
347
348impl<'a> Textify for TableName<'a> {
349    fn name() -> &'static str {
350        "TableName"
351    }
352
353    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
354        let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
355        write!(w, "{}", ctx.separated(names.iter(), "."))
356    }
357}
358
359impl<'a> Relation<'a> {
360    fn from_read<S: Scope>(rel: &'a ReadRel, ctx: &S) -> Self {
361        let columns = read_columns(rel);
362        let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
363
364        match &rel.read_type {
365            Some(ReadType::NamedTable(table)) => {
366                let table_name = Value::TableName(table.names.iter().map(|n| Name(n)).collect());
367                Relation {
368                    name: Cow::Borrowed("Read"),
369                    arguments: Some(Arguments {
370                        positional: vec![table_name],
371                        named: vec![],
372                    }),
373                    columns,
374                    emit,
375                    addenda: AddendumLines::from_advanced_extension(
376                        ctx,
377                        rel.advanced_extension.as_ref(),
378                    ),
379                    children: vec![],
380                }
381            }
382            Some(ReadType::VirtualTable(vt)) => {
383                let positional = vt
384                    .expressions
385                    .iter()
386                    .map(|row| Value::Tuple(row.fields.iter().map(Value::Expression).collect()))
387                    .collect();
388
389                Relation {
390                    name: Cow::Borrowed("Read:Virtual"),
391                    arguments: Some(Arguments {
392                        positional,
393                        named: vec![],
394                    }),
395                    columns,
396                    emit,
397                    addenda: AddendumLines::from_advanced_extension(
398                        ctx,
399                        rel.advanced_extension.as_ref(),
400                    ),
401                    children: vec![],
402                }
403            }
404            Some(ReadType::ExtensionTable(table)) => {
405                let decoded = match table.detail.as_ref().map(AnyRef::from) {
406                    Some(detail) => ctx.extension_registry().decode_extension_table(detail),
407                    None => Err(ExtensionError::MissingDetail),
408                };
409
410                Relation {
411                    name: Cow::Borrowed("Read:Extension"),
412                    arguments: None,
413                    columns,
414                    emit,
415                    addenda: AddendumLines::extension_table(
416                        ctx,
417                        decoded,
418                        rel.advanced_extension.as_ref(),
419                    ),
420                    children: vec![],
421                }
422            }
423            other => {
424                let err = PlanError::unimplemented(
425                    "ReadRel",
426                    Some("read_type"),
427                    format!("Unsupported read type {other:?}"),
428                );
429                Relation {
430                    name: Cow::Borrowed("Read"),
431                    arguments: Some(Arguments {
432                        positional: vec![Value::Missing(err)],
433                        named: vec![],
434                    }),
435                    columns,
436                    emit,
437                    addenda: AddendumLines::from_advanced_extension(
438                        ctx,
439                        rel.advanced_extension.as_ref(),
440                    ),
441                    children: vec![],
442                }
443            }
444        }
445    }
446}
447
448fn read_columns<'a>(rel: &'a ReadRel) -> Vec<Value<'a>> {
449    match rel.base_schema {
450        Some(ref schema) => schema_to_values(schema),
451        None => {
452            let err =
453                PlanError::unimplemented("ReadRel", Some("base_schema"), "Base schema is required");
454            vec![Value::Missing(err)]
455        }
456    }
457}
458
459pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
460    rel.as_ref().and_then(|c| c.emit_kind.as_ref())
461}
462
463impl<'a> Relation<'a> {
464    /// Create a vector of values that are references to the emitted outputs of
465    /// this relation. "Emitted" here meaning the outputs of this relation after
466    /// the emit kind has been applied.
467    ///
468    /// This is useful for relations like Filter and Limit whose direct outputs
469    /// are primarily those of its children (direct here meaning before the emit
470    /// has been applied).
471    pub fn input_refs(&self) -> Vec<Value<'a>> {
472        let len = self.emitted();
473        (0..len).map(|i| Value::Reference(i as i32)).collect()
474    }
475
476    /// Convert a vector of relation references into their structured form.
477    ///
478    /// Returns a list of children (with None for ones missing), and a count of input columns.
479    pub fn convert_children<S: Scope>(
480        refs: Vec<Option<&'a Rel>>,
481        ctx: &S,
482    ) -> (Vec<Option<Relation<'a>>>, usize) {
483        let mut children = vec![];
484        let mut inputs = 0;
485
486        for maybe_rel in refs {
487            match maybe_rel {
488                Some(rel) => {
489                    let child = Relation::from_rel(rel, ctx);
490                    inputs += child.emitted();
491                    children.push(Some(child));
492                }
493                None => children.push(None),
494            }
495        }
496
497        (children, inputs)
498    }
499}
500
501impl<'a> Relation<'a> {
502    fn from_filter<S: Scope>(rel: &'a FilterRel, ctx: &S) -> Self {
503        let condition = rel
504            .condition
505            .as_ref()
506            .map(|c| Value::Expression(c.as_ref()));
507        let condition = Value::expect(condition, || {
508            PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
509        });
510        let positional = vec![condition];
511        let arguments = Some(Arguments {
512            positional,
513            named: vec![],
514        });
515        let emit = get_emit(rel.common.as_ref());
516        let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
517        let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
518
519        Relation {
520            name: Cow::Borrowed("Filter"),
521            arguments,
522            columns,
523            emit,
524            addenda: AddendumLines::from_advanced_extension(ctx, rel.advanced_extension.as_ref()),
525            children,
526        }
527    }
528
529    fn from_project<S: Scope>(rel: &'a ProjectRel, ctx: &S) -> Self {
530        let (children, input_columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
531        let mut columns: Vec<Value> = vec![];
532        for i in 0..input_columns {
533            columns.push(Value::Reference(i as i32));
534        }
535        for expr in &rel.expressions {
536            columns.push(Value::Expression(expr));
537        }
538
539        Relation {
540            name: Cow::Borrowed("Project"),
541            arguments: None,
542            columns,
543            emit: get_emit(rel.common.as_ref()),
544            addenda: AddendumLines::from_advanced_extension(ctx, rel.advanced_extension.as_ref()),
545            children,
546        }
547    }
548
549    pub fn from_rel<S: Scope>(rel: &'a Rel, ctx: &S) -> Self {
550        match rel.rel_type.as_ref() {
551            Some(RelType::Read(r)) => Relation::from_read(r, ctx),
552            Some(RelType::Filter(r)) => Relation::from_filter(r, ctx),
553            Some(RelType::Project(r)) => Relation::from_project(r, ctx),
554            Some(RelType::Aggregate(r)) => Relation::from_aggregate(r, ctx),
555            Some(RelType::Sort(r)) => Relation::from_sort(r, ctx),
556            Some(RelType::Fetch(r)) => Relation::from_fetch(r, ctx),
557            Some(RelType::Join(r)) => Relation::from_join(r, ctx),
558            Some(RelType::ExtensionLeaf(r)) => Relation::from_extension_leaf(r, ctx),
559            Some(RelType::ExtensionSingle(r)) => Relation::from_extension_single(r, ctx),
560            Some(RelType::ExtensionMulti(r)) => Relation::from_extension_multi(r, ctx),
561            _ => {
562                let name = rel.name();
563                let token = ctx.failure(FormatError::Format(PlanError::unimplemented(
564                    "Rel",
565                    Some(name),
566                    format!("{name} is not yet supported in the text format"),
567                )));
568                Relation {
569                    name: Cow::Owned(format!("{token}")),
570                    arguments: None,
571                    columns: vec![],
572                    emit: None,
573                    addenda: AddendumLines::none(),
574                    children: vec![],
575                }
576            }
577        }
578    }
579
580    fn from_extension_leaf<S: Scope>(rel: &'a ExtensionLeafRel, ctx: &S) -> Self {
581        let detail_ref = rel.detail.as_ref().map(AnyRef::from);
582        let decoded = match detail_ref {
583            Some(d) => ctx.extension_registry().decode(d),
584            None => Err(ExtensionError::MissingDetail),
585        };
586        Relation::from_extension("ExtensionLeaf", decoded, vec![], ctx)
587    }
588
589    fn from_extension_single<S: Scope>(rel: &'a ExtensionSingleRel, ctx: &S) -> Self {
590        let detail_ref = rel.detail.as_ref().map(AnyRef::from);
591        let decoded = match detail_ref {
592            Some(d) => ctx.extension_registry().decode(d),
593            None => Err(ExtensionError::MissingDetail),
594        };
595        Relation::from_extension("ExtensionSingle", decoded, vec![rel.input.as_deref()], ctx)
596    }
597
598    fn from_extension_multi<S: Scope>(rel: &'a ExtensionMultiRel, ctx: &S) -> Self {
599        let detail_ref = rel.detail.as_ref().map(AnyRef::from);
600        let decoded = match detail_ref {
601            Some(d) => ctx.extension_registry().decode(d),
602            None => Err(ExtensionError::MissingDetail),
603        };
604        let mut child_refs: Vec<Option<&'a Rel>> = vec![];
605        for input in &rel.inputs {
606            child_refs.push(Some(input));
607        }
608        Relation::from_extension("ExtensionMulti", decoded, child_refs, ctx)
609    }
610
611    fn from_extension<S: Scope>(
612        ext_type: &'static str,
613        decoded: Result<(String, ExtensionArgs), ExtensionError>,
614        child_refs: Vec<Option<&'a Rel>>,
615        ctx: &S,
616    ) -> Self {
617        match decoded {
618            Ok((name, args)) => {
619                let (children, _) = Relation::convert_children(child_refs, ctx);
620                let mut positional = vec![];
621                for value in args.positional {
622                    positional.push(Value::ExtValue(value));
623                }
624                let mut named = vec![];
625                for (key, value) in args.named {
626                    named.push(NamedArg {
627                        name: Cow::Owned(key),
628                        value: Value::ExtValue(value),
629                    });
630                }
631                let mut columns = vec![];
632                for col in args.output_columns {
633                    columns.push(Value::ExtColumn(col));
634                }
635                Relation {
636                    name: Cow::Owned(format!("{}:{}", ext_type, name)),
637                    arguments: Some(Arguments { positional, named }),
638                    columns,
639                    emit: None,
640                    // Extension relations use `detail` rather than
641                    // `advanced_extension`; the field does not exist on these
642                    // proto types.
643                    addenda: AddendumLines::none(),
644                    children,
645                }
646            }
647            Err(error) => {
648                let (children, _) = Relation::convert_children(child_refs, ctx);
649                Relation {
650                    name: Cow::Borrowed(ext_type),
651                    arguments: None,
652                    columns: vec![Value::Missing(PlanError::invalid(
653                        "extension",
654                        None::<&str>,
655                        error.to_string(),
656                    ))],
657                    emit: None,
658                    addenda: AddendumLines::none(),
659                    children,
660                }
661            }
662        }
663    }
664
665    /// Convert an AggregateRel to a Relation for textification.
666    ///
667    /// The conversion follows this logic:
668    /// 1. Arguments: Group-by expressions (as Value::Expression)
669    /// 2. Columns: All possible outputs in order:
670    ///    - First: Group-by field references (Value::Reference)
671    ///    - Then: Aggregate function measures (Value::AggregateFunction)
672    /// 3. Emit: Uses the relation's emit mapping to select which outputs to display
673    /// 4. Children: The input relation
674    fn from_aggregate<S: Scope>(rel: &'a AggregateRel, ctx: &S) -> Self {
675        let mut grouping_sets: Vec<Vec<Value>> = vec![]; // the Groupings in the Aggregate
676        let expression_list: Vec<Value>; // grouping_expressions defined on Aggregate
677
678        // if rel.grouping_expressions is empty, the deprecated rel.groupings.grouping_expressions might be set
679        // If *both* the deprecated `rel.groupings.grouping_expressions` and `rel.grouping_expressions` are
680        // set, then we silently ignore the deprecated one.
681        #[allow(deprecated)]
682        if rel.grouping_expressions.is_empty()
683            && !rel.groupings.is_empty()
684            && !rel.groupings[0].grouping_expressions.is_empty()
685        {
686            (expression_list, grouping_sets) = Relation::get_grouping_sets(rel);
687        } else {
688            expression_list = rel
689                .grouping_expressions
690                .iter()
691                .map(Value::Expression)
692                .collect::<Vec<_>>(); // already a list of the unique expressions
693            for group in &rel.groupings {
694                let mut grouping_set: Vec<Value> = vec![];
695                for i in &group.expression_references {
696                    grouping_set.push(Value::Reference(*i as i32));
697                }
698                grouping_sets.push(grouping_set);
699            }
700            // no defined groupings means there is global group by
701            if rel.groupings.is_empty() {
702                grouping_sets.push(vec![]);
703            }
704        }
705
706        let is_single = grouping_sets.len() == 1;
707        let mut positional: Vec<Value> = vec![];
708        for g in grouping_sets {
709            if g.is_empty() {
710                positional.push(Value::EmptyGroup);
711            } else if is_single {
712                // Single non-empty grouping set: spread expressions directly without parens
713                positional.extend(g);
714            } else {
715                positional.push(Value::Tuple(g));
716            }
717        }
718
719        // adding the grouping_sets as a list of Arguments to Aggregate Rel
720        let arguments = Some(Arguments {
721            positional,
722            named: vec![],
723        });
724
725        // The columns are the direct outputs of this relation (before emit)
726        let mut all_outputs: Vec<Value> = expression_list;
727
728        // Then, add all measures (aggregate functions)
729        // These are indexed after the group-by fields
730        for m in &rel.measures {
731            if let Some(agg_fn) = m.measure.as_ref() {
732                all_outputs.push(Value::AggregateFunction(agg_fn));
733            }
734        }
735        let emit = get_emit(rel.common.as_ref());
736        let (children, _) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
737
738        Relation {
739            name: Cow::Borrowed("Aggregate"),
740            arguments,
741            columns: all_outputs,
742            emit,
743            addenda: AddendumLines::from_advanced_extension(ctx, rel.advanced_extension.as_ref()),
744            children,
745        }
746    }
747
748    fn get_grouping_sets(rel: &'a AggregateRel) -> (Vec<Value<'a>>, Vec<Vec<Value<'a>>>) {
749        let mut grouping_sets: Vec<Vec<Value>> = vec![];
750        let mut expression_list: Vec<Value> = Vec::new();
751
752        // groupings might have the same expressions in their set so we use a map to get unique expressions
753        let mut expression_index_map = HashMap::new();
754        let mut i: i32 = 0; // index for the unique expression in the grouping_expressions list
755
756        for group in &rel.groupings {
757            let mut grouping_set: Vec<Value> = vec![];
758            #[allow(deprecated)]
759            for exp in &group.grouping_expressions {
760                // TODO: use a better key here than encoding to bytes.
761                // Ideally, substrait-rs would support `PartialEq` and `Hash`,
762                // but as there isn't an easy way to do that now, we'll skip.
763                let key = exp.encode_to_vec();
764                expression_index_map.entry(key.clone()).or_insert_with(|| {
765                    let value = Value::Expression(exp);
766                    expression_list.push(value); // new unique expression found
767                    // mapping the byte encoded expression to its index in the group_expression list
768                    let index = i;
769                    i += 1;
770                    index // is expression returned by this closure and inserted into map
771                });
772                grouping_set.push(Value::Reference(expression_index_map[&key]));
773            }
774            grouping_sets.push(grouping_set);
775        }
776        (expression_list, grouping_sets)
777    }
778}
779
780impl Textify for RelRoot {
781    fn name() -> &'static str {
782        "RelRoot"
783    }
784
785    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
786        let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
787
788        write!(
789            w,
790            "{}Root[{}]",
791            ctx.indent(),
792            ctx.separated(names.iter(), ", ")
793        )?;
794        let child_scope = ctx.push_indent();
795        for child in self.input.iter() {
796            writeln!(w)?;
797            child.textify(&child_scope, w)?;
798        }
799
800        Ok(())
801    }
802}
803
804impl Textify for PlanRelType {
805    fn name() -> &'static str {
806        "PlanRelType"
807    }
808
809    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
810        match self {
811            PlanRelType::Rel(rel) => rel.textify(ctx, w),
812            PlanRelType::Root(root) => root.textify(ctx, w),
813        }
814    }
815}
816
817impl Textify for PlanRel {
818    fn name() -> &'static str {
819        "PlanRel"
820    }
821
822    /// Write the relation as a string. Inputs are ignored - those are handled
823    /// separately.
824    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
825        write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
826    }
827}
828
829impl<'a> Relation<'a> {
830    fn from_sort<S: Scope>(rel: &'a SortRel, ctx: &S) -> Self {
831        let (children, input_columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
832        let mut positional = vec![];
833        for sort_field in &rel.sorts {
834            positional.push(Value::from(sort_field));
835        }
836        let arguments = Some(Arguments {
837            positional,
838            named: vec![],
839        });
840        // The columns are the direct outputs of this relation (before emit)
841        let mut col_values = vec![];
842        for i in 0..input_columns {
843            col_values.push(Value::Reference(i as i32));
844        }
845        let emit = get_emit(rel.common.as_ref());
846        Relation {
847            name: Cow::Borrowed("Sort"),
848            arguments,
849            columns: col_values,
850            emit,
851            addenda: AddendumLines::from_advanced_extension(ctx, rel.advanced_extension.as_ref()),
852            children,
853        }
854    }
855
856    fn from_fetch<S: Scope>(rel: &'a FetchRel, ctx: &S) -> Self {
857        let (children, input_columns) = Relation::convert_children(vec![rel.input.as_deref()], ctx);
858        let mut named_args: Vec<NamedArg> = vec![];
859        match &rel.count_mode {
860            Some(CountMode::CountExpr(expr)) => {
861                named_args.push(NamedArg {
862                    name: Cow::Borrowed("limit"),
863                    value: Value::Expression(expr),
864                });
865            }
866            #[allow(deprecated)]
867            Some(CountMode::Count(val)) => {
868                named_args.push(NamedArg {
869                    name: Cow::Borrowed("limit"),
870                    value: Value::Integer(*val),
871                });
872            }
873            None => {}
874        }
875        if let Some(offset) = &rel.offset_mode {
876            match offset {
877                substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
878                    named_args.push(NamedArg {
879                        name: Cow::Borrowed("offset"),
880                        value: Value::Expression(expr),
881                    });
882                }
883                #[allow(deprecated)]
884                substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
885                    named_args.push(NamedArg {
886                        name: Cow::Borrowed("offset"),
887                        value: Value::Integer(*val),
888                    });
889                }
890            }
891        }
892
893        let emit = get_emit(rel.common.as_ref());
894        // Fetch is passthrough — direct output is all input columns.
895        let columns: Vec<Value> = (0..input_columns)
896            .map(|i| Value::Reference(i as i32))
897            .collect();
898        Relation {
899            name: Cow::Borrowed("Fetch"),
900            arguments: Some(Arguments {
901                positional: vec![],
902                named: named_args,
903            }),
904            columns,
905            emit,
906            addenda: AddendumLines::from_advanced_extension(ctx, rel.advanced_extension.as_ref()),
907            children,
908        }
909    }
910}
911
912fn join_output_columns(
913    join_type: join_rel::JoinType,
914    left_columns: usize,
915    right_columns: usize,
916) -> Vec<Value<'static>> {
917    let total_columns = match join_type {
918        // Inner, Left, Right, Outer joins output columns from both sides
919        join_rel::JoinType::Inner
920        | join_rel::JoinType::Left
921        | join_rel::JoinType::Right
922        | join_rel::JoinType::Outer => left_columns + right_columns,
923
924        // Left semi/anti joins only output columns from the left side
925        join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
926
927        // Right semi/anti joins output columns from the right side
928        join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
929
930        // Single joins behave like semi joins
931        join_rel::JoinType::LeftSingle => left_columns,
932        join_rel::JoinType::RightSingle => right_columns,
933
934        // Mark joins output base columns plus one mark column
935        join_rel::JoinType::LeftMark => left_columns + 1,
936        join_rel::JoinType::RightMark => right_columns + 1,
937
938        // Unspecified - fallback to all columns
939        join_rel::JoinType::Unspecified => left_columns + right_columns,
940    };
941
942    // Output is always a contiguous range starting from $0
943    (0..total_columns)
944        .map(|i| Value::Reference(i as i32))
945        .collect()
946}
947
948impl<'a> Relation<'a> {
949    fn from_join<S: Scope>(rel: &'a JoinRel, ctx: &S) -> Self {
950        let (children, _total_columns) =
951            Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()], ctx);
952
953        // convert_children should preserve input vector length
954        assert_eq!(
955            children.len(),
956            2,
957            "convert_children should return same number of elements as input"
958        );
959
960        // Calculate left and right column counts separately
961        let left_columns = match &children[0] {
962            Some(child) => child.emitted(),
963            None => 0,
964        };
965        let right_columns = match &children[1] {
966            Some(child) => child.emitted(),
967            None => 0,
968        };
969
970        // Convert join type from protobuf i32 to enum value
971        // JoinType is stored as i32 in protobuf, convert to typed enum for processing
972        let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
973            Ok(join_type) => {
974                let join_type_value = match join_type.as_enum_str() {
975                    Ok(s) => Value::Enum(s),
976                    Err(e) => Value::Missing(e),
977                };
978                (join_type, join_type_value)
979            }
980            Err(_) => {
981                // Use Unspecified for the join_type but create an error for the join_type_value
982                let join_type_error = Value::Missing(PlanError::invalid(
983                    "JoinRel",
984                    Some("type"),
985                    format!("Unknown join type: {}", rel.r#type),
986                ));
987                (join_rel::JoinType::Unspecified, join_type_error)
988            }
989        };
990
991        // Join condition
992        let condition = rel
993            .expression
994            .as_ref()
995            .map(|c| Value::Expression(c.as_ref()));
996        let condition = Value::expect(condition, || {
997            PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
998        });
999
1000        // TODO: Add support for post_join_filter when grammar is extended
1001        // Currently post_join_filter is not supported in the text format
1002        // grammar
1003        let positional = vec![join_type_value, condition];
1004        let arguments = Some(Arguments {
1005            positional,
1006            named: vec![],
1007        });
1008
1009        let emit = get_emit(rel.common.as_ref());
1010        let columns = join_output_columns(join_type, left_columns, right_columns);
1011
1012        Relation {
1013            name: Cow::Borrowed("Join"),
1014            arguments,
1015            columns,
1016            emit,
1017            addenda: AddendumLines::from_advanced_extension(ctx, rel.advanced_extension.as_ref()),
1018            children,
1019        }
1020    }
1021}
1022
1023impl<'a> From<&'a SortField> for Value<'a> {
1024    fn from(sf: &'a SortField) -> Self {
1025        let field = match &sf.expr {
1026            Some(expr) => match &expr.rex_type {
1027                Some(substrait::proto::expression::RexType::Selection(fref)) => {
1028                    if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
1029                        if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
1030                            Value::Reference(sf.field)
1031                        } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
1032                    } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
1033                }
1034                _ => Value::Missing(PlanError::unimplemented(
1035                    "SortField",
1036                    Some("expr"),
1037                    "Not a selection",
1038                )),
1039            },
1040            None => Value::Missing(PlanError::unimplemented(
1041                "SortField",
1042                Some("expr"),
1043                "Missing expr",
1044            )),
1045        };
1046        let direction = match &sf.sort_kind {
1047            Some(kind) => Value::from(kind),
1048            None => Value::Missing(PlanError::invalid(
1049                "SortKind",
1050                Some(Cow::Borrowed("sort_kind")),
1051                "Missing sort_kind",
1052            )),
1053        };
1054        Value::Tuple(vec![field, direction])
1055    }
1056}
1057
1058impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
1059    fn from(enum_val: &'a T) -> Self {
1060        match enum_val.as_enum_str() {
1061            Ok(s) => Value::Enum(s),
1062            Err(e) => Value::Missing(e),
1063        }
1064    }
1065}
1066
1067impl ValueEnum for SortKind {
1068    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
1069        let d = match self {
1070            &SortKind::Direction(d) => SortDirection::try_from(d),
1071            SortKind::ComparisonFunctionReference(f) => {
1072                return Err(PlanError::invalid(
1073                    "SortKind",
1074                    Some(Cow::Owned(format!("function reference{f}"))),
1075                    "SortKind::ComparisonFunctionReference unimplemented",
1076                ));
1077            }
1078        };
1079        let s = match d {
1080            Err(UnknownEnumValue(d)) => {
1081                return Err(PlanError::invalid(
1082                    "SortKind",
1083                    Some(Cow::Owned(format!("unknown variant: {d:?}"))),
1084                    "Unknown SortDirection",
1085                ));
1086            }
1087            Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
1088            Ok(SortDirection::AscNullsLast) => "AscNullsLast",
1089            Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
1090            Ok(SortDirection::DescNullsLast) => "DescNullsLast",
1091            Ok(SortDirection::Clustered) => "Clustered",
1092            Ok(SortDirection::Unspecified) => {
1093                return Err(PlanError::invalid(
1094                    "SortKind",
1095                    Option::<Cow<str>>::None,
1096                    "Unspecified SortDirection",
1097                ));
1098            }
1099        };
1100        Ok(Cow::Borrowed(s))
1101    }
1102}
1103
1104impl ValueEnum for join_rel::JoinType {
1105    fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
1106        let s = match self {
1107            join_rel::JoinType::Unspecified => {
1108                return Err(PlanError::invalid(
1109                    "JoinType",
1110                    Option::<Cow<str>>::None,
1111                    "Unspecified JoinType",
1112                ));
1113            }
1114            join_rel::JoinType::Inner => "Inner",
1115            join_rel::JoinType::Outer => "Outer",
1116            join_rel::JoinType::Left => "Left",
1117            join_rel::JoinType::Right => "Right",
1118            join_rel::JoinType::LeftSemi => "LeftSemi",
1119            join_rel::JoinType::RightSemi => "RightSemi",
1120            join_rel::JoinType::LeftAnti => "LeftAnti",
1121            join_rel::JoinType::RightAnti => "RightAnti",
1122            join_rel::JoinType::LeftSingle => "LeftSingle",
1123            join_rel::JoinType::RightSingle => "RightSingle",
1124            join_rel::JoinType::LeftMark => "LeftMark",
1125            join_rel::JoinType::RightMark => "RightMark",
1126        };
1127        Ok(Cow::Borrowed(s))
1128    }
1129}
1130
1131impl<'a> Textify for NamedArg<'a> {
1132    fn name() -> &'static str {
1133        "NamedArg"
1134    }
1135    fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
1136        write!(w, "{}=", self.name)?;
1137        self.value.textify(ctx, w)
1138    }
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143    use substrait::proto::aggregate_rel::Grouping;
1144    use substrait::proto::expression::literal::LiteralType;
1145    use substrait::proto::expression::{Literal, RexType, ScalarFunction};
1146    use substrait::proto::function_argument::ArgType;
1147    use substrait::proto::read_rel::{NamedTable, ReadType};
1148    use substrait::proto::rel_common::{Direct, Emit};
1149    use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
1150    use substrait::proto::{
1151        Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
1152    };
1153
1154    use super::*;
1155    use crate::fixtures::TestContext;
1156    use crate::parser::expressions::FieldIndex;
1157
1158    #[test]
1159    fn test_read_rel() {
1160        let ctx = TestContext::new();
1161
1162        // Create a simple ReadRel with a NamedStruct schema
1163        let read_rel = ReadRel {
1164            common: None,
1165            base_schema: Some(NamedStruct {
1166                names: vec!["col1".into(), "column 2".into()],
1167                r#struct: Some(Struct {
1168                    type_variation_reference: 0,
1169                    types: vec![
1170                        Type {
1171                            kind: Some(Kind::I32(ptype::I32 {
1172                                type_variation_reference: 0,
1173                                nullability: Nullability::Nullable as i32,
1174                            })),
1175                        },
1176                        Type {
1177                            kind: Some(Kind::String(ptype::String {
1178                                type_variation_reference: 0,
1179                                nullability: Nullability::Nullable as i32,
1180                            })),
1181                        },
1182                    ],
1183                    nullability: Nullability::Nullable as i32,
1184                }),
1185            }),
1186            filter: None,
1187            best_effort_filter: None,
1188            projection: None,
1189            advanced_extension: None,
1190            read_type: Some(ReadType::NamedTable(NamedTable {
1191                names: vec!["some_db".into(), "test_table".into()],
1192                advanced_extension: None,
1193            })),
1194        };
1195
1196        let rel = Rel {
1197            rel_type: Some(RelType::Read(Box::new(read_rel))),
1198        };
1199        let (result, errors) = ctx.textify(&rel);
1200        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1201        assert_eq!(
1202            result,
1203            "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
1204        );
1205    }
1206
1207    #[test]
1208    fn test_filter_rel() {
1209        let ctx = TestContext::new()
1210            .with_urn(1, "test_urn")
1211            .with_function(1, 10, "gt");
1212
1213        // Create a simple FilterRel with a ReadRel input and a filter expression
1214        let read_rel = ReadRel {
1215            common: None,
1216            base_schema: Some(NamedStruct {
1217                names: vec!["col1".into(), "col2".into()],
1218                r#struct: Some(Struct {
1219                    type_variation_reference: 0,
1220                    types: vec![
1221                        Type {
1222                            kind: Some(Kind::I32(ptype::I32 {
1223                                type_variation_reference: 0,
1224                                nullability: Nullability::Nullable as i32,
1225                            })),
1226                        },
1227                        Type {
1228                            kind: Some(Kind::I32(ptype::I32 {
1229                                type_variation_reference: 0,
1230                                nullability: Nullability::Nullable as i32,
1231                            })),
1232                        },
1233                    ],
1234                    nullability: Nullability::Nullable as i32,
1235                }),
1236            }),
1237            filter: None,
1238            best_effort_filter: None,
1239            projection: None,
1240            advanced_extension: None,
1241            read_type: Some(ReadType::NamedTable(NamedTable {
1242                names: vec!["test_table".into()],
1243                advanced_extension: None,
1244            })),
1245        };
1246
1247        // Create a filter expression: col1 > 10
1248        let filter_expr = Expression {
1249            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1250                function_reference: 10, // gt function
1251                arguments: vec![
1252                    FunctionArgument {
1253                        arg_type: Some(ArgType::Value(Reference(0).into())),
1254                    },
1255                    FunctionArgument {
1256                        arg_type: Some(ArgType::Value(Expression {
1257                            rex_type: Some(RexType::Literal(Literal {
1258                                literal_type: Some(LiteralType::I32(10)),
1259                                nullable: false,
1260                                type_variation_reference: 0,
1261                            })),
1262                        })),
1263                    },
1264                ],
1265                options: vec![],
1266                output_type: None,
1267                #[allow(deprecated)]
1268                args: vec![],
1269            })),
1270        };
1271
1272        let filter_rel = FilterRel {
1273            common: None,
1274            input: Some(Box::new(Rel {
1275                rel_type: Some(RelType::Read(Box::new(read_rel))),
1276            })),
1277            condition: Some(Box::new(filter_expr)),
1278            advanced_extension: None,
1279        };
1280
1281        let rel = Rel {
1282            rel_type: Some(RelType::Filter(Box::new(filter_rel))),
1283        };
1284
1285        let (result, errors) = ctx.textify(&rel);
1286        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1287        let expected = r#"
1288Filter[gt($0, 10:i32) => $0, $1]
1289  Read[test_table => col1:i32?, col2:i32?]"#
1290            .trim_start();
1291        assert_eq!(result, expected);
1292    }
1293
1294    #[test]
1295    fn test_aggregate_function_textify() {
1296        let ctx = TestContext::new()
1297        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1298        .with_function(1, 10, "sum")
1299        .with_function(1, 11, "count");
1300
1301        // Create a simple AggregateFunction
1302        let agg_fn = get_aggregate_func(10, 1);
1303
1304        let value = Value::AggregateFunction(&agg_fn);
1305        let (result, errors) = ctx.textify(&value);
1306
1307        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1308        assert_eq!(result, "sum($1)");
1309    }
1310
1311    #[test]
1312    fn test_aggregate_relation_textify() {
1313        let ctx = TestContext::new()
1314        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1315        .with_function(1, 10, "sum")
1316        .with_function(1, 11, "count");
1317
1318        // Create a simple AggregateRel
1319        let agg_fn1 = get_aggregate_func(10, 1);
1320        let agg_fn2 = get_aggregate_func(11, 1);
1321
1322        let grouping_expressions = vec![Expression {
1323            rex_type: Some(RexType::Selection(Box::new(
1324                FieldIndex(0).to_field_reference(),
1325            ))),
1326        }];
1327
1328        let measures = vec![
1329            aggregate_rel::Measure {
1330                measure: Some(agg_fn1),
1331                filter: None,
1332            },
1333            aggregate_rel::Measure {
1334                measure: Some(agg_fn2),
1335                filter: None,
1336            },
1337        ];
1338
1339        let common = Some(RelCommon {
1340            emit_kind: Some(EmitKind::Emit(Emit {
1341                output_mapping: vec![1, 2], // measures only
1342            })),
1343            ..Default::default()
1344        });
1345
1346        let aggregate_rel = create_aggregate_rel(grouping_expressions, vec![], measures, common);
1347
1348        let rel = Rel {
1349            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1350        };
1351        let (result, errors) = ctx.textify(&rel);
1352
1353        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1354        // Expected: Aggregate[_ => sum($1), count($1)] we chose to emit only measures
1355        assert!(result.contains("Aggregate[_ => sum($1), count($1)]"));
1356    }
1357
1358    #[test]
1359    fn test_multiple_groupings_on_aggregate_deprecated() {
1360        // Protobuf plan that uses AggregateRel.groupings with deprecated
1361        // grouping_expressions, leaving AggregateRel.grouping_expressions empty.
1362        let ctx = TestContext::new()
1363        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1364        .with_function(1, 11, "count");
1365
1366        let grouping_expr_0 = create_exp(0);
1367        let grouping_expr_1 = create_exp(1);
1368
1369        let grouping_sets = vec![
1370            aggregate_rel::Grouping {
1371                #[allow(deprecated)]
1372                grouping_expressions: vec![grouping_expr_0.clone()],
1373                expression_references: vec![],
1374            },
1375            aggregate_rel::Grouping {
1376                #[allow(deprecated)]
1377                grouping_expressions: vec![grouping_expr_0.clone(), grouping_expr_1.clone()],
1378                expression_references: vec![],
1379            },
1380        ];
1381
1382        let aggregate_rel = create_aggregate_rel(vec![], grouping_sets, vec![], None);
1383
1384        let rel = Rel {
1385            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1386        };
1387        let (result, errors) = ctx.textify(&rel);
1388
1389        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1390        assert!(result.contains("Aggregate[($0), ($0, $1) => $0, $1]"));
1391    }
1392
1393    #[test]
1394    fn test_multiple_groupings_with_measure_deprecated() {
1395        // Protobuf plan that uses AggregateRel.groupings with deprecated
1396        // grouping_expressions, leaving AggregateRel.grouping_expressions empty.
1397        let ctx = TestContext::new()
1398        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1399        .with_function(1, 11, "count");
1400
1401        let agg_fn1 = get_aggregate_func(11, 2);
1402
1403        let grouping_expr_0 = create_exp(0);
1404        let grouping_expr_1 = create_exp(1);
1405
1406        let grouping_sets = vec![
1407            aggregate_rel::Grouping {
1408                #[allow(deprecated)]
1409                grouping_expressions: vec![grouping_expr_0.clone()],
1410                expression_references: vec![],
1411            },
1412            aggregate_rel::Grouping {
1413                #[allow(deprecated)]
1414                grouping_expressions: vec![grouping_expr_0.clone(), grouping_expr_1.clone()],
1415                expression_references: vec![],
1416            },
1417        ];
1418
1419        let measures = vec![aggregate_rel::Measure {
1420            measure: Some(agg_fn1),
1421            filter: None,
1422        }];
1423
1424        let aggregate_rel = create_aggregate_rel(vec![], grouping_sets, measures, None);
1425
1426        let rel = Rel {
1427            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1428        };
1429        let (result, errors) = ctx.textify(&rel);
1430
1431        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1432        assert!(result.contains("($0), ($0, $1) => $0, $1, count($2)"));
1433    }
1434
1435    #[test]
1436    fn test_multiple_groupings_on_aggregate() {
1437        let ctx = TestContext::new()
1438        .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1439        .with_function(1, 11, "count");
1440
1441        let agg_fn2 = get_aggregate_func(11, 2);
1442
1443        let grouping_expressions = vec![
1444            Expression {
1445                rex_type: Some(RexType::Selection(Box::new(
1446                    FieldIndex(0).to_field_reference(),
1447                ))),
1448            },
1449            Expression {
1450                rex_type: Some(RexType::Selection(Box::new(
1451                    FieldIndex(1).to_field_reference(),
1452                ))),
1453            },
1454        ];
1455
1456        let grouping_sets = vec![
1457            Grouping {
1458                #[allow(deprecated)]
1459                grouping_expressions: vec![],
1460                expression_references: vec![0, 1],
1461            },
1462            Grouping {
1463                #[allow(deprecated)]
1464                grouping_expressions: vec![],
1465                expression_references: vec![0, 1],
1466            },
1467            Grouping {
1468                #[allow(deprecated)]
1469                grouping_expressions: vec![],
1470                expression_references: vec![1],
1471            },
1472            Grouping {
1473                #[allow(deprecated)]
1474                grouping_expressions: vec![],
1475                expression_references: vec![1, 1],
1476            },
1477            Grouping {
1478                #[allow(deprecated)]
1479                grouping_expressions: vec![],
1480                expression_references: vec![],
1481            },
1482        ];
1483
1484        let measures = vec![aggregate_rel::Measure {
1485            measure: Some(agg_fn2),
1486            filter: None,
1487        }];
1488
1489        let aggregate_rel =
1490            create_aggregate_rel(grouping_expressions, grouping_sets, measures, None);
1491
1492        let rel = Rel {
1493            rel_type: Some(RelType::Aggregate(Box::new(aggregate_rel))),
1494        };
1495        let (result, errors) = ctx.textify(&rel);
1496
1497        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1498        assert!(
1499            result
1500                .contains("Aggregate[($0, $1), ($0, $1), ($1), ($1, $1), _ => $0, $1, count($2)]")
1501        );
1502    }
1503
1504    #[test]
1505    fn test_arguments_textify_positional_only() {
1506        let ctx = TestContext::new();
1507        let args = Arguments {
1508            positional: vec![Value::Integer(42), Value::Integer(7)],
1509            named: vec![],
1510        };
1511        let (result, errors) = ctx.textify(&args);
1512        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1513        assert_eq!(result, "42, 7");
1514    }
1515
1516    #[test]
1517    fn test_arguments_textify_named_only() {
1518        let ctx = TestContext::new();
1519        let args = Arguments {
1520            positional: vec![],
1521            named: vec![
1522                NamedArg {
1523                    name: Cow::Borrowed("limit"),
1524                    value: Value::Integer(10),
1525                },
1526                NamedArg {
1527                    name: Cow::Borrowed("offset"),
1528                    value: Value::Integer(5),
1529                },
1530            ],
1531        };
1532        let (result, errors) = ctx.textify(&args);
1533        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1534        assert_eq!(result, "limit=10, offset=5");
1535    }
1536
1537    #[test]
1538    fn test_join_relation_unknown_type() {
1539        let ctx = TestContext::new();
1540
1541        // Create a join with an unknown/invalid type
1542        let join_rel = JoinRel {
1543            left: Some(Box::new(Rel {
1544                rel_type: Some(RelType::Read(Box::default())),
1545            })),
1546            right: Some(Box::new(Rel {
1547                rel_type: Some(RelType::Read(Box::default())),
1548            })),
1549            expression: Some(Box::new(Expression::default())),
1550            r#type: 999, // Invalid join type
1551            common: None,
1552            post_join_filter: None,
1553            advanced_extension: None,
1554        };
1555
1556        let rel = Rel {
1557            rel_type: Some(RelType::Join(Box::new(join_rel))),
1558        };
1559        let (result, errors) = ctx.textify(&rel);
1560
1561        // Should contain error for unknown join type but still show condition and columns
1562        assert!(!errors.is_empty(), "Expected errors for unknown join type");
1563        assert!(
1564            result.contains("!{JoinRel}"),
1565            "Expected error token for unknown join type"
1566        );
1567        assert!(
1568            result.contains("Join["),
1569            "Expected Join relation to be formatted"
1570        );
1571    }
1572
1573    #[test]
1574    fn test_arguments_textify_both() {
1575        let ctx = TestContext::new();
1576        let args = Arguments {
1577            positional: vec![Value::Integer(1)],
1578            named: vec![NamedArg {
1579                name: "foo".into(),
1580                value: Value::Integer(2),
1581            }],
1582        };
1583        let (result, errors) = ctx.textify(&args);
1584        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1585        assert_eq!(result, "1, foo=2");
1586    }
1587
1588    #[test]
1589    fn test_arguments_textify_empty() {
1590        let ctx = TestContext::new();
1591        let args = Arguments {
1592            positional: vec![],
1593            named: vec![],
1594        };
1595        let (result, errors) = ctx.textify(&args);
1596        assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1597        assert_eq!(result, "_");
1598    }
1599
1600    #[test]
1601    fn test_named_arg_textify_error_token() {
1602        let ctx = TestContext::new();
1603        let named_arg = NamedArg {
1604            name: "foo".into(),
1605            value: Value::Missing(PlanError::invalid(
1606                "my_enum",
1607                Some(Cow::Borrowed("my_enum")),
1608                Cow::Borrowed("my_enum"),
1609            )),
1610        };
1611        let (result, errors) = ctx.textify(&named_arg);
1612        // Should show !{my_enum} in the output
1613        assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1614        // Should also accumulate an error
1615        assert!(!errors.is_empty(), "Expected error for error token");
1616    }
1617
1618    #[test]
1619    fn test_join_type_enum_textify() {
1620        // Test that JoinType enum values convert correctly to their string representation
1621        assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1622        assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1623        assert_eq!(
1624            join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1625            "LeftSemi"
1626        );
1627        assert_eq!(
1628            join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1629            "LeftAnti"
1630        );
1631    }
1632
1633    #[test]
1634    fn test_join_output_columns() {
1635        // Test Inner join - outputs all columns from both sides
1636        let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1637        assert_eq!(inner_cols.len(), 5); // 2 + 3 = 5 columns
1638        assert!(matches!(inner_cols[0], Value::Reference(0)));
1639        assert!(matches!(inner_cols[4], Value::Reference(4)));
1640
1641        // Test LeftSemi join - outputs only left columns
1642        let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1643        assert_eq!(left_semi_cols.len(), 2); // Only left columns
1644        assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1645        assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1646
1647        // Test RightSemi join - outputs right columns as contiguous range starting from $0
1648        let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1649        assert_eq!(right_semi_cols.len(), 3); // Only right columns
1650        assert!(matches!(right_semi_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1651        assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1652        assert!(matches!(right_semi_cols[2], Value::Reference(2))); // Last right column
1653
1654        // Test LeftMark join - outputs left columns plus a mark column as contiguous range
1655        let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1656        assert_eq!(left_mark_cols.len(), 3); // 2 left + 1 mark
1657        assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1658        assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1659        assert!(matches!(left_mark_cols[2], Value::Reference(2))); // Mark column at contiguous position
1660
1661        // Test RightMark join - outputs right columns plus a mark column as contiguous range
1662        let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1663        assert_eq!(right_mark_cols.len(), 4); // 3 right + 1 mark
1664        assert!(matches!(right_mark_cols[0], Value::Reference(0))); // Contiguous range starts at $0
1665        assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1666        assert!(matches!(right_mark_cols[2], Value::Reference(2))); // Last right column
1667        assert!(matches!(right_mark_cols[3], Value::Reference(3))); // Mark column at contiguous position
1668    }
1669
1670    fn get_aggregate_func(func_ref: u32, column_ind: i32) -> AggregateFunction {
1671        AggregateFunction {
1672            function_reference: func_ref,
1673            arguments: vec![FunctionArgument {
1674                arg_type: Some(ArgType::Value(Expression {
1675                    rex_type: Some(RexType::Selection(Box::new(
1676                        FieldIndex(column_ind).to_field_reference(),
1677                    ))),
1678                })),
1679            }],
1680            options: vec![],
1681            output_type: None,
1682            invocation: 0,
1683            phase: 0,
1684            sorts: vec![],
1685            #[allow(deprecated)]
1686            args: vec![],
1687        }
1688    }
1689
1690    fn create_aggregate_rel(
1691        grouping_expressions: Vec<Expression>,
1692        grouping_sets: Vec<Grouping>,
1693        measures: Vec<aggregate_rel::Measure>,
1694        common: Option<RelCommon>,
1695    ) -> AggregateRel {
1696        let common = common.or_else(|| {
1697            Some(RelCommon {
1698                emit_kind: Some(EmitKind::Direct(Direct {})),
1699                ..Default::default()
1700            })
1701        });
1702        AggregateRel {
1703            input: Some(Box::new(Rel {
1704                rel_type: Some(RelType::Read(Box::new(ReadRel {
1705                    common: None,
1706                    base_schema: Some(get_basic_schema()),
1707                    filter: None,
1708                    best_effort_filter: None,
1709                    projection: None,
1710                    advanced_extension: None,
1711                    read_type: Some(ReadType::NamedTable(NamedTable {
1712                        names: vec!["orders".into()],
1713                        advanced_extension: None,
1714                    })),
1715                }))),
1716            })),
1717            grouping_expressions,
1718            groupings: grouping_sets,
1719            measures,
1720            common,
1721            advanced_extension: None,
1722        }
1723    }
1724
1725    fn get_basic_schema() -> NamedStruct {
1726        NamedStruct {
1727            names: vec!["category".into(), "amount".into(), "value".into()],
1728            r#struct: Some(Struct {
1729                type_variation_reference: 0,
1730                types: vec![
1731                    Type {
1732                        kind: Some(Kind::String(ptype::String {
1733                            type_variation_reference: 0,
1734                            nullability: Nullability::Nullable as i32,
1735                        })),
1736                    },
1737                    Type {
1738                        kind: Some(Kind::Fp64(ptype::Fp64 {
1739                            type_variation_reference: 0,
1740                            nullability: Nullability::Nullable as i32,
1741                        })),
1742                    },
1743                    Type {
1744                        kind: Some(Kind::I32(ptype::I32 {
1745                            type_variation_reference: 0,
1746                            nullability: Nullability::Nullable as i32,
1747                        })),
1748                    },
1749                ],
1750                nullability: Nullability::Nullable as i32,
1751            }),
1752        }
1753    }
1754
1755    fn create_exp(column_ind: i32) -> Expression {
1756        Expression {
1757            rex_type: Some(RexType::Selection(Box::new(
1758                FieldIndex(column_ind).to_field_reference(),
1759            ))),
1760        }
1761    }
1762
1763    #[test]
1764    fn test_unsupported_rel_type_produces_failure_token() {
1765        use substrait::proto::CrossRel;
1766
1767        let ctx = TestContext::new();
1768
1769        // CrossRel is a valid Substrait relation type that the textifier
1770        // does not yet support.  Wrapping it in a Rel and textifying should
1771        // produce a `!{Rel}` failure token rather than panicking.
1772        let rel = Rel {
1773            rel_type: Some(RelType::Cross(Box::new(CrossRel {
1774                common: None,
1775                left: None,
1776                right: None,
1777                advanced_extension: None,
1778            }))),
1779        };
1780
1781        let (result, errors) = ctx.textify(&rel);
1782
1783        // The output should contain the failure token, not an empty string.
1784        assert!(
1785            result.contains("!{Rel}"),
1786            "Expected '!{{Rel}}' in output, got: {result}"
1787        );
1788
1789        // Exactly one error should have been collected.
1790        assert_eq!(errors.0.len(), 1, "Expected exactly one error: {errors:?}");
1791
1792        // The error should be a Format / Unimplemented error mentioning CrossRel.
1793        match &errors.0[0] {
1794            FormatError::Format(plan_err) => {
1795                assert_eq!(plan_err.message, "Rel");
1796                assert_eq!(
1797                    plan_err.error_type,
1798                    crate::textify::foundation::FormatErrorType::Unimplemented
1799                );
1800                assert!(
1801                    plan_err.lookup.as_deref().unwrap_or("").contains("Cross"),
1802                    "Expected lookup to mention 'Cross', got: {:?}",
1803                    plan_err.lookup
1804                );
1805            }
1806            other => panic!("Expected FormatError::Format, got: {other:?}"),
1807        }
1808    }
1809}