1use std::collections::HashMap;
86use std::fmt;
87use std::sync::Arc;
88
89use thiserror::Error;
90
91use crate::extensions::any::{Any, AnyRef};
92use crate::extensions::args::ExtensionArgs;
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum ExtensionType {
97 Relation,
99 Enhancement,
101 Optimization,
103}
104
105#[derive(Debug, Error, Clone)]
107pub enum RegistrationError {
108 #[error("{ext_type:?} extension '{name}' already registered")]
109 DuplicateName {
110 ext_type: ExtensionType,
111 name: String,
112 },
113
114 #[error("Type URL '{type_url}' already registered to {ext_type:?} extension '{existing_name}'")]
115 ConflictingTypeUrl {
116 type_url: String,
117 ext_type: ExtensionType,
118 existing_name: String,
119 },
120}
121
122#[derive(Debug, Error, Clone)]
124pub enum ExtensionError {
125 #[error("Extension '{name}' not found in registry")]
127 NotFound { name: String },
128
129 #[error("Missing required argument: {name}")]
131 MissingArgument { name: String },
132
133 #[error("Invalid argument: {0}")]
135 InvalidArgument(String),
136
137 #[error("Type URL mismatch: expected {expected}, got {actual}")]
139 TypeUrlMismatch { expected: String, actual: String },
140
141 #[error("Failed to decode protobuf message")]
143 DecodeFailed(#[source] prost::DecodeError),
144
145 #[error("Failed to encode protobuf message")]
147 EncodeFailed(#[source] prost::EncodeError),
148
149 #[error("Extension detail is missing")]
151 MissingDetail,
152
153 #[error("{0}")]
155 Custom(String),
156}
157
158pub trait AnyConvertible: Sized {
162 fn to_any(&self) -> Result<Any, ExtensionError>;
164
165 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError>;
167
168 fn type_url() -> String;
172}
173
174impl<T> AnyConvertible for T
176where
177 T: prost::Message + prost::Name + Default,
178{
179 fn to_any(&self) -> Result<Any, ExtensionError> {
180 Any::encode(self)
181 }
182
183 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
184 any.decode()
185 }
186
187 fn type_url() -> String {
188 T::type_url()
189 }
190}
191
192pub trait Explainable: Sized {
194 fn name() -> &'static str;
197
198 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError>;
200
201 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError>;
203}
204
205trait ExtensionConverter: Send + Sync {
219 fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError>;
220
221 fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError>;
222}
223
224struct ExtensionAdapter<T>(std::marker::PhantomData<T>);
241
242impl<T: AnyConvertible + Explainable + Send + Sync> ExtensionConverter for ExtensionAdapter<T> {
243 fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError> {
244 T::from_args(args)?.to_any()
246 }
247
248 fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError> {
249 let owned_any = Any::new(detail.type_url.to_string(), detail.value.to_vec());
252 T::from_any(owned_any.as_ref())?.to_args()
253 }
254}
255
256pub trait Extension: AnyConvertible + Explainable + Send + Sync + 'static {}
257
258impl<T> Extension for T where T: AnyConvertible + Explainable + Send + Sync + 'static {}
259
260#[derive(Default, Clone)]
262pub struct ExtensionRegistry {
263 handlers: HashMap<(ExtensionType, String), Arc<dyn ExtensionConverter>>,
265 type_urls: HashMap<(ExtensionType, String), String>,
267}
268
269impl ExtensionRegistry {
270 pub fn new() -> Self {
272 Self {
273 handlers: HashMap::new(),
274 type_urls: HashMap::new(),
275 }
276 }
277
278 fn register<T>(&mut self, ext_type: ExtensionType) -> Result<(), RegistrationError>
280 where
281 T: Extension,
282 {
283 let canonical_name = T::name();
284 let type_url = T::type_url();
285 let handler: Arc<dyn ExtensionConverter> =
286 Arc::new(ExtensionAdapter::<T>(std::marker::PhantomData));
287
288 let key = (ext_type, canonical_name.to_string());
289 if self.handlers.contains_key(&key) {
290 return Err(RegistrationError::DuplicateName {
291 ext_type,
292 name: canonical_name.to_string(),
293 });
294 }
295
296 let type_url_key = (ext_type, type_url.clone());
298 if let Some(existing) = self.type_urls.get(&type_url_key)
299 && existing != canonical_name
300 {
301 return Err(RegistrationError::ConflictingTypeUrl {
302 type_url,
303 ext_type,
304 existing_name: existing.clone(),
305 });
306 }
307
308 self.handlers.insert(key, Arc::clone(&handler));
310 self.type_urls
311 .insert(type_url_key, canonical_name.to_string());
312 Ok(())
313 }
314
315 pub fn register_relation<T>(&mut self) -> Result<(), RegistrationError>
319 where
320 T: Extension,
321 {
322 self.register::<T>(ExtensionType::Relation)
323 }
324
325 pub fn register_enhancement<T>(&mut self) -> Result<(), RegistrationError>
332 where
333 T: Extension,
334 {
335 self.register::<T>(ExtensionType::Enhancement)
336 }
337
338 pub fn register_optimization<T>(&mut self) -> Result<(), RegistrationError>
345 where
346 T: Extension,
347 {
348 self.register::<T>(ExtensionType::Optimization)
349 }
350
351 pub fn parse_extension(
353 &self,
354 extension_name: &str,
355 args: &ExtensionArgs,
356 ) -> Result<Any, ExtensionError> {
357 self.parse_with_type(ExtensionType::Relation, extension_name, args)
358 }
359
360 pub fn parse_enhancement(
365 &self,
366 enhancement_name: &str,
367 args: &ExtensionArgs,
368 ) -> Result<Any, ExtensionError> {
369 self.parse_with_type(ExtensionType::Enhancement, enhancement_name, args)
370 }
371
372 pub fn parse_optimization(
377 &self,
378 optimization_name: &str,
379 args: &ExtensionArgs,
380 ) -> Result<Any, ExtensionError> {
381 self.parse_with_type(ExtensionType::Optimization, optimization_name, args)
382 }
383
384 fn parse_with_type(
386 &self,
387 ext_type: ExtensionType,
388 name: &str,
389 args: &ExtensionArgs,
390 ) -> Result<Any, ExtensionError> {
391 let key = (ext_type, name.to_string());
392 let handler = self
393 .handlers
394 .get(&key)
395 .ok_or_else(|| ExtensionError::NotFound {
396 name: name.to_string(),
397 })?;
398 handler.parse_detail(args)
399 }
400
401 pub fn decode(&self, detail: AnyRef<'_>) -> Result<(String, ExtensionArgs), ExtensionError> {
405 self.decode_with_type(ExtensionType::Relation, detail)
406 }
407
408 pub fn decode_enhancement(
416 &self,
417 detail: AnyRef<'_>,
418 ) -> Result<(String, ExtensionArgs), ExtensionError> {
419 self.decode_with_type(ExtensionType::Enhancement, detail)
420 }
421
422 pub fn decode_optimization(
430 &self,
431 detail: AnyRef<'_>,
432 ) -> Result<(String, ExtensionArgs), ExtensionError> {
433 self.decode_with_type(ExtensionType::Optimization, detail)
434 }
435
436 fn decode_with_type(
438 &self,
439 ext_type: ExtensionType,
440 detail: AnyRef<'_>,
441 ) -> Result<(String, ExtensionArgs), ExtensionError> {
442 let type_url_key = (ext_type, detail.type_url.to_string());
444 let extension_name =
445 self.type_urls
446 .get(&type_url_key)
447 .ok_or_else(|| ExtensionError::NotFound {
448 name: detail.type_url.to_string(),
449 })?;
450
451 let name_key = (ext_type, extension_name.clone());
453 let handler = self
454 .handlers
455 .get(&name_key)
456 .ok_or_else(|| ExtensionError::NotFound {
457 name: extension_name.clone(),
458 })?;
459
460 let args = handler.textify_detail(detail)?;
461
462 Ok((extension_name.clone(), args))
463 }
464
465 pub fn extension_names(&self, ext_type: ExtensionType) -> Vec<&str> {
467 let mut names: Vec<&str> = self
468 .type_urls
469 .iter()
470 .filter_map(|((t, _), name)| {
471 if *t == ext_type {
472 Some(name.as_str())
473 } else {
474 None
475 }
476 })
477 .collect();
478 names.sort_unstable();
479 names.dedup();
480 names
481 }
482
483 pub fn has_extension(&self, ext_type: ExtensionType, name: &str) -> bool {
485 self.handlers.contains_key(&(ext_type, name.to_string()))
486 }
487}
488
489impl fmt::Debug for ExtensionRegistry {
490 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
491 let mut keys: Vec<_> = self
492 .handlers
493 .keys()
494 .map(|(t, n)| (format!("{t:?}"), n.as_str()))
495 .collect();
496 keys.sort();
497 f.debug_struct("ExtensionRegistry")
498 .field("handlers", &keys)
499 .finish()
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::extensions::{ExtensionColumn, ExtensionRelationType, ExtensionValue};
507
508 struct TestExtension {
510 path: String,
511 batch_size: i64,
512 }
513
514 impl AnyConvertible for TestExtension {
516 fn to_any(&self) -> Result<Any, ExtensionError> {
517 let json_str = format!(
519 r#"{{"path":"{}","batch_size":{}}}"#,
520 self.path, self.batch_size
521 );
522 Ok(Any::new(Self::type_url(), json_str.into_bytes()))
523 }
524
525 fn type_url() -> String {
526 "test.TestExtension".to_string()
527 }
528
529 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
530 let json_str = String::from_utf8(any.value.to_vec())
532 .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
533
534 if json_str.contains("path") && json_str.contains("batch_size") {
536 Ok(TestExtension {
537 path: "test.parquet".to_string(),
538 batch_size: 1024,
539 })
540 } else {
541 Err(ExtensionError::Custom("Missing fields".to_string()))
542 }
543 }
544 }
545
546 impl Explainable for TestExtension {
547 fn name() -> &'static str {
548 "TestExtension"
549 }
550
551 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
552 let mut extractor = args.extractor();
553 let path: String = extractor.expect_named_arg::<&str>("path")?.to_string();
554 let batch_size: i64 = extractor.expect_named_arg("batch_size")?;
555 extractor.check_exhausted()?;
556
557 Ok(TestExtension {
558 path: path.to_string(),
559 batch_size,
560 })
561 }
562
563 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
564 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
565 args.named.insert(
566 "path".to_string(),
567 ExtensionValue::String(self.path.clone()),
568 );
569 args.named.insert(
570 "batch_size".to_string(),
571 ExtensionValue::Integer(self.batch_size),
572 );
573 Ok(args)
574 }
575 }
576
577 #[test]
578 fn test_extension_registry_basic() {
579 let mut registry = ExtensionRegistry::new();
580
581 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 0);
583 assert!(!registry.has_extension(ExtensionType::Relation, "TestExtension"));
584
585 registry.register_relation::<TestExtension>().unwrap();
587
588 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
590 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
591
592 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
594 args.named.insert(
595 "path".to_string(),
596 ExtensionValue::String("data.parquet".to_string()),
597 );
598 args.named
599 .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
600
601 let any = registry.parse_extension("TestExtension", &args).unwrap();
602 assert_eq!(any.type_url, "test.TestExtension");
603
604 let any_ref = any.as_ref();
605 let result = registry.decode(any_ref).unwrap();
606 assert_eq!(result.0, "TestExtension");
607 match result.1.named.get("path") {
608 Some(ExtensionValue::String(s)) => assert_eq!(s, "test.parquet"), _ => panic!("Expected String for path"),
610 }
611 }
612
613 #[test]
614 fn test_extension_args() {
615 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
616
617 args.named.insert(
619 "path".to_string(),
620 ExtensionValue::String("data/*.parquet".to_string()),
621 );
622 args.named
623 .insert("batch_size".to_string(), ExtensionValue::Integer(1024));
624
625 args.positional.push(ExtensionValue::Reference(0));
627
628 args.output_columns.push(ExtensionColumn::Named {
630 name: "col1".to_string(),
631 type_spec: "i32".to_string(),
632 });
633
634 let mut extractor = args.extractor();
636
637 match extractor.get_named_arg("path") {
638 Some(ExtensionValue::String(s)) => assert_eq!(s, "data/*.parquet"),
639 _ => panic!("Expected String for path"),
640 }
641
642 match extractor.get_named_arg("batch_size") {
643 Some(ExtensionValue::Integer(i)) => assert_eq!(*i, 1024),
644 _ => panic!("Expected Integer for batch_size"),
645 }
646
647 assert!(extractor.check_exhausted().is_ok());
649
650 assert_eq!(args.positional.len(), 1);
651 assert_eq!(args.output_columns.len(), 1);
652 }
653
654 #[test]
655 fn test_extension_error_cases() {
656 let registry = ExtensionRegistry::new();
657
658 let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
660 let result = registry.parse_extension("NonExistent", &args);
661 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
662
663 let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
665 let mut extractor = args.extractor();
666 let result = extractor.get_named_arg("missing");
667 assert!(result.is_none());
668 assert!(extractor.check_exhausted().is_ok());
669
670 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
672 args.named
673 .insert("test".to_string(), ExtensionValue::Integer(42));
674 let mut extractor = args.extractor();
675 let result = extractor.get_named_arg("test");
676 match result {
677 Some(ExtensionValue::Integer(42)) => {} _ => panic!("Expected Integer(42), got {result:?}"),
679 }
680 assert!(extractor.check_exhausted().is_ok());
681 }
682
683 struct TestEnhancement {
685 hint: String,
686 }
687
688 impl AnyConvertible for TestEnhancement {
689 fn to_any(&self) -> Result<Any, ExtensionError> {
690 let json_str = format!(r#"{{"hint":"{}"}}"#, self.hint);
691 Ok(Any::new(Self::type_url(), json_str.into_bytes()))
692 }
693
694 fn type_url() -> String {
695 "test.TestExtension".to_string()
697 }
698
699 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
700 let json_str = String::from_utf8(any.value.to_vec())
701 .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
702 if json_str.contains("hint") {
703 Ok(TestEnhancement {
704 hint: "test_hint".to_string(),
705 })
706 } else {
707 Err(ExtensionError::Custom("Missing hint field".to_string()))
708 }
709 }
710 }
711
712 impl Explainable for TestEnhancement {
713 fn name() -> &'static str {
714 "TestEnhancement"
715 }
716
717 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
718 let mut extractor = args.extractor();
719 let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_string();
720 extractor.check_exhausted()?;
721 Ok(TestEnhancement { hint })
722 }
723
724 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
725 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
726 args.named.insert(
727 "hint".to_string(),
728 ExtensionValue::String(self.hint.clone()),
729 );
730 Ok(args)
731 }
732 }
733
734 #[test]
735 fn test_namespace_separation() {
736 let mut registry = ExtensionRegistry::new();
737
738 registry.register_relation::<TestExtension>().unwrap();
740 registry.register_enhancement::<TestEnhancement>().unwrap();
741
742 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
744 assert!(registry.has_extension(ExtensionType::Enhancement, "TestEnhancement"));
745 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
746 assert_eq!(
747 registry.extension_names(ExtensionType::Enhancement).len(),
748 1
749 );
750
751 let mut ext_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
753 ext_args.named.insert(
754 "path".to_string(),
755 ExtensionValue::String("data.parquet".to_string()),
756 );
757 ext_args
758 .named
759 .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
760
761 let ext_any = registry
762 .parse_extension("TestExtension", &ext_args)
763 .unwrap();
764 assert_eq!(ext_any.type_url, "test.TestExtension");
765
766 let mut enh_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
768 enh_args.named.insert(
769 "hint".to_string(),
770 ExtensionValue::String("optimize".to_string()),
771 );
772
773 let enh_any = registry
774 .parse_enhancement("TestEnhancement", &enh_args)
775 .unwrap();
776 assert_eq!(enh_any.type_url, "test.TestExtension"); let enh_ref = enh_any.as_ref();
780 let (name, args) = registry.decode_enhancement(enh_ref).unwrap();
781 assert_eq!(name, "TestEnhancement");
782 match args.named.get("hint") {
783 Some(ExtensionValue::String(s)) => assert_eq!(s, "test_hint"), _ => panic!("Expected String for hint"),
785 }
786 }
787
788 #[test]
789 fn test_enhancement_duplicate_registration_returns_error() {
790 let mut registry = ExtensionRegistry::new();
791 registry.register_enhancement::<TestEnhancement>().unwrap();
792 let result = registry.register_enhancement::<TestEnhancement>();
793 assert!(matches!(
794 result,
795 Err(RegistrationError::DuplicateName { .. })
796 ));
797 }
798
799 #[test]
800 fn test_enhancement_not_found_error() {
801 let registry = ExtensionRegistry::new();
802 let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
803 let result = registry.parse_enhancement("NonExistentEnhancement", &args);
804 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
805 }
806
807 struct ConflictingExtension;
810
811 impl AnyConvertible for ConflictingExtension {
812 fn to_any(&self) -> Result<Any, ExtensionError> {
813 Ok(Any::new(Self::type_url(), vec![]))
814 }
815
816 fn type_url() -> String {
817 "test.TestExtension".to_string()
819 }
820
821 fn from_any<'a>(_any: AnyRef<'a>) -> Result<Self, ExtensionError> {
822 Ok(ConflictingExtension)
823 }
824 }
825
826 impl Explainable for ConflictingExtension {
827 fn name() -> &'static str {
828 "ConflictingExtension"
829 }
830
831 fn from_args(_args: &ExtensionArgs) -> Result<Self, ExtensionError> {
832 Ok(ConflictingExtension)
833 }
834
835 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
836 Ok(ExtensionArgs::new(ExtensionRelationType::Leaf))
837 }
838 }
839
840 #[test]
841 fn test_conflicting_type_url_leaves_registry_unchanged() {
842 let mut registry = ExtensionRegistry::new();
843 registry.register_relation::<TestExtension>().unwrap();
844
845 let result = registry.register_relation::<ConflictingExtension>();
847 assert!(matches!(
848 result,
849 Err(RegistrationError::ConflictingTypeUrl { .. })
850 ));
851
852 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
854 assert!(!registry.has_extension(ExtensionType::Relation, "ConflictingExtension"));
855 assert_eq!(
856 registry.extension_names(ExtensionType::Relation),
857 vec!["TestExtension"]
858 );
859 }
860}