1use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration};
2
3use saluki_error::GenericError;
4use saluki_metrics::static_metrics;
5use tokio::time::sleep;
6use tracing::debug;
7
8use crate::{hash::FastBuildHasher, task::spawn_traced};
9
10mod expiry;
11use self::expiry::{Expiration, ExpirationBuilder, ExpiryCapableLifecycle};
12
13pub mod weight;
14use self::weight::{ItemCountWeighter, Weighter, WrappedWeighter};
15
16type InnerCache<K, V, W, H> = quick_cache::sync::Cache<K, V, WrappedWeighter<W>, H, ExpiryCapableLifecycle<K>>;
17
18static_metrics! {
19 name => Telemetry,
20 prefix => cache,
21 labels => [cache_id: String],
22 metrics => [
23 gauge(current_items),
24 gauge(current_weight),
25 gauge(weight_limit),
26 counter(hits_total),
27 counter(misses_total),
28 counter(items_inserted_total),
29 counter(items_removed_total),
30 counter(items_expired_total),
31 trace_histogram(items_expired_batch_size),
32 ],
33}
34
35pub struct CacheBuilder<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
37 identifier: String,
38 capacity: NonZeroUsize,
39 weighter: W,
40 idle_period: Option<Duration>,
41 expiration_interval: Option<Duration>,
42 telemetry_enabled: bool,
43 _key: PhantomData<K>,
44 _value: PhantomData<V>,
45 _hasher: PhantomData<H>,
46}
47
48impl<K, V> CacheBuilder<K, V> {
49 pub fn from_identifier<N: Into<String>>(identifier: N) -> Result<CacheBuilder<K, V>, GenericError> {
59 let identifier = identifier.into();
60 if identifier.is_empty() {
61 return Err(GenericError::msg("cache identifier must not be empty"));
62 }
63
64 Ok(CacheBuilder {
65 identifier,
66 capacity: NonZeroUsize::MAX,
67 weighter: ItemCountWeighter,
68 idle_period: None,
69 expiration_interval: None,
70 telemetry_enabled: true,
71 _key: PhantomData,
72 _value: PhantomData,
73 _hasher: PhantomData,
74 })
75 }
76
77 pub fn for_tests() -> CacheBuilder<K, V> {
88 CacheBuilder::from_identifier("noop")
89 .expect("cache identifier not empty")
90 .with_telemetry(false)
91 }
92}
93
94impl<K, V, W, H> CacheBuilder<K, V, W, H> {
95 pub fn with_capacity(mut self, capacity: NonZeroUsize) -> Self {
104 self.capacity = capacity;
105 self
106 }
107
108 pub fn with_time_to_idle(mut self, idle_period: Option<Duration>) -> Self {
118 self.idle_period = idle_period;
119
120 if self.idle_period.is_some() {
122 self.expiration_interval = self.expiration_interval.or(Some(Duration::from_secs(1)));
123 }
124
125 self
126 }
127
128 pub fn with_expiration_interval(mut self, expiration_interval: Duration) -> Self {
141 self.expiration_interval = Some(expiration_interval);
142 self
143 }
144
145 pub fn with_item_weighter<W2>(self, weighter: W2) -> CacheBuilder<K, V, W2, H> {
158 CacheBuilder {
159 identifier: self.identifier,
160 capacity: self.capacity,
161 weighter,
162 idle_period: self.idle_period,
163 expiration_interval: self.expiration_interval,
164 telemetry_enabled: self.telemetry_enabled,
165 _key: PhantomData,
166 _value: PhantomData,
167 _hasher: PhantomData,
168 }
169 }
170
171 pub fn with_hasher<H2>(self) -> CacheBuilder<K, V, W, H2> {
179 CacheBuilder {
180 identifier: self.identifier,
181 capacity: self.capacity,
182 weighter: self.weighter,
183 idle_period: self.idle_period,
184 expiration_interval: self.expiration_interval,
185 telemetry_enabled: self.telemetry_enabled,
186 _key: PhantomData,
187 _value: PhantomData,
188 _hasher: PhantomData,
189 }
190 }
191
192 pub fn with_telemetry(mut self, enabled: bool) -> Self {
201 self.telemetry_enabled = enabled;
202 self
203 }
204}
205
206impl<K, V, W, H> CacheBuilder<K, V, W, H>
207where
208 K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
209 V: Clone + Send + Sync + 'static,
210 W: Weighter<K, V> + Clone + Send + Sync + 'static,
211 H: std::hash::BuildHasher + Clone + Default + Send + Sync + 'static,
212{
213 pub fn build(self) -> Cache<K, V, W, H> {
215 let capacity = self.capacity.get();
216
217 let telemetry = Telemetry::new(self.identifier);
218 telemetry.weight_limit().set(capacity as f64);
219
220 let mut expiration_builder = ExpirationBuilder::new();
222 if let Some(time_to_idle) = self.idle_period {
223 expiration_builder = expiration_builder.with_time_to_idle(time_to_idle);
224 }
225 let (expiration, expiry_lifecycle) = expiration_builder.build();
226
227 let cache = Cache {
229 inner: Arc::new(InnerCache::with(
230 capacity,
231 capacity as u64,
232 WrappedWeighter::from(self.weighter),
233 H::default(),
234 expiry_lifecycle,
235 )),
236 expiration: expiration.clone(),
237 telemetry: telemetry.clone(),
238 };
239
240 if let Some(expiration_interval) = self.expiration_interval {
242 let expiration = expiration.clone();
243
244 spawn_traced(drive_expiration(
245 cache.clone(),
246 telemetry.clone(),
247 expiration,
248 expiration_interval,
249 ));
250 }
251
252 if self.telemetry_enabled {
254 spawn_traced(drive_telemetry(cache.clone(), telemetry));
255 }
256
257 cache
258 }
259}
260
261#[derive(Clone)]
263pub struct Cache<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
264 inner: Arc<InnerCache<K, V, W, H>>,
265 expiration: Expiration<K>,
266 telemetry: Telemetry,
267}
268
269impl<K, V, W, H> Cache<K, V, W, H>
270where
271 K: Eq + std::hash::Hash + Clone,
272 V: Clone,
273 W: Weighter<K, V> + Clone,
274 H: std::hash::BuildHasher + Clone,
275{
276 pub fn is_empty(&self) -> bool {
278 self.inner.is_empty()
279 }
280
281 pub fn len(&self) -> usize {
283 self.inner.len()
284 }
285
286 pub fn weight(&self) -> u64 {
288 self.inner.weight()
289 }
290
291 pub fn insert(&self, key: K, value: V) {
297 self.inner.insert(key.clone(), value);
298 self.expiration.mark_entry_accessed(key);
299 self.telemetry.items_inserted_total().increment(1);
300 }
301
302 pub fn get(&self, key: &K) -> Option<V> {
306 let value = self.inner.get(key);
307 if value.is_some() {
308 self.expiration.mark_entry_accessed(key.clone());
309 self.telemetry.hits_total().increment(1);
310 } else {
311 self.telemetry.misses_total().increment(1);
312 }
313 value
314 }
315
316 pub fn remove(&self, key: &K) {
318 self.inner.remove(key);
319 self.expiration.mark_entry_removed(key.clone());
320 self.telemetry.items_removed_total().increment(1);
321 }
322}
323
324async fn drive_expiration<K, V, W, H>(
325 cache: Cache<K, V, W, H>, telemetry: Telemetry, expiration: Expiration<K>, expiration_interval: Duration,
326) where
327 K: Eq + std::hash::Hash + Clone,
328 V: Clone,
329 W: Weighter<K, V> + Clone,
330 H: std::hash::BuildHasher + Clone,
331{
332 let mut expired_item_keys = Vec::new();
333
334 loop {
335 sleep(expiration_interval).await;
336
337 expiration.drain_expired_items(&mut expired_item_keys);
339
340 let num_expired_items = expired_item_keys.len();
341 if num_expired_items != 0 {
342 telemetry.items_expired_total().increment(num_expired_items as u64);
343 telemetry.items_expired_batch_size().record(num_expired_items as f64);
344 }
345
346 debug!(num_expired_items, "Found expired items.");
347
348 for item_key in expired_item_keys.drain(..) {
349 cache.remove(&item_key);
350 }
351
352 debug!(num_expired_items, "Removed expired items.");
353 }
354}
355
356async fn drive_telemetry<K, V, W, H>(cache: Cache<K, V, W, H>, telemetry: Telemetry)
357where
358 K: Eq + std::hash::Hash + Clone,
359 V: Clone,
360 W: Weighter<K, V> + Clone,
361 H: std::hash::BuildHasher + Clone,
362{
363 loop {
364 sleep(Duration::from_secs(1)).await;
365
366 telemetry.current_items().set(cache.len() as f64);
367 telemetry.current_weight().set(cache.weight() as f64);
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[derive(Clone)]
376 pub struct ItemValueWeighter;
377
378 impl<K> Weighter<K, usize> for ItemValueWeighter {
379 fn item_weight(&self, _key: &K, value: &usize) -> u64 {
380 *value as u64
381 }
382 }
383
384 #[test]
385 fn empty_cache_identifier() {
386 let result = CacheBuilder::<u64, u64>::from_identifier("");
387 assert!(result.is_err(), "expected error for empty cache identifier");
388 }
389
390 #[test]
391 fn basic() {
392 const CACHE_KEY: usize = 42;
393 const CACHE_VALUE: &str = "value1";
394
395 let cache = CacheBuilder::for_tests().build();
396
397 assert_eq!(cache.len(), 0);
398 assert_eq!(cache.weight(), 0);
399
400 cache.insert(CACHE_KEY, CACHE_VALUE);
401 assert_eq!(cache.len(), 1);
402 assert_eq!(cache.weight(), 1);
403
404 assert_eq!(cache.get(&CACHE_KEY), Some(CACHE_VALUE));
405
406 cache.remove(&CACHE_KEY);
407 assert_eq!(cache.len(), 0);
408 assert_eq!(cache.weight(), 0);
409 }
410
411 #[test]
412 fn evict_at_capacity() {
413 const CAPACITY: usize = 3;
414
415 let cache = CacheBuilder::for_tests()
416 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
417 .build();
418
419 for i in 0..CAPACITY {
421 cache.insert(i, "value");
422 }
423
424 assert_eq!(cache.len(), CAPACITY);
425 assert_eq!(cache.weight(), CAPACITY as u64);
426
427 cache.insert(CAPACITY, "new_value");
430 assert_eq!(cache.len(), CAPACITY);
431 assert_eq!(cache.weight(), CAPACITY as u64);
432
433 let mut evicted = false;
434 for i in 0..CAPACITY {
435 if cache.get(&i).is_none() {
436 evicted = true;
437 break;
438 }
439 }
440 assert!(evicted, "expected at least one original item to be evicted");
441 }
442
443 #[test]
444 fn overweight_item() {
445 const CAPACITY: usize = 10;
446
447 let cache = CacheBuilder::for_tests()
449 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
450 .with_item_weighter(ItemValueWeighter)
451 .build();
452
453 assert_eq!(cache.len(), 0);
455 assert_eq!(cache.weight(), 0);
456
457 cache.insert(1, CAPACITY + 1);
458 assert_eq!(cache.len(), 0);
459 assert_eq!(cache.weight(), 0);
460 assert_eq!(cache.get(&1), None);
461 }
462
463 #[test]
464 fn evict_on_insert_by_weight() {
465 const CAPACITY: usize = 10;
466
467 let cache = CacheBuilder::for_tests()
469 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
470 .with_item_weighter(ItemValueWeighter)
471 .build();
472
473 assert_eq!(cache.len(), 0);
475 assert_eq!(cache.weight(), 0);
476
477 cache.insert(1, 3);
478 cache.insert(2, 4);
479 cache.insert(3, 3);
480 assert_eq!(cache.len(), 3);
481 assert_eq!(cache.weight(), CAPACITY as u64);
482
483 cache.insert(4, CAPACITY - 1);
486 assert_eq!(cache.len(), 1);
487 assert_eq!(cache.weight(), (CAPACITY - 1) as u64);
488
489 assert_eq!(cache.get(&1), None);
490 assert_eq!(cache.get(&2), None);
491 assert_eq!(cache.get(&3), None);
492 assert_eq!(cache.get(&4), Some(CAPACITY - 1));
493 }
494}