1use std::time::Instant;
4
5pub struct TokenBucket {
10 capacity: f64,
11 tokens: f64,
12 last_refill: Instant,
13 rate: f64,
14}
15
16impl TokenBucket {
17 pub fn new(rate: f64, burst: usize) -> Self {
21 Self {
22 capacity: burst as f64,
23 tokens: burst as f64,
24 last_refill: Instant::now(),
25 rate,
26 }
27 }
28
29 pub fn allow(&mut self) -> bool {
31 let now = Instant::now();
32 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
33 self.tokens = (self.tokens + elapsed * self.rate).min(self.capacity);
34 self.last_refill = now;
35 if self.tokens >= 1.0 {
36 self.tokens -= 1.0;
37 true
38 } else {
39 false
40 }
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use std::time::Duration;
47
48 use super::TokenBucket;
49
50 #[test]
51 fn full_bucket_allows_up_to_burst() {
52 let burst = 5;
53 let mut bucket = TokenBucket::new(1.0, burst);
54 for _ in 0..burst {
55 assert!(bucket.allow());
56 }
57 assert!(!bucket.allow());
58 }
59
60 #[test]
61 fn empty_bucket_refills_over_time() {
62 let mut bucket = TokenBucket::new(100.0, 1);
63 assert!(bucket.allow()); assert!(!bucket.allow()); std::thread::sleep(Duration::from_millis(20)); assert!(bucket.allow());
68 }
69
70 #[test]
71 fn refill_does_not_exceed_capacity() {
72 let burst = 3;
73 let mut bucket = TokenBucket::new(1000.0, burst);
74 assert!(bucket.allow());
75 assert!(bucket.allow());
76 assert!(bucket.allow());
77 assert!(!bucket.allow());
78
79 std::thread::sleep(Duration::from_millis(50)); for _ in 0..burst {
81 assert!(bucket.allow());
82 }
83 assert!(!bucket.allow());
84 }
85
86 #[test]
87 fn zero_rate_never_refills() {
88 let mut bucket = TokenBucket::new(0.0, 1);
89 assert!(bucket.allow()); assert!(!bucket.allow());
91 std::thread::sleep(Duration::from_millis(20));
92 assert!(!bucket.allow()); }
94}