substrait_explain/extensions/
args.rs1use std::collections::HashSet;
7use std::fmt;
8
9use indexmap::IndexMap;
10
11use super::ExtensionError;
12use crate::textify::expressions::Reference;
13use crate::textify::types::escaped;
14
15#[derive(Debug, Clone)]
19pub(crate) struct RawExpression {
20 text: String,
21}
22
23impl RawExpression {
24 pub fn new(text: String) -> Self {
25 Self { text }
26 }
27}
28
29impl fmt::Display for RawExpression {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{}", self.text)
32 }
33}
34
35#[derive(Debug, Clone)]
42pub struct ExtensionArgs {
43 pub positional: Vec<ExtensionValue>,
45 pub named: IndexMap<String, ExtensionValue>,
47 pub output_columns: Vec<ExtensionColumn>,
49 pub relation_type: ExtensionRelationType,
51}
52
53pub struct ArgsExtractor<'a> {
60 args: &'a ExtensionArgs,
61 consumed: HashSet<&'a str>,
62 checked: bool,
63}
64
65impl<'a> ArgsExtractor<'a> {
66 pub fn new(args: &'a ExtensionArgs) -> Self {
68 Self {
69 args,
70 consumed: HashSet::new(),
71 checked: false,
72 }
73 }
74
75 pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
77 match self.args.named.get_key_value(name) {
78 Some((k, value)) => {
79 self.consumed.insert(k);
80 Some(value)
81 }
82 None => None,
83 }
84 }
85
86 pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
89 where
90 T: TryFrom<&'a ExtensionValue>,
91 T::Error: Into<ExtensionError>,
92 {
93 match self.get_named_arg(name) {
94 Some(value) => T::try_from(value).map_err(Into::into),
95 None => Err(ExtensionError::MissingArgument {
96 name: name.to_string(),
97 }),
98 }
99 }
100
101 pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
104 where
105 T: TryFrom<&'a ExtensionValue>,
106 T::Error: Into<ExtensionError>,
107 {
108 match self.get_named_arg(name) {
109 Some(value) => T::try_from(value).map_err(Into::into),
110 None => Ok(default),
111 }
112 }
113
114 pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
121 self.checked = true;
122
123 let mut unknown_args = Vec::new();
124 for name in self.args.named.keys() {
125 if !self.consumed.contains(name.as_str()) {
126 unknown_args.push(name.as_str());
127 }
128 }
129
130 if unknown_args.is_empty() {
131 Ok(())
132 } else {
133 unknown_args.sort();
135 Err(ExtensionError::InvalidArgument(format!(
136 "Unknown named arguments: {}",
137 unknown_args.join(", ")
138 )))
139 }
140 }
141}
142
143impl Drop for ArgsExtractor<'_> {
144 fn drop(&mut self) {
145 if self.checked || std::thread::panicking() {
146 return;
147 }
148 debug_assert!(
150 false,
151 "ArgsExtractor dropped without calling check_exhausted()"
152 );
153 }
154}
155
156#[derive(Debug, Clone)]
158pub enum ExtensionValue {
159 String(String),
161 Integer(i64),
163 Float(f64),
165 Boolean(bool),
167 Reference(i32),
169 Enum(String),
171 #[allow(private_interfaces)]
174 Expression(RawExpression),
175}
176
177impl fmt::Display for ExtensionValue {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 match self {
180 ExtensionValue::String(s) => write!(f, "String({})", escaped(s)),
181 ExtensionValue::Integer(i) => write!(f, "Integer({})", i),
182 ExtensionValue::Float(n) => write!(f, "Float({})", n),
183 ExtensionValue::Boolean(b) => write!(f, "Boolean({})", b),
184 ExtensionValue::Reference(r) => write!(f, "Reference({})", r),
185 ExtensionValue::Enum(e) => write!(f, "Enum(&{})", e),
186 ExtensionValue::Expression(e) => write!(f, "Expression({})", e),
187 }
188 }
189}
190
191impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
192 type Error = ExtensionError;
193
194 fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
195 match value {
196 ExtensionValue::String(s) => Ok(s),
197 v => Err(ExtensionError::InvalidArgument(format!(
198 "Expected string, got {v}",
199 ))),
200 }
201 }
202}
203
204impl TryFrom<ExtensionValue> for String {
205 type Error = ExtensionError;
206
207 fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
208 match value {
209 ExtensionValue::String(s) => Ok(s),
210 v => Err(ExtensionError::InvalidArgument(format!(
211 "Expected string, got {v}",
212 ))),
213 }
214 }
215}
216
217pub struct EnumValue(pub String);
219
220impl<'a> TryFrom<&'a ExtensionValue> for EnumValue {
221 type Error = ExtensionError;
222
223 fn try_from(value: &'a ExtensionValue) -> Result<EnumValue, Self::Error> {
224 match value {
225 ExtensionValue::Enum(s) => Ok(EnumValue(s.clone())),
226 v => Err(ExtensionError::InvalidArgument(format!(
227 "Expected enum, got {v}",
228 ))),
229 }
230 }
231}
232
233impl TryFrom<&ExtensionValue> for i64 {
234 type Error = ExtensionError;
235
236 fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
237 match value {
238 &ExtensionValue::Integer(i) => Ok(i),
239 v => Err(ExtensionError::InvalidArgument(format!(
240 "Expected integer, got {v}",
241 ))),
242 }
243 }
244}
245
246impl TryFrom<&ExtensionValue> for f64 {
247 type Error = ExtensionError;
248
249 fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
250 match value {
251 &ExtensionValue::Float(f) => Ok(f),
252 v => Err(ExtensionError::InvalidArgument(format!(
253 "Expected float, got {v}",
254 ))),
255 }
256 }
257}
258
259impl TryFrom<&ExtensionValue> for bool {
260 type Error = ExtensionError;
261
262 fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
263 match value {
264 &ExtensionValue::Boolean(b) => Ok(b),
265 v => Err(ExtensionError::InvalidArgument(format!(
266 "Expected boolean, got {v}",
267 ))),
268 }
269 }
270}
271
272impl TryFrom<&ExtensionValue> for Reference {
273 type Error = ExtensionError;
274
275 fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
276 match value {
277 &ExtensionValue::Reference(r) => Ok(Reference(r)),
278 v => Err(ExtensionError::InvalidArgument(format!(
279 "Expected reference, got {v}",
280 ))),
281 }
282 }
283}
284
285#[derive(Debug, Clone)]
287pub enum ExtensionColumn {
288 Named { name: String, type_spec: String },
290 Reference(i32),
292 #[allow(private_interfaces)]
295 Expression(RawExpression),
296}
297
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub enum ExtensionRelationType {
301 Leaf,
303 Single,
305 Multi,
307}
308
309impl std::str::FromStr for ExtensionRelationType {
310 type Err = String;
311
312 fn from_str(s: &str) -> Result<Self, Self::Err> {
313 match s {
314 "ExtensionLeaf" => Ok(ExtensionRelationType::Leaf),
315 "ExtensionSingle" => Ok(ExtensionRelationType::Single),
316 "ExtensionMulti" => Ok(ExtensionRelationType::Multi),
317 _ => Err(format!("Unknown extension relation type: {}", s)),
318 }
319 }
320}
321
322impl ExtensionRelationType {
323 pub fn as_str(&self) -> &'static str {
325 match self {
326 ExtensionRelationType::Leaf => "ExtensionLeaf",
327 ExtensionRelationType::Single => "ExtensionSingle",
328 ExtensionRelationType::Multi => "ExtensionMulti",
329 }
330 }
331
332 pub fn validate_child_count(&self, child_count: usize) -> Result<(), String> {
334 match self {
335 ExtensionRelationType::Leaf => {
336 if child_count == 0 {
337 Ok(())
338 } else {
339 Err(format!(
340 "ExtensionLeaf should have no input children, got {child_count}"
341 ))
342 }
343 }
344 ExtensionRelationType::Single => {
345 if child_count == 1 {
346 Ok(())
347 } else {
348 Err(format!(
349 "ExtensionSingle should have exactly 1 input child, got {child_count}"
350 ))
351 }
352 }
353 ExtensionRelationType::Multi => {
354 Ok(())
356 }
357 }
358 }
359}
360
361impl ExtensionArgs {
365 pub fn new(relation_type: ExtensionRelationType) -> Self {
367 Self {
368 positional: Vec::new(),
369 named: IndexMap::new(),
370 output_columns: Vec::new(),
371 relation_type,
372 }
373 }
374
375 pub fn extractor(&self) -> ArgsExtractor<'_> {
377 ArgsExtractor::new(self)
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::ExtensionRelationType;
384
385 #[test]
386 fn extension_multi_allows_zero_children() {
387 assert!(ExtensionRelationType::Multi.validate_child_count(0).is_ok());
388 }
389
390 #[test]
391 fn extension_multi_allows_single_child() {
392 assert!(ExtensionRelationType::Multi.validate_child_count(1).is_ok());
393 }
394
395 #[test]
396 fn extension_multi_allows_multiple_children() {
397 assert!(ExtensionRelationType::Multi.validate_child_count(3).is_ok());
398 }
399
400 #[test]
401 fn extension_single_rejects_wrong_child_counts() {
402 assert!(
403 ExtensionRelationType::Single
404 .validate_child_count(0)
405 .is_err()
406 );
407 assert!(
408 ExtensionRelationType::Single
409 .validate_child_count(2)
410 .is_err()
411 );
412 }
413}