1use std::collections::HashMap;
84use std::fmt;
85use std::sync::Arc;
86
87use substrait::proto::NamedStruct;
88use substrait::proto::r#type::{Nullability, Struct};
89use thiserror::Error;
90
91use crate::extensions::any::{Any, AnyRef};
92use crate::extensions::args::{ExtensionArgs, ExtensionColumn, ExtensionValueKind};
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum ExtensionType {
97 Relation,
99 ExtensionTable,
101 Enhancement,
103 Optimization,
105}
106
107#[derive(Debug, Error, Clone)]
109pub enum RegistrationError {
110 #[error("{ext_type:?} extension '{name}' already registered")]
111 DuplicateName {
112 ext_type: ExtensionType,
113 name: String,
114 },
115
116 #[error("Type URL '{type_url}' already registered to {ext_type:?} extension '{existing_name}'")]
117 ConflictingTypeUrl {
118 type_url: String,
119 ext_type: ExtensionType,
120 existing_name: String,
121 },
122}
123
124#[derive(Debug, Error, Clone)]
126pub enum ExtensionError {
127 #[error("Extension '{name}' not found in registry")]
129 NotFound { name: String },
130
131 #[error("Missing required argument: {name}")]
133 MissingArgument { name: String },
134
135 #[error("Invalid argument: expected {expected}, got {actual}")]
137 InvalidArgumentType {
138 expected: ExtensionValueKind,
139 actual: ExtensionValueKind,
140 },
141
142 #[error("Invalid argument: {0}")]
148 InvalidArgument(String),
149
150 #[error("Type URL mismatch: expected {expected}, got {actual}")]
152 TypeUrlMismatch { expected: String, actual: String },
153
154 #[error("Failed to decode protobuf message")]
156 DecodeFailed(#[source] prost::DecodeError),
157
158 #[error("Failed to encode protobuf message")]
160 EncodeFailed(#[source] prost::EncodeError),
161
162 #[error("Extension detail is missing")]
164 MissingDetail,
165
166 #[error("{0}")]
168 Custom(String),
169}
170
171pub trait AnyConvertible: Sized {
175 fn to_any(&self) -> Result<Any, ExtensionError>;
177
178 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError>;
180
181 fn type_url() -> String;
185}
186
187impl<T> AnyConvertible for T
189where
190 T: prost::Message + prost::Name + Default,
191{
192 fn to_any(&self) -> Result<Any, ExtensionError> {
193 Any::encode(self)
194 }
195
196 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
197 any.decode()
198 }
199
200 fn type_url() -> String {
201 T::type_url()
202 }
203}
204
205pub trait ExtensionProtoConvert<T> {
216 fn convert(&self) -> Result<T, ExtensionError>;
218}
219
220impl ExtensionProtoConvert<NamedStruct> for [ExtensionColumn] {
221 fn convert(&self) -> Result<NamedStruct, ExtensionError> {
222 let mut names = Vec::with_capacity(self.len());
223 let mut types = Vec::with_capacity(self.len());
224 for col in self {
225 match col {
226 ExtensionColumn::Named { name, r#type: ty } => {
227 names.push(name.clone());
228 types.push(ty.clone());
229 }
230 other => {
231 return Err(ExtensionError::InvalidArgument(format!(
232 "Expected named column, got {other:?}"
233 )));
234 }
235 }
236 }
237 Ok(NamedStruct {
238 names,
239 r#struct: Some(Struct {
240 types,
241 type_variation_reference: 0,
242 nullability: Nullability::Required as i32,
246 }),
247 })
248 }
249}
250
251impl ExtensionProtoConvert<Vec<ExtensionColumn>> for NamedStruct {
252 fn convert(&self) -> Result<Vec<ExtensionColumn>, ExtensionError> {
253 let types = self
254 .r#struct
255 .as_ref()
256 .map(|s| s.types.as_slice())
257 .unwrap_or_default();
258 if self.names.len() != types.len() {
259 return Err(ExtensionError::InvalidArgument(format!(
260 "NamedStruct has {} names but {} types",
261 self.names.len(),
262 types.len()
263 )));
264 }
265 Ok(self
266 .names
267 .iter()
268 .zip(types.iter())
269 .map(|(name, ty)| ExtensionColumn::Named {
270 name: name.clone(),
271 r#type: ty.clone(),
272 })
273 .collect())
274 }
275}
276
277pub trait Explainable: Sized {
279 fn name() -> &'static str;
282
283 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError>;
285
286 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError>;
288}
289
290trait ExtensionConverter: Send + Sync {
304 fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError>;
305
306 fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError>;
307}
308
309struct ExtensionAdapter<T>(std::marker::PhantomData<T>);
326
327impl<T: AnyConvertible + Explainable + Send + Sync> ExtensionConverter for ExtensionAdapter<T> {
328 fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError> {
329 T::from_args(args)?.to_any()
330 }
331
332 fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError> {
333 let owned_any = Any::new(detail.type_url.to_string(), detail.value.to_vec());
334 T::from_any(owned_any.as_ref())?.to_args()
335 }
336}
337
338pub trait Extension: AnyConvertible + Explainable + Send + Sync + 'static {}
339
340impl<T> Extension for T where T: AnyConvertible + Explainable + Send + Sync + 'static {}
341
342#[derive(Default, Clone)]
344pub struct ExtensionRegistry {
345 handlers: HashMap<(ExtensionType, String), Arc<dyn ExtensionConverter>>,
347 type_urls: HashMap<(ExtensionType, String), String>,
349 descriptors: Vec<Vec<u8>>,
355}
356
357impl ExtensionRegistry {
358 pub fn new() -> Self {
360 Self {
361 handlers: HashMap::new(),
362 type_urls: HashMap::new(),
363 descriptors: Vec::new(),
364 }
365 }
366
367 pub fn add_descriptor(&mut self, bytes: Vec<u8>) {
375 self.descriptors.push(bytes);
376 }
377
378 pub fn descriptors(&self) -> Vec<&[u8]> {
380 self.descriptors.iter().map(|b| b.as_slice()).collect()
381 }
382
383 fn register<T>(&mut self, ext_type: ExtensionType) -> Result<(), RegistrationError>
385 where
386 T: Extension,
387 {
388 let canonical_name = T::name();
389 let type_url = T::type_url();
390 let handler: Arc<dyn ExtensionConverter> =
391 Arc::new(ExtensionAdapter::<T>(std::marker::PhantomData));
392
393 let key = (ext_type, canonical_name.to_string());
394 if self.handlers.contains_key(&key) {
395 return Err(RegistrationError::DuplicateName {
396 ext_type,
397 name: canonical_name.to_string(),
398 });
399 }
400
401 let type_url_key = (ext_type, type_url.clone());
403 if let Some(existing) = self.type_urls.get(&type_url_key)
404 && existing != canonical_name
405 {
406 return Err(RegistrationError::ConflictingTypeUrl {
407 type_url,
408 ext_type,
409 existing_name: existing.clone(),
410 });
411 }
412
413 self.handlers.insert(key, Arc::clone(&handler));
415 self.type_urls
416 .insert(type_url_key, canonical_name.to_string());
417 Ok(())
418 }
419
420 pub fn register_relation<T>(&mut self) -> Result<(), RegistrationError>
424 where
425 T: Extension,
426 {
427 self.register::<T>(ExtensionType::Relation)
428 }
429
430 pub fn register_extension_table<T>(&mut self) -> Result<(), RegistrationError>
438 where
439 T: Extension,
440 {
441 self.register::<T>(ExtensionType::ExtensionTable)
442 }
443
444 pub fn register_enhancement<T>(&mut self) -> Result<(), RegistrationError>
451 where
452 T: Extension,
453 {
454 self.register::<T>(ExtensionType::Enhancement)
455 }
456
457 pub fn register_optimization<T>(&mut self) -> Result<(), RegistrationError>
464 where
465 T: Extension,
466 {
467 self.register::<T>(ExtensionType::Optimization)
468 }
469
470 pub fn parse_extension(
472 &self,
473 extension_name: &str,
474 args: &ExtensionArgs,
475 ) -> Result<Any, ExtensionError> {
476 self.parse_with_type(ExtensionType::Relation, extension_name, args)
477 }
478
479 pub fn parse_extension_table(
484 &self,
485 extension_table_name: &str,
486 args: &ExtensionArgs,
487 ) -> Result<Any, ExtensionError> {
488 self.parse_with_type(ExtensionType::ExtensionTable, extension_table_name, args)
489 }
490
491 pub fn parse_enhancement(
496 &self,
497 enhancement_name: &str,
498 args: &ExtensionArgs,
499 ) -> Result<Any, ExtensionError> {
500 self.parse_with_type(ExtensionType::Enhancement, enhancement_name, args)
501 }
502
503 pub fn parse_optimization(
508 &self,
509 optimization_name: &str,
510 args: &ExtensionArgs,
511 ) -> Result<Any, ExtensionError> {
512 self.parse_with_type(ExtensionType::Optimization, optimization_name, args)
513 }
514
515 fn parse_with_type(
517 &self,
518 ext_type: ExtensionType,
519 name: &str,
520 args: &ExtensionArgs,
521 ) -> Result<Any, ExtensionError> {
522 let key = (ext_type, name.to_string());
523 let handler = self
524 .handlers
525 .get(&key)
526 .ok_or_else(|| ExtensionError::NotFound {
527 name: name.to_string(),
528 })?;
529 handler.parse_detail(args)
530 }
531
532 pub fn decode(&self, detail: AnyRef<'_>) -> Result<(String, ExtensionArgs), ExtensionError> {
536 self.decode_with_type(ExtensionType::Relation, detail)
537 }
538
539 pub fn decode_extension_table(
545 &self,
546 detail: AnyRef<'_>,
547 ) -> Result<(String, ExtensionArgs), ExtensionError> {
548 self.decode_with_type(ExtensionType::ExtensionTable, detail)
549 }
550
551 pub fn decode_enhancement(
559 &self,
560 detail: AnyRef<'_>,
561 ) -> Result<(String, ExtensionArgs), ExtensionError> {
562 self.decode_with_type(ExtensionType::Enhancement, detail)
563 }
564
565 pub fn decode_optimization(
573 &self,
574 detail: AnyRef<'_>,
575 ) -> Result<(String, ExtensionArgs), ExtensionError> {
576 self.decode_with_type(ExtensionType::Optimization, detail)
577 }
578
579 fn decode_with_type(
581 &self,
582 ext_type: ExtensionType,
583 detail: AnyRef<'_>,
584 ) -> Result<(String, ExtensionArgs), ExtensionError> {
585 let type_url_key = (ext_type, detail.type_url.to_string());
587 let extension_name =
588 self.type_urls
589 .get(&type_url_key)
590 .ok_or_else(|| ExtensionError::NotFound {
591 name: detail.type_url.to_string(),
592 })?;
593
594 let name_key = (ext_type, extension_name.clone());
596 let handler = self
597 .handlers
598 .get(&name_key)
599 .ok_or_else(|| ExtensionError::NotFound {
600 name: extension_name.clone(),
601 })?;
602
603 let args = handler.textify_detail(detail)?;
604
605 Ok((extension_name.clone(), args))
606 }
607
608 pub fn extension_names(&self, ext_type: ExtensionType) -> Vec<&str> {
610 let mut names: Vec<&str> = self
611 .type_urls
612 .iter()
613 .filter_map(|((t, _), name)| {
614 if *t == ext_type {
615 Some(name.as_str())
616 } else {
617 None
618 }
619 })
620 .collect();
621 names.sort_unstable();
622 names.dedup();
623 names
624 }
625
626 pub fn has_extension(&self, ext_type: ExtensionType, name: &str) -> bool {
628 self.handlers.contains_key(&(ext_type, name.to_string()))
629 }
630}
631
632impl fmt::Debug for ExtensionRegistry {
633 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634 let mut keys: Vec<_> = self
635 .handlers
636 .keys()
637 .map(|(t, n)| (format!("{t:?}"), n.as_str()))
638 .collect();
639 keys.sort();
640 f.debug_struct("ExtensionRegistry")
641 .field("handlers", &keys)
642 .finish()
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use crate::extensions::ExtensionColumn;
650
651 struct TestExtension {
653 path: String,
654 batch_size: i64,
655 }
656
657 impl AnyConvertible for TestExtension {
659 fn to_any(&self) -> Result<Any, ExtensionError> {
660 let json_str = format!(
662 r#"{{"path":"{}","batch_size":{}}}"#,
663 self.path, self.batch_size
664 );
665 Ok(Any::new(Self::type_url(), json_str.into_bytes()))
666 }
667
668 fn type_url() -> String {
669 "test.TestExtension".to_string()
670 }
671
672 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
673 let json_str = String::from_utf8(any.value.to_vec())
675 .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
676
677 if json_str.contains("path") && json_str.contains("batch_size") {
679 Ok(TestExtension {
680 path: "test.parquet".to_string(),
681 batch_size: 1024,
682 })
683 } else {
684 Err(ExtensionError::Custom("Missing fields".to_string()))
685 }
686 }
687 }
688
689 impl Explainable for TestExtension {
690 fn name() -> &'static str {
691 "TestExtension"
692 }
693
694 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
695 let mut extractor = args.extractor();
696 let path: String = extractor.expect_named_arg::<&str>("path")?.to_string();
697 let batch_size: i64 = extractor.expect_named_arg("batch_size")?;
698 extractor.check_exhausted()?;
699
700 Ok(TestExtension {
701 path: path.to_string(),
702 batch_size,
703 })
704 }
705
706 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
707 let mut args = ExtensionArgs::default();
708 args.insert("path", self.path.clone());
709 args.insert("batch_size", self.batch_size);
710 Ok(args)
711 }
712 }
713
714 #[test]
715 fn test_extension_registry_basic() {
716 let mut registry = ExtensionRegistry::new();
717
718 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 0);
720 assert_eq!(
721 registry
722 .extension_names(ExtensionType::ExtensionTable)
723 .len(),
724 0
725 );
726 assert!(!registry.has_extension(ExtensionType::Relation, "TestExtension"));
727
728 registry.register_relation::<TestExtension>().unwrap();
730
731 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
733 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
734
735 let mut args = ExtensionArgs::default();
737 args.insert("path", "data.parquet");
738 args.insert("batch_size", 2048_i64);
739
740 let any = registry.parse_extension("TestExtension", &args).unwrap();
741 assert_eq!(any.type_url, "test.TestExtension");
742
743 let any_ref = any.as_ref();
744 let result = registry.decode(any_ref).unwrap();
745 assert_eq!(result.0, "TestExtension");
746 assert_eq!(
747 <&str>::try_from(result.1.named.get("path").unwrap()).unwrap(),
748 "test.parquet"
749 );
750 }
751
752 #[test]
753 fn test_extension_table_registry_basic() {
754 let mut registry = ExtensionRegistry::new();
755
756 registry
757 .register_extension_table::<TestExtension>()
758 .unwrap();
759
760 assert_eq!(
761 registry.extension_names(ExtensionType::ExtensionTable),
762 vec!["TestExtension"]
763 );
764 assert!(registry.has_extension(ExtensionType::ExtensionTable, "TestExtension"));
765
766 let mut args = ExtensionArgs::default();
767 args.insert("path", "data.parquet");
768 args.insert("batch_size", 2048_i64);
769
770 let any = registry
771 .parse_extension_table("TestExtension", &args)
772 .unwrap();
773 assert_eq!(any.type_url, "test.TestExtension");
774
775 let (name, decoded_args) = registry.decode_extension_table(any.as_ref()).unwrap();
776 assert_eq!(name, "TestExtension");
777 assert_eq!(
778 <&str>::try_from(decoded_args.named.get("path").unwrap()).unwrap(),
779 "test.parquet"
780 );
781 }
782
783 #[test]
784 fn test_extension_args() {
785 let mut args = ExtensionArgs::default();
786
787 args.insert("path", "data/*.parquet");
789 args.insert("batch_size", 1024_i64);
790
791 args.push(crate::textify::expressions::Reference(0));
793
794 args.output_columns.push(ExtensionColumn::Named {
796 name: "col1".to_string(),
797 r#type: crate::fixtures::parse_type("i32"),
798 });
799
800 let mut extractor = args.extractor();
802
803 let path = extractor.get_named_arg("path").unwrap();
804 assert_eq!(<&str>::try_from(path).unwrap(), "data/*.parquet");
805
806 let batch_size = extractor.get_named_arg("batch_size").unwrap();
807 assert_eq!(i64::try_from(batch_size).unwrap(), 1024);
808
809 assert!(extractor.check_exhausted().is_ok());
811
812 assert_eq!(args.positional.len(), 1);
813 assert_eq!(args.output_columns.len(), 1);
814 }
815
816 #[test]
817 fn test_extension_error_cases() {
818 let registry = ExtensionRegistry::new();
819
820 let args = ExtensionArgs::default();
822 let result = registry.parse_extension("NonExistent", &args);
823 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
824
825 let args = ExtensionArgs::default();
827 let mut extractor = args.extractor();
828 let result = extractor.get_named_arg("missing");
829 assert!(result.is_none());
830 assert!(extractor.check_exhausted().is_ok());
831
832 let mut args = ExtensionArgs::default();
834 args.insert("test", 42_i64);
835 let mut extractor = args.extractor();
836 let result = extractor.get_named_arg("test");
837 assert_eq!(i64::try_from(result.unwrap()).unwrap(), 42);
838 assert!(extractor.check_exhausted().is_ok());
839 }
840
841 struct TestEnhancement {
843 hint: String,
844 }
845
846 impl AnyConvertible for TestEnhancement {
847 fn to_any(&self) -> Result<Any, ExtensionError> {
848 let json_str = format!(r#"{{"hint":"{}"}}"#, self.hint);
849 Ok(Any::new(Self::type_url(), json_str.into_bytes()))
850 }
851
852 fn type_url() -> String {
853 "test.TestExtension".to_string()
855 }
856
857 fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
858 let json_str = String::from_utf8(any.value.to_vec())
859 .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
860 if json_str.contains("hint") {
861 Ok(TestEnhancement {
862 hint: "test_hint".to_string(),
863 })
864 } else {
865 Err(ExtensionError::Custom("Missing hint field".to_string()))
866 }
867 }
868 }
869
870 impl Explainable for TestEnhancement {
871 fn name() -> &'static str {
872 "TestEnhancement"
873 }
874
875 fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
876 let mut extractor = args.extractor();
877 let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_string();
878 extractor.check_exhausted()?;
879 Ok(TestEnhancement { hint })
880 }
881
882 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
883 let mut args = ExtensionArgs::default();
884 args.insert("hint", self.hint.clone());
885 Ok(args)
886 }
887 }
888
889 #[test]
890 fn test_namespace_separation() {
891 let mut registry = ExtensionRegistry::new();
892
893 registry.register_relation::<TestExtension>().unwrap();
895 registry
896 .register_extension_table::<TestExtension>()
897 .unwrap();
898 registry.register_enhancement::<TestEnhancement>().unwrap();
899
900 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
902 assert!(registry.has_extension(ExtensionType::ExtensionTable, "TestExtension"));
903 assert!(registry.has_extension(ExtensionType::Enhancement, "TestEnhancement"));
904 assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
905 assert_eq!(
906 registry
907 .extension_names(ExtensionType::ExtensionTable)
908 .len(),
909 1
910 );
911 assert_eq!(
912 registry.extension_names(ExtensionType::Enhancement).len(),
913 1
914 );
915
916 let mut ext_args = ExtensionArgs::default();
918 ext_args.insert("path", "data.parquet");
919 ext_args.insert("batch_size", 2048_i64);
920
921 let ext_any = registry
922 .parse_extension("TestExtension", &ext_args)
923 .unwrap();
924 assert_eq!(ext_any.type_url, "test.TestExtension");
925
926 let table_any = registry
928 .parse_extension_table("TestExtension", &ext_args)
929 .unwrap();
930 assert_eq!(table_any.type_url, "test.TestExtension");
931
932 let mut enh_args = ExtensionArgs::default();
934 enh_args.insert("hint", "optimize");
935
936 let enh_any = registry
937 .parse_enhancement("TestEnhancement", &enh_args)
938 .unwrap();
939 assert_eq!(enh_any.type_url, "test.TestExtension"); let enh_ref = enh_any.as_ref();
943 let (name, args) = registry.decode_enhancement(enh_ref).unwrap();
944 assert_eq!(name, "TestEnhancement");
945 assert_eq!(
946 <&str>::try_from(args.named.get("hint").unwrap()).unwrap(),
947 "test_hint"
948 );
949 }
950
951 #[test]
952 fn test_enhancement_duplicate_registration_returns_error() {
953 let mut registry = ExtensionRegistry::new();
954 registry.register_enhancement::<TestEnhancement>().unwrap();
955 let result = registry.register_enhancement::<TestEnhancement>();
956 assert!(matches!(
957 result,
958 Err(RegistrationError::DuplicateName { .. })
959 ));
960 }
961
962 #[test]
963 fn test_extension_table_duplicate_registration_returns_error() {
964 let mut registry = ExtensionRegistry::new();
965 registry
966 .register_extension_table::<TestExtension>()
967 .unwrap();
968 let result = registry.register_extension_table::<TestExtension>();
969 assert!(matches!(
970 result,
971 Err(RegistrationError::DuplicateName { .. })
972 ));
973 }
974
975 #[test]
976 fn test_extension_table_not_found_error() {
977 let registry = ExtensionRegistry::new();
978 let args = ExtensionArgs::default();
979 let result = registry.parse_extension_table("NonExistentExtensionTable", &args);
980 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
981 }
982
983 #[test]
984 fn test_enhancement_not_found_error() {
985 let registry = ExtensionRegistry::new();
986 let args = ExtensionArgs::default();
987 let result = registry.parse_enhancement("NonExistentEnhancement", &args);
988 assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
989 }
990
991 struct ConflictingExtension;
994
995 impl AnyConvertible for ConflictingExtension {
996 fn to_any(&self) -> Result<Any, ExtensionError> {
997 Ok(Any::new(Self::type_url(), vec![]))
998 }
999
1000 fn type_url() -> String {
1001 "test.TestExtension".to_string()
1003 }
1004
1005 fn from_any<'a>(_any: AnyRef<'a>) -> Result<Self, ExtensionError> {
1006 Ok(ConflictingExtension)
1007 }
1008 }
1009
1010 impl Explainable for ConflictingExtension {
1011 fn name() -> &'static str {
1012 "ConflictingExtension"
1013 }
1014
1015 fn from_args(_args: &ExtensionArgs) -> Result<Self, ExtensionError> {
1016 Ok(ConflictingExtension)
1017 }
1018
1019 fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
1020 Ok(ExtensionArgs::default())
1021 }
1022 }
1023
1024 #[test]
1025 fn test_conflicting_type_url_leaves_registry_unchanged() {
1026 let mut registry = ExtensionRegistry::new();
1027 registry.register_relation::<TestExtension>().unwrap();
1028
1029 let result = registry.register_relation::<ConflictingExtension>();
1031 assert!(matches!(
1032 result,
1033 Err(RegistrationError::ConflictingTypeUrl { .. })
1034 ));
1035
1036 assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
1038 assert!(!registry.has_extension(ExtensionType::Relation, "ConflictingExtension"));
1039 assert_eq!(
1040 registry.extension_names(ExtensionType::Relation),
1041 vec!["TestExtension"]
1042 );
1043 }
1044
1045 #[test]
1046 fn test_extension_table_conflicting_type_url_leaves_registry_unchanged() {
1047 let mut registry = ExtensionRegistry::new();
1048 registry
1049 .register_extension_table::<TestExtension>()
1050 .unwrap();
1051
1052 let result = registry.register_extension_table::<ConflictingExtension>();
1054 assert!(matches!(
1055 result,
1056 Err(RegistrationError::ConflictingTypeUrl { .. })
1057 ));
1058
1059 assert!(registry.has_extension(ExtensionType::ExtensionTable, "TestExtension"));
1061 assert!(!registry.has_extension(ExtensionType::ExtensionTable, "ConflictingExtension"));
1062 assert_eq!(
1063 registry.extension_names(ExtensionType::ExtensionTable),
1064 vec!["TestExtension"]
1065 );
1066 }
1067}