substrait_explain/extensions/
registry.rs

1//! Extension Registry for Custom Substrait Extension Relations
2//!
3//! This module provides a registry system for custom Substrait extension
4//! relations, allowing users to register their own extension types with custom
5//! parsing and textification logic.
6//!
7//! # Overview
8//!
9//! The extension registry allows users to:
10//! - Register custom extension handlers for specific extension names
11//! - Parse extension arguments/named arguments into `google.protobuf.Any`
12//!   detail fields
13//! - Textify extension detail fields back into readable text format
14//! - Support both compile-time and runtime extension registration
15//!
16//! # Architecture
17//!
18//! The system is built around several key traits:
19//! - `AnyConvertible`: For converting types to/from protobuf Any messages
20//! - `Explainable`: For converting types to/from ExtensionArgs
21//! - `ExtensionRegistry`: Registry for managing extension types
22//!
23//! # Example Usage
24//!
25//! ```rust
26//! use substrait_explain::extensions::{
27//!     Any, AnyConvertible, AnyRef, Explainable, ExtensionArgs, ExtensionError, ExtensionRegistry,
28//!     ExtensionRelationType, ExtensionValue,
29//! };
30//!
31//! // Define a custom extension type
32//! struct CustomScanConfig {
33//!     path: String,
34//! }
35//!
36//! // Implement AnyConvertible for protobuf serialization
37//! impl AnyConvertible for CustomScanConfig {
38//!     fn to_any(&self) -> Result<Any, ExtensionError> {
39//!         // For this example, we'll create a simple Any (protobuf details field) with the path
40//!         Ok(Any::new(Self::type_url(), self.path.as_bytes().to_vec()))
41//!     }
42//!
43//!     fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
44//!         // Deserialize from Any
45//!         let path = String::from_utf8(any.value.to_vec())
46//!             .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {}", e)))?;
47//!         Ok(CustomScanConfig { path })
48//!     }
49//!
50//!     fn type_url() -> String {
51//!         "type.googleapis.com/example.CustomScanConfig".to_string()
52//!     }
53//! }
54//!
55//! // Implement Explainable for text format conversion
56//! impl Explainable for CustomScanConfig {
57//!     fn name() -> &'static str {
58//!         "ParquetScan"
59//!     }
60//!
61//!     fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
62//!         let mut extractor = args.extractor();
63//!         let path: &str = extractor.expect_named_arg("path")?;
64//!         extractor.check_exhausted()?;
65//!         Ok(CustomScanConfig {
66//!             path: path.to_string(),
67//!         })
68//!     }
69//!
70//!     fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
71//!         let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
72//!         args.named.insert(
73//!             "path".to_string(),
74//!             ExtensionValue::String(self.path.clone()),
75//!         );
76//!         Ok(args)
77//!     }
78//! }
79//!
80//! // Register the extension type
81//! let mut registry = ExtensionRegistry::new();
82//! registry.register_relation::<CustomScanConfig>().unwrap();
83//! ```
84
85use std::collections::HashMap;
86use std::fmt;
87use std::sync::Arc;
88
89use thiserror::Error;
90
91use crate::extensions::any::{Any, AnyRef};
92use crate::extensions::args::ExtensionArgs;
93
94/// Type of extension in the registry, used for namespace separation.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum ExtensionType {
97    /// Relation extension (e.g., ExtensionLeaf, ExtensionSingle, ExtensionMulti)
98    Relation,
99    /// Enhancement attached to a relation (uses `+ Enh:` prefix in text format)
100    Enhancement,
101    /// Optimization attached to a relation (uses `+ Opt:` prefix in text format)
102    Optimization,
103}
104
105/// Errors during extension registration (setup phase)
106#[derive(Debug, Error, Clone)]
107pub enum RegistrationError {
108    #[error("{ext_type:?} extension '{name}' already registered")]
109    DuplicateName {
110        ext_type: ExtensionType,
111        name: String,
112    },
113
114    #[error("Type URL '{type_url}' already registered to {ext_type:?} extension '{existing_name}'")]
115    ConflictingTypeUrl {
116        type_url: String,
117        ext_type: ExtensionType,
118        existing_name: String,
119    },
120}
121
122/// Errors during extension parsing, formatting, and argument extraction (runtime)
123#[derive(Debug, Error, Clone)]
124pub enum ExtensionError {
125    /// Extension not found in registry during lookup
126    #[error("Extension '{name}' not found in registry")]
127    NotFound { name: String },
128
129    /// Required argument not present (from ArgsExtractor)
130    #[error("Missing required argument: {name}")]
131    MissingArgument { name: String },
132
133    /// Invalid argument value (from Explainable impls and ArgsExtractor)
134    #[error("Invalid argument: {0}")]
135    InvalidArgument(String),
136
137    /// Type URL mismatch during protobuf Any decode
138    #[error("Type URL mismatch: expected {expected}, got {actual}")]
139    TypeUrlMismatch { expected: String, actual: String },
140
141    /// Protobuf message decode failure
142    #[error("Failed to decode protobuf message")]
143    DecodeFailed(#[source] prost::DecodeError),
144
145    /// Protobuf message encode failure
146    #[error("Failed to encode protobuf message")]
147    EncodeFailed(#[source] prost::EncodeError),
148
149    /// Extension detail field is missing from the relation
150    #[error("Extension detail is missing")]
151    MissingDetail,
152
153    /// Error from a custom AnyConvertible implementation
154    #[error("{0}")]
155    Custom(String),
156}
157
158/// Trait for types that can be converted to/from protobuf Any messages. Note
159/// that this is already implemented for all prost::Message types. For custom
160/// types, implement this trait.
161pub trait AnyConvertible: Sized {
162    /// Convert this type to a protobuf Any message
163    fn to_any(&self) -> Result<Any, ExtensionError>;
164
165    /// Convert from a protobuf Any message to this type
166    fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError>;
167
168    /// Get the protobuf type URL for this type.
169    /// For prost::Message types, this is provided automatically via blanket impl.
170    /// Custom types must implement this method.
171    fn type_url() -> String;
172}
173
174// Blanket implementation for all prost::Message types
175impl<T> AnyConvertible for T
176where
177    T: prost::Message + prost::Name + Default,
178{
179    fn to_any(&self) -> Result<Any, ExtensionError> {
180        Any::encode(self)
181    }
182
183    fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
184        any.decode()
185    }
186
187    fn type_url() -> String {
188        T::type_url()
189    }
190}
191
192/// Trait for types that participate in text explanations.
193pub trait Explainable: Sized {
194    /// Canonical textual name for this extension. This is what appears in
195    /// Substrait text plans and how the registry identifies the type.
196    fn name() -> &'static str;
197
198    /// Parse extension arguments into this type
199    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError>;
200
201    /// Convert this type to extension arguments
202    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError>;
203}
204
205/// Internal trait that converts between ExtensionArgs and protobuf Any messages.
206///
207/// This trait exists because we need to store handlers for different extension types
208/// in a single HashMap. Since Rust doesn't allow trait objects with multiple traits
209/// (like `Box<dyn AnyConvertible + Explainable>`), we need a single trait that
210/// combines both operations.
211///
212/// The ExtensionConverter acts as a bridge between:
213/// - The text format representation (ExtensionArgs) used by the parser/formatter
214/// - The protobuf Any messages stored in Substrait extension relations
215///
216/// This design allows the registry to work with any type while maintaining type safety
217/// through the AnyConvertible and Explainable traits that users implement.
218trait ExtensionConverter: Send + Sync {
219    fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError>;
220
221    fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError>;
222}
223
224/// Type adapter that implements ExtensionConverter for any type T that implements
225/// both AnyConvertible and Explainable.
226///
227/// This struct exists to solve Rust's "trait object problem": we can't store
228/// `Box<dyn AnyConvertible + Explainable>` because that's two traits, not one.
229/// Instead, we store `Box<dyn ExtensionConverter>` and use this adapter to bridge
230/// from the two user-facing traits to our single internal trait.
231///
232/// The adapter pattern allows us to:
233/// 1. Keep a clean API where users only implement AnyConvertible and Explainable
234/// 2. Store different types in the same HashMap through type erasure
235/// 3. Maintain type safety - the concrete type T is known at registration time
236/// 4. Avoid any runtime type checking or unsafe code
237///
238/// The PhantomData is necessary because we don't actually store a T, but we need
239/// the type information to call T's static methods (from_args, from_any).
240struct ExtensionAdapter<T>(std::marker::PhantomData<T>);
241
242impl<T: AnyConvertible + Explainable + Send + Sync> ExtensionConverter for ExtensionAdapter<T> {
243    fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError> {
244        // Convert: ExtensionArgs -> T -> Any
245        T::from_args(args)?.to_any()
246    }
247
248    fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError> {
249        // Convert: AnyRef -> Any -> T -> ExtensionArgs
250        // Create an owned Any from the AnyRef to work with existing T::from_any
251        let owned_any = Any::new(detail.type_url.to_string(), detail.value.to_vec());
252        T::from_any(owned_any.as_ref())?.to_args()
253    }
254}
255
256pub trait Extension: AnyConvertible + Explainable + Send + Sync + 'static {}
257
258impl<T> Extension for T where T: AnyConvertible + Explainable + Send + Sync + 'static {}
259
260/// Registry for extension handlers
261#[derive(Default, Clone)]
262pub struct ExtensionRegistry {
263    // Composite key: (ExtensionType, name) -> handler
264    handlers: HashMap<(ExtensionType, String), Arc<dyn ExtensionConverter>>,
265    // Composite key: (ExtensionType, type_url) -> name
266    type_urls: HashMap<(ExtensionType, String), String>,
267}
268
269impl ExtensionRegistry {
270    /// Create a new empty extension registry
271    pub fn new() -> Self {
272        Self {
273            handlers: HashMap::new(),
274            type_urls: HashMap::new(),
275        }
276    }
277
278    /// Register an extension type with a specific ExtensionType
279    fn register<T>(&mut self, ext_type: ExtensionType) -> Result<(), RegistrationError>
280    where
281        T: Extension,
282    {
283        let canonical_name = T::name();
284        let type_url = T::type_url();
285        let handler: Arc<dyn ExtensionConverter> =
286            Arc::new(ExtensionAdapter::<T>(std::marker::PhantomData));
287
288        let key = (ext_type, canonical_name.to_string());
289        if self.handlers.contains_key(&key) {
290            return Err(RegistrationError::DuplicateName {
291                ext_type,
292                name: canonical_name.to_string(),
293            });
294        }
295
296        // Check for type URL conflicts before mutating any state
297        let type_url_key = (ext_type, type_url.clone());
298        if let Some(existing) = self.type_urls.get(&type_url_key)
299            && existing != canonical_name
300        {
301            return Err(RegistrationError::ConflictingTypeUrl {
302                type_url,
303                ext_type,
304                existing_name: existing.clone(),
305            });
306        }
307
308        // All checks passed — safe to mutate
309        self.handlers.insert(key, Arc::clone(&handler));
310        self.type_urls
311            .insert(type_url_key, canonical_name.to_string());
312        Ok(())
313    }
314
315    /// Register a relation extension type that implements both AnyConvertible and Explainable
316    ///
317    /// The canonical textual name comes from `T::name()`.
318    pub fn register_relation<T>(&mut self) -> Result<(), RegistrationError>
319    where
320        T: Extension,
321    {
322        self.register::<T>(ExtensionType::Relation)
323    }
324
325    /// Register an enhancement type that implements both AnyConvertible and Explainable
326    ///
327    /// Enhancements are registered in a separate namespace from relation extensions,
328    /// allowing the same type URL to exist in both namespaces without conflict.
329    ///
330    /// The canonical textual name comes from `T::name()`.
331    pub fn register_enhancement<T>(&mut self) -> Result<(), RegistrationError>
332    where
333        T: Extension,
334    {
335        self.register::<T>(ExtensionType::Enhancement)
336    }
337
338    /// Register an optimization type that implements both AnyConvertible and Explainable
339    ///
340    /// Optimizations are registered in a separate namespace from relation extensions,
341    /// allowing the same type URL to exist in both namespaces without conflict.
342    ///
343    /// The canonical textual name comes from `T::name()`.
344    pub fn register_optimization<T>(&mut self) -> Result<(), RegistrationError>
345    where
346        T: Extension,
347    {
348        self.register::<T>(ExtensionType::Optimization)
349    }
350
351    /// Parse extension arguments into a protobuf Any message
352    pub fn parse_extension(
353        &self,
354        extension_name: &str,
355        args: &ExtensionArgs,
356    ) -> Result<Any, ExtensionError> {
357        self.parse_with_type(ExtensionType::Relation, extension_name, args)
358    }
359
360    /// Parse enhancement arguments into a protobuf Any message
361    ///
362    /// Looks up the enhancement handler in the enhancement namespace and parses
363    /// the arguments into a protobuf Any message.
364    pub fn parse_enhancement(
365        &self,
366        enhancement_name: &str,
367        args: &ExtensionArgs,
368    ) -> Result<Any, ExtensionError> {
369        self.parse_with_type(ExtensionType::Enhancement, enhancement_name, args)
370    }
371
372    /// Parse optimization arguments into a protobuf Any message
373    ///
374    /// Looks up the optimization handler in the optimization namespace and parses
375    /// the arguments into a protobuf Any message.
376    pub fn parse_optimization(
377        &self,
378        optimization_name: &str,
379        args: &ExtensionArgs,
380    ) -> Result<Any, ExtensionError> {
381        self.parse_with_type(ExtensionType::Optimization, optimization_name, args)
382    }
383
384    /// Internal method to parse extension arguments with a specific ExtensionType
385    fn parse_with_type(
386        &self,
387        ext_type: ExtensionType,
388        name: &str,
389        args: &ExtensionArgs,
390    ) -> Result<Any, ExtensionError> {
391        let key = (ext_type, name.to_string());
392        let handler = self
393            .handlers
394            .get(&key)
395            .ok_or_else(|| ExtensionError::NotFound {
396                name: name.to_string(),
397            })?;
398        handler.parse_detail(args)
399    }
400
401    /// Decode extension detail to extension name and ExtensionArgs
402    /// This is the primary method for textification - given an AnyRef with extension detail,
403    /// decode it to the extension name and appropriate ExtensionArgs for display
404    pub fn decode(&self, detail: AnyRef<'_>) -> Result<(String, ExtensionArgs), ExtensionError> {
405        self.decode_with_type(ExtensionType::Relation, detail)
406    }
407
408    /// Decode enhancement detail to enhancement name and ExtensionArgs
409    ///
410    /// This is the primary method for textification of enhancements - given an AnyRef
411    /// with enhancement detail, decode it to the enhancement name and appropriate
412    /// ExtensionArgs for display.
413    ///
414    /// Looks up the enhancement handler in the enhancement namespace by type URL.
415    pub fn decode_enhancement(
416        &self,
417        detail: AnyRef<'_>,
418    ) -> Result<(String, ExtensionArgs), ExtensionError> {
419        self.decode_with_type(ExtensionType::Enhancement, detail)
420    }
421
422    /// Decode optimization detail to optimization name and ExtensionArgs
423    ///
424    /// This is the primary method for textification of optimizations - given an AnyRef
425    /// with optimization detail, decode it to the optimization name and appropriate
426    /// ExtensionArgs for display.
427    ///
428    /// Looks up the optimization handler in the optimization namespace by type URL.
429    pub fn decode_optimization(
430        &self,
431        detail: AnyRef<'_>,
432    ) -> Result<(String, ExtensionArgs), ExtensionError> {
433        self.decode_with_type(ExtensionType::Optimization, detail)
434    }
435
436    /// Internal method to decode extension detail with a specific ExtensionType
437    fn decode_with_type(
438        &self,
439        ext_type: ExtensionType,
440        detail: AnyRef<'_>,
441    ) -> Result<(String, ExtensionArgs), ExtensionError> {
442        // Find extension name by type URL in the specified namespace
443        let type_url_key = (ext_type, detail.type_url.to_string());
444        let extension_name =
445            self.type_urls
446                .get(&type_url_key)
447                .ok_or_else(|| ExtensionError::NotFound {
448                    name: detail.type_url.to_string(),
449                })?;
450
451        // Get handler and textify the detail
452        let name_key = (ext_type, extension_name.clone());
453        let handler = self
454            .handlers
455            .get(&name_key)
456            .ok_or_else(|| ExtensionError::NotFound {
457                name: extension_name.clone(),
458            })?;
459
460        let args = handler.textify_detail(detail)?;
461
462        Ok((extension_name.clone(), args))
463    }
464
465    /// Get all registered extension names for a specific ExtensionType
466    pub fn extension_names(&self, ext_type: ExtensionType) -> Vec<&str> {
467        let mut names: Vec<&str> = self
468            .type_urls
469            .iter()
470            .filter_map(|((t, _), name)| {
471                if *t == ext_type {
472                    Some(name.as_str())
473                } else {
474                    None
475                }
476            })
477            .collect();
478        names.sort_unstable();
479        names.dedup();
480        names
481    }
482
483    /// Check if an extension is registered for a specific ExtensionType
484    pub fn has_extension(&self, ext_type: ExtensionType, name: &str) -> bool {
485        self.handlers.contains_key(&(ext_type, name.to_string()))
486    }
487}
488
489impl fmt::Debug for ExtensionRegistry {
490    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
491        let mut keys: Vec<_> = self
492            .handlers
493            .keys()
494            .map(|(t, n)| (format!("{t:?}"), n.as_str()))
495            .collect();
496        keys.sort();
497        f.debug_struct("ExtensionRegistry")
498            .field("handlers", &keys)
499            .finish()
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use crate::extensions::{ExtensionColumn, ExtensionRelationType, ExtensionValue};
507
508    // Mock type for testing
509    struct TestExtension {
510        path: String,
511        batch_size: i64,
512    }
513
514    // Manual implementation of AnyConvertible for testing (without prost)
515    impl AnyConvertible for TestExtension {
516        fn to_any(&self) -> Result<Any, ExtensionError> {
517            // Simple test implementation - create Any with JSON-like bytes
518            let json_str = format!(
519                r#"{{"path":"{}","batch_size":{}}}"#,
520                self.path, self.batch_size
521            );
522            Ok(Any::new(Self::type_url(), json_str.into_bytes()))
523        }
524
525        fn type_url() -> String {
526            "test.TestExtension".to_string()
527        }
528
529        fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
530            // Simple test implementation - parse from JSON-like bytes
531            let json_str = String::from_utf8(any.value.to_vec())
532                .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
533
534            // Simple manual parsing for test
535            if json_str.contains("path") && json_str.contains("batch_size") {
536                Ok(TestExtension {
537                    path: "test.parquet".to_string(),
538                    batch_size: 1024,
539                })
540            } else {
541                Err(ExtensionError::Custom("Missing fields".to_string()))
542            }
543        }
544    }
545
546    impl Explainable for TestExtension {
547        fn name() -> &'static str {
548            "TestExtension"
549        }
550
551        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
552            let mut extractor = args.extractor();
553            let path: String = extractor.expect_named_arg::<&str>("path")?.to_string();
554            let batch_size: i64 = extractor.expect_named_arg("batch_size")?;
555            extractor.check_exhausted()?;
556
557            Ok(TestExtension {
558                path: path.to_string(),
559                batch_size,
560            })
561        }
562
563        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
564            let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
565            args.named.insert(
566                "path".to_string(),
567                ExtensionValue::String(self.path.clone()),
568            );
569            args.named.insert(
570                "batch_size".to_string(),
571                ExtensionValue::Integer(self.batch_size),
572            );
573            Ok(args)
574        }
575    }
576
577    #[test]
578    fn test_extension_registry_basic() {
579        let mut registry = ExtensionRegistry::new();
580
581        // Initially empty
582        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 0);
583        assert!(!registry.has_extension(ExtensionType::Relation, "TestExtension"));
584
585        // Register extension type
586        registry.register_relation::<TestExtension>().unwrap();
587
588        // Now has extension
589        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
590        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
591
592        // Test parse and textify
593        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
594        args.named.insert(
595            "path".to_string(),
596            ExtensionValue::String("data.parquet".to_string()),
597        );
598        args.named
599            .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
600
601        let any = registry.parse_extension("TestExtension", &args).unwrap();
602        assert_eq!(any.type_url, "test.TestExtension");
603
604        let any_ref = any.as_ref();
605        let result = registry.decode(any_ref).unwrap();
606        assert_eq!(result.0, "TestExtension");
607        match result.1.named.get("path") {
608            Some(ExtensionValue::String(s)) => assert_eq!(s, "test.parquet"), // Due to our simple test impl
609            _ => panic!("Expected String for path"),
610        }
611    }
612
613    #[test]
614    fn test_extension_args() {
615        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
616
617        // Add named args
618        args.named.insert(
619            "path".to_string(),
620            ExtensionValue::String("data/*.parquet".to_string()),
621        );
622        args.named
623            .insert("batch_size".to_string(), ExtensionValue::Integer(1024));
624
625        // Add positional args
626        args.positional.push(ExtensionValue::Reference(0));
627
628        // Add output columns
629        args.output_columns.push(ExtensionColumn::Named {
630            name: "col1".to_string(),
631            type_spec: "i32".to_string(),
632        });
633
634        // Test retrieval - use extractor
635        let mut extractor = args.extractor();
636
637        match extractor.get_named_arg("path") {
638            Some(ExtensionValue::String(s)) => assert_eq!(s, "data/*.parquet"),
639            _ => panic!("Expected String for path"),
640        }
641
642        match extractor.get_named_arg("batch_size") {
643            Some(ExtensionValue::Integer(i)) => assert_eq!(*i, 1024),
644            _ => panic!("Expected Integer for batch_size"),
645        }
646
647        // Verify they were consumed
648        assert!(extractor.check_exhausted().is_ok());
649
650        assert_eq!(args.positional.len(), 1);
651        assert_eq!(args.output_columns.len(), 1);
652    }
653
654    #[test]
655    fn test_extension_error_cases() {
656        let registry = ExtensionRegistry::new();
657
658        // Extension not found
659        let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
660        let result = registry.parse_extension("NonExistent", &args);
661        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
662
663        // Missing argument
664        let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
665        let mut extractor = args.extractor();
666        let result = extractor.get_named_arg("missing");
667        assert!(result.is_none());
668        assert!(extractor.check_exhausted().is_ok());
669
670        // Type check example
671        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
672        args.named
673            .insert("test".to_string(), ExtensionValue::Integer(42));
674        let mut extractor = args.extractor();
675        let result = extractor.get_named_arg("test");
676        match result {
677            Some(ExtensionValue::Integer(42)) => {} // Expected
678            _ => panic!("Expected Integer(42), got {result:?}"),
679        }
680        assert!(extractor.check_exhausted().is_ok());
681    }
682
683    // Mock enhancement type for testing namespace separation
684    struct TestEnhancement {
685        hint: String,
686    }
687
688    impl AnyConvertible for TestEnhancement {
689        fn to_any(&self) -> Result<Any, ExtensionError> {
690            let json_str = format!(r#"{{"hint":"{}"}}"#, self.hint);
691            Ok(Any::new(Self::type_url(), json_str.into_bytes()))
692        }
693
694        fn type_url() -> String {
695            // Same type URL as TestExtension to test namespace separation
696            "test.TestExtension".to_string()
697        }
698
699        fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
700            let json_str = String::from_utf8(any.value.to_vec())
701                .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
702            if json_str.contains("hint") {
703                Ok(TestEnhancement {
704                    hint: "test_hint".to_string(),
705                })
706            } else {
707                Err(ExtensionError::Custom("Missing hint field".to_string()))
708            }
709        }
710    }
711
712    impl Explainable for TestEnhancement {
713        fn name() -> &'static str {
714            "TestEnhancement"
715        }
716
717        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
718            let mut extractor = args.extractor();
719            let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_string();
720            extractor.check_exhausted()?;
721            Ok(TestEnhancement { hint })
722        }
723
724        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
725            let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
726            args.named.insert(
727                "hint".to_string(),
728                ExtensionValue::String(self.hint.clone()),
729            );
730            Ok(args)
731        }
732    }
733
734    #[test]
735    fn test_namespace_separation() {
736        let mut registry = ExtensionRegistry::new();
737
738        // Register same type URL in both namespaces - should not conflict
739        registry.register_relation::<TestExtension>().unwrap();
740        registry.register_enhancement::<TestEnhancement>().unwrap();
741
742        // Verify both are registered
743        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
744        assert!(registry.has_extension(ExtensionType::Enhancement, "TestEnhancement"));
745        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
746        assert_eq!(
747            registry.extension_names(ExtensionType::Enhancement).len(),
748            1
749        );
750
751        // Test that extension namespace works
752        let mut ext_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
753        ext_args.named.insert(
754            "path".to_string(),
755            ExtensionValue::String("data.parquet".to_string()),
756        );
757        ext_args
758            .named
759            .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
760
761        let ext_any = registry
762            .parse_extension("TestExtension", &ext_args)
763            .unwrap();
764        assert_eq!(ext_any.type_url, "test.TestExtension");
765
766        // Test that enhancement namespace works
767        let mut enh_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
768        enh_args.named.insert(
769            "hint".to_string(),
770            ExtensionValue::String("optimize".to_string()),
771        );
772
773        let enh_any = registry
774            .parse_enhancement("TestEnhancement", &enh_args)
775            .unwrap();
776        assert_eq!(enh_any.type_url, "test.TestExtension"); // Same type URL!
777
778        // Test decode_enhancement
779        let enh_ref = enh_any.as_ref();
780        let (name, args) = registry.decode_enhancement(enh_ref).unwrap();
781        assert_eq!(name, "TestEnhancement");
782        match args.named.get("hint") {
783            Some(ExtensionValue::String(s)) => assert_eq!(s, "test_hint"), // Due to test impl
784            _ => panic!("Expected String for hint"),
785        }
786    }
787
788    #[test]
789    fn test_enhancement_duplicate_registration_returns_error() {
790        let mut registry = ExtensionRegistry::new();
791        registry.register_enhancement::<TestEnhancement>().unwrap();
792        let result = registry.register_enhancement::<TestEnhancement>();
793        assert!(matches!(
794            result,
795            Err(RegistrationError::DuplicateName { .. })
796        ));
797    }
798
799    #[test]
800    fn test_enhancement_not_found_error() {
801        let registry = ExtensionRegistry::new();
802        let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
803        let result = registry.parse_enhancement("NonExistentEnhancement", &args);
804        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
805    }
806
807    // Extension with same type URL as TestExtension but different name,
808    // used to test that conflicting type URLs don't leave stale state.
809    struct ConflictingExtension;
810
811    impl AnyConvertible for ConflictingExtension {
812        fn to_any(&self) -> Result<Any, ExtensionError> {
813            Ok(Any::new(Self::type_url(), vec![]))
814        }
815
816        fn type_url() -> String {
817            // Same type URL as TestExtension — will conflict in the same namespace
818            "test.TestExtension".to_string()
819        }
820
821        fn from_any<'a>(_any: AnyRef<'a>) -> Result<Self, ExtensionError> {
822            Ok(ConflictingExtension)
823        }
824    }
825
826    impl Explainable for ConflictingExtension {
827        fn name() -> &'static str {
828            "ConflictingExtension"
829        }
830
831        fn from_args(_args: &ExtensionArgs) -> Result<Self, ExtensionError> {
832            Ok(ConflictingExtension)
833        }
834
835        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
836            Ok(ExtensionArgs::new(ExtensionRelationType::Leaf))
837        }
838    }
839
840    #[test]
841    fn test_conflicting_type_url_leaves_registry_unchanged() {
842        let mut registry = ExtensionRegistry::new();
843        registry.register_relation::<TestExtension>().unwrap();
844
845        // Attempt to register a different extension with the same type URL
846        let result = registry.register_relation::<ConflictingExtension>();
847        assert!(matches!(
848            result,
849            Err(RegistrationError::ConflictingTypeUrl { .. })
850        ));
851
852        // Registry should still only know about the original extension
853        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
854        assert!(!registry.has_extension(ExtensionType::Relation, "ConflictingExtension"));
855        assert_eq!(
856            registry.extension_names(ExtensionType::Relation),
857            vec!["TestExtension"]
858        );
859    }
860}