substrait_explain/extensions/
args.rs1use std::collections::HashSet;
7use std::fmt;
8
9use indexmap::IndexMap;
10
11use super::ExtensionError;
12use crate::textify::expressions::Reference;
13use crate::textify::types::escaped;
14
15#[derive(Debug, Clone)]
19pub(crate) struct RawExpression {
20 text: String,
21}
22
23impl RawExpression {
24 pub fn new(text: String) -> Self {
25 Self { text }
26 }
27}
28
29impl fmt::Display for RawExpression {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{}", self.text)
32 }
33}
34
35#[derive(Debug, Clone)]
42pub struct ExtensionArgs {
43 pub positional: Vec<ExtensionValue>,
45 pub named: IndexMap<String, ExtensionValue>,
47 pub output_columns: Vec<ExtensionColumn>,
49 pub relation_type: ExtensionRelationType,
51}
52
53pub struct ArgsExtractor<'a> {
60 args: &'a ExtensionArgs,
61 consumed: HashSet<&'a str>,
62 checked: bool,
63}
64
65impl<'a> ArgsExtractor<'a> {
66 pub fn new(args: &'a ExtensionArgs) -> Self {
68 Self {
69 args,
70 consumed: HashSet::new(),
71 checked: false,
72 }
73 }
74
75 pub fn get_named_arg(&mut self, name: &str) -> Option<&'a ExtensionValue> {
77 match self.args.named.get_key_value(name) {
78 Some((k, value)) => {
79 self.consumed.insert(k);
80 Some(value)
81 }
82 None => None,
83 }
84 }
85
86 pub fn expect_named_arg<T>(&mut self, name: &str) -> Result<T, ExtensionError>
89 where
90 T: TryFrom<&'a ExtensionValue>,
91 T::Error: Into<ExtensionError>,
92 {
93 match self.get_named_arg(name) {
94 Some(value) => T::try_from(value).map_err(Into::into),
95 None => Err(ExtensionError::MissingArgument {
96 name: name.to_string(),
97 }),
98 }
99 }
100
101 pub fn get_named_or<T>(&mut self, name: &str, default: T) -> Result<T, ExtensionError>
104 where
105 T: TryFrom<&'a ExtensionValue>,
106 T::Error: Into<ExtensionError>,
107 {
108 match self.get_named_arg(name) {
109 Some(value) => T::try_from(value).map_err(Into::into),
110 None => Ok(default),
111 }
112 }
113
114 pub fn check_exhausted(&mut self) -> Result<(), ExtensionError> {
121 self.checked = true;
122
123 let mut unknown_args = Vec::new();
124 for name in self.args.named.keys() {
125 if !self.consumed.contains(name.as_str()) {
126 unknown_args.push(name.as_str());
127 }
128 }
129
130 if unknown_args.is_empty() {
131 Ok(())
132 } else {
133 unknown_args.sort();
135 Err(ExtensionError::InvalidArgument(format!(
136 "Unknown named arguments: {}",
137 unknown_args.join(", ")
138 )))
139 }
140 }
141}
142
143impl Drop for ArgsExtractor<'_> {
144 fn drop(&mut self) {
145 if self.checked || std::thread::panicking() {
146 return;
147 }
148 debug_assert!(
150 false,
151 "ArgsExtractor dropped without calling check_exhausted()"
152 );
153 }
154}
155
156#[derive(Debug, Clone)]
158pub enum ExtensionValue {
159 String(String),
161 Integer(i64),
163 Float(f64),
165 Boolean(bool),
167 Reference(i32),
169 #[allow(private_interfaces)]
172 Expression(RawExpression),
173}
174
175impl fmt::Display for ExtensionValue {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 match self {
178 ExtensionValue::String(s) => write!(f, "String({})", escaped(s)),
179 ExtensionValue::Integer(i) => write!(f, "Integer({})", i),
180 ExtensionValue::Float(n) => write!(f, "Float({})", n),
181 ExtensionValue::Boolean(b) => write!(f, "Boolean({})", b),
182 ExtensionValue::Reference(r) => write!(f, "Reference({})", r),
183 ExtensionValue::Expression(e) => write!(f, "Expression({})", e),
184 }
185 }
186}
187
188impl<'a> TryFrom<&'a ExtensionValue> for &'a str {
189 type Error = ExtensionError;
190
191 fn try_from(value: &'a ExtensionValue) -> Result<&'a str, Self::Error> {
192 match value {
193 ExtensionValue::String(s) => Ok(s),
194 v => Err(ExtensionError::InvalidArgument(format!(
195 "Expected string, got {v}",
196 ))),
197 }
198 }
199}
200
201impl TryFrom<ExtensionValue> for String {
202 type Error = ExtensionError;
203
204 fn try_from(value: ExtensionValue) -> Result<String, Self::Error> {
205 match value {
206 ExtensionValue::String(s) => Ok(s),
207 v => Err(ExtensionError::InvalidArgument(format!(
208 "Expected string, got {v}",
209 ))),
210 }
211 }
212}
213
214impl TryFrom<&ExtensionValue> for i64 {
215 type Error = ExtensionError;
216
217 fn try_from(value: &ExtensionValue) -> Result<i64, Self::Error> {
218 match value {
219 &ExtensionValue::Integer(i) => Ok(i),
220 v => Err(ExtensionError::InvalidArgument(format!(
221 "Expected integer, got {v}",
222 ))),
223 }
224 }
225}
226
227impl TryFrom<&ExtensionValue> for f64 {
228 type Error = ExtensionError;
229
230 fn try_from(value: &ExtensionValue) -> Result<f64, Self::Error> {
231 match value {
232 &ExtensionValue::Float(f) => Ok(f),
233 v => Err(ExtensionError::InvalidArgument(format!(
234 "Expected float, got {v}",
235 ))),
236 }
237 }
238}
239
240impl TryFrom<&ExtensionValue> for bool {
241 type Error = ExtensionError;
242
243 fn try_from(value: &ExtensionValue) -> Result<bool, Self::Error> {
244 match value {
245 &ExtensionValue::Boolean(b) => Ok(b),
246 v => Err(ExtensionError::InvalidArgument(format!(
247 "Expected boolean, got {v}",
248 ))),
249 }
250 }
251}
252
253impl TryFrom<&ExtensionValue> for Reference {
254 type Error = ExtensionError;
255
256 fn try_from(value: &ExtensionValue) -> Result<Reference, Self::Error> {
257 match value {
258 &ExtensionValue::Reference(r) => Ok(Reference(r)),
259 v => Err(ExtensionError::InvalidArgument(format!(
260 "Expected reference, got {v}",
261 ))),
262 }
263 }
264}
265
266#[derive(Debug, Clone)]
268pub enum ExtensionColumn {
269 Named { name: String, type_spec: String },
271 Reference(i32),
273 #[allow(private_interfaces)]
276 Expression(RawExpression),
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281pub enum ExtensionRelationType {
282 Leaf,
284 Single,
286 Multi,
288}
289
290impl std::str::FromStr for ExtensionRelationType {
291 type Err = String;
292
293 fn from_str(s: &str) -> Result<Self, Self::Err> {
294 match s {
295 "ExtensionLeaf" => Ok(ExtensionRelationType::Leaf),
296 "ExtensionSingle" => Ok(ExtensionRelationType::Single),
297 "ExtensionMulti" => Ok(ExtensionRelationType::Multi),
298 _ => Err(format!("Unknown extension relation type: {}", s)),
299 }
300 }
301}
302
303impl ExtensionRelationType {
304 pub fn as_str(&self) -> &'static str {
306 match self {
307 ExtensionRelationType::Leaf => "ExtensionLeaf",
308 ExtensionRelationType::Single => "ExtensionSingle",
309 ExtensionRelationType::Multi => "ExtensionMulti",
310 }
311 }
312
313 pub fn validate_child_count(&self, child_count: usize) -> Result<(), String> {
315 match self {
316 ExtensionRelationType::Leaf => {
317 if child_count == 0 {
318 Ok(())
319 } else {
320 Err(format!(
321 "ExtensionLeaf should have no input children, got {child_count}"
322 ))
323 }
324 }
325 ExtensionRelationType::Single => {
326 if child_count == 1 {
327 Ok(())
328 } else {
329 Err(format!(
330 "ExtensionSingle should have exactly 1 input child, got {child_count}"
331 ))
332 }
333 }
334 ExtensionRelationType::Multi => {
335 Ok(())
337 }
338 }
339 }
340}
341
342impl ExtensionArgs {
346 pub fn new(relation_type: ExtensionRelationType) -> Self {
348 Self {
349 positional: Vec::new(),
350 named: IndexMap::new(),
351 output_columns: Vec::new(),
352 relation_type,
353 }
354 }
355
356 pub fn extractor(&self) -> ArgsExtractor<'_> {
358 ArgsExtractor::new(self)
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::ExtensionRelationType;
365
366 #[test]
367 fn extension_multi_allows_zero_children() {
368 assert!(ExtensionRelationType::Multi.validate_child_count(0).is_ok());
369 }
370
371 #[test]
372 fn extension_multi_allows_single_child() {
373 assert!(ExtensionRelationType::Multi.validate_child_count(1).is_ok());
374 }
375
376 #[test]
377 fn extension_multi_allows_multiple_children() {
378 assert!(ExtensionRelationType::Multi.validate_child_count(3).is_ok());
379 }
380
381 #[test]
382 fn extension_single_rejects_wrong_child_counts() {
383 assert!(
384 ExtensionRelationType::Single
385 .validate_child_count(0)
386 .is_err()
387 );
388 assert!(
389 ExtensionRelationType::Single
390 .validate_child_count(2)
391 .is_err()
392 );
393 }
394}