1use std::borrow::Cow;
2use std::convert::TryFrom;
3use std::fmt;
4use std::fmt::Debug;
5
6use prost::UnknownEnumValue;
7use substrait::proto::fetch_rel::CountMode;
8use substrait::proto::plan_rel::RelType as PlanRelType;
9use substrait::proto::read_rel::ReadType;
10use substrait::proto::rel::RelType;
11use substrait::proto::rel_common::EmitKind;
12use substrait::proto::sort_field::{SortDirection, SortKind};
13use substrait::proto::{
14 AggregateFunction, AggregateRel, Expression, FetchRel, FilterRel, JoinRel, NamedStruct,
15 PlanRel, ProjectRel, ReadRel, Rel, RelCommon, RelRoot, SortField, SortRel, Type, join_rel,
16};
17
18use super::expressions::Reference;
19use super::types::Name;
20use super::{PlanError, Scope, Textify};
21
22pub trait NamedRelation {
23 fn name(&self) -> &'static str;
24}
25
26impl NamedRelation for Rel {
27 fn name(&self) -> &'static str {
28 match self.rel_type.as_ref() {
29 None => "UnknownRel",
30 Some(RelType::Read(_)) => "Read",
31 Some(RelType::Filter(_)) => "Filter",
32 Some(RelType::Project(_)) => "Project",
33 Some(RelType::Fetch(_)) => "Fetch",
34 Some(RelType::Aggregate(_)) => "Aggregate",
35 Some(RelType::Sort(_)) => "Sort",
36 Some(RelType::HashJoin(_)) => "HashJoin",
37 Some(RelType::Exchange(_)) => "Exchange",
38 Some(RelType::Join(_)) => "Join",
39 Some(RelType::Set(_)) => "Set",
40 Some(RelType::ExtensionLeaf(_)) => "ExtensionLeaf",
41 Some(RelType::Cross(_)) => "Cross",
42 Some(RelType::Reference(_)) => "Reference",
43 Some(RelType::ExtensionSingle(_)) => "ExtensionSingle",
44 Some(RelType::ExtensionMulti(_)) => "ExtensionMulti",
45 Some(RelType::Write(_)) => "Write",
46 Some(RelType::Ddl(_)) => "Ddl",
47 Some(RelType::Update(_)) => "Update",
48 Some(RelType::MergeJoin(_)) => "MergeJoin",
49 Some(RelType::NestedLoopJoin(_)) => "NestedLoopJoin",
50 Some(RelType::Window(_)) => "Window",
51 Some(RelType::Expand(_)) => "Expand",
52 }
53 }
54}
55
56pub trait ValueEnum {
62 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError>;
63}
64
65#[derive(Debug, Clone)]
66pub struct NamedArg<'a> {
67 pub name: &'a str,
68 pub value: Value<'a>,
69}
70
71#[derive(Debug, Clone)]
72pub enum Value<'a> {
73 Name(Name<'a>),
74 TableName(Vec<Name<'a>>),
75 Field(Option<Name<'a>>, Option<&'a Type>),
76 Tuple(Vec<Value<'a>>),
77 List(Vec<Value<'a>>),
78 Reference(i32),
79 Expression(&'a Expression),
80 AggregateFunction(&'a AggregateFunction),
81 Missing(PlanError),
83 Enum(Cow<'a, str>),
85 Integer(i32),
86}
87
88impl<'a> Value<'a> {
89 pub fn expect(maybe_value: Option<Self>, f: impl FnOnce() -> PlanError) -> Self {
90 match maybe_value {
91 Some(s) => s,
92 None => Value::Missing(f()),
93 }
94 }
95}
96
97impl<'a> From<Result<Vec<Name<'a>>, PlanError>> for Value<'a> {
98 fn from(token: Result<Vec<Name<'a>>, PlanError>) -> Self {
99 match token {
100 Ok(value) => Value::TableName(value),
101 Err(err) => Value::Missing(err),
102 }
103 }
104}
105
106impl<'a> Textify for Value<'a> {
107 fn name() -> &'static str {
108 "Value"
109 }
110
111 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
112 match self {
113 Value::Name(name) => write!(w, "{}", ctx.display(name)),
114 Value::TableName(names) => write!(w, "{}", ctx.separated(names, ".")),
115 Value::Field(name, typ) => {
116 write!(w, "{}:{}", ctx.expect(name.as_ref()), ctx.expect(*typ))
117 }
118 Value::Tuple(values) => write!(w, "({})", ctx.separated(values, ", ")),
119 Value::List(values) => write!(w, "[{}]", ctx.separated(values, ", ")),
120 Value::Reference(i) => write!(w, "{}", Reference(*i)),
121 Value::Expression(e) => write!(w, "{}", ctx.display(*e)),
122 Value::AggregateFunction(agg_fn) => agg_fn.textify(ctx, w),
123 Value::Missing(err) => write!(w, "{}", ctx.failure(err.clone())),
124 Value::Enum(res) => write!(w, "&{res}"),
125 Value::Integer(i) => write!(w, "{i}"),
126 }
127 }
128}
129
130fn schema_to_values<'a>(schema: &'a NamedStruct) -> Vec<Value<'a>> {
131 let mut fields = schema
132 .r#struct
133 .as_ref()
134 .map(|s| s.types.iter())
135 .into_iter()
136 .flatten();
137 let mut names = schema.names.iter();
138
139 let mut values = Vec::new();
143 loop {
144 let field = fields.next();
145 let name = names.next().map(|n| Name(n));
146 if field.is_none() && name.is_none() {
147 break;
148 }
149
150 values.push(Value::Field(name, field));
151 }
152
153 values
154}
155
156struct Emitted<'a> {
157 pub values: &'a [Value<'a>],
158 pub emit: Option<&'a EmitKind>,
159}
160
161impl<'a> Emitted<'a> {
162 pub fn new(values: &'a [Value<'a>], emit: Option<&'a EmitKind>) -> Self {
163 Self { values, emit }
164 }
165
166 pub fn write_direct<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
167 write!(w, "{}", ctx.separated(self.values.iter(), ", "))
168 }
169}
170
171impl<'a> Textify for Emitted<'a> {
172 fn name() -> &'static str {
173 "Emitted"
174 }
175
176 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
177 if ctx.options().show_emit {
178 return self.write_direct(ctx, w);
179 }
180
181 let indices = match &self.emit {
182 Some(EmitKind::Emit(e)) => &e.output_mapping,
183 Some(EmitKind::Direct(_)) => return self.write_direct(ctx, w),
184 None => return self.write_direct(ctx, w),
185 };
186
187 for (i, &index) in indices.iter().enumerate() {
188 if i > 0 {
189 write!(w, ", ")?;
190 }
191
192 match self.values.get(index as usize) {
193 Some(value) => write!(w, "{}", ctx.display(value))?,
194 None => write!(w, "{}", ctx.failure(PlanError::invalid(
195 "Emitted",
196 Some("output_mapping"),
197 format!(
198 "Output mapping index {} is out of bounds for values collection of size {}",
199 index, self.values.len()
200 )
201 )))?,
202 }
203 }
204
205 Ok(())
206 }
207}
208
209#[derive(Debug, Clone)]
210pub struct Arguments<'a> {
211 pub positional: Vec<Value<'a>>,
213 pub named: Vec<NamedArg<'a>>,
215}
216
217impl<'a> Textify for Arguments<'a> {
218 fn name() -> &'static str {
219 "Arguments"
220 }
221 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
222 if self.positional.is_empty() && self.named.is_empty() {
223 return write!(w, "_");
224 }
225
226 write!(w, "{}", ctx.separated(self.positional.iter(), ", "))?;
227 if !self.positional.is_empty() && !self.named.is_empty() {
228 write!(w, ", ")?;
229 }
230 write!(w, "{}", ctx.separated(self.named.iter(), ", "))
231 }
232}
233
234pub struct Relation<'a> {
235 pub name: &'a str,
236 pub arguments: Option<Arguments<'a>>,
245 pub columns: Vec<Value<'a>>,
248 pub emit: Option<&'a EmitKind>,
250 pub children: Vec<Option<Relation<'a>>>,
252}
253
254impl Textify for Relation<'_> {
255 fn name() -> &'static str {
256 "Relation"
257 }
258
259 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
260 let cols = Emitted::new(&self.columns, self.emit);
261 let indent = ctx.indent();
262 let name = self.name;
263 let cols = ctx.display(&cols);
264 match &self.arguments {
265 None => {
266 write!(w, "{indent}{name}[{cols}]")?;
267 }
268 Some(args) => {
269 let args = ctx.display(args);
270 write!(w, "{indent}{name}[{args} => {cols}]")?;
271 }
272 }
273 let child_scope = ctx.push_indent();
274 for child in self.children.iter().flatten() {
275 writeln!(w)?;
276 child.textify(&child_scope, w)?;
277 }
278 Ok(())
279 }
280}
281
282impl<'a> Relation<'a> {
283 pub fn emitted(&self) -> usize {
284 match self.emit {
285 Some(EmitKind::Emit(e)) => e.output_mapping.len(),
286 Some(EmitKind::Direct(_)) => self.columns.len(),
287 None => self.columns.len(),
288 }
289 }
290}
291
292#[derive(Debug, Copy, Clone)]
293pub struct TableName<'a>(&'a [String]);
294
295impl<'a> Textify for TableName<'a> {
296 fn name() -> &'static str {
297 "TableName"
298 }
299
300 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
301 let names = self.0.iter().map(|n| Name(n)).collect::<Vec<_>>();
302 write!(w, "{}", ctx.separated(names.iter(), "."))
303 }
304}
305
306pub fn get_table_name(rel: Option<&ReadType>) -> Result<&[String], PlanError> {
307 match rel {
308 Some(ReadType::NamedTable(r)) => Ok(r.names.as_slice()),
309 _ => Err(PlanError::unimplemented(
310 "ReadRel",
311 Some("table_name"),
312 format!("Unexpected read type {rel:?}") as String,
313 )),
314 }
315}
316
317impl<'a> From<&'a ReadRel> for Relation<'a> {
318 fn from(rel: &'a ReadRel) -> Self {
319 let name = get_table_name(rel.read_type.as_ref());
320 let table_name: Value = match name {
321 Ok(n) => Value::TableName(n.iter().map(|n| Name(n)).collect()),
322 Err(e) => Value::Missing(e),
323 };
324
325 let columns = match rel.base_schema {
326 Some(ref schema) => schema_to_values(schema),
327 None => {
328 let err = PlanError::unimplemented(
329 "ReadRel",
330 Some("base_schema"),
331 "Base schema is required",
332 );
333 vec![Value::Missing(err)]
334 }
335 };
336 let emit = rel.common.as_ref().and_then(|c| c.emit_kind.as_ref());
337
338 Relation {
339 name: "Read",
340 arguments: Some(Arguments {
341 positional: vec![table_name],
342 named: vec![],
343 }),
344 columns,
345 emit,
346 children: vec![],
347 }
348 }
349}
350
351pub fn get_emit(rel: Option<&RelCommon>) -> Option<&EmitKind> {
352 rel.as_ref().and_then(|c| c.emit_kind.as_ref())
353}
354
355impl<'a> Relation<'a> {
356 pub fn input_refs(&self) -> Vec<Value<'a>> {
364 let len = self.emitted();
365 (0..len).map(|i| Value::Reference(i as i32)).collect()
366 }
367
368 pub fn convert_children(refs: Vec<Option<&'a Rel>>) -> (Vec<Option<Relation<'a>>>, usize) {
372 let mut children = vec![];
373 let mut inputs = 0;
374
375 for maybe_rel in refs {
376 match maybe_rel {
377 Some(rel) => {
378 let child = Relation::from(rel);
379 inputs += child.emitted();
380 children.push(Some(child));
381 }
382 None => children.push(None),
383 }
384 }
385
386 (children, inputs)
387 }
388}
389
390impl<'a> From<&'a FilterRel> for Relation<'a> {
391 fn from(rel: &'a FilterRel) -> Self {
392 let condition = rel
393 .condition
394 .as_ref()
395 .map(|c| Value::Expression(c.as_ref()));
396 let condition = Value::expect(condition, || {
397 PlanError::unimplemented("FilterRel", Some("condition"), "Condition is None")
398 });
399 let positional = vec![condition];
400 let arguments = Some(Arguments {
401 positional,
402 named: vec![],
403 });
404 let emit = get_emit(rel.common.as_ref());
405 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
406 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
407
408 Relation {
409 name: "Filter",
410 arguments,
411 columns,
412 emit,
413 children,
414 }
415 }
416}
417
418impl<'a> From<&'a ProjectRel> for Relation<'a> {
419 fn from(rel: &'a ProjectRel) -> Self {
420 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
421 let expressions = rel.expressions.iter().map(Value::Expression);
422 let mut columns: Vec<Value> = (0..columns).map(|i| Value::Reference(i as i32)).collect();
423 columns.extend(expressions);
424
425 Relation {
426 name: "Project",
427 arguments: None,
428 columns,
429 emit: get_emit(rel.common.as_ref()),
430 children,
431 }
432 }
433}
434
435impl<'a> From<&'a Rel> for Relation<'a> {
436 fn from(rel: &'a Rel) -> Self {
437 match rel.rel_type.as_ref() {
438 Some(RelType::Read(r)) => Relation::from(r.as_ref()),
439 Some(RelType::Filter(r)) => Relation::from(r.as_ref()),
440 Some(RelType::Project(r)) => Relation::from(r.as_ref()),
441 Some(RelType::Aggregate(r)) => Relation::from(r.as_ref()),
442 Some(RelType::Sort(r)) => Relation::from(r.as_ref()),
443 Some(RelType::Fetch(r)) => Relation::from(r.as_ref()),
444 Some(RelType::Join(r)) => Relation::from(r.as_ref()),
445 _ => todo!(),
446 }
447 }
448}
449
450impl<'a> From<&'a AggregateRel> for Relation<'a> {
451 fn from(rel: &'a AggregateRel) -> Self {
461 let positional = rel
463 .grouping_expressions
464 .iter()
465 .map(Value::Expression)
466 .collect::<Vec<_>>();
467
468 let arguments = Some(Arguments {
469 positional,
470 named: vec![],
471 });
472 let mut all_outputs: Vec<Value> = vec![];
474 let input_field_count = rel.grouping_expressions.len();
475 for i in 0..input_field_count {
476 all_outputs.push(Value::Reference(i as i32));
477 }
478
479 for m in &rel.measures {
482 if let Some(agg_fn) = m.measure.as_ref() {
483 all_outputs.push(Value::AggregateFunction(agg_fn));
484 }
485 }
486 let emit = get_emit(rel.common.as_ref());
487 Relation {
488 name: "Aggregate",
489 arguments,
490 columns: all_outputs,
491 emit,
492 children: rel
493 .input
494 .as_ref()
495 .map(|c| Some(Relation::from(c.as_ref())))
496 .into_iter()
497 .collect(),
498 }
499 }
500}
501
502impl Textify for RelRoot {
503 fn name() -> &'static str {
504 "RelRoot"
505 }
506
507 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
508 let names = self.names.iter().map(|n| Name(n)).collect::<Vec<_>>();
509
510 write!(
511 w,
512 "{}Root[{}]",
513 ctx.indent(),
514 ctx.separated(names.iter(), ", ")
515 )?;
516 let child_scope = ctx.push_indent();
517 for child in self.input.iter() {
518 let child = Relation::from(child);
519 writeln!(w)?;
520 child.textify(&child_scope, w)?;
521 }
522
523 Ok(())
524 }
525}
526
527impl Textify for PlanRelType {
528 fn name() -> &'static str {
529 "PlanRelType"
530 }
531
532 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
533 match self {
534 PlanRelType::Rel(rel) => Relation::from(rel).textify(ctx, w),
535 PlanRelType::Root(root) => root.textify(ctx, w),
536 }
537 }
538}
539
540impl Textify for PlanRel {
541 fn name() -> &'static str {
542 "PlanRel"
543 }
544
545 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
548 write!(w, "{}", ctx.expect(self.rel_type.as_ref()))
549 }
550}
551
552impl<'a> From<&'a SortRel> for Relation<'a> {
553 fn from(rel: &'a SortRel) -> Self {
554 let (children, columns) = Relation::convert_children(vec![rel.input.as_deref()]);
555 let positional = rel.sorts.iter().map(Value::from).collect::<Vec<_>>();
556 let arguments = Some(Arguments {
557 positional,
558 named: vec![],
559 });
560 let columns = (0..columns).map(|i| Value::Reference(i as i32)).collect();
562 let emit = get_emit(rel.common.as_ref());
563 Relation {
564 name: "Sort",
565 arguments,
566 columns,
567 emit,
568 children,
569 }
570 }
571}
572
573impl<'a> From<&'a FetchRel> for Relation<'a> {
574 fn from(rel: &'a FetchRel) -> Self {
575 let (children, _columns) = Relation::convert_children(vec![rel.input.as_deref()]);
576 let mut named_args = Vec::new();
577 match &rel.count_mode {
578 Some(CountMode::CountExpr(expr)) => {
579 named_args.push(NamedArg {
580 name: "limit",
581 value: Value::Expression(expr),
582 });
583 }
584 #[allow(deprecated)]
585 Some(CountMode::Count(val)) => {
586 named_args.push(NamedArg {
587 name: "limit",
588 value: Value::Integer(*val as i32),
589 });
590 }
591 None => {}
592 }
593 if let Some(offset) = &rel.offset_mode {
594 match offset {
595 substrait::proto::fetch_rel::OffsetMode::OffsetExpr(expr) => {
596 named_args.push(NamedArg {
597 name: "offset",
598 value: Value::Expression(expr),
599 });
600 }
601 #[allow(deprecated)]
602 substrait::proto::fetch_rel::OffsetMode::Offset(val) => {
603 named_args.push(NamedArg {
604 name: "offset",
605 value: Value::Integer(*val as i32),
606 });
607 }
608 }
609 }
610
611 let emit = get_emit(rel.common.as_ref());
612 let columns = match emit {
613 Some(EmitKind::Emit(e)) => e
614 .output_mapping
615 .iter()
616 .map(|&i| Value::Reference(i))
617 .collect(),
618 _ => vec![],
619 };
620 Relation {
621 name: "Fetch",
622 arguments: Some(Arguments {
623 positional: vec![],
624 named: named_args,
625 }),
626 columns,
627 emit,
628 children,
629 }
630 }
631}
632
633fn join_output_columns(
634 join_type: join_rel::JoinType,
635 left_columns: usize,
636 right_columns: usize,
637) -> Vec<Value<'static>> {
638 let total_columns = match join_type {
639 join_rel::JoinType::Inner
641 | join_rel::JoinType::Left
642 | join_rel::JoinType::Right
643 | join_rel::JoinType::Outer => left_columns + right_columns,
644
645 join_rel::JoinType::LeftSemi | join_rel::JoinType::LeftAnti => left_columns,
647
648 join_rel::JoinType::RightSemi | join_rel::JoinType::RightAnti => right_columns,
650
651 join_rel::JoinType::LeftSingle => left_columns,
653 join_rel::JoinType::RightSingle => right_columns,
654
655 join_rel::JoinType::LeftMark => left_columns + 1,
657 join_rel::JoinType::RightMark => right_columns + 1,
658
659 join_rel::JoinType::Unspecified => left_columns + right_columns,
661 };
662
663 (0..total_columns)
665 .map(|i| Value::Reference(i as i32))
666 .collect()
667}
668
669impl<'a> From<&'a JoinRel> for Relation<'a> {
670 fn from(rel: &'a JoinRel) -> Self {
671 let (children, _total_columns) =
672 Relation::convert_children(vec![rel.left.as_deref(), rel.right.as_deref()]);
673
674 assert_eq!(
676 children.len(),
677 2,
678 "convert_children should return same number of elements as input"
679 );
680
681 let left_columns = match &children[0] {
683 Some(child) => child.emitted(),
684 None => 0,
685 };
686 let right_columns = match &children[1] {
687 Some(child) => child.emitted(),
688 None => 0,
689 };
690
691 let (join_type, join_type_value) = match join_rel::JoinType::try_from(rel.r#type) {
694 Ok(join_type) => {
695 let join_type_value = match join_type.as_enum_str() {
696 Ok(s) => Value::Enum(s),
697 Err(e) => Value::Missing(e),
698 };
699 (join_type, join_type_value)
700 }
701 Err(_) => {
702 let join_type_error = Value::Missing(PlanError::invalid(
704 "JoinRel",
705 Some("type"),
706 format!("Unknown join type: {}", rel.r#type),
707 ));
708 (join_rel::JoinType::Unspecified, join_type_error)
709 }
710 };
711
712 let condition = rel
714 .expression
715 .as_ref()
716 .map(|c| Value::Expression(c.as_ref()));
717 let condition = Value::expect(condition, || {
718 PlanError::unimplemented("JoinRel", Some("expression"), "Join condition is None")
719 });
720
721 let positional = vec![join_type_value, condition];
725 let arguments = Some(Arguments {
726 positional,
727 named: vec![],
728 });
729
730 let emit = get_emit(rel.common.as_ref());
731 let columns = join_output_columns(join_type, left_columns, right_columns);
732
733 Relation {
734 name: "Join",
735 arguments,
736 columns,
737 emit,
738 children,
739 }
740 }
741}
742
743impl<'a> From<&'a SortField> for Value<'a> {
744 fn from(sf: &'a SortField) -> Self {
745 let field = match &sf.expr {
746 Some(expr) => match &expr.rex_type {
747 Some(substrait::proto::expression::RexType::Selection(fref)) => {
748 if let Some(substrait::proto::expression::field_reference::ReferenceType::DirectReference(seg)) = &fref.reference_type {
749 if let Some(substrait::proto::expression::reference_segment::ReferenceType::StructField(sf)) = &seg.reference_type {
750 Value::Reference(sf.field)
751 } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a struct field")) }
752 } else { Value::Missing(PlanError::unimplemented("SortField", Some("expr"), "Not a direct reference")) }
753 }
754 _ => Value::Missing(PlanError::unimplemented(
755 "SortField",
756 Some("expr"),
757 "Not a selection",
758 )),
759 },
760 None => Value::Missing(PlanError::unimplemented(
761 "SortField",
762 Some("expr"),
763 "Missing expr",
764 )),
765 };
766 let direction = match &sf.sort_kind {
767 Some(kind) => Value::from(kind),
768 None => Value::Missing(PlanError::invalid(
769 "SortKind",
770 Some(Cow::Borrowed("sort_kind")),
771 "Missing sort_kind",
772 )),
773 };
774 Value::Tuple(vec![field, direction])
775 }
776}
777
778impl<'a, T: ValueEnum + ?Sized> From<&'a T> for Value<'a> {
779 fn from(enum_val: &'a T) -> Self {
780 match enum_val.as_enum_str() {
781 Ok(s) => Value::Enum(s),
782 Err(e) => Value::Missing(e),
783 }
784 }
785}
786
787impl ValueEnum for SortKind {
788 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
789 let d = match self {
790 &SortKind::Direction(d) => SortDirection::try_from(d),
791 SortKind::ComparisonFunctionReference(f) => {
792 return Err(PlanError::invalid(
793 "SortKind",
794 Some(Cow::Owned(format!("function reference{f}"))),
795 "SortKind::ComparisonFunctionReference unimplemented",
796 ));
797 }
798 };
799 let s = match d {
800 Err(UnknownEnumValue(d)) => {
801 return Err(PlanError::invalid(
802 "SortKind",
803 Some(Cow::Owned(format!("unknown variant: {d:?}"))),
804 "Unknown SortDirection",
805 ));
806 }
807 Ok(SortDirection::AscNullsFirst) => "AscNullsFirst",
808 Ok(SortDirection::AscNullsLast) => "AscNullsLast",
809 Ok(SortDirection::DescNullsFirst) => "DescNullsFirst",
810 Ok(SortDirection::DescNullsLast) => "DescNullsLast",
811 Ok(SortDirection::Clustered) => "Clustered",
812 Ok(SortDirection::Unspecified) => {
813 return Err(PlanError::invalid(
814 "SortKind",
815 Option::<Cow<str>>::None,
816 "Unspecified SortDirection",
817 ));
818 }
819 };
820 Ok(Cow::Borrowed(s))
821 }
822}
823
824impl ValueEnum for join_rel::JoinType {
825 fn as_enum_str(&self) -> Result<Cow<'static, str>, PlanError> {
826 let s = match self {
827 join_rel::JoinType::Unspecified => {
828 return Err(PlanError::invalid(
829 "JoinType",
830 Option::<Cow<str>>::None,
831 "Unspecified JoinType",
832 ));
833 }
834 join_rel::JoinType::Inner => "Inner",
835 join_rel::JoinType::Outer => "Outer",
836 join_rel::JoinType::Left => "Left",
837 join_rel::JoinType::Right => "Right",
838 join_rel::JoinType::LeftSemi => "LeftSemi",
839 join_rel::JoinType::RightSemi => "RightSemi",
840 join_rel::JoinType::LeftAnti => "LeftAnti",
841 join_rel::JoinType::RightAnti => "RightAnti",
842 join_rel::JoinType::LeftSingle => "LeftSingle",
843 join_rel::JoinType::RightSingle => "RightSingle",
844 join_rel::JoinType::LeftMark => "LeftMark",
845 join_rel::JoinType::RightMark => "RightMark",
846 };
847 Ok(Cow::Borrowed(s))
848 }
849}
850
851impl<'a> Textify for NamedArg<'a> {
852 fn name() -> &'static str {
853 "NamedArg"
854 }
855 fn textify<S: Scope, W: fmt::Write>(&self, ctx: &S, w: &mut W) -> fmt::Result {
856 write!(w, "{}=", self.name)?;
857 self.value.textify(ctx, w)
858 }
859}
860
861#[cfg(test)]
862mod tests {
863 use substrait::proto::expression::literal::LiteralType;
864 use substrait::proto::expression::{Literal, RexType, ScalarFunction};
865 use substrait::proto::function_argument::ArgType;
866 use substrait::proto::read_rel::{NamedTable, ReadType};
867 use substrait::proto::rel_common::Emit;
868 use substrait::proto::r#type::{self as ptype, Kind, Nullability, Struct};
869 use substrait::proto::{
870 Expression, FunctionArgument, NamedStruct, ReadRel, Type, aggregate_rel,
871 };
872
873 use super::*;
874 use crate::fixtures::TestContext;
875 use crate::parser::expressions::FieldIndex;
876
877 #[test]
878 fn test_read_rel() {
879 let ctx = TestContext::new();
880
881 let read_rel = ReadRel {
883 common: None,
884 base_schema: Some(NamedStruct {
885 names: vec!["col1".into(), "column 2".into()],
886 r#struct: Some(Struct {
887 type_variation_reference: 0,
888 types: vec![
889 Type {
890 kind: Some(Kind::I32(ptype::I32 {
891 type_variation_reference: 0,
892 nullability: Nullability::Nullable as i32,
893 })),
894 },
895 Type {
896 kind: Some(Kind::String(ptype::String {
897 type_variation_reference: 0,
898 nullability: Nullability::Nullable as i32,
899 })),
900 },
901 ],
902 nullability: Nullability::Nullable as i32,
903 }),
904 }),
905 filter: None,
906 best_effort_filter: None,
907 projection: None,
908 advanced_extension: None,
909 read_type: Some(ReadType::NamedTable(NamedTable {
910 names: vec!["some_db".into(), "test_table".into()],
911 advanced_extension: None,
912 })),
913 };
914
915 let rel = Relation::from(&read_rel);
916
917 let (result, errors) = ctx.textify(&rel);
918 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
919 assert_eq!(
920 result,
921 "Read[some_db.test_table => col1:i32?, \"column 2\":string?]"
922 );
923 }
924
925 #[test]
926 fn test_filter_rel() {
927 let ctx = TestContext::new()
928 .with_urn(1, "test_urn")
929 .with_function(1, 10, "gt");
930
931 let read_rel = ReadRel {
933 common: None,
934 base_schema: Some(NamedStruct {
935 names: vec!["col1".into(), "col2".into()],
936 r#struct: Some(Struct {
937 type_variation_reference: 0,
938 types: vec![
939 Type {
940 kind: Some(Kind::I32(ptype::I32 {
941 type_variation_reference: 0,
942 nullability: Nullability::Nullable as i32,
943 })),
944 },
945 Type {
946 kind: Some(Kind::I32(ptype::I32 {
947 type_variation_reference: 0,
948 nullability: Nullability::Nullable as i32,
949 })),
950 },
951 ],
952 nullability: Nullability::Nullable as i32,
953 }),
954 }),
955 filter: None,
956 best_effort_filter: None,
957 projection: None,
958 advanced_extension: None,
959 read_type: Some(ReadType::NamedTable(NamedTable {
960 names: vec!["test_table".into()],
961 advanced_extension: None,
962 })),
963 };
964
965 let filter_expr = Expression {
967 rex_type: Some(RexType::ScalarFunction(ScalarFunction {
968 function_reference: 10, arguments: vec![
970 FunctionArgument {
971 arg_type: Some(ArgType::Value(Reference(0).into())),
972 },
973 FunctionArgument {
974 arg_type: Some(ArgType::Value(Expression {
975 rex_type: Some(RexType::Literal(Literal {
976 literal_type: Some(LiteralType::I32(10)),
977 nullable: false,
978 type_variation_reference: 0,
979 })),
980 })),
981 },
982 ],
983 options: vec![],
984 output_type: None,
985 #[allow(deprecated)]
986 args: vec![],
987 })),
988 };
989
990 let filter_rel = FilterRel {
991 common: None,
992 input: Some(Box::new(Rel {
993 rel_type: Some(RelType::Read(Box::new(read_rel))),
994 })),
995 condition: Some(Box::new(filter_expr)),
996 advanced_extension: None,
997 };
998
999 let rel = Rel {
1000 rel_type: Some(RelType::Filter(Box::new(filter_rel))),
1001 };
1002
1003 let rel = Relation::from(&rel);
1004
1005 let (result, errors) = ctx.textify(&rel);
1006 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1007 let expected = r#"
1008Filter[gt($0, 10:i32) => $0, $1]
1009 Read[test_table => col1:i32?, col2:i32?]"#
1010 .trim_start();
1011 assert_eq!(result, expected);
1012 }
1013
1014 #[test]
1015 fn test_aggregate_function_textify() {
1016 let ctx = TestContext::new()
1017 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1018 .with_function(1, 10, "sum")
1019 .with_function(1, 11, "count");
1020
1021 let agg_fn = AggregateFunction {
1023 function_reference: 10, arguments: vec![FunctionArgument {
1025 arg_type: Some(ArgType::Value(Expression {
1026 rex_type: Some(RexType::Selection(Box::new(
1027 FieldIndex(1).to_field_reference(),
1028 ))),
1029 })),
1030 }],
1031 options: vec![],
1032 output_type: None,
1033 invocation: 0,
1034 phase: 0,
1035 sorts: vec![],
1036 #[allow(deprecated)]
1037 args: vec![],
1038 };
1039
1040 let value = Value::AggregateFunction(&agg_fn);
1041 let (result, errors) = ctx.textify(&value);
1042
1043 println!("Textification result: {result}");
1044 if !errors.is_empty() {
1045 println!("Errors: {errors:?}");
1046 }
1047
1048 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1049 assert_eq!(result, "sum($1)");
1050 }
1051
1052 #[test]
1053 fn test_aggregate_relation_textify() {
1054 let ctx = TestContext::new()
1055 .with_urn(1, "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate.yaml")
1056 .with_function(1, 10, "sum")
1057 .with_function(1, 11, "count");
1058
1059 let agg_fn1 = AggregateFunction {
1061 function_reference: 10, arguments: vec![FunctionArgument {
1063 arg_type: Some(ArgType::Value(Expression {
1064 rex_type: Some(RexType::Selection(Box::new(
1065 FieldIndex(1).to_field_reference(),
1066 ))),
1067 })),
1068 }],
1069 options: vec![],
1070 output_type: None,
1071 invocation: 0,
1072 phase: 0,
1073 sorts: vec![],
1074 #[allow(deprecated)]
1075 args: vec![],
1076 };
1077
1078 let agg_fn2 = AggregateFunction {
1079 function_reference: 11, arguments: vec![FunctionArgument {
1081 arg_type: Some(ArgType::Value(Expression {
1082 rex_type: Some(RexType::Selection(Box::new(
1083 FieldIndex(1).to_field_reference(),
1084 ))),
1085 })),
1086 }],
1087 options: vec![],
1088 output_type: None,
1089 invocation: 0,
1090 phase: 0,
1091 sorts: vec![],
1092 #[allow(deprecated)]
1093 args: vec![],
1094 };
1095
1096 let aggregate_rel = AggregateRel {
1097 input: Some(Box::new(Rel {
1098 rel_type: Some(RelType::Read(Box::new(ReadRel {
1099 common: None,
1100 base_schema: Some(NamedStruct {
1101 names: vec!["category".into(), "amount".into()],
1102 r#struct: Some(Struct {
1103 type_variation_reference: 0,
1104 types: vec![
1105 Type {
1106 kind: Some(Kind::String(ptype::String {
1107 type_variation_reference: 0,
1108 nullability: Nullability::Nullable as i32,
1109 })),
1110 },
1111 Type {
1112 kind: Some(Kind::Fp64(ptype::Fp64 {
1113 type_variation_reference: 0,
1114 nullability: Nullability::Nullable as i32,
1115 })),
1116 },
1117 ],
1118 nullability: Nullability::Nullable as i32,
1119 }),
1120 }),
1121 filter: None,
1122 best_effort_filter: None,
1123 projection: None,
1124 advanced_extension: None,
1125 read_type: Some(ReadType::NamedTable(NamedTable {
1126 names: vec!["orders".into()],
1127 advanced_extension: None,
1128 })),
1129 }))),
1130 })),
1131 grouping_expressions: vec![Expression {
1132 rex_type: Some(RexType::Selection(Box::new(
1133 FieldIndex(0).to_field_reference(),
1134 ))),
1135 }],
1136 groupings: vec![],
1137 measures: vec![
1138 aggregate_rel::Measure {
1139 measure: Some(agg_fn1),
1140 filter: None,
1141 },
1142 aggregate_rel::Measure {
1143 measure: Some(agg_fn2),
1144 filter: None,
1145 },
1146 ],
1147 common: Some(RelCommon {
1148 emit_kind: Some(EmitKind::Emit(Emit {
1149 output_mapping: vec![1, 2], })),
1151 ..Default::default()
1152 }),
1153 advanced_extension: None,
1154 };
1155
1156 let relation = Relation::from(&aggregate_rel);
1157 let (result, errors) = ctx.textify(&relation);
1158
1159 println!("Aggregate relation textification result:");
1160 println!("{result}");
1161 if !errors.is_empty() {
1162 println!("Errors: {errors:?}");
1163 }
1164
1165 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1166 assert!(result.contains("Aggregate[$0 => sum($1), count($1)]"));
1168 }
1169
1170 #[test]
1171 fn test_arguments_textify_positional_only() {
1172 let ctx = TestContext::new();
1173 let args = Arguments {
1174 positional: vec![Value::Integer(42), Value::Integer(7)],
1175 named: vec![],
1176 };
1177 let (result, errors) = ctx.textify(&args);
1178 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1179 assert_eq!(result, "42, 7");
1180 }
1181
1182 #[test]
1183 fn test_arguments_textify_named_only() {
1184 let ctx = TestContext::new();
1185 let args = Arguments {
1186 positional: vec![],
1187 named: vec![
1188 NamedArg {
1189 name: "limit",
1190 value: Value::Integer(10),
1191 },
1192 NamedArg {
1193 name: "offset",
1194 value: Value::Integer(5),
1195 },
1196 ],
1197 };
1198 let (result, errors) = ctx.textify(&args);
1199 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1200 assert_eq!(result, "limit=10, offset=5");
1201 }
1202
1203 #[test]
1204 fn test_join_relation_unknown_type() {
1205 let ctx = TestContext::new();
1206
1207 let join_rel = JoinRel {
1209 left: Some(Box::new(Rel {
1210 rel_type: Some(RelType::Read(Box::default())),
1211 })),
1212 right: Some(Box::new(Rel {
1213 rel_type: Some(RelType::Read(Box::default())),
1214 })),
1215 expression: Some(Box::new(Expression::default())),
1216 r#type: 999, common: None,
1218 post_join_filter: None,
1219 advanced_extension: None,
1220 };
1221
1222 let relation = Relation::from(&join_rel);
1223 let (result, errors) = ctx.textify(&relation);
1224
1225 assert!(!errors.is_empty(), "Expected errors for unknown join type");
1227 assert!(
1228 result.contains("!{JoinRel}"),
1229 "Expected error token for unknown join type"
1230 );
1231 assert!(
1232 result.contains("Join["),
1233 "Expected Join relation to be formatted"
1234 );
1235 println!("Unknown join type result: {result}");
1236 }
1237
1238 #[test]
1239 fn test_arguments_textify_both() {
1240 let ctx = TestContext::new();
1241 let args = Arguments {
1242 positional: vec![Value::Integer(1)],
1243 named: vec![NamedArg {
1244 name: "foo",
1245 value: Value::Integer(2),
1246 }],
1247 };
1248 let (result, errors) = ctx.textify(&args);
1249 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1250 assert_eq!(result, "1, foo=2");
1251 }
1252
1253 #[test]
1254 fn test_arguments_textify_empty() {
1255 let ctx = TestContext::new();
1256 let args = Arguments {
1257 positional: vec![],
1258 named: vec![],
1259 };
1260 let (result, errors) = ctx.textify(&args);
1261 assert!(errors.is_empty(), "Expected no errors, got: {errors:?}");
1262 assert_eq!(result, "_");
1263 }
1264
1265 #[test]
1266 fn test_named_arg_textify_error_token() {
1267 let ctx = TestContext::new();
1268 let named_arg = NamedArg {
1269 name: "foo",
1270 value: Value::Missing(PlanError::invalid(
1271 "my_enum",
1272 Some(Cow::Borrowed("my_enum")),
1273 Cow::Borrowed("my_enum"),
1274 )),
1275 };
1276 let (result, errors) = ctx.textify(&named_arg);
1277 assert!(result.contains("foo=!{my_enum}"), "Output: {result}");
1279 assert!(!errors.is_empty(), "Expected error for error token");
1281 }
1282
1283 #[test]
1284 fn test_join_type_enum_textify() {
1285 assert_eq!(join_rel::JoinType::Inner.as_enum_str().unwrap(), "Inner");
1287 assert_eq!(join_rel::JoinType::Left.as_enum_str().unwrap(), "Left");
1288 assert_eq!(
1289 join_rel::JoinType::LeftSemi.as_enum_str().unwrap(),
1290 "LeftSemi"
1291 );
1292 assert_eq!(
1293 join_rel::JoinType::LeftAnti.as_enum_str().unwrap(),
1294 "LeftAnti"
1295 );
1296 }
1297
1298 #[test]
1299 fn test_join_output_columns() {
1300 let inner_cols = super::join_output_columns(join_rel::JoinType::Inner, 2, 3);
1302 assert_eq!(inner_cols.len(), 5); assert!(matches!(inner_cols[0], Value::Reference(0)));
1304 assert!(matches!(inner_cols[4], Value::Reference(4)));
1305
1306 let left_semi_cols = super::join_output_columns(join_rel::JoinType::LeftSemi, 2, 3);
1308 assert_eq!(left_semi_cols.len(), 2); assert!(matches!(left_semi_cols[0], Value::Reference(0)));
1310 assert!(matches!(left_semi_cols[1], Value::Reference(1)));
1311
1312 let right_semi_cols = super::join_output_columns(join_rel::JoinType::RightSemi, 2, 3);
1314 assert_eq!(right_semi_cols.len(), 3); assert!(matches!(right_semi_cols[0], Value::Reference(0))); assert!(matches!(right_semi_cols[1], Value::Reference(1)));
1317 assert!(matches!(right_semi_cols[2], Value::Reference(2))); let left_mark_cols = super::join_output_columns(join_rel::JoinType::LeftMark, 2, 3);
1321 assert_eq!(left_mark_cols.len(), 3); assert!(matches!(left_mark_cols[0], Value::Reference(0)));
1323 assert!(matches!(left_mark_cols[1], Value::Reference(1)));
1324 assert!(matches!(left_mark_cols[2], Value::Reference(2))); let right_mark_cols = super::join_output_columns(join_rel::JoinType::RightMark, 2, 3);
1328 assert_eq!(right_mark_cols.len(), 4); assert!(matches!(right_mark_cols[0], Value::Reference(0))); assert!(matches!(right_mark_cols[1], Value::Reference(1)));
1331 assert!(matches!(right_mark_cols[2], Value::Reference(2))); assert!(matches!(right_mark_cols[3], Value::Reference(3))); }
1334}