substrait_explain/extensions/
args.rs1use std::collections::HashSet;
7use std::fmt;
8
9use indexmap::IndexMap;
10
11use super::ExtensionError;
12use crate::textify::expressions::Reference;
13use crate::textify::types::escaped;
14
15#[derive(Debug, Clone)]
19pub(crate) struct RawExpression {
20 text: String,
21}
22
23impl RawExpression {
24 pub fn new(text: String) -> Self {
25 Self { text }
26 }
27}
28
29impl fmt::Display for RawExpression {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{}", self.text)
32 }
33}
34
35#[derive(Debug, Clone)]
42pub struct ExtensionArgs {
43 pub positional: Vec<ExtensionValue>,
45 pub named: IndexMap<String, ExtensionValue>,
47 pub output_columns: Vec<ExtensionColumn>,
49 pub relation_type: ExtensionRelationType,
51}
52
53pub struct ArgsExtractor<'a> {
60 args: &'a ExtensionArgs,
61 consumed: HashSet<&'a str>,
62 checked: bool,
63}
64
65impl<'a> ArgsExtractor<'a> {
66 pub fn new(args: &'a ExtensionArgs) -> Self {
68 Self {
69 args,
70 consumed: HashSet::new(),
71 checked: false,
72 }
73 }
74
75 pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
77 match self.args.named.get_key_value(name) {
78 Some((k, value)) => {
79 self.consumed.insert(k);
80 Some(value)
81 }
82 None => None,
83 }
84 }
85
86 pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
89 where
90 T: TryFrom<&'a ExtensionValue>,
91 T::Error: Into<ExtensionError>,
92 {
93 match self.get_named_arg(name) {
94 Some(value) => T::try_from(value).map_err(Into::into),
95 None => Err(ExtensionError::MissingArgument {
96 name: name.to_string(),
97 }),
98 }
99 }
100
101 pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
104 where
105 T: TryFrom<&'a ExtensionValue>,
106 T::Error: Into<ExtensionError>,
107 {
108 match self.get_named_arg(name) {
109 Some(value) => T::try_from(value).map_err(Into::into),
110 None => Ok(default),
111 }
112 }
113
114 pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
121 self.checked = true;
122
123 let mut unknown_args = Vec::new();
124 for name in self.args.named.keys() {
125 if !self.consumed.contains(name.as_str()) {
126 unknown_args.push(name.as_str());
127 }
128 }
129
130 if unknown_args.is_empty() {
131 Ok(())
132 } else {
133 unknown_args.sort();
135 Err(ExtensionError::InvalidArgument(format!(
136 "Unknown named arguments: {}",
137 unknown_args.join(", ")
138 )))
139 }
140 }
141}
142
143impl Drop for ArgsExtractor<'_> {
144 fn drop(&mut self) {
145 if self.checked || std::thread::panicking() {
146 return;
147 }
148 debug_assert!(
150 false,
151 "ArgsExtractor dropped without calling check_exhausted()"
152 );
153 }
154}
155
156#[derive(Debug, Clone)]
157pub struct TupleValue(Vec<ExtensionValue>);
158
159impl TupleValue {
160 pub fn len(&self) -> usize {
161 self.0.len()
162 }
163
164 pub fn is_empty(&self) -> bool {
165 self.0.is_empty()
166 }
167
168 pub fn iter(&self) -> std::slice::Iter<'_, ExtensionValue> {
169 self.0.iter()
170 }
171}
172
173impl fmt::Display for TupleValue {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 write!(f, "(")?;
176 for (i, item) in self.0.iter().enumerate() {
177 if i > 0 {
178 write!(f, ", ")?;
179 }
180 write!(f, "{item}")?;
181 }
182 if self.0.len() == 1 {
183 write!(f, ",")?;
184 }
185 write!(f, ")")
186 }
187}
188
189impl<'a> IntoIterator for &'a TupleValue {
190 type Item = &'a ExtensionValue;
191 type IntoIter = std::slice::Iter<'a, ExtensionValue>;
192
193 fn into_iter(self) -> Self::IntoIter {
194 self.0.iter()
195 }
196}
197
198impl IntoIterator for TupleValue {
199 type Item = ExtensionValue;
200 type IntoIter = std::vec::IntoIter<ExtensionValue>;
201
202 fn into_iter(self) -> Self::IntoIter {
203 self.0.into_iter()
204 }
205}
206
207impl FromIterator<ExtensionValue> for TupleValue {
208 fn from_iter<I: IntoIterator<Item = ExtensionValue>>(iter: I) -> Self {
209 TupleValue(iter.into_iter().collect())
210 }
211}
212
213impl From<Vec<ExtensionValue>> for TupleValue {
214 fn from(items: Vec<ExtensionValue>) -> Self {
215 TupleValue(items)
216 }
217}
218
219#[derive(Debug, Clone)]
221pub enum ExtensionValue {
222 String(String),
224 Integer(i64),
226 Float(f64),
228 Boolean(bool),
230 Reference(i32),
232 Enum(String),
234 Tuple(TupleValue),
236 #[allow(private_interfaces)]
239 Expression(RawExpression),
240}
241
242impl fmt::Display for ExtensionValue {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 match self {
245 ExtensionValue::String(s) => write!(f, "String({})", escaped(s)),
246 ExtensionValue::Integer(i) => write!(f, "Integer({})", i),
247 ExtensionValue::Float(n) => write!(f, "Float({})", n),
248 ExtensionValue::Boolean(b) => write!(f, "Boolean({})", b),
249 ExtensionValue::Reference(r) => write!(f, "Reference({})", r),
250 ExtensionValue::Enum(e) => write!(f, "Enum(&{})", e),
251 ExtensionValue::Tuple(tv) => write!(f, "Tuple{tv}"),
252 ExtensionValue::Expression(e) => write!(f, "Expression({})", e),
253 }
254 }
255}
256
257impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
258 type Error = ExtensionError;
259
260 fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
261 match value {
262 ExtensionValue::String(s) => Ok(s),
263 v => Err(ExtensionError::InvalidArgument(format!(
264 "Expected string, got {v}",
265 ))),
266 }
267 }
268}
269
270impl TryFrom<ExtensionValue> for String {
271 type Error = ExtensionError;
272
273 fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
274 match value {
275 ExtensionValue::String(s) => Ok(s),
276 v => Err(ExtensionError::InvalidArgument(format!(
277 "Expected string, got {v}",
278 ))),
279 }
280 }
281}
282
283pub struct EnumValue(pub String);
285
286impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
287 type Error = ExtensionError;
288
289 fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
290 match value {
291 ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
292 v => Err(ExtensionError::InvalidArgument(format!(
293 "Expected enum, got {v}",
294 ))),
295 }
296 }
297}
298
299impl<'a> TryFrom<&'a ExtensionValue> for &'a TupleValue {
300 type Error = ExtensionError;
301
302 fn try_from(value: &'a ExtensionValue) -> Result<&'a TupleValue, Self::Error> {
303 match value {
304 ExtensionValue::Tuple(tv) => Ok(tv),
305 v => Err(ExtensionError::InvalidArgument(format!(
306 "Expected tuple, got {v}",
307 ))),
308 }
309 }
310}
311
312impl TryFrom<&ExtensionValue> for i64 {
313 type Error = ExtensionError;
314
315 fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
316 match value {
317 &ExtensionValue::Integer(i) => Ok(i),
318 v => Err(ExtensionError::InvalidArgument(format!(
319 "Expected integer, got {v}",
320 ))),
321 }
322 }
323}
324
325impl TryFrom<&ExtensionValue> for f64 {
326 type Error = ExtensionError;
327
328 fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
329 match value {
330 &ExtensionValue::Float(f) => Ok(f),
331 v => Err(ExtensionError::InvalidArgument(format!(
332 "Expected float, got {v}",
333 ))),
334 }
335 }
336}
337
338impl TryFrom<&ExtensionValue> for bool {
339 type Error = ExtensionError;
340
341 fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
342 match value {
343 &ExtensionValue::Boolean(b) => Ok(b),
344 v => Err(ExtensionError::InvalidArgument(format!(
345 "Expected boolean, got {v}",
346 ))),
347 }
348 }
349}
350
351impl TryFrom<&ExtensionValue> for Reference {
352 type Error = ExtensionError;
353
354 fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
355 match value {
356 &ExtensionValue::Reference(r) => Ok(Reference(r)),
357 v => Err(ExtensionError::InvalidArgument(format!(
358 "Expected reference, got {v}",
359 ))),
360 }
361 }
362}
363
364#[derive(Debug, Clone)]
366pub enum ExtensionColumn {
367 Named { name: String, type_spec: String },
369 Reference(i32),
371 #[allow(private_interfaces)]
374 Expression(RawExpression),
375}
376
377#[derive(Debug, Clone, Copy, PartialEq, Eq)]
379pub enum ExtensionRelationType {
380 Leaf,
382 Single,
384 Multi,
386}
387
388impl std::str::FromStr for ExtensionRelationType {
389 type Err = String;
390
391 fn from_str(s: &str) -> Result<Self, Self::Err> {
392 match s {
393 "ExtensionLeaf" => Ok(ExtensionRelationType::Leaf),
394 "ExtensionSingle" => Ok(ExtensionRelationType::Single),
395 "ExtensionMulti" => Ok(ExtensionRelationType::Multi),
396 _ => Err(format!("Unknown extension relation type: {}", s)),
397 }
398 }
399}
400
401impl ExtensionRelationType {
402 pub fn as_str(&self) -> &'static str {
404 match self {
405 ExtensionRelationType::Leaf => "ExtensionLeaf",
406 ExtensionRelationType::Single => "ExtensionSingle",
407 ExtensionRelationType::Multi => "ExtensionMulti",
408 }
409 }
410
411 pub fn validate_child_count(&self, child_count: usize) -> Result<(), String> {
413 match self {
414 ExtensionRelationType::Leaf => {
415 if child_count == 0 {
416 Ok(())
417 } else {
418 Err(format!(
419 "ExtensionLeaf should have no input children, got {child_count}"
420 ))
421 }
422 }
423 ExtensionRelationType::Single => {
424 if child_count == 1 {
425 Ok(())
426 } else {
427 Err(format!(
428 "ExtensionSingle should have exactly 1 input child, got {child_count}"
429 ))
430 }
431 }
432 ExtensionRelationType::Multi => {
433 Ok(())
435 }
436 }
437 }
438}
439
440impl ExtensionArgs {
444 pub fn new(relation_type: ExtensionRelationType) -> Self {
446 Self {
447 positional: Vec::new(),
448 named: IndexMap::new(),
449 output_columns: Vec::new(),
450 relation_type,
451 }
452 }
453
454 pub fn extractor(&self) -> ArgsExtractor<'_> {
456 ArgsExtractor::new(self)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::ExtensionRelationType;
463
464 #[test]
465 fn extension_multi_allows_zero_children() {
466 assert!(ExtensionRelationType::Multi.validate_child_count(0).is_ok());
467 }
468
469 #[test]
470 fn extension_multi_allows_single_child() {
471 assert!(ExtensionRelationType::Multi.validate_child_count(1).is_ok());
472 }
473
474 #[test]
475 fn extension_multi_allows_multiple_children() {
476 assert!(ExtensionRelationType::Multi.validate_child_count(3).is_ok());
477 }
478
479 #[test]
480 fn extension_single_rejects_wrong_child_counts() {
481 assert!(
482 ExtensionRelationType::Single
483 .validate_child_count(0)
484 .is_err()
485 );
486 assert!(
487 ExtensionRelationType::Single
488 .validate_child_count(2)
489 .is_err()
490 );
491 }
492}