substrait_explain/extensions/
registry.rs

1//! Extension Registry for Custom Substrait Extension Relations
2//!
3//! This module provides a registry system for custom Substrait extension
4//! relations, allowing users to register their own extension types with custom
5//! parsing and textification logic.
6//!
7//! # Overview
8//!
9//! The extension registry allows users to:
10//! - Register custom extension handlers for specific extension names
11//! - Parse extension arguments/named arguments into `google.protobuf.Any`
12//!   detail fields
13//! - Textify extension detail fields back into readable text format
14//! - Support both compile-time and runtime extension registration
15//!
16//! # Architecture
17//!
18//! The system is built around several key traits:
19//! - `AnyConvertible`: For converting types to/from protobuf Any messages
20//! - `Explainable`: For converting types to/from ExtensionArgs
21//! - `ExtensionRegistry`: Registry for managing extension types
22//!
23//! # Example Usage
24//!
25//! ```rust
26//! use substrait_explain::extensions::{
27//!     Any, AnyConvertible, AnyRef, Explainable, ExtensionArgs, ExtensionError, ExtensionRegistry,
28//!     ExtensionRelationType, ExtensionValue,
29//! };
30//!
31//! // Define a custom extension type
32//! struct CustomScanConfig {
33//!     path: String,
34//! }
35//!
36//! // Implement AnyConvertible for protobuf serialization
37//! impl AnyConvertible for CustomScanConfig {
38//!     fn to_any(&self) -> Result<Any, ExtensionError> {
39//!         // For this example, we'll create a simple Any (protobuf details field) with the path
40//!         Ok(Any::new(Self::type_url(), self.path.as_bytes().to_vec()))
41//!     }
42//!
43//!     fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
44//!         // Deserialize from Any
45//!         let path = String::from_utf8(any.value.to_vec())
46//!             .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {}", e)))?;
47//!         Ok(CustomScanConfig { path })
48//!     }
49//!
50//!     fn type_url() -> String {
51//!         "type.googleapis.com/example.CustomScanConfig".to_string()
52//!     }
53//! }
54//!
55//! // Implement Explainable for text format conversion
56//! impl Explainable for CustomScanConfig {
57//!     fn name() -> &'static str {
58//!         "ParquetScan"
59//!     }
60//!
61//!     fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
62//!         let mut extractor = args.extractor();
63//!         let path: &str = extractor.expect_named_arg("path")?;
64//!         extractor.check_exhausted()?;
65//!         Ok(CustomScanConfig {
66//!             path: path.to_string(),
67//!         })
68//!     }
69//!
70//!     fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
71//!         let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
72//!         args.named.insert(
73//!             "path".to_string(),
74//!             ExtensionValue::String(self.path.clone()),
75//!         );
76//!         Ok(args)
77//!     }
78//! }
79//!
80//! // Register the extension type
81//! let mut registry = ExtensionRegistry::new();
82//! registry.register_relation::<CustomScanConfig>().unwrap();
83//! ```
84
85use std::collections::HashMap;
86use std::fmt;
87use std::sync::Arc;
88
89use thiserror::Error;
90
91use crate::extensions::any::{Any, AnyRef};
92use crate::extensions::args::ExtensionArgs;
93
94/// Type of extension in the registry, used for namespace separation.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum ExtensionType {
97    /// Relation extension (e.g., ExtensionLeaf, ExtensionSingle, ExtensionMulti)
98    Relation,
99    /// Enhancement attached to a relation (uses `+ Enh:` prefix in text format)
100    Enhancement,
101    /// Optimization attached to a relation (uses `+ Opt:` prefix in text format)
102    Optimization,
103}
104
105/// Errors during extension registration (setup phase)
106#[derive(Debug, Error, Clone)]
107pub enum RegistrationError {
108    #[error("{ext_type:?} extension '{name}' already registered")]
109    DuplicateName {
110        ext_type: ExtensionType,
111        name: String,
112    },
113
114    #[error("Type URL '{type_url}' already registered to {ext_type:?} extension '{existing_name}'")]
115    ConflictingTypeUrl {
116        type_url: String,
117        ext_type: ExtensionType,
118        existing_name: String,
119    },
120}
121
122/// Errors during extension parsing, formatting, and argument extraction (runtime)
123#[derive(Debug, Error, Clone)]
124pub enum ExtensionError {
125    /// Extension not found in registry during lookup
126    #[error("Extension '{name}' not found in registry")]
127    NotFound { name: String },
128
129    /// Required argument not present (from ArgsExtractor)
130    #[error("Missing required argument: {name}")]
131    MissingArgument { name: String },
132
133    /// Invalid argument value (from Explainable impls and ArgsExtractor)
134    #[error("Invalid argument: {0}")]
135    InvalidArgument(String),
136
137    /// Type URL mismatch during protobuf Any decode
138    #[error("Type URL mismatch: expected {expected}, got {actual}")]
139    TypeUrlMismatch { expected: String, actual: String },
140
141    /// Protobuf message decode failure
142    #[error("Failed to decode protobuf message")]
143    DecodeFailed(#[source] prost::DecodeError),
144
145    /// Protobuf message encode failure
146    #[error("Failed to encode protobuf message")]
147    EncodeFailed(#[source] prost::EncodeError),
148
149    /// Extension detail field is missing from the relation
150    #[error("Extension detail is missing")]
151    MissingDetail,
152
153    /// Error from a custom AnyConvertible implementation
154    #[error("{0}")]
155    Custom(String),
156}
157
158/// Trait for types that can be converted to/from protobuf Any messages. Note
159/// that this is already implemented for all prost::Message types. For custom
160/// types, implement this trait.
161pub trait AnyConvertible: Sized {
162    /// Convert this type to a protobuf Any message
163    fn to_any(&self) -> Result<Any, ExtensionError>;
164
165    /// Convert from a protobuf Any message to this type
166    fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError>;
167
168    /// Get the protobuf type URL for this type.
169    /// For prost::Message types, this is provided automatically via blanket impl.
170    /// Custom types must implement this method.
171    fn type_url() -> String;
172}
173
174// Blanket implementation for all prost::Message types
175impl<T> AnyConvertible for T
176where
177    T: prost::Message + prost::Name + Default,
178{
179    fn to_any(&self) -> Result<Any, ExtensionError> {
180        Any::encode(self)
181    }
182
183    fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
184        any.decode()
185    }
186
187    fn type_url() -> String {
188        T::type_url()
189    }
190}
191
192/// Trait for types that participate in text explanations.
193pub trait Explainable: Sized {
194    /// Canonical textual name for this extension. This is what appears in
195    /// Substrait text plans and how the registry identifies the type.
196    fn name() -> &'static str;
197
198    /// Parse extension arguments into this type
199    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError>;
200
201    /// Convert this type to extension arguments
202    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError>;
203}
204
205/// Internal trait that converts between ExtensionArgs and protobuf Any messages.
206///
207/// This trait exists because we need to store handlers for different extension types
208/// in a single HashMap. Since Rust doesn't allow trait objects with multiple traits
209/// (like `Box<dyn AnyConvertible + Explainable>`), we need a single trait that
210/// combines both operations.
211///
212/// The ExtensionConverter acts as a bridge between:
213/// - The text format representation (ExtensionArgs) used by the parser/formatter
214/// - The protobuf Any messages stored in Substrait extension relations
215///
216/// This design allows the registry to work with any type while maintaining type safety
217/// through the AnyConvertible and Explainable traits that users implement.
218trait ExtensionConverter: Send + Sync {
219    fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError>;
220
221    fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError>;
222}
223
224/// Type adapter that implements ExtensionConverter for any type T that implements
225/// both AnyConvertible and Explainable.
226///
227/// This struct exists to solve Rust's "trait object problem": we can't store
228/// `Box<dyn AnyConvertible + Explainable>` because that's two traits, not one.
229/// Instead, we store `Box<dyn ExtensionConverter>` and use this adapter to bridge
230/// from the two user-facing traits to our single internal trait.
231///
232/// The adapter pattern allows us to:
233/// 1. Keep a clean API where users only implement AnyConvertible and Explainable
234/// 2. Store different types in the same HashMap through type erasure
235/// 3. Maintain type safety - the concrete type T is known at registration time
236/// 4. Avoid any runtime type checking or unsafe code
237///
238/// The PhantomData is necessary because we don't actually store a T, but we need
239/// the type information to call T's static methods (from_args, from_any).
240struct ExtensionAdapter<T>(std::marker::PhantomData<T>);
241
242impl<T: AnyConvertible + Explainable + Send + Sync> ExtensionConverter for ExtensionAdapter<T> {
243    fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError> {
244        // Convert: ExtensionArgs -> T -> Any
245        T::from_args(args)?.to_any()
246    }
247
248    fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError> {
249        // Convert: AnyRef -> Any -> T -> ExtensionArgs
250        // Create an owned Any from the AnyRef to work with existing T::from_any
251        let owned_any = Any::new(detail.type_url.to_string(), detail.value.to_vec());
252        T::from_any(owned_any.as_ref())?.to_args()
253    }
254}
255
256pub trait Extension: AnyConvertible + Explainable + Send + Sync + 'static {}
257
258impl<T> Extension for T where T: AnyConvertible + Explainable + Send + Sync + 'static {}
259
260/// Registry for extension handlers
261#[derive(Default, Clone)]
262pub struct ExtensionRegistry {
263    // Composite key: (ExtensionType, name) -> handler
264    handlers: HashMap<(ExtensionType, String), Arc<dyn ExtensionConverter>>,
265    // Composite key: (ExtensionType, type_url) -> name
266    type_urls: HashMap<(ExtensionType, String), String>,
267    // Compiled proto FileDescriptorSet blobs for extension types.
268    // Used by the JSON parser to resolve google.protobuf.Any type URLs in Go
269    // protojson input. Register these alongside the Rust handler so that a
270    // single registry carries all extension knowledge for both formatting and
271    // JSON parsing.
272    descriptors: Vec<Vec<u8>>,
273}
274
275impl ExtensionRegistry {
276    /// Create a new empty extension registry
277    pub fn new() -> Self {
278        Self {
279            handlers: HashMap::new(),
280            type_urls: HashMap::new(),
281            descriptors: Vec::new(),
282        }
283    }
284
285    /// Register a compiled proto `FileDescriptorSet` blob for extension types.
286    ///
287    /// Required when parsing extensions for plans that contain
288    /// `google.protobuf.Any` fields that use standard JSON encoding (with
289    /// `@type` for the type_url) whose types are not part of the Substrait core
290    /// schema. Pass the bytes of a compiled `.bin` descriptor, e.g.
291    /// `include_bytes!("my_extensions.bin")`.
292    pub fn add_descriptor(&mut self, bytes: Vec<u8>) {
293        self.descriptors.push(bytes);
294    }
295
296    /// Returns slices of all registered descriptor blobs.
297    pub fn descriptors(&self) -> Vec<&[u8]> {
298        self.descriptors.iter().map(|b| b.as_slice()).collect()
299    }
300
301    /// Register an extension type with a specific ExtensionType
302    fn register<T>(&mut self, ext_type: ExtensionType) -> Result<(), RegistrationError>
303    where
304        T: Extension,
305    {
306        let canonical_name = T::name();
307        let type_url = T::type_url();
308        let handler: Arc<dyn ExtensionConverter> =
309            Arc::new(ExtensionAdapter::<T>(std::marker::PhantomData));
310
311        let key = (ext_type, canonical_name.to_string());
312        if self.handlers.contains_key(&key) {
313            return Err(RegistrationError::DuplicateName {
314                ext_type,
315                name: canonical_name.to_string(),
316            });
317        }
318
319        // Check for type URL conflicts before mutating any state
320        let type_url_key = (ext_type, type_url.clone());
321        if let Some(existing) = self.type_urls.get(&type_url_key)
322            && existing != canonical_name
323        {
324            return Err(RegistrationError::ConflictingTypeUrl {
325                type_url,
326                ext_type,
327                existing_name: existing.clone(),
328            });
329        }
330
331        // All checks passed — safe to mutate
332        self.handlers.insert(key, Arc::clone(&handler));
333        self.type_urls
334            .insert(type_url_key, canonical_name.to_string());
335        Ok(())
336    }
337
338    /// Register a relation extension type that implements both AnyConvertible and Explainable
339    ///
340    /// The canonical textual name comes from `T::name()`.
341    pub fn register_relation<T>(&mut self) -> Result<(), RegistrationError>
342    where
343        T: Extension,
344    {
345        self.register::<T>(ExtensionType::Relation)
346    }
347
348    /// Register an enhancement type that implements both AnyConvertible and Explainable
349    ///
350    /// Enhancements are registered in a separate namespace from relation extensions,
351    /// allowing the same type URL to exist in both namespaces without conflict.
352    ///
353    /// The canonical textual name comes from `T::name()`.
354    pub fn register_enhancement<T>(&mut self) -> Result<(), RegistrationError>
355    where
356        T: Extension,
357    {
358        self.register::<T>(ExtensionType::Enhancement)
359    }
360
361    /// Register an optimization type that implements both AnyConvertible and Explainable
362    ///
363    /// Optimizations are registered in a separate namespace from relation extensions,
364    /// allowing the same type URL to exist in both namespaces without conflict.
365    ///
366    /// The canonical textual name comes from `T::name()`.
367    pub fn register_optimization<T>(&mut self) -> Result<(), RegistrationError>
368    where
369        T: Extension,
370    {
371        self.register::<T>(ExtensionType::Optimization)
372    }
373
374    /// Parse extension arguments into a protobuf Any message
375    pub fn parse_extension(
376        &self,
377        extension_name: &str,
378        args: &ExtensionArgs,
379    ) -> Result<Any, ExtensionError> {
380        self.parse_with_type(ExtensionType::Relation, extension_name, args)
381    }
382
383    /// Parse enhancement arguments into a protobuf Any message
384    ///
385    /// Looks up the enhancement handler in the enhancement namespace and parses
386    /// the arguments into a protobuf Any message.
387    pub fn parse_enhancement(
388        &self,
389        enhancement_name: &str,
390        args: &ExtensionArgs,
391    ) -> Result<Any, ExtensionError> {
392        self.parse_with_type(ExtensionType::Enhancement, enhancement_name, args)
393    }
394
395    /// Parse optimization arguments into a protobuf Any message
396    ///
397    /// Looks up the optimization handler in the optimization namespace and parses
398    /// the arguments into a protobuf Any message.
399    pub fn parse_optimization(
400        &self,
401        optimization_name: &str,
402        args: &ExtensionArgs,
403    ) -> Result<Any, ExtensionError> {
404        self.parse_with_type(ExtensionType::Optimization, optimization_name, args)
405    }
406
407    /// Internal method to parse extension arguments with a specific ExtensionType
408    fn parse_with_type(
409        &self,
410        ext_type: ExtensionType,
411        name: &str,
412        args: &ExtensionArgs,
413    ) -> Result<Any, ExtensionError> {
414        let key = (ext_type, name.to_string());
415        let handler = self
416            .handlers
417            .get(&key)
418            .ok_or_else(|| ExtensionError::NotFound {
419                name: name.to_string(),
420            })?;
421        handler.parse_detail(args)
422    }
423
424    /// Decode extension detail to extension name and ExtensionArgs
425    /// This is the primary method for textification - given an AnyRef with extension detail,
426    /// decode it to the extension name and appropriate ExtensionArgs for display
427    pub fn decode(&self, detail: AnyRef<'_>) -> Result<(String, ExtensionArgs), ExtensionError> {
428        self.decode_with_type(ExtensionType::Relation, detail)
429    }
430
431    /// Decode enhancement detail to enhancement name and ExtensionArgs
432    ///
433    /// This is the primary method for textification of enhancements - given an AnyRef
434    /// with enhancement detail, decode it to the enhancement name and appropriate
435    /// ExtensionArgs for display.
436    ///
437    /// Looks up the enhancement handler in the enhancement namespace by type URL.
438    pub fn decode_enhancement(
439        &self,
440        detail: AnyRef<'_>,
441    ) -> Result<(String, ExtensionArgs), ExtensionError> {
442        self.decode_with_type(ExtensionType::Enhancement, detail)
443    }
444
445    /// Decode optimization detail to optimization name and ExtensionArgs
446    ///
447    /// This is the primary method for textification of optimizations - given an AnyRef
448    /// with optimization detail, decode it to the optimization name and appropriate
449    /// ExtensionArgs for display.
450    ///
451    /// Looks up the optimization handler in the optimization namespace by type URL.
452    pub fn decode_optimization(
453        &self,
454        detail: AnyRef<'_>,
455    ) -> Result<(String, ExtensionArgs), ExtensionError> {
456        self.decode_with_type(ExtensionType::Optimization, detail)
457    }
458
459    /// Internal method to decode extension detail with a specific ExtensionType
460    fn decode_with_type(
461        &self,
462        ext_type: ExtensionType,
463        detail: AnyRef<'_>,
464    ) -> Result<(String, ExtensionArgs), ExtensionError> {
465        // Find extension name by type URL in the specified namespace
466        let type_url_key = (ext_type, detail.type_url.to_string());
467        let extension_name =
468            self.type_urls
469                .get(&type_url_key)
470                .ok_or_else(|| ExtensionError::NotFound {
471                    name: detail.type_url.to_string(),
472                })?;
473
474        // Get handler and textify the detail
475        let name_key = (ext_type, extension_name.clone());
476        let handler = self
477            .handlers
478            .get(&name_key)
479            .ok_or_else(|| ExtensionError::NotFound {
480                name: extension_name.clone(),
481            })?;
482
483        let args = handler.textify_detail(detail)?;
484
485        Ok((extension_name.clone(), args))
486    }
487
488    /// Get all registered extension names for a specific ExtensionType
489    pub fn extension_names(&self, ext_type: ExtensionType) -> Vec<&str> {
490        let mut names: Vec<&str> = self
491            .type_urls
492            .iter()
493            .filter_map(|((t, _), name)| {
494                if *t == ext_type {
495                    Some(name.as_str())
496                } else {
497                    None
498                }
499            })
500            .collect();
501        names.sort_unstable();
502        names.dedup();
503        names
504    }
505
506    /// Check if an extension is registered for a specific ExtensionType
507    pub fn has_extension(&self, ext_type: ExtensionType, name: &str) -> bool {
508        self.handlers.contains_key(&(ext_type, name.to_string()))
509    }
510}
511
512impl fmt::Debug for ExtensionRegistry {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        let mut keys: Vec<_> = self
515            .handlers
516            .keys()
517            .map(|(t, n)| (format!("{t:?}"), n.as_str()))
518            .collect();
519        keys.sort();
520        f.debug_struct("ExtensionRegistry")
521            .field("handlers", &keys)
522            .finish()
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use crate::extensions::{ExtensionColumn, ExtensionRelationType, ExtensionValue};
530
531    // Mock type for testing
532    struct TestExtension {
533        path: String,
534        batch_size: i64,
535    }
536
537    // Manual implementation of AnyConvertible for testing (without prost)
538    impl AnyConvertible for TestExtension {
539        fn to_any(&self) -> Result<Any, ExtensionError> {
540            // Simple test implementation - create Any with JSON-like bytes
541            let json_str = format!(
542                r#"{{"path":"{}","batch_size":{}}}"#,
543                self.path, self.batch_size
544            );
545            Ok(Any::new(Self::type_url(), json_str.into_bytes()))
546        }
547
548        fn type_url() -> String {
549            "test.TestExtension".to_string()
550        }
551
552        fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
553            // Simple test implementation - parse from JSON-like bytes
554            let json_str = String::from_utf8(any.value.to_vec())
555                .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
556
557            // Simple manual parsing for test
558            if json_str.contains("path") && json_str.contains("batch_size") {
559                Ok(TestExtension {
560                    path: "test.parquet".to_string(),
561                    batch_size: 1024,
562                })
563            } else {
564                Err(ExtensionError::Custom("Missing fields".to_string()))
565            }
566        }
567    }
568
569    impl Explainable for TestExtension {
570        fn name() -> &'static str {
571            "TestExtension"
572        }
573
574        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
575            let mut extractor = args.extractor();
576            let path: String = extractor.expect_named_arg::<&str>("path")?.to_string();
577            let batch_size: i64 = extractor.expect_named_arg("batch_size")?;
578            extractor.check_exhausted()?;
579
580            Ok(TestExtension {
581                path: path.to_string(),
582                batch_size,
583            })
584        }
585
586        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
587            let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
588            args.named.insert(
589                "path".to_string(),
590                ExtensionValue::String(self.path.clone()),
591            );
592            args.named.insert(
593                "batch_size".to_string(),
594                ExtensionValue::Integer(self.batch_size),
595            );
596            Ok(args)
597        }
598    }
599
600    #[test]
601    fn test_extension_registry_basic() {
602        let mut registry = ExtensionRegistry::new();
603
604        // Initially empty
605        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 0);
606        assert!(!registry.has_extension(ExtensionType::Relation, "TestExtension"));
607
608        // Register extension type
609        registry.register_relation::<TestExtension>().unwrap();
610
611        // Now has extension
612        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
613        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
614
615        // Test parse and textify
616        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
617        args.named.insert(
618            "path".to_string(),
619            ExtensionValue::String("data.parquet".to_string()),
620        );
621        args.named
622            .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
623
624        let any = registry.parse_extension("TestExtension", &args).unwrap();
625        assert_eq!(any.type_url, "test.TestExtension");
626
627        let any_ref = any.as_ref();
628        let result = registry.decode(any_ref).unwrap();
629        assert_eq!(result.0, "TestExtension");
630        match result.1.named.get("path") {
631            Some(ExtensionValue::String(s)) => assert_eq!(s, "test.parquet"), // Due to our simple test impl
632            _ => panic!("Expected String for path"),
633        }
634    }
635
636    #[test]
637    fn test_extension_args() {
638        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
639
640        // Add named args
641        args.named.insert(
642            "path".to_string(),
643            ExtensionValue::String("data/*.parquet".to_string()),
644        );
645        args.named
646            .insert("batch_size".to_string(), ExtensionValue::Integer(1024));
647
648        // Add positional args
649        args.positional.push(ExtensionValue::Reference(0));
650
651        // Add output columns
652        args.output_columns.push(ExtensionColumn::Named {
653            name: "col1".to_string(),
654            type_spec: "i32".to_string(),
655        });
656
657        // Test retrieval - use extractor
658        let mut extractor = args.extractor();
659
660        match extractor.get_named_arg("path") {
661            Some(ExtensionValue::String(s)) => assert_eq!(s, "data/*.parquet"),
662            _ => panic!("Expected String for path"),
663        }
664
665        match extractor.get_named_arg("batch_size") {
666            Some(ExtensionValue::Integer(i)) => assert_eq!(*i, 1024),
667            _ => panic!("Expected Integer for batch_size"),
668        }
669
670        // Verify they were consumed
671        assert!(extractor.check_exhausted().is_ok());
672
673        assert_eq!(args.positional.len(), 1);
674        assert_eq!(args.output_columns.len(), 1);
675    }
676
677    #[test]
678    fn test_extension_error_cases() {
679        let registry = ExtensionRegistry::new();
680
681        // Extension not found
682        let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
683        let result = registry.parse_extension("NonExistent", &args);
684        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
685
686        // Missing argument
687        let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
688        let mut extractor = args.extractor();
689        let result = extractor.get_named_arg("missing");
690        assert!(result.is_none());
691        assert!(extractor.check_exhausted().is_ok());
692
693        // Type check example
694        let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
695        args.named
696            .insert("test".to_string(), ExtensionValue::Integer(42));
697        let mut extractor = args.extractor();
698        let result = extractor.get_named_arg("test");
699        match result {
700            Some(ExtensionValue::Integer(42)) => {} // Expected
701            _ => panic!("Expected Integer(42), got {result:?}"),
702        }
703        assert!(extractor.check_exhausted().is_ok());
704    }
705
706    // Mock enhancement type for testing namespace separation
707    struct TestEnhancement {
708        hint: String,
709    }
710
711    impl AnyConvertible for TestEnhancement {
712        fn to_any(&self) -> Result<Any, ExtensionError> {
713            let json_str = format!(r#"{{"hint":"{}"}}"#, self.hint);
714            Ok(Any::new(Self::type_url(), json_str.into_bytes()))
715        }
716
717        fn type_url() -> String {
718            // Same type URL as TestExtension to test namespace separation
719            "test.TestExtension".to_string()
720        }
721
722        fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
723            let json_str = String::from_utf8(any.value.to_vec())
724                .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
725            if json_str.contains("hint") {
726                Ok(TestEnhancement {
727                    hint: "test_hint".to_string(),
728                })
729            } else {
730                Err(ExtensionError::Custom("Missing hint field".to_string()))
731            }
732        }
733    }
734
735    impl Explainable for TestEnhancement {
736        fn name() -> &'static str {
737            "TestEnhancement"
738        }
739
740        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
741            let mut extractor = args.extractor();
742            let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_string();
743            extractor.check_exhausted()?;
744            Ok(TestEnhancement { hint })
745        }
746
747        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
748            let mut args = ExtensionArgs::new(ExtensionRelationType::Leaf);
749            args.named.insert(
750                "hint".to_string(),
751                ExtensionValue::String(self.hint.clone()),
752            );
753            Ok(args)
754        }
755    }
756
757    #[test]
758    fn test_namespace_separation() {
759        let mut registry = ExtensionRegistry::new();
760
761        // Register same type URL in both namespaces - should not conflict
762        registry.register_relation::<TestExtension>().unwrap();
763        registry.register_enhancement::<TestEnhancement>().unwrap();
764
765        // Verify both are registered
766        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
767        assert!(registry.has_extension(ExtensionType::Enhancement, "TestEnhancement"));
768        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
769        assert_eq!(
770            registry.extension_names(ExtensionType::Enhancement).len(),
771            1
772        );
773
774        // Test that extension namespace works
775        let mut ext_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
776        ext_args.named.insert(
777            "path".to_string(),
778            ExtensionValue::String("data.parquet".to_string()),
779        );
780        ext_args
781            .named
782            .insert("batch_size".to_string(), ExtensionValue::Integer(2048));
783
784        let ext_any = registry
785            .parse_extension("TestExtension", &ext_args)
786            .unwrap();
787        assert_eq!(ext_any.type_url, "test.TestExtension");
788
789        // Test that enhancement namespace works
790        let mut enh_args = ExtensionArgs::new(ExtensionRelationType::Leaf);
791        enh_args.named.insert(
792            "hint".to_string(),
793            ExtensionValue::String("optimize".to_string()),
794        );
795
796        let enh_any = registry
797            .parse_enhancement("TestEnhancement", &enh_args)
798            .unwrap();
799        assert_eq!(enh_any.type_url, "test.TestExtension"); // Same type URL!
800
801        // Test decode_enhancement
802        let enh_ref = enh_any.as_ref();
803        let (name, args) = registry.decode_enhancement(enh_ref).unwrap();
804        assert_eq!(name, "TestEnhancement");
805        match args.named.get("hint") {
806            Some(ExtensionValue::String(s)) => assert_eq!(s, "test_hint"), // Due to test impl
807            _ => panic!("Expected String for hint"),
808        }
809    }
810
811    #[test]
812    fn test_enhancement_duplicate_registration_returns_error() {
813        let mut registry = ExtensionRegistry::new();
814        registry.register_enhancement::<TestEnhancement>().unwrap();
815        let result = registry.register_enhancement::<TestEnhancement>();
816        assert!(matches!(
817            result,
818            Err(RegistrationError::DuplicateName { .. })
819        ));
820    }
821
822    #[test]
823    fn test_enhancement_not_found_error() {
824        let registry = ExtensionRegistry::new();
825        let args = ExtensionArgs::new(ExtensionRelationType::Leaf);
826        let result = registry.parse_enhancement("NonExistentEnhancement", &args);
827        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
828    }
829
830    // Extension with same type URL as TestExtension but different name,
831    // used to test that conflicting type URLs don't leave stale state.
832    struct ConflictingExtension;
833
834    impl AnyConvertible for ConflictingExtension {
835        fn to_any(&self) -> Result<Any, ExtensionError> {
836            Ok(Any::new(Self::type_url(), vec![]))
837        }
838
839        fn type_url() -> String {
840            // Same type URL as TestExtension — will conflict in the same namespace
841            "test.TestExtension".to_string()
842        }
843
844        fn from_any<'a>(_any: AnyRef<'a>) -> Result<Self, ExtensionError> {
845            Ok(ConflictingExtension)
846        }
847    }
848
849    impl Explainable for ConflictingExtension {
850        fn name() -> &'static str {
851            "ConflictingExtension"
852        }
853
854        fn from_args(_args: &ExtensionArgs) -> Result<Self, ExtensionError> {
855            Ok(ConflictingExtension)
856        }
857
858        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
859            Ok(ExtensionArgs::new(ExtensionRelationType::Leaf))
860        }
861    }
862
863    #[test]
864    fn test_conflicting_type_url_leaves_registry_unchanged() {
865        let mut registry = ExtensionRegistry::new();
866        registry.register_relation::<TestExtension>().unwrap();
867
868        // Attempt to register a different extension with the same type URL
869        let result = registry.register_relation::<ConflictingExtension>();
870        assert!(matches!(
871            result,
872            Err(RegistrationError::ConflictingTypeUrl { .. })
873        ));
874
875        // Registry should still only know about the original extension
876        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
877        assert!(!registry.has_extension(ExtensionType::Relation, "ConflictingExtension"));
878        assert_eq!(
879            registry.extension_names(ExtensionType::Relation),
880            vec!["TestExtension"]
881        );
882    }
883}