Skip to main content

saluki_common/
deser.rs

1//! Deserialization helpers.
2//!
3//! This module provides various helpers for handling the deserialization of common data types in more flexible and
4//! permissive ways. These helpers are designed to be used with the `serde_with` crate.
5
6use std::fmt;
7
8use serde::{
9    de::{Error, Unexpected},
10    Deserializer,
11};
12use serde_with::DeserializeAs;
13
14/// Permissively deserializes a boolean.
15///
16/// This helper module allows deserializing a `bool` from a number of possible data types:
17///
18/// - `true` or `false` as a native boolean
19/// - `1` or `0` as an integer (signed, unsigned, or floating point)
20/// - `"true"` or `"false"` as a string, case insensitive
21/// - `"1"`, `"t"`, or `"T"` as truthy strings; `"0"`, `"f"`, or `"F"` as falsy strings
22pub struct PermissiveBool;
23
24impl<'de> DeserializeAs<'de, bool> for PermissiveBool {
25    fn deserialize_as<D>(deserializer: D) -> Result<bool, D::Error>
26    where
27        D: Deserializer<'de>,
28    {
29        struct Visitor;
30
31        impl<'vde> serde::de::Visitor<'vde> for Visitor {
32            type Value = bool;
33
34            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
35                formatter.write_str("a boolean, string, integer, or floating-point number")
36            }
37
38            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
39            where
40                E: Error,
41            {
42                Ok(value)
43            }
44
45            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
46            where
47                E: Error,
48            {
49                // Check short forms first, then fall back to case-insensitive "true"/"false".
50                match value {
51                    "1" | "t" | "T" => Ok(true),
52                    "0" | "f" | "F" => Ok(false),
53                    _ => match value.to_lowercase().as_str() {
54                        "true" => Ok(true),
55                        "false" => Ok(false),
56                        _ => Err(Error::invalid_value(
57                            Unexpected::Str(value),
58                            &"a boolean string (\"true\" or \"false\", case insensitive, or short forms: \"1\", \"t\", \"T\", \"0\", \"f\", \"F\")",
59                        )),
60                    },
61                }
62            }
63
64            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
65            where
66                E: Error,
67            {
68                match value {
69                    0 => Ok(false),
70                    1 => Ok(true),
71                    _ => Err(Error::invalid_value(Unexpected::Signed(value), &"0 or 1")),
72                }
73            }
74
75            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
76            where
77                E: Error,
78            {
79                match value {
80                    0 => Ok(false),
81                    1 => Ok(true),
82                    _ => Err(Error::invalid_value(Unexpected::Unsigned(value), &"0 or 1")),
83                }
84            }
85
86            fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
87            where
88                E: Error,
89            {
90                match value {
91                    0.0 => Ok(false),
92                    1.0 => Ok(true),
93                    _ => Err(Error::invalid_value(Unexpected::Float(value), &"0.0 or 1.0")),
94                }
95            }
96        }
97
98        deserializer.deserialize_any(Visitor)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use serde::de::{value::StrDeserializer, IntoDeserializer};
105    use serde_with::DeserializeAs;
106
107    use super::PermissiveBool;
108
109    fn parse_bool(v: bool) -> Result<bool, serde::de::value::Error> {
110        PermissiveBool::deserialize_as(v.into_deserializer())
111    }
112
113    fn parse_str(s: &str) -> Result<bool, serde::de::value::Error> {
114        let de: StrDeserializer<serde::de::value::Error> = s.into_deserializer();
115        PermissiveBool::deserialize_as(de)
116    }
117
118    fn parse_int(v: i64) -> Result<bool, serde::de::value::Error> {
119        PermissiveBool::deserialize_as(v.into_deserializer())
120    }
121
122    // Native boolean
123    #[test]
124    fn native_true() {
125        assert!(parse_bool(true).unwrap());
126    }
127
128    #[test]
129    fn native_false() {
130        assert!(!parse_bool(false).unwrap());
131    }
132
133    // String variants
134    #[test]
135    fn str_truthy() {
136        for s in &["1", "t", "T", "true", "True", "tRuE"] {
137            assert!(parse_str(s).unwrap(), "expected {s:?} to be truthy");
138        }
139    }
140
141    #[test]
142    fn str_falsy() {
143        for s in &["0", "f", "F", "false", "False", "fAlSe"] {
144            assert!(!parse_str(s).unwrap(), "expected {s:?} to be falsy");
145        }
146    }
147
148    // Invalid string
149    #[test]
150    fn str_invalid_rejected() {
151        assert!(parse_str("yes").is_err());
152        assert!(parse_str("no").is_err());
153        assert!(parse_str("2").is_err());
154        assert!(parse_str("").is_err());
155    }
156
157    // Integer variants
158    #[test]
159    fn int_true() {
160        assert!(parse_int(1).unwrap());
161    }
162
163    #[test]
164    fn int_false() {
165        assert!(!parse_int(0).unwrap());
166    }
167}