substrait_explain/extensions/
args.rs

1//! Core extension data structures without parser dependencies
2//!
3//! This module contains the core data structures for extension arguments,
4//! values, and columns without any parser or textify dependencies.
5
6use std::collections::HashSet;
7use std::fmt;
8
9use indexmap::IndexMap;
10
11use super::ExtensionError;
12use crate::textify::expressions::Reference;
13use crate::textify::types::escaped;
14
15/// Placeholder for a future expression implementation.
16/// Holds the raw text of the parsed expression. The inner field is private —
17/// this type will be replaced with a proper expression AST in the future.
18#[derive(Debug, Clone)]
19pub(crate) struct RawExpression {
20    text: String,
21}
22
23impl RawExpression {
24    pub fn new(text: String) -> Self {
25        Self { text }
26    }
27}
28
29impl fmt::Display for RawExpression {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "{}", self.text)
32    }
33}
34
35/// Represents the arguments and output columns for an extension relation.
36///
37/// Named arguments are stored in an [`IndexMap`] whose iteration order
38/// determines display order. Extension [`super::Explainable::to_args()`]
39/// implementations should insert named arguments in the order they should
40/// appear in the text format.
41#[derive(Debug, Clone)]
42pub struct ExtensionArgs {
43    /// Positional arguments (expressions, literals, references)
44    pub positional: Vec<ExtensionValue>,
45    /// Named arguments, displayed in the order they were inserted
46    pub named: IndexMap<String, ExtensionValue>,
47    /// Output columns (named columns, references, or expressions)
48    pub output_columns: Vec<ExtensionColumn>,
49    /// The type of extension relation (Leaf/Single/Multi)
50    pub relation_type: ExtensionRelationType,
51}
52
53/// Helper struct for extracting named arguments with validation.
54///
55/// Tracks which arguments have been consumed. Callers **must** call
56/// [`check_exhausted`](ArgsExtractor::check_exhausted) before dropping to
57/// verify no unexpected arguments remain. In debug builds, dropping without
58/// calling `check_exhausted` will panic (matching the [`RuleIter`](crate::parser::RuleIter) pattern).
59pub struct ArgsExtractor<'a> {
60    args: &'a ExtensionArgs,
61    consumed: HashSet<&'a str>,
62    checked: bool,
63}
64
65impl<'a> ArgsExtractor<'a> {
66    /// Create a new extractor for the given arguments
67    pub fn new(args: &'a ExtensionArgs) -> Self {
68        Self {
69            args,
70            consumed: HashSet::new(),
71            checked: false,
72        }
73    }
74
75    /// Get a named argument value, marking it as consumed if found.
76    pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
77        match self.args.named.get_key_value(name) {
78            Some((k, value)) => {
79                self.consumed.insert(k);
80                Some(value)
81            }
82            None => None,
83        }
84    }
85
86    /// Get a named argument value or return an error
87    /// Marks the argument as consumed if found
88    pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
89    where
90        T: TryFrom<&'a ExtensionValue>,
91        T::Error: Into<ExtensionError>,
92    {
93        match self.get_named_arg(name) {
94            Some(value) => T::try_from(value).map_err(Into::into),
95            None => Err(ExtensionError::MissingArgument {
96                name: name.to_string(),
97            }),
98        }
99    }
100
101    /// Get a named argument value or default
102    /// Marks the argument as consumed if it exists in the source args
103    pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
104    where
105        T: TryFrom<&'a ExtensionValue>,
106        T::Error: Into<ExtensionError>,
107    {
108        match self.get_named_arg(name) {
109            Some(value) => T::try_from(value).map_err(Into::into),
110            None => Ok(default),
111        }
112    }
113
114    /// Check that all named arguments in the source have been consumed,
115    /// returning an error if not.
116    ///
117    /// Must be called before the extractor is dropped, to validate that all
118    /// args are correctly handled. In debug builds, dropping without calling
119    /// this method will panic.
120    pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
121        self.checked = true;
122
123        let mut unknown_args = Vec::new();
124        for name in self.args.named.keys() {
125            if !self.consumed.contains(name.as_str()) {
126                unknown_args.push(name.as_str());
127            }
128        }
129
130        if unknown_args.is_empty() {
131            Ok(())
132        } else {
133            // Sort for stable error messages
134            unknown_args.sort();
135            Err(ExtensionError::InvalidArgument(format!(
136                "Unknown named arguments: {}",
137                unknown_args.join(", ")
138            )))
139        }
140    }
141}
142
143impl Drop for ArgsExtractor<'_> {
144    fn drop(&mut self) {
145        if self.checked || std::thread::panicking() {
146            return;
147        }
148        // If we get here, the caller forgot to call check_exhausted().
149        debug_assert!(
150            false,
151            "ArgsExtractor dropped without calling check_exhausted()"
152        );
153    }
154}
155
156/// Represents a value in extension arguments
157#[derive(Debug, Clone)]
158pub enum ExtensionValue {
159    /// String literal value
160    String(String),
161    /// Integer literal value
162    Integer(i64),
163    /// Float literal value
164    Float(f64),
165    /// Boolean literal value
166    Boolean(bool),
167    /// Field reference ($0, $1, etc.)
168    Reference(i32),
169    /// Enum value (e.g. &CORE, &Inner) — Uses the wrapper EnumValue. the string holds the identifier without the `&` prefix
170    Enum(String),
171    /// Expression (function call, etc.) — not yet fully supported, hence the
172    /// private interface.
173    #[allow(private_interfaces)]
174    Expression(RawExpression),
175}
176
177impl fmt::Display for ExtensionValue {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        match self {
180            ExtensionValue::String(s) => write!(f, "String({})", escaped(s)),
181            ExtensionValue::Integer(i) => write!(f, "Integer({})", i),
182            ExtensionValue::Float(n) => write!(f, "Float({})", n),
183            ExtensionValue::Boolean(b) => write!(f, "Boolean({})", b),
184            ExtensionValue::Reference(r) => write!(f, "Reference({})", r),
185            ExtensionValue::Enum(e) => write!(f, "Enum(&{})", e),
186            ExtensionValue::Expression(e) => write!(f, "Expression({})", e),
187        }
188    }
189}
190
191impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
192    type Error = ExtensionError;
193
194    fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
195        match value {
196            ExtensionValue::String(s) => Ok(s),
197            v => Err(ExtensionError::InvalidArgument(format!(
198                "Expected string, got {v}",
199            ))),
200        }
201    }
202}
203
204impl TryFrom<ExtensionValue> for String {
205    type Error = ExtensionError;
206
207    fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
208        match value {
209            ExtensionValue::String(s) => Ok(s),
210            v => Err(ExtensionError::InvalidArgument(format!(
211                "Expected string, got {v}",
212            ))),
213        }
214    }
215}
216
217/// Helper for extracting the identifier from an [`ExtensionValue::Enum`].
218pub struct EnumValue(pub String);
219
220impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
221    type Error = ExtensionError;
222
223    fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
224        match value {
225            ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
226            v => Err(ExtensionError::InvalidArgument(format!(
227                "Expected enum, got {v}",
228            ))),
229        }
230    }
231}
232
233impl TryFrom<&ExtensionValue> for i64 {
234    type Error = ExtensionError;
235
236    fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
237        match value {
238            &ExtensionValue::Integer(i) => Ok(i),
239            v => Err(ExtensionError::InvalidArgument(format!(
240                "Expected integer, got {v}",
241            ))),
242        }
243    }
244}
245
246impl TryFrom<&ExtensionValue> for f64 {
247    type Error = ExtensionError;
248
249    fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
250        match value {
251            &ExtensionValue::Float(f) => Ok(f),
252            v => Err(ExtensionError::InvalidArgument(format!(
253                "Expected float, got {v}",
254            ))),
255        }
256    }
257}
258
259impl TryFrom<&ExtensionValue> for bool {
260    type Error = ExtensionError;
261
262    fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
263        match value {
264            &ExtensionValue::Boolean(b) => Ok(b),
265            v => Err(ExtensionError::InvalidArgument(format!(
266                "Expected boolean, got {v}",
267            ))),
268        }
269    }
270}
271
272impl TryFrom<&ExtensionValue> for Reference {
273    type Error = ExtensionError;
274
275    fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
276        match value {
277            &ExtensionValue::Reference(r) => Ok(Reference(r)),
278            v => Err(ExtensionError::InvalidArgument(format!(
279                "Expected reference, got {v}",
280            ))),
281        }
282    }
283}
284
285/// Represents an output column specification
286#[derive(Debug, Clone)]
287pub enum ExtensionColumn {
288    /// Named column with type (name:type)
289    Named { name: String, type_spec: String },
290    /// Field reference ($0, $1, etc.)
291    Reference(i32),
292    /// Expression column — not yet fully supported, hence the private
293    /// interface.
294    #[allow(private_interfaces)]
295    Expression(RawExpression),
296}
297
298/// Extension relation types
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub enum ExtensionRelationType {
301    /// Extension leaf relation - no input children
302    Leaf,
303    /// Extension single relation - exactly one input child
304    Single,
305    /// Extension multi relation - zero or more input children
306    Multi,
307}
308
309impl std::str::FromStr for ExtensionRelationType {
310    type Err = String;
311
312    fn from_str(s: &str) -> Result<Self, Self::Err> {
313        match s {
314            "ExtensionLeaf" => Ok(ExtensionRelationType::Leaf),
315            "ExtensionSingle" => Ok(ExtensionRelationType::Single),
316            "ExtensionMulti" => Ok(ExtensionRelationType::Multi),
317            _ => Err(format!("Unknown extension relation type: {}", s)),
318        }
319    }
320}
321
322impl ExtensionRelationType {
323    /// Get the string representation used in the text format
324    pub fn as_str(&self) -> &'static str {
325        match self {
326            ExtensionRelationType::Leaf => "ExtensionLeaf",
327            ExtensionRelationType::Single => "ExtensionSingle",
328            ExtensionRelationType::Multi => "ExtensionMulti",
329        }
330    }
331
332    /// Validate that the child count matches this relation type
333    pub fn validate_child_count(&self, child_count: usize) -> Result<(), String> {
334        match self {
335            ExtensionRelationType::Leaf => {
336                if child_count == 0 {
337                    Ok(())
338                } else {
339                    Err(format!(
340                        "ExtensionLeaf should have no input children, got {child_count}"
341                    ))
342                }
343            }
344            ExtensionRelationType::Single => {
345                if child_count == 1 {
346                    Ok(())
347                } else {
348                    Err(format!(
349                        "ExtensionSingle should have exactly 1 input child, got {child_count}"
350                    ))
351                }
352            }
353            ExtensionRelationType::Multi => {
354                // ExtensionMulti relations accept zero or more children.
355                Ok(())
356            }
357        }
358    }
359}
360
361// Note: create_rel is implemented in parser/extensions.rs to avoid
362// pulling in protobuf dependencies in the core args module
363
364impl ExtensionArgs {
365    /// Create a new empty ExtensionArgs
366    pub fn new(relation_type: ExtensionRelationType) -> Self {
367        Self {
368            positional: Vec::new(),
369            named: IndexMap::new(),
370            output_columns: Vec::new(),
371            relation_type,
372        }
373    }
374
375    /// Create an extractor for validating named arguments
376    pub fn extractor(&self) -> ArgsExtractor<'_> {
377        ArgsExtractor::new(self)
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::ExtensionRelationType;
384
385    #[test]
386    fn extension_multi_allows_zero_children() {
387        assert!(ExtensionRelationType::Multi.validate_child_count(0).is_ok());
388    }
389
390    #[test]
391    fn extension_multi_allows_single_child() {
392        assert!(ExtensionRelationType::Multi.validate_child_count(1).is_ok());
393    }
394
395    #[test]
396    fn extension_multi_allows_multiple_children() {
397        assert!(ExtensionRelationType::Multi.validate_child_count(3).is_ok());
398    }
399
400    #[test]
401    fn extension_single_rejects_wrong_child_counts() {
402        assert!(
403            ExtensionRelationType::Single
404                .validate_child_count(0)
405                .is_err()
406        );
407        assert!(
408            ExtensionRelationType::Single
409                .validate_child_count(2)
410                .is_err()
411        );
412    }
413}