1use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration};
2
3use saluki_error::GenericError;
4use saluki_metrics::reexport::metrics::Counter;
5use saluki_metrics::static_metrics;
6use tokio::time::sleep;
7use tokio_util::sync::{CancellationToken, DropGuard};
8use tracing::debug;
9
10use crate::{hash::FastBuildHasher, task::spawn_traced};
11
12mod expiry;
13use self::expiry::{Expiration, ExpirationBuilder, ExpiryCapableLifecycle};
14
15pub mod weight;
16use self::weight::{ItemCountWeighter, Weighter, WrappedWeighter};
17
18type RawCache<K, V, W, H> = quick_cache::sync::Cache<K, V, WrappedWeighter<W>, H, ExpiryCapableLifecycle<K>>;
19
20static_metrics! {
21 name => Telemetry,
22 prefix => cache,
23 labels => [cache_id: String],
24 metrics => [
25 gauge(current_items),
26 gauge(current_weight),
27 gauge(weight_limit),
28 counter(hits_total),
29 counter(misses_total),
30 counter(items_evicted_total),
31 debug_counter(items_inserted_total),
32 debug_counter(items_removed_total),
33 debug_counter(items_expired_total),
34 trace_histogram(items_expired_batch_size),
35 ],
36}
37
38struct InnerCache<K, V, W, H> {
39 cache: Arc<RawCache<K, V, W, H>>,
40 _task_shutdown_guard: DropGuard,
41}
42
43pub struct CacheBuilder<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
45 identifier: String,
46 capacity: NonZeroUsize,
47 weighter: W,
48 idle_period: Option<Duration>,
49 expiration_interval: Option<Duration>,
50 telemetry_enabled: bool,
51 _key: PhantomData<K>,
52 _value: PhantomData<V>,
53 _hasher: PhantomData<H>,
54}
55
56impl<K, V> CacheBuilder<K, V> {
57 pub fn from_identifier<N: Into<String>>(identifier: N) -> Result<CacheBuilder<K, V>, GenericError> {
67 let identifier = identifier.into();
68 if identifier.is_empty() {
69 return Err(GenericError::msg("cache identifier must not be empty"));
70 }
71
72 Ok(CacheBuilder {
73 identifier,
74 capacity: NonZeroUsize::MAX,
75 weighter: ItemCountWeighter,
76 idle_period: None,
77 expiration_interval: None,
78 telemetry_enabled: true,
79 _key: PhantomData,
80 _value: PhantomData,
81 _hasher: PhantomData,
82 })
83 }
84
85 pub fn for_tests() -> CacheBuilder<K, V> {
96 CacheBuilder::from_identifier("noop")
97 .expect("cache identifier not empty")
98 .with_telemetry(false)
99 }
100}
101
102impl<K, V, W, H> CacheBuilder<K, V, W, H> {
103 pub fn with_capacity(mut self, capacity: NonZeroUsize) -> Self {
112 self.capacity = capacity;
113 self
114 }
115
116 pub fn with_time_to_idle(mut self, idle_period: Option<Duration>) -> Self {
126 self.idle_period = idle_period;
127
128 if self.idle_period.is_some() {
130 self.expiration_interval = self.expiration_interval.or(Some(Duration::from_secs(1)));
131 }
132
133 self
134 }
135
136 pub fn with_expiration_interval(mut self, expiration_interval: Duration) -> Self {
149 self.expiration_interval = Some(expiration_interval);
150 self
151 }
152
153 pub fn with_item_weighter<W2>(self, weighter: W2) -> CacheBuilder<K, V, W2, H> {
166 CacheBuilder {
167 identifier: self.identifier,
168 capacity: self.capacity,
169 weighter,
170 idle_period: self.idle_period,
171 expiration_interval: self.expiration_interval,
172 telemetry_enabled: self.telemetry_enabled,
173 _key: PhantomData,
174 _value: PhantomData,
175 _hasher: PhantomData,
176 }
177 }
178
179 pub fn with_hasher<H2>(self) -> CacheBuilder<K, V, W, H2> {
187 CacheBuilder {
188 identifier: self.identifier,
189 capacity: self.capacity,
190 weighter: self.weighter,
191 idle_period: self.idle_period,
192 expiration_interval: self.expiration_interval,
193 telemetry_enabled: self.telemetry_enabled,
194 _key: PhantomData,
195 _value: PhantomData,
196 _hasher: PhantomData,
197 }
198 }
199
200 pub fn with_telemetry(mut self, enabled: bool) -> Self {
209 self.telemetry_enabled = enabled;
210 self
211 }
212}
213
214impl<K, V, W, H> CacheBuilder<K, V, W, H>
215where
216 K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
217 V: Clone + Send + Sync + 'static,
218 W: Weighter<K, V> + Clone + Send + Sync + 'static,
219 H: std::hash::BuildHasher + Clone + Default + Send + Sync + 'static,
220{
221 pub fn build(self) -> Cache<K, V, W, H> {
223 let capacity = self.capacity.get();
224
225 let telemetry = Telemetry::new(self.identifier);
226 telemetry.weight_limit().set(capacity as f64);
227
228 let eviction_counter = if self.telemetry_enabled {
230 telemetry.items_evicted_total().clone()
231 } else {
232 Counter::noop()
233 };
234 let mut expiration_builder = ExpirationBuilder::new(eviction_counter);
235 if let Some(time_to_idle) = self.idle_period {
236 expiration_builder = expiration_builder.with_time_to_idle(time_to_idle);
237 }
238 let (expiration, expiry_lifecycle) = expiration_builder.build();
239
240 let shutdown_token = CancellationToken::new();
242 let raw_cache = Arc::new(RawCache::with(
243 capacity,
244 capacity as u64,
245 WrappedWeighter::from(self.weighter),
246 H::default(),
247 expiry_lifecycle,
248 ));
249
250 let cache = Cache {
251 inner: Arc::new(InnerCache {
252 cache: Arc::clone(&raw_cache),
253 _task_shutdown_guard: shutdown_token.clone().drop_guard(),
254 }),
255 expiration: expiration.clone(),
256 telemetry: telemetry.clone(),
257 };
258
259 if let Some(expiration_interval) = self.expiration_interval {
261 let expiration = expiration.clone();
262
263 spawn_traced(drive_expiration(
264 Arc::clone(&raw_cache),
265 telemetry.clone(),
266 expiration,
267 expiration_interval,
268 shutdown_token.clone(),
269 ));
270 }
271
272 if self.telemetry_enabled {
274 spawn_traced(drive_telemetry(Arc::clone(&raw_cache), telemetry, shutdown_token));
275 }
276
277 cache
278 }
279}
280
281#[derive(Clone)]
283pub struct Cache<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
284 inner: Arc<InnerCache<K, V, W, H>>,
285 expiration: Expiration<K>,
286 telemetry: Telemetry,
287}
288
289impl<K, V, W, H> Cache<K, V, W, H>
290where
291 K: Eq + std::hash::Hash + Clone,
292 V: Clone,
293 W: Weighter<K, V> + Clone,
294 H: std::hash::BuildHasher + Clone,
295{
296 pub fn is_empty(&self) -> bool {
298 self.inner.cache.is_empty()
299 }
300
301 pub fn len(&self) -> usize {
303 self.inner.cache.len()
304 }
305
306 pub fn weight(&self) -> u64 {
308 self.inner.cache.weight()
309 }
310
311 pub fn insert(&self, key: K, value: V) {
317 self.inner.cache.insert(key.clone(), value);
318 self.expiration.mark_entry_accessed(key);
319 self.telemetry.items_inserted_total().increment(1);
320 }
321
322 pub fn get(&self, key: &K) -> Option<V> {
326 let value = self.inner.cache.get(key);
327 if value.is_some() {
328 self.expiration.mark_entry_accessed(key.clone());
329 self.telemetry.hits_total().increment(1);
330 } else {
331 self.telemetry.misses_total().increment(1);
332 }
333 value
334 }
335
336 pub fn remove(&self, key: &K) {
338 self.inner.cache.remove(key);
339 self.expiration.mark_entry_removed(key.clone());
340 self.telemetry.items_removed_total().increment(1);
341 }
342}
343
344async fn drive_expiration<K, V, W, H>(
345 cache: Arc<RawCache<K, V, W, H>>, telemetry: Telemetry, expiration: Expiration<K>, expiration_interval: Duration,
346 shutdown: CancellationToken,
347) where
348 K: Eq + std::hash::Hash + Clone,
349 V: Clone,
350 W: Weighter<K, V> + Clone,
351 H: std::hash::BuildHasher + Clone,
352{
353 let mut expired_item_keys = Vec::new();
354
355 loop {
356 tokio::select! {
357 _ = shutdown.cancelled() => break,
358 _ = sleep(expiration_interval) => {}
359 }
360
361 expiration.drain_expired_items(&mut expired_item_keys);
363
364 let num_expired_items = expired_item_keys.len();
365 if num_expired_items != 0 {
366 telemetry.items_expired_total().increment(num_expired_items as u64);
367 telemetry.items_expired_batch_size().record(num_expired_items as f64);
368 }
369
370 debug!(num_expired_items, "Found expired items.");
371
372 for item_key in expired_item_keys.drain(..) {
373 cache.remove(&item_key);
374 telemetry.items_removed_total().increment(1);
375 expiration.mark_entry_removed(item_key);
376 }
377
378 debug!(num_expired_items, "Removed expired items.");
379 }
380}
381
382async fn drive_telemetry<K, V, W, H>(
383 cache: Arc<RawCache<K, V, W, H>>, telemetry: Telemetry, shutdown: CancellationToken,
384) where
385 K: Eq + std::hash::Hash + Clone,
386 V: Clone,
387 W: Weighter<K, V> + Clone,
388 H: std::hash::BuildHasher + Clone,
389{
390 loop {
391 tokio::select! {
392 _ = shutdown.cancelled() => break,
393 _ = sleep(Duration::from_secs(1)) => {}
394 }
395
396 telemetry.current_items().set(cache.len() as f64);
397 telemetry.current_weight().set(cache.weight() as f64);
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[derive(Clone)]
406 pub struct ItemValueWeighter;
407
408 impl<K> Weighter<K, usize> for ItemValueWeighter {
409 fn item_weight(&self, _key: &K, value: &usize) -> u64 {
410 *value as u64
411 }
412 }
413
414 #[test]
415 fn empty_cache_identifier() {
416 let result = CacheBuilder::<u64, u64>::from_identifier("");
417 assert!(result.is_err(), "expected error for empty cache identifier");
418 }
419
420 #[test]
421 fn basic() {
422 const CACHE_KEY: usize = 42;
423 const CACHE_VALUE: &str = "value1";
424
425 let cache = CacheBuilder::for_tests().build();
426
427 assert_eq!(cache.len(), 0);
428 assert_eq!(cache.weight(), 0);
429
430 cache.insert(CACHE_KEY, CACHE_VALUE);
431 assert_eq!(cache.len(), 1);
432 assert_eq!(cache.weight(), 1);
433
434 assert_eq!(cache.get(&CACHE_KEY), Some(CACHE_VALUE));
435
436 cache.remove(&CACHE_KEY);
437 assert_eq!(cache.len(), 0);
438 assert_eq!(cache.weight(), 0);
439 }
440
441 #[test]
442 fn evict_at_capacity() {
443 const CAPACITY: usize = 3;
444
445 let cache = CacheBuilder::for_tests()
446 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
447 .build();
448
449 for i in 0..CAPACITY {
451 cache.insert(i, "value");
452 }
453
454 assert_eq!(cache.len(), CAPACITY);
455 assert_eq!(cache.weight(), CAPACITY as u64);
456
457 cache.insert(CAPACITY, "new_value");
460 assert_eq!(cache.len(), CAPACITY);
461 assert_eq!(cache.weight(), CAPACITY as u64);
462
463 let mut evicted = false;
464 for i in 0..CAPACITY {
465 if cache.get(&i).is_none() {
466 evicted = true;
467 break;
468 }
469 }
470 assert!(evicted, "expected at least one original item to be evicted");
471 }
472
473 #[test]
474 fn overweight_item() {
475 const CAPACITY: usize = 10;
476
477 let cache = CacheBuilder::for_tests()
479 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
480 .with_item_weighter(ItemValueWeighter)
481 .build();
482
483 assert_eq!(cache.len(), 0);
485 assert_eq!(cache.weight(), 0);
486
487 cache.insert(1, CAPACITY + 1);
488 assert_eq!(cache.len(), 0);
489 assert_eq!(cache.weight(), 0);
490 assert_eq!(cache.get(&1), None);
491 }
492
493 #[test]
494 fn evict_on_insert_by_weight() {
495 const CAPACITY: usize = 10;
496
497 let cache = CacheBuilder::for_tests()
499 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
500 .with_item_weighter(ItemValueWeighter)
501 .build();
502
503 assert_eq!(cache.len(), 0);
505 assert_eq!(cache.weight(), 0);
506
507 cache.insert(1, 3);
508 cache.insert(2, 4);
509 cache.insert(3, 3);
510 assert_eq!(cache.len(), 3);
511 assert_eq!(cache.weight(), CAPACITY as u64);
512
513 cache.insert(4, CAPACITY - 1);
516 assert_eq!(cache.len(), 1);
517 assert_eq!(cache.weight(), (CAPACITY - 1) as u64);
518
519 assert_eq!(cache.get(&1), None);
520 assert_eq!(cache.get(&2), None);
521 assert_eq!(cache.get(&3), None);
522 assert_eq!(cache.get(&4), Some(CAPACITY - 1));
523 }
524
525 #[tokio::test]
526 async fn tasks_stop_when_cache_dropped() {
527 let cache = CacheBuilder::<u64, u64>::from_identifier("test-drop")
528 .expect("valid identifier")
529 .with_time_to_idle(Some(Duration::from_secs(60)))
530 .with_expiration_interval(Duration::from_millis(50))
531 .build();
532
533 let weak_cache = Arc::downgrade(&cache.inner.cache);
535
536 drop(cache);
537
538 sleep(Duration::from_millis(100)).await;
545
546 assert!(
547 weak_cache.upgrade().is_none(),
548 "raw cache should be released after background tasks exit"
549 );
550 }
551}