1use std::collections::HashMap;
86use std::fmt;
87use std::sync::Arc;
88
89use thiserror::Error;
90
91use crate::extensions::any::{Any, AnyRef};
92use crate::extensions::args::ExtensionArgs;
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum ExtensionType {
97 Relation,
99 Enhancement,
101 Optimization,
103}
104
105#[derive(Debug, Error, Clone)]
107pub enum RegistrationError {
108 #[error("{ext_type:?} extension '{name}' already registered")]
109 DuplicateName {
110 ext_type: ExtensionType,
111 name: String,
112 },
113
114 #[error("Type URL '{type_url}' already registered to {ext_type:?} extension '{existing_name}'")]
115 ConflictingTypeUrl {
116 type_url: String,
117 ext_type: ExtensionType,
118 existing_name: String,
119 },
120}
121
122#[derive(Debug, Error, Clone)]
124pub enum ExtensionError {
125 #[error("Extension '{name}' not found in registry")]
127 NotFound { name: String },
128
129 #[error("Missing required argument: {name}")]
131 MissingArgument { name: String },
132
133 #[error("Invalid argument: {0}")]
135 InvalidArgument(String),
136
137 #[error("Type URL mismatch: expected {expected}, got {actual}")]
139 TypeUrlMismatch { expected: String, actual: String },
140
141 #[error("Failed to decode protobuf message")]
143 DecodeFailed(#[source] prost::DecodeError),
144
145 #[error("Failed to encode protobuf message")]
147 EncodeFailed(#[source] prost::EncodeError),
148
149 #[error("Extension detail is missing")]
151 MissingDetail,
152
153 #[error("{0}")]
155 Custom(String),
156}
157
158pub trait AnyConvertible: Sized {
162 fn to_any(&self) -> Result<Any, ExtensionError>;
164
165 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError>;
167
168 fn type_url() -> String;
172}
173
174impl<T> AnyConvertible for T
176where
177 T: prost::Message + prost::Name + Default,
178{
179 fn to_any(&self) -> Result<Any, ExtensionError> {
180 Any::encode(self)
181 }
182
183 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
184 any.decode()
185 }
186
187 fn type_url() -> String {
188 T::type_url()
189 }
190}
191
192pub trait Explainable: Sized {
194 fn name() -> &'static str;
197
198 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError>;
200
201 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError>;
203}
204
205trait ExtensionConverter: Send + Sync {
219 fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError>;
220
221 fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError>;
222}
223
224struct ExtensionAdapter<T>(std::marker::PhantomData<T>);
241
242impl<T: AnyConvertible + Explainable + Send + Sync> ExtensionConverter for ExtensionAdapter<T> {
243 fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError> {
244 T::from_args(args)?.to_any()
246 }
247
248 fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError> {
249 let owned_any = Any::new(detail.type_url.to_string(), detail.value.to_vec());
252 T::from_any(owned_any.as_ref())?.to_args()
253 }
254}
255
256pub trait Extension: AnyConvertible + Explainable + Send + Sync + 'static {}
257
258impl<T> Extension for T where T: AnyConvertible + Explainable + Send + Sync + 'static {}
259
260#[derive(Default, Clone)]
262pub struct ExtensionRegistry {
263 handlers: HashMap<(ExtensionType, String), Arc<dyn ExtensionConverter>>,
265 type_urls: HashMap<(ExtensionType, String), String>,
267 descriptors: Vec<Vec<u8>>,
273}
274
275impl ExtensionRegistry {
276 pub fn new() -> Self {
278 Self {
279 handlers: HashMap::new(),
280 type_urls: HashMap::new(),
281 descriptors: Vec::new(),
282 }
283 }
284
285 pub fn add_descriptor(&mut self, bytes: Vec<u8>) {
293 self.descriptors.push(bytes);
294 }
295
296 pub fn descriptors(&self) -> Vec<&[u8]> {
298 self.descriptors.iter().map(|b| b.as_slice()).collect()
299 }
300
301 fn register<T>(&mut self, ext_type: ExtensionType) -> Result<(), RegistrationError>
303 where
304 T: Extension,
305 {
306 let canonical_name = T::name();
307 let type_url = T::type_url();
308 let handler: Arc<dyn ExtensionConverter> =
309 Arc::new(ExtensionAdapter::<T>(std::marker::PhantomData));
310
311 let key = (ext_type, canonical_name.to_string());
312 if self.handlers.contains_key(&key) {
313 return Err(RegistrationError::DuplicateName {
314 ext_type,
315 name: canonical_name.to_string(),
316 });
317 }
318
319 let type_url_key = (ext_type, type_url.clone());
321 if let Some(existing) = self.type_urls.get(&type_url_key)
322 && existing != canonical_name
323 {
324 return Err(RegistrationError::ConflictingTypeUrl {
325 type_url,
326 ext_type,
327 existing_name: existing.clone(),
328 });
329 }
330
331 self.handlers.insert(key, Arc::clone(&handler));
333 self.type_urls
334 .insert(type_url_key, canonical_name.to_string());
335 Ok(())
336 }
337
338 pub fn register_relation<T>(&mut self) -> Result<(), RegistrationError>
342 where
343 T: Extension,
344 {
345 self.register::<T>(ExtensionType::Relation)
346 }
347
348 pub fn register_enhancement<T>(&mut self) -> Result<(), RegistrationError>
355 where
356 T: Extension,
357 {
358 self.register::<T>(ExtensionType::Enhancement)
359 }
360
361 pub fn register_optimization<T>(&mut self) -> Result<(), RegistrationError>
368 where
369 T: Extension,
370 {
371 self.register::<T>(ExtensionType::Optimization)
372 }
373
374 pub fn parse_extension(
376 &self,
377 extension_name: &str,
378 args: &ExtensionArgs,
379 ) -> Result<Any, ExtensionError> {
380 self.parse_with_type(ExtensionType::Relation, extension_name, args)
381 }
382
383 pub fn parse_enhancement(
388 &self,
389 enhancement_name: &str,
390 args: &ExtensionArgs,
391 ) -> Result<Any, ExtensionError> {
392 self.parse_with_type(ExtensionType::Enhancement, enhancement_name, args)
393 }
394
395 pub fn parse_optimization(
400 &self,
401 optimization_name: &str,
402 args: &ExtensionArgs,
403 ) -> Result<Any, ExtensionError> {
404 self.parse_with_type(ExtensionType::Optimization, optimization_name, args)
405 }
406
407 fn parse_with_type(
409 &self,
410 ext_type: ExtensionType,
411 name: &str,
412 args: &ExtensionArgs,
413 ) -> Result<Any, ExtensionError> {
414 let key = (ext_type, name.to_string());
415 let handler = self
416 .handlers
417 .get(&key)
418 .ok_or_else(|| ExtensionError::NotFound {
419 name: name.to_string(),
420 })?;
421 handler.parse_detail(args)
422 }
423
424 pub fn decode(&self, detail: AnyRef<'_>) -> Result<(String, ExtensionArgs), ExtensionError> {
428 self.decode_with_type(ExtensionType::Relation, detail)
429 }
430
431 pub fn decode_enhancement(
439 &self,
440 detail: AnyRef<'_>,
441 ) -> Result<(String, ExtensionArgs), ExtensionError> {
442 self.decode_with_type(ExtensionType::Enhancement, detail)
443 }
444
445 pub fn decode_optimization(
453 &self,
454 detail: AnyRef<'_>,
455 ) -> Result<(String, ExtensionArgs), ExtensionError> {
456 self.decode_with_type(ExtensionType::Optimization, detail)
457 }
458
459 fn decode_with_type(
461 &self,
462 ext_type: ExtensionType,
463 detail: AnyRef<'_>,
464 ) -> Result<(String, ExtensionArgs), ExtensionError> {
465 let type_url_key = (ext_type, detail.type_url.to_string());
467 let extension_name =
468 self.type_urls
469 .get(&type_url_key)
470 .ok_or_else(|| ExtensionError::NotFound {
471 name: detail.type_url.to_string(),
472 })?;
473
474 let name_key = (ext_type, extension_name.clone());
476 let handler = self
477 .handlers
478 .get(&name_key)
479 .ok_or_else(|| ExtensionError::NotFound {
480 name: extension_name.clone(),
481 })?;
482
483 let args = handler.textify_detail(detail)?;
484
485 Ok((extension_name.clone(), args))
486 }
487
488 pub fn extension_names(&self, ext_type: ExtensionType) -> Vec<&str> {
490 let mut names: Vec<&str> = self
491 .type_urls
492 .iter()
493 .filter_map(|((t, _), name)| {
494 if *t == ext_type {
495 Some(name.as_str())
496 } else {
497 None
498 }
499 })
500 .collect();
501 names.sort_unstable();
502 names.dedup();
503 names
504 }
505
506 pub fn has_extension(&self, ext_type: ExtensionType, name: &str) -> bool {
508 self.handlers.contains_key(&(ext_type, name.to_string()))
509 }
510}
511
512impl fmt::Debug for ExtensionRegistry {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 let mut keys: Vec<_> = self
515 .handlers
516 .keys()
517 .map(|(t, n)| (format!("{t:?}"), n.as_str()))
518 .collect();
519 keys.sort();
520 f.debug_struct("ExtensionRegistry")
521 .field("handlers", &keys)
522 .finish()
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::extensions::{ExtensionColumn, ExtensionRelationType, ExtensionValue};
530
531 struct TestExtension {
533 path: String,
534 batch_size: i64,
535 }
536
537 impl AnyConvertible for TestExtension {
539 fn to_any(&self) -> Result<Any, ExtensionError> {
540 let json_str = format!(
542 r#"{{"path":"{}","batch_size":{}}}"#,
543 self.path, self.batch_size
544 );
545 Ok(Any::new(Self::type_url(), json_str.into_bytes()))
546 }
547
548 fn type_url() -> String {
549 "test.TestExtension".to_string()
550 }
551
552 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
553 let json_str = String::from_utf8(any.value.to_vec())
555 .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
556
557 if json_str.contains("path") && json_str.contains("batch_size") {
559 Ok(TestExtension {
560 path: "test.parquet".to_string(),
561 batch_size: 1024,
562 })
563 } else {
564 Err(ExtensionError::Custom("Missing fields".to_string()))
565 }
566 }
567 }
568
569 impl Explainable for TestExtension {
570 fn name() -> &'static str {
571 "TestExtension"
572 }
573
574 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
575 let mut extractor = args.extractor();
576 let path: String = extractor.expect_named_arg::<&str>("path")?.to_string();
577 let batch_size: i64 = extractor.expect_named_arg("batch_size")?;
578 extractor.check_exhausted()?;
579
580 Ok(TestExtension {
581 path: path.to_string(),
582 batch_size,
583 })
584 }
585
586 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
587 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
588 args.named.insert(
589 "path".to_string(),
590 ExtensionValue::String(self.path.clone()),
591 );
592 args.named.insert(
593 "batch_size".to_string(),
594 ExtensionValue::Integer(self.batch_size),
595 );
596 Ok(args)
597 }
598 }
599
600 #[test]
601 fn test_extension_registry_basic() {
602 let mut registry = ExtensionRegistry::new();
603
604 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 0);
606 assert!(!registry.has_extension(ExtensionType::Relation, "TestExtension"));
607
608 registry.register_relation::<TestExtension>().unwrap();
610
611 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
613 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
614
615 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
617 args.named.insert(
618 "path".to_string(),
619 ExtensionValue::String("data.parquet".to_string()),
620 );
621 args.named
622 .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
623
624 let any = registry.parse_extension("TestExtension", &args).unwrap();
625 assert_eq!(any.type_url, "test.TestExtension");
626
627 let any_ref = any.as_ref();
628 let result = registry.decode(any_ref).unwrap();
629 assert_eq!(result.0, "TestExtension");
630 match result.1.named.get("path") {
631 Some(ExtensionValue::String(s)) => assert_eq!(s, "test.parquet"), _ => panic!("Expected String for path"),
633 }
634 }
635
636 #[test]
637 fn test_extension_args() {
638 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
639
640 args.named.insert(
642 "path".to_string(),
643 ExtensionValue::String("data/*.parquet".to_string()),
644 );
645 args.named
646 .insert("batch_size".to_string(), ExtensionValue::Integer(1024));
647
648 args.positional.push(ExtensionValue::Reference(0));
650
651 args.output_columns.push(ExtensionColumn::Named {
653 name: "col1".to_string(),
654 type_spec: "i32".to_string(),
655 });
656
657 let mut extractor = args.extractor();
659
660 match extractor.get_named_arg("path") {
661 Some(ExtensionValue::String(s)) => assert_eq!(s, "data/*.parquet"),
662 _ => panic!("Expected String for path"),
663 }
664
665 match extractor.get_named_arg("batch_size") {
666 Some(ExtensionValue::Integer(i)) => assert_eq!(*i, 1024),
667 _ => panic!("Expected Integer for batch_size"),
668 }
669
670 assert!(extractor.check_exhausted().is_ok());
672
673 assert_eq!(args.positional.len(), 1);
674 assert_eq!(args.output_columns.len(), 1);
675 }
676
677 #[test]
678 fn test_extension_error_cases() {
679 let registry = ExtensionRegistry::new();
680
681 let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
683 let result = registry.parse_extension("NonExistent", &args);
684 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
685
686 let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
688 let mut extractor = args.extractor();
689 let result = extractor.get_named_arg("missing");
690 assert!(result.is_none());
691 assert!(extractor.check_exhausted().is_ok());
692
693 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
695 args.named
696 .insert("test".to_string(), ExtensionValue::Integer(42));
697 let mut extractor = args.extractor();
698 let result = extractor.get_named_arg("test");
699 match result {
700 Some(ExtensionValue::Integer(42)) => {} _ => panic!("Expected Integer(42), got {result:?}"),
702 }
703 assert!(extractor.check_exhausted().is_ok());
704 }
705
706 struct TestEnhancement {
708 hint: String,
709 }
710
711 impl AnyConvertible for TestEnhancement {
712 fn to_any(&self) -> Result<Any, ExtensionError> {
713 let json_str = format!(r#"{{"hint":"{}"}}"#, self.hint);
714 Ok(Any::new(Self::type_url(), json_str.into_bytes()))
715 }
716
717 fn type_url() -> String {
718 "test.TestExtension".to_string()
720 }
721
722 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
723 let json_str = String::from_utf8(any.value.to_vec())
724 .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
725 if json_str.contains("hint") {
726 Ok(TestEnhancement {
727 hint: "test_hint".to_string(),
728 })
729 } else {
730 Err(ExtensionError::Custom("Missing hint field".to_string()))
731 }
732 }
733 }
734
735 impl Explainable for TestEnhancement {
736 fn name() -> &'static str {
737 "TestEnhancement"
738 }
739
740 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
741 let mut extractor = args.extractor();
742 let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_string();
743 extractor.check_exhausted()?;
744 Ok(TestEnhancement { hint })
745 }
746
747 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
748 let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
749 args.named.insert(
750 "hint".to_string(),
751 ExtensionValue::String(self.hint.clone()),
752 );
753 Ok(args)
754 }
755 }
756
757 #[test]
758 fn test_namespace_separation() {
759 let mut registry = ExtensionRegistry::new();
760
761 registry.register_relation::<TestExtension>().unwrap();
763 registry.register_enhancement::<TestEnhancement>().unwrap();
764
765 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
767 assert!(registry.has_extension(ExtensionType::Enhancement, "TestEnhancement"));
768 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
769 assert_eq!(
770 registry.extension_names(ExtensionType::Enhancement).len(),
771 1
772 );
773
774 let mut ext_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
776 ext_args.named.insert(
777 "path".to_string(),
778 ExtensionValue::String("data.parquet".to_string()),
779 );
780 ext_args
781 .named
782 .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
783
784 let ext_any = registry
785 .parse_extension("TestExtension", &ext_args)
786 .unwrap();
787 assert_eq!(ext_any.type_url, "test.TestExtension");
788
789 let mut enh_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
791 enh_args.named.insert(
792 "hint".to_string(),
793 ExtensionValue::String("optimize".to_string()),
794 );
795
796 let enh_any = registry
797 .parse_enhancement("TestEnhancement", &enh_args)
798 .unwrap();
799 assert_eq!(enh_any.type_url, "test.TestExtension"); let enh_ref = enh_any.as_ref();
803 let (name, args) = registry.decode_enhancement(enh_ref).unwrap();
804 assert_eq!(name, "TestEnhancement");
805 match args.named.get("hint") {
806 Some(ExtensionValue::String(s)) => assert_eq!(s, "test_hint"), _ => panic!("Expected String for hint"),
808 }
809 }
810
811 #[test]
812 fn test_enhancement_duplicate_registration_returns_error() {
813 let mut registry = ExtensionRegistry::new();
814 registry.register_enhancement::<TestEnhancement>().unwrap();
815 let result = registry.register_enhancement::<TestEnhancement>();
816 assert!(matches!(
817 result,
818 Err(RegistrationError::DuplicateName { .. })
819 ));
820 }
821
822 #[test]
823 fn test_enhancement_not_found_error() {
824 let registry = ExtensionRegistry::new();
825 let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
826 let result = registry.parse_enhancement("NonExistentEnhancement", &args);
827 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
828 }
829
830 struct ConflictingExtension;
833
834 impl AnyConvertible for ConflictingExtension {
835 fn to_any(&self) -> Result<Any, ExtensionError> {
836 Ok(Any::new(Self::type_url(), vec![]))
837 }
838
839 fn type_url() -> String {
840 "test.TestExtension".to_string()
842 }
843
844 fn from_any<'a>(_any: AnyRef<'a>) -> Result<Self, ExtensionError> {
845 Ok(ConflictingExtension)
846 }
847 }
848
849 impl Explainable for ConflictingExtension {
850 fn name() -> &'static str {
851 "ConflictingExtension"
852 }
853
854 fn from_args(_args: &ExtensionArgs) -> Result<Self, ExtensionError> {
855 Ok(ConflictingExtension)
856 }
857
858 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
859 Ok(ExtensionArgs::new(ExtensionRelationType::Leaf))
860 }
861 }
862
863 #[test]
864 fn test_conflicting_type_url_leaves_registry_unchanged() {
865 let mut registry = ExtensionRegistry::new();
866 registry.register_relation::<TestExtension>().unwrap();
867
868 let result = registry.register_relation::<ConflictingExtension>();
870 assert!(matches!(
871 result,
872 Err(RegistrationError::ConflictingTypeUrl { .. })
873 ));
874
875 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
877 assert!(!registry.has_extension(ExtensionType::Relation, "ConflictingExtension"));
878 assert_eq!(
879 registry.extension_names(ExtensionType::Relation),
880 vec!["TestExtension"]
881 );
882 }
883}