1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::fmt;
4use std::fmt::Debug;
5
6use prost::UnknownEnumValue;
7use substrait::proto::fetch_rel::CountMode;
8use substrait::proto::plan_rel::RelType as PlanRelType;
9use substrait::proto::read_rel::ReadType;
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::EmitKind;
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14 AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct,
15 PlanRel, ProjectRel, ReadRel, Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
16};
17
18use super::expressions::Reference;
19use super::types::Name;
20use super::{PlanError, Scope, Textify};
21
22pub trait NamedRelation {
23 fn name(&self) -> &'static str;
24}
25
26impl NamedRelation for Rel {
27 fn name(&self) -> &'static str {
28 match self.rel_type.as_ref() {
29 None => "UnknownRel",
30 Some(RelType::Read(_)) => "Read",
31 Some(RelType::Filter(_)) => "Filter",
32 Some(RelType::Project(_)) => "Project",
33 Some(RelType::Fetch(_)) => "Fetch",
34 Some(RelType::Aggregate(_)) => "Aggregate",
35 Some(RelType::Sort(_)) => "Sort",
36 Some(RelType::HashJoin(_)) => "HashJoin",
37 Some(RelType::Exchange(_)) => "Exchange",
38 Some(RelType::Join(_)) => "Join",
39 Some(RelType::Set(_)) => "Set",
40 Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
41 Some(RelType::Cross(_)) => "Cross",
42 Some(RelType::Reference(_)) => "Reference",
43 Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
44 Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
45 Some(RelType::Write(_)) => "Write",
46 Some(RelType::Ddl(_)) => "Ddl",
47 Some(RelType::Update(_)) => "Update",
48 Some(RelType::MergeJoin(_)) => "MergeJoin",
49 Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
50 Some(RelType::Window(_)) => "Window",
51 Some(RelType::Expand(_)) => "Expand",
52 }
53 }
54}
55
56pub trait ValueEnum {
62 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
63}
64
65#[derive(Debug, Clone)]
66pub struct NamedArg<'a> {
67 pub name: &'a str,
68 pub value: Value<'a>,
69}
70
71#[derive(Debug, Clone)]
72pub enum Value<'a> {
73 Name(Name<'a>),
74 TableName(Vec<Name<'a>>),
75 Field(Option<Name<'a>>, Option<&'a Type>),
76 Tuple(Vec<Value<'a>>),
77 List(Vec<Value<'a>>),
78 Reference(i32),
79 Expression(&'a Expression),
80 AggregateFunction(&'a AggregateFunction),
81 Missing(PlanError),
83 Enum(Cow<'a, str>),
85 Integer(i32),
86}
87
88impl<'a> Value<'a> {
89 pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
90 match maybe_value {
91 Some(s) => s,
92 None => Value::Missing(f()),
93 }
94 }
95}
96
97impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
98 fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
99 match token {
100 Ok(value) => Value::TableName(value),
101 Err(err) => Value::Missing(err),
102 }
103 }
104}
105
106impl<'a> Textify for Value<'a> {
107 fn name() -> &'static str {
108 "Value"
109 }
110
111 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
112 match self {
113 Value::Name(name) => write!(w, "{}", ctx.display(name)),
114 Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
115 Value::Field(name, typ) => {
116 write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
117 }
118 Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
119 Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
120 Value::Reference(i) => write!(w, "{}", Reference(*i)),
121 Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
122 Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
123 Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
124 Value::Enum(res) => write!(w, "&{res}"),
125 Value::Integer(i) => write!(w, "{i}"),
126 }
127 }
128}
129
130fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
131 let mut fields = schema
132 .r#struct
133 .as_ref()
134 .map(|s| s.types.iter())
135 .into_iter()
136 .flatten();
137 let mut names = schema.names.iter();
138
139 let mut values = Vec::new();
143 loop {
144 let field = fields.next();
145 let name = names.next().map(|n| Name(n));
146 if field.is_none() && name.is_none() {
147 break;
148 }
149
150 values.push(Value::Field(name, field));
151 }
152
153 values
154}
155
156struct Emitted<'a> {
157 pub values: &'a [Value<'a>],
158 pub emit: Option<&'a EmitKind>,
159}
160
161impl<'a> Emitted<'a> {
162 pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
163 Self { values, emit }
164 }
165
166 pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
167 write!(w, "{}", ctx.separated(self.values.iter(), ", "))
168 }
169}
170
171impl<'a> Textify for Emitted<'a> {
172 fn name() -> &'static str {
173 "Emitted"
174 }
175
176 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
177 if ctx.options().show_emit {
178 return self.write_direct(ctx, w);
179 }
180
181 let indices = match &self.emit {
182 Some(EmitKind::Emit(e)) => &e.output_mapping,
183 Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
184 None => return self.write_direct(ctx, w),
185 };
186
187 for (i, &index) in indices.iter().enumerate() {
188 if i > 0 {
189 write!(w, ", ")?;
190 }
191
192 match self.values.get(index as usize) {
193 Some(value) => write!(w, "{}", ctx.display(value))?,
194 None => write!(w, "{}", ctx.failure(PlanError::invalid(
195 "Emitted",
196 Some("output_mapping"),
197 format!(
198 "Output mapping index {} is out of bounds for values collection of size {}",
199 index, self.values.len()
200 )
201 )))?,
202 }
203 }
204
205 Ok(())
206 }
207}
208
209#[derive(Debug, Clone)]
210pub struct Arguments<'a> {
211 pub positional: Vec<Value<'a>>,
213 pub named: Vec<NamedArg<'a>>,
215}
216
217impl<'a> Textify for Arguments<'a> {
218 fn name() -> &'static str {
219 "Arguments"
220 }
221 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
222 if self.positional.is_empty() && self.named.is_empty() {
223 return write!(w, "_");
224 }
225
226 write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
227 if !self.positional.is_empty() && !self.named.is_empty() {
228 write!(w, ", ")?;
229 }
230 write!(w, "{}", ctx.separated(self.named.iter(), ", "))
231 }
232}
233
234pub struct Relation<'a> {
235 pub name: &'a str,
236 pub arguments: Option<Arguments<'a>>,
245 pub columns: Vec<Value<'a>>,
248 pub emit: Option<&'a EmitKind>,
250 pub children: Vec<Option<Relation<'a>>>,
252}
253
254impl Textify for Relation<'_> {
255 fn name() -> &'static str {
256 "Relation"
257 }
258
259 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
260 let cols = Emitted::new(&self.columns, self.emit);
261 let indent = ctx.indent();
262 let name = self.name;
263 let cols = ctx.display(&cols);
264 match &self.arguments {
265 None => {
266 write!(w, "{indent}{name}[{cols}]")?;
267 }
268 Some(args) => {
269 let args = ctx.display(args);
270 write!(w, "{indent}{name}[{args} => {cols}]")?;
271 }
272 }
273 let child_scope = ctx.push_indent();
274 for child in self.children.iter().flatten() {
275 writeln!(w)?;
276 child.textify(&child_scope, w)?;
277 }
278 Ok(())
279 }
280}
281
282impl<'a> Relation<'a> {
283 pub fn emitted(&self) -> usize {
284 match self.emit {
285 Some(EmitKind::Emit(e)) => e.output_mapping.len(),
286 Some(EmitKind::Direct(_)) => self.columns.len(),
287 None => self.columns.len(),
288 }
289 }
290}
291
292#[derive(Debug, Copy, Clone)]
293pub struct TableName<'a>(&'a [String]);
294
295impl<'a> Textify for TableName<'a> {
296 fn name() -> &'static str {
297 "TableName"
298 }
299
300 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
301 let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
302 write!(w, "{}", ctx.separated(names.iter(), "."))
303 }
304}
305
306pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
307 match rel {
308 Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
309 _ => Err(PlanError::unimplemented(
310 "ReadRel",
311 Some("table_name"),
312 format!("Unexpected read type {rel:?}") as String,
313 )),
314 }
315}
316
317impl<'a> From<&'a ReadRel> for Relation<'a> {
318 fn from(rel: &'a ReadRel) -> Self {
319 let name = get_table_name(rel.read_type.as_ref());
320 let table_name: Value = match name {
321 Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
322 Err(e) => Value::Missing(e),
323 };
324
325 let columns = match rel.base_schema {
326 Some(ref schema) => schema_to_values(schema),
327 None => {
328 let err = PlanError::unimplemented(
329 "ReadRel",
330 Some("base_schema"),
331 "Base schema is required",
332 );
333 vec![Value::Missing(err)]
334 }
335 };
336 let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
337
338 Relation {
339 name: "Read",
340 arguments: Some(Arguments {
341 positional: vec![table_name],
342 named: vec![],
343 }),
344 columns,
345 emit,
346 children: vec![],
347 }
348 }
349}
350
351pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
352 rel.as_ref().and_then(|c| c.emit_kind.as_ref())
353}
354
355impl<'a> Relation<'a> {
356 pub fn input_refs(&self) -> Vec<Value<'a>> {
364 let len = self.emitted();
365 (0..len).map(|i| Value::Reference(i as i32)).collect()
366 }
367
368 pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
372 let mut children = vec![];
373 let mut inputs = 0;
374
375 for maybe_rel in refs {
376 match maybe_rel {
377 Some(rel) => {
378 let child = Relation::from(rel);
379 inputs += child.emitted();
380 children.push(Some(child));
381 }
382 None => children.push(None),
383 }
384 }
385
386 (children, inputs)
387 }
388}
389
390impl<'a> From<&'a FilterRel> for Relation<'a> {
391 fn from(rel: &'a FilterRel) -> Self {
392 let condition = rel
393 .condition
394 .as_ref()
395 .map(|c| Value::Expression(c.as_ref()));
396 let condition = Value::expect(condition, || {
397 PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
398 });
399 let positional = vec![condition];
400 let arguments = Some(Arguments {
401 positional,
402 named: vec![],
403 });
404 let emit = get_emit(rel.common.as_ref());
405 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
406 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
407
408 Relation {
409 name: "Filter",
410 arguments,
411 columns,
412 emit,
413 children,
414 }
415 }
416}
417
418impl<'a> From<&'a ProjectRel> for Relation<'a> {
419 fn from(rel: &'a ProjectRel) -> Self {
420 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
421 let expressions = rel.expressions.iter().map(Value::Expression);
422 let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
423 columns.extend(expressions);
424
425 Relation {
426 name: "Project",
427 arguments: None,
428 columns,
429 emit: get_emit(rel.common.as_ref()),
430 children,
431 }
432 }
433}
434
435impl<'a> From<&'a Rel> for Relation<'a> {
436 fn from(rel: &'a Rel) -> Self {
437 match rel.rel_type.as_ref() {
438 Some(RelType::Read(r)) => Relation::from(r.as_ref()),
439 Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
440 Some(RelType::Project(r)) => Relation::from(r.as_ref()),
441 Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
442 Some(RelType::Sort(r)) => Relation::from(r.as_ref()),
443 Some(RelType::Fetch(r)) => Relation::from(r.as_ref()),
444 Some(RelType::Join(r)) => Relation::from(r.as_ref()),
445 _ => todo!(),
446 }
447 }
448}
449
450impl<'a> From<&'a AggregateRel> for Relation<'a> {
451 fn from(rel: &'a AggregateRel) -> Self {
461 let positional = rel
463 .grouping_expressions
464 .iter()
465 .map(Value::Expression)
466 .collect::<Vec<_>>();
467
468 let arguments = Some(Arguments {
469 positional,
470 named: vec![],
471 });
472 let mut all_outputs: Vec<Value> = vec![];
474 let input_field_count = rel.grouping_expressions.len();
475 for i in 0..input_field_count {
476 all_outputs.push(Value::Reference(i as i32));
477 }
478
479 for m in &rel.measures {
482 if let Some(agg_fn) = m.measure.as_ref() {
483 all_outputs.push(Value::AggregateFunction(agg_fn));
484 }
485 }
486 let emit = get_emit(rel.common.as_ref());
487 Relation {
488 name: "Aggregate",
489 arguments,
490 columns: all_outputs,
491 emit,
492 children: rel
493 .input
494 .as_ref()
495 .map(|c| Some(Relation::from(c.as_ref())))
496 .into_iter()
497 .collect(),
498 }
499 }
500}
501
502impl Textify for RelRoot {
503 fn name() -> &'static str {
504 "RelRoot"
505 }
506
507 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
508 let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
509
510 write!(
511 w,
512 "{}Root[{}]",
513 ctx.indent(),
514 ctx.separated(names.iter(), ", ")
515 )?;
516 let child_scope = ctx.push_indent();
517 for child in self.input.iter() {
518 let child = Relation::from(child);
519 writeln!(w)?;
520 child.textify(&child_scope, w)?;
521 }
522
523 Ok(())
524 }
525}
526
527impl Textify for PlanRelType {
528 fn name() -> &'static str {
529 "PlanRelType"
530 }
531
532 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
533 match self {
534 PlanRelType::Rel(rel) => Relation::from(rel).textify(ctx, w),
535 PlanRelType::Root(root) => root.textify(ctx, w),
536 }
537 }
538}
539
540impl Textify for PlanRel {
541 fn name() -> &'static str {
542 "PlanRel"
543 }
544
545 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
548 write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
549 }
550}
551
552impl<'a> From<&'a SortRel> for Relation<'a> {
553 fn from(rel: &'a SortRel) -> Self {
554 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
555 let positional = rel.sorts.iter().map(Value::from).collect::<Vec<_>>();
556 let arguments = Some(Arguments {
557 positional,
558 named: vec![],
559 });
560 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
562 let emit = get_emit(rel.common.as_ref());
563 Relation {
564 name: "Sort",
565 arguments,
566 columns,
567 emit,
568 children,
569 }
570 }
571}
572
573impl<'a> From<&'a FetchRel> for Relation<'a> {
574 fn from(rel: &'a FetchRel) -> Self {
575 let (children, _columns) = Relation::convert_children(vec![rel.input.as_deref()]);
576 let mut named_args = Vec::new();
577 match &rel.count_mode {
578 Some(CountMode::CountExpr(expr)) => {
579 named_args.push(NamedArg {
580 name: "limit",
581 value: Value::Expression(expr),
582 });
583 }
584 Some(CountMode::Count(val)) => {
585 named_args.push(NamedArg {
586 name: "limit",
587 value: Value::Integer(*val as i32),
588 });
589 }
590 None => {}
591 }
592 if let Some(offset) = &rel.offset_mode {
593 match offset {
594 substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
595 named_args.push(NamedArg {
596 name: "offset",
597 value: Value::Expression(expr),
598 });
599 }
600 substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
601 named_args.push(NamedArg {
602 name: "offset",
603 value: Value::Integer(*val as i32),
604 });
605 }
606 }
607 }
608
609 let emit = get_emit(rel.common.as_ref());
610 let columns = match emit {
611 Some(EmitKind::Emit(e)) => e
612 .output_mapping
613 .iter()
614 .map(|&i| Value::Reference(i))
615 .collect(),
616 _ => vec![],
617 };
618 Relation {
619 name: "Fetch",
620 arguments: Some(Arguments {
621 positional: vec![],
622 named: named_args,
623 }),
624 columns,
625 emit,
626 children,
627 }
628 }
629}
630
631fn join_output_columns(
632 join_type: join_rel::JoinType,
633 left_columns: usize,
634 right_columns: usize,
635) -> Vec<Value<'static>> {
636 let total_columns = match join_type {
637 join_rel::JoinType::Inner
639 | join_rel::JoinType::Left
640 | join_rel::JoinType::Right
641 | join_rel::JoinType::Outer => left_columns + right_columns,
642
643 join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
645
646 join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
648
649 join_rel::JoinType::LeftSingle => left_columns,
651 join_rel::JoinType::RightSingle => right_columns,
652
653 join_rel::JoinType::LeftMark => left_columns + 1,
655 join_rel::JoinType::RightMark => right_columns + 1,
656
657 join_rel::JoinType::Unspecified => left_columns + right_columns,
659 };
660
661 (0..total_columns)
663 .map(|i| Value::Reference(i as i32))
664 .collect()
665}
666
667impl<'a> From<&'a JoinRel> for Relation<'a> {
668 fn from(rel: &'a JoinRel) -> Self {
669 let (children, _total_columns) =
670 Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()]);
671
672 assert_eq!(
674 children.len(),
675 2,
676 "convert_children should return same number of elements as input"
677 );
678
679 let left_columns = match &children[0] {
681 Some(child) => child.emitted(),
682 None => 0,
683 };
684 let right_columns = match &children[1] {
685 Some(child) => child.emitted(),
686 None => 0,
687 };
688
689 let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
692 Ok(join_type) => {
693 let join_type_value = match join_type.as_enum_str() {
694 Ok(s) => Value::Enum(s),
695 Err(e) => Value::Missing(e),
696 };
697 (join_type, join_type_value)
698 }
699 Err(_) => {
700 let join_type_error = Value::Missing(PlanError::invalid(
702 "JoinRel",
703 Some("type"),
704 format!("Unknown join type: {}", rel.r#type),
705 ));
706 (join_rel::JoinType::Unspecified, join_type_error)
707 }
708 };
709
710 let condition = rel
712 .expression
713 .as_ref()
714 .map(|c| Value::Expression(c.as_ref()));
715 let condition = Value::expect(condition, || {
716 PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
717 });
718
719 let positional = vec![join_type_value, condition];
723 let arguments = Some(Arguments {
724 positional,
725 named: vec![],
726 });
727
728 let emit = get_emit(rel.common.as_ref());
729 let columns = join_output_columns(join_type, left_columns, right_columns);
730
731 Relation {
732 name: "Join",
733 arguments,
734 columns,
735 emit,
736 children,
737 }
738 }
739}
740
741impl<'a> From<&'a SortField> for Value<'a> {
742 fn from(sf: &'a SortField) -> Self {
743 let field = match &sf.expr {
744 Some(expr) => match &expr.rex_type {
745 Some(substrait::proto::expression::RexType::Selection(fref)) => {
746 if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
747 if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
748 Value::Reference(sf.field)
749 } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
750 } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
751 }
752 _ => Value::Missing(PlanError::unimplemented(
753 "SortField",
754 Some("expr"),
755 "Not a selection",
756 )),
757 },
758 None => Value::Missing(PlanError::unimplemented(
759 "SortField",
760 Some("expr"),
761 "Missing expr",
762 )),
763 };
764 let direction = match &sf.sort_kind {
765 Some(kind) => Value::from(kind),
766 None => Value::Missing(PlanError::invalid(
767 "SortKind",
768 Some(Cow::Borrowed("sort_kind")),
769 "Missing sort_kind",
770 )),
771 };
772 Value::Tuple(vec![field, direction])
773 }
774}
775
776impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
777 fn from(enum_val: &'a T) -> Self {
778 match enum_val.as_enum_str() {
779 Ok(s) => Value::Enum(s),
780 Err(e) => Value::Missing(e),
781 }
782 }
783}
784
785impl ValueEnum for SortKind {
786 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
787 let d = match self {
788 &SortKind::Direction(d) => SortDirection::try_from(d),
789 SortKind::ComparisonFunctionReference(f) => {
790 return Err(PlanError::invalid(
791 "SortKind",
792 Some(Cow::Owned(format!("function reference{f}"))),
793 "SortKind::ComparisonFunctionReference unimplemented",
794 ));
795 }
796 };
797 let s = match d {
798 Err(UnknownEnumValue(d)) => {
799 return Err(PlanError::invalid(
800 "SortKind",
801 Some(Cow::Owned(format!("unknown variant: {d:?}"))),
802 "Unknown SortDirection",
803 ));
804 }
805 Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
806 Ok(SortDirection::AscNullsLast) => "AscNullsLast",
807 Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
808 Ok(SortDirection::DescNullsLast) => "DescNullsLast",
809 Ok(SortDirection::Clustered) => "Clustered",
810 Ok(SortDirection::Unspecified) => {
811 return Err(PlanError::invalid(
812 "SortKind",
813 Option::<Cow<str>>::None,
814 "Unspecified SortDirection",
815 ));
816 }
817 };
818 Ok(Cow::Borrowed(s))
819 }
820}
821
822impl ValueEnum for join_rel::JoinType {
823 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
824 let s = match self {
825 join_rel::JoinType::Unspecified => {
826 return Err(PlanError::invalid(
827 "JoinType",
828 Option::<Cow<str>>::None,
829 "Unspecified JoinType",
830 ));
831 }
832 join_rel::JoinType::Inner => "Inner",
833 join_rel::JoinType::Outer => "Outer",
834 join_rel::JoinType::Left => "Left",
835 join_rel::JoinType::Right => "Right",
836 join_rel::JoinType::LeftSemi => "LeftSemi",
837 join_rel::JoinType::RightSemi => "RightSemi",
838 join_rel::JoinType::LeftAnti => "LeftAnti",
839 join_rel::JoinType::RightAnti => "RightAnti",
840 join_rel::JoinType::LeftSingle => "LeftSingle",
841 join_rel::JoinType::RightSingle => "RightSingle",
842 join_rel::JoinType::LeftMark => "LeftMark",
843 join_rel::JoinType::RightMark => "RightMark",
844 };
845 Ok(Cow::Borrowed(s))
846 }
847}
848
849impl<'a> Textify for NamedArg<'a> {
850 fn name() -> &'static str {
851 "NamedArg"
852 }
853 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
854 write!(w, "{}=", self.name)?;
855 self.value.textify(ctx, w)
856 }
857}
858
859#[cfg(test)]
860mod tests {
861 use substrait::proto::expression::literal::LiteralType;
862 use substrait::proto::expression::{Literal, RexType, ScalarFunction};
863 use substrait::proto::function_argument::ArgType;
864 use substrait::proto::read_rel::{NamedTable, ReadType};
865 use substrait::proto::rel_common::Emit;
866 use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
867 use substrait::proto::{
868 Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
869 };
870
871 use super::*;
872 use crate::fixtures::TestContext;
873 use crate::parser::expressions::FieldIndex;
874
875 #[test]
876 fn test_read_rel() {
877 let ctx = TestContext::new();
878
879 let read_rel = ReadRel {
881 common: None,
882 base_schema: Some(NamedStruct {
883 names: vec!["col1".into(), "column 2".into()],
884 r#struct: Some(Struct {
885 type_variation_reference: 0,
886 types: vec![
887 Type {
888 kind: Some(Kind::I32(ptype::I32 {
889 type_variation_reference: 0,
890 nullability: Nullability::Nullable as i32,
891 })),
892 },
893 Type {
894 kind: Some(Kind::String(ptype::String {
895 type_variation_reference: 0,
896 nullability: Nullability::Nullable as i32,
897 })),
898 },
899 ],
900 nullability: Nullability::Nullable as i32,
901 }),
902 }),
903 filter: None,
904 best_effort_filter: None,
905 projection: None,
906 advanced_extension: None,
907 read_type: Some(ReadType::NamedTable(NamedTable {
908 names: vec!["some_db".into(), "test_table".into()],
909 advanced_extension: None,
910 })),
911 };
912
913 let rel = Relation::from(&read_rel);
914
915 let (result, errors) = ctx.textify(&rel);
916 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
917 assert_eq!(
918 result,
919 "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
920 );
921 }
922
923 #[test]
924 fn test_filter_rel() {
925 let ctx = TestContext::new()
926 .with_uri(1, "test_uri")
927 .with_function(1, 10, "gt");
928
929 let read_rel = ReadRel {
931 common: None,
932 base_schema: Some(NamedStruct {
933 names: vec!["col1".into(), "col2".into()],
934 r#struct: Some(Struct {
935 type_variation_reference: 0,
936 types: vec![
937 Type {
938 kind: Some(Kind::I32(ptype::I32 {
939 type_variation_reference: 0,
940 nullability: Nullability::Nullable as i32,
941 })),
942 },
943 Type {
944 kind: Some(Kind::I32(ptype::I32 {
945 type_variation_reference: 0,
946 nullability: Nullability::Nullable as i32,
947 })),
948 },
949 ],
950 nullability: Nullability::Nullable as i32,
951 }),
952 }),
953 filter: None,
954 best_effort_filter: None,
955 projection: None,
956 advanced_extension: None,
957 read_type: Some(ReadType::NamedTable(NamedTable {
958 names: vec!["test_table".into()],
959 advanced_extension: None,
960 })),
961 };
962
963 let filter_expr = Expression {
965 rex_type: Some(RexType::ScalarFunction(ScalarFunction {
966 function_reference: 10, arguments: vec![
968 FunctionArgument {
969 arg_type: Some(ArgType::Value(Reference(0).into())),
970 },
971 FunctionArgument {
972 arg_type: Some(ArgType::Value(Expression {
973 rex_type: Some(RexType::Literal(Literal {
974 literal_type: Some(LiteralType::I32(10)),
975 nullable: false,
976 type_variation_reference: 0,
977 })),
978 })),
979 },
980 ],
981 options: vec![],
982 output_type: None,
983 #[allow(deprecated)]
984 args: vec![],
985 })),
986 };
987
988 let filter_rel = FilterRel {
989 common: None,
990 input: Some(Box::new(Rel {
991 rel_type: Some(RelType::Read(Box::new(read_rel))),
992 })),
993 condition: Some(Box::new(filter_expr)),
994 advanced_extension: None,
995 };
996
997 let rel = Rel {
998 rel_type: Some(RelType::Filter(Box::new(filter_rel))),
999 };
1000
1001 let rel = Relation::from(&rel);
1002
1003 let (result, errors) = ctx.textify(&rel);
1004 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1005 let expected = r#"
1006Filter[gt($0, 10:i32) => $0, $1]
1007 Read[test_table => col1:i32?, col2:i32?]"#
1008 .trim_start();
1009 assert_eq!(result, expected);
1010 }
1011
1012 #[test]
1013 fn test_aggregate_function_textify() {
1014 let ctx = TestContext::new()
1015 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1016 .with_function(1, 10, "sum")
1017 .with_function(1, 11, "count");
1018
1019 let agg_fn = AggregateFunction {
1021 function_reference: 10, arguments: vec![FunctionArgument {
1023 arg_type: Some(ArgType::Value(Expression {
1024 rex_type: Some(RexType::Selection(Box::new(
1025 FieldIndex(1).to_field_reference(),
1026 ))),
1027 })),
1028 }],
1029 options: vec![],
1030 output_type: None,
1031 invocation: 0,
1032 phase: 0,
1033 sorts: vec![],
1034 #[allow(deprecated)]
1035 args: vec![],
1036 };
1037
1038 let value = Value::AggregateFunction(&agg_fn);
1039 let (result, errors) = ctx.textify(&value);
1040
1041 println!("Textification result: {result}");
1042 if !errors.is_empty() {
1043 println!("Errors: {errors:?}");
1044 }
1045
1046 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1047 assert_eq!(result, "sum($1)");
1048 }
1049
1050 #[test]
1051 fn test_aggregate_relation_textify() {
1052 let ctx = TestContext::new()
1053 .with_uri(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1054 .with_function(1, 10, "sum")
1055 .with_function(1, 11, "count");
1056
1057 let agg_fn1 = AggregateFunction {
1059 function_reference: 10, arguments: vec![FunctionArgument {
1061 arg_type: Some(ArgType::Value(Expression {
1062 rex_type: Some(RexType::Selection(Box::new(
1063 FieldIndex(1).to_field_reference(),
1064 ))),
1065 })),
1066 }],
1067 options: vec![],
1068 output_type: None,
1069 invocation: 0,
1070 phase: 0,
1071 sorts: vec![],
1072 #[allow(deprecated)]
1073 args: vec![],
1074 };
1075
1076 let agg_fn2 = AggregateFunction {
1077 function_reference: 11, arguments: vec![FunctionArgument {
1079 arg_type: Some(ArgType::Value(Expression {
1080 rex_type: Some(RexType::Selection(Box::new(
1081 FieldIndex(1).to_field_reference(),
1082 ))),
1083 })),
1084 }],
1085 options: vec![],
1086 output_type: None,
1087 invocation: 0,
1088 phase: 0,
1089 sorts: vec![],
1090 #[allow(deprecated)]
1091 args: vec![],
1092 };
1093
1094 let aggregate_rel = AggregateRel {
1095 input: Some(Box::new(Rel {
1096 rel_type: Some(RelType::Read(Box::new(ReadRel {
1097 common: None,
1098 base_schema: Some(NamedStruct {
1099 names: vec!["category".into(), "amount".into()],
1100 r#struct: Some(Struct {
1101 type_variation_reference: 0,
1102 types: vec![
1103 Type {
1104 kind: Some(Kind::String(ptype::String {
1105 type_variation_reference: 0,
1106 nullability: Nullability::Nullable as i32,
1107 })),
1108 },
1109 Type {
1110 kind: Some(Kind::Fp64(ptype::Fp64 {
1111 type_variation_reference: 0,
1112 nullability: Nullability::Nullable as i32,
1113 })),
1114 },
1115 ],
1116 nullability: Nullability::Nullable as i32,
1117 }),
1118 }),
1119 filter: None,
1120 best_effort_filter: None,
1121 projection: None,
1122 advanced_extension: None,
1123 read_type: Some(ReadType::NamedTable(NamedTable {
1124 names: vec!["orders".into()],
1125 advanced_extension: None,
1126 })),
1127 }))),
1128 })),
1129 grouping_expressions: vec![Expression {
1130 rex_type: Some(RexType::Selection(Box::new(
1131 FieldIndex(0).to_field_reference(),
1132 ))),
1133 }],
1134 groupings: vec![],
1135 measures: vec![
1136 aggregate_rel::Measure {
1137 measure: Some(agg_fn1),
1138 filter: None,
1139 },
1140 aggregate_rel::Measure {
1141 measure: Some(agg_fn2),
1142 filter: None,
1143 },
1144 ],
1145 common: Some(RelCommon {
1146 emit_kind: Some(EmitKind::Emit(Emit {
1147 output_mapping: vec![1, 2], })),
1149 ..Default::default()
1150 }),
1151 advanced_extension: None,
1152 };
1153
1154 let relation = Relation::from(&aggregate_rel);
1155 let (result, errors) = ctx.textify(&relation);
1156
1157 println!("Aggregate relation textification result:");
1158 println!("{result}");
1159 if !errors.is_empty() {
1160 println!("Errors: {errors:?}");
1161 }
1162
1163 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1164 assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
1166 }
1167
1168 #[test]
1169 fn test_arguments_textify_positional_only() {
1170 let ctx = TestContext::new();
1171 let args = Arguments {
1172 positional: vec![Value::Integer(42), Value::Integer(7)],
1173 named: vec![],
1174 };
1175 let (result, errors) = ctx.textify(&args);
1176 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1177 assert_eq!(result, "42, 7");
1178 }
1179
1180 #[test]
1181 fn test_arguments_textify_named_only() {
1182 let ctx = TestContext::new();
1183 let args = Arguments {
1184 positional: vec![],
1185 named: vec![
1186 NamedArg {
1187 name: "limit",
1188 value: Value::Integer(10),
1189 },
1190 NamedArg {
1191 name: "offset",
1192 value: Value::Integer(5),
1193 },
1194 ],
1195 };
1196 let (result, errors) = ctx.textify(&args);
1197 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1198 assert_eq!(result, "limit=10, offset=5");
1199 }
1200
1201 #[test]
1202 fn test_join_relation_unknown_type() {
1203 let ctx = TestContext::new();
1204
1205 let join_rel = JoinRel {
1207 left: Some(Box::new(Rel {
1208 rel_type: Some(RelType::Read(Box::default())),
1209 })),
1210 right: Some(Box::new(Rel {
1211 rel_type: Some(RelType::Read(Box::default())),
1212 })),
1213 expression: Some(Box::new(Expression::default())),
1214 r#type: 999, common: None,
1216 post_join_filter: None,
1217 advanced_extension: None,
1218 };
1219
1220 let relation = Relation::from(&join_rel);
1221 let (result, errors) = ctx.textify(&relation);
1222
1223 assert!(!errors.is_empty(), "Expected errors for unknown join type");
1225 assert!(
1226 result.contains("!{JoinRel}"),
1227 "Expected error token for unknown join type"
1228 );
1229 assert!(
1230 result.contains("Join["),
1231 "Expected Join relation to be formatted"
1232 );
1233 println!("Unknown join type result: {result}");
1234 }
1235
1236 #[test]
1237 fn test_arguments_textify_both() {
1238 let ctx = TestContext::new();
1239 let args = Arguments {
1240 positional: vec![Value::Integer(1)],
1241 named: vec![NamedArg {
1242 name: "foo",
1243 value: Value::Integer(2),
1244 }],
1245 };
1246 let (result, errors) = ctx.textify(&args);
1247 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1248 assert_eq!(result, "1, foo=2");
1249 }
1250
1251 #[test]
1252 fn test_arguments_textify_empty() {
1253 let ctx = TestContext::new();
1254 let args = Arguments {
1255 positional: vec![],
1256 named: vec![],
1257 };
1258 let (result, errors) = ctx.textify(&args);
1259 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1260 assert_eq!(result, "_");
1261 }
1262
1263 #[test]
1264 fn test_named_arg_textify_error_token() {
1265 let ctx = TestContext::new();
1266 let named_arg = NamedArg {
1267 name: "foo",
1268 value: Value::Missing(PlanError::invalid(
1269 "my_enum",
1270 Some(Cow::Borrowed("my_enum")),
1271 Cow::Borrowed("my_enum"),
1272 )),
1273 };
1274 let (result, errors) = ctx.textify(&named_arg);
1275 assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1277 assert!(!errors.is_empty(), "Expected error for error token");
1279 }
1280
1281 #[test]
1282 fn test_join_type_enum_textify() {
1283 assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1285 assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1286 assert_eq!(
1287 join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1288 "LeftSemi"
1289 );
1290 assert_eq!(
1291 join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1292 "LeftAnti"
1293 );
1294 }
1295
1296 #[test]
1297 fn test_join_output_columns() {
1298 let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1300 assert_eq!(inner_cols.len(), 5); assert!(matches!(inner_cols[0], Value::Reference(0)));
1302 assert!(matches!(inner_cols[4], Value::Reference(4)));
1303
1304 let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1306 assert_eq!(left_semi_cols.len(), 2); assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1308 assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1309
1310 let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1312 assert_eq!(right_semi_cols.len(), 3); assert!(matches!(right_semi_cols[0], Value::Reference(0))); assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1315 assert!(matches!(right_semi_cols[2], Value::Reference(2))); let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1319 assert_eq!(left_mark_cols.len(), 3); assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1321 assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1322 assert!(matches!(left_mark_cols[2], Value::Reference(2))); let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1326 assert_eq!(right_mark_cols.len(), 4); assert!(matches!(right_mark_cols[0], Value::Reference(0))); assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1329 assert!(matches!(right_mark_cols[2], Value::Reference(2))); assert!(matches!(right_mark_cols[3], Value::Reference(3))); }
1332}