substrait_explain/extensions/
examples.rs

1//! [`PartitionHint`] demonstrates how to implement a custom enhancement:
2//! positional enum arguments combined with an optional named integer argument.
3//!
4//! # Text Format
5//!
6//! ```text
7//! Read[data => col:i64]
8//!   + Enh:PartitionHint[&HASH, count=8]
9//! ```
10//!
11//! Each positional argument is a [`PartitionStrategy`] variant rendered with
12//! the `&` enum prefix.  The optional named argument `count` gives the target
13//! number of partitions (`0` / absent means "let the executor decide").
14
15use crate::extensions::args::{EnumValue, ExtensionArgs, ExtensionRelationType, ExtensionValue};
16use crate::extensions::registry::{Explainable, ExtensionError};
17
18// ---------------------------------------------------------------------------
19// PartitionStrategy enum
20// ---------------------------------------------------------------------------
21
22/// Partitioning strategy for a [`PartitionHint`] enhancement.
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, prost::Enumeration)]
24#[repr(i32)]
25pub enum PartitionStrategy {
26    /// No strategy specified.
27    Unspecified = 0,
28    /// Distribute rows by hashing one or more key columns.
29    Hash = 1,
30    /// Sort-based range partitioning.
31    Range = 2,
32    /// Broadcast the entire relation to every partition.
33    Broadcast = 3,
34}
35
36impl PartitionStrategy {
37    /// The identifier used in the text format (without the leading `&`).
38    pub fn as_str_name(self) -> &'static str {
39        match self {
40            PartitionStrategy::Unspecified => "UNSPECIFIED",
41            PartitionStrategy::Hash => "HASH",
42            PartitionStrategy::Range => "RANGE",
43            PartitionStrategy::Broadcast => "BROADCAST",
44        }
45    }
46
47    /// Parse from the text-format identifier (without the leading `&`).
48    pub fn from_str_name(s: &str) -> Option<Self> {
49        match s {
50            "UNSPECIFIED" => Some(PartitionStrategy::Unspecified),
51            "HASH" => Some(PartitionStrategy::Hash),
52            "RANGE" => Some(PartitionStrategy::Range),
53            "BROADCAST" => Some(PartitionStrategy::Broadcast),
54            _ => None,
55        }
56    }
57}
58
59// ---------------------------------------------------------------------------
60// PartitionHint
61// ---------------------------------------------------------------------------
62
63/// Enhancement that hints the executor how to partition a relation's output.
64///
65/// Attach this to any standard relation via `register_enhancement` to convey
66/// partitioning decisions made during planning.
67///
68/// # Text Format
69///
70/// ```rust
71/// # use substrait_explain::extensions::{ExtensionRegistry, examples::PartitionHint};
72/// # use substrait_explain::format_with_registry;
73/// # use substrait_explain::parser::Parser;
74/// #
75/// # let mut registry = ExtensionRegistry::new();
76/// # registry.register_enhancement::<PartitionHint>().unwrap();
77/// # let parser = Parser::new().with_extension_registry(registry.clone());
78/// #
79/// # let plan_text = r#"
80/// === Plan
81/// Root[result]
82///   Read[data => col:i64]
83///     + Enh:PartitionHint[&HASH, count=8]
84/// # "#;
85/// #
86/// # let plan = parser.parse_plan(plan_text).unwrap();
87/// # let (formatted, errors) = format_with_registry(&plan, &Default::default(), &registry);
88/// # assert!(errors.is_empty());
89/// # assert_eq!(formatted.trim(), plan_text.trim());
90/// ```
91#[derive(Clone, PartialEq, prost::Message)]
92pub struct PartitionHint {
93    /// The strategies to apply, in order of preference.  Each value is the
94    /// integer representation of [`PartitionStrategy`].
95    #[prost(enumeration = "PartitionStrategy", repeated, tag = "1")]
96    pub strategies: Vec<i32>,
97    /// Target number of partitions.  `0` means "let the executor decide".
98    #[prost(int64, tag = "2")]
99    pub count: i64,
100}
101
102impl prost::Name for PartitionHint {
103    const PACKAGE: &'static str = "example";
104    const NAME: &'static str = "PartitionHint";
105
106    fn full_name() -> String {
107        "example.PartitionHint".to_owned()
108    }
109
110    fn type_url() -> String {
111        "type.googleapis.com/example.PartitionHint".to_owned()
112    }
113}
114
115impl Explainable for PartitionHint {
116    fn name() -> &'static str {
117        "PartitionHint"
118    }
119
120    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
121        // Positional arguments are PartitionStrategy enum values.
122        let strategies: Result<Vec<i32>, ExtensionError> = args
123            .positional
124            .iter()
125            .map(|val| {
126                let EnumValue(ident) = EnumValue::try_from(val)?;
127                PartitionStrategy::from_str_name(&ident)
128                    .map(|s| s as i32)
129                    .ok_or_else(|| {
130                        ExtensionError::InvalidArgument(format!(
131                            "Unknown PartitionStrategy variant '&{ident}'; \
132                             expected one of &UNSPECIFIED, &HASH, &RANGE, &BROADCAST"
133                        ))
134                    })
135            })
136            .collect();
137
138        let mut extractor = args.extractor();
139        let count: i64 = extractor.get_named_or("count", 0)?;
140        extractor.check_exhausted()?;
141
142        Ok(PartitionHint {
143            strategies: strategies?,
144            count,
145        })
146    }
147
148    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
149        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
150        for &raw in &self.strategies {
151            let s = PartitionStrategy::try_from(raw).unwrap_or(PartitionStrategy::Unspecified);
152            args.positional
153                .push(ExtensionValue::Enum(s.as_str_name().to_owned()));
154        }
155        if self.count != 0 {
156            args.named
157                .insert("count".to_owned(), ExtensionValue::Integer(self.count));
158        }
159        Ok(args)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::extensions::{AnyConvertible, ExtensionRegistry};
167
168    fn make_hint(strategies: Vec<PartitionStrategy>, count: i64) -> PartitionHint {
169        PartitionHint {
170            strategies: strategies.into_iter().map(|s| s as i32).collect(),
171            count,
172        }
173    }
174
175    #[test]
176    fn round_trip_via_any() {
177        let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 4);
178        let any = original.to_any().expect("encode");
179        let decoded = PartitionHint::from_any(any.as_ref()).expect("decode");
180        assert_eq!(original, decoded);
181    }
182
183    #[test]
184    fn to_args_produces_enum_and_named() {
185        let hint = make_hint(vec![PartitionStrategy::Hash], 8);
186        let args = hint.to_args().unwrap();
187        assert_eq!(args.positional.len(), 1);
188        assert!(matches!(&args.positional[0], ExtensionValue::Enum(s) if s == "HASH"));
189        assert!(matches!(
190            args.named.get("count"),
191            Some(ExtensionValue::Integer(8))
192        ));
193    }
194
195    #[test]
196    fn to_args_omits_zero_count() {
197        let hint = make_hint(vec![PartitionStrategy::Broadcast], 0);
198        let args = hint.to_args().unwrap();
199        assert!(args.named.is_empty(), "count=0 should be omitted");
200    }
201
202    #[test]
203    fn from_args_round_trip() {
204        let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 16);
205        let args = original.to_args().unwrap();
206        let decoded = PartitionHint::from_args(&args).unwrap();
207        assert_eq!(original, decoded);
208    }
209
210    #[test]
211    fn from_args_rejects_unknown_strategy() {
212        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
213        args.positional
214            .push(ExtensionValue::Enum("BOGUS".to_owned()));
215        assert!(PartitionHint::from_args(&args).is_err());
216    }
217
218    #[test]
219    fn from_args_rejects_non_enum_positional() {
220        // An integer positional arg where an enum is expected should fail.
221        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
222        args.positional.push(ExtensionValue::Integer(1));
223        let result = PartitionHint::from_args(&args);
224        assert!(
225            result.is_err(),
226            "expected error for non-enum positional arg, got {result:?}"
227        );
228    }
229
230    #[test]
231    fn from_args_rejects_extra_named_args() {
232        // check_exhausted should reject unknown named args.
233        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
234        args.named
235            .insert("unknown_key".to_owned(), ExtensionValue::Integer(99));
236        let result = PartitionHint::from_args(&args);
237        assert!(
238            result.is_err(),
239            "expected error for unknown named arg, got {result:?}"
240        );
241    }
242
243    #[test]
244    fn from_args_empty_strategies_roundtrip() {
245        let original = make_hint(vec![], 0);
246        let args = original.to_args().unwrap();
247        let decoded = PartitionHint::from_args(&args).unwrap();
248        assert_eq!(original, decoded);
249        assert!(decoded.strategies.is_empty());
250        assert_eq!(decoded.count, 0);
251    }
252
253    #[test]
254    fn registry_roundtrip() {
255        let mut registry = ExtensionRegistry::new();
256        registry.register_enhancement::<PartitionHint>().unwrap();
257
258        let original = make_hint(vec![PartitionStrategy::Hash], 4);
259        let any = original.to_any().unwrap();
260
261        let (name, args) = registry
262            .decode_enhancement(any.as_ref())
263            .expect("decode_enhancement");
264        assert_eq!(name, "PartitionHint");
265        assert_eq!(args.positional.len(), 1);
266
267        let any2 = registry
268            .parse_enhancement("PartitionHint", &args)
269            .expect("parse_enhancement");
270        let decoded = PartitionHint::from_any(any2.as_ref()).unwrap();
271        assert_eq!(original, decoded);
272    }
273}