1use std::fmt;
12use std::fmt::{Debug, Display, Formatter};
13use std::str::FromStr;
14use std::time::Duration;
15
16use serde::de::{self, Deserializer, Visitor};
17use serde::{Deserialize, Serialize, Serializer};
18use snafu::Snafu;
19
20#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
49pub struct DurationString(Duration);
50
51impl Debug for DurationString {
52 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
53 Debug::fmt(&self.0, f)
54 }
55}
56
57impl DurationString {
58 pub const fn new(d: Duration) -> Self {
60 Self(d)
61 }
62
63 pub const fn as_duration(&self) -> Duration {
65 self.0
66 }
67}
68
69impl From<Duration> for DurationString {
70 fn from(d: Duration) -> Self {
71 Self(d)
72 }
73}
74
75impl From<DurationString> for Duration {
76 fn from(d: DurationString) -> Self {
77 d.0
78 }
79}
80
81impl std::ops::Deref for DurationString {
82 type Target = Duration;
83
84 fn deref(&self) -> &Duration {
85 &self.0
86 }
87}
88
89impl FromStr for DurationString {
90 type Err = ParseDurationError;
91
92 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 parse_string(s).map(Self)
94 }
95}
96
97impl Display for DurationString {
98 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
99 write!(f, "{}s{}ns", self.0.as_secs(), self.0.subsec_nanos())
100 }
101}
102
103impl Serialize for DurationString {
104 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
105 serializer.collect_str(self)
106 }
107}
108
109impl<'de> Deserialize<'de> for DurationString {
110 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
111 deserializer.deserialize_any(DurationStringVisitor)
112 }
113}
114
115struct DurationStringVisitor;
116
117impl<'de> Visitor<'de> for DurationStringVisitor {
118 type Value = DurationString;
119
120 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 f.write_str("a duration string (e.g. \"30s\", \"1h30m\") or a non-negative number of nanoseconds")
122 }
123
124 fn visit_i64<E: de::Error>(self, n: i64) -> Result<DurationString, E> {
125 if n < 0 {
126 return Err(E::custom(ParseDurationError::Negative));
127 }
128 Ok(DurationString(Duration::from_nanos(n as u64)))
129 }
130
131 fn visit_i128<E: de::Error>(self, n: i128) -> Result<DurationString, E> {
132 if n < 0 {
133 return Err(E::custom(ParseDurationError::Negative));
134 }
135 if n > MAX_NANOS_U64 as i128 {
136 return Err(E::custom(ParseDurationError::Overflow));
137 }
138 Ok(DurationString(Duration::from_nanos(n as u64)))
139 }
140
141 fn visit_u64<E: de::Error>(self, n: u64) -> Result<DurationString, E> {
142 if n > MAX_NANOS_U64 {
143 return Err(E::custom(ParseDurationError::Overflow));
144 }
145 Ok(DurationString(Duration::from_nanos(n)))
146 }
147
148 fn visit_u128<E: de::Error>(self, n: u128) -> Result<DurationString, E> {
149 if n > MAX_NANOS_U64 as u128 {
150 return Err(E::custom(ParseDurationError::Overflow));
151 }
152 Ok(DurationString(Duration::from_nanos(n as u64)))
153 }
154
155 fn visit_f64<E: de::Error>(self, f: f64) -> Result<DurationString, E> {
156 if !f.is_finite() {
157 return Err(E::custom("duration nanoseconds must be finite"));
158 }
159 if f < 0.0 {
160 return Err(E::custom(ParseDurationError::Negative));
161 }
162 if f > MAX_NANOS_U64 as f64 {
163 return Err(E::custom(ParseDurationError::Overflow));
164 }
165 Ok(DurationString(Duration::from_nanos(f as u64)))
166 }
167
168 fn visit_str<E: de::Error>(self, s: &str) -> Result<DurationString, E> {
169 parse_string(s).map(DurationString).map_err(E::custom)
170 }
171
172 fn visit_string<E: de::Error>(self, s: String) -> Result<DurationString, E> {
173 self.visit_str(&s)
174 }
175}
176
177const MAX_NANOS_U64: u64 = i64::MAX as u64;
180
181#[derive(Debug, Snafu)]
183pub enum ParseDurationError {
184 #[snafu(display("invalid duration '{}': {}", input, reason))]
186 Invalid {
187 input: String,
189 reason: String,
191 },
192 #[snafu(display("negative durations are not supported"))]
194 Negative,
195 #[snafu(display("duration value exceeds supported range"))]
197 Overflow,
198}
199
200fn invalid(input: &str, reason: impl Into<String>) -> ParseDurationError {
201 ParseDurationError::Invalid {
202 input: input.to_string(),
203 reason: reason.into(),
204 }
205}
206
207fn parse_string(s: &str) -> Result<Duration, ParseDurationError> {
210 let trimmed = s.trim();
211 match parse_duration(trimmed) {
212 Ok(d) => Ok(d),
213 Err(err) => match trimmed.parse::<i128>() {
214 Ok(n) if n < 0 => Err(ParseDurationError::Negative),
215 Ok(n) => {
216 if n > MAX_NANOS_U64 as i128 {
217 return Err(ParseDurationError::Overflow);
218 }
219 Ok(Duration::from_nanos(n as u64))
220 }
221 Err(_) => Err(err),
222 },
223 }
224}
225
226fn parse_duration(s: &str) -> Result<Duration, ParseDurationError> {
229 let orig = s;
230 let mut rest = s;
231 let mut total_ns: u128 = 0;
232 let mut negative = false;
233
234 if let Some(c) = rest.chars().next() {
235 if c == '+' || c == '-' {
236 negative = c == '-';
237 rest = &rest[1..];
238 }
239 }
240
241 if rest == "0" {
243 return Ok(Duration::ZERO);
244 }
245 if rest.is_empty() {
246 return Err(invalid(orig, "empty duration"));
247 }
248
249 while !rest.is_empty() {
250 let (int_part, after_int) = consume_digits(rest);
251 let had_int = !int_part.is_empty();
252
253 let (frac_part, after_frac) = if let Some(stripped) = after_int.strip_prefix('.') {
254 consume_digits(stripped)
255 } else {
256 ("", after_int)
257 };
258 let consumed_dot = after_int.starts_with('.');
259 let had_frac = consumed_dot && !frac_part.is_empty();
260
261 if !had_int && !had_frac {
262 return Err(invalid(orig, "expected digits"));
263 }
264
265 rest = after_frac;
266
267 let unit_str = consume_unit(rest);
268 if unit_str.is_empty() {
269 return Err(invalid(orig, "missing unit"));
270 }
271 rest = &rest[unit_str.len()..];
272
273 let unit_ns: u128 = match unit_str {
274 "ns" => 1,
275 "us" | "µs" => 1_000,
276 "ms" => 1_000_000,
277 "s" => 1_000_000_000,
278 "m" => 60 * 1_000_000_000,
279 "h" => 3_600 * 1_000_000_000,
280 other => return Err(invalid(orig, format!("unknown unit '{}'", other))),
281 };
282
283 let int_val: u128 = if int_part.is_empty() {
284 0
285 } else {
286 int_part
287 .parse::<u128>()
288 .map_err(|_| invalid(orig, "integer overflow"))?
289 };
290
291 let mut ns = int_val.checked_mul(unit_ns).ok_or_else(|| invalid(orig, "overflow"))?;
292
293 if !frac_part.is_empty() {
294 let keep = frac_part.len().min(18);
297 let frac_digits = &frac_part[..keep];
298 let mut scale: u128 = 1;
299 for _ in 0..keep {
300 scale *= 10;
301 }
302 let f: u128 = frac_digits
303 .parse::<u128>()
304 .map_err(|_| invalid(orig, "invalid fractional"))?;
305 let frac_ns = f.checked_mul(unit_ns).ok_or_else(|| invalid(orig, "overflow"))? / scale;
306 ns = ns.checked_add(frac_ns).ok_or_else(|| invalid(orig, "overflow"))?;
307 }
308
309 total_ns = total_ns.checked_add(ns).ok_or_else(|| invalid(orig, "overflow"))?;
310 }
311
312 if negative && total_ns != 0 {
313 return Err(ParseDurationError::Negative);
314 }
315
316 if total_ns > MAX_NANOS_U64 as u128 {
317 return Err(ParseDurationError::Overflow);
318 }
319 Ok(Duration::from_nanos(total_ns as u64))
320}
321
322fn consume_digits(s: &str) -> (&str, &str) {
323 let end = s.bytes().take_while(|b| b.is_ascii_digit()).count();
324 s.split_at(end)
325}
326
327fn consume_unit(s: &str) -> &str {
328 let mut end = 0;
329 for (i, c) in s.char_indices() {
330 if c.is_ascii_alphabetic() || c == 'µ' {
331 end = i + c.len_utf8();
332 } else {
333 break;
334 }
335 }
336 &s[..end]
337}
338
339#[cfg(test)]
340mod tests {
341 use anyhow::Context as _;
342 use serde_json::json;
343
344 use super::*;
345
346 const NS: Duration = Duration::from_nanos(1);
347 const _US: Duration = Duration::from_micros(1);
348 const MS: Duration = Duration::from_millis(1);
349 const S: Duration = Duration::from_secs(1);
350 const M: Duration = Duration::from_secs(60);
351 const H: Duration = Duration::from_secs(3600);
352
353 #[test]
354 fn deserialize_integer_succeeds() {
355 let json = r#"{ "value": 15 }"#;
356 let deserialized: SerdeTest = serde_json::from_str(json).unwrap();
357 assert_eq!(deserialized.value.as_duration(), 15 * NS);
358 }
359
360 #[test]
362 fn deserialize_float_succeeds() {
363 let json = r#"{ "value": 1.5 }"#;
364 let deserialized: SerdeTest = serde_json::from_str(json).unwrap();
365 assert_eq!(deserialized.value.as_duration(), 1 * NS);
366 }
367
368 #[derive(Default, Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize)]
369 struct SerdeTest {
370 value: DurationString,
371 }
372
373 impl From<Duration> for SerdeTest {
374 fn from(value: Duration) -> Self {
375 Self { value: value.into() }
376 }
377 }
378
379 fn test_json(input_value: &str) -> String {
380 json!({"value": input_value}).to_string()
381 }
382
383 fn test_yaml(input_value: &str) -> String {
384 format!("value: {input_value}")
385 }
386
387 fn run_success_case(input: &str, expected: Duration, serialized: &str) -> anyhow::Result<()> {
388 let expected_struct: SerdeTest = expected.into();
389 let json = test_json(input);
390 let yaml = test_yaml(input);
391 let msg = format!("failure for duration test case '{input}'");
392 let parsed_duration = DurationString::from_str(input).context(msg.clone())?;
393 anyhow::ensure!(
394 expected == parsed_duration.as_duration(),
395 "{msg}, expected: {expected:?}, got {:?}",
396 parsed_duration.as_duration()
397 );
398 let deserialized_from_json: SerdeTest = serde_json::from_str(&json).context(msg.clone())?;
399 anyhow::ensure!(
400 expected_struct == deserialized_from_json,
401 "{msg}, expected: {expected_struct:?}, got {deserialized_from_json:?}"
402 );
403 let roundtrip_json = serde_json::from_str(&serde_json::to_string(&expected_struct)?)?;
404 anyhow::ensure!(
405 expected_struct == roundtrip_json,
406 "{msg}, expected json roundrip to produce {expected_struct:?}, but got {roundtrip_json:?}"
407 );
408 let deserialized_from_yaml: SerdeTest = serde_yaml::from_str(&yaml).context(msg.clone())?;
409 anyhow::ensure!(
410 expected_struct == deserialized_from_yaml,
411 "{msg}, expected: {expected_struct:?}, got {deserialized_from_yaml:?}"
412 );
413 let roundtrip_yaml = serde_yaml::from_str(&serde_yaml::to_string(&expected_struct)?)?;
414 anyhow::ensure!(
415 expected_struct == roundtrip_yaml,
416 "{msg}, expected json roundrip to produce {expected_struct:?}, but got {roundtrip_yaml:?}"
417 );
418 let actual_serialized = parsed_duration.to_string();
419 anyhow::ensure!(
420 serialized == actual_serialized,
421 "Expected the input '{input}' to be serialized as '{serialized}' but got '{actual_serialized}'"
422 );
423 Ok(())
424 }
425
426 #[test]
427 fn duration_string_success_cases() {
428 let cases: &[(&str, Duration, &str)] = &[
429 ("0", Duration::ZERO, "0s0ns"),
430 ("-0", Duration::ZERO, "0s0ns"),
431 ("+0", Duration::ZERO, "0s0ns"),
432 ("+5h", Duration::from_hours(5), "18000s0ns"),
433 (".5s", Duration::from_millis(500), "0s500000000ns"),
434 ("5.s", Duration::from_secs(5), "5s0ns"),
435 ("0.000000001s", Duration::from_nanos(1), "0s1ns"),
436 ("1.5h", Duration::from_mins(90), "5400s0ns"),
437 (
438 "2h45m30.5s",
439 (2 * H) + (45 * M) + (30 * S) + (500 * MS),
440 "9930s500000000ns",
441 ),
442 ("12µs", Duration::from_micros(12), "0s12000ns"),
443 ("0s", Duration::ZERO, "0s0ns"),
444 ("1h1m1s1ms1us1ns", H + M + S + MS + (1000 * NS) + NS, "3661s1001001ns"),
445 ("24h", Duration::from_hours(24), "86400s0ns"),
446 (
447 "9223372036854775807ns",
448 Duration::from_nanos(9223372036854775807),
449 "9223372036s854775807ns",
450 ),
451 (
452 "9223372036854775.807us",
453 Duration::from_secs(9223372036) + (854775807 * NS),
454 "9223372036s854775807ns",
455 ),
456 (
457 "2562047h47m16.854775807s",
458 Duration::from_secs(9223372036) + (854775807 * NS),
459 "9223372036s854775807ns",
460 ),
461 ("0.1ns", Duration::ZERO, "0s0ns"),
462 ("05s", Duration::from_secs(5), "5s0ns"),
463 ("1ns1s", S + NS, "1s1ns"),
464 ("100h100m100s", (100 * H) + (100 * M) + (100 * S), "366100s0ns"),
465 ("5m32s", (5 * M) + (32 * S), "332s0ns"),
466 ("1m0s", M, "60s0ns"),
467 ("5m0s", 5 * M, "300s0ns"),
468 ("6m0s", 6 * M, "360s0ns"),
469 ("10m0s", 10 * M, "600s0ns"),
470 ("15m0s", 15 * M, "900s0ns"),
471 ("30m0s", 30 * M, "1800s0ns"),
472 ("40m0s", 40 * M, "2400s0ns"),
473 ("50m0s", 50 * M, "3000s0ns"),
474 ("87600h0m0s", 87600 * H, "315360000s0ns"),
475 ("5", 5 * NS, "0s5ns"),
476 (" 5s", 5 * S, "5s0ns"),
477 ("5s", 5 * S, "5s0ns"),
478 ];
479
480 for (input, expected, serialized) in cases {
481 run_success_case(input, *expected, serialized).unwrap();
482 }
483 }
484
485 fn run_failure_case(input: &str, expected_msg: &str) -> anyhow::Result<()> {
486 let result = DurationString::from_str(input);
487 match result {
488 Ok(value) => {
489 anyhow::bail!("Expected an error when parsing '{input}', but instead received the value '{value:?}'")
490 }
491 Err(e) => {
492 anyhow::ensure!(
493 e.to_string().contains(expected_msg),
494 "Expected the error message when parsing '{input}' to contain {expected_msg:?}, but the message is {e}"
495 );
496 }
497 }
498
499 Ok(())
500 }
501
502 #[test]
503 fn duration_string_failure_cases() {
504 let cases: &[(&str, &str)] = &[
505 ("5m32sFOO", "unknown unit 'sFOO'"),
506 ("", "empty duration"),
507 (" ", "empty duration"),
508 ("+", "empty duration"),
509 ("-", "empty duration"),
510 (".", "expected digits"),
511 ("s", "expected digits"),
512 (".s", "expected digits"),
513 ("--5s", "expected digits"),
514 ("5.5.5s", "missing unit"),
515 ("1e3s", "unknown unit 'e'"),
516 ("5ns5", "missing unit"),
517 ("9223372036854775808ns", "exceeds"),
518 ("-1s", "negative"),
519 ("-0.5h", "negative"),
520 ("1d", "unknown unit 'd'"),
521 ("1w", "unknown unit 'w'"),
522 ("1S", "unknown unit 'S'"),
523 ("12 µs", "missing unit"),
524 ("5 s", "missing unit"),
525 ("5. s", "missing unit"),
526 ];
527
528 for (input, expected_msg) in cases {
529 run_failure_case(input, expected_msg).unwrap();
530 }
531 }
532}