substrait_explain/extensions/
args.rs

1//! Core extension data structures without parser dependencies
2//!
3//! This module contains the core data structures for extension arguments,
4//! values, and columns without any parser or textify dependencies.
5
6use std::collections::HashSet;
7use std::fmt;
8
9use indexmap::IndexMap;
10
11use super::ExtensionError;
12use crate::textify::expressions::Reference;
13use crate::textify::types::escaped;
14
15/// Placeholder for a future expression implementation.
16/// Holds the raw text of the parsed expression. The inner field is private —
17/// this type will be replaced with a proper expression AST in the future.
18#[derive(Debug, Clone)]
19pub(crate) struct RawExpression {
20    text: String,
21}
22
23impl RawExpression {
24    pub fn new(text: String) -> Self {
25        Self { text }
26    }
27}
28
29impl fmt::Display for RawExpression {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "{}", self.text)
32    }
33}
34
35/// Represents the arguments and output columns for an extension relation.
36///
37/// Named arguments are stored in an [`IndexMap`] whose iteration order
38/// determines display order. Extension [`super::Explainable::to_args()`]
39/// implementations should insert named arguments in the order they should
40/// appear in the text format.
41#[derive(Debug, Clone)]
42pub struct ExtensionArgs {
43    /// Positional arguments (expressions, literals, references)
44    pub positional: Vec<ExtensionValue>,
45    /// Named arguments, displayed in the order they were inserted
46    pub named: IndexMap<String, ExtensionValue>,
47    /// Output columns (named columns, references, or expressions)
48    pub output_columns: Vec<ExtensionColumn>,
49    /// The type of extension relation (Leaf/Single/Multi)
50    pub relation_type: ExtensionRelationType,
51}
52
53/// Helper struct for extracting named arguments with validation.
54///
55/// Tracks which arguments have been consumed. Callers **must** call
56/// [`check_exhausted`](ArgsExtractor::check_exhausted) before dropping to
57/// verify no unexpected arguments remain. In debug builds, dropping without
58/// calling `check_exhausted` will panic (matching the [`RuleIter`](crate::parser::RuleIter) pattern).
59pub struct ArgsExtractor<'a> {
60    args: &'a ExtensionArgs,
61    consumed: HashSet<&'a str>,
62    checked: bool,
63}
64
65impl<'a> ArgsExtractor<'a> {
66    /// Create a new extractor for the given arguments
67    pub fn new(args: &'a ExtensionArgs) -> Self {
68        Self {
69            args,
70            consumed: HashSet::new(),
71            checked: false,
72        }
73    }
74
75    /// Get a named argument value, marking it as consumed if found.
76    pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
77        match self.args.named.get_key_value(name) {
78            Some((k, value)) => {
79                self.consumed.insert(k);
80                Some(value)
81            }
82            None => None,
83        }
84    }
85
86    /// Get a named argument value or return an error
87    /// Marks the argument as consumed if found
88    pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
89    where
90        T: TryFrom<&'a ExtensionValue>,
91        T::Error: Into<ExtensionError>,
92    {
93        match self.get_named_arg(name) {
94            Some(value) => T::try_from(value).map_err(Into::into),
95            None => Err(ExtensionError::MissingArgument {
96                name: name.to_string(),
97            }),
98        }
99    }
100
101    /// Get a named argument value or default
102    /// Marks the argument as consumed if it exists in the source args
103    pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
104    where
105        T: TryFrom<&'a ExtensionValue>,
106        T::Error: Into<ExtensionError>,
107    {
108        match self.get_named_arg(name) {
109            Some(value) => T::try_from(value).map_err(Into::into),
110            None => Ok(default),
111        }
112    }
113
114    /// Check that all named arguments in the source have been consumed,
115    /// returning an error if not.
116    ///
117    /// Must be called before the extractor is dropped, to validate that all
118    /// args are correctly handled. In debug builds, dropping without calling
119    /// this method will panic.
120    pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
121        self.checked = true;
122
123        let mut unknown_args = Vec::new();
124        for name in self.args.named.keys() {
125            if !self.consumed.contains(name.as_str()) {
126                unknown_args.push(name.as_str());
127            }
128        }
129
130        if unknown_args.is_empty() {
131            Ok(())
132        } else {
133            // Sort for stable error messages
134            unknown_args.sort();
135            Err(ExtensionError::InvalidArgument(format!(
136                "Unknown named arguments: {}",
137                unknown_args.join(", ")
138            )))
139        }
140    }
141}
142
143impl Drop for ArgsExtractor<'_> {
144    fn drop(&mut self) {
145        if self.checked || std::thread::panicking() {
146            return;
147        }
148        // If we get here, the caller forgot to call check_exhausted().
149        debug_assert!(
150            false,
151            "ArgsExtractor dropped without calling check_exhausted()"
152        );
153    }
154}
155
156#[derive(Debug, Clone)]
157pub struct TupleValue(Vec<ExtensionValue>);
158
159impl TupleValue {
160    pub fn len(&self) -> usize {
161        self.0.len()
162    }
163
164    pub fn is_empty(&self) -> bool {
165        self.0.is_empty()
166    }
167
168    pub fn iter(&self) -> std::slice::Iter<'_, ExtensionValue> {
169        self.0.iter()
170    }
171}
172
173impl fmt::Display for TupleValue {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        write!(f, "(")?;
176        for (i, item) in self.0.iter().enumerate() {
177            if i > 0 {
178                write!(f, ", ")?;
179            }
180            write!(f, "{item}")?;
181        }
182        if self.0.len() == 1 {
183            write!(f, ",")?;
184        }
185        write!(f, ")")
186    }
187}
188
189impl<'a> IntoIterator for &'a TupleValue {
190    type Item = &'a ExtensionValue;
191    type IntoIter = std::slice::Iter<'a, ExtensionValue>;
192
193    fn into_iter(self) -> Self::IntoIter {
194        self.0.iter()
195    }
196}
197
198impl IntoIterator for TupleValue {
199    type Item = ExtensionValue;
200    type IntoIter = std::vec::IntoIter<ExtensionValue>;
201
202    fn into_iter(self) -> Self::IntoIter {
203        self.0.into_iter()
204    }
205}
206
207impl FromIterator<ExtensionValue> for TupleValue {
208    fn from_iter<I: IntoIterator<Item = ExtensionValue>>(iter: I) -> Self {
209        TupleValue(iter.into_iter().collect())
210    }
211}
212
213impl From<Vec<ExtensionValue>> for TupleValue {
214    fn from(items: Vec<ExtensionValue>) -> Self {
215        TupleValue(items)
216    }
217}
218
219/// Represents a value in extension arguments
220#[derive(Debug, Clone)]
221pub enum ExtensionValue {
222    /// String literal value
223    String(String),
224    /// Integer literal value
225    Integer(i64),
226    /// Float literal value
227    Float(f64),
228    /// Boolean literal value
229    Boolean(bool),
230    /// Field reference ($0, $1, etc.)
231    Reference(i32),
232    /// Enum value (e.g. &CORE, &Inner) — Uses the wrapper EnumValue. the string holds the identifier without the `&` prefix
233    Enum(String),
234    /// Tuple of values, e.g. (&HASH, &RANGE) or (42, 'hello')
235    Tuple(TupleValue),
236    /// Expression (function call, etc.) — not yet fully supported, hence the
237    /// private interface.
238    #[allow(private_interfaces)]
239    Expression(RawExpression),
240}
241
242impl fmt::Display for ExtensionValue {
243    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244        match self {
245            ExtensionValue::String(s) => write!(f, "String({})", escaped(s)),
246            ExtensionValue::Integer(i) => write!(f, "Integer({})", i),
247            ExtensionValue::Float(n) => write!(f, "Float({})", n),
248            ExtensionValue::Boolean(b) => write!(f, "Boolean({})", b),
249            ExtensionValue::Reference(r) => write!(f, "Reference({})", r),
250            ExtensionValue::Enum(e) => write!(f, "Enum(&{})", e),
251            ExtensionValue::Tuple(tv) => write!(f, "Tuple{tv}"),
252            ExtensionValue::Expression(e) => write!(f, "Expression({})", e),
253        }
254    }
255}
256
257impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
258    type Error = ExtensionError;
259
260    fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
261        match value {
262            ExtensionValue::String(s) => Ok(s),
263            v => Err(ExtensionError::InvalidArgument(format!(
264                "Expected string, got {v}",
265            ))),
266        }
267    }
268}
269
270impl TryFrom<ExtensionValue> for String {
271    type Error = ExtensionError;
272
273    fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
274        match value {
275            ExtensionValue::String(s) => Ok(s),
276            v => Err(ExtensionError::InvalidArgument(format!(
277                "Expected string, got {v}",
278            ))),
279        }
280    }
281}
282
283/// Helper for extracting the identifier from an [`ExtensionValue::Enum`].
284pub struct EnumValue(pub String);
285
286impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
287    type Error = ExtensionError;
288
289    fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
290        match value {
291            ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
292            v => Err(ExtensionError::InvalidArgument(format!(
293                "Expected enum, got {v}",
294            ))),
295        }
296    }
297}
298
299impl<'a> TryFrom<&'a ExtensionValue> for &'a TupleValue {
300    type Error = ExtensionError;
301
302    fn try_from(value: &'a ExtensionValue) -> Result<&'a TupleValue, Self::Error> {
303        match value {
304            ExtensionValue::Tuple(tv) => Ok(tv),
305            v => Err(ExtensionError::InvalidArgument(format!(
306                "Expected tuple, got {v}",
307            ))),
308        }
309    }
310}
311
312impl TryFrom<&ExtensionValue> for i64 {
313    type Error = ExtensionError;
314
315    fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
316        match value {
317            &ExtensionValue::Integer(i) => Ok(i),
318            v => Err(ExtensionError::InvalidArgument(format!(
319                "Expected integer, got {v}",
320            ))),
321        }
322    }
323}
324
325impl TryFrom<&ExtensionValue> for f64 {
326    type Error = ExtensionError;
327
328    fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
329        match value {
330            &ExtensionValue::Float(f) => Ok(f),
331            v => Err(ExtensionError::InvalidArgument(format!(
332                "Expected float, got {v}",
333            ))),
334        }
335    }
336}
337
338impl TryFrom<&ExtensionValue> for bool {
339    type Error = ExtensionError;
340
341    fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
342        match value {
343            &ExtensionValue::Boolean(b) => Ok(b),
344            v => Err(ExtensionError::InvalidArgument(format!(
345                "Expected boolean, got {v}",
346            ))),
347        }
348    }
349}
350
351impl TryFrom<&ExtensionValue> for Reference {
352    type Error = ExtensionError;
353
354    fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
355        match value {
356            &ExtensionValue::Reference(r) => Ok(Reference(r)),
357            v => Err(ExtensionError::InvalidArgument(format!(
358                "Expected reference, got {v}",
359            ))),
360        }
361    }
362}
363
364/// Represents an output column specification
365#[derive(Debug, Clone)]
366pub enum ExtensionColumn {
367    /// Named column with type (name:type)
368    Named { name: String, type_spec: String },
369    /// Field reference ($0, $1, etc.)
370    Reference(i32),
371    /// Expression column — not yet fully supported, hence the private
372    /// interface.
373    #[allow(private_interfaces)]
374    Expression(RawExpression),
375}
376
377/// Extension relation types
378#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum ExtensionRelationType {
380    /// Extension leaf relation - no input children
381    Leaf,
382    /// Extension single relation - exactly one input child
383    Single,
384    /// Extension multi relation - zero or more input children
385    Multi,
386}
387
388impl std::str::FromStr for ExtensionRelationType {
389    type Err = String;
390
391    fn from_str(s: &str) -> Result<Self, Self::Err> {
392        match s {
393            "ExtensionLeaf" => Ok(ExtensionRelationType::Leaf),
394            "ExtensionSingle" => Ok(ExtensionRelationType::Single),
395            "ExtensionMulti" => Ok(ExtensionRelationType::Multi),
396            _ => Err(format!("Unknown extension relation type: {}", s)),
397        }
398    }
399}
400
401impl ExtensionRelationType {
402    /// Get the string representation used in the text format
403    pub fn as_str(&self) -> &'static str {
404        match self {
405            ExtensionRelationType::Leaf => "ExtensionLeaf",
406            ExtensionRelationType::Single => "ExtensionSingle",
407            ExtensionRelationType::Multi => "ExtensionMulti",
408        }
409    }
410
411    /// Validate that the child count matches this relation type
412    pub fn validate_child_count(&self, child_count: usize) -> Result<(), String> {
413        match self {
414            ExtensionRelationType::Leaf => {
415                if child_count == 0 {
416                    Ok(())
417                } else {
418                    Err(format!(
419                        "ExtensionLeaf should have no input children, got {child_count}"
420                    ))
421                }
422            }
423            ExtensionRelationType::Single => {
424                if child_count == 1 {
425                    Ok(())
426                } else {
427                    Err(format!(
428                        "ExtensionSingle should have exactly 1 input child, got {child_count}"
429                    ))
430                }
431            }
432            ExtensionRelationType::Multi => {
433                // ExtensionMulti relations accept zero or more children.
434                Ok(())
435            }
436        }
437    }
438}
439
440// Note: create_rel is implemented in parser/extensions.rs to avoid
441// pulling in protobuf dependencies in the core args module
442
443impl ExtensionArgs {
444    /// Create a new empty ExtensionArgs
445    pub fn new(relation_type: ExtensionRelationType) -> Self {
446        Self {
447            positional: Vec::new(),
448            named: IndexMap::new(),
449            output_columns: Vec::new(),
450            relation_type,
451        }
452    }
453
454    /// Create an extractor for validating named arguments
455    pub fn extractor(&self) -> ArgsExtractor<'_> {
456        ArgsExtractor::new(self)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::ExtensionRelationType;
463
464    #[test]
465    fn extension_multi_allows_zero_children() {
466        assert!(ExtensionRelationType::Multi.validate_child_count(0).is_ok());
467    }
468
469    #[test]
470    fn extension_multi_allows_single_child() {
471        assert!(ExtensionRelationType::Multi.validate_child_count(1).is_ok());
472    }
473
474    #[test]
475    fn extension_multi_allows_multiple_children() {
476        assert!(ExtensionRelationType::Multi.validate_child_count(3).is_ok());
477    }
478
479    #[test]
480    fn extension_single_rejects_wrong_child_counts() {
481        assert!(
482            ExtensionRelationType::Single
483                .validate_child_count(0)
484                .is_err()
485        );
486        assert!(
487            ExtensionRelationType::Single
488                .validate_child_count(2)
489                .is_err()
490        );
491    }
492}