Skip to main content

substrait_explain/extensions/
args.rs

1//! Text-format data structures used by registered advanced extension handlers.
2//!
3//! These types describe the arguments accepted by custom relation types,
4//! enhancements, and optimization hints. Relation extensions can additionally
5//! describe output columns.
6//!
7//! The interface presented to extension handlers is structured rather than
8//! textual: handlers read and write values such as [`ExtensionArgs`], [`Expr`],
9//! and [`proto::Type`]. `substrait-explain` handles the surrounding
10//! parsing/textification. Some values need plan context before they reach a
11//! handler; for example, an expression argument like `add($0, $1)` is parsed
12//! using [`SimpleExtensions`](crate::extensions::SimpleExtensions) to resolve
13//! the text function name to the protobuf function anchor, and formatted by
14//! resolving that anchor back to a text name.
15//!
16//! The extension-facing interface for Substrait objects (e.g. [`proto::Type`])
17//! should map directly to Substrait protobuf concepts. Sometimes that means
18//! storing the protobuf type directly, as named output columns do with
19//! [`proto::Type`]; sometimes it means using a small wrapper, as
20//! expression-compatible arguments do with [`Expr`] around
21//! [`proto::Expression`].
22//!
23//! Untyped scalar literals (e.g. `2`, `2.435`, `'string'`) are kept as
24//! extension scalar values so text rendering can preserve scalar syntax even in
25//! verbose output, while handlers that accept expressions can still widen them
26//! into default Substrait literal expressions.
27
28use std::collections::HashSet;
29use std::fmt;
30
31use indexmap::IndexMap;
32use substrait::proto;
33use substrait::proto::expression::field_reference::ReferenceType;
34use substrait::proto::expression::literal::LiteralType;
35use substrait::proto::expression::{RexType, reference_segment};
36
37use super::ExtensionError;
38use crate::textify::expressions::Reference;
39
40/// Kind of relation addendum in the text format.
41///
42/// Addenda are `+`-prefixed lines attached to relations. They are syntax-level
43/// constructs, distinct from [`crate::extensions::registry::ExtensionType`],
44/// which describes registry namespaces.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum AddendumKind {
47    Enhancement,
48    Optimization,
49    ExtensionTable,
50}
51
52impl AddendumKind {
53    pub(crate) fn prefix(self) -> &'static str {
54        match self {
55            AddendumKind::Enhancement => "Enh",
56            AddendumKind::Optimization => "Opt",
57            AddendumKind::ExtensionTable => "Ext",
58        }
59    }
60}
61
62/// A Substrait expression carried as an extension argument or output column.
63///
64/// Boxed because `proto::Expression` is large (multiple `Vec` fields in
65/// variants like `ScalarFunction`).
66#[derive(Debug, Clone)]
67pub struct Expr(Box<proto::Expression>);
68
69impl Expr {
70    /// Create a direct field-reference expression (`$N`).
71    pub fn field(index: i32) -> Self {
72        Reference(index).into()
73    }
74
75    /// Borrow the underlying Substrait expression protobuf.
76    pub fn as_proto(&self) -> &proto::Expression {
77        self.0.as_ref()
78    }
79
80    /// Clone the underlying Substrait expression protobuf.
81    pub fn to_proto(&self) -> proto::Expression {
82        self.as_proto().clone()
83    }
84
85    /// If this expression is a direct field reference (`$N`), return it.
86    pub fn as_direct_reference(&self) -> Option<i32> {
87        let Some(RexType::Selection(field_ref)) = self.as_proto().rex_type.as_ref() else {
88            return None;
89        };
90        let Some(ReferenceType::DirectReference(segment)) = field_ref.reference_type.as_ref()
91        else {
92            return None;
93        };
94        let Some(reference_segment::ReferenceType::StructField(field)) =
95            segment.reference_type.as_ref()
96        else {
97            return None;
98        };
99        if field.child.is_some() {
100            return None;
101        }
102        Some(field.field)
103    }
104}
105
106impl From<proto::Expression> for Expr {
107    fn from(expr: proto::Expression) -> Self {
108        Expr(Box::new(expr))
109    }
110}
111
112impl From<proto::expression::Literal> for Expr {
113    fn from(literal: proto::expression::Literal) -> Self {
114        proto::Expression {
115            rex_type: Some(RexType::Literal(literal)),
116        }
117        .into()
118    }
119}
120
121impl From<Reference> for Expr {
122    fn from(reference: Reference) -> Self {
123        proto::Expression::from(reference).into()
124    }
125}
126
127impl From<Expr> for proto::Expression {
128    fn from(expr: Expr) -> Self {
129        *expr.0
130    }
131}
132
133impl From<i64> for Expr {
134    fn from(value: i64) -> Self {
135        proto::expression::Literal {
136            literal_type: Some(LiteralType::I64(value)),
137            nullable: false,
138            type_variation_reference: 0,
139        }
140        .into()
141    }
142}
143
144impl From<f64> for Expr {
145    fn from(value: f64) -> Self {
146        proto::expression::Literal {
147            literal_type: Some(LiteralType::Fp64(value)),
148            nullable: false,
149            type_variation_reference: 0,
150        }
151        .into()
152    }
153}
154
155impl From<bool> for Expr {
156    fn from(value: bool) -> Self {
157        proto::expression::Literal {
158            literal_type: Some(LiteralType::Boolean(value)),
159            nullable: false,
160            type_variation_reference: 0,
161        }
162        .into()
163    }
164}
165
166impl From<String> for Expr {
167    fn from(value: String) -> Self {
168        proto::expression::Literal {
169            literal_type: Some(LiteralType::String(value)),
170            nullable: false,
171            type_variation_reference: 0,
172        }
173        .into()
174    }
175}
176
177impl From<&str> for Expr {
178    fn from(value: &str) -> Self {
179        value.to_string().into()
180    }
181}
182
183/// Represents extension arguments plus optional output columns.
184///
185/// Named arguments are stored in an [`IndexMap`] whose iteration order
186/// determines display order. Extension [`super::Explainable::to_args()`]
187/// implementations should insert named arguments in the order they should
188/// appear in the text format.
189#[derive(Debug, Clone, Default)]
190pub struct ExtensionArgs {
191    /// Positional arguments.
192    pub positional: Vec<ExtensionValue>,
193    /// Named arguments, displayed in the order they were inserted
194    pub named: IndexMap<String, ExtensionValue>,
195    /// Output columns for custom relation types.
196    pub output_columns: Vec<ExtensionColumn>,
197}
198
199/// Helper struct for extracting named arguments with validation.
200///
201/// Tracks which arguments have been consumed. Callers **must** call
202/// [`check_exhausted`](ArgsExtractor::check_exhausted) before dropping to
203/// verify no unexpected arguments remain. In debug builds, dropping without
204/// calling `check_exhausted` will panic. This catches [`Explainable`](super::Explainable)
205/// implementations that forget to reject unexpected named arguments.
206pub struct ArgsExtractor<'a> {
207    args: &'a ExtensionArgs,
208    consumed: HashSet<&'a str>,
209    checked: bool,
210}
211
212impl<'a> ArgsExtractor<'a> {
213    /// Create a new extractor for the given arguments
214    pub fn new(args: &'a ExtensionArgs) -> Self {
215        Self {
216            args,
217            consumed: HashSet::new(),
218            checked: false,
219        }
220    }
221
222    /// Get a named argument value, marking it as consumed if found.
223    pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
224        match self.args.named.get_key_value(name) {
225            Some((k, value)) => {
226                self.consumed.insert(k);
227                Some(value)
228            }
229            None => None,
230        }
231    }
232
233    /// Get a named argument value or return an error
234    /// Marks the argument as consumed if found
235    pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
236    where
237        T: TryFrom<&'a ExtensionValue>,
238        T::Error: Into<ExtensionError>,
239    {
240        match self.get_named_arg(name) {
241            Some(value) => T::try_from(value).map_err(Into::into),
242            None => Err(ExtensionError::MissingArgument {
243                name: name.to_string(),
244            }),
245        }
246    }
247
248    /// Get a named argument value or default
249    /// Marks the argument as consumed if it exists in the source args
250    pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
251    where
252        T: TryFrom<&'a ExtensionValue>,
253        T::Error: Into<ExtensionError>,
254    {
255        match self.get_named_arg(name) {
256            Some(value) => T::try_from(value).map_err(Into::into),
257            None => Ok(default),
258        }
259    }
260
261    /// Check that all named arguments in the source have been consumed,
262    /// returning an error if not.
263    ///
264    /// Must be called before the extractor is dropped, to validate that all
265    /// args are correctly handled. In debug builds, dropping without calling
266    /// this method will panic.
267    pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
268        self.checked = true;
269
270        let mut unknown_args = Vec::new();
271        for name in self.args.named.keys() {
272            if !self.consumed.contains(name.as_str()) {
273                unknown_args.push(name.as_str());
274            }
275        }
276
277        if unknown_args.is_empty() {
278            Ok(())
279        } else {
280            // Sort for stable error messages
281            unknown_args.sort();
282            Err(ExtensionError::InvalidArgument(format!(
283                "Unknown named arguments: {}",
284                unknown_args.join(", ")
285            )))
286        }
287    }
288}
289
290impl Drop for ArgsExtractor<'_> {
291    fn drop(&mut self) {
292        if self.checked || std::thread::panicking() {
293            return;
294        }
295        // If we get here, the caller forgot to call check_exhausted().
296        debug_assert!(
297            false,
298            "ArgsExtractor dropped without calling check_exhausted()"
299        );
300    }
301}
302
303/// A tuple-valued extension argument.
304///
305/// Tuple values preserve positional order and can be iterated by value or by
306/// reference.
307#[derive(Debug, Clone)]
308pub struct TupleValue(Vec<ExtensionValue>);
309
310impl TupleValue {
311    pub fn len(&self) -> usize {
312        self.0.len()
313    }
314
315    pub fn is_empty(&self) -> bool {
316        self.0.is_empty()
317    }
318
319    pub fn iter(&self) -> std::slice::Iter<'_, ExtensionValue> {
320        self.0.iter()
321    }
322}
323
324impl<'a> IntoIterator for &'a TupleValue {
325    type Item = &'a ExtensionValue;
326    type IntoIter = std::slice::Iter<'a, ExtensionValue>;
327
328    fn into_iter(self) -> Self::IntoIter {
329        self.0.iter()
330    }
331}
332
333impl IntoIterator for TupleValue {
334    type Item = ExtensionValue;
335    type IntoIter = std::vec::IntoIter<ExtensionValue>;
336
337    fn into_iter(self) -> Self::IntoIter {
338        self.0.into_iter()
339    }
340}
341
342impl FromIterator<ExtensionValue> for TupleValue {
343    fn from_iter<I: IntoIterator<Item = ExtensionValue>>(iter: I) -> Self {
344        TupleValue(iter.into_iter().collect())
345    }
346}
347
348impl From<Vec<ExtensionValue>> for TupleValue {
349    fn from(items: Vec<ExtensionValue>) -> Self {
350        TupleValue(items)
351    }
352}
353
354/// Represents a value in extension arguments.
355///
356/// These values are the structured form of text-format extension arguments,
357/// fully resolved - i.e. any additional context (such as function anchors etc)
358/// are part of this struct itself.
359#[derive(Debug, Clone)]
360pub enum ExtensionValue {
361    /// Untyped literals. These are not input or output with types (e.g. `2`,
362    /// not `2:i64`), and suitable for protobuf extension fields that are not
363    /// substrait types.
364    String(String),
365    Integer(i64),
366    Float(f64),
367    Boolean(bool),
368
369    /// Substrait expression value, including typed literals and field references.
370    ///
371    /// Use `TryFrom<&ExtensionValue> for Expr` when a handler accepts either an
372    /// expression or a scalar value widened into an expression.
373    Expr(Expr),
374    /// Enum value (e.g. &CORE, &Inner) — the string holds the identifier
375    /// without the `&` prefix
376    Enum(String),
377    /// Tuple of values, e.g. (&HASH, &RANGE) or (42, 'hello')
378    Tuple(TupleValue),
379    // TODO: Consider adding support for types as arguments. May need dedicated
380    // syntax (`:typename`, perhaps?), as type names may not be distinguishable
381    // from identifiers
382}
383
384/// The variant kind of an [`ExtensionValue`], used in diagnostics.
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum ExtensionValueKind {
387    String,
388    Integer,
389    Float,
390    Boolean,
391    Reference,
392    Enum,
393    Tuple,
394    Expression,
395}
396
397impl fmt::Display for ExtensionValueKind {
398    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399        match self {
400            ExtensionValueKind::String => write!(f, "string"),
401            ExtensionValueKind::Integer => write!(f, "integer"),
402            ExtensionValueKind::Float => write!(f, "float"),
403            ExtensionValueKind::Boolean => write!(f, "boolean"),
404            ExtensionValueKind::Reference => write!(f, "reference"),
405            ExtensionValueKind::Enum => write!(f, "enum"),
406            ExtensionValueKind::Tuple => write!(f, "tuple"),
407            ExtensionValueKind::Expression => write!(f, "expression"),
408        }
409    }
410}
411
412impl ExtensionValue {
413    /// Return the variant kind of this value for structured diagnostics.
414    pub fn kind(&self) -> ExtensionValueKind {
415        match self {
416            ExtensionValue::String(_) => ExtensionValueKind::String,
417            ExtensionValue::Integer(_) => ExtensionValueKind::Integer,
418            ExtensionValue::Float(_) => ExtensionValueKind::Float,
419            ExtensionValue::Boolean(_) => ExtensionValueKind::Boolean,
420            ExtensionValue::Expr(_) => ExtensionValueKind::Expression,
421            ExtensionValue::Enum(_) => ExtensionValueKind::Enum,
422            ExtensionValue::Tuple(_) => ExtensionValueKind::Tuple,
423        }
424    }
425}
426
427impl From<Expr> for ExtensionValue {
428    fn from(expr: Expr) -> Self {
429        ExtensionValue::Expr(expr)
430    }
431}
432
433impl From<proto::Expression> for ExtensionValue {
434    fn from(expr: proto::Expression) -> Self {
435        Expr::from(expr).into()
436    }
437}
438
439impl From<proto::expression::Literal> for ExtensionValue {
440    fn from(literal: proto::expression::Literal) -> Self {
441        Expr::from(literal).into()
442    }
443}
444
445impl From<Reference> for ExtensionValue {
446    fn from(reference: Reference) -> Self {
447        Expr::from(reference).into()
448    }
449}
450
451impl From<i64> for ExtensionValue {
452    fn from(value: i64) -> Self {
453        ExtensionValue::Integer(value)
454    }
455}
456
457impl From<f64> for ExtensionValue {
458    fn from(value: f64) -> Self {
459        ExtensionValue::Float(value)
460    }
461}
462
463impl From<bool> for ExtensionValue {
464    fn from(value: bool) -> Self {
465        ExtensionValue::Boolean(value)
466    }
467}
468
469impl From<String> for ExtensionValue {
470    fn from(value: String) -> Self {
471        ExtensionValue::String(value)
472    }
473}
474
475impl From<&str> for ExtensionValue {
476    fn from(value: &str) -> Self {
477        ExtensionValue::String(value.to_string())
478    }
479}
480
481fn invalid_type(expected: ExtensionValueKind, actual: &ExtensionValue) -> ExtensionError {
482    ExtensionError::InvalidArgumentType {
483        expected,
484        actual: actual.kind(),
485    }
486}
487
488impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
489    type Error = ExtensionError;
490
491    fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
492        match value {
493            ExtensionValue::String(s) => Ok(s),
494            v => Err(invalid_type(ExtensionValueKind::String, v)),
495        }
496    }
497}
498
499impl TryFrom<ExtensionValue> for String {
500    type Error = ExtensionError;
501
502    fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
503        <&str>::try_from(&value).map(ToOwned::to_owned)
504    }
505}
506
507/// Helper for extracting the identifier from an [`ExtensionValue::Enum`].
508pub struct EnumValue(pub String);
509
510impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
511    type Error = ExtensionError;
512
513    fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
514        match value {
515            ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
516            v => Err(invalid_type(ExtensionValueKind::Enum, v)),
517        }
518    }
519}
520
521impl<'a> TryFrom<&'a ExtensionValue> for &'a TupleValue {
522    type Error = ExtensionError;
523
524    fn try_from(value: &'a ExtensionValue) -> Result<&'a TupleValue, Self::Error> {
525        match value {
526            ExtensionValue::Tuple(tv) => Ok(tv),
527            v => Err(invalid_type(ExtensionValueKind::Tuple, v)),
528        }
529    }
530}
531
532impl TryFrom<&ExtensionValue> for i64 {
533    type Error = ExtensionError;
534
535    fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
536        match value {
537            ExtensionValue::Integer(i) => Ok(*i),
538            v => Err(invalid_type(ExtensionValueKind::Integer, v)),
539        }
540    }
541}
542
543impl TryFrom<&ExtensionValue> for f64 {
544    type Error = ExtensionError;
545
546    fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
547        match value {
548            ExtensionValue::Float(f) => Ok(*f),
549            v => Err(invalid_type(ExtensionValueKind::Float, v)),
550        }
551    }
552}
553
554impl TryFrom<&ExtensionValue> for bool {
555    type Error = ExtensionError;
556
557    fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
558        match value {
559            ExtensionValue::Boolean(b) => Ok(*b),
560            v => Err(invalid_type(ExtensionValueKind::Boolean, v)),
561        }
562    }
563}
564
565impl TryFrom<&ExtensionValue> for Reference {
566    type Error = ExtensionError;
567
568    fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
569        match value {
570            ExtensionValue::Expr(expr) => expr
571                .as_direct_reference()
572                .map(Reference)
573                .ok_or_else(|| invalid_type(ExtensionValueKind::Reference, value)),
574            v => Err(invalid_type(ExtensionValueKind::Reference, v)),
575        }
576    }
577}
578
579impl TryFrom<&ExtensionValue> for Expr {
580    type Error = ExtensionError;
581
582    fn try_from(value: &ExtensionValue) -> Result<Expr, Self::Error> {
583        match value {
584            ExtensionValue::Expr(e) => Ok(e.clone()),
585            // Untyped extension scalars are intentionally expression-compatible:
586            // `arg=2` carries no syntax that distinguishes "configuration
587            // integer" from "i64 literal expression". Scalar-specific
588            // extraction (`i64`, `&str`, `bool`, etc.) still requires the scalar
589            // variants, while expression extraction widens them to default
590            // non-nullable Substrait literal expressions.
591            ExtensionValue::Integer(i) => Ok(Expr::from(*i)),
592            ExtensionValue::Float(f) => Ok(Expr::from(*f)),
593            ExtensionValue::String(s) => Ok(Expr::from(s.as_str())),
594            ExtensionValue::Boolean(b) => Ok(Expr::from(*b)),
595            v => Err(invalid_type(ExtensionValueKind::Expression, v)),
596        }
597    }
598}
599
600/// Represents an output column specification.
601///
602/// These values mirror the text-format output column forms. Named columns keep
603/// the parsed Substrait type protobuf so handlers can convert directly to
604/// relation schemas.
605#[derive(Debug, Clone)]
606pub enum ExtensionColumn {
607    /// Named column with a parsed Substrait type (e.g. `name:i64?`).
608    Named {
609        /// Column name as it appears in the extension relation output.
610        name: String,
611        /// Parsed Substrait type for the column.
612        ///
613        /// This uses the protobuf field name, hence the raw identifier.
614        r#type: proto::Type,
615    },
616    /// Expression-compatible output column, including field references.
617    Expr(Expr),
618}
619
620impl ExtensionColumn {
621    /// Create an expression output column that references an existing input field (`$N`).
622    pub fn field(index: i32) -> Self {
623        Self::Expr(Expr::field(index))
624    }
625}
626
627impl ExtensionArgs {
628    /// Push a positional extension argument.
629    pub fn push<T>(&mut self, value: T)
630    where
631        T: Into<ExtensionValue>,
632    {
633        self.positional.push(value.into());
634    }
635
636    /// Insert a named extension argument, returning any previous value.
637    pub fn insert<K, V>(&mut self, name: K, value: V) -> Option<ExtensionValue>
638    where
639        K: Into<String>,
640        V: Into<ExtensionValue>,
641    {
642        self.named.insert(name.into(), value.into())
643    }
644
645    /// Create an extractor for validating named arguments
646    pub fn extractor(&self) -> ArgsExtractor<'_> {
647        ArgsExtractor::new(self)
648    }
649}