Skip to main content

substrait_explain/extensions/
examples.rs

1//! [`PartitionHint`] demonstrates how to implement a custom enhancement:
2//! positional enum arguments combined with an optional named integer argument.
3//!
4//! # Text Format
5//!
6//! ```text
7//! Read[data => col:i64]
8//!   + Enh:PartitionHint[&HASH, count=8]
9//! ```
10//!
11//! Each positional argument is a [`PartitionStrategy`] variant rendered with
12//! the `&` enum prefix.  The optional named argument `count` gives the target
13//! number of partitions (`0` / absent means "let the executor decide").
14
15use crate::extensions::args::{EnumValue, ExtensionArgs, ExtensionValue};
16use crate::extensions::registry::{Explainable, ExtensionError, ExtensionRegistry};
17
18// ---------------------------------------------------------------------------
19// PartitionStrategy enum
20// ---------------------------------------------------------------------------
21
22/// Partitioning strategy for a [`PartitionHint`] enhancement.
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, prost::Enumeration)]
24#[repr(i32)]
25pub enum PartitionStrategy {
26    /// No strategy specified.
27    Unspecified = 0,
28    /// Distribute rows by hashing one or more key columns.
29    Hash = 1,
30    /// Sort-based range partitioning.
31    Range = 2,
32    /// Broadcast the entire relation to every partition.
33    Broadcast = 3,
34}
35
36impl PartitionStrategy {
37    /// The identifier used in the text format (without the leading `&`).
38    pub fn as_str_name(self) -> &'static str {
39        match self {
40            PartitionStrategy::Unspecified => "UNSPECIFIED",
41            PartitionStrategy::Hash => "HASH",
42            PartitionStrategy::Range => "RANGE",
43            PartitionStrategy::Broadcast => "BROADCAST",
44        }
45    }
46
47    /// Parse from the text-format identifier (without the leading `&`).
48    pub fn from_str_name(s: &str) -> Option<Self> {
49        match s {
50            "UNSPECIFIED" => Some(PartitionStrategy::Unspecified),
51            "HASH" => Some(PartitionStrategy::Hash),
52            "RANGE" => Some(PartitionStrategy::Range),
53            "BROADCAST" => Some(PartitionStrategy::Broadcast),
54            _ => None,
55        }
56    }
57}
58
59// ---------------------------------------------------------------------------
60// PartitionHint
61// ---------------------------------------------------------------------------
62
63/// Enhancement that hints the executor how to partition a relation's output.
64///
65/// Attach this to any standard relation via `register_enhancement` to convey
66/// partitioning decisions made during planning.
67///
68/// # Text Format
69///
70/// ```rust
71/// # use substrait_explain::extensions::examples;
72/// # use substrait_explain::format_with_registry;
73/// # use substrait_explain::parser::Parser;
74/// #
75/// # let registry = examples::registry();
76/// # let parser = Parser::new().with_extension_registry(registry.clone());
77/// #
78/// # let plan_text = r#"
79/// === Plan
80/// Root[result]
81///   Read[data => col:i64]
82///     + Enh:PartitionHint[&HASH, count=8]
83/// # "#;
84/// #
85/// # let plan = parser.parse_plan(plan_text).unwrap();
86/// # let (formatted, errors) = format_with_registry(&plan, &Default::default(), &registry);
87/// # assert!(errors.is_empty());
88/// # assert_eq!(formatted.trim(), plan_text.trim());
89/// ```
90#[derive(Clone, PartialEq, prost::Message)]
91pub struct PartitionHint {
92    /// The strategies to apply, in order of preference.  Each value is the
93    /// integer representation of [`PartitionStrategy`].
94    #[prost(enumeration = "PartitionStrategy", repeated, tag = "1")]
95    pub strategies: Vec<i32>,
96    /// Target number of partitions.  `0` means "let the executor decide".
97    #[prost(int64, tag = "2")]
98    pub count: i64,
99}
100
101impl prost::Name for PartitionHint {
102    const PACKAGE: &'static str = "example";
103    const NAME: &'static str = "PartitionHint";
104
105    fn full_name() -> String {
106        "example.PartitionHint".to_owned()
107    }
108
109    fn type_url() -> String {
110        "type.googleapis.com/example.PartitionHint".to_owned()
111    }
112}
113
114impl Explainable for PartitionHint {
115    fn name() -> &'static str {
116        "PartitionHint"
117    }
118
119    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
120        // Positional arguments are PartitionStrategy enum values.
121        let strategies: Result<Vec<i32>, ExtensionError> = args
122            .positional
123            .iter()
124            .map(|val| {
125                let EnumValue(ident) = EnumValue::try_from(val)?;
126                PartitionStrategy::from_str_name(&ident)
127                    .map(|s| s as i32)
128                    .ok_or_else(|| {
129                        ExtensionError::InvalidArgument(format!(
130                            "Unknown PartitionStrategy variant '&{ident}'; \
131                             expected one of &UNSPECIFIED, &HASH, &RANGE, &BROADCAST"
132                        ))
133                    })
134            })
135            .collect();
136
137        let mut extractor = args.extractor();
138        let count: i64 = extractor.get_named_or("count", 0)?;
139        extractor.check_exhausted()?;
140
141        Ok(PartitionHint {
142            strategies: strategies?,
143            count,
144        })
145    }
146
147    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
148        let mut args = ExtensionArgs::default();
149        for &raw in &self.strategies {
150            let s = PartitionStrategy::try_from(raw).unwrap_or(PartitionStrategy::Unspecified);
151            args.positional
152                .push(ExtensionValue::Enum(s.as_str_name().to_owned()));
153        }
154        if self.count != 0 {
155            args.insert("count", self.count);
156        }
157        Ok(args)
158    }
159}
160
161// ---------------------------------------------------------------------------
162// PlanHint
163// ---------------------------------------------------------------------------
164
165/// Optimization hint that carries a planner directive as a string.
166///
167/// Attach this to any standard relation via `register_optimization` to convey
168/// planner choices without changing relation semantics.
169///
170/// # Text Format
171///
172/// ```rust
173/// # use substrait_explain::extensions::examples;
174/// # use substrait_explain::format_with_registry;
175/// # use substrait_explain::parser::Parser;
176/// #
177/// # let registry = examples::registry();
178/// # let parser = Parser::new().with_extension_registry(registry.clone());
179/// #
180/// # let plan_text = r#"
181/// === Plan
182/// Root[result]
183///   Read[data => col:i64]
184///     + Opt:PlanHint[hint='use_index']
185/// # "#;
186/// #
187/// # let plan = parser.parse_plan(plan_text).unwrap();
188/// # let (formatted, errors) = format_with_registry(&plan, &Default::default(), &registry);
189/// # assert!(errors.is_empty());
190/// # assert_eq!(formatted.trim(), plan_text.trim());
191/// ```
192#[derive(Clone, PartialEq, prost::Message)]
193pub struct PlanHint {
194    /// Planner directive. The text format stores this as `hint='...'`.
195    #[prost(string, tag = "1")]
196    pub hint: String,
197}
198
199impl prost::Name for PlanHint {
200    const PACKAGE: &'static str = "example";
201    const NAME: &'static str = "PlanHint";
202
203    fn full_name() -> String {
204        "example.PlanHint".to_owned()
205    }
206
207    fn type_url() -> String {
208        "type.googleapis.com/example.PlanHint".to_owned()
209    }
210}
211
212impl Explainable for PlanHint {
213    fn name() -> &'static str {
214        "PlanHint"
215    }
216
217    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
218        if !args.positional.is_empty() {
219            return Err(ExtensionError::InvalidArgument(
220                "PlanHint does not accept positional arguments".to_owned(),
221            ));
222        }
223
224        let mut extractor = args.extractor();
225        let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_owned();
226        extractor.check_exhausted()?;
227        Ok(PlanHint { hint })
228    }
229
230    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
231        let mut args = ExtensionArgs::default();
232        args.named
233            .insert("hint".to_owned(), ExtensionValue::String(self.hint.clone()));
234        Ok(args)
235    }
236}
237
238// ---------------------------------------------------------------------------
239// BlobStoreRead
240// ---------------------------------------------------------------------------
241
242/// ExtensionTable detail for a simple blob-store backed read.
243///
244/// Attach this to `Read:Extension[...]` via `register_extension_table` to
245/// describe a custom table source whose output schema is carried by the
246/// surrounding `ReadRel.base_schema`.
247///
248/// # Text Format
249///
250/// ```rust
251/// # use substrait_explain::extensions::examples;
252/// # use substrait_explain::format_with_registry;
253/// # use substrait_explain::parser::Parser;
254/// #
255/// # let registry = examples::registry();
256/// # let parser = Parser::new().with_extension_registry(registry.clone());
257/// #
258/// # let plan_text = r#"
259/// === Plan
260/// Root[id, payload]
261///   Read:Extension[id:i64, payload:string]
262///     + Ext:BlobStoreRead['path/to/file', limit=100]
263/// # "#;
264/// #
265/// # let plan = parser.parse_plan(plan_text).unwrap();
266/// # let (formatted, errors) = format_with_registry(&plan, &Default::default(), &registry);
267/// # assert!(errors.is_empty());
268/// # assert_eq!(formatted.trim(), plan_text.trim());
269/// ```
270#[derive(Clone, PartialEq, prost::Message)]
271pub struct BlobStoreRead {
272    /// Blob path or URI to read.
273    #[prost(string, tag = "1")]
274    pub path: String,
275    /// Optional row limit. `0` means no limit.
276    #[prost(int64, tag = "2")]
277    pub limit: i64,
278    /// Whether archived blobs should be included.
279    #[prost(bool, tag = "3")]
280    pub include_archived: bool,
281}
282
283impl prost::Name for BlobStoreRead {
284    const PACKAGE: &'static str = "example";
285    const NAME: &'static str = "BlobStoreRead";
286
287    fn full_name() -> String {
288        "example.BlobStoreRead".to_owned()
289    }
290
291    fn type_url() -> String {
292        "type.googleapis.com/example.BlobStoreRead".to_owned()
293    }
294}
295
296impl Explainable for BlobStoreRead {
297    fn name() -> &'static str {
298        "BlobStoreRead"
299    }
300
301    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
302        if args.positional.len() != 1 {
303            return Err(ExtensionError::InvalidArgument(format!(
304                "BlobStoreRead expects exactly 1 positional path argument, got {}",
305                args.positional.len()
306            )));
307        }
308        if !args.output_columns.is_empty() {
309            return Err(ExtensionError::InvalidArgument(
310                "BlobStoreRead output columns belong in Read:Extension[...]".to_owned(),
311            ));
312        }
313
314        let path = <&str>::try_from(&args.positional[0])?.to_owned();
315        let mut extractor = args.extractor();
316        let limit: i64 = extractor.get_named_or("limit", 0)?;
317        let include_archived: bool = extractor.get_named_or("include_archived", false)?;
318        extractor.check_exhausted()?;
319
320        Ok(Self {
321            path,
322            limit,
323            include_archived,
324        })
325    }
326
327    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
328        let mut args = ExtensionArgs::default();
329        args.positional
330            .push(ExtensionValue::String(self.path.clone()));
331        if self.limit != 0 {
332            args.named
333                .insert("limit".to_owned(), ExtensionValue::Integer(self.limit));
334        }
335        if self.include_archived {
336            args.named
337                .insert("include_archived".to_owned(), ExtensionValue::Boolean(true));
338        }
339        Ok(args)
340    }
341}
342
343/// Create an [`ExtensionRegistry`] preloaded with the example extension types.
344pub fn registry() -> ExtensionRegistry {
345    let mut registry = ExtensionRegistry::new();
346    registry
347        .register_enhancement::<PartitionHint>()
348        .expect("register PartitionHint example enhancement");
349    registry
350        .register_optimization::<PlanHint>()
351        .expect("register PlanHint example optimization");
352    registry
353        .register_extension_table::<BlobStoreRead>()
354        .expect("register BlobStoreRead example extension table");
355    registry
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::extensions::AnyConvertible;
362
363    fn make_hint(strategies: Vec<PartitionStrategy>, count: i64) -> PartitionHint {
364        PartitionHint {
365            strategies: strategies.into_iter().map(|s| s as i32).collect(),
366            count,
367        }
368    }
369
370    #[test]
371    fn round_trip_via_any() {
372        let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 4);
373        let any = original.to_any().expect("encode");
374        let decoded = PartitionHint::from_any(any.as_ref()).expect("decode");
375        assert_eq!(original, decoded);
376    }
377
378    #[test]
379    fn to_args_produces_enum_and_named() {
380        let hint = make_hint(vec![PartitionStrategy::Hash], 8);
381        let args = hint.to_args().unwrap();
382        assert_eq!(args.positional.len(), 1);
383        assert!(matches!(&args.positional[0], ExtensionValue::Enum(s) if s == "HASH"));
384        let count = args.named.get("count").expect("count arg");
385        assert_eq!(i64::try_from(count).unwrap(), 8);
386    }
387
388    #[test]
389    fn to_args_omits_zero_count() {
390        let hint = make_hint(vec![PartitionStrategy::Broadcast], 0);
391        let args = hint.to_args().unwrap();
392        assert!(args.named.is_empty(), "count=0 should be omitted");
393    }
394
395    #[test]
396    fn from_args_round_trip() {
397        let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 16);
398        let args = original.to_args().unwrap();
399        let decoded = PartitionHint::from_args(&args).unwrap();
400        assert_eq!(original, decoded);
401    }
402
403    #[test]
404    fn from_args_rejects_unknown_strategy() {
405        let mut args = ExtensionArgs::default();
406        args.positional
407            .push(ExtensionValue::Enum("BOGUS".to_owned()));
408        assert!(PartitionHint::from_args(&args).is_err());
409    }
410
411    #[test]
412    fn from_args_rejects_non_enum_positional() {
413        // An integer positional arg where an enum is expected should fail.
414        let mut args = ExtensionArgs::default();
415        args.push(1_i64);
416        let result = PartitionHint::from_args(&args);
417        assert!(
418            result.is_err(),
419            "expected error for non-enum positional arg, got {result:?}"
420        );
421    }
422
423    #[test]
424    fn from_args_rejects_extra_named_args() {
425        // check_exhausted should reject unknown named args.
426        let mut args = ExtensionArgs::default();
427        args.insert("unknown_key", 99_i64);
428        let result = PartitionHint::from_args(&args);
429        assert!(
430            result.is_err(),
431            "expected error for unknown named arg, got {result:?}"
432        );
433    }
434
435    #[test]
436    fn from_args_empty_strategies_roundtrip() {
437        let original = make_hint(vec![], 0);
438        let args = original.to_args().unwrap();
439        let decoded = PartitionHint::from_args(&args).unwrap();
440        assert_eq!(original, decoded);
441        assert!(decoded.strategies.is_empty());
442        assert_eq!(decoded.count, 0);
443    }
444
445    #[test]
446    fn registry_roundtrip() {
447        let registry = registry();
448
449        let original = make_hint(vec![PartitionStrategy::Hash], 4);
450        let any = original.to_any().unwrap();
451
452        let (name, args) = registry
453            .decode_enhancement(any.as_ref())
454            .expect("decode_enhancement");
455        assert_eq!(name, "PartitionHint");
456        assert_eq!(args.positional.len(), 1);
457
458        let any2 = registry
459            .parse_enhancement("PartitionHint", &args)
460            .expect("parse_enhancement");
461        let decoded = PartitionHint::from_any(any2.as_ref()).unwrap();
462        assert_eq!(original, decoded);
463    }
464
465    #[test]
466    fn plan_hint_args_round_trip() {
467        let original = PlanHint {
468            hint: "use_index".to_owned(),
469        };
470        let args = original.to_args().unwrap();
471        let decoded = PlanHint::from_args(&args).unwrap();
472        assert_eq!(original, decoded);
473    }
474
475    #[test]
476    fn plan_hint_registry_roundtrip() {
477        let registry = registry();
478
479        let original = PlanHint {
480            hint: "parallel".to_owned(),
481        };
482        let any = original.to_any().unwrap();
483
484        let (name, args) = registry
485            .decode_optimization(any.as_ref())
486            .expect("decode_optimization");
487        assert_eq!(name, "PlanHint");
488
489        let any2 = registry
490            .parse_optimization("PlanHint", &args)
491            .expect("parse_optimization");
492        let decoded = PlanHint::from_any(any2.as_ref()).unwrap();
493        assert_eq!(original, decoded);
494    }
495
496    #[test]
497    fn blob_store_read_args_round_trip() {
498        let original = BlobStoreRead {
499            path: "path/to/file".to_owned(),
500            limit: 100,
501            include_archived: true,
502        };
503
504        let args = original.to_args().unwrap();
505        let decoded = BlobStoreRead::from_args(&args).unwrap();
506
507        assert_eq!(original, decoded);
508    }
509
510    #[test]
511    fn blob_store_read_registry_roundtrip() {
512        let registry = registry();
513
514        let original = BlobStoreRead {
515            path: "path/to/file".to_owned(),
516            limit: 100,
517            include_archived: true,
518        };
519        let any = original.to_any().unwrap();
520
521        let (name, args) = registry
522            .decode_extension_table(any.as_ref())
523            .expect("decode_extension_table");
524        assert_eq!(name, "BlobStoreRead");
525
526        let any2 = registry
527            .parse_extension_table("BlobStoreRead", &args)
528            .expect("parse_extension_table");
529        let decoded = BlobStoreRead::from_any(any2.as_ref()).unwrap();
530        assert_eq!(original, decoded);
531    }
532}