saluki_io/
compression.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use async_compression::{
8    tokio::write::{ZlibEncoder, ZstdEncoder},
9    Level,
10};
11use http::HeaderValue;
12use pin_project::pin_project;
13use tokio::io::AsyncWrite;
14use tracing::trace;
15
16// "Red zone" threshold factor.
17//
18// See `CompressionEstimator::would_write_exceed_threshold` for details.
19const THRESHOLD_RED_ZONE: f64 = 0.99;
20
21static CONTENT_ENCODING_DEFLATE: HeaderValue = HeaderValue::from_static("deflate");
22static CONTENT_ENCODING_ZSTD: HeaderValue = HeaderValue::from_static("zstd");
23
24/// Compression schemes supported by `Compressor`.
25#[derive(Copy, Clone, Debug)]
26pub enum CompressionScheme {
27    /// No compression.
28    Noop,
29    /// Zlib.
30    Zlib(Level),
31    /// Zstd.
32    Zstd(Level),
33}
34
35impl CompressionScheme {
36    /// No compression.
37    pub const fn noop() -> Self {
38        Self::Noop
39    }
40
41    /// Zlib compression, using the default compression level (6).
42    pub const fn zlib_default() -> Self {
43        Self::Zlib(Level::Default)
44    }
45
46    /// Zstd compression, using the default compression level (3).
47    pub const fn zstd_default() -> Self {
48        Self::Zstd(Level::Default)
49    }
50
51    /// Create a new compression scheme from a string and level.
52    ///
53    /// Level is only used if the scheme is `zstd`.
54    ///
55    /// Defaults to zstd with level 3.
56    pub fn new(scheme: &str, level: i32) -> Self {
57        match scheme {
58            "zlib" => CompressionScheme::zlib_default(),
59            "zstd" => Self::Zstd(Level::Precise(level)),
60            _ => Self::Zstd(Level::Default),
61        }
62    }
63}
64
65#[pin_project]
66pub struct CountingWriter<W> {
67    #[pin]
68    inner: W,
69    total_written: u64,
70}
71
72impl<W> CountingWriter<W> {
73    fn new(inner: W) -> Self {
74        Self {
75            inner,
76            total_written: 0,
77        }
78    }
79
80    fn total_written(&self) -> u64 {
81        self.total_written
82    }
83
84    fn into_inner(self) -> W {
85        self.inner
86    }
87}
88
89/// Statistics for a writer.
90pub trait WriteStatistics {
91    /// Returns the total number of bytes written.
92    fn total_written(&self) -> u64;
93}
94
95impl<W: AsyncWrite> AsyncWrite for CountingWriter<W> {
96    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
97        let mut this = self.project();
98        this.inner.as_mut().poll_write(cx, buf).map(|result| {
99            if let Ok(written) = &result {
100                *this.total_written += *written as u64;
101            }
102
103            result
104        })
105    }
106
107    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
108        self.project().inner.poll_flush(cx)
109    }
110
111    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
112        self.project().inner.poll_shutdown(cx)
113    }
114}
115
116/// Generic compressor.
117///
118/// Exposes a semi-type-erased compression stream, by allowing the compression to be configured via `CompressionScheme`,
119/// and generically wrapping over a given writer.
120#[pin_project(project = CompressorProjected)]
121pub enum Compressor<W: AsyncWrite> {
122    /// No-op compressor.
123    Noop(#[pin] CountingWriter<W>),
124    /// Zlib compressor.
125    Zlib(#[pin] ZlibEncoder<W>),
126    /// Zstd compressor.
127    Zstd(#[pin] ZstdEncoder<CountingWriter<W>>),
128}
129
130impl<W: AsyncWrite> Compressor<W> {
131    /// Creates a new compressor from a given compression scheme and writer.
132    pub fn from_scheme(scheme: CompressionScheme, writer: W) -> Self {
133        match scheme {
134            CompressionScheme::Noop => Self::Noop(CountingWriter::new(writer)),
135            CompressionScheme::Zlib(level) => Self::Zlib(ZlibEncoder::with_quality(writer, level)),
136            CompressionScheme::Zstd(level) => Self::Zstd(ZstdEncoder::with_quality(CountingWriter::new(writer), level)),
137        }
138    }
139
140    /// Consumes the compressor, returning the inner writer.
141    pub fn into_inner(self) -> W {
142        match self {
143            Self::Noop(encoder) => encoder.into_inner(),
144            Self::Zlib(encoder) => encoder.into_inner(),
145            Self::Zstd(encoder) => encoder.into_inner().into_inner(),
146        }
147    }
148
149    /// Returns the content encoding for this compressor.
150    pub fn content_encoding(&self) -> Option<HeaderValue> {
151        match self {
152            Self::Noop(_) => None,
153            Self::Zlib(_) => Some(CONTENT_ENCODING_DEFLATE.clone()),
154            Self::Zstd(_) => Some(CONTENT_ENCODING_ZSTD.clone()),
155        }
156    }
157}
158
159impl<W: AsyncWrite> AsyncWrite for Compressor<W> {
160    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
161        match self.project() {
162            CompressorProjected::Noop(encoder) => encoder.poll_write(cx, buf),
163            CompressorProjected::Zlib(encoder) => encoder.poll_write(cx, buf),
164            CompressorProjected::Zstd(encoder) => encoder.poll_write(cx, buf),
165        }
166    }
167
168    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
169        match self.project() {
170            CompressorProjected::Noop(encoder) => encoder.poll_flush(cx),
171            CompressorProjected::Zlib(encoder) => encoder.poll_flush(cx),
172            CompressorProjected::Zstd(encoder) => encoder.poll_flush(cx),
173        }
174    }
175
176    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
177        match self.project() {
178            CompressorProjected::Noop(encoder) => encoder.poll_shutdown(cx),
179            CompressorProjected::Zlib(encoder) => encoder.poll_shutdown(cx),
180            CompressorProjected::Zstd(encoder) => encoder.poll_shutdown(cx),
181        }
182    }
183}
184
185impl<W: AsyncWrite> WriteStatistics for Compressor<W> {
186    fn total_written(&self) -> u64 {
187        match self {
188            Compressor::Noop(encoder) => encoder.total_written(),
189            Compressor::Zlib(encoder) => encoder.total_out(),
190            Compressor::Zstd(encoder) => encoder.get_ref().total_written(),
191        }
192    }
193}
194
195/// A streaming estimator for the size of compressed data.
196///
197/// For many compression algorithms, there is a large amount of buffering and state during compression. This allows
198/// compression algorithms to better compress data by finding patterns across the current and previous inputs, as well
199/// as amortize how often they write compressed data to the output stream, increasing the potential efficiency of the
200/// related function or system calls to do so.
201///
202/// However, this presents a problem when there is a need to ensure that the size of the compressed data does not exceed
203/// a certain threshold. As many inputs can be written to the compressor before the next chunk of compressed data is
204/// output, it is possible to write enough data that the compressed output exceeds the threshold. Further, many
205/// compression algorithms/implementations do not provide a way to query the size of the compressed data without
206/// expensive operations that either require doing multiple compression passes on different slices of the data, or early
207/// flushing of compressed data, potentially leading to abnormally low compression ratios.
208///
209/// This estimator provides a way to estimate the size of the compressed data by combining both the known size of data
210/// written to the compressor's output stream, as well as the inputs written to the compressor. We track the state
211/// changes of the compressor, observing when it writes compressed data to the output stream. We additionally track
212/// every write in terms of its uncompressed size. In combining the two, we estimate the worst-case size of the
213/// compressed data based on what we know has been compressed so far and what we've written since the last time the
214/// compressed flush to the output stream.
215///
216/// TODO: We should probably move this into `Compressor` itself, because it will also make it easier to do
217/// per-compression-algorithm tweaks to the estimation logic if that's a path we want to take, and it also would be
218/// cleaner and let us avoid any footguns around forgetting to update the necessary estimator state, etc.
219#[derive(Debug, Default)]
220pub struct CompressionEstimator {
221    in_flight_uncompressed_len: usize,
222    total_uncompressed_len: usize,
223    total_compressed_len: u64,
224    current_compression_ratio: f64,
225}
226
227impl CompressionEstimator {
228    /// Tracks a write to the compressor.
229    pub fn track_write<W>(&mut self, compressor: &W, uncompressed_len: usize)
230    where
231        W: WriteStatistics,
232    {
233        self.in_flight_uncompressed_len += uncompressed_len;
234        self.total_uncompressed_len += uncompressed_len;
235
236        let compressed_len = compressor.total_written();
237        let compressed_len_delta = (compressed_len - self.total_compressed_len) as usize;
238        if compressed_len_delta > 0 {
239            // We just observed the compressor flushing data, so we need to recalculate our compression ratio.
240            self.current_compression_ratio = compressed_len as f64 / self.total_uncompressed_len as f64;
241            self.total_compressed_len = compressed_len;
242            self.in_flight_uncompressed_len = 0;
243
244            trace!(
245                block_size = compressed_len_delta,
246                uncompressed_len = self.total_uncompressed_len,
247                compressed_len = self.total_compressed_len,
248                compression_ratio = self.current_compression_ratio,
249                "Compressor wrote block to output stream."
250            );
251        }
252    }
253
254    /// Resets the estimator.
255    pub fn reset(&mut self) {
256        self.in_flight_uncompressed_len = 0;
257        self.total_uncompressed_len = 0;
258        self.total_compressed_len = 0;
259        self.current_compression_ratio = 0.0;
260    }
261
262    /// Returns the estimated length of the compressor.
263    ///
264    /// This figure is the sum of the total bytes written by the compressor to the output stream and the number of
265    /// uncompressed bytes written to the compressor since the last time the compressor wrote to the output stream
266    /// when factoring in the estimated compression ratio over the overall output stream.
267    pub fn estimated_len(&self) -> usize {
268        let estimated_in_flight_compressed_len =
269            (self.in_flight_uncompressed_len as f64 * self.current_compression_ratio) as usize;
270
271        self.total_compressed_len as usize + estimated_in_flight_compressed_len
272    }
273
274    /// Estimates if writing `len` bytes to the compressor would cause the final compressed size to exceed `threshold`
275    /// bytes.
276    pub fn would_write_exceed_threshold(&self, len: usize, threshold: usize) -> bool {
277        // If the length of the data to be written exceeds the threshold, then it obviously would exceed the threshold.
278        if len > threshold {
279            return true;
280        }
281
282        // If we have yet to see any compressed data, we can't make a meaningful estimate, and this likely means that
283        // the compressor is still actively able to compress more data into the first block, which when eventually
284        // written, should never exceed the compressed size limit... so we choose to not block writes in this case.
285        if self.total_compressed_len == 0 {
286            return false;
287        }
288
289        // We adjust the given threshold down by a small amount to account for the fact that the final block written by
290        // the compressor has more variability in size than the rest, due to being more likely to be flushed before
291        // internal buffers are full and having the chance to most efficiently compress the data. Essentially, if we
292        // estimate that writing `len` more bytes would put our compressed length into the "red zone", then it's too
293        // risky to write those bytes.
294        //
295        // This is a bit of a fudge factor, but we arrived at the value through empirical testing with the regression
296        // detector benchmarks. Small enough to not have a major impact on payload size efficiency, but large enough to
297        // entirely get rid of compressed payload size limit violations.
298        let adjusted_threshold = (threshold as f64 * THRESHOLD_RED_ZONE) as usize;
299        self.estimated_len() + len > adjusted_threshold
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    struct MockCompressor {
308        current_uncompressed_len: u64,
309        total_uncompressed_len: usize,
310        compressed_len: u64,
311    }
312
313    impl MockCompressor {
314        fn new() -> Self {
315            MockCompressor {
316                current_uncompressed_len: 0,
317                total_uncompressed_len: 0,
318                compressed_len: 0,
319            }
320        }
321
322        fn write(&mut self, n: usize) {
323            self.current_uncompressed_len += n as u64;
324            self.total_uncompressed_len += n;
325        }
326
327        fn flush(&mut self, compression_ratio: f64) {
328            self.compressed_len += (self.current_uncompressed_len as f64 * compression_ratio) as u64;
329            self.current_uncompressed_len = 0;
330        }
331
332        fn total_uncompressed_len(&self) -> usize {
333            self.total_uncompressed_len
334        }
335    }
336
337    impl WriteStatistics for MockCompressor {
338        fn total_written(&self) -> u64 {
339            self.compressed_len
340        }
341    }
342
343    #[test]
344    fn compression_estimator_no_output() {
345        let estimator = CompressionEstimator::default();
346
347        assert!(!estimator.would_write_exceed_threshold(10, 100));
348        assert!(estimator.would_write_exceed_threshold(100, 90));
349    }
350
351    #[test]
352    fn compression_estimator_single_flush() {
353        const MAX_COMPRESSED_LEN: usize = 100;
354        const COMPRESSION_RATIO: f64 = 0.7;
355        const WRITE_LEN: usize = 50;
356
357        let mut estimator = CompressionEstimator::default();
358
359        // Create our mock compressor and do a basic write, and then flush, so that our estimator can get some data.
360        let mut compressor = MockCompressor::new();
361        assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
362
363        // Write 50 bytes with a compression ratio of 0.7, giving us 35 bytes compressed.
364        compressor.write(WRITE_LEN);
365        compressor.flush(COMPRESSION_RATIO);
366        assert_eq!(compressor.total_written(), 35);
367
368        estimator.track_write(&compressor, WRITE_LEN);
369
370        // We should be able to write 65 more bytes compressed, so 100 bytes uncompressed, given the compression ratio we have (0.7),
371        // would give us 70 bytes estimated.. which is over the threshold.
372        assert!(estimator.would_write_exceed_threshold(100, MAX_COMPRESSED_LEN));
373
374        // However, another 50 byte write would theoretically just be another 35 bytes compressed, so 85 bytes compressed total,
375        // which is under our threshold and should be allowed.
376        assert!(!estimator.would_write_exceed_threshold(WRITE_LEN, MAX_COMPRESSED_LEN));
377    }
378
379    #[test]
380    fn compression_estimator_multiple_flush_partial() {
381        const MAX_COMPRESSED_LEN: usize = 5000;
382        const FIRST_COMPRESSION_RATIO: f64 = 0.7;
383        const FIRST_WRITE_LEN: usize = 5000;
384        const SECOND_COMPRESSION_RATIO: f64 = 2.1;
385        const SECOND_WRITE_LEN: usize = 300;
386        const THIRD_WRITE_LEN: usize = 820;
387
388        let mut estimator = CompressionEstimator::default();
389
390        // Create our mock compressor and assert we can do our first write.
391        let mut compressor = MockCompressor::new();
392        assert!(!estimator.would_write_exceed_threshold(FIRST_WRITE_LEN, MAX_COMPRESSED_LEN));
393
394        // Write 5,000 bytes with a compression ratio of 0.7, giving us 3,500 bytes compressed.
395        compressor.write(FIRST_WRITE_LEN);
396        compressor.flush(FIRST_COMPRESSION_RATIO);
397        assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN);
398        assert_eq!(compressor.total_written(), 3500);
399
400        estimator.track_write(&compressor, FIRST_WRITE_LEN);
401
402        // We now do second write that simulates a "short" flush on the compressor: this might just be the compressor writing a partial block.
403        //
404        // What we want to test here is the estimator's ability to focus on the overall compression ratio rather than getting "lost" due to
405        // a single block being flushed which, when viewed naively, appears to be vastly bigger than the actual in-flight uncompressed data
406        // that it represents.
407        //
408        // We end up writing 300 bytes uncompressed with a compression ratio of 2.1, giving us 630 bytes compressed. We now have a total of
409        // 5,300 bytes uncompressed and 4,130 bytes compressed. Our overall compression ratio is now 0.77.
410        compressor.write(SECOND_WRITE_LEN);
411        compressor.flush(SECOND_COMPRESSION_RATIO);
412        assert_eq!(compressor.total_uncompressed_len(), FIRST_WRITE_LEN + SECOND_WRITE_LEN);
413        assert_eq!(compressor.total_written(), 4130);
414
415        estimator.track_write(&compressor, SECOND_WRITE_LEN);
416
417        // At this point, with our compressed limit of 5,000 bytes, we should be able to fit in another 870 bytes compressed. We do have to
418        // compensate for the "red zone" threshold, though, which should put us at a threshold of 4,950 bytes so 820 bytes compressed.
419        //
420        // We use the compressed length when calling `would_write_exceed_threshold` because it uses the uncompressed length as the worst-case
421        // scenario, which is that the write would not be compressed at all.
422        assert!(!estimator.would_write_exceed_threshold(THIRD_WRITE_LEN, MAX_COMPRESSED_LEN));
423    }
424}