substrait_explain/extensions/
args.rs1use std::collections::HashSet;
29use std::fmt;
30
31use indexmap::IndexMap;
32use substrait::proto;
33use substrait::proto::expression::field_reference::ReferenceType;
34use substrait::proto::expression::literal::LiteralType;
35use substrait::proto::expression::{RexType, reference_segment};
36
37use super::ExtensionError;
38use crate::textify::expressions::Reference;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum AddendumKind {
47 Enhancement,
48 Optimization,
49 ExtensionTable,
50}
51
52impl AddendumKind {
53 pub(crate) fn prefix(self) -> &'static str {
54 match self {
55 AddendumKind::Enhancement => "Enh",
56 AddendumKind::Optimization => "Opt",
57 AddendumKind::ExtensionTable => "Ext",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
67pub struct Expr(Box<proto::Expression>);
68
69impl Expr {
70 pub fn field(index: i32) -> Self {
72 Reference(index).into()
73 }
74
75 pub fn as_proto(&self) -> &proto::Expression {
77 self.0.as_ref()
78 }
79
80 pub fn to_proto(&self) -> proto::Expression {
82 self.as_proto().clone()
83 }
84
85 pub fn as_direct_reference(&self) -> Option<i32> {
87 let Some(RexType::Selection(field_ref)) = self.as_proto().rex_type.as_ref() else {
88 return None;
89 };
90 let Some(ReferenceType::DirectReference(segment)) = field_ref.reference_type.as_ref()
91 else {
92 return None;
93 };
94 let Some(reference_segment::ReferenceType::StructField(field)) =
95 segment.reference_type.as_ref()
96 else {
97 return None;
98 };
99 if field.child.is_some() {
100 return None;
101 }
102 Some(field.field)
103 }
104}
105
106impl From<proto::Expression> for Expr {
107 fn from(expr: proto::Expression) -> Self {
108 Expr(Box::new(expr))
109 }
110}
111
112impl From<proto::expression::Literal> for Expr {
113 fn from(literal: proto::expression::Literal) -> Self {
114 proto::Expression {
115 rex_type: Some(RexType::Literal(literal)),
116 }
117 .into()
118 }
119}
120
121impl From<Reference> for Expr {
122 fn from(reference: Reference) -> Self {
123 proto::Expression::from(reference).into()
124 }
125}
126
127impl From<Expr> for proto::Expression {
128 fn from(expr: Expr) -> Self {
129 *expr.0
130 }
131}
132
133impl From<i64> for Expr {
134 fn from(value: i64) -> Self {
135 proto::expression::Literal {
136 literal_type: Some(LiteralType::I64(value)),
137 nullable: false,
138 type_variation_reference: 0,
139 }
140 .into()
141 }
142}
143
144impl From<f64> for Expr {
145 fn from(value: f64) -> Self {
146 proto::expression::Literal {
147 literal_type: Some(LiteralType::Fp64(value)),
148 nullable: false,
149 type_variation_reference: 0,
150 }
151 .into()
152 }
153}
154
155impl From<bool> for Expr {
156 fn from(value: bool) -> Self {
157 proto::expression::Literal {
158 literal_type: Some(LiteralType::Boolean(value)),
159 nullable: false,
160 type_variation_reference: 0,
161 }
162 .into()
163 }
164}
165
166impl From<String> for Expr {
167 fn from(value: String) -> Self {
168 proto::expression::Literal {
169 literal_type: Some(LiteralType::String(value)),
170 nullable: false,
171 type_variation_reference: 0,
172 }
173 .into()
174 }
175}
176
177impl From<&str> for Expr {
178 fn from(value: &str) -> Self {
179 value.to_string().into()
180 }
181}
182
183#[derive(Debug, Clone, Default)]
190pub struct ExtensionArgs {
191 pub positional: Vec<ExtensionValue>,
193 pub named: IndexMap<String, ExtensionValue>,
195 pub output_columns: Vec<ExtensionColumn>,
197}
198
199pub struct ArgsExtractor<'a> {
207 args: &'a ExtensionArgs,
208 consumed: HashSet<&'a str>,
209 checked: bool,
210}
211
212impl<'a> ArgsExtractor<'a> {
213 pub fn new(args: &'a ExtensionArgs) -> Self {
215 Self {
216 args,
217 consumed: HashSet::new(),
218 checked: false,
219 }
220 }
221
222 pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
224 match self.args.named.get_key_value(name) {
225 Some((k, value)) => {
226 self.consumed.insert(k);
227 Some(value)
228 }
229 None => None,
230 }
231 }
232
233 pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
236 where
237 T: TryFrom<&'a ExtensionValue>,
238 T::Error: Into<ExtensionError>,
239 {
240 match self.get_named_arg(name) {
241 Some(value) => T::try_from(value).map_err(Into::into),
242 None => Err(ExtensionError::MissingArgument {
243 name: name.to_string(),
244 }),
245 }
246 }
247
248 pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
251 where
252 T: TryFrom<&'a ExtensionValue>,
253 T::Error: Into<ExtensionError>,
254 {
255 match self.get_named_arg(name) {
256 Some(value) => T::try_from(value).map_err(Into::into),
257 None => Ok(default),
258 }
259 }
260
261 pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
268 self.checked = true;
269
270 let mut unknown_args = Vec::new();
271 for name in self.args.named.keys() {
272 if !self.consumed.contains(name.as_str()) {
273 unknown_args.push(name.as_str());
274 }
275 }
276
277 if unknown_args.is_empty() {
278 Ok(())
279 } else {
280 unknown_args.sort();
282 Err(ExtensionError::InvalidArgument(format!(
283 "Unknown named arguments: {}",
284 unknown_args.join(", ")
285 )))
286 }
287 }
288}
289
290impl Drop for ArgsExtractor<'_> {
291 fn drop(&mut self) {
292 if self.checked || std::thread::panicking() {
293 return;
294 }
295 debug_assert!(
297 false,
298 "ArgsExtractor dropped without calling check_exhausted()"
299 );
300 }
301}
302
303#[derive(Debug, Clone)]
308pub struct TupleValue(Vec<ExtensionValue>);
309
310impl TupleValue {
311 pub fn len(&self) -> usize {
312 self.0.len()
313 }
314
315 pub fn is_empty(&self) -> bool {
316 self.0.is_empty()
317 }
318
319 pub fn iter(&self) -> std::slice::Iter<'_, ExtensionValue> {
320 self.0.iter()
321 }
322}
323
324impl<'a> IntoIterator for &'a TupleValue {
325 type Item = &'a ExtensionValue;
326 type IntoIter = std::slice::Iter<'a, ExtensionValue>;
327
328 fn into_iter(self) -> Self::IntoIter {
329 self.0.iter()
330 }
331}
332
333impl IntoIterator for TupleValue {
334 type Item = ExtensionValue;
335 type IntoIter = std::vec::IntoIter<ExtensionValue>;
336
337 fn into_iter(self) -> Self::IntoIter {
338 self.0.into_iter()
339 }
340}
341
342impl FromIterator<ExtensionValue> for TupleValue {
343 fn from_iter<I: IntoIterator<Item = ExtensionValue>>(iter: I) -> Self {
344 TupleValue(iter.into_iter().collect())
345 }
346}
347
348impl From<Vec<ExtensionValue>> for TupleValue {
349 fn from(items: Vec<ExtensionValue>) -> Self {
350 TupleValue(items)
351 }
352}
353
354#[derive(Debug, Clone)]
360pub enum ExtensionValue {
361 String(String),
365 Integer(i64),
366 Float(f64),
367 Boolean(bool),
368
369 Expr(Expr),
374 Enum(String),
377 Tuple(TupleValue),
379 }
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum ExtensionValueKind {
387 String,
388 Integer,
389 Float,
390 Boolean,
391 Reference,
392 Enum,
393 Tuple,
394 Expression,
395}
396
397impl fmt::Display for ExtensionValueKind {
398 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399 match self {
400 ExtensionValueKind::String => write!(f, "string"),
401 ExtensionValueKind::Integer => write!(f, "integer"),
402 ExtensionValueKind::Float => write!(f, "float"),
403 ExtensionValueKind::Boolean => write!(f, "boolean"),
404 ExtensionValueKind::Reference => write!(f, "reference"),
405 ExtensionValueKind::Enum => write!(f, "enum"),
406 ExtensionValueKind::Tuple => write!(f, "tuple"),
407 ExtensionValueKind::Expression => write!(f, "expression"),
408 }
409 }
410}
411
412impl ExtensionValue {
413 pub fn kind(&self) -> ExtensionValueKind {
415 match self {
416 ExtensionValue::String(_) => ExtensionValueKind::String,
417 ExtensionValue::Integer(_) => ExtensionValueKind::Integer,
418 ExtensionValue::Float(_) => ExtensionValueKind::Float,
419 ExtensionValue::Boolean(_) => ExtensionValueKind::Boolean,
420 ExtensionValue::Expr(_) => ExtensionValueKind::Expression,
421 ExtensionValue::Enum(_) => ExtensionValueKind::Enum,
422 ExtensionValue::Tuple(_) => ExtensionValueKind::Tuple,
423 }
424 }
425}
426
427impl From<Expr> for ExtensionValue {
428 fn from(expr: Expr) -> Self {
429 ExtensionValue::Expr(expr)
430 }
431}
432
433impl From<proto::Expression> for ExtensionValue {
434 fn from(expr: proto::Expression) -> Self {
435 Expr::from(expr).into()
436 }
437}
438
439impl From<proto::expression::Literal> for ExtensionValue {
440 fn from(literal: proto::expression::Literal) -> Self {
441 Expr::from(literal).into()
442 }
443}
444
445impl From<Reference> for ExtensionValue {
446 fn from(reference: Reference) -> Self {
447 Expr::from(reference).into()
448 }
449}
450
451impl From<i64> for ExtensionValue {
452 fn from(value: i64) -> Self {
453 ExtensionValue::Integer(value)
454 }
455}
456
457impl From<f64> for ExtensionValue {
458 fn from(value: f64) -> Self {
459 ExtensionValue::Float(value)
460 }
461}
462
463impl From<bool> for ExtensionValue {
464 fn from(value: bool) -> Self {
465 ExtensionValue::Boolean(value)
466 }
467}
468
469impl From<String> for ExtensionValue {
470 fn from(value: String) -> Self {
471 ExtensionValue::String(value)
472 }
473}
474
475impl From<&str> for ExtensionValue {
476 fn from(value: &str) -> Self {
477 ExtensionValue::String(value.to_string())
478 }
479}
480
481fn invalid_type(expected: ExtensionValueKind, actual: &ExtensionValue) -> ExtensionError {
482 ExtensionError::InvalidArgumentType {
483 expected,
484 actual: actual.kind(),
485 }
486}
487
488impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
489 type Error = ExtensionError;
490
491 fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
492 match value {
493 ExtensionValue::String(s) => Ok(s),
494 v => Err(invalid_type(ExtensionValueKind::String, v)),
495 }
496 }
497}
498
499impl TryFrom<ExtensionValue> for String {
500 type Error = ExtensionError;
501
502 fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
503 <&str>::try_from(&value).map(ToOwned::to_owned)
504 }
505}
506
507pub struct EnumValue(pub String);
509
510impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
511 type Error = ExtensionError;
512
513 fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
514 match value {
515 ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
516 v => Err(invalid_type(ExtensionValueKind::Enum, v)),
517 }
518 }
519}
520
521impl<'a> TryFrom<&'a ExtensionValue> for &'a TupleValue {
522 type Error = ExtensionError;
523
524 fn try_from(value: &'a ExtensionValue) -> Result<&'a TupleValue, Self::Error> {
525 match value {
526 ExtensionValue::Tuple(tv) => Ok(tv),
527 v => Err(invalid_type(ExtensionValueKind::Tuple, v)),
528 }
529 }
530}
531
532impl TryFrom<&ExtensionValue> for i64 {
533 type Error = ExtensionError;
534
535 fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
536 match value {
537 ExtensionValue::Integer(i) => Ok(*i),
538 v => Err(invalid_type(ExtensionValueKind::Integer, v)),
539 }
540 }
541}
542
543impl TryFrom<&ExtensionValue> for f64 {
544 type Error = ExtensionError;
545
546 fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
547 match value {
548 ExtensionValue::Float(f) => Ok(*f),
549 v => Err(invalid_type(ExtensionValueKind::Float, v)),
550 }
551 }
552}
553
554impl TryFrom<&ExtensionValue> for bool {
555 type Error = ExtensionError;
556
557 fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
558 match value {
559 ExtensionValue::Boolean(b) => Ok(*b),
560 v => Err(invalid_type(ExtensionValueKind::Boolean, v)),
561 }
562 }
563}
564
565impl TryFrom<&ExtensionValue> for Reference {
566 type Error = ExtensionError;
567
568 fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
569 match value {
570 ExtensionValue::Expr(expr) => expr
571 .as_direct_reference()
572 .map(Reference)
573 .ok_or_else(|| invalid_type(ExtensionValueKind::Reference, value)),
574 v => Err(invalid_type(ExtensionValueKind::Reference, v)),
575 }
576 }
577}
578
579impl TryFrom<&ExtensionValue> for Expr {
580 type Error = ExtensionError;
581
582 fn try_from(value: &ExtensionValue) -> Result<Expr, Self::Error> {
583 match value {
584 ExtensionValue::Expr(e) => Ok(e.clone()),
585 ExtensionValue::Integer(i) => Ok(Expr::from(*i)),
592 ExtensionValue::Float(f) => Ok(Expr::from(*f)),
593 ExtensionValue::String(s) => Ok(Expr::from(s.as_str())),
594 ExtensionValue::Boolean(b) => Ok(Expr::from(*b)),
595 v => Err(invalid_type(ExtensionValueKind::Expression, v)),
596 }
597 }
598}
599
600#[derive(Debug, Clone)]
606pub enum ExtensionColumn {
607 Named {
609 name: String,
611 r#type: proto::Type,
615 },
616 Expr(Expr),
618}
619
620impl ExtensionColumn {
621 pub fn field(index: i32) -> Self {
623 Self::Expr(Expr::field(index))
624 }
625}
626
627impl ExtensionArgs {
628 pub fn push<T>(&mut self, value: T)
630 where
631 T: Into<ExtensionValue>,
632 {
633 self.positional.push(value.into());
634 }
635
636 pub fn insert<K, V>(&mut self, name: K, value: V) -> Option<ExtensionValue>
638 where
639 K: Into<String>,
640 V: Into<ExtensionValue>,
641 {
642 self.named.insert(name.into(), value.into())
643 }
644
645 pub fn extractor(&self) -> ArgsExtractor<'_> {
647 ArgsExtractor::new(self)
648 }
649}