Skip to main content

saluki_config/
duration_string.rs

1//! A duration configuration value compatible with the Agent.
2//!
3//! The Agent loads configuration via [spf13/viper][viper], which uses [spf13/cast][cast] to coerce YAML/JSON/env
4//! values into Go's [`time.Duration`][go-duration]. [`DurationString`] reproduces that coercion so ADP accepts the
5//! same inputs and interprets them the same way.
6//!
7//! [viper]: https://github.com/spf13/viper
8//! [cast]: https://github.com/spf13/cast
9//! [go-duration]: https://pkg.go.dev/time#ParseDuration
10
11use std::fmt;
12use std::fmt::{Debug, Display, Formatter};
13use std::str::FromStr;
14use std::time::Duration;
15
16use serde::de::{self, Deserializer, Visitor};
17use serde::{Deserialize, Serialize, Serializer};
18use snafu::Snafu;
19
20/// A duration value that deserializes from the formats accepted by the Agent's configuration loader.
21///
22/// # Deserialization
23///
24/// Accepted inputs:
25///
26/// - Strings with Go time-unit suffixes: `"30s"`, `"1h30m"`, `"250ms"`, `"2h45m30s"`, `"1.5h"`. Valid suffixes: `ns`,
27///   `us`, `µs`, `ms`, `s`, `m`, `h`.
28///
29/// - Strings containing only a bare integer: `"5"` is 5 **nanoseconds**. This matches vipers `cast.ToDurationE`'s
30///   fallback for unit-less string values.
31///
32/// - Integer numbers: `5` is 5 **nanoseconds**.
33///
34/// - Floating-point numbers: `5.0` is 5 **nanoseconds** (truncated toward zero).
35///
36/// Negative durations (for example `"-1h"`) are rejected because [`std::time::Duration`] cannot represent them.
37///
38/// # Bare numbers are nanoseconds, not seconds (!!)
39///
40/// A configuration value like `expected_tags_duration: 30` means 30 **nanoseconds**, not 30 seconds. Use `"30s"` for
41/// 30 seconds. This matches the Agent's `time.Duration` coercion.
42///
43/// # Serialization
44///
45/// Serializes as `"{seconds}s{nanoseconds}ns"`. For example, 30 seconds becomes `"30s0ns"` and 30.5 seconds becomes
46/// `"30s500000000ns"`. Whole seconds are maximized and the nanosecond component is always less than `1_000_000_000`,
47/// so the form is unambiguous and round-trips through this parser and would be accepted by the Agent as well.
48#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
49pub struct DurationString(Duration);
50
51impl Debug for DurationString {
52    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
53        Debug::fmt(&self.0, f)
54    }
55}
56
57impl DurationString {
58    /// Creates a new `DurationString` wrapping the given [`Duration`].
59    pub const fn new(d: Duration) -> Self {
60        Self(d)
61    }
62
63    /// Returns the underlying [`Duration`].
64    pub const fn as_duration(&self) -> Duration {
65        self.0
66    }
67}
68
69impl From<Duration> for DurationString {
70    fn from(d: Duration) -> Self {
71        Self(d)
72    }
73}
74
75impl From<DurationString> for Duration {
76    fn from(d: DurationString) -> Self {
77        d.0
78    }
79}
80
81impl std::ops::Deref for DurationString {
82    type Target = Duration;
83
84    fn deref(&self) -> &Duration {
85        &self.0
86    }
87}
88
89impl FromStr for DurationString {
90    type Err = ParseDurationError;
91
92    fn from_str(s: &str) -> Result<Self, Self::Err> {
93        parse_string(s).map(Self)
94    }
95}
96
97impl Display for DurationString {
98    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
99        write!(f, "{}s{}ns", self.0.as_secs(), self.0.subsec_nanos())
100    }
101}
102
103impl Serialize for DurationString {
104    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
105        serializer.collect_str(self)
106    }
107}
108
109impl<'de> Deserialize<'de> for DurationString {
110    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
111        deserializer.deserialize_any(DurationStringVisitor)
112    }
113}
114
115struct DurationStringVisitor;
116
117impl<'de> Visitor<'de> for DurationStringVisitor {
118    type Value = DurationString;
119
120    fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        f.write_str("a duration string (e.g. \"30s\", \"1h30m\") or a non-negative number of nanoseconds")
122    }
123
124    fn visit_i64<E: de::Error>(self, n: i64) -> Result<DurationString, E> {
125        if n < 0 {
126            return Err(E::custom(ParseDurationError::Negative));
127        }
128        Ok(DurationString(Duration::from_nanos(n as u64)))
129    }
130
131    fn visit_i128<E: de::Error>(self, n: i128) -> Result<DurationString, E> {
132        if n < 0 {
133            return Err(E::custom(ParseDurationError::Negative));
134        }
135        if n > MAX_NANOS_U64 as i128 {
136            return Err(E::custom(ParseDurationError::Overflow));
137        }
138        Ok(DurationString(Duration::from_nanos(n as u64)))
139    }
140
141    fn visit_u64<E: de::Error>(self, n: u64) -> Result<DurationString, E> {
142        if n > MAX_NANOS_U64 {
143            return Err(E::custom(ParseDurationError::Overflow));
144        }
145        Ok(DurationString(Duration::from_nanos(n)))
146    }
147
148    fn visit_u128<E: de::Error>(self, n: u128) -> Result<DurationString, E> {
149        if n > MAX_NANOS_U64 as u128 {
150            return Err(E::custom(ParseDurationError::Overflow));
151        }
152        Ok(DurationString(Duration::from_nanos(n as u64)))
153    }
154
155    fn visit_f64<E: de::Error>(self, f: f64) -> Result<DurationString, E> {
156        if !f.is_finite() {
157            return Err(E::custom("duration nanoseconds must be finite"));
158        }
159        if f < 0.0 {
160            return Err(E::custom(ParseDurationError::Negative));
161        }
162        if f > MAX_NANOS_U64 as f64 {
163            return Err(E::custom(ParseDurationError::Overflow));
164        }
165        Ok(DurationString(Duration::from_nanos(f as u64)))
166    }
167
168    fn visit_str<E: de::Error>(self, s: &str) -> Result<DurationString, E> {
169        parse_string(s).map(DurationString).map_err(E::custom)
170    }
171
172    fn visit_string<E: de::Error>(self, s: String) -> Result<DurationString, E> {
173        self.visit_str(&s)
174    }
175}
176
177/// Maximum number of nanoseconds we will accept, matching the Agent's cap (Go's `time.Duration` is `int64`, so
178/// `i64::MAX` nanoseconds is the largest representable value).
179const MAX_NANOS_U64: u64 = i64::MAX as u64;
180
181/// Error returned when a duration value cannot be parsed.
182#[derive(Debug, Snafu)]
183pub enum ParseDurationError {
184    /// The value was syntactically invalid.
185    #[snafu(display("invalid duration '{}': {}", input, reason))]
186    Invalid {
187        /// The original input string.
188        input: String,
189        /// Reason the input was rejected.
190        reason: String,
191    },
192    /// The value parsed to a negative duration.
193    #[snafu(display("negative durations are not supported"))]
194    Negative,
195    /// The value exceeds the range of [`std::time::Duration`] as nanoseconds.
196    #[snafu(display("duration value exceeds supported range"))]
197    Overflow,
198}
199
200fn invalid(input: &str, reason: impl Into<String>) -> ParseDurationError {
201    ParseDurationError::Invalid {
202        input: input.to_string(),
203        reason: reason.into(),
204    }
205}
206
207/// Parses a string using viper/cast precedence: try matching Go's `time.ParseDuration` first (with our
208/// `parse_duration`, then fall back to a bare integer (treated as nanoseconds).
209fn parse_string(s: &str) -> Result<Duration, ParseDurationError> {
210    let trimmed = s.trim();
211    match parse_duration(trimmed) {
212        Ok(d) => Ok(d),
213        Err(err) => match trimmed.parse::<i128>() {
214            Ok(n) if n < 0 => Err(ParseDurationError::Negative),
215            Ok(n) => {
216                if n > MAX_NANOS_U64 as i128 {
217                    return Err(ParseDurationError::Overflow);
218                }
219                Ok(Duration::from_nanos(n as u64))
220            }
221            Err(_) => Err(err),
222        },
223    }
224}
225
226/// Parses a string in the exact format accepted by Go's `time.ParseDuration`, restricted to non-negative values
227/// (since [`std::time::Duration`] cannot represent negatives).
228fn parse_duration(s: &str) -> Result<Duration, ParseDurationError> {
229    let orig = s;
230    let mut rest = s;
231    let mut total_ns: u128 = 0;
232    let mut negative = false;
233
234    if let Some(c) = rest.chars().next() {
235        if c == '+' || c == '-' {
236            negative = c == '-';
237            rest = &rest[1..];
238        }
239    }
240
241    // Special case: "0" alone (possibly after a sign) is zero.
242    if rest == "0" {
243        return Ok(Duration::ZERO);
244    }
245    if rest.is_empty() {
246        return Err(invalid(orig, "empty duration"));
247    }
248
249    while !rest.is_empty() {
250        let (int_part, after_int) = consume_digits(rest);
251        let had_int = !int_part.is_empty();
252
253        let (frac_part, after_frac) = if let Some(stripped) = after_int.strip_prefix('.') {
254            consume_digits(stripped)
255        } else {
256            ("", after_int)
257        };
258        let consumed_dot = after_int.starts_with('.');
259        let had_frac = consumed_dot && !frac_part.is_empty();
260
261        if !had_int && !had_frac {
262            return Err(invalid(orig, "expected digits"));
263        }
264
265        rest = after_frac;
266
267        let unit_str = consume_unit(rest);
268        if unit_str.is_empty() {
269            return Err(invalid(orig, "missing unit"));
270        }
271        rest = &rest[unit_str.len()..];
272
273        let unit_ns: u128 = match unit_str {
274            "ns" => 1,
275            "us" | "µs" => 1_000,
276            "ms" => 1_000_000,
277            "s" => 1_000_000_000,
278            "m" => 60 * 1_000_000_000,
279            "h" => 3_600 * 1_000_000_000,
280            other => return Err(invalid(orig, format!("unknown unit '{}'", other))),
281        };
282
283        let int_val: u128 = if int_part.is_empty() {
284            0
285        } else {
286            int_part
287                .parse::<u128>()
288                .map_err(|_| invalid(orig, "integer overflow"))?
289        };
290
291        let mut ns = int_val.checked_mul(unit_ns).ok_or_else(|| invalid(orig, "overflow"))?;
292
293        if !frac_part.is_empty() {
294            // Truncate the fraction to at most 18 digits to keep the intermediate u128 math well within range. 18
295            // decimal digits of precision is well beyond nanoseconds for every supported unit.
296            let keep = frac_part.len().min(18);
297            let frac_digits = &frac_part[..keep];
298            let mut scale: u128 = 1;
299            for _ in 0..keep {
300                scale *= 10;
301            }
302            let f: u128 = frac_digits
303                .parse::<u128>()
304                .map_err(|_| invalid(orig, "invalid fractional"))?;
305            let frac_ns = f.checked_mul(unit_ns).ok_or_else(|| invalid(orig, "overflow"))? / scale;
306            ns = ns.checked_add(frac_ns).ok_or_else(|| invalid(orig, "overflow"))?;
307        }
308
309        total_ns = total_ns.checked_add(ns).ok_or_else(|| invalid(orig, "overflow"))?;
310    }
311
312    if negative && total_ns != 0 {
313        return Err(ParseDurationError::Negative);
314    }
315
316    if total_ns > MAX_NANOS_U64 as u128 {
317        return Err(ParseDurationError::Overflow);
318    }
319    Ok(Duration::from_nanos(total_ns as u64))
320}
321
322fn consume_digits(s: &str) -> (&str, &str) {
323    let end = s.bytes().take_while(|b| b.is_ascii_digit()).count();
324    s.split_at(end)
325}
326
327fn consume_unit(s: &str) -> &str {
328    let mut end = 0;
329    for (i, c) in s.char_indices() {
330        if c.is_ascii_alphabetic() || c == 'µ' {
331            end = i + c.len_utf8();
332        } else {
333            break;
334        }
335    }
336    &s[..end]
337}
338
339#[cfg(test)]
340mod tests {
341    use anyhow::Context as _;
342    use serde_json::json;
343
344    use super::*;
345
346    const NS: Duration = Duration::from_nanos(1);
347    const _US: Duration = Duration::from_micros(1);
348    const MS: Duration = Duration::from_millis(1);
349    const S: Duration = Duration::from_secs(1);
350    const M: Duration = Duration::from_secs(60);
351    const H: Duration = Duration::from_secs(3600);
352
353    #[test]
354    fn deserialize_integer_succeeds() {
355        let json = r#"{ "value": 15 }"#;
356        let deserialized: SerdeTest = serde_json::from_str(json).unwrap();
357        assert_eq!(deserialized.value.as_duration(), 15 * NS);
358    }
359
360    /// Interesting test case because the 1.5 is interpreted as nanoseconds then truncated to 1ns in the Duration
361    #[test]
362    fn deserialize_float_succeeds() {
363        let json = r#"{ "value": 1.5 }"#;
364        let deserialized: SerdeTest = serde_json::from_str(json).unwrap();
365        assert_eq!(deserialized.value.as_duration(), 1 * NS);
366    }
367
368    #[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)]
369    struct SerdeTest {
370        value: DurationString,
371    }
372
373    impl From<Duration> for SerdeTest {
374        fn from(value: Duration) -> Self {
375            Self { value: value.into() }
376        }
377    }
378
379    fn test_json(input_value: &str) -> String {
380        json!({"value": input_value}).to_string()
381    }
382
383    fn test_yaml(input_value: &str) -> String {
384        format!("value: {input_value}")
385    }
386
387    fn run_success_case(input: &str, expected: Duration, serialized: &str) -> anyhow::Result<()> {
388        let expected_struct: SerdeTest = expected.into();
389        let json = test_json(input);
390        let yaml = test_yaml(input);
391        let msg = format!("failure for duration test case '{input}'");
392        let parsed_duration = DurationString::from_str(input).context(msg.clone())?;
393        anyhow::ensure!(
394            expected == parsed_duration.as_duration(),
395            "{msg}, expected: {expected:?}, got {:?}",
396            parsed_duration.as_duration()
397        );
398        let deserialized_from_json: SerdeTest = serde_json::from_str(&json).context(msg.clone())?;
399        anyhow::ensure!(
400            expected_struct == deserialized_from_json,
401            "{msg}, expected: {expected_struct:?}, got {deserialized_from_json:?}"
402        );
403        let roundtrip_json = serde_json::from_str(&serde_json::to_string(&expected_struct)?)?;
404        anyhow::ensure!(
405            expected_struct == roundtrip_json,
406            "{msg}, expected json roundrip to produce {expected_struct:?}, but got {roundtrip_json:?}"
407        );
408        let deserialized_from_yaml: SerdeTest = serde_yaml::from_str(&yaml).context(msg.clone())?;
409        anyhow::ensure!(
410            expected_struct == deserialized_from_yaml,
411            "{msg}, expected: {expected_struct:?}, got {deserialized_from_yaml:?}"
412        );
413        let roundtrip_yaml = serde_yaml::from_str(&serde_yaml::to_string(&expected_struct)?)?;
414        anyhow::ensure!(
415            expected_struct == roundtrip_yaml,
416            "{msg}, expected json roundrip to produce {expected_struct:?}, but got {roundtrip_yaml:?}"
417        );
418        let actual_serialized = parsed_duration.to_string();
419        anyhow::ensure!(
420            serialized == actual_serialized,
421            "Expected the input '{input}' to be serialized as '{serialized}' but got '{actual_serialized}'"
422        );
423        Ok(())
424    }
425
426    #[test]
427    fn duration_string_success_cases() {
428        let cases: &[(&str, Duration, &str)] = &[
429            ("0", Duration::ZERO, "0s0ns"),
430            ("-0", Duration::ZERO, "0s0ns"),
431            ("+0", Duration::ZERO, "0s0ns"),
432            ("+5h", Duration::from_hours(5), "18000s0ns"),
433            (".5s", Duration::from_millis(500), "0s500000000ns"),
434            ("5.s", Duration::from_secs(5), "5s0ns"),
435            ("0.000000001s", Duration::from_nanos(1), "0s1ns"),
436            ("1.5h", Duration::from_mins(90), "5400s0ns"),
437            (
438                "2h45m30.5s",
439                (2 * H) + (45 * M) + (30 * S) + (500 * MS),
440                "9930s500000000ns",
441            ),
442            ("12µs", Duration::from_micros(12), "0s12000ns"),
443            ("0s", Duration::ZERO, "0s0ns"),
444            ("1h1m1s1ms1us1ns", H + M + S + MS + (1000 * NS) + NS, "3661s1001001ns"),
445            ("24h", Duration::from_hours(24), "86400s0ns"),
446            (
447                "9223372036854775807ns",
448                Duration::from_nanos(9223372036854775807),
449                "9223372036s854775807ns",
450            ),
451            (
452                "9223372036854775.807us",
453                Duration::from_secs(9223372036) + (854775807 * NS),
454                "9223372036s854775807ns",
455            ),
456            (
457                "2562047h47m16.854775807s",
458                Duration::from_secs(9223372036) + (854775807 * NS),
459                "9223372036s854775807ns",
460            ),
461            ("0.1ns", Duration::ZERO, "0s0ns"),
462            ("05s", Duration::from_secs(5), "5s0ns"),
463            ("1ns1s", S + NS, "1s1ns"),
464            ("100h100m100s", (100 * H) + (100 * M) + (100 * S), "366100s0ns"),
465            ("5m32s", (5 * M) + (32 * S), "332s0ns"),
466            ("1m0s", M, "60s0ns"),
467            ("5m0s", 5 * M, "300s0ns"),
468            ("6m0s", 6 * M, "360s0ns"),
469            ("10m0s", 10 * M, "600s0ns"),
470            ("15m0s", 15 * M, "900s0ns"),
471            ("30m0s", 30 * M, "1800s0ns"),
472            ("40m0s", 40 * M, "2400s0ns"),
473            ("50m0s", 50 * M, "3000s0ns"),
474            ("87600h0m0s", 87600 * H, "315360000s0ns"),
475            ("5", 5 * NS, "0s5ns"),
476            (" 5s", 5 * S, "5s0ns"),
477            ("5s", 5 * S, "5s0ns"),
478        ];
479
480        for (input, expected, serialized) in cases {
481            run_success_case(input, *expected, serialized).unwrap();
482        }
483    }
484
485    fn run_failure_case(input: &str, expected_msg: &str) -> anyhow::Result<()> {
486        let result = DurationString::from_str(input);
487        match result {
488            Ok(value) => {
489                anyhow::bail!("Expected an error when parsing '{input}', but instead received the value '{value:?}'")
490            }
491            Err(e) => {
492                anyhow::ensure!(
493                    e.to_string().contains(expected_msg),
494                    "Expected the error message when parsing '{input}' to contain {expected_msg:?}, but the message is {e}"
495                );
496            }
497        }
498
499        Ok(())
500    }
501
502    #[test]
503    fn duration_string_failure_cases() {
504        let cases: &[(&str, &str)] = &[
505            ("5m32sFOO", "unknown unit 'sFOO'"),
506            ("", "empty duration"),
507            (" ", "empty duration"),
508            ("+", "empty duration"),
509            ("-", "empty duration"),
510            (".", "expected digits"),
511            ("s", "expected digits"),
512            (".s", "expected digits"),
513            ("--5s", "expected digits"),
514            ("5.5.5s", "missing unit"),
515            ("1e3s", "unknown unit 'e'"),
516            ("5ns5", "missing unit"),
517            ("9223372036854775808ns", "exceeds"),
518            ("-1s", "negative"),
519            ("-0.5h", "negative"),
520            ("1d", "unknown unit 'd'"),
521            ("1w", "unknown unit 'w'"),
522            ("1S", "unknown unit 'S'"),
523            ("12 µs", "missing unit"),
524            ("5 s", "missing unit"),
525            ("5. s", "missing unit"),
526        ];
527
528        for (input, expected_msg) in cases {
529            run_failure_case(input, expected_msg).unwrap();
530        }
531    }
532}