substrait_explain/extensions/
args.rs1use std::collections::HashSet;
29use std::fmt;
30
31use indexmap::IndexMap;
32use substrait::proto;
33use substrait::proto::expression::field_reference::ReferenceType;
34use substrait::proto::expression::literal::LiteralType;
35use substrait::proto::expression::{RexType, reference_segment};
36
37use super::ExtensionError;
38use crate::textify::expressions::Reference;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum AddendumKind {
47 Enhancement,
48 Optimization,
49 ExtensionTable,
50}
51
52impl AddendumKind {
53 pub(crate) fn prefix(self) -> &'static str {
54 match self {
55 AddendumKind::Enhancement => "Enh",
56 AddendumKind::Optimization => "Opt",
57 AddendumKind::ExtensionTable => "Ext",
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
67pub struct Expr(Box<proto::Expression>);
68
69impl Expr {
70 pub fn as_proto(&self) -> &proto::Expression {
72 self.0.as_ref()
73 }
74
75 pub fn to_proto(&self) -> proto::Expression {
77 self.as_proto().clone()
78 }
79
80 pub fn as_direct_reference(&self) -> Option<Reference> {
82 let Some(RexType::Selection(field_ref)) = self.as_proto().rex_type.as_ref() else {
83 return None;
84 };
85 let Some(ReferenceType::DirectReference(segment)) = field_ref.reference_type.as_ref()
86 else {
87 return None;
88 };
89 let Some(reference_segment::ReferenceType::StructField(field)) =
90 segment.reference_type.as_ref()
91 else {
92 return None;
93 };
94 if field.child.is_some() {
95 return None;
96 }
97 Some(Reference(field.field))
98 }
99}
100
101impl From<proto::Expression> for Expr {
102 fn from(expr: proto::Expression) -> Self {
103 Expr(Box::new(expr))
104 }
105}
106
107impl From<proto::expression::Literal> for Expr {
108 fn from(literal: proto::expression::Literal) -> Self {
109 proto::Expression {
110 rex_type: Some(RexType::Literal(literal)),
111 }
112 .into()
113 }
114}
115
116impl From<Reference> for Expr {
117 fn from(reference: Reference) -> Self {
118 proto::Expression::from(reference).into()
119 }
120}
121
122impl From<Expr> for proto::Expression {
123 fn from(expr: Expr) -> Self {
124 *expr.0
125 }
126}
127
128impl From<i64> for Expr {
129 fn from(value: i64) -> Self {
130 proto::expression::Literal {
131 literal_type: Some(LiteralType::I64(value)),
132 nullable: false,
133 type_variation_reference: 0,
134 }
135 .into()
136 }
137}
138
139impl From<f64> for Expr {
140 fn from(value: f64) -> Self {
141 proto::expression::Literal {
142 literal_type: Some(LiteralType::Fp64(value)),
143 nullable: false,
144 type_variation_reference: 0,
145 }
146 .into()
147 }
148}
149
150impl From<bool> for Expr {
151 fn from(value: bool) -> Self {
152 proto::expression::Literal {
153 literal_type: Some(LiteralType::Boolean(value)),
154 nullable: false,
155 type_variation_reference: 0,
156 }
157 .into()
158 }
159}
160
161impl From<String> for Expr {
162 fn from(value: String) -> Self {
163 proto::expression::Literal {
164 literal_type: Some(LiteralType::String(value)),
165 nullable: false,
166 type_variation_reference: 0,
167 }
168 .into()
169 }
170}
171
172impl From<&str> for Expr {
173 fn from(value: &str) -> Self {
174 value.to_string().into()
175 }
176}
177
178#[derive(Debug, Clone, Default)]
185pub struct ExtensionArgs {
186 pub positional: Vec<ExtensionValue>,
188 pub named: IndexMap<String, ExtensionValue>,
190 pub output_columns: Vec<ExtensionColumn>,
192}
193
194pub struct ArgsExtractor<'a> {
201 args: &'a ExtensionArgs,
202 consumed: HashSet<&'a str>,
203 checked: bool,
204}
205
206impl<'a> ArgsExtractor<'a> {
207 pub fn new(args: &'a ExtensionArgs) -> Self {
209 Self {
210 args,
211 consumed: HashSet::new(),
212 checked: false,
213 }
214 }
215
216 pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
218 match self.args.named.get_key_value(name) {
219 Some((k, value)) => {
220 self.consumed.insert(k);
221 Some(value)
222 }
223 None => None,
224 }
225 }
226
227 pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
230 where
231 T: TryFrom<&'a ExtensionValue>,
232 T::Error: Into<ExtensionError>,
233 {
234 match self.get_named_arg(name) {
235 Some(value) => T::try_from(value).map_err(Into::into),
236 None => Err(ExtensionError::MissingArgument {
237 name: name.to_string(),
238 }),
239 }
240 }
241
242 pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
245 where
246 T: TryFrom<&'a ExtensionValue>,
247 T::Error: Into<ExtensionError>,
248 {
249 match self.get_named_arg(name) {
250 Some(value) => T::try_from(value).map_err(Into::into),
251 None => Ok(default),
252 }
253 }
254
255 pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
262 self.checked = true;
263
264 let mut unknown_args = Vec::new();
265 for name in self.args.named.keys() {
266 if !self.consumed.contains(name.as_str()) {
267 unknown_args.push(name.as_str());
268 }
269 }
270
271 if unknown_args.is_empty() {
272 Ok(())
273 } else {
274 unknown_args.sort();
276 Err(ExtensionError::InvalidArgument(format!(
277 "Unknown named arguments: {}",
278 unknown_args.join(", ")
279 )))
280 }
281 }
282}
283
284impl Drop for ArgsExtractor<'_> {
285 fn drop(&mut self) {
286 if self.checked || std::thread::panicking() {
287 return;
288 }
289 debug_assert!(
291 false,
292 "ArgsExtractor dropped without calling check_exhausted()"
293 );
294 }
295}
296
297#[derive(Debug, Clone)]
302pub struct TupleValue(Vec<ExtensionValue>);
303
304impl TupleValue {
305 pub fn len(&self) -> usize {
306 self.0.len()
307 }
308
309 pub fn is_empty(&self) -> bool {
310 self.0.is_empty()
311 }
312
313 pub fn iter(&self) -> std::slice::Iter<'_, ExtensionValue> {
314 self.0.iter()
315 }
316}
317
318impl<'a> IntoIterator for &'a TupleValue {
319 type Item = &'a ExtensionValue;
320 type IntoIter = std::slice::Iter<'a, ExtensionValue>;
321
322 fn into_iter(self) -> Self::IntoIter {
323 self.0.iter()
324 }
325}
326
327impl IntoIterator for TupleValue {
328 type Item = ExtensionValue;
329 type IntoIter = std::vec::IntoIter<ExtensionValue>;
330
331 fn into_iter(self) -> Self::IntoIter {
332 self.0.into_iter()
333 }
334}
335
336impl FromIterator<ExtensionValue> for TupleValue {
337 fn from_iter<I: IntoIterator<Item = ExtensionValue>>(iter: I) -> Self {
338 TupleValue(iter.into_iter().collect())
339 }
340}
341
342impl From<Vec<ExtensionValue>> for TupleValue {
343 fn from(items: Vec<ExtensionValue>) -> Self {
344 TupleValue(items)
345 }
346}
347
348#[derive(Debug, Clone)]
354pub enum ExtensionValue {
355 String(String),
359 Integer(i64),
360 Float(f64),
361 Boolean(bool),
362
363 Expr(Expr),
368 Enum(String),
371 Tuple(TupleValue),
373 }
377
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub enum ExtensionValueKind {
381 String,
382 Integer,
383 Float,
384 Boolean,
385 Reference,
386 Enum,
387 Tuple,
388 Expression,
389}
390
391impl fmt::Display for ExtensionValueKind {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 match self {
394 ExtensionValueKind::String => write!(f, "string"),
395 ExtensionValueKind::Integer => write!(f, "integer"),
396 ExtensionValueKind::Float => write!(f, "float"),
397 ExtensionValueKind::Boolean => write!(f, "boolean"),
398 ExtensionValueKind::Reference => write!(f, "reference"),
399 ExtensionValueKind::Enum => write!(f, "enum"),
400 ExtensionValueKind::Tuple => write!(f, "tuple"),
401 ExtensionValueKind::Expression => write!(f, "expression"),
402 }
403 }
404}
405
406impl ExtensionValue {
407 pub fn kind(&self) -> ExtensionValueKind {
409 match self {
410 ExtensionValue::String(_) => ExtensionValueKind::String,
411 ExtensionValue::Integer(_) => ExtensionValueKind::Integer,
412 ExtensionValue::Float(_) => ExtensionValueKind::Float,
413 ExtensionValue::Boolean(_) => ExtensionValueKind::Boolean,
414 ExtensionValue::Expr(_) => ExtensionValueKind::Expression,
415 ExtensionValue::Enum(_) => ExtensionValueKind::Enum,
416 ExtensionValue::Tuple(_) => ExtensionValueKind::Tuple,
417 }
418 }
419}
420
421impl From<Expr> for ExtensionValue {
422 fn from(expr: Expr) -> Self {
423 ExtensionValue::Expr(expr)
424 }
425}
426
427impl From<proto::Expression> for ExtensionValue {
428 fn from(expr: proto::Expression) -> Self {
429 Expr::from(expr).into()
430 }
431}
432
433impl From<proto::expression::Literal> for ExtensionValue {
434 fn from(literal: proto::expression::Literal) -> Self {
435 Expr::from(literal).into()
436 }
437}
438
439impl From<Reference> for ExtensionValue {
440 fn from(reference: Reference) -> Self {
441 Expr::from(reference).into()
442 }
443}
444
445impl From<i64> for ExtensionValue {
446 fn from(value: i64) -> Self {
447 ExtensionValue::Integer(value)
448 }
449}
450
451impl From<f64> for ExtensionValue {
452 fn from(value: f64) -> Self {
453 ExtensionValue::Float(value)
454 }
455}
456
457impl From<bool> for ExtensionValue {
458 fn from(value: bool) -> Self {
459 ExtensionValue::Boolean(value)
460 }
461}
462
463impl From<String> for ExtensionValue {
464 fn from(value: String) -> Self {
465 ExtensionValue::String(value)
466 }
467}
468
469impl From<&str> for ExtensionValue {
470 fn from(value: &str) -> Self {
471 ExtensionValue::String(value.to_string())
472 }
473}
474
475fn invalid_type(expected: ExtensionValueKind, actual: &ExtensionValue) -> ExtensionError {
476 ExtensionError::InvalidArgumentType {
477 expected,
478 actual: actual.kind(),
479 }
480}
481
482impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
483 type Error = ExtensionError;
484
485 fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
486 match value {
487 ExtensionValue::String(s) => Ok(s),
488 v => Err(invalid_type(ExtensionValueKind::String, v)),
489 }
490 }
491}
492
493impl TryFrom<ExtensionValue> for String {
494 type Error = ExtensionError;
495
496 fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
497 <&str>::try_from(&value).map(ToOwned::to_owned)
498 }
499}
500
501pub struct EnumValue(pub String);
503
504impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
505 type Error = ExtensionError;
506
507 fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
508 match value {
509 ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
510 v => Err(invalid_type(ExtensionValueKind::Enum, v)),
511 }
512 }
513}
514
515impl<'a> TryFrom<&'a ExtensionValue> for &'a TupleValue {
516 type Error = ExtensionError;
517
518 fn try_from(value: &'a ExtensionValue) -> Result<&'a TupleValue, Self::Error> {
519 match value {
520 ExtensionValue::Tuple(tv) => Ok(tv),
521 v => Err(invalid_type(ExtensionValueKind::Tuple, v)),
522 }
523 }
524}
525
526impl TryFrom<&ExtensionValue> for i64 {
527 type Error = ExtensionError;
528
529 fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
530 match value {
531 ExtensionValue::Integer(i) => Ok(*i),
532 v => Err(invalid_type(ExtensionValueKind::Integer, v)),
533 }
534 }
535}
536
537impl TryFrom<&ExtensionValue> for f64 {
538 type Error = ExtensionError;
539
540 fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
541 match value {
542 ExtensionValue::Float(f) => Ok(*f),
543 v => Err(invalid_type(ExtensionValueKind::Float, v)),
544 }
545 }
546}
547
548impl TryFrom<&ExtensionValue> for bool {
549 type Error = ExtensionError;
550
551 fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
552 match value {
553 ExtensionValue::Boolean(b) => Ok(*b),
554 v => Err(invalid_type(ExtensionValueKind::Boolean, v)),
555 }
556 }
557}
558
559impl TryFrom<&ExtensionValue> for Reference {
560 type Error = ExtensionError;
561
562 fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
563 match value {
564 ExtensionValue::Expr(expr) => expr
565 .as_direct_reference()
566 .ok_or_else(|| invalid_type(ExtensionValueKind::Reference, value)),
567 v => Err(invalid_type(ExtensionValueKind::Reference, v)),
568 }
569 }
570}
571
572impl TryFrom<&ExtensionValue> for Expr {
573 type Error = ExtensionError;
574
575 fn try_from(value: &ExtensionValue) -> Result<Expr, Self::Error> {
576 match value {
577 ExtensionValue::Expr(e) => Ok(e.clone()),
578 ExtensionValue::Integer(i) => Ok(Expr::from(*i)),
585 ExtensionValue::Float(f) => Ok(Expr::from(*f)),
586 ExtensionValue::String(s) => Ok(Expr::from(s.as_str())),
587 ExtensionValue::Boolean(b) => Ok(Expr::from(*b)),
588 v => Err(invalid_type(ExtensionValueKind::Expression, v)),
589 }
590 }
591}
592
593#[derive(Debug, Clone)]
599pub enum ExtensionColumn {
600 Named {
602 name: String,
604 r#type: proto::Type,
608 },
609 Expr(Expr),
611}
612
613impl ExtensionArgs {
614 pub fn push<T>(&mut self, value: T)
616 where
617 T: Into<ExtensionValue>,
618 {
619 self.positional.push(value.into());
620 }
621
622 pub fn insert<K, V>(&mut self, name: K, value: V) -> Option<ExtensionValue>
624 where
625 K: Into<String>,
626 V: Into<ExtensionValue>,
627 {
628 self.named.insert(name.into(), value.into())
629 }
630
631 pub fn extractor(&self) -> ArgsExtractor<'_> {
633 ArgsExtractor::new(self)
634 }
635}