substrait_explain/extensions/
args.rs

1//! Core extension data structures without parser dependencies
2//!
3//! This module contains the core data structures for extension arguments,
4//! values, and columns without any parser or textify dependencies.
5
6use std::collections::HashSet;
7use std::fmt;
8
9use indexmap::IndexMap;
10
11use super::ExtensionError;
12use crate::textify::expressions::Reference;
13use crate::textify::types::escaped;
14
15/// Placeholder for a future expression implementation.
16/// Holds the raw text of the parsed expression. The inner field is private —
17/// this type will be replaced with a proper expression AST in the future.
18#[derive(Debug, Clone)]
19pub(crate) struct RawExpression {
20    text: String,
21}
22
23impl RawExpression {
24    pub fn new(text: String) -> Self {
25        Self { text }
26    }
27}
28
29impl fmt::Display for RawExpression {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "{}", self.text)
32    }
33}
34
35/// Represents the arguments and output columns for an extension relation.
36///
37/// Named arguments are stored in an [`IndexMap`] whose iteration order
38/// determines display order. Extension [`Explainable::to_args()`]
39/// implementations should insert named arguments in the order they should
40/// appear in the text format.
41#[derive(Debug, Clone)]
42pub struct ExtensionArgs {
43    /// Positional arguments (expressions, literals, references)
44    pub positional: Vec<ExtensionValue>,
45    /// Named arguments, displayed in the order they were inserted
46    pub named: IndexMap<String, ExtensionValue>,
47    /// Output columns (named columns, references, or expressions)
48    pub output_columns: Vec<ExtensionColumn>,
49    /// The type of extension relation (Leaf/Single/Multi)
50    pub relation_type: ExtensionRelationType,
51}
52
53/// Helper struct for extracting named arguments with validation.
54///
55/// Tracks which arguments have been consumed. Callers **must** call
56/// [`check_exhausted`](ArgsExtractor::check_exhausted) before dropping to
57/// verify no unexpected arguments remain. In debug builds, dropping without
58/// calling `check_exhausted` will panic (matching the [`RuleIter`] pattern).
59pub struct ArgsExtractor<'a> {
60    args: &'a ExtensionArgs,
61    consumed: HashSet<&'a str>,
62    checked: bool,
63}
64
65impl<'a> ArgsExtractor<'a> {
66    /// Create a new extractor for the given arguments
67    pub fn new(args: &'a ExtensionArgs) -> Self {
68        Self {
69            args,
70            consumed: HashSet::new(),
71            checked: false,
72        }
73    }
74
75    /// Get a named argument value, marking it as consumed if found.
76    pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
77        match self.args.named.get_key_value(name) {
78            Some((k, value)) => {
79                self.consumed.insert(k);
80                Some(value)
81            }
82            None => None,
83        }
84    }
85
86    /// Get a named argument value or return an error
87    /// Marks the argument as consumed if found
88    pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
89    where
90        T: TryFrom<&'a ExtensionValue>,
91        T::Error: Into<ExtensionError>,
92    {
93        match self.get_named_arg(name) {
94            Some(value) => T::try_from(value).map_err(Into::into),
95            None => Err(ExtensionError::MissingArgument {
96                name: name.to_string(),
97            }),
98        }
99    }
100
101    /// Get a named argument value or default
102    /// Marks the argument as consumed if it exists in the source args
103    pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
104    where
105        T: TryFrom<&'a ExtensionValue>,
106        T::Error: Into<ExtensionError>,
107    {
108        match self.get_named_arg(name) {
109            Some(value) => T::try_from(value).map_err(Into::into),
110            None => Ok(default),
111        }
112    }
113
114    /// Check that all named arguments in the source have been consumed,
115    /// returning an error if not.
116    ///
117    /// Must be called before the extractor is dropped, to validate that all
118    /// args are correctly handled. In debug builds, dropping without calling
119    /// this method will panic.
120    pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
121        self.checked = true;
122
123        let mut unknown_args = Vec::new();
124        for name in self.args.named.keys() {
125            if !self.consumed.contains(name.as_str()) {
126                unknown_args.push(name.as_str());
127            }
128        }
129
130        if unknown_args.is_empty() {
131            Ok(())
132        } else {
133            // Sort for stable error messages
134            unknown_args.sort();
135            Err(ExtensionError::InvalidArgument(format!(
136                "Unknown named arguments: {}",
137                unknown_args.join(", ")
138            )))
139        }
140    }
141}
142
143impl Drop for ArgsExtractor<'_> {
144    fn drop(&mut self) {
145        if self.checked || std::thread::panicking() {
146            return;
147        }
148        // If we get here, the caller forgot to call check_exhausted().
149        debug_assert!(
150            false,
151            "ArgsExtractor dropped without calling check_exhausted()"
152        );
153    }
154}
155
156/// Represents a value in extension arguments
157#[derive(Debug, Clone)]
158pub enum ExtensionValue {
159    /// String literal value
160    String(String),
161    /// Integer literal value
162    Integer(i64),
163    /// Float literal value
164    Float(f64),
165    /// Boolean literal value
166    Boolean(bool),
167    /// Field reference ($0, $1, etc.)
168    Reference(i32),
169    /// Expression (function call, etc.) — not yet fully supported, hence the
170    /// private interface.
171    #[allow(private_interfaces)]
172    Expression(RawExpression),
173}
174
175impl fmt::Display for ExtensionValue {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        match self {
178            ExtensionValue::String(s) => write!(f, "String({})", escaped(s)),
179            ExtensionValue::Integer(i) => write!(f, "Integer({})", i),
180            ExtensionValue::Float(n) => write!(f, "Float({})", n),
181            ExtensionValue::Boolean(b) => write!(f, "Boolean({})", b),
182            ExtensionValue::Reference(r) => write!(f, "Reference({})", r),
183            ExtensionValue::Expression(e) => write!(f, "Expression({})", e),
184        }
185    }
186}
187
188impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
189    type Error = ExtensionError;
190
191    fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
192        match value {
193            ExtensionValue::String(s) => Ok(s),
194            v => Err(ExtensionError::InvalidArgument(format!(
195                "Expected string, got {v}",
196            ))),
197        }
198    }
199}
200
201impl TryFrom<ExtensionValue> for String {
202    type Error = ExtensionError;
203
204    fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
205        match value {
206            ExtensionValue::String(s) => Ok(s),
207            v => Err(ExtensionError::InvalidArgument(format!(
208                "Expected string, got {v}",
209            ))),
210        }
211    }
212}
213
214impl TryFrom<&ExtensionValue> for i64 {
215    type Error = ExtensionError;
216
217    fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
218        match value {
219            &ExtensionValue::Integer(i) => Ok(i),
220            v => Err(ExtensionError::InvalidArgument(format!(
221                "Expected integer, got {v}",
222            ))),
223        }
224    }
225}
226
227impl TryFrom<&ExtensionValue> for f64 {
228    type Error = ExtensionError;
229
230    fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
231        match value {
232            &ExtensionValue::Float(f) => Ok(f),
233            v => Err(ExtensionError::InvalidArgument(format!(
234                "Expected float, got {v}",
235            ))),
236        }
237    }
238}
239
240impl TryFrom<&ExtensionValue> for bool {
241    type Error = ExtensionError;
242
243    fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
244        match value {
245            &ExtensionValue::Boolean(b) => Ok(b),
246            v => Err(ExtensionError::InvalidArgument(format!(
247                "Expected boolean, got {v}",
248            ))),
249        }
250    }
251}
252
253impl TryFrom<&ExtensionValue> for Reference {
254    type Error = ExtensionError;
255
256    fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
257        match value {
258            &ExtensionValue::Reference(r) => Ok(Reference(r)),
259            v => Err(ExtensionError::InvalidArgument(format!(
260                "Expected reference, got {v}",
261            ))),
262        }
263    }
264}
265
266/// Represents an output column specification
267#[derive(Debug, Clone)]
268pub enum ExtensionColumn {
269    /// Named column with type (name:type)
270    Named { name: String, type_spec: String },
271    /// Field reference ($0, $1, etc.)
272    Reference(i32),
273    /// Expression column — not yet fully supported, hence the private
274    /// interface.
275    #[allow(private_interfaces)]
276    Expression(RawExpression),
277}
278
279/// Extension relation types
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281pub enum ExtensionRelationType {
282    /// Extension leaf relation - no input children
283    Leaf,
284    /// Extension single relation - exactly one input child
285    Single,
286    /// Extension multi relation - zero or more input children
287    Multi,
288}
289
290impl std::str::FromStr for ExtensionRelationType {
291    type Err = String;
292
293    fn from_str(s: &str) -> Result<Self, Self::Err> {
294        match s {
295            "ExtensionLeaf" => Ok(ExtensionRelationType::Leaf),
296            "ExtensionSingle" => Ok(ExtensionRelationType::Single),
297            "ExtensionMulti" => Ok(ExtensionRelationType::Multi),
298            _ => Err(format!("Unknown extension relation type: {}", s)),
299        }
300    }
301}
302
303impl ExtensionRelationType {
304    /// Get the string representation used in the text format
305    pub fn as_str(&self) -> &'static str {
306        match self {
307            ExtensionRelationType::Leaf => "ExtensionLeaf",
308            ExtensionRelationType::Single => "ExtensionSingle",
309            ExtensionRelationType::Multi => "ExtensionMulti",
310        }
311    }
312
313    /// Validate that the child count matches this relation type
314    pub fn validate_child_count(&self, child_count: usize) -> Result<(), String> {
315        match self {
316            ExtensionRelationType::Leaf => {
317                if child_count == 0 {
318                    Ok(())
319                } else {
320                    Err(format!(
321                        "ExtensionLeaf should have no input children, got {child_count}"
322                    ))
323                }
324            }
325            ExtensionRelationType::Single => {
326                if child_count == 1 {
327                    Ok(())
328                } else {
329                    Err(format!(
330                        "ExtensionSingle should have exactly 1 input child, got {child_count}"
331                    ))
332                }
333            }
334            ExtensionRelationType::Multi => {
335                // ExtensionMulti relations accept zero or more children.
336                Ok(())
337            }
338        }
339    }
340}
341
342// Note: create_rel is implemented in parser/extensions.rs to avoid
343// pulling in protobuf dependencies in the core args module
344
345impl ExtensionArgs {
346    /// Create a new empty ExtensionArgs
347    pub fn new(relation_type: ExtensionRelationType) -> Self {
348        Self {
349            positional: Vec::new(),
350            named: IndexMap::new(),
351            output_columns: Vec::new(),
352            relation_type,
353        }
354    }
355
356    /// Create an extractor for validating named arguments
357    pub fn extractor(&self) -> ArgsExtractor<'_> {
358        ArgsExtractor::new(self)
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::ExtensionRelationType;
365
366    #[test]
367    fn extension_multi_allows_zero_children() {
368        assert!(ExtensionRelationType::Multi.validate_child_count(0).is_ok());
369    }
370
371    #[test]
372    fn extension_multi_allows_single_child() {
373        assert!(ExtensionRelationType::Multi.validate_child_count(1).is_ok());
374    }
375
376    #[test]
377    fn extension_multi_allows_multiple_children() {
378        assert!(ExtensionRelationType::Multi.validate_child_count(3).is_ok());
379    }
380
381    #[test]
382    fn extension_single_rejects_wrong_child_counts() {
383        assert!(
384            ExtensionRelationType::Single
385                .validate_child_count(0)
386                .is_err()
387        );
388        assert!(
389            ExtensionRelationType::Single
390                .validate_child_count(2)
391                .is_err()
392        );
393    }
394}