1use std::fmt;
7
8use serde::{
9 de::{Error, Unexpected},
10 Deserializer,
11};
12use serde_with::DeserializeAs;
13
14pub struct PermissiveBool;
23
24impl<'de> DeserializeAs<'de, bool> for PermissiveBool {
25 fn deserialize_as<D>(deserializer: D) -> Result<bool, D::Error>
26 where
27 D: Deserializer<'de>,
28 {
29 struct Visitor;
30
31 impl<'vde> serde::de::Visitor<'vde> for Visitor {
32 type Value = bool;
33
34 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
35 formatter.write_str("a boolean, string, integer, or floating-point number")
36 }
37
38 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
39 where
40 E: Error,
41 {
42 Ok(value)
43 }
44
45 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
46 where
47 E: Error,
48 {
49 match value {
51 "1" | "t" | "T" => Ok(true),
52 "0" | "f" | "F" => Ok(false),
53 _ => match value.to_lowercase().as_str() {
54 "true" => Ok(true),
55 "false" => Ok(false),
56 _ => Err(Error::invalid_value(
57 Unexpected::Str(value),
58 &"a boolean string (\"true\" or \"false\", case insensitive, or short forms: \"1\", \"t\", \"T\", \"0\", \"f\", \"F\")",
59 )),
60 },
61 }
62 }
63
64 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
65 where
66 E: Error,
67 {
68 match value {
69 0 => Ok(false),
70 1 => Ok(true),
71 _ => Err(Error::invalid_value(Unexpected::Signed(value), &"0 or 1")),
72 }
73 }
74
75 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
76 where
77 E: Error,
78 {
79 match value {
80 0 => Ok(false),
81 1 => Ok(true),
82 _ => Err(Error::invalid_value(Unexpected::Unsigned(value), &"0 or 1")),
83 }
84 }
85
86 fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
87 where
88 E: Error,
89 {
90 match value {
91 0.0 => Ok(false),
92 1.0 => Ok(true),
93 _ => Err(Error::invalid_value(Unexpected::Float(value), &"0.0 or 1.0")),
94 }
95 }
96 }
97
98 deserializer.deserialize_any(Visitor)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use serde::de::{value::StrDeserializer, IntoDeserializer};
105 use serde_with::DeserializeAs;
106
107 use super::PermissiveBool;
108
109 fn parse_bool(v: bool) -> Result<bool, serde::de::value::Error> {
110 PermissiveBool::deserialize_as(v.into_deserializer())
111 }
112
113 fn parse_str(s: &str) -> Result<bool, serde::de::value::Error> {
114 let de: StrDeserializer<serde::de::value::Error> = s.into_deserializer();
115 PermissiveBool::deserialize_as(de)
116 }
117
118 fn parse_int(v: i64) -> Result<bool, serde::de::value::Error> {
119 PermissiveBool::deserialize_as(v.into_deserializer())
120 }
121
122 #[test]
124 fn native_true() {
125 assert!(parse_bool(true).unwrap());
126 }
127
128 #[test]
129 fn native_false() {
130 assert!(!parse_bool(false).unwrap());
131 }
132
133 #[test]
135 fn str_truthy() {
136 for s in &["1", "t", "T", "true", "True", "tRuE"] {
137 assert!(parse_str(s).unwrap(), "expected {s:?} to be truthy");
138 }
139 }
140
141 #[test]
142 fn str_falsy() {
143 for s in &["0", "f", "F", "false", "False", "fAlSe"] {
144 assert!(!parse_str(s).unwrap(), "expected {s:?} to be falsy");
145 }
146 }
147
148 #[test]
150 fn str_invalid_rejected() {
151 assert!(parse_str("yes").is_err());
152 assert!(parse_str("no").is_err());
153 assert!(parse_str("2").is_err());
154 assert!(parse_str("").is_err());
155 }
156
157 #[test]
159 fn int_true() {
160 assert!(parse_int(1).unwrap());
161 }
162
163 #[test]
164 fn int_false() {
165 assert!(!parse_int(0).unwrap());
166 }
167}