Skip to main content

saluki_common/cache/
mod.rs

1use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration};
2
3use saluki_error::GenericError;
4use saluki_metrics::reexport::metrics::Counter;
5use saluki_metrics::static_metrics;
6use tokio::time::sleep;
7use tokio_util::sync::{CancellationToken, DropGuard};
8use tracing::debug;
9
10use crate::{hash::FastBuildHasher, task::spawn_traced};
11
12mod expiry;
13use self::expiry::{Expiration, ExpirationBuilder, ExpiryCapableLifecycle};
14
15pub mod weight;
16use self::weight::{ItemCountWeighter, Weighter, WrappedWeighter};
17
18type RawCache<K, V, W, H> = quick_cache::sync::Cache<K, V, WrappedWeighter<W>, H, ExpiryCapableLifecycle<K>>;
19
20static_metrics! {
21    name => Telemetry,
22    prefix => cache,
23    labels => [cache_id: String],
24    metrics => [
25        gauge(current_items),
26        gauge(current_weight),
27        gauge(weight_limit),
28        counter(hits_total),
29        counter(misses_total),
30        counter(items_evicted_total),
31        debug_counter(items_inserted_total),
32        debug_counter(items_removed_total),
33        debug_counter(items_expired_total),
34        trace_histogram(items_expired_batch_size),
35    ],
36}
37
38struct InnerCache<K, V, W, H> {
39    cache: Arc<RawCache<K, V, W, H>>,
40    _task_shutdown_guard: DropGuard,
41}
42
43/// Builder for creating a [`Cache`].
44pub struct CacheBuilder<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
45    identifier: String,
46    capacity: NonZeroUsize,
47    weighter: W,
48    idle_period: Option<Duration>,
49    expiration_interval: Option<Duration>,
50    telemetry_enabled: bool,
51    _key: PhantomData<K>,
52    _value: PhantomData<V>,
53    _hasher: PhantomData<H>,
54}
55
56impl<K, V> CacheBuilder<K, V> {
57    /// Creates a new `CacheBuilder` with the given cache identifier.
58    ///
59    /// The cache identifier _should_ be unique, but it isn't required to be. Metrics for the cache will be emitted
60    /// using the given identifier, so in cases where the identifier isn't unique, those metrics will be aggregated
61    /// together and it won't be possible to distinguish between the different caches.
62    ///
63    /// # Errors
64    ///
65    /// If the given cache identifier is empty, an error is returned.
66    pub fn from_identifier<N: Into<String>>(identifier: N) -> Result<CacheBuilder<K, V>, GenericError> {
67        let identifier = identifier.into();
68        if identifier.is_empty() {
69            return Err(GenericError::msg("cache identifier must not be empty"));
70        }
71
72        Ok(CacheBuilder {
73            identifier,
74            capacity: NonZeroUsize::MAX,
75            weighter: ItemCountWeighter,
76            idle_period: None,
77            expiration_interval: None,
78            telemetry_enabled: true,
79            _key: PhantomData,
80            _value: PhantomData,
81            _hasher: PhantomData,
82        })
83    }
84
85    /// Configures a [`CacheBuilder`] that's suitable for tests.
86    ///
87    /// This configures the builder with the following defaults:
88    ///
89    /// - cache identifier of "noop"
90    /// - unlimited cache size
91    /// - telemetry disabled
92    ///
93    /// This is generally only useful for testing purposes, and is exposed publicly in order to be used in cross-crate
94    /// testing scenarios.
95    pub fn for_tests() -> CacheBuilder<K, V> {
96        CacheBuilder::from_identifier("noop")
97            .expect("cache identifier not empty")
98            .with_telemetry(false)
99    }
100}
101
102impl<K, V, W, H> CacheBuilder<K, V, W, H> {
103    /// Sets the capacity for the cache.
104    ///
105    /// The capacity is used, in conjunction with the item weighter, to determine how many items can be held in the
106    /// cache and when items should be evicted to make room for new items.
107    ///
108    /// See [`with_item_weighter`][Self::with_item_weighter] for more information on how the item weighter is used.
109    ///
110    /// Defaults to unlimited capacity.
111    pub fn with_capacity(mut self, capacity: NonZeroUsize) -> Self {
112        self.capacity = capacity;
113        self
114    }
115
116    /// Enables expiration of cached items based on how long since they were last accessed.
117    ///
118    /// Items which haven't been accessed within the configured duration will be marked for expiration, and be removed
119    /// from the cache shortly thereafter. For the purposes of expiration, "accessed" is either when the item was first
120    /// inserted or when it was last read.
121    ///
122    /// If the given value is `None`, expiration is disabled.
123    ///
124    /// Defaults to no expiration.
125    pub fn with_time_to_idle(mut self, idle_period: Option<Duration>) -> Self {
126        self.idle_period = idle_period;
127
128        // Make sure we have an expiration interval set if expiration is enabled.
129        if self.idle_period.is_some() {
130            self.expiration_interval = self.expiration_interval.or(Some(Duration::from_secs(1)));
131        }
132
133        self
134    }
135
136    /// Sets the interval at which the expiration process will run.
137    ///
138    /// This controls how often the expiration process will run to check for expired items. While items become
139    /// _eligible_ for expiration after the configured duration, they're not _guaranteed_ to be
140    /// removed immediately: the expiration process must still run to actually find the expired items and remove them.
141    ///
142    /// This means that the rough upper bound for how long an item may be kept alive is the sum of
143    /// both the configured expiration duration and the expiration interval.
144    ///
145    /// This value is only relevant if expiration is enabled.
146    ///
147    /// Defaults to 1 second.
148    pub fn with_expiration_interval(mut self, expiration_interval: Duration) -> Self {
149        self.expiration_interval = Some(expiration_interval);
150        self
151    }
152
153    /// Sets the item weighter for the cache.
154    ///
155    /// The item weighter is used to determine the "weight" of each item in the cache, which is used during
156    /// insertion/eviction to determine if an item can be held in the cache without first having to evict other items to
157    /// stay within the configured capacity.
158    ///
159    /// For example, if the configured capacity is set to 10,000, and the "item count" weighter is used, then the cache
160    /// will operate in a way that aims to simply ensure that no more than 10,000 items are held in the cache at any given
161    /// time. This allows defining custom weighters that can be used to track other aspects of the items in the cache,
162    /// such as their size in bytes, or some other metric that's relevant to the intended caching behavior.
163    ///
164    /// Defaults to "item count" weighter.
165    pub fn with_item_weighter<W2>(self, weighter: W2) -> CacheBuilder<K, V, W2, H> {
166        CacheBuilder {
167            identifier: self.identifier,
168            capacity: self.capacity,
169            weighter,
170            idle_period: self.idle_period,
171            expiration_interval: self.expiration_interval,
172            telemetry_enabled: self.telemetry_enabled,
173            _key: PhantomData,
174            _value: PhantomData,
175            _hasher: PhantomData,
176        }
177    }
178
179    /// Sets the item hasher for the cache.
180    ///
181    /// As cache keys are hashed before performing any reads or writes, the chosen hasher can potentially impact the
182    /// performance of those operations. In some scenarios, it may be desirable to use a different hasher than the
183    /// default one in order to optimize for specific key types or access patterns.
184    ///
185    /// Defaults to a fast, non-cryptographic hasher: [`FastBuildHasher`].
186    pub fn with_hasher<H2>(self) -> CacheBuilder<K, V, W, H2> {
187        CacheBuilder {
188            identifier: self.identifier,
189            capacity: self.capacity,
190            weighter: self.weighter,
191            idle_period: self.idle_period,
192            expiration_interval: self.expiration_interval,
193            telemetry_enabled: self.telemetry_enabled,
194            _key: PhantomData,
195            _value: PhantomData,
196            _hasher: PhantomData,
197        }
198    }
199
200    /// Sets whether or not to enable telemetry for this cache.
201    ///
202    /// Reporting the telemetry of the cache requires running an asynchronous task to override adding additional
203    /// overhead in the hot path of reading or writing to the cache. In some cases, it may be cumbersome to always
204    /// create the cache in an asynchronous context so that the telemetry task can be spawned. This method allows
205    /// disabling telemetry reporting in those cases.
206    ///
207    /// Defaults to telemetry enabled.
208    pub fn with_telemetry(mut self, enabled: bool) -> Self {
209        self.telemetry_enabled = enabled;
210        self
211    }
212}
213
214impl<K, V, W, H> CacheBuilder<K, V, W, H>
215where
216    K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
217    V: Clone + Send + Sync + 'static,
218    W: Weighter<K, V> + Clone + Send + Sync + 'static,
219    H: std::hash::BuildHasher + Clone + Default + Send + Sync + 'static,
220{
221    /// Builds a [`Cache`] from the current configuration.
222    pub fn build(self) -> Cache<K, V, W, H> {
223        let capacity = self.capacity.get();
224
225        let telemetry = Telemetry::new(self.identifier);
226        telemetry.weight_limit().set(capacity as f64);
227
228        // Configure expiration if enabled.
229        let eviction_counter = if self.telemetry_enabled {
230            telemetry.items_evicted_total().clone()
231        } else {
232            Counter::noop()
233        };
234        let mut expiration_builder = ExpirationBuilder::new(eviction_counter);
235        if let Some(time_to_idle) = self.idle_period {
236            expiration_builder = expiration_builder.with_time_to_idle(time_to_idle);
237        }
238        let (expiration, expiry_lifecycle) = expiration_builder.build();
239
240        // Create the underlying cache and shutdown signal.
241        let shutdown_token = CancellationToken::new();
242        let raw_cache = Arc::new(RawCache::with(
243            capacity,
244            capacity as u64,
245            WrappedWeighter::from(self.weighter),
246            H::default(),
247            expiry_lifecycle,
248        ));
249
250        let cache = Cache {
251            inner: Arc::new(InnerCache {
252                cache: Arc::clone(&raw_cache),
253                _task_shutdown_guard: shutdown_token.clone().drop_guard(),
254            }),
255            expiration: expiration.clone(),
256            telemetry: telemetry.clone(),
257        };
258
259        // If expiration is enabled, spawn a background task to actually drive expiration.
260        if let Some(expiration_interval) = self.expiration_interval {
261            let expiration = expiration.clone();
262
263            spawn_traced(drive_expiration(
264                Arc::clone(&raw_cache),
265                telemetry.clone(),
266                expiration,
267                expiration_interval,
268                shutdown_token.clone(),
269            ));
270        }
271
272        // If telemetry is enabled, spawn a background task to drive telemetry reporting.
273        if self.telemetry_enabled {
274            spawn_traced(drive_telemetry(Arc::clone(&raw_cache), telemetry, shutdown_token));
275        }
276
277        cache
278    }
279}
280
281/// A simple concurrent cache.
282#[derive(Clone)]
283pub struct Cache<K, V, W = ItemCountWeighter, H = FastBuildHasher> {
284    inner: Arc<InnerCache<K, V, W, H>>,
285    expiration: Expiration<K>,
286    telemetry: Telemetry,
287}
288
289impl<K, V, W, H> Cache<K, V, W, H>
290where
291    K: Eq + std::hash::Hash + Clone,
292    V: Clone,
293    W: Weighter<K, V> + Clone,
294    H: std::hash::BuildHasher + Clone,
295{
296    /// Returns `true` if the cache is empty.
297    pub fn is_empty(&self) -> bool {
298        self.inner.cache.is_empty()
299    }
300
301    /// Returns the number of items currently in the cache.
302    pub fn len(&self) -> usize {
303        self.inner.cache.len()
304    }
305
306    /// Returns the total weight of all items in the cache.
307    pub fn weight(&self) -> u64 {
308        self.inner.cache.weight()
309    }
310
311    /// Inserts an item into the cache with the given key and value.
312    ///
313    /// If an item with the same key already exists, it will be overwritten and the old value will be dropped. If the
314    /// cache is full, one or more items will be evicted to make room for the new item, based on the configured item
315    /// weighter and the weight of the new item.
316    pub fn insert(&self, key: K, value: V) {
317        self.inner.cache.insert(key.clone(), value);
318        self.expiration.mark_entry_accessed(key);
319        self.telemetry.items_inserted_total().increment(1);
320    }
321
322    /// Gets an item from the cache by its key.
323    ///
324    /// If the item is found, it's cloned and `Some(value)` is returned. Otherwise, `None` is returned.
325    pub fn get(&self, key: &K) -> Option<V> {
326        let value = self.inner.cache.get(key);
327        if value.is_some() {
328            self.expiration.mark_entry_accessed(key.clone());
329            self.telemetry.hits_total().increment(1);
330        } else {
331            self.telemetry.misses_total().increment(1);
332        }
333        value
334    }
335
336    /// Removes an item from the cache by its key.
337    pub fn remove(&self, key: &K) {
338        self.inner.cache.remove(key);
339        self.expiration.mark_entry_removed(key.clone());
340        self.telemetry.items_removed_total().increment(1);
341    }
342}
343
344async fn drive_expiration<K, V, W, H>(
345    cache: Arc<RawCache<K, V, W, H>>, telemetry: Telemetry, expiration: Expiration<K>, expiration_interval: Duration,
346    shutdown: CancellationToken,
347) where
348    K: Eq + std::hash::Hash + Clone,
349    V: Clone,
350    W: Weighter<K, V> + Clone,
351    H: std::hash::BuildHasher + Clone,
352{
353    let mut expired_item_keys = Vec::new();
354
355    loop {
356        tokio::select! {
357            _ = shutdown.cancelled() => break,
358            _ = sleep(expiration_interval) => {}
359        }
360
361        // Drain all expired items that have been queued up for the cache.
362        expiration.drain_expired_items(&mut expired_item_keys);
363
364        let num_expired_items = expired_item_keys.len();
365        if num_expired_items != 0 {
366            telemetry.items_expired_total().increment(num_expired_items as u64);
367            telemetry.items_expired_batch_size().record(num_expired_items as f64);
368        }
369
370        debug!(num_expired_items, "Found expired items.");
371
372        for item_key in expired_item_keys.drain(..) {
373            cache.remove(&item_key);
374            telemetry.items_removed_total().increment(1);
375            expiration.mark_entry_removed(item_key);
376        }
377
378        debug!(num_expired_items, "Removed expired items.");
379    }
380}
381
382async fn drive_telemetry<K, V, W, H>(
383    cache: Arc<RawCache<K, V, W, H>>, telemetry: Telemetry, shutdown: CancellationToken,
384) where
385    K: Eq + std::hash::Hash + Clone,
386    V: Clone,
387    W: Weighter<K, V> + Clone,
388    H: std::hash::BuildHasher + Clone,
389{
390    loop {
391        tokio::select! {
392            _ = shutdown.cancelled() => break,
393            _ = sleep(Duration::from_secs(1)) => {}
394        }
395
396        telemetry.current_items().set(cache.len() as f64);
397        telemetry.current_weight().set(cache.weight() as f64);
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[derive(Clone)]
406    pub struct ItemValueWeighter;
407
408    impl<K> Weighter<K, usize> for ItemValueWeighter {
409        fn item_weight(&self, _key: &K, value: &usize) -> u64 {
410            *value as u64
411        }
412    }
413
414    #[test]
415    fn empty_cache_identifier() {
416        let result = CacheBuilder::<u64, u64>::from_identifier("");
417        assert!(result.is_err(), "expected error for empty cache identifier");
418    }
419
420    #[test]
421    fn basic() {
422        const CACHE_KEY: usize = 42;
423        const CACHE_VALUE: &str = "value1";
424
425        let cache = CacheBuilder::for_tests().build();
426
427        assert_eq!(cache.len(), 0);
428        assert_eq!(cache.weight(), 0);
429
430        cache.insert(CACHE_KEY, CACHE_VALUE);
431        assert_eq!(cache.len(), 1);
432        assert_eq!(cache.weight(), 1);
433
434        assert_eq!(cache.get(&CACHE_KEY), Some(CACHE_VALUE));
435
436        cache.remove(&CACHE_KEY);
437        assert_eq!(cache.len(), 0);
438        assert_eq!(cache.weight(), 0);
439    }
440
441    #[test]
442    fn evict_at_capacity() {
443        const CAPACITY: usize = 3;
444
445        let cache = CacheBuilder::for_tests()
446            .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
447            .build();
448
449        // Insert items up to the capacity.
450        for i in 0..CAPACITY {
451            cache.insert(i, "value");
452        }
453
454        assert_eq!(cache.len(), CAPACITY);
455        assert_eq!(cache.weight(), CAPACITY as u64);
456
457        // Inserting another item should evict something else to make room, leaving it such that the cache still has the
458        // same number of items.
459        cache.insert(CAPACITY, "new_value");
460        assert_eq!(cache.len(), CAPACITY);
461        assert_eq!(cache.weight(), CAPACITY as u64);
462
463        let mut evicted = false;
464        for i in 0..CAPACITY {
465            if cache.get(&i).is_none() {
466                evicted = true;
467                break;
468            }
469        }
470        assert!(evicted, "expected at least one original item to be evicted");
471    }
472
473    #[test]
474    fn overweight_item() {
475        const CAPACITY: usize = 10;
476
477        // Create our cache using an "item value" weighter, which uses the item value itself as the weight.
478        let cache = CacheBuilder::for_tests()
479            .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
480            .with_item_weighter(ItemValueWeighter)
481            .build();
482
483        // We should fail to insert an item that is too heavy for the cache by itself.
484        assert_eq!(cache.len(), 0);
485        assert_eq!(cache.weight(), 0);
486
487        cache.insert(1, CAPACITY + 1);
488        assert_eq!(cache.len(), 0);
489        assert_eq!(cache.weight(), 0);
490        assert_eq!(cache.get(&1), None);
491    }
492
493    #[test]
494    fn evict_on_insert_by_weight() {
495        const CAPACITY: usize = 10;
496
497        // Create our cache using an "item value" weighter, which uses the item value itself as the weight.
498        let cache = CacheBuilder::for_tests()
499            .with_capacity(NonZeroUsize::new(CAPACITY).unwrap())
500            .with_item_weighter(ItemValueWeighter)
501            .build();
502
503        // Insert three items which together have a weight equal to the cache capacity.
504        assert_eq!(cache.len(), 0);
505        assert_eq!(cache.weight(), 0);
506
507        cache.insert(1, 3);
508        cache.insert(2, 4);
509        cache.insert(3, 3);
510        assert_eq!(cache.len(), 3);
511        assert_eq!(cache.weight(), CAPACITY as u64);
512
513        // Now try to insert an item that has a weight that is smaller than the cache capacity, but larger than all
514        // prior items combined, which should evict all prior items to make room for the new item.
515        cache.insert(4, CAPACITY - 1);
516        assert_eq!(cache.len(), 1);
517        assert_eq!(cache.weight(), (CAPACITY - 1) as u64);
518
519        assert_eq!(cache.get(&1), None);
520        assert_eq!(cache.get(&2), None);
521        assert_eq!(cache.get(&3), None);
522        assert_eq!(cache.get(&4), Some(CAPACITY - 1));
523    }
524
525    #[tokio::test]
526    async fn tasks_stop_when_cache_dropped() {
527        let cache = CacheBuilder::<u64, u64>::from_identifier("test-drop")
528            .expect("valid identifier")
529            .with_time_to_idle(Some(Duration::from_secs(60)))
530            .with_expiration_interval(Duration::from_millis(50))
531            .build();
532
533        // Grab a weak reference to the raw cache data held by the background tasks.
534        let weak_cache = Arc::downgrade(&cache.inner.cache);
535
536        drop(cache);
537
538        // When `InnerCache` is dropped, the cancellation token's drop guard is also dropped, which triggers
539        // cancellation, so both tasks should wake up immediately and exit, releasing their Arc<RawCache> references.
540        //
541        // TODO: There's no good way to assert the tasks have shutdown besides sleeping and checking the weak cache is
542        // gone. It would be nice if there was a way to asynchronously _and_ fallibly shutdown the runtime with a
543        // timeout, such that we could detect if they shutdown cleanly... but alas.
544        sleep(Duration::from_millis(100)).await;
545
546        assert!(
547            weak_cache.upgrade().is_none(),
548            "raw cache should be released after background tasks exit"
549        );
550    }
551}