Skip to main content

substrait_explain/extensions/
args.rs

1//! Text-format data structures used by registered advanced extension handlers.
2//!
3//! These types describe the arguments accepted by custom relation types,
4//! enhancements, and optimization hints. Relation extensions can additionally
5//! describe output columns.
6//!
7//! The interface presented to extension handlers is structured rather than
8//! textual: handlers read and write values such as [`ExtensionArgs`], [`Expr`],
9//! and [`proto::Type`]. `substrait-explain` handles the surrounding
10//! parsing/textification. Some values need plan context before they reach a
11//! handler; for example, an expression argument like `add($0, $1)` is parsed
12//! using [`SimpleExtensions`](crate::extensions::SimpleExtensions) to resolve
13//! the text function name to the protobuf function anchor, and formatted by
14//! resolving that anchor back to a text name.
15//!
16//! The extension-facing interface for Substrait objects (e.g. [`proto::Type`])
17//! should map directly to Substrait protobuf concepts. Sometimes that means
18//! storing the protobuf type directly, as named output columns do with
19//! [`proto::Type`]; sometimes it means using a small wrapper, as
20//! expression-compatible arguments do with [`Expr`] around
21//! [`proto::Expression`].
22//!
23//! Untyped scalar literals (e.g. `2`, `2.435`, `'string'`) are kept as
24//! extension scalar values so text rendering can preserve scalar syntax even in
25//! verbose output, while handlers that accept expressions can still widen them
26//! into default Substrait literal expressions.
27
28use std::collections::HashSet;
29use std::fmt;
30
31use indexmap::IndexMap;
32use substrait::proto;
33use substrait::proto::expression::field_reference::ReferenceType;
34use substrait::proto::expression::literal::LiteralType;
35use substrait::proto::expression::{RexType, reference_segment};
36
37use super::ExtensionError;
38use crate::textify::expressions::Reference;
39
40/// Kind of relation addendum in the text format.
41///
42/// Addenda are `+`-prefixed lines attached to relations. They are syntax-level
43/// constructs, distinct from [`crate::extensions::registry::ExtensionType`],
44/// which describes registry namespaces.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum AddendumKind {
47    Enhancement,
48    Optimization,
49    ExtensionTable,
50}
51
52impl AddendumKind {
53    pub(crate) fn prefix(self) -> &'static str {
54        match self {
55            AddendumKind::Enhancement => "Enh",
56            AddendumKind::Optimization => "Opt",
57            AddendumKind::ExtensionTable => "Ext",
58        }
59    }
60}
61
62/// A Substrait expression carried as an extension argument or output column.
63///
64/// Boxed because `proto::Expression` is large (multiple `Vec` fields in
65/// variants like `ScalarFunction`).
66#[derive(Debug, Clone)]
67pub struct Expr(Box<proto::Expression>);
68
69impl Expr {
70    /// Borrow the underlying Substrait expression protobuf.
71    pub fn as_proto(&self) -> &proto::Expression {
72        self.0.as_ref()
73    }
74
75    /// Clone the underlying Substrait expression protobuf.
76    pub fn to_proto(&self) -> proto::Expression {
77        self.as_proto().clone()
78    }
79
80    /// If this expression is a direct field reference (`$N`), return it.
81    pub fn as_direct_reference(&self) -> Option<Reference> {
82        let Some(RexType::Selection(field_ref)) = self.as_proto().rex_type.as_ref() else {
83            return None;
84        };
85        let Some(ReferenceType::DirectReference(segment)) = field_ref.reference_type.as_ref()
86        else {
87            return None;
88        };
89        let Some(reference_segment::ReferenceType::StructField(field)) =
90            segment.reference_type.as_ref()
91        else {
92            return None;
93        };
94        if field.child.is_some() {
95            return None;
96        }
97        Some(Reference(field.field))
98    }
99}
100
101impl From<proto::Expression> for Expr {
102    fn from(expr: proto::Expression) -> Self {
103        Expr(Box::new(expr))
104    }
105}
106
107impl From<proto::expression::Literal> for Expr {
108    fn from(literal: proto::expression::Literal) -> Self {
109        proto::Expression {
110            rex_type: Some(RexType::Literal(literal)),
111        }
112        .into()
113    }
114}
115
116impl From<Reference> for Expr {
117    fn from(reference: Reference) -> Self {
118        proto::Expression::from(reference).into()
119    }
120}
121
122impl From<Expr> for proto::Expression {
123    fn from(expr: Expr) -> Self {
124        *expr.0
125    }
126}
127
128impl From<i64> for Expr {
129    fn from(value: i64) -> Self {
130        proto::expression::Literal {
131            literal_type: Some(LiteralType::I64(value)),
132            nullable: false,
133            type_variation_reference: 0,
134        }
135        .into()
136    }
137}
138
139impl From<f64> for Expr {
140    fn from(value: f64) -> Self {
141        proto::expression::Literal {
142            literal_type: Some(LiteralType::Fp64(value)),
143            nullable: false,
144            type_variation_reference: 0,
145        }
146        .into()
147    }
148}
149
150impl From<bool> for Expr {
151    fn from(value: bool) -> Self {
152        proto::expression::Literal {
153            literal_type: Some(LiteralType::Boolean(value)),
154            nullable: false,
155            type_variation_reference: 0,
156        }
157        .into()
158    }
159}
160
161impl From<String> for Expr {
162    fn from(value: String) -> Self {
163        proto::expression::Literal {
164            literal_type: Some(LiteralType::String(value)),
165            nullable: false,
166            type_variation_reference: 0,
167        }
168        .into()
169    }
170}
171
172impl From<&str> for Expr {
173    fn from(value: &str) -> Self {
174        value.to_string().into()
175    }
176}
177
178/// Represents extension arguments plus optional output columns.
179///
180/// Named arguments are stored in an [`IndexMap`] whose iteration order
181/// determines display order. Extension [`super::Explainable::to_args()`]
182/// implementations should insert named arguments in the order they should
183/// appear in the text format.
184#[derive(Debug, Clone, Default)]
185pub struct ExtensionArgs {
186    /// Positional arguments.
187    pub positional: Vec<ExtensionValue>,
188    /// Named arguments, displayed in the order they were inserted
189    pub named: IndexMap<String, ExtensionValue>,
190    /// Output columns for custom relation types.
191    pub output_columns: Vec<ExtensionColumn>,
192}
193
194/// Helper struct for extracting named arguments with validation.
195///
196/// Tracks which arguments have been consumed. Callers **must** call
197/// [`check_exhausted`](ArgsExtractor::check_exhausted) before dropping to
198/// verify no unexpected arguments remain. In debug builds, dropping without
199/// calling `check_exhausted` will panic (matching the [`RuleIter`](crate::parser::RuleIter) pattern).
200pub struct ArgsExtractor<'a> {
201    args: &'a ExtensionArgs,
202    consumed: HashSet<&'a str>,
203    checked: bool,
204}
205
206impl<'a> ArgsExtractor<'a> {
207    /// Create a new extractor for the given arguments
208    pub fn new(args: &'a ExtensionArgs) -> Self {
209        Self {
210            args,
211            consumed: HashSet::new(),
212            checked: false,
213        }
214    }
215
216    /// Get a named argument value, marking it as consumed if found.
217    pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
218        match self.args.named.get_key_value(name) {
219            Some((k, value)) => {
220                self.consumed.insert(k);
221                Some(value)
222            }
223            None => None,
224        }
225    }
226
227    /// Get a named argument value or return an error
228    /// Marks the argument as consumed if found
229    pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
230    where
231        T: TryFrom<&'a ExtensionValue>,
232        T::Error: Into<ExtensionError>,
233    {
234        match self.get_named_arg(name) {
235            Some(value) => T::try_from(value).map_err(Into::into),
236            None => Err(ExtensionError::MissingArgument {
237                name: name.to_string(),
238            }),
239        }
240    }
241
242    /// Get a named argument value or default
243    /// Marks the argument as consumed if it exists in the source args
244    pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
245    where
246        T: TryFrom<&'a ExtensionValue>,
247        T::Error: Into<ExtensionError>,
248    {
249        match self.get_named_arg(name) {
250            Some(value) => T::try_from(value).map_err(Into::into),
251            None => Ok(default),
252        }
253    }
254
255    /// Check that all named arguments in the source have been consumed,
256    /// returning an error if not.
257    ///
258    /// Must be called before the extractor is dropped, to validate that all
259    /// args are correctly handled. In debug builds, dropping without calling
260    /// this method will panic.
261    pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
262        self.checked = true;
263
264        let mut unknown_args = Vec::new();
265        for name in self.args.named.keys() {
266            if !self.consumed.contains(name.as_str()) {
267                unknown_args.push(name.as_str());
268            }
269        }
270
271        if unknown_args.is_empty() {
272            Ok(())
273        } else {
274            // Sort for stable error messages
275            unknown_args.sort();
276            Err(ExtensionError::InvalidArgument(format!(
277                "Unknown named arguments: {}",
278                unknown_args.join(", ")
279            )))
280        }
281    }
282}
283
284impl Drop for ArgsExtractor<'_> {
285    fn drop(&mut self) {
286        if self.checked || std::thread::panicking() {
287            return;
288        }
289        // If we get here, the caller forgot to call check_exhausted().
290        debug_assert!(
291            false,
292            "ArgsExtractor dropped without calling check_exhausted()"
293        );
294    }
295}
296
297/// A tuple-valued extension argument.
298///
299/// Tuple values preserve positional order and can be iterated by value or by
300/// reference.
301#[derive(Debug, Clone)]
302pub struct TupleValue(Vec<ExtensionValue>);
303
304impl TupleValue {
305    pub fn len(&self) -> usize {
306        self.0.len()
307    }
308
309    pub fn is_empty(&self) -> bool {
310        self.0.is_empty()
311    }
312
313    pub fn iter(&self) -> std::slice::Iter<'_, ExtensionValue> {
314        self.0.iter()
315    }
316}
317
318impl<'a> IntoIterator for &'a TupleValue {
319    type Item = &'a ExtensionValue;
320    type IntoIter = std::slice::Iter<'a, ExtensionValue>;
321
322    fn into_iter(self) -> Self::IntoIter {
323        self.0.iter()
324    }
325}
326
327impl IntoIterator for TupleValue {
328    type Item = ExtensionValue;
329    type IntoIter = std::vec::IntoIter<ExtensionValue>;
330
331    fn into_iter(self) -> Self::IntoIter {
332        self.0.into_iter()
333    }
334}
335
336impl FromIterator<ExtensionValue> for TupleValue {
337    fn from_iter<I: IntoIterator<Item = ExtensionValue>>(iter: I) -> Self {
338        TupleValue(iter.into_iter().collect())
339    }
340}
341
342impl From<Vec<ExtensionValue>> for TupleValue {
343    fn from(items: Vec<ExtensionValue>) -> Self {
344        TupleValue(items)
345    }
346}
347
348/// Represents a value in extension arguments.
349///
350/// These values are the structured form of text-format extension arguments,
351/// fully resolved - i.e. any additional context (such as function anchors etc)
352/// are part of this struct itself.
353#[derive(Debug, Clone)]
354pub enum ExtensionValue {
355    /// Untyped literals. These are not input or output with types (e.g. `2`,
356    /// not `2:i64`), and suitable for protobuf extension fields that are not
357    /// substrait types.
358    String(String),
359    Integer(i64),
360    Float(f64),
361    Boolean(bool),
362
363    /// Substrait expression value, including typed literals and field references.
364    ///
365    /// Use `TryFrom<&ExtensionValue> for Expr` when a handler accepts either an
366    /// expression or a scalar value widened into an expression.
367    Expr(Expr),
368    /// Enum value (e.g. &CORE, &Inner) — the string holds the identifier
369    /// without the `&` prefix
370    Enum(String),
371    /// Tuple of values, e.g. (&HASH, &RANGE) or (42, 'hello')
372    Tuple(TupleValue),
373    // TODO: Consider adding support for types as arguments. May need dedicated
374    // syntax (`:typename`, perhaps?), as type names may not be distinguishable
375    // from identifiers
376}
377
378/// The variant kind of an [`ExtensionValue`], used in diagnostics.
379#[derive(Debug, Clone, Copy, PartialEq, Eq)]
380pub enum ExtensionValueKind {
381    String,
382    Integer,
383    Float,
384    Boolean,
385    Reference,
386    Enum,
387    Tuple,
388    Expression,
389}
390
391impl fmt::Display for ExtensionValueKind {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        match self {
394            ExtensionValueKind::String => write!(f, "string"),
395            ExtensionValueKind::Integer => write!(f, "integer"),
396            ExtensionValueKind::Float => write!(f, "float"),
397            ExtensionValueKind::Boolean => write!(f, "boolean"),
398            ExtensionValueKind::Reference => write!(f, "reference"),
399            ExtensionValueKind::Enum => write!(f, "enum"),
400            ExtensionValueKind::Tuple => write!(f, "tuple"),
401            ExtensionValueKind::Expression => write!(f, "expression"),
402        }
403    }
404}
405
406impl ExtensionValue {
407    /// Return the variant kind of this value for structured diagnostics.
408    pub fn kind(&self) -> ExtensionValueKind {
409        match self {
410            ExtensionValue::String(_) => ExtensionValueKind::String,
411            ExtensionValue::Integer(_) => ExtensionValueKind::Integer,
412            ExtensionValue::Float(_) => ExtensionValueKind::Float,
413            ExtensionValue::Boolean(_) => ExtensionValueKind::Boolean,
414            ExtensionValue::Expr(_) => ExtensionValueKind::Expression,
415            ExtensionValue::Enum(_) => ExtensionValueKind::Enum,
416            ExtensionValue::Tuple(_) => ExtensionValueKind::Tuple,
417        }
418    }
419}
420
421impl From<Expr> for ExtensionValue {
422    fn from(expr: Expr) -> Self {
423        ExtensionValue::Expr(expr)
424    }
425}
426
427impl From<proto::Expression> for ExtensionValue {
428    fn from(expr: proto::Expression) -> Self {
429        Expr::from(expr).into()
430    }
431}
432
433impl From<proto::expression::Literal> for ExtensionValue {
434    fn from(literal: proto::expression::Literal) -> Self {
435        Expr::from(literal).into()
436    }
437}
438
439impl From<Reference> for ExtensionValue {
440    fn from(reference: Reference) -> Self {
441        Expr::from(reference).into()
442    }
443}
444
445impl From<i64> for ExtensionValue {
446    fn from(value: i64) -> Self {
447        ExtensionValue::Integer(value)
448    }
449}
450
451impl From<f64> for ExtensionValue {
452    fn from(value: f64) -> Self {
453        ExtensionValue::Float(value)
454    }
455}
456
457impl From<bool> for ExtensionValue {
458    fn from(value: bool) -> Self {
459        ExtensionValue::Boolean(value)
460    }
461}
462
463impl From<String> for ExtensionValue {
464    fn from(value: String) -> Self {
465        ExtensionValue::String(value)
466    }
467}
468
469impl From<&str> for ExtensionValue {
470    fn from(value: &str) -> Self {
471        ExtensionValue::String(value.to_string())
472    }
473}
474
475fn invalid_type(expected: ExtensionValueKind, actual: &ExtensionValue) -> ExtensionError {
476    ExtensionError::InvalidArgumentType {
477        expected,
478        actual: actual.kind(),
479    }
480}
481
482impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
483    type Error = ExtensionError;
484
485    fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
486        match value {
487            ExtensionValue::String(s) => Ok(s),
488            v => Err(invalid_type(ExtensionValueKind::String, v)),
489        }
490    }
491}
492
493impl TryFrom<ExtensionValue> for String {
494    type Error = ExtensionError;
495
496    fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
497        <&str>::try_from(&value).map(ToOwned::to_owned)
498    }
499}
500
501/// Helper for extracting the identifier from an [`ExtensionValue::Enum`].
502pub struct EnumValue(pub String);
503
504impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
505    type Error = ExtensionError;
506
507    fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
508        match value {
509            ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
510            v => Err(invalid_type(ExtensionValueKind::Enum, v)),
511        }
512    }
513}
514
515impl<'a> TryFrom<&'a ExtensionValue> for &'a TupleValue {
516    type Error = ExtensionError;
517
518    fn try_from(value: &'a ExtensionValue) -> Result<&'a TupleValue, Self::Error> {
519        match value {
520            ExtensionValue::Tuple(tv) => Ok(tv),
521            v => Err(invalid_type(ExtensionValueKind::Tuple, v)),
522        }
523    }
524}
525
526impl TryFrom<&ExtensionValue> for i64 {
527    type Error = ExtensionError;
528
529    fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
530        match value {
531            ExtensionValue::Integer(i) => Ok(*i),
532            v => Err(invalid_type(ExtensionValueKind::Integer, v)),
533        }
534    }
535}
536
537impl TryFrom<&ExtensionValue> for f64 {
538    type Error = ExtensionError;
539
540    fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
541        match value {
542            ExtensionValue::Float(f) => Ok(*f),
543            v => Err(invalid_type(ExtensionValueKind::Float, v)),
544        }
545    }
546}
547
548impl TryFrom<&ExtensionValue> for bool {
549    type Error = ExtensionError;
550
551    fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
552        match value {
553            ExtensionValue::Boolean(b) => Ok(*b),
554            v => Err(invalid_type(ExtensionValueKind::Boolean, v)),
555        }
556    }
557}
558
559impl TryFrom<&ExtensionValue> for Reference {
560    type Error = ExtensionError;
561
562    fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
563        match value {
564            ExtensionValue::Expr(expr) => expr
565                .as_direct_reference()
566                .ok_or_else(|| invalid_type(ExtensionValueKind::Reference, value)),
567            v => Err(invalid_type(ExtensionValueKind::Reference, v)),
568        }
569    }
570}
571
572impl TryFrom<&ExtensionValue> for Expr {
573    type Error = ExtensionError;
574
575    fn try_from(value: &ExtensionValue) -> Result<Expr, Self::Error> {
576        match value {
577            ExtensionValue::Expr(e) => Ok(e.clone()),
578            // Untyped extension scalars are intentionally expression-compatible:
579            // `arg=2` carries no syntax that distinguishes "configuration
580            // integer" from "i64 literal expression". Scalar-specific
581            // extraction (`i64`, `&str`, `bool`, etc.) still requires the scalar
582            // variants, while expression extraction widens them to default
583            // non-nullable Substrait literal expressions.
584            ExtensionValue::Integer(i) => Ok(Expr::from(*i)),
585            ExtensionValue::Float(f) => Ok(Expr::from(*f)),
586            ExtensionValue::String(s) => Ok(Expr::from(s.as_str())),
587            ExtensionValue::Boolean(b) => Ok(Expr::from(*b)),
588            v => Err(invalid_type(ExtensionValueKind::Expression, v)),
589        }
590    }
591}
592
593/// Represents an output column specification.
594///
595/// These values mirror the text-format output column forms. Named columns keep
596/// the parsed Substrait type protobuf so handlers can convert directly to
597/// relation schemas.
598#[derive(Debug, Clone)]
599pub enum ExtensionColumn {
600    /// Named column with a parsed Substrait type (e.g. `name:i64?`).
601    Named {
602        /// Column name as it appears in the extension relation output.
603        name: String,
604        /// Parsed Substrait type for the column.
605        ///
606        /// This uses the protobuf field name, hence the raw identifier.
607        r#type: proto::Type,
608    },
609    /// Expression-compatible output column, including field references.
610    Expr(Expr),
611}
612
613impl ExtensionArgs {
614    /// Push a positional extension argument.
615    pub fn push<T>(&mut self, value: T)
616    where
617        T: Into<ExtensionValue>,
618    {
619        self.positional.push(value.into());
620    }
621
622    /// Insert a named extension argument, returning any previous value.
623    pub fn insert<K, V>(&mut self, name: K, value: V) -> Option<ExtensionValue>
624    where
625        K: Into<String>,
626        V: Into<ExtensionValue>,
627    {
628        self.named.insert(name.into(), value.into())
629    }
630
631    /// Create an extractor for validating named arguments
632    pub fn extractor(&self) -> ArgsExtractor<'_> {
633        ArgsExtractor::new(self)
634    }
635}