substrait_explain/extensions/
examples.rs1use crate::extensions::args::{EnumValue, ExtensionArgs, ExtensionValue};
16use crate::extensions::registry::{Explainable, ExtensionError, ExtensionRegistry};
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, prost::Enumeration)]
24#[repr(i32)]
25pub enum PartitionStrategy {
26 Unspecified = 0,
28 Hash = 1,
30 Range = 2,
32 Broadcast = 3,
34}
35
36impl PartitionStrategy {
37 pub fn as_str_name(self) -> &'static str {
39 match self {
40 PartitionStrategy::Unspecified => "UNSPECIFIED",
41 PartitionStrategy::Hash => "HASH",
42 PartitionStrategy::Range => "RANGE",
43 PartitionStrategy::Broadcast => "BROADCAST",
44 }
45 }
46
47 pub fn from_str_name(s: &str) -> Option<Self> {
49 match s {
50 "UNSPECIFIED" => Some(PartitionStrategy::Unspecified),
51 "HASH" => Some(PartitionStrategy::Hash),
52 "RANGE" => Some(PartitionStrategy::Range),
53 "BROADCAST" => Some(PartitionStrategy::Broadcast),
54 _ => None,
55 }
56 }
57}
58
59#[derive(Clone, PartialEq, prost::Message)]
91pub struct PartitionHint {
92 #[prost(enumeration = "PartitionStrategy", repeated, tag = "1")]
95 pub strategies: Vec<i32>,
96 #[prost(int64, tag = "2")]
98 pub count: i64,
99}
100
101impl prost::Name for PartitionHint {
102 const PACKAGE: &'static str = "example";
103 const NAME: &'static str = "PartitionHint";
104
105 fn full_name() -> String {
106 "example.PartitionHint".to_owned()
107 }
108
109 fn type_url() -> String {
110 "type.googleapis.com/example.PartitionHint".to_owned()
111 }
112}
113
114impl Explainable for PartitionHint {
115 fn name() -> &'static str {
116 "PartitionHint"
117 }
118
119 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
120 let strategies: Result<Vec<i32>, ExtensionError> = args
122 .positional
123 .iter()
124 .map(|val| {
125 let EnumValue(ident) = EnumValue::try_from(val)?;
126 PartitionStrategy::from_str_name(&ident)
127 .map(|s| s as i32)
128 .ok_or_else(|| {
129 ExtensionError::InvalidArgument(format!(
130 "Unknown PartitionStrategy variant '&{ident}'; \
131 expected one of &UNSPECIFIED, &HASH, &RANGE, &BROADCAST"
132 ))
133 })
134 })
135 .collect();
136
137 let mut extractor = args.extractor();
138 let count: i64 = extractor.get_named_or("count", 0)?;
139 extractor.check_exhausted()?;
140
141 Ok(PartitionHint {
142 strategies: strategies?,
143 count,
144 })
145 }
146
147 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
148 let mut args = ExtensionArgs::default();
149 for &raw in &self.strategies {
150 let s = PartitionStrategy::try_from(raw).unwrap_or(PartitionStrategy::Unspecified);
151 args.positional
152 .push(ExtensionValue::Enum(s.as_str_name().to_owned()));
153 }
154 if self.count != 0 {
155 args.insert("count", self.count);
156 }
157 Ok(args)
158 }
159}
160
161#[derive(Clone, PartialEq, prost::Message)]
193pub struct PlanHint {
194 #[prost(string, tag = "1")]
196 pub hint: String,
197}
198
199impl prost::Name for PlanHint {
200 const PACKAGE: &'static str = "example";
201 const NAME: &'static str = "PlanHint";
202
203 fn full_name() -> String {
204 "example.PlanHint".to_owned()
205 }
206
207 fn type_url() -> String {
208 "type.googleapis.com/example.PlanHint".to_owned()
209 }
210}
211
212impl Explainable for PlanHint {
213 fn name() -> &'static str {
214 "PlanHint"
215 }
216
217 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
218 if !args.positional.is_empty() {
219 return Err(ExtensionError::InvalidArgument(
220 "PlanHint does not accept positional arguments".to_owned(),
221 ));
222 }
223
224 let mut extractor = args.extractor();
225 let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_owned();
226 extractor.check_exhausted()?;
227 Ok(PlanHint { hint })
228 }
229
230 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
231 let mut args = ExtensionArgs::default();
232 args.named
233 .insert("hint".to_owned(), ExtensionValue::String(self.hint.clone()));
234 Ok(args)
235 }
236}
237
238#[derive(Clone, PartialEq, prost::Message)]
271pub struct BlobStoreRead {
272 #[prost(string, tag = "1")]
274 pub path: String,
275 #[prost(int64, tag = "2")]
277 pub limit: i64,
278 #[prost(bool, tag = "3")]
280 pub include_archived: bool,
281}
282
283impl prost::Name for BlobStoreRead {
284 const PACKAGE: &'static str = "example";
285 const NAME: &'static str = "BlobStoreRead";
286
287 fn full_name() -> String {
288 "example.BlobStoreRead".to_owned()
289 }
290
291 fn type_url() -> String {
292 "type.googleapis.com/example.BlobStoreRead".to_owned()
293 }
294}
295
296impl Explainable for BlobStoreRead {
297 fn name() -> &'static str {
298 "BlobStoreRead"
299 }
300
301 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
302 if args.positional.len() != 1 {
303 return Err(ExtensionError::InvalidArgument(format!(
304 "BlobStoreRead expects exactly 1 positional path argument, got {}",
305 args.positional.len()
306 )));
307 }
308 if !args.output_columns.is_empty() {
309 return Err(ExtensionError::InvalidArgument(
310 "BlobStoreRead output columns belong in Read:Extension[...]".to_owned(),
311 ));
312 }
313
314 let path = <&str>::try_from(&args.positional[0])?.to_owned();
315 let mut extractor = args.extractor();
316 let limit: i64 = extractor.get_named_or("limit", 0)?;
317 let include_archived: bool = extractor.get_named_or("include_archived", false)?;
318 extractor.check_exhausted()?;
319
320 Ok(Self {
321 path,
322 limit,
323 include_archived,
324 })
325 }
326
327 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
328 let mut args = ExtensionArgs::default();
329 args.positional
330 .push(ExtensionValue::String(self.path.clone()));
331 if self.limit != 0 {
332 args.named
333 .insert("limit".to_owned(), ExtensionValue::Integer(self.limit));
334 }
335 if self.include_archived {
336 args.named
337 .insert("include_archived".to_owned(), ExtensionValue::Boolean(true));
338 }
339 Ok(args)
340 }
341}
342
343pub fn registry() -> ExtensionRegistry {
345 let mut registry = ExtensionRegistry::new();
346 registry
347 .register_enhancement::<PartitionHint>()
348 .expect("register PartitionHint example enhancement");
349 registry
350 .register_optimization::<PlanHint>()
351 .expect("register PlanHint example optimization");
352 registry
353 .register_extension_table::<BlobStoreRead>()
354 .expect("register BlobStoreRead example extension table");
355 registry
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::extensions::AnyConvertible;
362
363 fn make_hint(strategies: Vec<PartitionStrategy>, count: i64) -> PartitionHint {
364 PartitionHint {
365 strategies: strategies.into_iter().map(|s| s as i32).collect(),
366 count,
367 }
368 }
369
370 #[test]
371 fn round_trip_via_any() {
372 let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 4);
373 let any = original.to_any().expect("encode");
374 let decoded = PartitionHint::from_any(any.as_ref()).expect("decode");
375 assert_eq!(original, decoded);
376 }
377
378 #[test]
379 fn to_args_produces_enum_and_named() {
380 let hint = make_hint(vec![PartitionStrategy::Hash], 8);
381 let args = hint.to_args().unwrap();
382 assert_eq!(args.positional.len(), 1);
383 assert!(matches!(&args.positional[0], ExtensionValue::Enum(s) if s == "HASH"));
384 let count = args.named.get("count").expect("count arg");
385 assert_eq!(i64::try_from(count).unwrap(), 8);
386 }
387
388 #[test]
389 fn to_args_omits_zero_count() {
390 let hint = make_hint(vec![PartitionStrategy::Broadcast], 0);
391 let args = hint.to_args().unwrap();
392 assert!(args.named.is_empty(), "count=0 should be omitted");
393 }
394
395 #[test]
396 fn from_args_round_trip() {
397 let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 16);
398 let args = original.to_args().unwrap();
399 let decoded = PartitionHint::from_args(&args).unwrap();
400 assert_eq!(original, decoded);
401 }
402
403 #[test]
404 fn from_args_rejects_unknown_strategy() {
405 let mut args = ExtensionArgs::default();
406 args.positional
407 .push(ExtensionValue::Enum("BOGUS".to_owned()));
408 assert!(PartitionHint::from_args(&args).is_err());
409 }
410
411 #[test]
412 fn from_args_rejects_non_enum_positional() {
413 let mut args = ExtensionArgs::default();
415 args.push(1_i64);
416 let result = PartitionHint::from_args(&args);
417 assert!(
418 result.is_err(),
419 "expected error for non-enum positional arg, got {result:?}"
420 );
421 }
422
423 #[test]
424 fn from_args_rejects_extra_named_args() {
425 let mut args = ExtensionArgs::default();
427 args.insert("unknown_key", 99_i64);
428 let result = PartitionHint::from_args(&args);
429 assert!(
430 result.is_err(),
431 "expected error for unknown named arg, got {result:?}"
432 );
433 }
434
435 #[test]
436 fn from_args_empty_strategies_roundtrip() {
437 let original = make_hint(vec![], 0);
438 let args = original.to_args().unwrap();
439 let decoded = PartitionHint::from_args(&args).unwrap();
440 assert_eq!(original, decoded);
441 assert!(decoded.strategies.is_empty());
442 assert_eq!(decoded.count, 0);
443 }
444
445 #[test]
446 fn registry_roundtrip() {
447 let registry = registry();
448
449 let original = make_hint(vec![PartitionStrategy::Hash], 4);
450 let any = original.to_any().unwrap();
451
452 let (name, args) = registry
453 .decode_enhancement(any.as_ref())
454 .expect("decode_enhancement");
455 assert_eq!(name, "PartitionHint");
456 assert_eq!(args.positional.len(), 1);
457
458 let any2 = registry
459 .parse_enhancement("PartitionHint", &args)
460 .expect("parse_enhancement");
461 let decoded = PartitionHint::from_any(any2.as_ref()).unwrap();
462 assert_eq!(original, decoded);
463 }
464
465 #[test]
466 fn plan_hint_args_round_trip() {
467 let original = PlanHint {
468 hint: "use_index".to_owned(),
469 };
470 let args = original.to_args().unwrap();
471 let decoded = PlanHint::from_args(&args).unwrap();
472 assert_eq!(original, decoded);
473 }
474
475 #[test]
476 fn plan_hint_registry_roundtrip() {
477 let registry = registry();
478
479 let original = PlanHint {
480 hint: "parallel".to_owned(),
481 };
482 let any = original.to_any().unwrap();
483
484 let (name, args) = registry
485 .decode_optimization(any.as_ref())
486 .expect("decode_optimization");
487 assert_eq!(name, "PlanHint");
488
489 let any2 = registry
490 .parse_optimization("PlanHint", &args)
491 .expect("parse_optimization");
492 let decoded = PlanHint::from_any(any2.as_ref()).unwrap();
493 assert_eq!(original, decoded);
494 }
495
496 #[test]
497 fn blob_store_read_args_round_trip() {
498 let original = BlobStoreRead {
499 path: "path/to/file".to_owned(),
500 limit: 100,
501 include_archived: true,
502 };
503
504 let args = original.to_args().unwrap();
505 let decoded = BlobStoreRead::from_args(&args).unwrap();
506
507 assert_eq!(original, decoded);
508 }
509
510 #[test]
511 fn blob_store_read_registry_roundtrip() {
512 let registry = registry();
513
514 let original = BlobStoreRead {
515 path: "path/to/file".to_owned(),
516 limit: 100,
517 include_archived: true,
518 };
519 let any = original.to_any().unwrap();
520
521 let (name, args) = registry
522 .decode_extension_table(any.as_ref())
523 .expect("decode_extension_table");
524 assert_eq!(name, "BlobStoreRead");
525
526 let any2 = registry
527 .parse_extension_table("BlobStoreRead", &args)
528 .expect("parse_extension_table");
529 let decoded = BlobStoreRead::from_any(any2.as_ref()).unwrap();
530 assert_eq!(original, decoded);
531 }
532}