Skip to main content

libddwaf/
serde.rs

1//! Implementations of [`serde::Deserialize`] for [`object::WafObject`](crate::object::WafObject) and
2//! [`object::WafMap`](crate::object::WafMap).
3//!
4//! This module also provides [`Limits`] for applying constraints during deserialization,
5//! similar to the PHP extension's `dd_mpack_limits` structure.
6
7use std::cell::Cell;
8
9use serde::{
10    de::Error,
11    ser::{SerializeMap, SerializeSeq},
12    Deserializer,
13};
14
15use crate::object::{
16    Keyed, WafArray, WafBool, WafFloat, WafMap, WafNull, WafObject, WafObjectType, WafSigned,
17    WafString, WafUnsigned,
18};
19
20impl<'de> serde::Deserialize<'de> for WafObject {
21    fn deserialize<D>(deserializer: D) -> Result<WafObject, D::Error>
22    where
23        D: Deserializer<'de>,
24    {
25        deserializer.deserialize_any(Visitor)
26    }
27}
28
29impl<'de> serde::Deserialize<'de> for WafMap {
30    fn deserialize<D>(deserializer: D) -> Result<WafMap, D::Error>
31    where
32        D: Deserializer<'de>,
33    {
34        let dobj = deserializer.deserialize_any(Visitor)?;
35        dobj.try_into()
36            .map_err(|_| serde::de::Error::custom("invalid type: not a map"))
37    }
38}
39
40struct Visitor;
41
42impl<'de> serde::de::Visitor<'de> for Visitor {
43    type Value = WafObject;
44
45    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
46        formatter.write_str(
47            "a valid WafObject (unsigned, signed, string, array, map, bool, float, or null)",
48        )
49    }
50
51    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
52    where
53        E: Error,
54    {
55        Ok(WafObject::from(v))
56    }
57
58    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
59    where
60        E: Error,
61    {
62        Ok(WafObject::from(v))
63    }
64
65    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
66    where
67        E: Error,
68    {
69        Ok(WafObject::from(v))
70    }
71
72    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
73    where
74        E: Error,
75    {
76        Ok(WafObject::from(v))
77    }
78
79    fn visit_unit<E>(self) -> Result<Self::Value, E>
80    where
81        E: Error,
82    {
83        Ok(WafObject::from(()))
84    }
85
86    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
87    where
88        E: Error,
89    {
90        Ok(WafObject::from(v))
91    }
92
93    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
94    where
95        E: Error,
96    {
97        Ok(WafObject::from(WafString::from(v)))
98    }
99
100    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
101    where
102        A: serde::de::SeqAccess<'de>,
103    {
104        let mut vec = seq.size_hint().map(Vec::with_capacity).unwrap_or_default();
105        while let Some(value) = seq.next_element()? {
106            vec.push(value);
107        }
108        let mut res = WafArray::new(vec.len().try_into().unwrap());
109        for (i, v) in vec.into_iter().enumerate() {
110            res[i] = v;
111        }
112        Ok(res.into())
113    }
114
115    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
116    where
117        A: serde::de::MapAccess<'de>,
118    {
119        let mut vec: Vec<(WafObject, WafObject)> =
120            map.size_hint().map(Vec::with_capacity).unwrap_or_default();
121        while let Some((key, value)) = map.next_entry::<WafObject, WafObject>()? {
122            vec.push((key, value));
123        }
124        let mut res = WafMap::new(vec.len().try_into().map_err(A::Error::custom)?);
125        for (i, (k, v)) in vec.into_iter().enumerate() {
126            res[i] = Keyed::new(k, v);
127        }
128        Ok(res.into())
129    }
130}
131
132impl serde::Serialize for WafObject {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: serde::Serializer,
136    {
137        match self.object_type() {
138            WafObjectType::Unsigned => {
139                unsafe { self.as_type_unchecked::<WafUnsigned>() }.serialize(serializer)
140            }
141            WafObjectType::Signed => {
142                unsafe { self.as_type_unchecked::<WafSigned>() }.serialize(serializer)
143            }
144            WafObjectType::Bool => {
145                unsafe { self.as_type_unchecked::<WafBool>() }.serialize(serializer)
146            }
147            WafObjectType::Float => {
148                unsafe { self.as_type_unchecked::<WafFloat>() }.serialize(serializer)
149            }
150            WafObjectType::String => {
151                unsafe { self.as_type_unchecked::<WafString>() }.serialize(serializer)
152            }
153            WafObjectType::Array => {
154                unsafe { self.as_type_unchecked::<WafArray>() }.serialize(serializer)
155            }
156            WafObjectType::Map => {
157                unsafe { self.as_type_unchecked::<WafMap>() }.serialize(serializer)
158            }
159            WafObjectType::Null | WafObjectType::Invalid => serializer.serialize_unit(),
160        }
161    }
162}
163
164impl serde::Serialize for WafUnsigned {
165    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
166    where
167        S: serde::Serializer,
168    {
169        serializer.serialize_u64(self.value())
170    }
171}
172
173impl serde::Serialize for WafSigned {
174    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175    where
176        S: serde::Serializer,
177    {
178        serializer.serialize_i64(self.value())
179    }
180}
181
182impl serde::Serialize for WafBool {
183    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
184    where
185        S: serde::Serializer,
186    {
187        serializer.serialize_bool(self.value())
188    }
189}
190
191impl serde::Serialize for WafFloat {
192    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193    where
194        S: serde::Serializer,
195    {
196        serializer.serialize_f64(self.value())
197    }
198}
199
200impl serde::Serialize for WafString {
201    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
202    where
203        S: serde::Serializer,
204    {
205        serializer.serialize_str(&String::from_utf8_lossy(self.as_bytes()))
206    }
207}
208
209impl serde::Serialize for WafArray {
210    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
211    where
212        S: serde::Serializer,
213    {
214        let mut seq_serializer = serializer.serialize_seq(Some(self.len() as usize))?;
215        for value in self.iter() {
216            seq_serializer.serialize_element(value)?;
217        }
218        seq_serializer.end()
219    }
220}
221
222impl serde::Serialize for WafMap {
223    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
224    where
225        S: serde::Serializer,
226    {
227        let mut map_serializer = serializer.serialize_map(Some(self.len() as usize))?;
228        for keyed_val in self.iter() {
229            // Key is serialized as WafObject; formats requiring string keys (e.g. JSON)
230            // will error if the key is not a WafString
231            map_serializer.serialize_entry(keyed_val.key(), keyed_val.value())?;
232        }
233        map_serializer.end()
234    }
235}
236
237/// Default maximum string length (4096 bytes).
238pub const DEFAULT_MAX_STRING_LENGTH: u32 = 4096;
239
240/// Default maximum depth (21 levels).
241pub const DEFAULT_MAX_DEPTH: usize = 21;
242
243/// Default maximum elements (2048 elements).
244pub const DEFAULT_MAX_ELEMENTS: usize = 2048;
245
246/// Limits applied during deserialization to prevent excessive resource usage.
247///
248/// This is modeled after the PHP extension's `dd_mpack_limits` structure:
249/// - `max_string_length`: Strings longer than this are truncated
250/// - `max_depth`: Maximum nesting depth; deeper structures become null
251/// - `max_elements`: Maximum total elements; excess elements are skipped
252///
253/// # Example
254/// ```
255/// use libddwaf::serde::Limits;
256///
257/// let limits = Limits::default();
258/// assert_eq!(limits.max_string_length, 4096);
259/// assert_eq!(limits.max_depth, 21);
260/// assert_eq!(limits.max_elements, 2048);
261///
262/// // Create custom limits
263/// let custom = Limits {
264///     max_string_length: 1024,
265///     max_depth: 10,
266///     max_elements: 100,
267/// };
268/// ```
269#[derive(Debug, Clone)]
270pub struct Limits {
271    pub max_string_length: u32,
272    pub max_depth: usize,
273    pub max_elements: usize,
274}
275
276impl Default for Limits {
277    fn default() -> Self {
278        Self {
279            max_string_length: DEFAULT_MAX_STRING_LENGTH,
280            max_depth: DEFAULT_MAX_DEPTH,
281            max_elements: DEFAULT_MAX_ELEMENTS,
282        }
283    }
284}
285
286/// The result of deserializing with limits.
287#[derive(Debug)]
288pub struct LimitedResult<T> {
289    /// The deserialized value.
290    pub value: T,
291    /// Whether any limits were reached during deserialization.
292    pub truncated: bool,
293}
294
295/// Deserialize a [`WafObject`] from a deserializer with the specified limits.
296///
297/// Returns a [`LimitedResult`] containing the deserialized value and whether
298/// truncation occurred.
299///
300/// # Example
301/// ```
302/// use libddwaf::serde::{Limits, deserialize_with_limits};
303/// use libddwaf::object::WafObject;
304///
305/// let json = r#"{"key": "value"}"#;
306/// let limits = Limits {
307///     max_string_length: 3,
308///     max_depth: 10,
309///     max_elements: 100,
310/// };
311/// let mut deserializer = serde_json::Deserializer::from_str(json);
312/// let result = deserialize_with_limits(&mut deserializer, &limits).unwrap();
313///
314/// assert!(result.truncated); // "value" was truncated to "val"
315/// ```
316/// # Errors
317/// Returns an error if the deserializer returns an error.
318pub fn deserialize_with_limits<'de, D>(
319    deserializer: D,
320    limits: &Limits,
321) -> Result<LimitedResult<WafObject>, D::Error>
322where
323    D: Deserializer<'de>,
324{
325    let state = LimitedState::new(limits);
326    let visitor = LimitedVisitor { state: &state };
327    let value = deserializer.deserialize_any(visitor)?;
328    Ok(LimitedResult {
329        value,
330        truncated: state.truncated.get(),
331    })
332}
333
334struct LimitedState<'a> {
335    limits: &'a Limits,
336    depth_remaining: Cell<usize>,
337    elements_remaining: Cell<usize>,
338    truncated: Cell<bool>,
339}
340
341impl<'a> LimitedState<'a> {
342    fn new(limits: &'a Limits) -> Self {
343        Self {
344            limits,
345            depth_remaining: Cell::new(limits.max_depth),
346            elements_remaining: Cell::new(limits.max_elements),
347            truncated: Cell::new(false),
348        }
349    }
350
351    /// Consumes one element from the remaining count.
352    /// Returns true if the element can be processed, false if limit reached.
353    fn consume_element(&self) -> bool {
354        let remaining = self.elements_remaining.get();
355        if remaining == 0 {
356            self.truncated.set(true);
357            false
358        } else {
359            self.elements_remaining.set(remaining - 1);
360            true
361        }
362    }
363
364    fn can_descend(&self) -> bool {
365        self.depth_remaining.get() > 0
366    }
367
368    fn enter_depth(&self) {
369        let depth = self.depth_remaining.get();
370        if depth > 0 {
371            self.depth_remaining.set(depth - 1);
372        }
373    }
374
375    fn exit_depth(&self) {
376        self.depth_remaining.set(self.depth_remaining.get() + 1);
377    }
378
379    fn truncate_string<'b>(&self, s: &'b str) -> &'b str {
380        if s.len() > self.limits.max_string_length as usize {
381            self.truncated.set(true);
382            // Find a valid UTF-8 boundary
383            let mut end = self.limits.max_string_length as usize;
384            while end > 0 && !s.is_char_boundary(end) {
385                end -= 1;
386            }
387            &s[..end]
388        } else {
389            s
390        }
391    }
392
393    fn truncate_bytes<'b>(&self, b: &'b [u8]) -> &'b [u8] {
394        if b.len() > self.limits.max_string_length as usize {
395            self.truncated.set(true);
396            &b[..self.limits.max_string_length as usize]
397        } else {
398            b
399        }
400    }
401}
402
403/// A visitor that applies limits during deserialization.
404struct LimitedVisitor<'a> {
405    state: &'a LimitedState<'a>,
406}
407
408impl<'de> serde::de::Visitor<'de> for LimitedVisitor<'_> {
409    type Value = WafObject;
410
411    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
412        formatter.write_str(
413            "a valid WafObject (unsigned, signed, string, array, map, bool, float, or null)",
414        )
415    }
416
417    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
418    where
419        E: Error,
420    {
421        if !self.state.consume_element() {
422            return Ok(WafNull::new().into());
423        }
424        Ok(WafObject::from(v))
425    }
426
427    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
428    where
429        E: Error,
430    {
431        if !self.state.consume_element() {
432            return Ok(WafNull::new().into());
433        }
434        Ok(WafObject::from(v))
435    }
436
437    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
438    where
439        E: Error,
440    {
441        if !self.state.consume_element() {
442            return Ok(WafNull::new().into());
443        }
444        Ok(WafObject::from(v))
445    }
446
447    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
448    where
449        E: Error,
450    {
451        if !self.state.consume_element() {
452            return Ok(WafNull::new().into());
453        }
454        Ok(WafObject::from(v))
455    }
456
457    fn visit_unit<E>(self) -> Result<Self::Value, E>
458    where
459        E: Error,
460    {
461        if !self.state.consume_element() {
462            return Ok(WafNull::new().into());
463        }
464        Ok(WafObject::from(()))
465    }
466
467    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
468    where
469        E: Error,
470    {
471        if !self.state.consume_element() {
472            return Ok(WafNull::new().into());
473        }
474        let truncated = self.state.truncate_string(v);
475        Ok(WafObject::from(truncated))
476    }
477
478    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
479    where
480        E: Error,
481    {
482        if !self.state.consume_element() {
483            return Ok(WafNull::new().into());
484        }
485        let truncated = self.state.truncate_bytes(v);
486        Ok(WafObject::from(WafString::from(truncated)))
487    }
488
489    #[allow(clippy::cast_possible_truncation)]
490    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
491    where
492        A: serde::de::SeqAccess<'de>,
493    {
494        if !self.state.can_descend() {
495            self.state.truncated.set(true);
496            // Drain the sequence
497            while seq.next_element::<serde::de::IgnoredAny>()?.is_some() {}
498            return Ok(WafNull::new().into());
499        }
500
501        // Consume element for the array itself
502        if !self.state.consume_element() {
503            // Drain the sequence
504            while seq.next_element::<serde::de::IgnoredAny>()?.is_some() {}
505            return Ok(WafNull::new().into());
506        }
507
508        self.state.enter_depth();
509
510        let mut vec = seq.size_hint().map(Vec::with_capacity).unwrap_or_default();
511        while self.state.elements_remaining.get() > 0 {
512            match seq.next_element_seed(LimitedSeed { state: self.state })? {
513                Some(value) => vec.push(value),
514                None => break,
515            }
516        }
517
518        // If there are remaining elements, drain them and mark as truncated
519        if seq.next_element::<serde::de::IgnoredAny>()?.is_some() {
520            self.state.truncated.set(true);
521            while seq.next_element::<serde::de::IgnoredAny>()?.is_some() {}
522        }
523
524        self.state.exit_depth();
525
526        let len = vec.len().min(u16::MAX as usize);
527        let mut res = WafArray::new(len as u16);
528        for (i, v) in vec.into_iter().take(len).enumerate() {
529            res[i] = v;
530        }
531        Ok(res.into())
532    }
533
534    #[allow(clippy::cast_possible_truncation)]
535    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
536    where
537        A: serde::de::MapAccess<'de>,
538    {
539        // Check if we can descend
540        if !self.state.can_descend() {
541            self.state.truncated.set(true);
542            // Drain the map
543            while map
544                .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
545                .is_some()
546            {}
547            return Ok(WafNull::new().into());
548        }
549
550        // Consume element for the map itself
551        if !self.state.consume_element() {
552            // Drain the map
553            while map
554                .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
555                .is_some()
556            {}
557            return Ok(WafNull::new().into());
558        }
559
560        self.state.enter_depth();
561
562        let mut vec: Vec<Keyed<WafObject>> =
563            map.size_hint().map(Vec::with_capacity).unwrap_or_default();
564
565        while self.state.elements_remaining.get() > 0 {
566            match map.next_entry_seed(
567                LimitedSeed { state: self.state },
568                LimitedSeed { state: self.state },
569            )? {
570                Some(pair @ (_, _)) => vec.push(pair.into()),
571                None => break,
572            }
573        }
574
575        // If there are remaining entries, drain them and mark as truncated
576        if map
577            .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
578            .is_some()
579        {
580            self.state.truncated.set(true);
581            while map
582                .next_entry::<serde::de::IgnoredAny, serde::de::IgnoredAny>()?
583                .is_some()
584            {}
585        }
586
587        self.state.exit_depth();
588
589        let len = vec.len().min(u16::MAX as usize);
590        let mut res = WafMap::new(len as u16);
591        for (i, keyed) in vec.into_iter().take(len).enumerate() {
592            res[i] = keyed;
593        }
594        Ok(res.into())
595    }
596}
597
598/// A `DeserializeSeed` that uses the limited visitor.
599struct LimitedSeed<'a> {
600    state: &'a LimitedState<'a>,
601}
602
603impl<'de, 'a> serde::de::DeserializeSeed<'de> for LimitedSeed<'a>
604where
605    'de: 'a,
606{
607    type Value = WafObject;
608
609    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
610    where
611        D: Deserializer<'de>,
612    {
613        let visitor = LimitedVisitor { state: self.state };
614        deserializer.deserialize_any(visitor)
615    }
616}