Skip to main content

substrait_explain/extensions/
registry.rs

1//! Registry for custom Substrait advanced extension payloads.
2//!
3//! This module lets users register handlers for advanced extensions that carry
4//! `google.protobuf.Any` detail payloads: custom relation types, relation
5//! enhancements, and optimization hints.
6//!
7//! # Overview
8//!
9//! The extension registry allows users to:
10//! - Register custom extension handlers in relation, enhancement, or
11//!   optimization namespaces
12//! - Parse extension arguments/named arguments into `google.protobuf.Any`
13//!   detail fields
14//! - Textify extension detail fields back into readable text format
15//! - Keep each registered payload's protobuf type URL associated with its
16//!   canonical text-format name
17//!
18//! # Architecture
19//!
20//! The system is built around several key traits:
21//! - `AnyConvertible`: For converting types to/from protobuf Any messages
22//! - `Explainable`: For converting types to/from ExtensionArgs
23//! - `ExtensionRegistry`: Registry for managing extension types
24//!
25//! # Example Usage
26//!
27//! ```rust
28//! use substrait_explain::extensions::{
29//!     Any, AnyConvertible, AnyRef, Explainable, ExtensionArgs, ExtensionError, ExtensionRegistry,
30//! };
31//!
32//! // Define a custom extension type
33//! struct CustomScanConfig {
34//!     path: String,
35//! }
36//!
37//! // Implement AnyConvertible for protobuf serialization
38//! impl AnyConvertible for CustomScanConfig {
39//!     fn to_any(&self) -> Result<Any, ExtensionError> {
40//!         // For this example, we'll create a simple Any (protobuf details field) with the path
41//!         Ok(Any::new(Self::type_url(), self.path.as_bytes().to_vec()))
42//!     }
43//!
44//!     fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
45//!         // Deserialize from Any
46//!         let path = String::from_utf8(any.value.to_vec())
47//!             .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {}", e)))?;
48//!         Ok(CustomScanConfig { path })
49//!     }
50//!
51//!     fn type_url() -> String {
52//!         "type.googleapis.com/example.CustomScanConfig".to_string()
53//!     }
54//! }
55//!
56//! // Implement Explainable for text format conversion
57//! impl Explainable for CustomScanConfig {
58//!     fn name() -> &'static str {
59//!         "ParquetScan"
60//!     }
61//!
62//!     fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
63//!         let mut extractor = args.extractor();
64//!         let path: &str = extractor.expect_named_arg("path")?;
65//!         extractor.check_exhausted()?;
66//!         Ok(CustomScanConfig {
67//!             path: path.to_string(),
68//!         })
69//!     }
70//!
71//!     fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
72//!         let mut args = ExtensionArgs::default();
73//!         args.insert("path", self.path.clone());
74//!         Ok(args)
75//!     }
76//! }
77//!
78//! // Register the extension type
79//! let mut registry = ExtensionRegistry::new();
80//! registry.register_relation::<CustomScanConfig>().unwrap();
81//! ```
82
83use std::collections::HashMap;
84use std::fmt;
85use std::sync::Arc;
86
87use substrait::proto::NamedStruct;
88use substrait::proto::r#type::{Nullability, Struct};
89use thiserror::Error;
90
91use crate::extensions::any::{Any, AnyRef};
92use crate::extensions::args::{ExtensionArgs, ExtensionColumn, ExtensionValueKind};
93
94/// Type of extension in the registry, used for namespace separation.
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
96pub enum ExtensionType {
97    /// Relation extension (e.g., ExtensionLeaf, ExtensionSingle, ExtensionMulti)
98    Relation,
99    /// ExtensionTable detail attached to a ReadRel (uses `+ Ext:` prefix in text format)
100    ExtensionTable,
101    /// Enhancement attached to a relation (uses `+ Enh:` prefix in text format)
102    Enhancement,
103    /// Optimization attached to a relation (uses `+ Opt:` prefix in text format)
104    Optimization,
105}
106
107/// Errors during extension registration (setup phase)
108#[derive(Debug, Error, Clone)]
109pub enum RegistrationError {
110    #[error("{ext_type:?} extension '{name}' already registered")]
111    DuplicateName {
112        ext_type: ExtensionType,
113        name: String,
114    },
115
116    #[error("Type URL '{type_url}' already registered to {ext_type:?} extension '{existing_name}'")]
117    ConflictingTypeUrl {
118        type_url: String,
119        ext_type: ExtensionType,
120        existing_name: String,
121    },
122}
123
124/// Errors during extension parsing, formatting, and argument extraction (runtime)
125#[derive(Debug, Error, Clone)]
126pub enum ExtensionError {
127    /// Extension not found in registry during lookup
128    #[error("Extension '{name}' not found in registry")]
129    NotFound { name: String },
130
131    /// Required argument not present (from ArgsExtractor)
132    #[error("Missing required argument: {name}")]
133    MissingArgument { name: String },
134
135    /// Invalid argument type found while extracting an extension argument.
136    #[error("Invalid argument: expected {expected}, got {actual}")]
137    InvalidArgumentType {
138        expected: ExtensionValueKind,
139        actual: ExtensionValueKind,
140    },
141
142    /// Invalid argument value with a custom diagnostic.
143    ///
144    /// Prefer structured variants for common mechanical validation failures.
145    /// Use this for domain-specific validation from `Explainable`
146    /// implementations.
147    #[error("Invalid argument: {0}")]
148    InvalidArgument(String),
149
150    /// Type URL mismatch during protobuf Any decode
151    #[error("Type URL mismatch: expected {expected}, got {actual}")]
152    TypeUrlMismatch { expected: String, actual: String },
153
154    /// Protobuf message decode failure
155    #[error("Failed to decode protobuf message")]
156    DecodeFailed(#[source] prost::DecodeError),
157
158    /// Protobuf message encode failure
159    #[error("Failed to encode protobuf message")]
160    EncodeFailed(#[source] prost::EncodeError),
161
162    /// Extension detail field is missing from the relation
163    #[error("Extension detail is missing")]
164    MissingDetail,
165
166    /// Error from a custom AnyConvertible implementation
167    #[error("{0}")]
168    Custom(String),
169}
170
171/// Trait for types that can be converted to/from protobuf Any messages. Note
172/// that this is already implemented for all prost::Message types. For custom
173/// types, implement this trait.
174pub trait AnyConvertible: Sized {
175    /// Convert this type to a protobuf Any message
176    fn to_any(&self) -> Result<Any, ExtensionError>;
177
178    /// Convert from a protobuf Any message to this type
179    fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError>;
180
181    /// Get the protobuf type URL for this type.
182    /// For prost::Message types, this is provided automatically via blanket impl.
183    /// Custom types must implement this method.
184    fn type_url() -> String;
185}
186
187// Blanket implementation for all prost::Message types
188impl<T> AnyConvertible for T
189where
190    T: prost::Message + prost::Name + Default,
191{
192    fn to_any(&self) -> Result<Any, ExtensionError> {
193        Any::encode(self)
194    }
195
196    fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
197        any.decode()
198    }
199
200    fn type_url() -> String {
201        T::type_url()
202    }
203}
204
205/// Conversion between extension arguments and Substrait protobuf values.
206///
207/// Extension arguments are the structured values exposed to [`Explainable`]
208/// implementations, such as [`ExtensionArgs`],
209/// [`ExtensionValue`](crate::extensions::ExtensionValue), and
210/// [`ExtensionColumn`]. This trait adapts those values to and from protobuf
211/// types without going through text.
212///
213/// Implementations may convert in either direction; the target type `T`
214/// determines the direction.
215pub trait ExtensionProtoConvert<T> {
216    /// Convert this value into `T`.
217    fn convert(&self) -> Result<T, ExtensionError>;
218}
219
220impl ExtensionProtoConvert<NamedStruct> for [ExtensionColumn] {
221    fn convert(&self) -> Result<NamedStruct, ExtensionError> {
222        let mut names = Vec::with_capacity(self.len());
223        let mut types = Vec::with_capacity(self.len());
224        for col in self {
225            match col {
226                ExtensionColumn::Named { name, r#type: ty } => {
227                    names.push(name.clone());
228                    types.push(ty.clone());
229                }
230                other => {
231                    return Err(ExtensionError::InvalidArgument(format!(
232                        "Expected named column, got {other:?}"
233                    )));
234                }
235            }
236        }
237        Ok(NamedStruct {
238            names,
239            r#struct: Some(Struct {
240                types,
241                type_variation_reference: 0,
242                // In Substrait, the schema of a type is defined as
243                // non-nullable; you can have an empty schema (no columns), but
244                // not a null schema.
245                nullability: Nullability::Required as i32,
246            }),
247        })
248    }
249}
250
251impl ExtensionProtoConvert<Vec<ExtensionColumn>> for NamedStruct {
252    fn convert(&self) -> Result<Vec<ExtensionColumn>, ExtensionError> {
253        let types = self
254            .r#struct
255            .as_ref()
256            .map(|s| s.types.as_slice())
257            .unwrap_or_default();
258        if self.names.len() != types.len() {
259            return Err(ExtensionError::InvalidArgument(format!(
260                "NamedStruct has {} names but {} types",
261                self.names.len(),
262                types.len()
263            )));
264        }
265        Ok(self
266            .names
267            .iter()
268            .zip(types.iter())
269            .map(|(name, ty)| ExtensionColumn::Named {
270                name: name.clone(),
271                r#type: ty.clone(),
272            })
273            .collect())
274    }
275}
276
277/// Trait for types that participate in text explanations.
278pub trait Explainable: Sized {
279    /// Canonical textual name for this extension. This is what appears in
280    /// Substrait text plans and how the registry identifies the type.
281    fn name() -> &'static str;
282
283    /// Parse extension arguments into this type
284    fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError>;
285
286    /// Convert this type to extension arguments
287    fn to_args(&self) -> Result<ExtensionArgs, ExtensionError>;
288}
289
290/// Internal trait that converts between ExtensionArgs and protobuf Any messages.
291///
292/// This trait exists because we need to store handlers for different extension types
293/// in a single HashMap. Since Rust doesn't allow trait objects with multiple traits
294/// (like `Box<dyn AnyConvertible + Explainable>`), we need a single trait that
295/// combines both operations.
296///
297/// The ExtensionConverter acts as a bridge between:
298/// - The text format representation (ExtensionArgs) used by the parser/formatter
299/// - The protobuf Any messages stored in Substrait advanced extension payloads
300///
301/// This design allows the registry to work with any type while maintaining type safety
302/// through the AnyConvertible and Explainable traits that users implement.
303trait ExtensionConverter: Send + Sync {
304    fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError>;
305
306    fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError>;
307}
308
309/// Type adapter that implements ExtensionConverter for any type T that implements
310/// both AnyConvertible and Explainable.
311///
312/// This struct exists to solve Rust's "trait object problem": we can't store
313/// `Box<dyn AnyConvertible + Explainable>` because that's two traits, not one.
314/// Instead, we store `Box<dyn ExtensionConverter>` and use this adapter to bridge
315/// from the two user-facing traits to our single internal trait.
316///
317/// The adapter pattern allows us to:
318/// 1. Keep a clean API where users only implement AnyConvertible and Explainable
319/// 2. Store different types in the same HashMap through type erasure
320/// 3. Maintain type safety - the concrete type T is known at registration time
321/// 4. Avoid any runtime type checking or unsafe code
322///
323/// The PhantomData is necessary because we don't actually store a T, but we need
324/// the type information to call T's static methods (from_args, from_any).
325struct ExtensionAdapter<T>(std::marker::PhantomData<T>);
326
327impl<T: AnyConvertible + Explainable + Send + Sync> ExtensionConverter for ExtensionAdapter<T> {
328    fn parse_detail(&self, args: &ExtensionArgs) -> Result<Any, ExtensionError> {
329        T::from_args(args)?.to_any()
330    }
331
332    fn textify_detail(&self, detail: AnyRef<'_>) -> Result<ExtensionArgs, ExtensionError> {
333        let owned_any = Any::new(detail.type_url.to_string(), detail.value.to_vec());
334        T::from_any(owned_any.as_ref())?.to_args()
335    }
336}
337
338pub trait Extension: AnyConvertible + Explainable + Send + Sync + 'static {}
339
340impl<T> Extension for T where T: AnyConvertible + Explainable + Send + Sync + 'static {}
341
342/// Registry for extension handlers
343#[derive(Default, Clone)]
344pub struct ExtensionRegistry {
345    // Composite key: (ExtensionType, name) -> handler
346    handlers: HashMap<(ExtensionType, String), Arc<dyn ExtensionConverter>>,
347    // Composite key: (ExtensionType, type_url) -> name
348    type_urls: HashMap<(ExtensionType, String), String>,
349    // Compiled proto FileDescriptorSet blobs for extension types.
350    // Used by the JSON parser to resolve google.protobuf.Any type URLs in Go
351    // protojson input. Register these alongside the Rust handler so that a
352    // single registry carries all extension knowledge for both formatting and
353    // JSON parsing.
354    descriptors: Vec<Vec<u8>>,
355}
356
357impl ExtensionRegistry {
358    /// Create a new empty extension registry
359    pub fn new() -> Self {
360        Self {
361            handlers: HashMap::new(),
362            type_urls: HashMap::new(),
363            descriptors: Vec::new(),
364        }
365    }
366
367    /// Register a compiled proto `FileDescriptorSet` blob for extension types.
368    ///
369    /// Required when parsing extensions for plans that contain
370    /// `google.protobuf.Any` fields that use standard JSON encoding (with
371    /// `@type` for the type_url) whose types are not part of the Substrait core
372    /// schema. Pass the bytes of a compiled `.bin` descriptor, e.g.
373    /// `include_bytes!("my_extensions.bin")`.
374    pub fn add_descriptor(&mut self, bytes: Vec<u8>) {
375        self.descriptors.push(bytes);
376    }
377
378    /// Returns slices of all registered descriptor blobs.
379    pub fn descriptors(&self) -> Vec<&[u8]> {
380        self.descriptors.iter().map(|b| b.as_slice()).collect()
381    }
382
383    /// Register an extension type with a specific ExtensionType
384    fn register<T>(&mut self, ext_type: ExtensionType) -> Result<(), RegistrationError>
385    where
386        T: Extension,
387    {
388        let canonical_name = T::name();
389        let type_url = T::type_url();
390        let handler: Arc<dyn ExtensionConverter> =
391            Arc::new(ExtensionAdapter::<T>(std::marker::PhantomData));
392
393        let key = (ext_type, canonical_name.to_string());
394        if self.handlers.contains_key(&key) {
395            return Err(RegistrationError::DuplicateName {
396                ext_type,
397                name: canonical_name.to_string(),
398            });
399        }
400
401        // Check for type URL conflicts before mutating any state
402        let type_url_key = (ext_type, type_url.clone());
403        if let Some(existing) = self.type_urls.get(&type_url_key)
404            && existing != canonical_name
405        {
406            return Err(RegistrationError::ConflictingTypeUrl {
407                type_url,
408                ext_type,
409                existing_name: existing.clone(),
410            });
411        }
412
413        // All checks passed — safe to mutate
414        self.handlers.insert(key, Arc::clone(&handler));
415        self.type_urls
416            .insert(type_url_key, canonical_name.to_string());
417        Ok(())
418    }
419
420    /// Register a relation extension type that implements both AnyConvertible and Explainable
421    ///
422    /// The canonical textual name comes from `T::name()`.
423    pub fn register_relation<T>(&mut self) -> Result<(), RegistrationError>
424    where
425        T: Extension,
426    {
427        self.register::<T>(ExtensionType::Relation)
428    }
429
430    /// Register an ExtensionTable detail type that implements both AnyConvertible and Explainable
431    ///
432    /// ExtensionTable details are registered in a separate namespace from
433    /// extension relations, allowing the same type URL to exist in both namespaces
434    /// without conflict.
435    ///
436    /// The canonical textual name comes from `T::name()`.
437    pub fn register_extension_table<T>(&mut self) -> Result<(), RegistrationError>
438    where
439        T: Extension,
440    {
441        self.register::<T>(ExtensionType::ExtensionTable)
442    }
443
444    /// Register an enhancement type that implements both AnyConvertible and Explainable
445    ///
446    /// Enhancements are registered in a separate namespace from relation extensions,
447    /// allowing the same type URL to exist in both namespaces without conflict.
448    ///
449    /// The canonical textual name comes from `T::name()`.
450    pub fn register_enhancement<T>(&mut self) -> Result<(), RegistrationError>
451    where
452        T: Extension,
453    {
454        self.register::<T>(ExtensionType::Enhancement)
455    }
456
457    /// Register an optimization type that implements both AnyConvertible and Explainable
458    ///
459    /// Optimizations are registered in a separate namespace from relation extensions,
460    /// allowing the same type URL to exist in both namespaces without conflict.
461    ///
462    /// The canonical textual name comes from `T::name()`.
463    pub fn register_optimization<T>(&mut self) -> Result<(), RegistrationError>
464    where
465        T: Extension,
466    {
467        self.register::<T>(ExtensionType::Optimization)
468    }
469
470    /// Parse extension arguments into a protobuf Any message
471    pub fn parse_extension(
472        &self,
473        extension_name: &str,
474        args: &ExtensionArgs,
475    ) -> Result<Any, ExtensionError> {
476        self.parse_with_type(ExtensionType::Relation, extension_name, args)
477    }
478
479    /// Parse ExtensionTable arguments into a protobuf Any message
480    ///
481    /// Looks up the ExtensionTable detail handler in the ExtensionTable namespace
482    /// and parses the arguments into a protobuf Any message.
483    pub fn parse_extension_table(
484        &self,
485        extension_table_name: &str,
486        args: &ExtensionArgs,
487    ) -> Result<Any, ExtensionError> {
488        self.parse_with_type(ExtensionType::ExtensionTable, extension_table_name, args)
489    }
490
491    /// Parse enhancement arguments into a protobuf Any message
492    ///
493    /// Looks up the enhancement handler in the enhancement namespace and parses
494    /// the arguments into a protobuf Any message.
495    pub fn parse_enhancement(
496        &self,
497        enhancement_name: &str,
498        args: &ExtensionArgs,
499    ) -> Result<Any, ExtensionError> {
500        self.parse_with_type(ExtensionType::Enhancement, enhancement_name, args)
501    }
502
503    /// Parse optimization arguments into a protobuf Any message
504    ///
505    /// Looks up the optimization handler in the optimization namespace and parses
506    /// the arguments into a protobuf Any message.
507    pub fn parse_optimization(
508        &self,
509        optimization_name: &str,
510        args: &ExtensionArgs,
511    ) -> Result<Any, ExtensionError> {
512        self.parse_with_type(ExtensionType::Optimization, optimization_name, args)
513    }
514
515    /// Internal method to parse extension arguments with a specific ExtensionType
516    fn parse_with_type(
517        &self,
518        ext_type: ExtensionType,
519        name: &str,
520        args: &ExtensionArgs,
521    ) -> Result<Any, ExtensionError> {
522        let key = (ext_type, name.to_string());
523        let handler = self
524            .handlers
525            .get(&key)
526            .ok_or_else(|| ExtensionError::NotFound {
527                name: name.to_string(),
528            })?;
529        handler.parse_detail(args)
530    }
531
532    /// Decode extension detail to extension name and ExtensionArgs
533    /// This is the primary method for textification - given an AnyRef with extension detail,
534    /// decode it to the extension name and appropriate ExtensionArgs for display
535    pub fn decode(&self, detail: AnyRef<'_>) -> Result<(String, ExtensionArgs), ExtensionError> {
536        self.decode_with_type(ExtensionType::Relation, detail)
537    }
538
539    /// Decode ExtensionTable detail to extension name and ExtensionArgs
540    ///
541    /// This is the primary method for textification of ExtensionTable reads -
542    /// given an AnyRef with ExtensionTable detail, decode it to the extension
543    /// name and appropriate ExtensionArgs for display.
544    pub fn decode_extension_table(
545        &self,
546        detail: AnyRef<'_>,
547    ) -> Result<(String, ExtensionArgs), ExtensionError> {
548        self.decode_with_type(ExtensionType::ExtensionTable, detail)
549    }
550
551    /// Decode enhancement detail to enhancement name and ExtensionArgs
552    ///
553    /// This is the primary method for textification of enhancements - given an AnyRef
554    /// with enhancement detail, decode it to the enhancement name and appropriate
555    /// ExtensionArgs for display.
556    ///
557    /// Looks up the enhancement handler in the enhancement namespace by type URL.
558    pub fn decode_enhancement(
559        &self,
560        detail: AnyRef<'_>,
561    ) -> Result<(String, ExtensionArgs), ExtensionError> {
562        self.decode_with_type(ExtensionType::Enhancement, detail)
563    }
564
565    /// Decode optimization detail to optimization name and ExtensionArgs
566    ///
567    /// This is the primary method for textification of optimizations - given an AnyRef
568    /// with optimization detail, decode it to the optimization name and appropriate
569    /// ExtensionArgs for display.
570    ///
571    /// Looks up the optimization handler in the optimization namespace by type URL.
572    pub fn decode_optimization(
573        &self,
574        detail: AnyRef<'_>,
575    ) -> Result<(String, ExtensionArgs), ExtensionError> {
576        self.decode_with_type(ExtensionType::Optimization, detail)
577    }
578
579    /// Internal method to decode extension detail with a specific ExtensionType
580    fn decode_with_type(
581        &self,
582        ext_type: ExtensionType,
583        detail: AnyRef<'_>,
584    ) -> Result<(String, ExtensionArgs), ExtensionError> {
585        // Find extension name by type URL in the specified namespace
586        let type_url_key = (ext_type, detail.type_url.to_string());
587        let extension_name =
588            self.type_urls
589                .get(&type_url_key)
590                .ok_or_else(|| ExtensionError::NotFound {
591                    name: detail.type_url.to_string(),
592                })?;
593
594        // Get handler and textify the detail
595        let name_key = (ext_type, extension_name.clone());
596        let handler = self
597            .handlers
598            .get(&name_key)
599            .ok_or_else(|| ExtensionError::NotFound {
600                name: extension_name.clone(),
601            })?;
602
603        let args = handler.textify_detail(detail)?;
604
605        Ok((extension_name.clone(), args))
606    }
607
608    /// Get all registered extension names for a specific ExtensionType
609    pub fn extension_names(&self, ext_type: ExtensionType) -> Vec<&str> {
610        let mut names: Vec<&str> = self
611            .type_urls
612            .iter()
613            .filter_map(|((t, _), name)| {
614                if *t == ext_type {
615                    Some(name.as_str())
616                } else {
617                    None
618                }
619            })
620            .collect();
621        names.sort_unstable();
622        names.dedup();
623        names
624    }
625
626    /// Check if an extension is registered for a specific ExtensionType
627    pub fn has_extension(&self, ext_type: ExtensionType, name: &str) -> bool {
628        self.handlers.contains_key(&(ext_type, name.to_string()))
629    }
630}
631
632impl fmt::Debug for ExtensionRegistry {
633    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634        let mut keys: Vec<_> = self
635            .handlers
636            .keys()
637            .map(|(t, n)| (format!("{t:?}"), n.as_str()))
638            .collect();
639        keys.sort();
640        f.debug_struct("ExtensionRegistry")
641            .field("handlers", &keys)
642            .finish()
643    }
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use crate::extensions::ExtensionColumn;
650
651    // Mock type for testing
652    struct TestExtension {
653        path: String,
654        batch_size: i64,
655    }
656
657    // Manual implementation of AnyConvertible for testing (without prost)
658    impl AnyConvertible for TestExtension {
659        fn to_any(&self) -> Result<Any, ExtensionError> {
660            // Simple test implementation - create Any with JSON-like bytes
661            let json_str = format!(
662                r#"{{"path":"{}","batch_size":{}}}"#,
663                self.path, self.batch_size
664            );
665            Ok(Any::new(Self::type_url(), json_str.into_bytes()))
666        }
667
668        fn type_url() -> String {
669            "test.TestExtension".to_string()
670        }
671
672        fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
673            // Simple test implementation - parse from JSON-like bytes
674            let json_str = String::from_utf8(any.value.to_vec())
675                .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
676
677            // Simple manual parsing for test
678            if json_str.contains("path") && json_str.contains("batch_size") {
679                Ok(TestExtension {
680                    path: "test.parquet".to_string(),
681                    batch_size: 1024,
682                })
683            } else {
684                Err(ExtensionError::Custom("Missing fields".to_string()))
685            }
686        }
687    }
688
689    impl Explainable for TestExtension {
690        fn name() -> &'static str {
691            "TestExtension"
692        }
693
694        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
695            let mut extractor = args.extractor();
696            let path: String = extractor.expect_named_arg::<&str>("path")?.to_string();
697            let batch_size: i64 = extractor.expect_named_arg("batch_size")?;
698            extractor.check_exhausted()?;
699
700            Ok(TestExtension {
701                path: path.to_string(),
702                batch_size,
703            })
704        }
705
706        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
707            let mut args = ExtensionArgs::default();
708            args.insert("path", self.path.clone());
709            args.insert("batch_size", self.batch_size);
710            Ok(args)
711        }
712    }
713
714    #[test]
715    fn test_extension_registry_basic() {
716        let mut registry = ExtensionRegistry::new();
717
718        // Initially empty
719        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 0);
720        assert_eq!(
721            registry
722                .extension_names(ExtensionType::ExtensionTable)
723                .len(),
724            0
725        );
726        assert!(!registry.has_extension(ExtensionType::Relation, "TestExtension"));
727
728        // Register extension type
729        registry.register_relation::<TestExtension>().unwrap();
730
731        // Now has extension
732        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
733        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
734
735        // Test parse and textify
736        let mut args = ExtensionArgs::default();
737        args.insert("path", "data.parquet");
738        args.insert("batch_size", 2048_i64);
739
740        let any = registry.parse_extension("TestExtension", &args).unwrap();
741        assert_eq!(any.type_url, "test.TestExtension");
742
743        let any_ref = any.as_ref();
744        let result = registry.decode(any_ref).unwrap();
745        assert_eq!(result.0, "TestExtension");
746        assert_eq!(
747            <&str>::try_from(result.1.named.get("path").unwrap()).unwrap(),
748            "test.parquet"
749        );
750    }
751
752    #[test]
753    fn test_extension_table_registry_basic() {
754        let mut registry = ExtensionRegistry::new();
755
756        registry
757            .register_extension_table::<TestExtension>()
758            .unwrap();
759
760        assert_eq!(
761            registry.extension_names(ExtensionType::ExtensionTable),
762            vec!["TestExtension"]
763        );
764        assert!(registry.has_extension(ExtensionType::ExtensionTable, "TestExtension"));
765
766        let mut args = ExtensionArgs::default();
767        args.insert("path", "data.parquet");
768        args.insert("batch_size", 2048_i64);
769
770        let any = registry
771            .parse_extension_table("TestExtension", &args)
772            .unwrap();
773        assert_eq!(any.type_url, "test.TestExtension");
774
775        let (name, decoded_args) = registry.decode_extension_table(any.as_ref()).unwrap();
776        assert_eq!(name, "TestExtension");
777        assert_eq!(
778            <&str>::try_from(decoded_args.named.get("path").unwrap()).unwrap(),
779            "test.parquet"
780        );
781    }
782
783    #[test]
784    fn test_extension_args() {
785        let mut args = ExtensionArgs::default();
786
787        // Add named args
788        args.insert("path", "data/*.parquet");
789        args.insert("batch_size", 1024_i64);
790
791        // Add positional args
792        args.push(crate::textify::expressions::Reference(0));
793
794        // Add output columns
795        args.output_columns.push(ExtensionColumn::Named {
796            name: "col1".to_string(),
797            r#type: crate::fixtures::parse_type("i32"),
798        });
799
800        // Test retrieval - use extractor
801        let mut extractor = args.extractor();
802
803        let path = extractor.get_named_arg("path").unwrap();
804        assert_eq!(<&str>::try_from(path).unwrap(), "data/*.parquet");
805
806        let batch_size = extractor.get_named_arg("batch_size").unwrap();
807        assert_eq!(i64::try_from(batch_size).unwrap(), 1024);
808
809        // Verify they were consumed
810        assert!(extractor.check_exhausted().is_ok());
811
812        assert_eq!(args.positional.len(), 1);
813        assert_eq!(args.output_columns.len(), 1);
814    }
815
816    #[test]
817    fn test_extension_error_cases() {
818        let registry = ExtensionRegistry::new();
819
820        // Extension not found
821        let args = ExtensionArgs::default();
822        let result = registry.parse_extension("NonExistent", &args);
823        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
824
825        // Missing argument
826        let args = ExtensionArgs::default();
827        let mut extractor = args.extractor();
828        let result = extractor.get_named_arg("missing");
829        assert!(result.is_none());
830        assert!(extractor.check_exhausted().is_ok());
831
832        // Type check example
833        let mut args = ExtensionArgs::default();
834        args.insert("test", 42_i64);
835        let mut extractor = args.extractor();
836        let result = extractor.get_named_arg("test");
837        assert_eq!(i64::try_from(result.unwrap()).unwrap(), 42);
838        assert!(extractor.check_exhausted().is_ok());
839    }
840
841    // Mock enhancement type for testing namespace separation
842    struct TestEnhancement {
843        hint: String,
844    }
845
846    impl AnyConvertible for TestEnhancement {
847        fn to_any(&self) -> Result<Any, ExtensionError> {
848            let json_str = format!(r#"{{"hint":"{}"}}"#, self.hint);
849            Ok(Any::new(Self::type_url(), json_str.into_bytes()))
850        }
851
852        fn type_url() -> String {
853            // Same type URL as TestExtension to test namespace separation
854            "test.TestExtension".to_string()
855        }
856
857        fn from_any<'a>(any: AnyRef<'a>) -> Result<Self, ExtensionError> {
858            let json_str = String::from_utf8(any.value.to_vec())
859                .map_err(|e| ExtensionError::Custom(format!("Invalid UTF-8: {e}")))?;
860            if json_str.contains("hint") {
861                Ok(TestEnhancement {
862                    hint: "test_hint".to_string(),
863                })
864            } else {
865                Err(ExtensionError::Custom("Missing hint field".to_string()))
866            }
867        }
868    }
869
870    impl Explainable for TestEnhancement {
871        fn name() -> &'static str {
872            "TestEnhancement"
873        }
874
875        fn from_args(args: &ExtensionArgs) -> Result<Self, ExtensionError> {
876            let mut extractor = args.extractor();
877            let hint: String = extractor.expect_named_arg::<&str>("hint")?.to_string();
878            extractor.check_exhausted()?;
879            Ok(TestEnhancement { hint })
880        }
881
882        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
883            let mut args = ExtensionArgs::default();
884            args.insert("hint", self.hint.clone());
885            Ok(args)
886        }
887    }
888
889    #[test]
890    fn test_namespace_separation() {
891        let mut registry = ExtensionRegistry::new();
892
893        // Register same type URL in multiple namespaces - should not conflict
894        registry.register_relation::<TestExtension>().unwrap();
895        registry
896            .register_extension_table::<TestExtension>()
897            .unwrap();
898        registry.register_enhancement::<TestEnhancement>().unwrap();
899
900        // Verify all are registered
901        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
902        assert!(registry.has_extension(ExtensionType::ExtensionTable, "TestExtension"));
903        assert!(registry.has_extension(ExtensionType::Enhancement, "TestEnhancement"));
904        assert_eq!(registry.extension_names(ExtensionType::Relation).len(), 1);
905        assert_eq!(
906            registry
907                .extension_names(ExtensionType::ExtensionTable)
908                .len(),
909            1
910        );
911        assert_eq!(
912            registry.extension_names(ExtensionType::Enhancement).len(),
913            1
914        );
915
916        // Test that extension namespace works
917        let mut ext_args = ExtensionArgs::default();
918        ext_args.insert("path", "data.parquet");
919        ext_args.insert("batch_size", 2048_i64);
920
921        let ext_any = registry
922            .parse_extension("TestExtension", &ext_args)
923            .unwrap();
924        assert_eq!(ext_any.type_url, "test.TestExtension");
925
926        // Test that ExtensionTable namespace works independently
927        let table_any = registry
928            .parse_extension_table("TestExtension", &ext_args)
929            .unwrap();
930        assert_eq!(table_any.type_url, "test.TestExtension");
931
932        // Test that enhancement namespace works
933        let mut enh_args = ExtensionArgs::default();
934        enh_args.insert("hint", "optimize");
935
936        let enh_any = registry
937            .parse_enhancement("TestEnhancement", &enh_args)
938            .unwrap();
939        assert_eq!(enh_any.type_url, "test.TestExtension"); // Same type URL!
940
941        // Test decode_enhancement
942        let enh_ref = enh_any.as_ref();
943        let (name, args) = registry.decode_enhancement(enh_ref).unwrap();
944        assert_eq!(name, "TestEnhancement");
945        assert_eq!(
946            <&str>::try_from(args.named.get("hint").unwrap()).unwrap(),
947            "test_hint"
948        );
949    }
950
951    #[test]
952    fn test_enhancement_duplicate_registration_returns_error() {
953        let mut registry = ExtensionRegistry::new();
954        registry.register_enhancement::<TestEnhancement>().unwrap();
955        let result = registry.register_enhancement::<TestEnhancement>();
956        assert!(matches!(
957            result,
958            Err(RegistrationError::DuplicateName { .. })
959        ));
960    }
961
962    #[test]
963    fn test_extension_table_duplicate_registration_returns_error() {
964        let mut registry = ExtensionRegistry::new();
965        registry
966            .register_extension_table::<TestExtension>()
967            .unwrap();
968        let result = registry.register_extension_table::<TestExtension>();
969        assert!(matches!(
970            result,
971            Err(RegistrationError::DuplicateName { .. })
972        ));
973    }
974
975    #[test]
976    fn test_extension_table_not_found_error() {
977        let registry = ExtensionRegistry::new();
978        let args = ExtensionArgs::default();
979        let result = registry.parse_extension_table("NonExistentExtensionTable", &args);
980        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
981    }
982
983    #[test]
984    fn test_enhancement_not_found_error() {
985        let registry = ExtensionRegistry::new();
986        let args = ExtensionArgs::default();
987        let result = registry.parse_enhancement("NonExistentEnhancement", &args);
988        assert!(matches!(result, Err(ExtensionError::NotFound { .. })));
989    }
990
991    // Extension with same type URL as TestExtension but different name,
992    // used to test that conflicting type URLs don't leave stale state.
993    struct ConflictingExtension;
994
995    impl AnyConvertible for ConflictingExtension {
996        fn to_any(&self) -> Result<Any, ExtensionError> {
997            Ok(Any::new(Self::type_url(), vec![]))
998        }
999
1000        fn type_url() -> String {
1001            // Same type URL as TestExtension — will conflict in the same namespace
1002            "test.TestExtension".to_string()
1003        }
1004
1005        fn from_any<'a>(_any: AnyRef<'a>) -> Result<Self, ExtensionError> {
1006            Ok(ConflictingExtension)
1007        }
1008    }
1009
1010    impl Explainable for ConflictingExtension {
1011        fn name() -> &'static str {
1012            "ConflictingExtension"
1013        }
1014
1015        fn from_args(_args: &ExtensionArgs) -> Result<Self, ExtensionError> {
1016            Ok(ConflictingExtension)
1017        }
1018
1019        fn to_args(&self) -> Result<ExtensionArgs, ExtensionError> {
1020            Ok(ExtensionArgs::default())
1021        }
1022    }
1023
1024    #[test]
1025    fn test_conflicting_type_url_leaves_registry_unchanged() {
1026        let mut registry = ExtensionRegistry::new();
1027        registry.register_relation::<TestExtension>().unwrap();
1028
1029        // Attempt to register a different extension with the same type URL
1030        let result = registry.register_relation::<ConflictingExtension>();
1031        assert!(matches!(
1032            result,
1033            Err(RegistrationError::ConflictingTypeUrl { .. })
1034        ));
1035
1036        // Registry should still only know about the original extension
1037        assert!(registry.has_extension(ExtensionType::Relation, "TestExtension"));
1038        assert!(!registry.has_extension(ExtensionType::Relation, "ConflictingExtension"));
1039        assert_eq!(
1040            registry.extension_names(ExtensionType::Relation),
1041            vec!["TestExtension"]
1042        );
1043    }
1044
1045    #[test]
1046    fn test_extension_table_conflicting_type_url_leaves_registry_unchanged() {
1047        let mut registry = ExtensionRegistry::new();
1048        registry
1049            .register_extension_table::<TestExtension>()
1050            .unwrap();
1051
1052        // Attempt to register a different extension table with the same type URL
1053        let result = registry.register_extension_table::<ConflictingExtension>();
1054        assert!(matches!(
1055            result,
1056            Err(RegistrationError::ConflictingTypeUrl { .. })
1057        ));
1058
1059        // Registry should still only know about the original extension table
1060        assert!(registry.has_extension(ExtensionType::ExtensionTable, "TestExtension"));
1061        assert!(!registry.has_extension(ExtensionType::ExtensionTable, "ConflictingExtension"));
1062        assert_eq!(
1063            registry.extension_names(ExtensionType::ExtensionTable),
1064            vec!["TestExtension"]
1065        );
1066    }
1067}