1use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration};
2
3use saluki_error::GenericError;
4use saluki_metrics::static_metrics;
5use tokio::time::sleep;
6use tokio_util::sync::{CancellationToken, DropGuard};
7use tracing::debug;
8
9use crate::{hash::FastBuildHasher, task::spawn_traced};
10
11mod expiry;
12use self::expiry::{Expiration, ExpirationBuilder, ExpiryCapableLifecycle};
13
14pub mod weight;
15use self::weight::{ItemCountWeighter, Weighter, WrappedWeighter};
16
17type RawCache<K, V, W, H> = quick_cache::sync::Cache<K, V, WrappedWeighter<W>, H, ExpiryCapableLifecycle<K>>;
18
19static_metrics! {
20 name => Telemetry,
21 prefix => cache,
22 labels => [cache_id: String],
23 metrics => [
24 gauge(current_items),
25 gauge(current_weight),
26 gauge(weight_limit),
27 counter(hits_total),
28 counter(misses_total),
29 counter(items_inserted_total),
30 counter(items_removed_total),
31 counter(items_expired_total),
32 trace_histogram(items_expired_batch_size),
33 ],
34}
35
36struct InnerCache<K, V, W, H> {
37 cache: Arc<RawCache<K, V, W, H>>,
38 _task_shutdown_guard: DropGuard,
39}
40
41pub struct CacheBuilder<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
43 identifier: String,
44 capacity: NonZeroUsize,
45 weighter: W,
46 idle_period: Option<Duration>,
47 expiration_interval: Option<Duration>,
48 telemetry_enabled: bool,
49 _key: PhantomData<K>,
50 _value: PhantomData<V>,
51 _hasher: PhantomData<H>,
52}
53
54impl<K, V> CacheBuilder<K, V> {
55 pub fn from_identifier<N: Into<String>>(identifier: N) -> Result<CacheBuilder<K, V>, GenericError> {
65 let identifier = identifier.into();
66 if identifier.is_empty() {
67 return Err(GenericError::msg("cache identifier must not be empty"));
68 }
69
70 Ok(CacheBuilder {
71 identifier,
72 capacity: NonZeroUsize::MAX,
73 weighter: ItemCountWeighter,
74 idle_period: None,
75 expiration_interval: None,
76 telemetry_enabled: true,
77 _key: PhantomData,
78 _value: PhantomData,
79 _hasher: PhantomData,
80 })
81 }
82
83 pub fn for_tests() -> CacheBuilder<K, V> {
94 CacheBuilder::from_identifier("noop")
95 .expect("cache identifier not empty")
96 .with_telemetry(false)
97 }
98}
99
100impl<K, V, W, H> CacheBuilder<K, V, W, H> {
101 pub fn with_capacity(mut self, capacity: NonZeroUsize) -> Self {
110 self.capacity = capacity;
111 self
112 }
113
114 pub fn with_time_to_idle(mut self, idle_period: Option<Duration>) -> Self {
124 self.idle_period = idle_period;
125
126 if self.idle_period.is_some() {
128 self.expiration_interval = self.expiration_interval.or(Some(Duration::from_secs(1)));
129 }
130
131 self
132 }
133
134 pub fn with_expiration_interval(mut self, expiration_interval: Duration) -> Self {
147 self.expiration_interval = Some(expiration_interval);
148 self
149 }
150
151 pub fn with_item_weighter<W2>(self, weighter: W2) -> CacheBuilder<K, V, W2, H> {
164 CacheBuilder {
165 identifier: self.identifier,
166 capacity: self.capacity,
167 weighter,
168 idle_period: self.idle_period,
169 expiration_interval: self.expiration_interval,
170 telemetry_enabled: self.telemetry_enabled,
171 _key: PhantomData,
172 _value: PhantomData,
173 _hasher: PhantomData,
174 }
175 }
176
177 pub fn with_hasher<H2>(self) -> CacheBuilder<K, V, W, H2> {
185 CacheBuilder {
186 identifier: self.identifier,
187 capacity: self.capacity,
188 weighter: self.weighter,
189 idle_period: self.idle_period,
190 expiration_interval: self.expiration_interval,
191 telemetry_enabled: self.telemetry_enabled,
192 _key: PhantomData,
193 _value: PhantomData,
194 _hasher: PhantomData,
195 }
196 }
197
198 pub fn with_telemetry(mut self, enabled: bool) -> Self {
207 self.telemetry_enabled = enabled;
208 self
209 }
210}
211
212impl<K, V, W, H> CacheBuilder<K, V, W, H>
213where
214 K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
215 V: Clone + Send + Sync + 'static,
216 W: Weighter<K, V> + Clone + Send + Sync + 'static,
217 H: std::hash::BuildHasher + Clone + Default + Send + Sync + 'static,
218{
219 pub fn build(self) -> Cache<K, V, W, H> {
221 let capacity = self.capacity.get();
222
223 let telemetry = Telemetry::new(self.identifier);
224 telemetry.weight_limit().set(capacity as f64);
225
226 let mut expiration_builder = ExpirationBuilder::new();
228 if let Some(time_to_idle) = self.idle_period {
229 expiration_builder = expiration_builder.with_time_to_idle(time_to_idle);
230 }
231 let (expiration, expiry_lifecycle) = expiration_builder.build();
232
233 let shutdown_token = CancellationToken::new();
235 let raw_cache = Arc::new(RawCache::with(
236 capacity,
237 capacity as u64,
238 WrappedWeighter::from(self.weighter),
239 H::default(),
240 expiry_lifecycle,
241 ));
242
243 let cache = Cache {
244 inner: Arc::new(InnerCache {
245 cache: Arc::clone(&raw_cache),
246 _task_shutdown_guard: shutdown_token.clone().drop_guard(),
247 }),
248 expiration: expiration.clone(),
249 telemetry: telemetry.clone(),
250 };
251
252 if let Some(expiration_interval) = self.expiration_interval {
254 let expiration = expiration.clone();
255
256 spawn_traced(drive_expiration(
257 Arc::clone(&raw_cache),
258 telemetry.clone(),
259 expiration,
260 expiration_interval,
261 shutdown_token.clone(),
262 ));
263 }
264
265 if self.telemetry_enabled {
267 spawn_traced(drive_telemetry(Arc::clone(&raw_cache), telemetry, shutdown_token));
268 }
269
270 cache
271 }
272}
273
274#[derive(Clone)]
276pub struct Cache<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
277 inner: Arc<InnerCache<K, V, W, H>>,
278 expiration: Expiration<K>,
279 telemetry: Telemetry,
280}
281
282impl<K, V, W, H> Cache<K, V, W, H>
283where
284 K: Eq + std::hash::Hash + Clone,
285 V: Clone,
286 W: Weighter<K, V> + Clone,
287 H: std::hash::BuildHasher + Clone,
288{
289 pub fn is_empty(&self) -> bool {
291 self.inner.cache.is_empty()
292 }
293
294 pub fn len(&self) -> usize {
296 self.inner.cache.len()
297 }
298
299 pub fn weight(&self) -> u64 {
301 self.inner.cache.weight()
302 }
303
304 pub fn insert(&self, key: K, value: V) {
310 self.inner.cache.insert(key.clone(), value);
311 self.expiration.mark_entry_accessed(key);
312 self.telemetry.items_inserted_total().increment(1);
313 }
314
315 pub fn get(&self, key: &K) -> Option<V> {
319 let value = self.inner.cache.get(key);
320 if value.is_some() {
321 self.expiration.mark_entry_accessed(key.clone());
322 self.telemetry.hits_total().increment(1);
323 } else {
324 self.telemetry.misses_total().increment(1);
325 }
326 value
327 }
328
329 pub fn remove(&self, key: &K) {
331 self.inner.cache.remove(key);
332 self.expiration.mark_entry_removed(key.clone());
333 self.telemetry.items_removed_total().increment(1);
334 }
335}
336
337async fn drive_expiration<K, V, W, H>(
338 cache: Arc<RawCache<K, V, W, H>>, telemetry: Telemetry, expiration: Expiration<K>, expiration_interval: Duration,
339 shutdown: CancellationToken,
340) where
341 K: Eq + std::hash::Hash + Clone,
342 V: Clone,
343 W: Weighter<K, V> + Clone,
344 H: std::hash::BuildHasher + Clone,
345{
346 let mut expired_item_keys = Vec::new();
347
348 loop {
349 tokio::select! {
350 _ = shutdown.cancelled() => break,
351 _ = sleep(expiration_interval) => {}
352 }
353
354 expiration.drain_expired_items(&mut expired_item_keys);
356
357 let num_expired_items = expired_item_keys.len();
358 if num_expired_items != 0 {
359 telemetry.items_expired_total().increment(num_expired_items as u64);
360 telemetry.items_expired_batch_size().record(num_expired_items as f64);
361 }
362
363 debug!(num_expired_items, "Found expired items.");
364
365 for item_key in expired_item_keys.drain(..) {
366 cache.remove(&item_key);
367 telemetry.items_removed_total().increment(1);
368 expiration.mark_entry_removed(item_key);
369 }
370
371 debug!(num_expired_items, "Removed expired items.");
372 }
373}
374
375async fn drive_telemetry<K, V, W, H>(
376 cache: Arc<RawCache<K, V, W, H>>, telemetry: Telemetry, shutdown: CancellationToken,
377) where
378 K: Eq + std::hash::Hash + Clone,
379 V: Clone,
380 W: Weighter<K, V> + Clone,
381 H: std::hash::BuildHasher + Clone,
382{
383 loop {
384 tokio::select! {
385 _ = shutdown.cancelled() => break,
386 _ = sleep(Duration::from_secs(1)) => {}
387 }
388
389 telemetry.current_items().set(cache.len() as f64);
390 telemetry.current_weight().set(cache.weight() as f64);
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[derive(Clone)]
399 pub struct ItemValueWeighter;
400
401 impl<K> Weighter<K, usize> for ItemValueWeighter {
402 fn item_weight(&self, _key: &K, value: &usize) -> u64 {
403 *value as u64
404 }
405 }
406
407 #[test]
408 fn empty_cache_identifier() {
409 let result = CacheBuilder::<u64, u64>::from_identifier("");
410 assert!(result.is_err(), "expected error for empty cache identifier");
411 }
412
413 #[test]
414 fn basic() {
415 const CACHE_KEY: usize = 42;
416 const CACHE_VALUE: &str = "value1";
417
418 let cache = CacheBuilder::for_tests().build();
419
420 assert_eq!(cache.len(), 0);
421 assert_eq!(cache.weight(), 0);
422
423 cache.insert(CACHE_KEY, CACHE_VALUE);
424 assert_eq!(cache.len(), 1);
425 assert_eq!(cache.weight(), 1);
426
427 assert_eq!(cache.get(&CACHE_KEY), Some(CACHE_VALUE));
428
429 cache.remove(&CACHE_KEY);
430 assert_eq!(cache.len(), 0);
431 assert_eq!(cache.weight(), 0);
432 }
433
434 #[test]
435 fn evict_at_capacity() {
436 const CAPACITY: usize = 3;
437
438 let cache = CacheBuilder::for_tests()
439 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
440 .build();
441
442 for i in 0..CAPACITY {
444 cache.insert(i, "value");
445 }
446
447 assert_eq!(cache.len(), CAPACITY);
448 assert_eq!(cache.weight(), CAPACITY as u64);
449
450 cache.insert(CAPACITY, "new_value");
453 assert_eq!(cache.len(), CAPACITY);
454 assert_eq!(cache.weight(), CAPACITY as u64);
455
456 let mut evicted = false;
457 for i in 0..CAPACITY {
458 if cache.get(&i).is_none() {
459 evicted = true;
460 break;
461 }
462 }
463 assert!(evicted, "expected at least one original item to be evicted");
464 }
465
466 #[test]
467 fn overweight_item() {
468 const CAPACITY: usize = 10;
469
470 let cache = CacheBuilder::for_tests()
472 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
473 .with_item_weighter(ItemValueWeighter)
474 .build();
475
476 assert_eq!(cache.len(), 0);
478 assert_eq!(cache.weight(), 0);
479
480 cache.insert(1, CAPACITY + 1);
481 assert_eq!(cache.len(), 0);
482 assert_eq!(cache.weight(), 0);
483 assert_eq!(cache.get(&1), None);
484 }
485
486 #[test]
487 fn evict_on_insert_by_weight() {
488 const CAPACITY: usize = 10;
489
490 let cache = CacheBuilder::for_tests()
492 .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
493 .with_item_weighter(ItemValueWeighter)
494 .build();
495
496 assert_eq!(cache.len(), 0);
498 assert_eq!(cache.weight(), 0);
499
500 cache.insert(1, 3);
501 cache.insert(2, 4);
502 cache.insert(3, 3);
503 assert_eq!(cache.len(), 3);
504 assert_eq!(cache.weight(), CAPACITY as u64);
505
506 cache.insert(4, CAPACITY - 1);
509 assert_eq!(cache.len(), 1);
510 assert_eq!(cache.weight(), (CAPACITY - 1) as u64);
511
512 assert_eq!(cache.get(&1), None);
513 assert_eq!(cache.get(&2), None);
514 assert_eq!(cache.get(&3), None);
515 assert_eq!(cache.get(&4), Some(CAPACITY - 1));
516 }
517
518 #[tokio::test]
519 async fn tasks_stop_when_cache_dropped() {
520 let cache = CacheBuilder::<u64, u64>::from_identifier("test-drop")
521 .expect("valid identifier")
522 .with_time_to_idle(Some(Duration::from_secs(60)))
523 .with_expiration_interval(Duration::from_millis(50))
524 .build();
525
526 let weak_cache = Arc::downgrade(&cache.inner.cache);
528
529 drop(cache);
530
531 sleep(Duration::from_millis(100)).await;
538
539 assert!(
540 weak_cache.upgrade().is_none(),
541 "raw cache should be released after background tasks exit"
542 );
543 }
544}