substrait_explain/extensions/
examples.rs1use crate::extensions::args::{EnumValue, ExtensionArgs, ExtensionRelationType, ExtensionValue};
16use crate::extensions::registry::{Explainable, ExtensionError};
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, prost::Enumeration)]
24#[repr(i32)]
25pub enum PartitionStrategy {
26 Unspecified = 0,
28 Hash = 1,
30 Range = 2,
32 Broadcast = 3,
34}
35
36impl PartitionStrategy {
37 pub fn as_str_name(self) -> &'static str {
39 match self {
40 PartitionStrategy::Unspecified => "UNSPECIFIED",
41 PartitionStrategy::Hash => "HASH",
42 PartitionStrategy::Range => "RANGE",
43 PartitionStrategy::Broadcast => "BROADCAST",
44 }
45 }
46
47 pub fn from_str_name(s: &str) -> Option<Self> {
49 match s {
50 "UNSPECIFIED" => Some(PartitionStrategy::Unspecified),
51 "HASH" => Some(PartitionStrategy::Hash),
52 "RANGE" => Some(PartitionStrategy::Range),
53 "BROADCAST" => Some(PartitionStrategy::Broadcast),
54 _ => None,
55 }
56 }
57}
58
59#[derive(Clone, PartialEq, prost::Message)]
92pub struct PartitionHint {
93 #[prost(enumeration = "PartitionStrategy", repeated, tag = "1")]
96 pub strategies: Vec<i32>,
97 #[prost(int64, tag = "2")]
99 pub count: i64,
100}
101
102impl prost::Name for PartitionHint {
103 const PACKAGE: &'static str = "example";
104 const NAME: &'static str = "PartitionHint";
105
106 fn full_name() -> String {
107 "example.PartitionHint".to_owned()
108 }
109
110 fn type_url() -> String {
111 "type.googleapis.com/example.PartitionHint".to_owned()
112 }
113}
114
115impl Explainable for PartitionHint {
116 fn name() -> &'static str {
117 "PartitionHint"
118 }
119
120 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
121 let strategies: Result<Vec<i32>, ExtensionError> = args
123 .positional
124 .iter()
125 .map(|val| {
126 let EnumValue(ident) = EnumValue::try_from(val)?;
127 PartitionStrategy::from_str_name(&ident)
128 .map(|s| s as i32)
129 .ok_or_else(|| {
130 ExtensionError::InvalidArgument(format!(
131 "Unknown PartitionStrategy variant '&{ident}'; \
132 expected one of &UNSPECIFIED, &HASH, &RANGE, &BROADCAST"
133 ))
134 })
135 })
136 .collect();
137
138 let mut extractor = args.extractor();
139 let count: i64 = extractor.get_named_or("count", 0)?;
140 extractor.check_exhausted()?;
141
142 Ok(PartitionHint {
143 strategies: strategies?,
144 count,
145 })
146 }
147
148 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
149 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
150 for &raw in &self.strategies {
151 let s = PartitionStrategy::try_from(raw).unwrap_or(PartitionStrategy::Unspecified);
152 args.positional
153 .push(ExtensionValue::Enum(s.as_str_name().to_owned()));
154 }
155 if self.count != 0 {
156 args.named
157 .insert("count".to_owned(), ExtensionValue::Integer(self.count));
158 }
159 Ok(args)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::extensions::{AnyConvertible, ExtensionRegistry};
167
168 fn make_hint(strategies: Vec<PartitionStrategy>, count: i64) -> PartitionHint {
169 PartitionHint {
170 strategies: strategies.into_iter().map(|s| s as i32).collect(),
171 count,
172 }
173 }
174
175 #[test]
176 fn round_trip_via_any() {
177 let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 4);
178 let any = original.to_any().expect("encode");
179 let decoded = PartitionHint::from_any(any.as_ref()).expect("decode");
180 assert_eq!(original, decoded);
181 }
182
183 #[test]
184 fn to_args_produces_enum_and_named() {
185 let hint = make_hint(vec![PartitionStrategy::Hash], 8);
186 let args = hint.to_args().unwrap();
187 assert_eq!(args.positional.len(), 1);
188 assert!(matches!(&args.positional[0], ExtensionValue::Enum(s) if s == "HASH"));
189 assert!(matches!(
190 args.named.get("count"),
191 Some(ExtensionValue::Integer(8))
192 ));
193 }
194
195 #[test]
196 fn to_args_omits_zero_count() {
197 let hint = make_hint(vec![PartitionStrategy::Broadcast], 0);
198 let args = hint.to_args().unwrap();
199 assert!(args.named.is_empty(), "count=0 should be omitted");
200 }
201
202 #[test]
203 fn from_args_round_trip() {
204 let original = make_hint(vec![PartitionStrategy::Hash, PartitionStrategy::Range], 16);
205 let args = original.to_args().unwrap();
206 let decoded = PartitionHint::from_args(&args).unwrap();
207 assert_eq!(original, decoded);
208 }
209
210 #[test]
211 fn from_args_rejects_unknown_strategy() {
212 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
213 args.positional
214 .push(ExtensionValue::Enum("BOGUS".to_owned()));
215 assert!(PartitionHint::from_args(&args).is_err());
216 }
217
218 #[test]
219 fn from_args_rejects_non_enum_positional() {
220 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
222 args.positional.push(ExtensionValue::Integer(1));
223 let result = PartitionHint::from_args(&args);
224 assert!(
225 result.is_err(),
226 "expected error for non-enum positional arg, got {result:?}"
227 );
228 }
229
230 #[test]
231 fn from_args_rejects_extra_named_args() {
232 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
234 args.named
235 .insert("unknown_key".to_owned(), ExtensionValue::Integer(99));
236 let result = PartitionHint::from_args(&args);
237 assert!(
238 result.is_err(),
239 "expected error for unknown named arg, got {result:?}"
240 );
241 }
242
243 #[test]
244 fn from_args_empty_strategies_roundtrip() {
245 let original = make_hint(vec![], 0);
246 let args = original.to_args().unwrap();
247 let decoded = PartitionHint::from_args(&args).unwrap();
248 assert_eq!(original, decoded);
249 assert!(decoded.strategies.is_empty());
250 assert_eq!(decoded.count, 0);
251 }
252
253 #[test]
254 fn registry_roundtrip() {
255 let mut registry = ExtensionRegistry::new();
256 registry.register_enhancement::<PartitionHint>().unwrap();
257
258 let original = make_hint(vec![PartitionStrategy::Hash], 4);
259 let any = original.to_any().unwrap();
260
261 let (name, args) = registry
262 .decode_enhancement(any.as_ref())
263 .expect("decode_enhancement");
264 assert_eq!(name, "PartitionHint");
265 assert_eq!(args.positional.len(), 1);
266
267 let any2 = registry
268 .parse_enhancement("PartitionHint", &args)
269 .expect("parse_enhancement");
270 let decoded = PartitionHint::from_any(any2.as_ref()).unwrap();
271 assert_eq!(original, decoded);
272 }
273}