1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::fmt;
4use std::fmt::Debug;
5
6use prost::UnknownEnumValue;
7use substrait::proto::fetch_rel::CountMode;
8use substrait::proto::plan_rel::RelType as PlanRelType;
9use substrait::proto::read_rel::ReadType;
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::EmitKind;
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14 AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct,
15 PlanRel, ProjectRel, ReadRel, Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
16};
17
18use super::expressions::Reference;
19use super::types::Name;
20use super::{PlanError, Scope, Textify};
21
22pub trait NamedRelation {
23 fn name(&self) -> &'static str;
24}
25
26impl NamedRelation for Rel {
27 fn name(&self) -> &'static str {
28 match self.rel_type.as_ref() {
29 None => "UnknownRel",
30 Some(RelType::Read(_)) => "Read",
31 Some(RelType::Filter(_)) => "Filter",
32 Some(RelType::Project(_)) => "Project",
33 Some(RelType::Fetch(_)) => "Fetch",
34 Some(RelType::Aggregate(_)) => "Aggregate",
35 Some(RelType::Sort(_)) => "Sort",
36 Some(RelType::HashJoin(_)) => "HashJoin",
37 Some(RelType::Exchange(_)) => "Exchange",
38 Some(RelType::Join(_)) => "Join",
39 Some(RelType::Set(_)) => "Set",
40 Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
41 Some(RelType::Cross(_)) => "Cross",
42 Some(RelType::Reference(_)) => "Reference",
43 Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
44 Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
45 Some(RelType::Write(_)) => "Write",
46 Some(RelType::Ddl(_)) => "Ddl",
47 Some(RelType::Update(_)) => "Update",
48 Some(RelType::MergeJoin(_)) => "MergeJoin",
49 Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
50 Some(RelType::Window(_)) => "Window",
51 Some(RelType::Expand(_)) => "Expand",
52 }
53 }
54}
55
56impl Textify for Rel {
57 fn name() -> &'static str {
58 "Rel"
59 }
60
61 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
62 match &self.rel_type {
64 Some(substrait::proto::rel::RelType::ExtensionLeaf(ext)) => ext.textify(ctx, w),
65 Some(substrait::proto::rel::RelType::ExtensionSingle(ext)) => ext.textify(ctx, w),
66 Some(substrait::proto::rel::RelType::ExtensionMulti(ext)) => ext.textify(ctx, w),
67 _ => {
68 let relation = Relation::from(self);
70 relation.textify(ctx, w)
71 }
72 }
73 }
74}
75
76pub trait ValueEnum {
82 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
83}
84
85#[derive(Debug, Clone)]
86pub struct NamedArg<'a> {
87 pub name: &'a str,
88 pub value: Value<'a>,
89}
90
91#[derive(Debug, Clone)]
92pub enum Value<'a> {
93 Name(Name<'a>),
94 TableName(Vec<Name<'a>>),
95 Field(Option<Name<'a>>, Option<&'a Type>),
96 Tuple(Vec<Value<'a>>),
97 List(Vec<Value<'a>>),
98 Reference(i32),
99 Expression(&'a Expression),
100 AggregateFunction(&'a AggregateFunction),
101 Missing(PlanError),
103 Enum(Cow<'a, str>),
105 Integer(i64),
106 Float(f64),
107 Boolean(bool),
108}
109
110impl<'a> Value<'a> {
111 pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
112 match maybe_value {
113 Some(s) => s,
114 None => Value::Missing(f()),
115 }
116 }
117}
118
119impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
120 fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
121 match token {
122 Ok(value) => Value::TableName(value),
123 Err(err) => Value::Missing(err),
124 }
125 }
126}
127
128impl<'a> Textify for Value<'a> {
129 fn name() -> &'static str {
130 "Value"
131 }
132
133 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
134 match self {
135 Value::Name(name) => write!(w, "{}", ctx.display(name)),
136 Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
137 Value::Field(name, typ) => {
138 write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
139 }
140 Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
141 Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
142 Value::Reference(i) => write!(w, "{}", Reference(*i)),
143 Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
144 Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
145 Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
146 Value::Enum(res) => write!(w, "&{res}"),
147 Value::Integer(i) => write!(w, "{i}"),
148 Value::Float(f) => write!(w, "{f}"),
149 Value::Boolean(b) => write!(w, "{b}"),
150 }
151 }
152}
153
154fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
155 let mut fields = schema
156 .r#struct
157 .as_ref()
158 .map(|s| s.types.iter())
159 .into_iter()
160 .flatten();
161 let mut names = schema.names.iter();
162
163 let mut values = Vec::new();
167 loop {
168 let field = fields.next();
169 let name = names.next().map(|n| Name(n));
170 if field.is_none() && name.is_none() {
171 break;
172 }
173
174 values.push(Value::Field(name, field));
175 }
176
177 values
178}
179
180struct Emitted<'a> {
181 pub values: &'a [Value<'a>],
182 pub emit: Option<&'a EmitKind>,
183}
184
185impl<'a> Emitted<'a> {
186 pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
187 Self { values, emit }
188 }
189
190 pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
191 write!(w, "{}", ctx.separated(self.values.iter(), ", "))
192 }
193}
194
195impl<'a> Textify for Emitted<'a> {
196 fn name() -> &'static str {
197 "Emitted"
198 }
199
200 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
201 if ctx.options().show_emit {
202 return self.write_direct(ctx, w);
203 }
204
205 let indices = match &self.emit {
206 Some(EmitKind::Emit(e)) => &e.output_mapping,
207 Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
208 None => return self.write_direct(ctx, w),
209 };
210
211 for (i, &index) in indices.iter().enumerate() {
212 if i > 0 {
213 write!(w, ", ")?;
214 }
215
216 match self.values.get(index as usize) {
217 Some(value) => write!(w, "{}", ctx.display(value))?,
218 None => write!(w, "{}", ctx.failure(PlanError::invalid(
219 "Emitted",
220 Some("output_mapping"),
221 format!(
222 "Output mapping index {} is out of bounds for values collection of size {}",
223 index, self.values.len()
224 )
225 )))?,
226 }
227 }
228
229 Ok(())
230 }
231}
232
233#[derive(Debug, Clone)]
234pub struct Arguments<'a> {
235 pub positional: Vec<Value<'a>>,
237 pub named: Vec<NamedArg<'a>>,
239}
240
241impl<'a> Textify for Arguments<'a> {
242 fn name() -> &'static str {
243 "Arguments"
244 }
245 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
246 if self.positional.is_empty() && self.named.is_empty() {
247 return write!(w, "_");
248 }
249
250 write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
251 if !self.positional.is_empty() && !self.named.is_empty() {
252 write!(w, ", ")?;
253 }
254 write!(w, "{}", ctx.separated(self.named.iter(), ", "))
255 }
256}
257
258pub struct Relation<'a> {
259 pub name: &'a str,
260 pub arguments: Option<Arguments<'a>>,
269 pub columns: Vec<Value<'a>>,
272 pub emit: Option<&'a EmitKind>,
274 pub children: Vec<Option<Relation<'a>>>,
276}
277
278impl Textify for Relation<'_> {
279 fn name() -> &'static str {
280 "Relation"
281 }
282
283 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
284 let cols = Emitted::new(&self.columns, self.emit);
285 let indent = ctx.indent();
286 let name = self.name;
287 let cols = ctx.display(&cols);
288 match &self.arguments {
289 None => {
290 write!(w, "{indent}{name}[{cols}]")?;
291 }
292 Some(args) => {
293 let args = ctx.display(args);
294 write!(w, "{indent}{name}[{args} => {cols}]")?;
295 }
296 }
297 let child_scope = ctx.push_indent();
298 for child in self.children.iter().flatten() {
299 writeln!(w)?;
300 child.textify(&child_scope, w)?;
301 }
302 Ok(())
303 }
304}
305
306impl<'a> Relation<'a> {
307 pub fn emitted(&self) -> usize {
308 match self.emit {
309 Some(EmitKind::Emit(e)) => e.output_mapping.len(),
310 Some(EmitKind::Direct(_)) => self.columns.len(),
311 None => self.columns.len(),
312 }
313 }
314}
315
316#[derive(Debug, Copy, Clone)]
317pub struct TableName<'a>(&'a [String]);
318
319impl<'a> Textify for TableName<'a> {
320 fn name() -> &'static str {
321 "TableName"
322 }
323
324 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
325 let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
326 write!(w, "{}", ctx.separated(names.iter(), "."))
327 }
328}
329
330pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
331 match rel {
332 Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
333 _ => Err(PlanError::unimplemented(
334 "ReadRel",
335 Some("table_name"),
336 format!("Unexpected read type {rel:?}") as String,
337 )),
338 }
339}
340
341impl<'a> From<&'a ReadRel> for Relation<'a> {
342 fn from(rel: &'a ReadRel) -> Self {
343 let name = get_table_name(rel.read_type.as_ref());
344 let table_name: Value = match name {
345 Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
346 Err(e) => Value::Missing(e),
347 };
348
349 let columns = match rel.base_schema {
350 Some(ref schema) => schema_to_values(schema),
351 None => {
352 let err = PlanError::unimplemented(
353 "ReadRel",
354 Some("base_schema"),
355 "Base schema is required",
356 );
357 vec![Value::Missing(err)]
358 }
359 };
360 let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
361
362 Relation {
363 name: "Read",
364 arguments: Some(Arguments {
365 positional: vec![table_name],
366 named: vec![],
367 }),
368 columns,
369 emit,
370 children: vec![],
371 }
372 }
373}
374
375pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
376 rel.as_ref().and_then(|c| c.emit_kind.as_ref())
377}
378
379impl<'a> Relation<'a> {
380 pub fn input_refs(&self) -> Vec<Value<'a>> {
388 let len = self.emitted();
389 (0..len).map(|i| Value::Reference(i as i32)).collect()
390 }
391
392 pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
396 let mut children = vec![];
397 let mut inputs = 0;
398
399 for maybe_rel in refs {
400 match maybe_rel {
401 Some(rel) => {
402 let child = Relation::from(rel);
403 inputs += child.emitted();
404 children.push(Some(child));
405 }
406 None => children.push(None),
407 }
408 }
409
410 (children, inputs)
411 }
412}
413
414impl<'a> From<&'a FilterRel> for Relation<'a> {
415 fn from(rel: &'a FilterRel) -> Self {
416 let condition = rel
417 .condition
418 .as_ref()
419 .map(|c| Value::Expression(c.as_ref()));
420 let condition = Value::expect(condition, || {
421 PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
422 });
423 let positional = vec![condition];
424 let arguments = Some(Arguments {
425 positional,
426 named: vec![],
427 });
428 let emit = get_emit(rel.common.as_ref());
429 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
430 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
431
432 Relation {
433 name: "Filter",
434 arguments,
435 columns,
436 emit,
437 children,
438 }
439 }
440}
441
442impl<'a> From<&'a ProjectRel> for Relation<'a> {
443 fn from(rel: &'a ProjectRel) -> Self {
444 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
445 let expressions = rel.expressions.iter().map(Value::Expression);
446 let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
447 columns.extend(expressions);
448
449 Relation {
450 name: "Project",
451 arguments: None,
452 columns,
453 emit: get_emit(rel.common.as_ref()),
454 children,
455 }
456 }
457}
458
459impl<'a> From<&'a Rel> for Relation<'a> {
460 fn from(rel: &'a Rel) -> Self {
461 match rel.rel_type.as_ref() {
462 Some(RelType::Read(r)) => Relation::from(r.as_ref()),
463 Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
464 Some(RelType::Project(r)) => Relation::from(r.as_ref()),
465 Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
466 Some(RelType::Sort(r)) => Relation::from(r.as_ref()),
467 Some(RelType::Fetch(r)) => Relation::from(r.as_ref()),
468 Some(RelType::Join(r)) => Relation::from(r.as_ref()),
469 Some(RelType::ExtensionLeaf(_)) => {
470 unreachable!("Extension relations should use Textify trait directly")
471 }
472 Some(RelType::ExtensionSingle(_)) => {
473 unreachable!("Extension relations should use Textify trait directly")
474 }
475 Some(RelType::ExtensionMulti(_)) => {
476 unreachable!("Extension relations should use Textify trait directly")
477 }
478 _ => todo!(),
479 }
480 }
481}
482
483impl<'a> From<&'a AggregateRel> for Relation<'a> {
484 fn from(rel: &'a AggregateRel) -> Self {
494 let positional = rel
496 .grouping_expressions
497 .iter()
498 .map(Value::Expression)
499 .collect::<Vec<_>>();
500
501 let arguments = Some(Arguments {
502 positional,
503 named: vec![],
504 });
505 let mut all_outputs: Vec<Value> = vec![];
507 let input_field_count = rel.grouping_expressions.len();
508 for i in 0..input_field_count {
509 all_outputs.push(Value::Reference(i as i32));
510 }
511
512 for m in &rel.measures {
515 if let Some(agg_fn) = m.measure.as_ref() {
516 all_outputs.push(Value::AggregateFunction(agg_fn));
517 }
518 }
519 let emit = get_emit(rel.common.as_ref());
520 Relation {
521 name: "Aggregate",
522 arguments,
523 columns: all_outputs,
524 emit,
525 children: rel
526 .input
527 .as_ref()
528 .map(|c| Some(Relation::from(c.as_ref())))
529 .into_iter()
530 .collect(),
531 }
532 }
533}
534
535impl Textify for RelRoot {
536 fn name() -> &'static str {
537 "RelRoot"
538 }
539
540 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
541 let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
542
543 write!(
544 w,
545 "{}Root[{}]",
546 ctx.indent(),
547 ctx.separated(names.iter(), ", ")
548 )?;
549 let child_scope = ctx.push_indent();
550 for child in self.input.iter() {
551 writeln!(w)?;
552 child.textify(&child_scope, w)?;
553 }
554
555 Ok(())
556 }
557}
558
559impl Textify for PlanRelType {
560 fn name() -> &'static str {
561 "PlanRelType"
562 }
563
564 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
565 match self {
566 PlanRelType::Rel(rel) => rel.textify(ctx, w),
567 PlanRelType::Root(root) => root.textify(ctx, w),
568 }
569 }
570}
571
572impl Textify for PlanRel {
573 fn name() -> &'static str {
574 "PlanRel"
575 }
576
577 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
580 write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
581 }
582}
583
584impl<'a> From<&'a SortRel> for Relation<'a> {
585 fn from(rel: &'a SortRel) -> Self {
586 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
587 let positional = rel.sorts.iter().map(Value::from).collect::<Vec<_>>();
588 let arguments = Some(Arguments {
589 positional,
590 named: vec![],
591 });
592 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
594 let emit = get_emit(rel.common.as_ref());
595 Relation {
596 name: "Sort",
597 arguments,
598 columns,
599 emit,
600 children,
601 }
602 }
603}
604
605impl<'a> From<&'a FetchRel> for Relation<'a> {
606 fn from(rel: &'a FetchRel) -> Self {
607 let (children, _columns) = Relation::convert_children(vec![rel.input.as_deref()]);
608 let mut named_args = Vec::new();
609 match &rel.count_mode {
610 Some(CountMode::CountExpr(expr)) => {
611 named_args.push(NamedArg {
612 name: "limit",
613 value: Value::Expression(expr),
614 });
615 }
616 #[allow(deprecated)]
617 Some(CountMode::Count(val)) => {
618 named_args.push(NamedArg {
619 name: "limit",
620 value: Value::Integer(*val),
621 });
622 }
623 None => {}
624 }
625 if let Some(offset) = &rel.offset_mode {
626 match offset {
627 substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
628 named_args.push(NamedArg {
629 name: "offset",
630 value: Value::Expression(expr),
631 });
632 }
633 #[allow(deprecated)]
634 substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
635 named_args.push(NamedArg {
636 name: "offset",
637 value: Value::Integer(*val),
638 });
639 }
640 }
641 }
642
643 let emit = get_emit(rel.common.as_ref());
644 let columns = match emit {
645 Some(EmitKind::Emit(e)) => e
646 .output_mapping
647 .iter()
648 .map(|&i| Value::Reference(i))
649 .collect(),
650 _ => vec![],
651 };
652 Relation {
653 name: "Fetch",
654 arguments: Some(Arguments {
655 positional: vec![],
656 named: named_args,
657 }),
658 columns,
659 emit,
660 children,
661 }
662 }
663}
664
665fn join_output_columns(
666 join_type: join_rel::JoinType,
667 left_columns: usize,
668 right_columns: usize,
669) -> Vec<Value<'static>> {
670 let total_columns = match join_type {
671 join_rel::JoinType::Inner
673 | join_rel::JoinType::Left
674 | join_rel::JoinType::Right
675 | join_rel::JoinType::Outer => left_columns + right_columns,
676
677 join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
679
680 join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
682
683 join_rel::JoinType::LeftSingle => left_columns,
685 join_rel::JoinType::RightSingle => right_columns,
686
687 join_rel::JoinType::LeftMark => left_columns + 1,
689 join_rel::JoinType::RightMark => right_columns + 1,
690
691 join_rel::JoinType::Unspecified => left_columns + right_columns,
693 };
694
695 (0..total_columns)
697 .map(|i| Value::Reference(i as i32))
698 .collect()
699}
700
701impl<'a> From<&'a JoinRel> for Relation<'a> {
702 fn from(rel: &'a JoinRel) -> Self {
703 let (children, _total_columns) =
704 Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()]);
705
706 assert_eq!(
708 children.len(),
709 2,
710 "convert_children should return same number of elements as input"
711 );
712
713 let left_columns = match &children[0] {
715 Some(child) => child.emitted(),
716 None => 0,
717 };
718 let right_columns = match &children[1] {
719 Some(child) => child.emitted(),
720 None => 0,
721 };
722
723 let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
726 Ok(join_type) => {
727 let join_type_value = match join_type.as_enum_str() {
728 Ok(s) => Value::Enum(s),
729 Err(e) => Value::Missing(e),
730 };
731 (join_type, join_type_value)
732 }
733 Err(_) => {
734 let join_type_error = Value::Missing(PlanError::invalid(
736 "JoinRel",
737 Some("type"),
738 format!("Unknown join type: {}", rel.r#type),
739 ));
740 (join_rel::JoinType::Unspecified, join_type_error)
741 }
742 };
743
744 let condition = rel
746 .expression
747 .as_ref()
748 .map(|c| Value::Expression(c.as_ref()));
749 let condition = Value::expect(condition, || {
750 PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
751 });
752
753 let positional = vec![join_type_value, condition];
757 let arguments = Some(Arguments {
758 positional,
759 named: vec![],
760 });
761
762 let emit = get_emit(rel.common.as_ref());
763 let columns = join_output_columns(join_type, left_columns, right_columns);
764
765 Relation {
766 name: "Join",
767 arguments,
768 columns,
769 emit,
770 children,
771 }
772 }
773}
774
775impl<'a> From<&'a SortField> for Value<'a> {
776 fn from(sf: &'a SortField) -> Self {
777 let field = match &sf.expr {
778 Some(expr) => match &expr.rex_type {
779 Some(substrait::proto::expression::RexType::Selection(fref)) => {
780 if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
781 if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
782 Value::Reference(sf.field)
783 } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
784 } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
785 }
786 _ => Value::Missing(PlanError::unimplemented(
787 "SortField",
788 Some("expr"),
789 "Not a selection",
790 )),
791 },
792 None => Value::Missing(PlanError::unimplemented(
793 "SortField",
794 Some("expr"),
795 "Missing expr",
796 )),
797 };
798 let direction = match &sf.sort_kind {
799 Some(kind) => Value::from(kind),
800 None => Value::Missing(PlanError::invalid(
801 "SortKind",
802 Some(Cow::Borrowed("sort_kind")),
803 "Missing sort_kind",
804 )),
805 };
806 Value::Tuple(vec![field, direction])
807 }
808}
809
810impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
811 fn from(enum_val: &'a T) -> Self {
812 match enum_val.as_enum_str() {
813 Ok(s) => Value::Enum(s),
814 Err(e) => Value::Missing(e),
815 }
816 }
817}
818
819impl ValueEnum for SortKind {
820 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
821 let d = match self {
822 &SortKind::Direction(d) => SortDirection::try_from(d),
823 SortKind::ComparisonFunctionReference(f) => {
824 return Err(PlanError::invalid(
825 "SortKind",
826 Some(Cow::Owned(format!("function reference{f}"))),
827 "SortKind::ComparisonFunctionReference unimplemented",
828 ));
829 }
830 };
831 let s = match d {
832 Err(UnknownEnumValue(d)) => {
833 return Err(PlanError::invalid(
834 "SortKind",
835 Some(Cow::Owned(format!("unknown variant: {d:?}"))),
836 "Unknown SortDirection",
837 ));
838 }
839 Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
840 Ok(SortDirection::AscNullsLast) => "AscNullsLast",
841 Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
842 Ok(SortDirection::DescNullsLast) => "DescNullsLast",
843 Ok(SortDirection::Clustered) => "Clustered",
844 Ok(SortDirection::Unspecified) => {
845 return Err(PlanError::invalid(
846 "SortKind",
847 Option::<Cow<str>>::None,
848 "Unspecified SortDirection",
849 ));
850 }
851 };
852 Ok(Cow::Borrowed(s))
853 }
854}
855
856impl ValueEnum for join_rel::JoinType {
857 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
858 let s = match self {
859 join_rel::JoinType::Unspecified => {
860 return Err(PlanError::invalid(
861 "JoinType",
862 Option::<Cow<str>>::None,
863 "Unspecified JoinType",
864 ));
865 }
866 join_rel::JoinType::Inner => "Inner",
867 join_rel::JoinType::Outer => "Outer",
868 join_rel::JoinType::Left => "Left",
869 join_rel::JoinType::Right => "Right",
870 join_rel::JoinType::LeftSemi => "LeftSemi",
871 join_rel::JoinType::RightSemi => "RightSemi",
872 join_rel::JoinType::LeftAnti => "LeftAnti",
873 join_rel::JoinType::RightAnti => "RightAnti",
874 join_rel::JoinType::LeftSingle => "LeftSingle",
875 join_rel::JoinType::RightSingle => "RightSingle",
876 join_rel::JoinType::LeftMark => "LeftMark",
877 join_rel::JoinType::RightMark => "RightMark",
878 };
879 Ok(Cow::Borrowed(s))
880 }
881}
882
883impl<'a> Textify for NamedArg<'a> {
884 fn name() -> &'static str {
885 "NamedArg"
886 }
887 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
888 write!(w, "{}=", self.name)?;
889 self.value.textify(ctx, w)
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use substrait::proto::expression::literal::LiteralType;
896 use substrait::proto::expression::{Literal, RexType, ScalarFunction};
897 use substrait::proto::function_argument::ArgType;
898 use substrait::proto::read_rel::{NamedTable, ReadType};
899 use substrait::proto::rel_common::Emit;
900 use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
901 use substrait::proto::{
902 Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
903 };
904
905 use super::*;
906 use crate::fixtures::TestContext;
907 use crate::parser::expressions::FieldIndex;
908
909 #[test]
910 fn test_read_rel() {
911 let ctx = TestContext::new();
912
913 let read_rel = ReadRel {
915 common: None,
916 base_schema: Some(NamedStruct {
917 names: vec!["col1".into(), "column 2".into()],
918 r#struct: Some(Struct {
919 type_variation_reference: 0,
920 types: vec![
921 Type {
922 kind: Some(Kind::I32(ptype::I32 {
923 type_variation_reference: 0,
924 nullability: Nullability::Nullable as i32,
925 })),
926 },
927 Type {
928 kind: Some(Kind::String(ptype::String {
929 type_variation_reference: 0,
930 nullability: Nullability::Nullable as i32,
931 })),
932 },
933 ],
934 nullability: Nullability::Nullable as i32,
935 }),
936 }),
937 filter: None,
938 best_effort_filter: None,
939 projection: None,
940 advanced_extension: None,
941 read_type: Some(ReadType::NamedTable(NamedTable {
942 names: vec!["some_db".into(), "test_table".into()],
943 advanced_extension: None,
944 })),
945 };
946
947 let rel = Relation::from(&read_rel);
948
949 let (result, errors) = ctx.textify(&rel);
950 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
951 assert_eq!(
952 result,
953 "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
954 );
955 }
956
957 #[test]
958 fn test_filter_rel() {
959 let ctx = TestContext::new()
960 .with_urn(1, "test_urn")
961 .with_function(1, 10, "gt");
962
963 let read_rel = ReadRel {
965 common: None,
966 base_schema: Some(NamedStruct {
967 names: vec!["col1".into(), "col2".into()],
968 r#struct: Some(Struct {
969 type_variation_reference: 0,
970 types: vec![
971 Type {
972 kind: Some(Kind::I32(ptype::I32 {
973 type_variation_reference: 0,
974 nullability: Nullability::Nullable as i32,
975 })),
976 },
977 Type {
978 kind: Some(Kind::I32(ptype::I32 {
979 type_variation_reference: 0,
980 nullability: Nullability::Nullable as i32,
981 })),
982 },
983 ],
984 nullability: Nullability::Nullable as i32,
985 }),
986 }),
987 filter: None,
988 best_effort_filter: None,
989 projection: None,
990 advanced_extension: None,
991 read_type: Some(ReadType::NamedTable(NamedTable {
992 names: vec!["test_table".into()],
993 advanced_extension: None,
994 })),
995 };
996
997 let filter_expr = Expression {
999 rex_type: Some(RexType::ScalarFunction(ScalarFunction {
1000 function_reference: 10, arguments: vec![
1002 FunctionArgument {
1003 arg_type: Some(ArgType::Value(Reference(0).into())),
1004 },
1005 FunctionArgument {
1006 arg_type: Some(ArgType::Value(Expression {
1007 rex_type: Some(RexType::Literal(Literal {
1008 literal_type: Some(LiteralType::I32(10)),
1009 nullable: false,
1010 type_variation_reference: 0,
1011 })),
1012 })),
1013 },
1014 ],
1015 options: vec![],
1016 output_type: None,
1017 #[allow(deprecated)]
1018 args: vec![],
1019 })),
1020 };
1021
1022 let filter_rel = FilterRel {
1023 common: None,
1024 input: Some(Box::new(Rel {
1025 rel_type: Some(RelType::Read(Box::new(read_rel))),
1026 })),
1027 condition: Some(Box::new(filter_expr)),
1028 advanced_extension: None,
1029 };
1030
1031 let rel = Rel {
1032 rel_type: Some(RelType::Filter(Box::new(filter_rel))),
1033 };
1034
1035 let rel = Relation::from(&rel);
1036
1037 let (result, errors) = ctx.textify(&rel);
1038 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1039 let expected = r#"
1040Filter[gt($0, 10:i32) => $0, $1]
1041 Read[test_table => col1:i32?, col2:i32?]"#
1042 .trim_start();
1043 assert_eq!(result, expected);
1044 }
1045
1046 #[test]
1047 fn test_aggregate_function_textify() {
1048 let ctx = TestContext::new()
1049 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1050 .with_function(1, 10, "sum")
1051 .with_function(1, 11, "count");
1052
1053 let agg_fn = AggregateFunction {
1055 function_reference: 10, arguments: vec![FunctionArgument {
1057 arg_type: Some(ArgType::Value(Expression {
1058 rex_type: Some(RexType::Selection(Box::new(
1059 FieldIndex(1).to_field_reference(),
1060 ))),
1061 })),
1062 }],
1063 options: vec![],
1064 output_type: None,
1065 invocation: 0,
1066 phase: 0,
1067 sorts: vec![],
1068 #[allow(deprecated)]
1069 args: vec![],
1070 };
1071
1072 let value = Value::AggregateFunction(&agg_fn);
1073 let (result, errors) = ctx.textify(&value);
1074
1075 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1076 assert_eq!(result, "sum($1)");
1077 }
1078
1079 #[test]
1080 fn test_aggregate_relation_textify() {
1081 let ctx = TestContext::new()
1082 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1083 .with_function(1, 10, "sum")
1084 .with_function(1, 11, "count");
1085
1086 let agg_fn1 = AggregateFunction {
1088 function_reference: 10, arguments: vec![FunctionArgument {
1090 arg_type: Some(ArgType::Value(Expression {
1091 rex_type: Some(RexType::Selection(Box::new(
1092 FieldIndex(1).to_field_reference(),
1093 ))),
1094 })),
1095 }],
1096 options: vec![],
1097 output_type: None,
1098 invocation: 0,
1099 phase: 0,
1100 sorts: vec![],
1101 #[allow(deprecated)]
1102 args: vec![],
1103 };
1104
1105 let agg_fn2 = AggregateFunction {
1106 function_reference: 11, arguments: vec![FunctionArgument {
1108 arg_type: Some(ArgType::Value(Expression {
1109 rex_type: Some(RexType::Selection(Box::new(
1110 FieldIndex(1).to_field_reference(),
1111 ))),
1112 })),
1113 }],
1114 options: vec![],
1115 output_type: None,
1116 invocation: 0,
1117 phase: 0,
1118 sorts: vec![],
1119 #[allow(deprecated)]
1120 args: vec![],
1121 };
1122
1123 let aggregate_rel = AggregateRel {
1124 input: Some(Box::new(Rel {
1125 rel_type: Some(RelType::Read(Box::new(ReadRel {
1126 common: None,
1127 base_schema: Some(NamedStruct {
1128 names: vec!["category".into(), "amount".into()],
1129 r#struct: Some(Struct {
1130 type_variation_reference: 0,
1131 types: vec![
1132 Type {
1133 kind: Some(Kind::String(ptype::String {
1134 type_variation_reference: 0,
1135 nullability: Nullability::Nullable as i32,
1136 })),
1137 },
1138 Type {
1139 kind: Some(Kind::Fp64(ptype::Fp64 {
1140 type_variation_reference: 0,
1141 nullability: Nullability::Nullable as i32,
1142 })),
1143 },
1144 ],
1145 nullability: Nullability::Nullable as i32,
1146 }),
1147 }),
1148 filter: None,
1149 best_effort_filter: None,
1150 projection: None,
1151 advanced_extension: None,
1152 read_type: Some(ReadType::NamedTable(NamedTable {
1153 names: vec!["orders".into()],
1154 advanced_extension: None,
1155 })),
1156 }))),
1157 })),
1158 grouping_expressions: vec![Expression {
1159 rex_type: Some(RexType::Selection(Box::new(
1160 FieldIndex(0).to_field_reference(),
1161 ))),
1162 }],
1163 groupings: vec![],
1164 measures: vec![
1165 aggregate_rel::Measure {
1166 measure: Some(agg_fn1),
1167 filter: None,
1168 },
1169 aggregate_rel::Measure {
1170 measure: Some(agg_fn2),
1171 filter: None,
1172 },
1173 ],
1174 common: Some(RelCommon {
1175 emit_kind: Some(EmitKind::Emit(Emit {
1176 output_mapping: vec![1, 2], })),
1178 ..Default::default()
1179 }),
1180 advanced_extension: None,
1181 };
1182
1183 let relation = Relation::from(&aggregate_rel);
1184 let (result, errors) = ctx.textify(&relation);
1185
1186 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1187 assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
1189 }
1190
1191 #[test]
1192 fn test_arguments_textify_positional_only() {
1193 let ctx = TestContext::new();
1194 let args = Arguments {
1195 positional: vec![Value::Integer(42), Value::Integer(7)],
1196 named: vec![],
1197 };
1198 let (result, errors) = ctx.textify(&args);
1199 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1200 assert_eq!(result, "42, 7");
1201 }
1202
1203 #[test]
1204 fn test_arguments_textify_named_only() {
1205 let ctx = TestContext::new();
1206 let args = Arguments {
1207 positional: vec![],
1208 named: vec![
1209 NamedArg {
1210 name: "limit",
1211 value: Value::Integer(10),
1212 },
1213 NamedArg {
1214 name: "offset",
1215 value: Value::Integer(5),
1216 },
1217 ],
1218 };
1219 let (result, errors) = ctx.textify(&args);
1220 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1221 assert_eq!(result, "limit=10, offset=5");
1222 }
1223
1224 #[test]
1225 fn test_join_relation_unknown_type() {
1226 let ctx = TestContext::new();
1227
1228 let join_rel = JoinRel {
1230 left: Some(Box::new(Rel {
1231 rel_type: Some(RelType::Read(Box::default())),
1232 })),
1233 right: Some(Box::new(Rel {
1234 rel_type: Some(RelType::Read(Box::default())),
1235 })),
1236 expression: Some(Box::new(Expression::default())),
1237 r#type: 999, common: None,
1239 post_join_filter: None,
1240 advanced_extension: None,
1241 };
1242
1243 let relation = Relation::from(&join_rel);
1244 let (result, errors) = ctx.textify(&relation);
1245
1246 assert!(!errors.is_empty(), "Expected errors for unknown join type");
1248 assert!(
1249 result.contains("!{JoinRel}"),
1250 "Expected error token for unknown join type"
1251 );
1252 assert!(
1253 result.contains("Join["),
1254 "Expected Join relation to be formatted"
1255 );
1256 }
1257
1258 #[test]
1259 fn test_arguments_textify_both() {
1260 let ctx = TestContext::new();
1261 let args = Arguments {
1262 positional: vec![Value::Integer(1)],
1263 named: vec![NamedArg {
1264 name: "foo",
1265 value: Value::Integer(2),
1266 }],
1267 };
1268 let (result, errors) = ctx.textify(&args);
1269 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1270 assert_eq!(result, "1, foo=2");
1271 }
1272
1273 #[test]
1274 fn test_arguments_textify_empty() {
1275 let ctx = TestContext::new();
1276 let args = Arguments {
1277 positional: vec![],
1278 named: vec![],
1279 };
1280 let (result, errors) = ctx.textify(&args);
1281 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1282 assert_eq!(result, "_");
1283 }
1284
1285 #[test]
1286 fn test_named_arg_textify_error_token() {
1287 let ctx = TestContext::new();
1288 let named_arg = NamedArg {
1289 name: "foo",
1290 value: Value::Missing(PlanError::invalid(
1291 "my_enum",
1292 Some(Cow::Borrowed("my_enum")),
1293 Cow::Borrowed("my_enum"),
1294 )),
1295 };
1296 let (result, errors) = ctx.textify(&named_arg);
1297 assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1299 assert!(!errors.is_empty(), "Expected error for error token");
1301 }
1302
1303 #[test]
1304 fn test_join_type_enum_textify() {
1305 assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1307 assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1308 assert_eq!(
1309 join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1310 "LeftSemi"
1311 );
1312 assert_eq!(
1313 join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1314 "LeftAnti"
1315 );
1316 }
1317
1318 #[test]
1319 fn test_join_output_columns() {
1320 let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1322 assert_eq!(inner_cols.len(), 5); assert!(matches!(inner_cols[0], Value::Reference(0)));
1324 assert!(matches!(inner_cols[4], Value::Reference(4)));
1325
1326 let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1328 assert_eq!(left_semi_cols.len(), 2); assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1330 assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1331
1332 let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1334 assert_eq!(right_semi_cols.len(), 3); assert!(matches!(right_semi_cols[0], Value::Reference(0))); assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1337 assert!(matches!(right_semi_cols[2], Value::Reference(2))); let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1341 assert_eq!(left_mark_cols.len(), 3); assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1343 assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1344 assert!(matches!(left_mark_cols[2], Value::Reference(2))); let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1348 assert_eq!(right_mark_cols.len(), 4); assert!(matches!(right_mark_cols[0], Value::Reference(0))); assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1351 assert!(matches!(right_mark_cols[2], Value::Reference(2))); assert!(matches!(right_mark_cols[3], Value::Reference(3))); }
1354}